feat: bot sac & ppo save & load

This commit is contained in:
Henri Bourcereau 2025-08-21 14:35:25 +02:00
parent afeb3561e0
commit 0c58490f87
8 changed files with 127 additions and 103 deletions

View file

@ -9,26 +9,6 @@ edition = "2021"
name = "burn_train" name = "burn_train"
path = "src/burnrl/main.rs" path = "src/burnrl/main.rs"
[[bin]]
name = "train_dqn_burn_valid"
path = "src/burnrl/dqn_valid/main.rs"
[[bin]]
name = "train_dqn_burn_big"
path = "src/burnrl/dqn_big/main.rs"
[[bin]]
name = "train_dqn_burn"
path = "src/burnrl/dqn/main.rs"
[[bin]]
name = "train_sac_burn"
path = "src/burnrl/sac/main.rs"
[[bin]]
name = "train_ppo_burn"
path = "src/burnrl/ppo/main.rs"
[[bin]] [[bin]]
name = "train_dqn_simple" name = "train_dqn_simple"
path = "src/dqn_simple/main.rs" path = "src/dqn_simple/main.rs"

View file

@ -3,8 +3,8 @@
ROOT="$(cd "$(dirname "$0")" && pwd)/../.." ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs" LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=18 CFG_SIZE=17
ALGO="dqn" ALGO="sac"
BINBOT=burn_train BINBOT=burn_train
# BINBOT=train_ppo_burn # BINBOT=train_ppo_burn
# BINBOT=train_dqn_burn # BINBOT=train_dqn_burn

View file

@ -155,10 +155,10 @@ impl Environment for TrictracEnvironment {
self.goodmoves_count as f32 / self.step_count as f32 self.goodmoves_count as f32 / self.step_count as f32
}; };
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio); self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 { let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log"; let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) { if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{:?}", history); write!(out, "{history:?}").expect("could not write history log");
} }
"!!!!" "!!!!"
} else { } else {

View file

@ -29,8 +29,10 @@ fn main() {
batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
// SAC
min_probability: 1e-9, min_probability: 1e-9,
// DQN
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05 eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower // eps_decay higher = epsilon decrease slower
@ -38,6 +40,7 @@ fn main() {
// epsilon is updated at the start of each episode // epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ? eps_decay: 2000.0, // 1000 ?
// PPO
lambda: 0.95, lambda: 0.95,
epsilon_clip: 0.2, epsilon_clip: 0.2,
critic_weight: 0.5, critic_weight: 0.5,
@ -48,7 +51,7 @@ fn main() {
match algo.as_str() { match algo.as_str() {
"dqn" => { "dqn" => {
let agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false); let _agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = dqn_model::load_model(conf.dense_size, &path); let loaded_model = dqn_model::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> = let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
@ -58,23 +61,30 @@ fn main() {
demo_model(loaded_agent); demo_model(loaded_agent);
} }
"dqn_big" => { "dqn_big" => {
let agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false); let _agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
} }
"dqn_valid" => { "dqn_valid" => {
let agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false); let _agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
} }
"sac" => { "sac" => {
let agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false); let _agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
// println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
// let loaded_model = sac_model::load_model(conf.dense_size, &path); let loaded_model = sac_model::load_model(conf.dense_size, &path);
// let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> = let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
// burn_rl::agent::SAC::new(loaded_model.unwrap()); burn_rl::agent::SAC::new(loaded_model.unwrap());
//
// println!("> Test avec le modèle chargé"); println!("> Test avec le modèle chargé");
// demo_model(loaded_agent); demo_model(loaded_agent);
} }
"ppo" => { "ppo" => {
let agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false); let _agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo_model::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
} }
&_ => { &_ => {
dbg!("unknown algo {algo}"); dbg!("unknown algo {algo}");

View file

@ -1,13 +1,17 @@
use crate::burnrl::environment::TrictracEnvironment; use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::Config; use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module; use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig}; use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax}; use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime; use std::time::SystemTime;
#[derive(Module, Debug)] #[derive(Module, Debug)]
@ -57,7 +61,10 @@ const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>; type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)] #[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>( pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config, conf: &Config,
visualized: bool, visualized: bool,
// ) -> PPO<E, B, Net<B>> { // ) -> PPO<E, B, Net<B>> {
@ -126,9 +133,61 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
memory.clear(); memory.clear();
} }
let valid_agent = agent.valid(model);
if let Some(path) = &conf.save_path { if let Some(path) = &conf.save_path {
// save_model(???, path); let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
} }
let valid_agent = agent.valid(model);
valid_agent valid_agent
} }
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -96,7 +96,10 @@ const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>; type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)] #[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>( pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config, conf: &Config,
visualized: bool, visualized: bool,
) -> impl Agent<E> { ) -> impl Agent<E> {
@ -105,9 +108,9 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
let state_dim = <<E as Environment>::StateType as State>::size(); let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size(); let action_dim = <<E as Environment>::ActionType as Action>::size();
let mut actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim); let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let mut critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim); let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim); let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2); let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default(); let mut agent = MyAgent::default();
@ -134,8 +137,6 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
optimizer_config.init(), optimizer_config.init(),
); );
let mut policy_net = agent.model().clone();
let mut step = 0_usize; let mut step = 0_usize;
for episode in 0..conf.num_episodes { for episode in 0..conf.num_episodes {
@ -186,33 +187,35 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
let valid_agent = agent.valid(nets.actor); let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path { if let Some(path) = &conf.save_path {
// save_model(???, path); if let Some(model) = valid_agent.model() {
save_model(model, path);
}
} }
valid_agent valid_agent
} }
// pub fn save_model(model: ???, path: &String) { pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
// let recorder = CompactRecorder::new(); let recorder = CompactRecorder::new();
// let model_path = format!("{path}.mpk"); let model_path = format!("{path}.mpk");
// println!("info: Modèle de validation sauvegardé : {model_path}"); println!("info: Modèle de validation sauvegardé : {model_path}");
// recorder recorder
// .record(model.clone().into_record(), model_path.into()) .record(model.clone().into_record(), model_path.into())
// .unwrap(); .unwrap();
// } }
//
// pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> { pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
// let model_path = format!("{path}.mpk"); let model_path = format!("{path}.mpk");
// // println!("Chargement du modèle depuis : {model_path}"); // println!("Chargement du modèle depuis : {model_path}");
//
// CompactRecorder::new() CompactRecorder::new()
// .load(model_path.into(), &NdArrayDevice::default()) .load(model_path.into(), &NdArrayDevice::default())
// .map(|record| { .map(|record| {
// Actor::new( Actor::new(
// <TrictracEnvironment as Environment>::StateType::size(), <TrictracEnvironment as Environment>::StateType::size(),
// dense_size, dense_size,
// <TrictracEnvironment as Environment>::ActionType::size(), <TrictracEnvironment as Environment>::ActionType::size(),
// ) )
// .load_record(record) .load_record(record)
// }) })
// .ok() .ok()
// } }

View file

@ -8,7 +8,6 @@ use store::MoveRules;
use crate::burnrl::dqn_model; use crate::burnrl::dqn_model;
use crate::burnrl::environment; use crate::burnrl::environment;
use crate::burnrl::utils;
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>; type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;

View file

@ -1,4 +1,4 @@
# Description du projet et question # Description du projet
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust. Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions. Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
@ -12,35 +12,8 @@ Plus précisément, l'état du jeu est défini par le struct GameState dans stor
'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste. 'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning. Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning.
J'utilise la bibliothèque burn (<https://burn.dev/>).
Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`) Une version utilisant l'algorithme DQN peut être lancée avec `cargo run --bin=burn_train -- dqn`). Elle effectue un entraînement, sauvegarde les données du modèle obtenu puis recharge le modèle depuis le disque pour tester l'agent. L'entraînement est fait dans la fonction 'run' du fichier bot/src/burnrl/dqn_model.rs, la sauvegarde du modèle dans la fonction 'save_model' et le chargement dans la fonction 'load_model'.
Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`.
Une version, toujours DQN, mais en utilisant la bibliothèque burn (<https://burn.dev/>) est en cours de développement. J'essaie de faire l'équivalent avec les algorithmes PPO (fichier bot/src/burnrl/ppo_model.rs) et SAC (fichier bot/src/burnrl/sac_model.rs) : les fonctions 'run' sont implémentées mais pas les fonctions 'save_model' et 'load_model'. Peux-tu les implémenter ?
L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'.
Voici la sortie de l'entraînement lancé avec 'just trainbot' :
```
> Entraînement
> {"episode": 0, "reward": -1692.3148, "duration": 1000}
> {"episode": 1, "reward": -361.6962, "duration": 1000}
> {"episode": 2, "reward": -126.1013, "duration": 1000}
> {"episode": 3, "reward": -36.8000, "duration": 1000}
> {"episode": 4, "reward": -21.4997, "duration": 1000}
> {"episode": 5, "reward": -8.3000, "duration": 1000}
> {"episode": 6, "reward": 3.1000, "duration": 1000}
> {"episode": 7, "reward": -21.5998, "duration": 1000}
> {"episode": 8, "reward": -10.1999, "duration": 1000}
> {"episode": 9, "reward": 3.1000, "duration": 1000}
> {"episode": 10, "reward": 14.5002, "duration": 1000}
> {"episode": 11, "reward": 10.7000, "duration": 1000}
> {"episode": 12, "reward": -0.7000, "duration": 1000}
thread 'main' has overflowed its stack
fatal runtime error: stack overflow
error: Recipe `trainbot` was terminated on line 25 by signal 6
```
Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug <https://github.com/yunjhongwu/burn-rl-examples/issues/40> , donc peut-être que le problème vient du paquet 'burl-rl'.