From 354dcfd3415f8e27d3781deca22ecdc32eadad46 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 8 Jul 2025 21:58:15 +0200 Subject: [PATCH] wip burn-rl dqn example --- bot/Cargo.toml | 10 +- bot/src/bin/train_burn_rl.rs | 3 +- bot/src/bin/train_dqn_full.rs | 7 +- bot/src/burnrl/dqn_model.rs | 142 ++++++++++++++++++ .../environment.rs} | 15 +- bot/src/burnrl/main.rs | 16 ++ bot/src/burnrl/mod.rs | 3 + bot/src/burnrl/utils.rs | 44 ++++++ bot/src/lib.rs | 3 +- bot/src/strategy.rs | 1 - 10 files changed, 224 insertions(+), 20 deletions(-) create mode 100644 bot/src/burnrl/dqn_model.rs rename bot/src/{strategy/burn_environment.rs => burnrl/environment.rs} (96%) create mode 100644 bot/src/burnrl/main.rs create mode 100644 bot/src/burnrl/mod.rs create mode 100644 bot/src/burnrl/utils.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 38bfee9..5578fae 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -5,13 +5,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "train_dqn_burn" +path = "src/burnrl/main.rs" + [[bin]] name = "train_dqn" path = "src/bin/train_dqn.rs" -[[bin]] -name = "train_burn_rl" -path = "src/bin/train_burn_rl.rs" +# [[bin]] +# name = "train_burn_rl" +# path = "src/bin/train_burn_rl.rs" [[bin]] name = "train_dqn_full" diff --git a/bot/src/bin/train_burn_rl.rs b/bot/src/bin/train_burn_rl.rs index 6962f84..73337cd 100644 --- a/bot/src/bin/train_burn_rl.rs +++ b/bot/src/bin/train_burn_rl.rs @@ -1,4 +1,4 @@ -use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment}; +use bot::burnrl::environment::{TrictracAction, TrictracEnvironment}; use bot::strategy::dqn_common::get_valid_actions; use burn_rl::base::Environment; use rand::Rng; @@ -224,4 +224,3 @@ fn print_help() { println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple"); println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard"); } - diff --git a/bot/src/bin/train_dqn_full.rs b/bot/src/bin/train_dqn_full.rs index 56321b1..42e90ae 100644 --- a/bot/src/bin/train_dqn_full.rs +++ b/bot/src/bin/train_dqn_full.rs @@ -1,5 +1,5 @@ +use bot::burnrl::environment::{TrictracAction, TrictracEnvironment}; use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; -use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment}; use bot::strategy::dqn_common::get_valid_actions; use burn::optim::AdamConfig; use burn_rl::base::Environment; @@ -130,10 +130,7 @@ fn main() -> Result<(), Box> { let valid_indices: Vec = (0..valid_actions.len()).collect(); // Sélectionner une action avec l'agent DQN - let action_index = agent.select_action( - ¤t_state_data, - &valid_indices, - ); + let action_index = agent.select_action(¤t_state_data, &valid_indices); let action = TrictracAction { index: action_index as u32, }; diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs new file mode 100644 index 0000000..2a6db43 --- /dev/null +++ b/bot/src/burnrl/dqn_model.rs @@ -0,0 +1,142 @@ +use crate::burnrl::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 4096; +const DENSE_SIZE: usize = 128; +const EPS_DECAY: f64 = 1000.0; +const EPS_START: f64 = 0.9; +const EPS_END: f64 = 0.05; + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run( + num_episodes: usize, + visualized: bool, +) -> impl Agent { + let mut env = E::new(visualized); + + let model = Net::::new( + <::StateType as State>::size(), + DENSE_SIZE, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + let config = DQNTrainingConfig::default(); + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + + while !episode_done { + let eps_threshold = + EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= E::MAX_STEPS { + env.reset(); + episode_done = true; + + println!( + "{{\"episode\": {}, \"reward\": {:.4}, \"duration\": {}}}", + episode, episode_reward, episode_duration + ); + } else { + state = *snapshot.state(); + } + } + } + + agent.valid() +} diff --git a/bot/src/strategy/burn_environment.rs b/bot/src/burnrl/environment.rs similarity index 96% rename from bot/src/strategy/burn_environment.rs rename to bot/src/burnrl/environment.rs index 00d9ccd..669d3b4 100644 --- a/bot/src/strategy/burn_environment.rs +++ b/bot/src/burnrl/environment.rs @@ -1,3 +1,4 @@ +use crate::strategy::dqn_common; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; @@ -57,9 +58,7 @@ impl Action for TrictracAction { } fn size() -> usize { - // Utiliser l'espace d'actions compactes pour réduire la complexité - // Maximum estimé basé sur les actions contextuelles - 1000 // Estimation conservative, sera ajusté dynamiquement + 1252 } } @@ -205,8 +204,8 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, - ) -> Option { - use super::dqn_common::get_valid_actions; + ) -> Option { + use dqn_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); @@ -223,9 +222,9 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu fn execute_action( &mut self, - action: super::dqn_common::TrictracAction, + action: dqn_common::TrictracAction, ) -> Result> { - use super::dqn_common::TrictracAction; + use dqn_common::TrictracAction; let mut reward = 0.0; @@ -320,7 +319,7 @@ impl TrictracEnvironment { // Si c'est le tour de l'adversaire, jouer automatiquement if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { // Utiliser la stratégie default pour l'adversaire - use super::default::DefaultStrategy; + use crate::strategy::default::DefaultStrategy; use crate::BotStrategy; let mut default_strategy = DefaultStrategy::default(); diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs new file mode 100644 index 0000000..ef5da61 --- /dev/null +++ b/bot/src/burnrl/main.rs @@ -0,0 +1,16 @@ +use burn::backend::{Autodiff, NdArray}; +use burn_rl::base::ElemType; +use bot::burnrl::{ + dqn_model, + environment, + utils::demo_model, +}; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + let agent = dqn_model::run::(512, false); //true); + + demo_model::(agent); +} diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/burnrl/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs new file mode 100644 index 0000000..7cfb165 --- /dev/null +++ b/bot/src/burnrl/utils.rs @@ -0,0 +1,44 @@ +use burn::module::{Param, ParamId}; +use burn::nn::Linear; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; +use burn_rl::base::{Agent, ElemType, Environment}; + +pub fn demo_model(agent: impl Agent) { + let mut env = E::new(true); + let mut state = env.state(); + let mut done = false; + while !done { + if let Some(action) = agent.react(&state) { + let snapshot = env.step(action); + state = *snapshot.state(); + done = snapshot.done(); + } + } +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/lib.rs b/bot/src/lib.rs index d3da040..0dc60c0 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,7 +1,8 @@ +pub mod burnrl; pub mod strategy; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; -pub use strategy::burn_dqn_strategy::{BurnDqnStrategy, create_burn_dqn_strategy}; +pub use strategy::burn_dqn_strategy::{create_burn_dqn_strategy, BurnDqnStrategy}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; diff --git a/bot/src/strategy.rs b/bot/src/strategy.rs index e26c20f..a0ffc7a 100644 --- a/bot/src/strategy.rs +++ b/bot/src/strategy.rs @@ -1,6 +1,5 @@ pub mod burn_dqn_agent; pub mod burn_dqn_strategy; -pub mod burn_environment; pub mod client; pub mod default; pub mod dqn;