diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index f337289..9465ec1 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -1,16 +1,15 @@ -use crate::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use crate::burnrl::utils::soft_update_linear; -#[derive(Module, Debug)] +#[derive(Module, Debug, Clone)] pub struct Net { linear_0: Linear, linear_1: Linear, @@ -19,11 +18,11 @@ pub struct Net { impl Net { #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self { Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + linear_0: LinearConfig::new(input_size, dense_size).init(device), + linear_1: LinearConfig::new(dense_size, dense_size).init(device), + linear_2: LinearConfig::new(dense_size, output_size).init(device), } } @@ -34,7 +33,7 @@ impl Net { impl Model, Tensor> for Net { fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); + let layer_0_output = relu(self.linear_0.forward(input.clone())); let layer_1_output = relu(self.linear_1.forward(layer_0_output)); relu(self.linear_2.forward(layer_1_output)) @@ -46,8 +45,8 @@ impl Model, Tensor> for Net { } impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); + fn soft_update(self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = self.consume(); Self { linear_0: soft_update_linear(linear_0, &that.linear_0, tau), @@ -72,11 +71,13 @@ pub fn run( visualized: bool, ) -> impl Agent { let mut env = E::new(visualized); + let device = Default::default(); let model = Net::::new( - <::StateType as State>::size(), + ::size(), DENSE_SIZE, - <::ActionType as Action>::size(), + ::size(), + &device, ); let mut agent = MyAgent::new(model); @@ -107,7 +108,7 @@ pub fn run( let snapshot = env.step(action); episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); + >::into(snapshot.reward().clone()); memory.push( state, @@ -118,8 +119,7 @@ pub fn run( ); if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); + policy_net = agent.train(policy_net, &memory, &mut optimizer, &config); } step += 1; @@ -138,16 +138,5 @@ pub fn run( } } } - - // Save - let path = "models/burn_dqn".to_string(); - let inference_network = agent.model().clone().into_record(); - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.burn", path); - println!("Modèle sauvegardé : {}", model_path); - recorder - .record(inference_network, model_path.into()) - .unwrap(); - agent.valid() -} +} \ No newline at end of file diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 669d3b4..e6faf80 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -199,6 +199,15 @@ impl Environment for TrictracEnvironment { } impl TrictracEnvironment { + pub fn valid_actions(&self) -> Vec { + dqn_common::get_valid_actions(&self.game) + .into_iter() + .map(|a| TrictracAction { + index: a.to_action_index() as u32, + }) + .collect() + } + /// Convertit une action burn-rl vers une action Trictrac fn convert_action( &self, @@ -380,4 +389,4 @@ impl TrictracEnvironment { } reward } -} +} \ No newline at end of file diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 6e55928..aa657ac 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,13 +1,61 @@ use bot::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::base::ElemType; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::Module; +use burn::record::{CompactRecorder, Recorder}; +use burn_rl::agent::DQN; +use burn_rl::base::{Action, Agent, ElemType, Environment, State}; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; fn main() { + println!("> Entraînement"); let num_episodes = 3; let agent = dqn_model::run::(num_episodes, false); //true); - demo_model::(agent); + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + save_model(valid_agent.model().as_ref().unwrap()); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(); + let loaded_agent = DQN::new(loaded_model); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent, |env| env.valid_actions()); } + +fn save_model(model: &dqn_model::Net>) { + let path = "models/burn_dqn".to_string(); + let recorder = CompactRecorder::new(); + let model_path = format!("{}_model.burn", path); + println!("Modèle de validation sauvegardé : {}", model_path); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +fn load_model() -> dqn_model::Net> { + // TODO : reprendre le DENSE_SIZE de dqn_model.rs + const DENSE_SIZE: usize = 128; + + let path = "models/burn_dqn".to_string(); + let model_path = format!("{}_model.burn", path); + println!("Chargement du modèle depuis : {}", model_path); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + DENSE_SIZE, + ::ActionType::size(), + &device, + ) + .load_record(record) +} \ No newline at end of file diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index bc8d836..d17df4a 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -1,21 +1,60 @@ -use burn::module::{Param, ParamId}; +use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; use burn::tensor::Tensor; -use burn_rl::base::{Agent, ElemType, Environment}; +use burn_rl::base::{Action, ElemType, Environment, State}; +use burn_rl::agent::DQN; -pub fn demo_model(agent: impl Agent) { +pub fn demo_model( + agent: DQN, + mut get_valid_actions: F, +) where + E: Environment, + M: Module + burn_rl::agent::DQNModel, + B: Backend, + F: FnMut(&E) -> Vec, + ::ActionType: PartialEq, +{ let mut env = E::new(true); let mut state = env.state(); let mut done = false; + let mut total_reward = 0.0; + let mut steps = 0; + while !done { - if let Some(action) = agent.react(&state) { - let snapshot = env.step(action); - state = *snapshot.state(); - // println!("{:?}", state); - done = snapshot.done(); + let model = agent.model().as_ref().unwrap(); + let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); + let q_values = model.infer(state_tensor); + + let valid_actions = get_valid_actions(&env); + if valid_actions.is_empty() { + break; // No valid actions, end of episode } + + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions.contains(&E::ActionType::from(index as u32)) { + masked_q_values = + masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY); + } + } + + let action_index = masked_q_values.argmax(1).into_scalar() as u32; + let action = E::ActionType::from(action_index); + + let snapshot = env.step(action); + state = *snapshot.state(); + total_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + steps += 1; + done = snapshot.done() || steps >= E::MAX_STEPS; } + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + total_reward, steps + ); } fn soft_update_tensor( diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 81aa5f1..e01f406 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use store::MoveRules; #[derive(Debug)] diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 779ce3d..6f22fac 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use std::path::Path; use store::MoveRules; diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 9a24ae6..1993a5d 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice, GameEvent, PlayerId}; +use store::{CheckerMove, Dice}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -259,7 +259,7 @@ impl SimpleNeuralNetwork { /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; + use store::TurnStage; let mut valid_actions = Vec::new();