feat: bot sac & ppo save & load

This commit is contained in:
Henri Bourcereau 2025-08-21 14:35:25 +02:00
parent afeb3561e0
commit 0c58490f87
8 changed files with 127 additions and 103 deletions

View file

@ -155,10 +155,10 @@ impl Environment for TrictracEnvironment {
self.goodmoves_count as f32 / self.step_count as f32
};
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{:?}", history);
write!(out, "{history:?}").expect("could not write history log");
}
"!!!!"
} else {

View file

@ -29,8 +29,10 @@ fn main() {
batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
// SAC
min_probability: 1e-9,
// DQN
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower
@ -38,6 +40,7 @@ fn main() {
// epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ?
// PPO
lambda: 0.95,
epsilon_clip: 0.2,
critic_weight: 0.5,
@ -48,7 +51,7 @@ fn main() {
match algo.as_str() {
"dqn" => {
let agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
let _agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn_model::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
@ -58,23 +61,30 @@ fn main() {
demo_model(loaded_agent);
}
"dqn_big" => {
let agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
let _agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
}
"dqn_valid" => {
let agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
let _agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
}
"sac" => {
let agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
// println!("> Chargement du modèle pour test");
// let loaded_model = sac_model::load_model(conf.dense_size, &path);
// let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
// burn_rl::agent::SAC::new(loaded_model.unwrap());
//
// println!("> Test avec le modèle chargé");
// demo_model(loaded_agent);
let _agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac_model::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo" => {
let agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
let _agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo_model::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
&_ => {
dbg!("unknown algo {algo}");

View file

@ -1,13 +1,17 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -57,7 +61,10 @@ const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
@ -126,9 +133,61 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
memory.clear();
}
let valid_agent = agent.valid(model);
if let Some(path) = &conf.save_path {
// save_model(???, path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
let valid_agent = agent.valid(model);
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -96,7 +96,10 @@ const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
@ -105,9 +108,9 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let mut actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let mut critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
@ -134,8 +137,6 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
optimizer_config.init(),
);
let mut policy_net = agent.model().clone();
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
@ -186,33 +187,35 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
// save_model(???, path);
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
// pub fn save_model(model: ???, path: &String) {
// let recorder = CompactRecorder::new();
// let model_path = format!("{path}.mpk");
// println!("info: Modèle de validation sauvegardé : {model_path}");
// recorder
// .record(model.clone().into_record(), model_path.into())
// .unwrap();
// }
//
// pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
// let model_path = format!("{path}.mpk");
// // println!("Chargement du modèle depuis : {model_path}");
//
// CompactRecorder::new()
// .load(model_path.into(), &NdArrayDevice::default())
// .map(|record| {
// Actor::new(
// <TrictracEnvironment as Environment>::StateType::size(),
// dense_size,
// <TrictracEnvironment as Environment>::ActionType::size(),
// )
// .load_record(record)
// })
// .ok()
// }
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}