From 8f41cc1412e32e3665718ff854d5fd32c06b3cbf Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 21 Aug 2025 17:39:45 +0200 Subject: [PATCH] feat: bot all algos --- bot/scripts/train.sh | 17 +- bot/src/burnrl/{dqn_model.rs => algos/dqn.rs} | 0 .../{dqn_big_model.rs => algos/dqn_big.rs} | 0 .../dqn_valid.rs} | 0 bot/src/burnrl/algos/mod.rs | 9 + bot/src/burnrl/{ppo_model.rs => algos/ppo.rs} | 4 +- bot/src/burnrl/algos/ppo_big.rs | 191 +++++++++++++++ bot/src/burnrl/algos/ppo_valid.rs | 191 +++++++++++++++ bot/src/burnrl/{sac_model.rs => algos/sac.rs} | 0 bot/src/burnrl/algos/sac_big.rs | 222 ++++++++++++++++++ bot/src/burnrl/algos/sac_valid.rs | 222 ++++++++++++++++++ bot/src/burnrl/main.rs | 78 +++++- bot/src/burnrl/mod.rs | 6 +- bot/src/strategy/dqnburn.rs | 6 +- justfile | 8 +- store/src/board.rs | 4 +- store/src/game.rs | 6 +- store/src/game_rules_points.rs | 4 +- 18 files changed, 929 insertions(+), 39 deletions(-) rename bot/src/burnrl/{dqn_model.rs => algos/dqn.rs} (100%) rename bot/src/burnrl/{dqn_big_model.rs => algos/dqn_big.rs} (100%) rename bot/src/burnrl/{dqn_valid_model.rs => algos/dqn_valid.rs} (100%) create mode 100644 bot/src/burnrl/algos/mod.rs rename bot/src/burnrl/{ppo_model.rs => algos/ppo.rs} (99%) create mode 100644 bot/src/burnrl/algos/ppo_big.rs create mode 100644 bot/src/burnrl/algos/ppo_valid.rs rename bot/src/burnrl/{sac_model.rs => algos/sac.rs} (100%) create mode 100644 bot/src/burnrl/algos/sac_big.rs create mode 100644 bot/src/burnrl/algos/sac_valid.rs diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index a9f5e81..87a3770 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -1,10 +1,9 @@ -#!/usr/bin/env sh +#!/usr/bin/env bash ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=17 -ALGO="sac" BINBOT=burn_train # BINBOT=train_ppo_burn # BINBOT=train_dqn_burn @@ -15,6 +14,7 @@ OPPONENT="random" PLOT_EXT="png" train() { + ALGO=$1 cargo build --release --bin=$BINBOT NAME="$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$ALGO/$NAME.out" @@ -23,6 +23,7 @@ train() { } plot() { + ALGO=$1 NAME=$(ls -rt "$LOGS_DIR/$ALGO" | tail -n 1) LOGS="$LOGS_DIR/$ALGO/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") @@ -37,8 +38,14 @@ plot() { feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT" } -if [ "$1" = "plot" ]; then - plot +if [[ -z "$1" ]]; then + echo "Usage : train [plot] " +elif [ "$1" = "plot" ]; then + if [[ -z "$2" ]]; then + echo "Usage : train [plot] " + else + plot $2 + fi else - train + train $1 fi diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/algos/dqn.rs similarity index 100% rename from bot/src/burnrl/dqn_model.rs rename to bot/src/burnrl/algos/dqn.rs diff --git a/bot/src/burnrl/dqn_big_model.rs b/bot/src/burnrl/algos/dqn_big.rs similarity index 100% rename from bot/src/burnrl/dqn_big_model.rs rename to bot/src/burnrl/algos/dqn_big.rs diff --git a/bot/src/burnrl/dqn_valid_model.rs b/bot/src/burnrl/algos/dqn_valid.rs similarity index 100% rename from bot/src/burnrl/dqn_valid_model.rs rename to bot/src/burnrl/algos/dqn_valid.rs diff --git a/bot/src/burnrl/algos/mod.rs b/bot/src/burnrl/algos/mod.rs new file mode 100644 index 0000000..af13327 --- /dev/null +++ b/bot/src/burnrl/algos/mod.rs @@ -0,0 +1,9 @@ +pub mod dqn; +pub mod dqn_big; +pub mod dqn_valid; +pub mod ppo; +pub mod ppo_big; +pub mod ppo_valid; +pub mod sac; +pub mod sac_big; +pub mod sac_valid; diff --git a/bot/src/burnrl/ppo_model.rs b/bot/src/burnrl/algos/ppo.rs similarity index 99% rename from bot/src/burnrl/ppo_model.rs rename to bot/src/burnrl/algos/ppo.rs index ea0b055..df6818c 100644 --- a/bot/src/burnrl/ppo_model.rs +++ b/bot/src/burnrl/algos/ppo.rs @@ -161,8 +161,7 @@ pub fn run< save_model(&model_with_loaded_weights, path); } - let valid_agent = agent.valid(model); - valid_agent + agent.valid(model) } pub fn save_model(model: &Net>, path: &String) { @@ -190,4 +189,3 @@ pub fn load_model(dense_size: usize, path: &String) -> Option { + linear: Linear, + linear_actor: Linear, + linear_critic: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + let initializer = Initializer::XavierUniform { gain: 1.0 }; + Self { + linear: LinearConfig::new(input_size, dense_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_actor: LinearConfig::new(dense_size, output_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_critic: LinearConfig::new(dense_size, 1) + .with_initializer(initializer) + .init(&Default::default()), + } + } +} + +impl Model, PPOOutput, Tensor> for Net { + fn forward(&self, input: Tensor) -> PPOOutput { + let layer_0_output = relu(self.linear.forward(input)); + let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1); + let values = self.linear_critic.forward(layer_0_output); + + PPOOutput::::new(policies, values) + } + + fn infer(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear.forward(input)); + softmax(self.linear_actor.forward(layer_0_output.clone()), 1) + } +} + +impl PPOModel for Net {} +#[allow(unused)] +const MEMORY_SIZE: usize = 512; + +type MyAgent = PPO>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, + // ) -> PPO> { +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let mut model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + let agent = MyAgent::default(); + let config = PPOTrainingConfig { + gamma: conf.gamma, + lambda: conf.lambda, + epsilon_clip: conf.epsilon_clip, + critic_weight: conf.critic_weight, + entropy_weight: conf.entropy_weight, + learning_rate: conf.learning_rate, + epochs: conf.epochs, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + let mut memory = Memory::::default(); + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut now = SystemTime::now(); + + env.reset(); + while !episode_done { + let state = env.state(); + if let Some(action) = MyAgent::::react_with_model(&state, &model) { + let snapshot = env.step(action); + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + episode_duration += 1; + episode_done = snapshot.done() || episode_duration >= conf.max_steps; + } + } + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), + ); + + now = SystemTime::now(); + model = MyAgent::train::(model, &memory, &mut optimizer, &config); + memory.clear(); + } + + if let Some(path) = &conf.save_path { + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + let tmp_path = env::temp_dir().join("tmp_model.mpk"); + + // Save the trained model (backend B) to a temporary file + recorder + .record(model.clone().into_record(), tmp_path.clone()) + .expect("Failed to save temporary model"); + + // Create a new model instance with the target backend (NdArray) + let model_to_save: Net> = Net::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + // Load the record from the temporary file into the new model + let record = recorder + .load(tmp_path.clone(), &device) + .expect("Failed to load temporary model"); + let model_with_loaded_weights = model_to_save.load_record(record); + + // Clean up the temporary file + fs::remove_file(tmp_path).expect("Failed to remove temporary model file"); + + save_model(&model_with_loaded_weights, path); + } + agent.valid(model) +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} diff --git a/bot/src/burnrl/algos/ppo_valid.rs b/bot/src/burnrl/algos/ppo_valid.rs new file mode 100644 index 0000000..8a391fb --- /dev/null +++ b/bot/src/burnrl/algos/ppo_valid.rs @@ -0,0 +1,191 @@ +use crate::burnrl::environment_valid::TrictracEnvironment; +use crate::burnrl::utils::Config; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Initializer, Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::env; +use std::fs; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear: Linear, + linear_actor: Linear, + linear_critic: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + let initializer = Initializer::XavierUniform { gain: 1.0 }; + Self { + linear: LinearConfig::new(input_size, dense_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_actor: LinearConfig::new(dense_size, output_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_critic: LinearConfig::new(dense_size, 1) + .with_initializer(initializer) + .init(&Default::default()), + } + } +} + +impl Model, PPOOutput, Tensor> for Net { + fn forward(&self, input: Tensor) -> PPOOutput { + let layer_0_output = relu(self.linear.forward(input)); + let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1); + let values = self.linear_critic.forward(layer_0_output); + + PPOOutput::::new(policies, values) + } + + fn infer(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear.forward(input)); + softmax(self.linear_actor.forward(layer_0_output.clone()), 1) + } +} + +impl PPOModel for Net {} +#[allow(unused)] +const MEMORY_SIZE: usize = 512; + +type MyAgent = PPO>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, + // ) -> PPO> { +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let mut model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + let agent = MyAgent::default(); + let config = PPOTrainingConfig { + gamma: conf.gamma, + lambda: conf.lambda, + epsilon_clip: conf.epsilon_clip, + critic_weight: conf.critic_weight, + entropy_weight: conf.entropy_weight, + learning_rate: conf.learning_rate, + epochs: conf.epochs, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + let mut memory = Memory::::default(); + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut now = SystemTime::now(); + + env.reset(); + while !episode_done { + let state = env.state(); + if let Some(action) = MyAgent::::react_with_model(&state, &model) { + let snapshot = env.step(action); + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + episode_duration += 1; + episode_done = snapshot.done() || episode_duration >= conf.max_steps; + } + } + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), + ); + + now = SystemTime::now(); + model = MyAgent::train::(model, &memory, &mut optimizer, &config); + memory.clear(); + } + + if let Some(path) = &conf.save_path { + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + let tmp_path = env::temp_dir().join("tmp_model.mpk"); + + // Save the trained model (backend B) to a temporary file + recorder + .record(model.clone().into_record(), tmp_path.clone()) + .expect("Failed to save temporary model"); + + // Create a new model instance with the target backend (NdArray) + let model_to_save: Net> = Net::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + // Load the record from the temporary file into the new model + let record = recorder + .load(tmp_path.clone(), &device) + .expect("Failed to load temporary model"); + let model_with_loaded_weights = model_to_save.load_record(record); + + // Clean up the temporary file + fs::remove_file(tmp_path).expect("Failed to remove temporary model file"); + + save_model(&model_with_loaded_weights, path); + } + agent.valid(model) +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} diff --git a/bot/src/burnrl/sac_model.rs b/bot/src/burnrl/algos/sac.rs similarity index 100% rename from bot/src/burnrl/sac_model.rs rename to bot/src/burnrl/algos/sac.rs diff --git a/bot/src/burnrl/algos/sac_big.rs b/bot/src/burnrl/algos/sac_big.rs new file mode 100644 index 0000000..1361b42 --- /dev/null +++ b/bot/src/burnrl/algos/sac_big.rs @@ -0,0 +1,222 @@ +use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Actor { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Actor { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } +} + +impl Model, Tensor> for Actor { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + softmax(self.linear_2.forward(layer_1_output), 1) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACActor for Actor {} + +#[derive(Module, Debug)] +pub struct Critic { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Critic { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Critic { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + self.linear_2.forward(layer_1_output) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACCritic for Critic { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 4096; + +type MyAgent = SAC>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + let state_dim = <::StateType as State>::size(); + let action_dim = <::ActionType as Action>::size(); + + let actor = Actor::::new(state_dim, conf.dense_size, action_dim); + let critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); + let critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); + let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); + + let mut agent = MyAgent::default(); + + let config = SACTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + min_probability: conf.min_probability, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone()); + + let mut optimizer = SACOptimizer::new( + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.init(), + ); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + if let Some(action) = MyAgent::::react_with_model(&state, &nets.actor) { + let snapshot = env.step(action); + + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + nets = agent.train::(nets, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + env.reset(); + episode_done = true; + + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs() + ); + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + } + + let valid_agent = agent.valid(nets.actor); + if let Some(path) = &conf.save_path { + if let Some(model) = valid_agent.model() { + save_model(model, path); + } + } + valid_agent +} + +pub fn save_model(model: &Actor>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Actor::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + diff --git a/bot/src/burnrl/algos/sac_valid.rs b/bot/src/burnrl/algos/sac_valid.rs new file mode 100644 index 0000000..81523c4 --- /dev/null +++ b/bot/src/burnrl/algos/sac_valid.rs @@ -0,0 +1,222 @@ +use crate::burnrl::environment_valid::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Actor { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Actor { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } +} + +impl Model, Tensor> for Actor { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + softmax(self.linear_2.forward(layer_1_output), 1) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACActor for Actor {} + +#[derive(Module, Debug)] +pub struct Critic { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Critic { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Critic { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + self.linear_2.forward(layer_1_output) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACCritic for Critic { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 4096; + +type MyAgent = SAC>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + let state_dim = <::StateType as State>::size(); + let action_dim = <::ActionType as Action>::size(); + + let actor = Actor::::new(state_dim, conf.dense_size, action_dim); + let critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); + let critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); + let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); + + let mut agent = MyAgent::default(); + + let config = SACTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + min_probability: conf.min_probability, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone()); + + let mut optimizer = SACOptimizer::new( + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.init(), + ); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + if let Some(action) = MyAgent::::react_with_model(&state, &nets.actor) { + let snapshot = env.step(action); + + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + nets = agent.train::(nets, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + env.reset(); + episode_done = true; + + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs() + ); + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + } + + let valid_agent = agent.valid(nets.actor); + if let Some(path) = &conf.save_path { + if let Some(model) = valid_agent.model() { + save_model(model, path); + } + } + valid_agent +} + +pub fn save_model(model: &Actor>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Actor::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index ce76b4d..d289dd6 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,8 +1,10 @@ +use bot::burnrl::algos::{ + dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid, +}; use bot::burnrl::environment::TrictracEnvironment; use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig; use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid; use bot::burnrl::utils::{demo_model, Config}; -use bot::burnrl::{dqn_big_model, dqn_model, dqn_valid_model, ppo_model, sac_model}; use burn::backend::{Autodiff, NdArray}; use burn_rl::base::ElemType; use std::env; @@ -51,9 +53,9 @@ fn main() { match algo.as_str() { "dqn" => { - let _agent = dqn_model::run::(&conf, false); + let _agent = dqn::run::(&conf, false); println!("> Chargement du modèle pour test"); - let loaded_model = dqn_model::load_model(conf.dense_size, &path); + let loaded_model = dqn::load_model(conf.dense_size, &path); let loaded_agent: burn_rl::agent::DQN = burn_rl::agent::DQN::new(loaded_model.unwrap()); @@ -61,33 +63,87 @@ fn main() { demo_model(loaded_agent); } "dqn_big" => { - let _agent = dqn_big_model::run::(&conf, false); + let _agent = dqn_big::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = dqn_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::DQN = + burn_rl::agent::DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); } "dqn_valid" => { - let _agent = dqn_valid_model::run::(&conf, false); + let _agent = dqn_valid::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = dqn_valid::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::DQN = + burn_rl::agent::DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); } "sac" => { - let _agent = sac_model::run::(&conf, false); + let _agent = sac::run::(&conf, false); println!("> Chargement du modèle pour test"); - let loaded_model = sac_model::load_model(conf.dense_size, &path); + let loaded_model = sac::load_model(conf.dense_size, &path); let loaded_agent: burn_rl::agent::SAC = burn_rl::agent::SAC::new(loaded_model.unwrap()); println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - "ppo" => { - let _agent = ppo_model::run::(&conf, false); + "sac_big" => { + let _agent = sac_big::run::(&conf, false); println!("> Chargement du modèle pour test"); - let loaded_model = ppo_model::load_model(conf.dense_size, &path); + let loaded_model = sac_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::SAC = + burn_rl::agent::SAC::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } + "sac_valid" => { + let _agent = sac_valid::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = sac_valid::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::SAC = + burn_rl::agent::SAC::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } + "ppo" => { + let _agent = ppo::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = ppo::load_model(conf.dense_size, &path); let loaded_agent: burn_rl::agent::PPO = burn_rl::agent::PPO::new(loaded_model.unwrap()); println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } + "ppo_big" => { + let _agent = ppo_big::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = ppo_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::PPO = + burn_rl::agent::PPO::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } + "ppo_valid" => { + let _agent = ppo_valid::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = ppo_valid::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::PPO = + burn_rl::agent::PPO::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } &_ => { - dbg!("unknown algo {algo}"); + println!("unknown algo {algo}"); } } } diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs index 7b719ee..62bebc8 100644 --- a/bot/src/burnrl/mod.rs +++ b/bot/src/burnrl/mod.rs @@ -1,9 +1,5 @@ -pub mod dqn_big_model; -pub mod dqn_model; -pub mod dqn_valid_model; +pub mod algos; pub mod environment; pub mod environment_big; pub mod environment_valid; -pub mod ppo_model; -pub mod sac_model; pub mod utils; diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index 1f317d0..2fea85e 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -6,11 +6,11 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use log::info; use store::MoveRules; -use crate::burnrl::dqn_model; +use crate::burnrl::algos::dqn; use crate::burnrl::environment; use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; -type DqnBurnNetwork = dqn_model::Net>; +type DqnBurnNetwork = dqn::Net>; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -40,7 +40,7 @@ impl DqnBurnStrategy { pub fn new_with_model(model_path: &String) -> Self { info!("Loading model {model_path:?}"); let mut strategy = Self::new(); - strategy.model = dqn_model::load_model(256, model_path); + strategy.model = dqn::load_model(256, model_path); strategy } diff --git a/justfile b/justfile index f554b15..f89bc7c 100644 --- a/justfile +++ b/justfile @@ -25,13 +25,13 @@ pythonlib: trainsimple: cargo build --release --bin=train_dqn_simple LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out -trainbot: +trainbot algo: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok # ./bot/scripts/trainValid.sh - ./bot/scripts/train.sh -plottrainbot: - ./bot/scripts/train.sh plot + ./bot/scripts/train.sh {{algo}} +plottrainbot algo: + ./bot/scripts/train.sh plot {{algo}} debugtrainbot: cargo build --bin=train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn diff --git a/store/src/board.rs b/store/src/board.rs index 4740f2d..da0bae8 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -271,7 +271,7 @@ impl Board { .map(|cells| { cells .into_iter() - .map(|cell| format!("{:>5}", cell)) + .map(|cell| format!("{cell:>5}")) .collect::>() .join("") }) @@ -282,7 +282,7 @@ impl Board { .map(|cells| { cells .into_iter() - .map(|cell| format!("{:>5}", cell)) + .map(|cell| format!("{cell:>5}")) .collect::>() .join("") }) diff --git a/store/src/game.rs b/store/src/game.rs index 6f593bb..f8a1276 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -244,7 +244,7 @@ impl GameState { pos_bits.push_str(&white_bits); pos_bits.push_str(&black_bits); - pos_bits = format!("{:0>108}", pos_bits); + pos_bits = format!("{pos_bits:0>108}"); // println!("{}", pos_bits); let pos_u8 = pos_bits .as_bytes() @@ -647,9 +647,7 @@ impl GameState { fn inc_roll_count(&mut self, player_id: PlayerId) { self.players.get_mut(&player_id).map(|p| { - if p.dice_roll_count < u8::MAX { - p.dice_roll_count += 1; - } + p.dice_roll_count = p.dice_roll_count.saturating_add(1); p }); } diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index c8ea334..4e94d08 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -603,7 +603,7 @@ mod tests { ); let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) }); let jans = points_rules.get_result_jans(8); - assert!(jans.0.len() > 0); + assert!(!jans.0.is_empty()); } #[test] @@ -628,7 +628,7 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2, ], ); - let mut rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) }); + let rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) }); assert_eq!(12, rules.get_points(5).0); // Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4