diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 5df0623..ecda4d0 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -17,6 +17,10 @@ path = "src/burnrl/dqn_big/main.rs" name = "train_dqn_burn" path = "src/burnrl/dqn/main.rs" +[[bin]] +name = "train_sac_burn" +path = "src/burnrl/sac/main.rs" + [[bin]] name = "train_ppo_burn" path = "src/burnrl/ppo/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index d310bbe..4c02189 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -4,7 +4,8 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 -BINBOT=train_ppo_burn +BINBOT=train_sac_burn +# BINBOT=train_ppo_burn # BINBOT=train_dqn_burn # BINBOT=train_dqn_burn_big # BINBOT=train_dqn_burn_before diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs index 0afacb4..13e2c8e 100644 --- a/bot/src/burnrl/mod.rs +++ b/bot/src/burnrl/mod.rs @@ -4,3 +4,5 @@ pub mod dqn_valid; pub mod environment; pub mod environment_big; pub mod environment_valid; +pub mod ppo; +pub mod sac; diff --git a/bot/src/burnrl/ppo/main.rs b/bot/src/burnrl/ppo/main.rs index 3633e29..798c2aa 100644 --- a/bot/src/burnrl/ppo/main.rs +++ b/bot/src/burnrl/ppo/main.rs @@ -13,18 +13,18 @@ type Env = environment::TrictracEnvironment; fn main() { // println!("> Entraînement"); - // See also MEMORY_SIZE in dqn_model.rs : 8192 + // See also MEMORY_SIZE in ppo_model.rs : 8192 let conf = ppo_model::PpoConfig { // defaults num_episodes: 50, // 40 max_steps: 1000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + dense_size: 128, // 128 neural network complexity (default 128) + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme // plus lente moins sensible aux coups de chance learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais // converger batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) lambda: 0.95, epsilon_clip: 0.2, diff --git a/bot/src/burnrl/sac/main.rs b/bot/src/burnrl/sac/main.rs new file mode 100644 index 0000000..2f72c32 --- /dev/null +++ b/bot/src/burnrl/sac/main.rs @@ -0,0 +1,45 @@ +use bot::burnrl::environment; +use bot::burnrl::sac::{sac_model, utils::demo_model}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::SAC; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = sac_model::SacConfig { + // defaults + num_episodes: 50, // 40 + max_steps: 1000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 1.0, // 1.0 limite max de correction à apporter au gradient + min_probability: 1e-9, + }; + println!("{conf}----------"); + let valid_agent = sac_model::run::(&conf, false); //true); + + // let valid_agent = agent.valid(); + + // println!("> Sauvegarde du modèle de validation"); + // + // let path = "bot/models/burnrl_dqn".to_string(); + // save_model(valid_agent.model().as_ref().unwrap(), &path); + // + // println!("> Chargement du modèle pour test"); + // let loaded_model = load_model(conf.dense_size, &path); + // let loaded_agent = DQN::new(loaded_model.unwrap()); + // + // println!("> Test avec le modèle chargé"); + // demo_model(loaded_agent); +} diff --git a/bot/src/burnrl/sac/mod.rs b/bot/src/burnrl/sac/mod.rs new file mode 100644 index 0000000..77e721a --- /dev/null +++ b/bot/src/burnrl/sac/mod.rs @@ -0,0 +1,2 @@ +pub mod sac_model; +pub mod utils; diff --git a/bot/src/burnrl/sac/sac_model.rs b/bot/src/burnrl/sac/sac_model.rs new file mode 100644 index 0000000..96b2e24 --- /dev/null +++ b/bot/src/burnrl/sac/sac_model.rs @@ -0,0 +1,233 @@ +use crate::burnrl::environment::TrictracEnvironment; +use crate::burnrl::sac::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Actor { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Actor { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } +} + +impl Model, Tensor> for Actor { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + softmax(self.linear_2.forward(layer_1_output), 1) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACActor for Actor {} + +#[derive(Module, Debug)] +pub struct Critic { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Critic { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Critic { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + self.linear_2.forward(layer_1_output) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACCritic for Critic { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 4096; + +pub struct SacConfig { + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, + pub min_probability: f32, +} + +impl Default for SacConfig { + fn default() -> Self { + Self { + max_steps: 2000, + num_episodes: 1000, + dense_size: 32, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 1.0, + min_probability: 1e-9, + } + } +} + +impl fmt::Display for SacConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + s.push_str(&format!("min_probability={:?}\n", self.min_probability)); + write!(f, "{s}") + } +} + +type MyAgent = SAC>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &SacConfig, + visualized: bool, +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + let state_dim = <::StateType as State>::size(); + let action_dim = <::ActionType as Action>::size(); + + let mut actor = Actor::::new(state_dim, conf.dense_size, action_dim); + let mut critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); + let mut critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); + let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); + + let mut agent = MyAgent::default(); + + let config = SACTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + min_probability: conf.min_probability, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone()); + + let mut optimizer = SACOptimizer::new( + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.init(), + ); + + let mut policy_net = agent.model().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + if let Some(action) = MyAgent::::react_with_model(&state, &nets.actor) { + let snapshot = env.step(action); + + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + nets = agent.train::(nets, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + env.reset(); + episode_done = true; + + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs() + ); + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + } + + agent.valid(nets.actor) +} diff --git a/bot/src/burnrl/sac/utils.rs b/bot/src/burnrl/sac/utils.rs new file mode 100644 index 0000000..ac6059d --- /dev/null +++ b/bot/src/burnrl/sac/utils.rs @@ -0,0 +1,78 @@ +use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::burnrl::sac::sac_model; +use crate::training_common::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +// use burn_rl::agent::{SACModel, SAC}; +use burn_rl::base::{Agent, ElemType, Environment}; + +// pub fn save_model(model: &sac_model::Net>, path: &String) { +// let recorder = CompactRecorder::new(); +// let model_path = format!("{path}.mpk"); +// println!("Modèle de validation sauvegardé : {model_path}"); +// recorder +// .record(model.clone().into_record(), model_path.into()) +// .unwrap(); +// } +// +// pub fn load_model(dense_size: usize, path: &String) -> Option>> { +// let model_path = format!("{path}.mpk"); +// // println!("Chargement du modèle depuis : {model_path}"); +// +// CompactRecorder::new() +// .load(model_path.into(), &NdArrayDevice::default()) +// .map(|record| { +// dqn_model::Net::new( +// ::StateType::size(), +// dense_size, +// ::ActionType::size(), +// ) +// .load_record(record) +// }) +// .ok() +// } +// + +pub fn demo_model(agent: impl Agent) { + let mut env = E::new(true); + let mut state = env.state(); + let mut done = false; + while !done { + if let Some(action) = agent.react(&state) { + let snapshot = env.step(action); + state = *snapshot.state(); + done = snapshot.done(); + } + } +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +}