From 0c58490f873c5ef39e0860f74624203f138f4b92 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 21 Aug 2025 14:35:25 +0200 Subject: [PATCH] feat: bot sac & ppo save & load --- bot/Cargo.toml | 20 ----------- bot/scripts/train.sh | 4 +-- bot/src/burnrl/environment.rs | 4 +-- bot/src/burnrl/main.rs | 34 +++++++++++------- bot/src/burnrl/ppo_model.rs | 65 +++++++++++++++++++++++++++++++-- bot/src/burnrl/sac_model.rs | 67 ++++++++++++++++++----------------- bot/src/strategy/dqnburn.rs | 1 - doc/refs/geminiQuestions.md | 35 +++--------------- 8 files changed, 127 insertions(+), 103 deletions(-) diff --git a/bot/Cargo.toml b/bot/Cargo.toml index c775179..2de6307 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -9,26 +9,6 @@ edition = "2021" name = "burn_train" path = "src/burnrl/main.rs" -[[bin]] -name = "train_dqn_burn_valid" -path = "src/burnrl/dqn_valid/main.rs" - -[[bin]] -name = "train_dqn_burn_big" -path = "src/burnrl/dqn_big/main.rs" - -[[bin]] -name = "train_dqn_burn" -path = "src/burnrl/dqn/main.rs" - -[[bin]] -name = "train_sac_burn" -path = "src/burnrl/sac/main.rs" - -[[bin]] -name = "train_ppo_burn" -path = "src/burnrl/ppo/main.rs" - [[bin]] name = "train_dqn_simple" path = "src/dqn_simple/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index b9f7f2a..a9f5e81 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -3,8 +3,8 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" -CFG_SIZE=18 -ALGO="dqn" +CFG_SIZE=17 +ALGO="sac" BINBOT=burn_train # BINBOT=train_ppo_burn # BINBOT=train_dqn_burn diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 9805451..c74cf64 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -155,10 +155,10 @@ impl Environment for TrictracEnvironment { self.goodmoves_count as f32 / self.step_count as f32 }; self.best_ratio = self.best_ratio.max(self.goodmoves_ratio); - let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 { + let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 { let path = "bot/models/logs/debug.log"; if let Ok(mut out) = std::fs::File::create(path) { - write!(out, "{:?}", history); + write!(out, "{history:?}").expect("could not write history log"); } "!!!!" } else { diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index a911e06..ce76b4d 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -29,8 +29,10 @@ fn main() { batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) + // SAC min_probability: 1e-9, + // DQN eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_end: 0.05, // 0.05 // eps_decay higher = epsilon decrease slower @@ -38,6 +40,7 @@ fn main() { // epsilon is updated at the start of each episode eps_decay: 2000.0, // 1000 ? + // PPO lambda: 0.95, epsilon_clip: 0.2, critic_weight: 0.5, @@ -48,7 +51,7 @@ fn main() { match algo.as_str() { "dqn" => { - let agent = dqn_model::run::(&conf, false); + let _agent = dqn_model::run::(&conf, false); println!("> Chargement du modèle pour test"); let loaded_model = dqn_model::load_model(conf.dense_size, &path); let loaded_agent: burn_rl::agent::DQN = @@ -58,23 +61,30 @@ fn main() { demo_model(loaded_agent); } "dqn_big" => { - let agent = dqn_big_model::run::(&conf, false); + let _agent = dqn_big_model::run::(&conf, false); } "dqn_valid" => { - let agent = dqn_valid_model::run::(&conf, false); + let _agent = dqn_valid_model::run::(&conf, false); } "sac" => { - let agent = sac_model::run::(&conf, false); - // println!("> Chargement du modèle pour test"); - // let loaded_model = sac_model::load_model(conf.dense_size, &path); - // let loaded_agent: burn_rl::agent::SAC = - // burn_rl::agent::SAC::new(loaded_model.unwrap()); - // - // println!("> Test avec le modèle chargé"); - // demo_model(loaded_agent); + let _agent = sac_model::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = sac_model::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::SAC = + burn_rl::agent::SAC::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); } "ppo" => { - let agent = ppo_model::run::(&conf, false); + let _agent = ppo_model::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = ppo_model::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::PPO = + burn_rl::agent::PPO::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); } &_ => { dbg!("unknown algo {algo}"); diff --git a/bot/src/burnrl/ppo_model.rs b/bot/src/burnrl/ppo_model.rs index 8546b04..ea0b055 100644 --- a/bot/src/burnrl/ppo_model.rs +++ b/bot/src/burnrl/ppo_model.rs @@ -1,13 +1,17 @@ use crate::burnrl::environment::TrictracEnvironment; use crate::burnrl::utils::Config; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::Module; use burn::nn::{Initializer, Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::{relu, softmax}; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::env; +use std::fs; use std::time::SystemTime; #[derive(Module, Debug)] @@ -57,7 +61,10 @@ const MEMORY_SIZE: usize = 512; type MyAgent = PPO>; #[allow(unused)] -pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( conf: &Config, visualized: bool, // ) -> PPO> { @@ -126,9 +133,61 @@ pub fn run, B: AutodiffBackend>( memory.clear(); } - let valid_agent = agent.valid(model); if let Some(path) = &conf.save_path { - // save_model(???, path); + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + let tmp_path = env::temp_dir().join("tmp_model.mpk"); + + // Save the trained model (backend B) to a temporary file + recorder + .record(model.clone().into_record(), tmp_path.clone()) + .expect("Failed to save temporary model"); + + // Create a new model instance with the target backend (NdArray) + let model_to_save: Net> = Net::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + // Load the record from the temporary file into the new model + let record = recorder + .load(tmp_path.clone(), &device) + .expect("Failed to load temporary model"); + let model_with_loaded_weights = model_to_save.load_record(record); + + // Clean up the temporary file + fs::remove_file(tmp_path).expect("Failed to remove temporary model file"); + + save_model(&model_with_loaded_weights, path); } + let valid_agent = agent.valid(model); valid_agent } + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + diff --git a/bot/src/burnrl/sac_model.rs b/bot/src/burnrl/sac_model.rs index bc7c87d..67db72a 100644 --- a/bot/src/burnrl/sac_model.rs +++ b/bot/src/burnrl/sac_model.rs @@ -96,7 +96,10 @@ const MEMORY_SIZE: usize = 4096; type MyAgent = SAC>; #[allow(unused)] -pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( conf: &Config, visualized: bool, ) -> impl Agent { @@ -105,9 +108,9 @@ pub fn run, B: AutodiffBackend>( let state_dim = <::StateType as State>::size(); let action_dim = <::ActionType as Action>::size(); - let mut actor = Actor::::new(state_dim, conf.dense_size, action_dim); - let mut critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); - let mut critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); + let actor = Actor::::new(state_dim, conf.dense_size, action_dim); + let critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); + let critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); let mut agent = MyAgent::default(); @@ -134,8 +137,6 @@ pub fn run, B: AutodiffBackend>( optimizer_config.init(), ); - let mut policy_net = agent.model().clone(); - let mut step = 0_usize; for episode in 0..conf.num_episodes { @@ -186,33 +187,35 @@ pub fn run, B: AutodiffBackend>( let valid_agent = agent.valid(nets.actor); if let Some(path) = &conf.save_path { - // save_model(???, path); + if let Some(model) = valid_agent.model() { + save_model(model, path); + } } valid_agent } -// pub fn save_model(model: ???, path: &String) { -// let recorder = CompactRecorder::new(); -// let model_path = format!("{path}.mpk"); -// println!("info: Modèle de validation sauvegardé : {model_path}"); -// recorder -// .record(model.clone().into_record(), model_path.into()) -// .unwrap(); -// } -// -// pub fn load_model(dense_size: usize, path: &String) -> Option>> { -// let model_path = format!("{path}.mpk"); -// // println!("Chargement du modèle depuis : {model_path}"); -// -// CompactRecorder::new() -// .load(model_path.into(), &NdArrayDevice::default()) -// .map(|record| { -// Actor::new( -// ::StateType::size(), -// dense_size, -// ::ActionType::size(), -// ) -// .load_record(record) -// }) -// .ok() -// } +pub fn save_model(model: &Actor>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Actor::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} \ No newline at end of file diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index e513860..1f317d0 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -8,7 +8,6 @@ use store::MoveRules; use crate::burnrl::dqn_model; use crate::burnrl::environment; -use crate::burnrl::utils; use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; type DqnBurnNetwork = dqn_model::Net>; diff --git a/doc/refs/geminiQuestions.md b/doc/refs/geminiQuestions.md index 46c33d8..2801fe2 100644 --- a/doc/refs/geminiQuestions.md +++ b/doc/refs/geminiQuestions.md @@ -1,4 +1,4 @@ -# Description du projet et question +# Description du projet Je développe un jeu de TricTrac () dans le langage rust. Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions. @@ -12,35 +12,8 @@ Plus précisément, l'état du jeu est défini par le struct GameState dans stor 'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste. Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning. +J'utilise la bibliothèque burn (). -Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`) -Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`. +Une version utilisant l'algorithme DQN peut être lancée avec `cargo run --bin=burn_train -- dqn`). Elle effectue un entraînement, sauvegarde les données du modèle obtenu puis recharge le modèle depuis le disque pour tester l'agent. L'entraînement est fait dans la fonction 'run' du fichier bot/src/burnrl/dqn_model.rs, la sauvegarde du modèle dans la fonction 'save_model' et le chargement dans la fonction 'load_model'. -Une version, toujours DQN, mais en utilisant la bibliothèque burn () est en cours de développement. - -L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'. - -Voici la sortie de l'entraînement lancé avec 'just trainbot' : - -``` -> Entraînement -> {"episode": 0, "reward": -1692.3148, "duration": 1000} -> {"episode": 1, "reward": -361.6962, "duration": 1000} -> {"episode": 2, "reward": -126.1013, "duration": 1000} -> {"episode": 3, "reward": -36.8000, "duration": 1000} -> {"episode": 4, "reward": -21.4997, "duration": 1000} -> {"episode": 5, "reward": -8.3000, "duration": 1000} -> {"episode": 6, "reward": 3.1000, "duration": 1000} -> {"episode": 7, "reward": -21.5998, "duration": 1000} -> {"episode": 8, "reward": -10.1999, "duration": 1000} -> {"episode": 9, "reward": 3.1000, "duration": 1000} -> {"episode": 10, "reward": 14.5002, "duration": 1000} -> {"episode": 11, "reward": 10.7000, "duration": 1000} -> {"episode": 12, "reward": -0.7000, "duration": 1000} - -thread 'main' has overflowed its stack -fatal runtime error: stack overflow -error: Recipe `trainbot` was terminated on line 25 by signal 6 -``` - -Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug , donc peut-être que le problème vient du paquet 'burl-rl'. +J'essaie de faire l'équivalent avec les algorithmes PPO (fichier bot/src/burnrl/ppo_model.rs) et SAC (fichier bot/src/burnrl/sac_model.rs) : les fonctions 'run' sont implémentées mais pas les fonctions 'save_model' et 'load_model'. Peux-tu les implémenter ?