From 18e85744d695b590978d1371fad40e520e328f55 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 20 Aug 2025 13:09:57 +0200 Subject: [PATCH] refacto: burnrl --- bot/Cargo.toml | 4 + bot/src/burnrl/dqn/main.rs | 54 -------- bot/src/burnrl/dqn/mod.rs | 2 - bot/src/burnrl/dqn/utils.rs | 112 ---------------- bot/src/burnrl/dqn_big/main.rs | 54 -------- bot/src/burnrl/dqn_big/mod.rs | 2 - bot/src/burnrl/dqn_big/utils.rs | 112 ---------------- .../dqn_model.rs => dqn_big_model.rs} | 112 ++++++++-------- bot/src/burnrl/{dqn => }/dqn_model.rs | 107 +++++++--------- bot/src/burnrl/dqn_valid/main.rs | 53 -------- bot/src/burnrl/dqn_valid/mod.rs | 2 - bot/src/burnrl/dqn_valid/utils.rs | 112 ---------------- .../dqn_model.rs => dqn_valid_model.rs} | 112 +++++++--------- bot/src/burnrl/environment.rs | 19 +-- bot/src/burnrl/environment_big.rs | 11 +- bot/src/burnrl/main.rs | 58 +++++++++ bot/src/burnrl/mod.rs | 11 +- bot/src/burnrl/ppo/main.rs | 52 -------- bot/src/burnrl/ppo/mod.rs | 2 - bot/src/burnrl/ppo/utils.rs | 88 ------------- bot/src/burnrl/{ppo => }/ppo_model.rs | 64 +-------- bot/src/burnrl/sac/main.rs | 45 ------- bot/src/burnrl/sac/mod.rs | 2 - bot/src/burnrl/sac/utils.rs | 78 ----------- bot/src/burnrl/{sac => }/sac_model.rs | 85 +++++------- bot/src/burnrl/utils.rs | 121 ++++++++++++++++++ bot/src/strategy/dqnburn.rs | 5 +- 27 files changed, 387 insertions(+), 1092 deletions(-) delete mode 100644 bot/src/burnrl/dqn/main.rs delete mode 100644 bot/src/burnrl/dqn/mod.rs delete mode 100644 bot/src/burnrl/dqn/utils.rs delete mode 100644 bot/src/burnrl/dqn_big/main.rs delete mode 100644 bot/src/burnrl/dqn_big/mod.rs delete mode 100644 bot/src/burnrl/dqn_big/utils.rs rename bot/src/burnrl/{dqn_valid/dqn_model.rs => dqn_big_model.rs} (70%) rename bot/src/burnrl/{dqn => }/dqn_model.rs (71%) delete mode 100644 bot/src/burnrl/dqn_valid/main.rs delete mode 100644 bot/src/burnrl/dqn_valid/mod.rs delete mode 100644 bot/src/burnrl/dqn_valid/utils.rs rename bot/src/burnrl/{dqn_big/dqn_model.rs => dqn_valid_model.rs} (67%) create mode 100644 bot/src/burnrl/main.rs delete mode 100644 bot/src/burnrl/ppo/main.rs delete mode 100644 bot/src/burnrl/ppo/mod.rs delete mode 100644 bot/src/burnrl/ppo/utils.rs rename bot/src/burnrl/{ppo => }/ppo_model.rs (71%) delete mode 100644 bot/src/burnrl/sac/main.rs delete mode 100644 bot/src/burnrl/sac/mod.rs delete mode 100644 bot/src/burnrl/sac/utils.rs rename bot/src/burnrl/{sac => }/sac_model.rs (80%) create mode 100644 bot/src/burnrl/utils.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index ecda4d0..20c4e93 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -5,6 +5,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "burn_demo" +path = "src/burnrl/main.rs" + [[bin]] name = "train_dqn_burn_valid" path = "src/burnrl/dqn_valid/main.rs" diff --git a/bot/src/burnrl/dqn/main.rs b/bot/src/burnrl/dqn/main.rs deleted file mode 100644 index fb55c60..0000000 --- a/bot/src/burnrl/dqn/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -use bot::burnrl::dqn::{ - dqn_model, - utils::{demo_model, load_model, save_model}, -}; -use bot::burnrl::environment; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 50, // 40 - min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 1000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? - - gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.0005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "bot/models/burnrl_dqn".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/burnrl/dqn/mod.rs b/bot/src/burnrl/dqn/mod.rs deleted file mode 100644 index 27fcc58..0000000 --- a/bot/src/burnrl/dqn/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod dqn_model; -pub mod utils; diff --git a/bot/src/burnrl/dqn/utils.rs b/bot/src/burnrl/dqn/utils.rs deleted file mode 100644 index 77e2402..0000000 --- a/bot/src/burnrl/dqn/utils.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::burnrl::dqn::dqn_model; -use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use crate::training_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/burnrl/dqn_big/main.rs b/bot/src/burnrl/dqn_big/main.rs deleted file mode 100644 index a8c5c9f..0000000 --- a/bot/src/burnrl/dqn_big/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -use bot::burnrl::dqn_big::{ - dqn_model, - utils::{demo_model, load_model, save_model}, -}; -use bot::burnrl::environment_big; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment_big::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 40, // 40 - min_steps: 2000.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 4000, // 1000 max steps by episode - dense_size: 128, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 1000.0, // 1000 ? - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "models/burn_dqn_40".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/burnrl/dqn_big/mod.rs b/bot/src/burnrl/dqn_big/mod.rs deleted file mode 100644 index 27fcc58..0000000 --- a/bot/src/burnrl/dqn_big/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod dqn_model; -pub mod utils; diff --git a/bot/src/burnrl/dqn_big/utils.rs b/bot/src/burnrl/dqn_big/utils.rs deleted file mode 100644 index fa8de44..0000000 --- a/bot/src/burnrl/dqn_big/utils.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::burnrl::dqn_big::dqn_model; -use crate::burnrl::environment_big::{TrictracAction, TrictracEnvironment}; -use crate::training_common_big::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/burnrl/dqn_valid/dqn_model.rs b/bot/src/burnrl/dqn_big_model.rs similarity index 70% rename from bot/src/burnrl/dqn_valid/dqn_model.rs rename to bot/src/burnrl/dqn_big_model.rs index 9d53a2f..7e8951f 100644 --- a/bot/src/burnrl/dqn_valid/dqn_model.rs +++ b/bot/src/burnrl/dqn_big_model.rs @@ -1,15 +1,16 @@ -use crate::burnrl::dqn_valid::utils::soft_update_linear; -use crate::burnrl::environment::TrictracEnvironment; +use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; use std::time::SystemTime; #[derive(Module, Debug)] @@ -62,66 +63,18 @@ impl DQNModel for Net { #[allow(unused)] const MEMORY_SIZE: usize = 8192; -pub struct DqnConfig { - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - type MyAgent = DQN>; #[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, +// pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, visualized: bool, -) -> DQN> { - // ) -> impl Agent { + // ) -> DQN> { +) -> impl Agent { let mut env = E::new(visualized); env.as_mut().max_steps = conf.max_steps; @@ -189,8 +142,13 @@ pub fn run, B: AutodiffBackend>( if snapshot.done() || episode_duration >= conf.max_steps { let envmut = env.as_mut(); + let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32) + * 100.0) + .round() as u32; println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}", + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + goodmoves_ratio, envmut.pointrolls_count, now.elapsed().unwrap().as_secs(), ); @@ -202,5 +160,35 @@ pub fn run, B: AutodiffBackend>( } } } - agent + let valid_agent = agent.valid(); + if let Some(path) = &conf.save_path { + save_model(valid_agent.model().as_ref().unwrap(), path); + } + valid_agent +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() } diff --git a/bot/src/burnrl/dqn/dqn_model.rs b/bot/src/burnrl/dqn_model.rs similarity index 71% rename from bot/src/burnrl/dqn/dqn_model.rs rename to bot/src/burnrl/dqn_model.rs index 204cef0..efec37e 100644 --- a/bot/src/burnrl/dqn/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -1,15 +1,16 @@ -use crate::burnrl::dqn::utils::soft_update_linear; use crate::burnrl::environment::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; use std::time::SystemTime; #[derive(Module, Debug)] @@ -62,69 +63,18 @@ impl DQNModel for Net { #[allow(unused)] const MEMORY_SIZE: usize = 8192; -pub struct DqnConfig { - pub min_steps: f32, - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("min_steps={:?}\n", self.min_steps)); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - min_steps: 250.0, - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - type MyAgent = DQN>; #[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, +// pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, visualized: bool, -) -> DQN> { - // ) -> impl Agent { + // ) -> DQN> { +) -> impl Agent { let mut env = E::new(visualized); // env.as_mut().min_steps = conf.min_steps; env.as_mut().max_steps = conf.max_steps; @@ -203,7 +153,6 @@ pub fn run, B: AutodiffBackend>( envmut.pointrolls_count, now.elapsed().unwrap().as_secs(), ); - if goodmoves_ratio < 5 && 10 < episode {} env.reset(); episode_done = true; now = SystemTime::now(); @@ -212,5 +161,35 @@ pub fn run, B: AutodiffBackend>( } } } - agent + let valid_agent = agent.valid(); + if let Some(path) = &conf.save_path { + save_model(valid_agent.model().as_ref().unwrap(), path); + } + valid_agent +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() } diff --git a/bot/src/burnrl/dqn_valid/main.rs b/bot/src/burnrl/dqn_valid/main.rs deleted file mode 100644 index b049372..0000000 --- a/bot/src/burnrl/dqn_valid/main.rs +++ /dev/null @@ -1,53 +0,0 @@ -use bot::burnrl::dqn_valid::{ - dqn_model, - utils::{demo_model, load_model, save_model}, -}; -use bot::burnrl::environment; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 100, // 40 - max_steps: 1000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "bot/models/burn_dqn_valid_40".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/burnrl/dqn_valid/mod.rs b/bot/src/burnrl/dqn_valid/mod.rs deleted file mode 100644 index 27fcc58..0000000 --- a/bot/src/burnrl/dqn_valid/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod dqn_model; -pub mod utils; diff --git a/bot/src/burnrl/dqn_valid/utils.rs b/bot/src/burnrl/dqn_valid/utils.rs deleted file mode 100644 index 2e87e2a..0000000 --- a/bot/src/burnrl/dqn_valid/utils.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::burnrl::dqn_valid::dqn_model; -use crate::burnrl::environment_valid::{TrictracAction, TrictracEnvironment}; -use crate::training_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/burnrl/dqn_big/dqn_model.rs b/bot/src/burnrl/dqn_valid_model.rs similarity index 67% rename from bot/src/burnrl/dqn_big/dqn_model.rs rename to bot/src/burnrl/dqn_valid_model.rs index 1ccafef..6198100 100644 --- a/bot/src/burnrl/dqn_big/dqn_model.rs +++ b/bot/src/burnrl/dqn_valid_model.rs @@ -1,15 +1,16 @@ -use crate::burnrl::dqn_big::utils::soft_update_linear; -use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::environment_valid::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; use std::time::SystemTime; #[derive(Module, Debug)] @@ -62,71 +63,19 @@ impl DQNModel for Net { #[allow(unused)] const MEMORY_SIZE: usize = 8192; -pub struct DqnConfig { - pub min_steps: f32, - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("min_steps={:?}\n", self.min_steps)); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - min_steps: 250.0, - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - type MyAgent = DQN>; #[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, +// pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, visualized: bool, -) -> DQN> { - // ) -> impl Agent { + // ) -> DQN> { +) -> impl Agent { let mut env = E::new(visualized); - env.as_mut().min_steps = conf.min_steps; env.as_mut().max_steps = conf.max_steps; let model = Net::::new( @@ -194,8 +143,7 @@ pub fn run, B: AutodiffBackend>( if snapshot.done() || episode_duration >= conf.max_steps { let envmut = env.as_mut(); println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", - envmut.goodmoves_count, + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}", envmut.pointrolls_count, now.elapsed().unwrap().as_secs(), ); @@ -207,5 +155,35 @@ pub fn run, B: AutodiffBackend>( } } } - agent + let valid_agent = agent.valid(); + if let Some(path) = &conf.save_path { + save_model(valid_agent.model().as_ref().unwrap(), path); + } + valid_agent +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() } diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 1d8e80d..9805451 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -139,6 +139,7 @@ impl Environment for TrictracEnvironment { fn reset(&mut self) -> Snapshot { // Réinitialiser le jeu + let history = self.game.history.clone(); self.game = GameState::new(false); self.game.init_player("DQN Agent"); self.game.init_player("Opponent"); @@ -157,18 +158,18 @@ impl Environment for TrictracEnvironment { let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 { let path = "bot/models/logs/debug.log"; if let Ok(mut out) = std::fs::File::create(path) { - write!(out, "{:?}", self.game.history); + write!(out, "{:?}", history); } "!!!!" } else { "" }; - println!( - "info: correct moves: {} ({}%) {}", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32, - warning - ); + // println!( + // "info: correct moves: {} ({}%) {}", + // self.goodmoves_count, + // (100.0 * self.goodmoves_ratio).round() as u32, + // warning + // ); self.step_count = 0; self.pointrolls_count = 0; self.goodmoves_count = 0; @@ -369,7 +370,7 @@ impl TrictracEnvironment { if self.game.validate(&dice_event) { self.game.consume(&dice_event); let (points, adv_points) = self.game.dice_points; - reward += REWARD_RATIO * (points - adv_points) as f32; + reward += REWARD_RATIO * (points as f32 - adv_points as f32); if points > 0 { is_rollpoint = true; // println!("info: rolled for {reward}"); @@ -479,7 +480,7 @@ impl TrictracEnvironment { PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); // Récompense proportionnelle aux points - reward -= REWARD_RATIO * (points - adv_points) as f32; + reward -= REWARD_RATIO * (points as f32 - adv_points as f32); } } } diff --git a/bot/src/burnrl/environment_big.rs b/bot/src/burnrl/environment_big.rs index b362fc1..1bba2bd 100644 --- a/bot/src/burnrl/environment_big.rs +++ b/bot/src/burnrl/environment_big.rs @@ -89,7 +89,6 @@ pub struct TrictracEnvironment { current_state: TrictracState, episode_reward: f32, pub step_count: usize, - pub min_steps: f32, pub max_steps: usize, pub pointrolls_count: usize, pub goodmoves_count: usize, @@ -122,7 +121,6 @@ impl Environment for TrictracEnvironment { current_state, episode_reward: 0.0, step_count: 0, - min_steps: 250.0, max_steps: 2000, pointrolls_count: 0, goodmoves_count: 0, @@ -196,9 +194,10 @@ impl Environment for TrictracEnvironment { } // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + // let max_steps = self.max_steps + // let max_steps = self.min_steps + // + (self.max_steps as f32 - self.min_steps) + // * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); if done { @@ -211,7 +210,7 @@ impl Environment for TrictracEnvironment { } } } - let terminated = done || self.step_count >= max_steps.round() as usize; + let terminated = done || self.step_count >= self.max_steps; // Mettre à jour l'état self.current_state = TrictracState::from_game_state(&self.game); diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs new file mode 100644 index 0000000..24759f0 --- /dev/null +++ b/bot/src/burnrl/main.rs @@ -0,0 +1,58 @@ +use bot::burnrl::sac_model as burn_model; +// use bot::burnrl::dqn_big_model as burn_model; +// use bot::burnrl::dqn_model as burn_model; +// use bot::burnrl::environment_big::TrictracEnvironment; +use bot::burnrl::environment::TrictracEnvironment; +use bot::burnrl::utils::{demo_model, Config}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::SAC as MyAgent; +// use burn_rl::agent::DQN as MyAgent; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = TrictracEnvironment; + +fn main() { + let path = "bot/models/burnrl_dqn".to_string(); + let conf = Config { + save_path: Some(path.clone()), + num_episodes: 30, // 40 + max_steps: 1000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + + gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.0005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) + + min_probability: 1e-9, + + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + lambda: 0.95, + epsilon_clip: 0.2, + critic_weight: 0.5, + entropy_weight: 0.01, + epochs: 8, + }; + println!("{conf}----------"); + + let agent = burn_model::run::(&conf, false); //true); + + // println!("> Chargement du modèle pour test"); + // let loaded_model = burn_model::load_model(conf.dense_size, &path); + // let loaded_agent: MyAgent = MyAgent::new(loaded_model.unwrap()); + // + // println!("> Test avec le modèle chargé"); + // demo_model(loaded_agent); + + // demo_model::(agent); +} diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs index 13e2c8e..7b719ee 100644 --- a/bot/src/burnrl/mod.rs +++ b/bot/src/burnrl/mod.rs @@ -1,8 +1,9 @@ -pub mod dqn; -pub mod dqn_big; -pub mod dqn_valid; +pub mod dqn_big_model; +pub mod dqn_model; +pub mod dqn_valid_model; pub mod environment; pub mod environment_big; pub mod environment_valid; -pub mod ppo; -pub mod sac; +pub mod ppo_model; +pub mod sac_model; +pub mod utils; diff --git a/bot/src/burnrl/ppo/main.rs b/bot/src/burnrl/ppo/main.rs deleted file mode 100644 index 798c2aa..0000000 --- a/bot/src/burnrl/ppo/main.rs +++ /dev/null @@ -1,52 +0,0 @@ -use bot::burnrl::environment; -use bot::burnrl::ppo::{ - ppo_model, - utils::{demo_model, load_model, save_model}, -}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::PPO; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in ppo_model.rs : 8192 - let conf = ppo_model::PpoConfig { - // defaults - num_episodes: 50, // 40 - max_steps: 1000, // 1000 max steps by episode - dense_size: 128, // 128 neural network complexity (default 128) - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - - lambda: 0.95, - epsilon_clip: 0.2, - critic_weight: 0.5, - entropy_weight: 0.01, - epochs: 8, - }; - println!("{conf}----------"); - let valid_agent = ppo_model::run::(&conf, false); //true); - - // let valid_agent = agent.valid(model); - - println!("> Sauvegarde du modèle de validation"); - - let path = "bot/models/burnrl_ppo".to_string(); - panic!("how to do that : save model"); - // save_model(valid_agent.model().as_ref().unwrap(), &path); - - // println!("> Chargement du modèle pour test"); - // let loaded_model = load_model(conf.dense_size, &path); - // let loaded_agent = PPO::new(loaded_model.unwrap()); - // - // println!("> Test avec le modèle chargé"); - // demo_model(loaded_agent); -} diff --git a/bot/src/burnrl/ppo/mod.rs b/bot/src/burnrl/ppo/mod.rs deleted file mode 100644 index 1b442d8..0000000 --- a/bot/src/burnrl/ppo/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod ppo_model; -pub mod utils; diff --git a/bot/src/burnrl/ppo/utils.rs b/bot/src/burnrl/ppo/utils.rs deleted file mode 100644 index 9457217..0000000 --- a/bot/src/burnrl/ppo/utils.rs +++ /dev/null @@ -1,88 +0,0 @@ -use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use crate::burnrl::ppo::ppo_model; -use crate::training_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{PPOModel, PPO}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &ppo_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - ppo_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: PPO) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &PPO, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - panic!("how to do that ?"); - None - // Get q-values - // let q_values = agent - // .model() - // .as_ref() - // .unwrap() - // .infer(state.to_tensor().unsqueeze()); - // // Get valid actions - // let valid_actions_indices = get_valid_action_indices(&env.game); - // if valid_actions_indices.is_empty() { - // return None; // No valid actions, end of episode - // } - // // Set non valid actions q-values to lowest - // let mut masked_q_values = q_values.clone(); - // let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - // for (index, q_value) in q_values_vec.iter().enumerate() { - // if !valid_actions_indices.contains(&index) { - // masked_q_values = masked_q_values.clone().mask_fill( - // masked_q_values.clone().equal_elem(*q_value), - // f32::NEG_INFINITY, - // ); - // } - // } - // // Get best action (highest q-value) - // let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - // let action = TrictracAction::from(action_index); - // Some(action) -} diff --git a/bot/src/burnrl/ppo/ppo_model.rs b/bot/src/burnrl/ppo_model.rs similarity index 71% rename from bot/src/burnrl/ppo/ppo_model.rs rename to bot/src/burnrl/ppo_model.rs index dc0b5ca..8546b04 100644 --- a/bot/src/burnrl/ppo/ppo_model.rs +++ b/bot/src/burnrl/ppo_model.rs @@ -1,4 +1,5 @@ use crate::burnrl::environment::TrictracEnvironment; +use crate::burnrl::utils::Config; use burn::module::Module; use burn::nn::{Initializer, Linear, LinearConfig}; use burn::optim::AdamWConfig; @@ -7,7 +8,6 @@ use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::fmt; use std::time::SystemTime; #[derive(Module, Debug)] @@ -54,64 +54,11 @@ impl PPOModel for Net {} #[allow(unused)] const MEMORY_SIZE: usize = 512; -pub struct PpoConfig { - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - - pub gamma: f32, - pub lambda: f32, - pub epsilon_clip: f32, - pub critic_weight: f32, - pub entropy_weight: f32, - pub learning_rate: f32, - pub epochs: usize, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for PpoConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("lambda={:?}\n", self.lambda)); - s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip)); - s.push_str(&format!("critic_weight={:?}\n", self.critic_weight)); - s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("epochs={:?}\n", self.epochs)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - write!(f, "{s}") - } -} - -impl Default for PpoConfig { - fn default() -> Self { - Self { - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - - gamma: 0.99, - lambda: 0.95, - epsilon_clip: 0.2, - critic_weight: 0.5, - entropy_weight: 0.01, - learning_rate: 0.001, - epochs: 8, - batch_size: 8, - clip_grad: 100.0, - } - } -} type MyAgent = PPO>; #[allow(unused)] pub fn run, B: AutodiffBackend>( - conf: &PpoConfig, + conf: &Config, visualized: bool, // ) -> PPO> { ) -> impl Agent { @@ -179,6 +126,9 @@ pub fn run, B: AutodiffBackend>( memory.clear(); } - agent.valid(model) - // agent + let valid_agent = agent.valid(model); + if let Some(path) = &conf.save_path { + // save_model(???, path); + } + valid_agent } diff --git a/bot/src/burnrl/sac/main.rs b/bot/src/burnrl/sac/main.rs deleted file mode 100644 index 2f72c32..0000000 --- a/bot/src/burnrl/sac/main.rs +++ /dev/null @@ -1,45 +0,0 @@ -use bot::burnrl::environment; -use bot::burnrl::sac::{sac_model, utils::demo_model}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::SAC; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = sac_model::SacConfig { - // defaults - num_episodes: 50, // 40 - max_steps: 1000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 1.0, // 1.0 limite max de correction à apporter au gradient - min_probability: 1e-9, - }; - println!("{conf}----------"); - let valid_agent = sac_model::run::(&conf, false); //true); - - // let valid_agent = agent.valid(); - - // println!("> Sauvegarde du modèle de validation"); - // - // let path = "bot/models/burnrl_dqn".to_string(); - // save_model(valid_agent.model().as_ref().unwrap(), &path); - // - // println!("> Chargement du modèle pour test"); - // let loaded_model = load_model(conf.dense_size, &path); - // let loaded_agent = DQN::new(loaded_model.unwrap()); - // - // println!("> Test avec le modèle chargé"); - // demo_model(loaded_agent); -} diff --git a/bot/src/burnrl/sac/mod.rs b/bot/src/burnrl/sac/mod.rs deleted file mode 100644 index 77e721a..0000000 --- a/bot/src/burnrl/sac/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod sac_model; -pub mod utils; diff --git a/bot/src/burnrl/sac/utils.rs b/bot/src/burnrl/sac/utils.rs deleted file mode 100644 index ac6059d..0000000 --- a/bot/src/burnrl/sac/utils.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use crate::burnrl::sac::sac_model; -use crate::training_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -// use burn_rl::agent::{SACModel, SAC}; -use burn_rl::base::{Agent, ElemType, Environment}; - -// pub fn save_model(model: &sac_model::Net>, path: &String) { -// let recorder = CompactRecorder::new(); -// let model_path = format!("{path}.mpk"); -// println!("Modèle de validation sauvegardé : {model_path}"); -// recorder -// .record(model.clone().into_record(), model_path.into()) -// .unwrap(); -// } -// -// pub fn load_model(dense_size: usize, path: &String) -> Option>> { -// let model_path = format!("{path}.mpk"); -// // println!("Chargement du modèle depuis : {model_path}"); -// -// CompactRecorder::new() -// .load(model_path.into(), &NdArrayDevice::default()) -// .map(|record| { -// dqn_model::Net::new( -// ::StateType::size(), -// dense_size, -// ::ActionType::size(), -// ) -// .load_record(record) -// }) -// .ok() -// } -// - -pub fn demo_model(agent: impl Agent) { - let mut env = E::new(true); - let mut state = env.state(); - let mut done = false; - while !done { - if let Some(action) = agent.react(&state) { - let snapshot = env.step(action); - state = *snapshot.state(); - done = snapshot.done(); - } - } -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/burnrl/sac/sac_model.rs b/bot/src/burnrl/sac_model.rs similarity index 80% rename from bot/src/burnrl/sac/sac_model.rs rename to bot/src/burnrl/sac_model.rs index 96b2e24..bc7c87d 100644 --- a/bot/src/burnrl/sac/sac_model.rs +++ b/bot/src/burnrl/sac_model.rs @@ -1,14 +1,15 @@ use crate::burnrl::environment::TrictracEnvironment; -use crate::burnrl::sac::utils::soft_update_linear; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::{relu, softmax}; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::fmt; use std::time::SystemTime; #[derive(Module, Debug)] @@ -92,57 +93,11 @@ impl SACCritic for Critic { #[allow(unused)] const MEMORY_SIZE: usize = 4096; -pub struct SacConfig { - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, - pub min_probability: f32, -} - -impl Default for SacConfig { - fn default() -> Self { - Self { - max_steps: 2000, - num_episodes: 1000, - dense_size: 32, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 1.0, - min_probability: 1e-9, - } - } -} - -impl fmt::Display for SacConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - s.push_str(&format!("min_probability={:?}\n", self.min_probability)); - write!(f, "{s}") - } -} - type MyAgent = SAC>; #[allow(unused)] pub fn run, B: AutodiffBackend>( - conf: &SacConfig, + conf: &Config, visualized: bool, ) -> impl Agent { let mut env = E::new(visualized); @@ -229,5 +184,35 @@ pub fn run, B: AutodiffBackend>( } } - agent.valid(nets.actor) + let valid_agent = agent.valid(nets.actor); + if let Some(path) = &conf.save_path { + // save_model(???, path); + } + valid_agent } + +// pub fn save_model(model: ???, path: &String) { +// let recorder = CompactRecorder::new(); +// let model_path = format!("{path}.mpk"); +// println!("info: Modèle de validation sauvegardé : {model_path}"); +// recorder +// .record(model.clone().into_record(), model_path.into()) +// .unwrap(); +// } +// +// pub fn load_model(dense_size: usize, path: &String) -> Option>> { +// let model_path = format!("{path}.mpk"); +// // println!("Chargement du modèle depuis : {model_path}"); +// +// CompactRecorder::new() +// .load(model_path.into(), &NdArrayDevice::default()) +// .map(|record| { +// Actor::new( +// ::StateType::size(), +// dense_size, +// ::ActionType::size(), +// ) +// .load_record(record) +// }) +// .ok() +// } diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs new file mode 100644 index 0000000..21c6cec --- /dev/null +++ b/bot/src/burnrl/utils.rs @@ -0,0 +1,121 @@ +use burn::module::{Param, ParamId}; +use burn::nn::Linear; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; +use burn_rl::base::{Agent, ElemType, Environment}; + +pub struct Config { + pub save_path: Option, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, + + // for SAC + pub min_probability: f32, + + // for DQN + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + // for PPO + pub lambda: f32, + pub epsilon_clip: f32, + pub critic_weight: f32, + pub entropy_weight: f32, + pub epochs: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + save_path: None, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + min_probability: 1e-9, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + lambda: 0.95, + epsilon_clip: 0.2, + critic_weight: 0.5, + entropy_weight: 0.01, + epochs: 8, + } + } +} + +impl std::fmt::Display for Config { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut s = String::new(); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + s.push_str(&format!("min_probability={:?}\n", self.min_probability)); + s.push_str(&format!("lambda={:?}\n", self.lambda)); + s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip)); + s.push_str(&format!("critic_weight={:?}\n", self.critic_weight)); + s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight)); + s.push_str(&format!("epochs={:?}\n", self.epochs)); + write!(f, "{s}") + } +} + +pub fn demo_model(agent: impl Agent) { + let mut env = E::new(true); + let mut state = env.state(); + let mut done = false; + while !done { + if let Some(action) = agent.react(&state) { + let snapshot = env.step(action); + state = *snapshot.state(); + done = snapshot.done(); + } + } +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index 3d25c2b..e513860 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -6,8 +6,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use log::info; use store::MoveRules; -use crate::burnrl::dqn::{dqn_model, utils}; +use crate::burnrl::dqn_model; use crate::burnrl::environment; +use crate::burnrl::utils; use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; type DqnBurnNetwork = dqn_model::Net>; @@ -40,7 +41,7 @@ impl DqnBurnStrategy { pub fn new_with_model(model_path: &String) -> Self { info!("Loading model {model_path:?}"); let mut strategy = Self::new(); - strategy.model = utils::load_model(256, model_path); + strategy.model = dqn_model::load_model(256, model_path); strategy }