From 1b58ca4ccc3220a98e5d6f9e753186116f2ed8aa Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 17:07:34 +0200 Subject: [PATCH] refact dqn burn demo --- bot/src/dqn/burnrl/main.rs | 44 ++++---------------- bot/src/dqn/burnrl/utils.rs | 39 ++++++++++++++++-- bot/src/strategy/dqn.rs | 82 ++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 82 deletions(-) diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 7b4584c..8408e6a 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,9 +1,10 @@ -use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::Module; -use burn::record::{CompactRecorder, Recorder}; +use bot::dqn::burnrl::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; -use burn_rl::base::{Action, Agent, ElemType, Environment, State}; +use burn_rl::base::ElemType; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; @@ -25,12 +26,9 @@ fn main() { println!("> Sauvegarde du modèle de validation"); - let path = "models/burn_dqn_50".to_string(); + let path = "models/burn_dqn_40".to_string(); save_model(valid_agent.model().as_ref().unwrap(), &path); - // println!("> Test avec le modèle entraîné"); - // demo_model::(valid_agent); - println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); let loaded_agent = DQN::new(loaded_model); @@ -38,31 +36,3 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - -fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.mpk", path); - println!("Modèle de validation sauvegardé : {}", model_path); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { - let model_path = format!("{}_model.mpk", path); - println!("Chargement du modèle depuis : {}", model_path); - - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) -} diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index ba04cb6..66fa850 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -1,12 +1,45 @@ -use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::dqn::burnrl::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; use crate::dqn::dqn_common::get_valid_action_indices; -use burn::module::{Param, ParamId}; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::backend::Backend; use burn::tensor::cast::ToElement; use burn::tensor::Tensor; use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{ElemType, Environment, State}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { + let model_path = format!("{path}_model.mpk"); + println!("Chargement du modèle depuis : {model_path}"); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) +} pub fn demo_model>(agent: DQN) { let mut env = TrictracEnvironment::new(true); diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 109a9cf..34fb853 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -114,50 +114,48 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { - dice_order, - from1, - from2, - } = action - { - let dicevals = self.game.dice.values; - let (mut dice1, mut dice2) = if dice_order { - (dicevals.0, dicevals.1) - } else { - (dicevals.1, dicevals.0) - }; + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; - if from1 == 0 { - // empty move - dice1 = 0; - } - let mut to1 = from1 + dice1 as usize; - if 24 < to1 { - // sortie - to1 = 0; - } - if from2 == 0 { - // empty move - dice2 = 0; - } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; - } - - let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); - - let chosen_move = if self.color == Color::White { - (checker_move1, checker_move2) - } else { - (checker_move1.mirror(), checker_move2.mirror()) - }; - - return chosen_move; + if from1 == 0 { + // empty move + dice1 = 0; } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; } // Fallback : utiliser la stratégie par défaut