From b92c9eb7ffad5742efef968b56d0b27cd60a4602 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 25 Jul 2025 17:26:02 +0200 Subject: [PATCH] fix: convert_action from_action_index --- bot/src/burnrl/environment.rs | 11 +++++ bot/src/burnrl/main.rs | 9 ++-- bot/src/burnrl/utils.rs | 38 +++++++++++++++- bot/src/burnrl/utils_wip.rs | 85 +++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 bot/src/burnrl/utils_wip.rs diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 669d3b4..8ccb600 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -92,6 +92,7 @@ impl Environment for TrictracEnvironment { type RewardType = f32; const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies + // const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { let mut game = GameState::new(false); @@ -139,6 +140,7 @@ impl Environment for TrictracEnvironment { // Convertir l'action burn-rl vers une action Trictrac let trictrac_action = self.convert_action(action, &self.game); + // println!("chosen action: {:?} -> {:?}", action, trictrac_action); let mut reward = 0.0; let mut terminated = false; @@ -204,6 +206,15 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, + ) -> Option { + dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, ) -> Option { use dqn_common::get_valid_actions; diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 41a29e2..9755ce1 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -10,7 +10,7 @@ type Env = environment::TrictracEnvironment; fn main() { println!("> Entraînement"); - let num_episodes = 3; + let num_episodes = 10; let agent = dqn_model::run::(num_episodes, false); //true); let valid_agent = agent.valid(); @@ -18,6 +18,9 @@ fn main() { println!("> Sauvegarde du modèle de validation"); save_model(valid_agent.model().as_ref().unwrap()); + println!("> Test avec le modèle entraîné"); + demo_model::(valid_agent); + println!("> Chargement du modèle pour test"); let loaded_model = load_model(); let loaded_agent = DQN::new(loaded_model); @@ -29,7 +32,7 @@ fn main() { fn save_model(model: &dqn_model::Net>) { let path = "models/burn_dqn".to_string(); let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.burn", path); + let model_path = format!("{}_model.mpk", path); println!("Modèle de validation sauvegardé : {}", model_path); recorder .record(model.clone().into_record(), model_path.into()) @@ -41,7 +44,7 @@ fn load_model() -> dqn_model::Net> { const DENSE_SIZE: usize = 128; let path = "models/burn_dqn".to_string(); - let model_path = format!("{}_model.burn", path); + let model_path = format!("{}_model.mpk", path); println!("Chargement du modèle depuis : {}", model_path); let device = NdArrayDevice::default(); diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index bc8d836..1815c08 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -9,10 +9,46 @@ pub fn demo_model(agent: impl Agent) { let mut state = env.state(); let mut done = false; while !done { + // // Get q values for current state + // let model = agent.model().as_ref().unwrap(); + // let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); + // let q_values = model.infer(state_tensor); + // + // // Get valid actions + // let valid_actions = get_valid_actions(&state); + // if valid_actions.is_empty() { + // break; // No valid actions, end of episode + // } + // + // // Set q values of non valid actions to the lowest + // let mut masked_q_values = q_values.clone(); + // let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + // for (index, q_value) in q_values_vec.iter().enumerate() { + // if !valid_actions.contains(&E::ActionType::from(index as u32)) { + // masked_q_values = masked_q_values.clone().mask_fill( + // masked_q_values.clone().equal_elem(*q_value), + // f32::NEG_INFINITY, + // ); + // } + // } + // + // // Get action with the highest q-value + // let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + // let action = E::ActionType::from(action_index); + // + // // Execute action + // let snapshot = env.step(action); + // state = *snapshot.state(); + // // println!("{:?}", state); + // done = snapshot.done(); + if let Some(action) = agent.react(&state) { + // println!("before : {:?}", state); + // println!("action : {:?}", action); let snapshot = env.step(action); state = *snapshot.state(); - // println!("{:?}", state); + // println!("after : {:?}", state); + // done = true; done = snapshot.done(); } } diff --git a/bot/src/burnrl/utils_wip.rs b/bot/src/burnrl/utils_wip.rs new file mode 100644 index 0000000..dcf08a2 --- /dev/null +++ b/bot/src/burnrl/utils_wip.rs @@ -0,0 +1,85 @@ +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn demo_model(agent: DQN, mut get_valid_actions: F) +where + E: Environment, + M: Module + burn_rl::agent::DQNModel, + B: Backend, + F: FnMut(&E) -> Vec, + ::ActionType: PartialEq, +{ + let mut env = E::new(true); + let mut state = env.state(); + let mut done = false; + let mut total_reward = 0.0; + let mut steps = 0; + + while !done { + let model = agent.model().as_ref().unwrap(); + let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); + let q_values = model.infer(state_tensor); + + let valid_actions = get_valid_actions(&env); + if valid_actions.is_empty() { + break; // No valid actions, end of episode + } + + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions.contains(&E::ActionType::from(index as u32)) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = E::ActionType::from(action_index); + + let snapshot = env.step(action); + state = *snapshot.state(); + total_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + steps += 1; + done = snapshot.done() || steps >= E::MAX_STEPS; + } + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + total_reward, steps + ); +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +}