diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index 9465ec1..f337289 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -1,15 +1,16 @@ +use crate::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use crate::burnrl::utils::soft_update_linear; -#[derive(Module, Debug, Clone)] +#[derive(Module, Debug)] pub struct Net { linear_0: Linear, linear_1: Linear, @@ -18,11 +19,11 @@ pub struct Net { impl Net { #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { Self { - linear_0: LinearConfig::new(input_size, dense_size).init(device), - linear_1: LinearConfig::new(dense_size, dense_size).init(device), - linear_2: LinearConfig::new(dense_size, output_size).init(device), + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), } } @@ -33,7 +34,7 @@ impl Net { impl Model, Tensor> for Net { fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input.clone())); + let layer_0_output = relu(self.linear_0.forward(input)); let layer_1_output = relu(self.linear_1.forward(layer_0_output)); relu(self.linear_2.forward(layer_1_output)) @@ -45,8 +46,8 @@ impl Model, Tensor> for Net { } impl DQNModel for Net { - fn soft_update(self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = self.consume(); + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); Self { linear_0: soft_update_linear(linear_0, &that.linear_0, tau), @@ -71,13 +72,11 @@ pub fn run( visualized: bool, ) -> impl Agent { let mut env = E::new(visualized); - let device = Default::default(); let model = Net::::new( - ::size(), + <::StateType as State>::size(), DENSE_SIZE, - ::size(), - &device, + <::ActionType as Action>::size(), ); let mut agent = MyAgent::new(model); @@ -108,7 +107,7 @@ pub fn run( let snapshot = env.step(action); episode_reward += - >::into(snapshot.reward().clone()); + <::RewardType as Into>::into(snapshot.reward().clone()); memory.push( state, @@ -119,7 +118,8 @@ pub fn run( ); if config.batch_size < memory.len() { - policy_net = agent.train(policy_net, &memory, &mut optimizer, &config); + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); } step += 1; @@ -138,5 +138,16 @@ pub fn run( } } } + + // Save + let path = "models/burn_dqn".to_string(); + let inference_network = agent.model().clone().into_record(); + let recorder = CompactRecorder::new(); + let model_path = format!("{}_model.burn", path); + println!("Modèle sauvegardé : {}", model_path); + recorder + .record(inference_network, model_path.into()) + .unwrap(); + agent.valid() -} \ No newline at end of file +} diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index e6faf80..669d3b4 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -199,15 +199,6 @@ impl Environment for TrictracEnvironment { } impl TrictracEnvironment { - pub fn valid_actions(&self) -> Vec { - dqn_common::get_valid_actions(&self.game) - .into_iter() - .map(|a| TrictracAction { - index: a.to_action_index() as u32, - }) - .collect() - } - /// Convertit une action burn-rl vers une action Trictrac fn convert_action( &self, @@ -389,4 +380,4 @@ impl TrictracEnvironment { } reward } -} \ No newline at end of file +} diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index aa657ac..6e55928 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,61 +1,13 @@ use bot::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::Module; -use burn::record::{CompactRecorder, Recorder}; -use burn_rl::agent::DQN; -use burn_rl::base::{Action, Agent, ElemType, Environment, State}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::base::ElemType; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; fn main() { - println!("> Entraînement"); let num_episodes = 3; let agent = dqn_model::run::(num_episodes, false); //true); - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - save_model(valid_agent.model().as_ref().unwrap()); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(); - let loaded_agent = DQN::new(loaded_model); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent, |env| env.valid_actions()); + demo_model::(agent); } - -fn save_model(model: &dqn_model::Net>) { - let path = "models/burn_dqn".to_string(); - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.burn", path); - println!("Modèle de validation sauvegardé : {}", model_path); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -fn load_model() -> dqn_model::Net> { - // TODO : reprendre le DENSE_SIZE de dqn_model.rs - const DENSE_SIZE: usize = 128; - - let path = "models/burn_dqn".to_string(); - let model_path = format!("{}_model.burn", path); - println!("Chargement du modèle depuis : {}", model_path); - - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - DENSE_SIZE, - ::ActionType::size(), - &device, - ) - .load_record(record) -} \ No newline at end of file diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index d17df4a..bc8d836 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -1,60 +1,21 @@ -use burn::module::{Module, Param, ParamId}; +use burn::module::{Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; use burn::tensor::Tensor; -use burn_rl::base::{Action, ElemType, Environment, State}; -use burn_rl::agent::DQN; +use burn_rl::base::{Agent, ElemType, Environment}; -pub fn demo_model( - agent: DQN, - mut get_valid_actions: F, -) where - E: Environment, - M: Module + burn_rl::agent::DQNModel, - B: Backend, - F: FnMut(&E) -> Vec, - ::ActionType: PartialEq, -{ +pub fn demo_model(agent: impl Agent) { let mut env = E::new(true); let mut state = env.state(); let mut done = false; - let mut total_reward = 0.0; - let mut steps = 0; - while !done { - let model = agent.model().as_ref().unwrap(); - let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); - let q_values = model.infer(state_tensor); - - let valid_actions = get_valid_actions(&env); - if valid_actions.is_empty() { - break; // No valid actions, end of episode + if let Some(action) = agent.react(&state) { + let snapshot = env.step(action); + state = *snapshot.state(); + // println!("{:?}", state); + done = snapshot.done(); } - - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions.contains(&E::ActionType::from(index as u32)) { - masked_q_values = - masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY); - } - } - - let action_index = masked_q_values.argmax(1).into_scalar() as u32; - let action = E::ActionType::from(action_index); - - let snapshot = env.step(action); - state = *snapshot.state(); - total_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); - steps += 1; - done = snapshot.done() || steps >= E::MAX_STEPS; } - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - total_reward, steps - ); } fn soft_update_tensor( diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index e01f406..81aa5f1 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use store::MoveRules; #[derive(Debug)] diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 6f22fac..779ce3d 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 1993a5d..9a24ae6 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice}; +use store::{CheckerMove, Dice, GameEvent, PlayerId}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -259,7 +259,7 @@ impl SimpleNeuralNetwork { /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - + use crate::PointsRules; use store::TurnStage; let mut valid_actions = Vec::new();