From fcd50bc0f230825b176cd81debc72a62c1b4bcd0 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 19 Aug 2025 16:27:37 +0200 Subject: [PATCH] refacto: bot directories --- bot/Cargo.toml | 8 +-- .../burnrl_big => burnrl/dqn}/dqn_model.rs | 13 +++-- bot/src/{dqn/burnrl => burnrl/dqn}/main.rs | 13 ++--- bot/src/{dqn/burnrl_big => burnrl/dqn}/mod.rs | 1 - bot/src/{dqn/burnrl => burnrl/dqn}/utils.rs | 8 ++- .../burnrl => burnrl/dqn_big}/dqn_model.rs | 4 +- .../burnrl_big => burnrl/dqn_big}/main.rs | 7 +-- .../burnrl_valid => burnrl/dqn_big}/mod.rs | 1 - .../burnrl_valid => burnrl/dqn_big}/utils.rs | 8 ++- .../dqn_valid}/dqn_model.rs | 4 +- .../burnrl_valid => burnrl/dqn_valid}/main.rs | 5 +- .../{dqn/burnrl => burnrl/dqn_valid}/mod.rs | 1 - .../burnrl_big => burnrl/dqn_valid}/utils.rs | 8 ++- bot/src/{dqn => }/burnrl/environment.rs | 53 ++++++++++++------- .../environment_big.rs} | 16 +++--- .../environment_valid.rs} | 16 +++--- bot/src/burnrl/mod.rs | 6 +++ bot/src/dqn/mod.rs | 7 --- .../{dqn/simple => dqn_simple}/dqn_model.rs | 3 +- .../{dqn/simple => dqn_simple}/dqn_trainer.rs | 2 +- bot/src/{dqn/simple => dqn_simple}/main.rs | 6 +-- bot/src/{dqn/simple => dqn_simple}/mod.rs | 0 bot/src/lib.rs | 5 +- bot/src/strategy/dqn.rs | 4 +- bot/src/strategy/dqnburn.rs | 5 +- .../{dqn/dqn_common.rs => training_common.rs} | 0 ...n_common_big.rs => training_common_big.rs} | 0 27 files changed, 110 insertions(+), 94 deletions(-) rename bot/src/{dqn/burnrl_big => burnrl/dqn}/dqn_model.rs (92%) rename bot/src/{dqn/burnrl => burnrl/dqn}/main.rs (85%) rename bot/src/{dqn/burnrl_big => burnrl/dqn}/mod.rs (61%) rename bot/src/{dqn/burnrl => burnrl/dqn}/utils.rs (95%) rename bot/src/{dqn/burnrl => burnrl/dqn_big}/dqn_model.rs (98%) rename bot/src/{dqn/burnrl_big => burnrl/dqn_big}/main.rs (94%) rename bot/src/{dqn/burnrl_valid => burnrl/dqn_big}/mod.rs (61%) rename bot/src/{dqn/burnrl_valid => burnrl/dqn_big}/utils.rs (95%) rename bot/src/{dqn/burnrl_valid => burnrl/dqn_valid}/dqn_model.rs (98%) rename bot/src/{dqn/burnrl_valid => burnrl/dqn_valid}/main.rs (96%) rename bot/src/{dqn/burnrl => burnrl/dqn_valid}/mod.rs (61%) rename bot/src/{dqn/burnrl_big => burnrl/dqn_valid}/utils.rs (95%) rename bot/src/{dqn => }/burnrl/environment.rs (91%) rename bot/src/{dqn/burnrl_big/environment.rs => burnrl/environment_big.rs} (96%) rename bot/src/{dqn/burnrl_valid/environment.rs => burnrl/environment_valid.rs} (96%) create mode 100644 bot/src/burnrl/mod.rs delete mode 100644 bot/src/dqn/mod.rs rename bot/src/{dqn/simple => dqn_simple}/dqn_model.rs (98%) rename bot/src/{dqn/simple => dqn_simple}/dqn_trainer.rs (99%) rename bot/src/{dqn/simple => dqn_simple}/main.rs (96%) rename bot/src/{dqn/simple => dqn_simple}/mod.rs (100%) rename bot/src/{dqn/dqn_common.rs => training_common.rs} (100%) rename bot/src/{dqn/dqn_common_big.rs => training_common_big.rs} (100%) diff --git a/bot/Cargo.toml b/bot/Cargo.toml index c043393..1dea531 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -7,19 +7,19 @@ edition = "2021" [[bin]] name = "train_dqn_burn_valid" -path = "src/dqn/burnrl_valid/main.rs" +path = "src/burnrl/dqn_valid/main.rs" [[bin]] name = "train_dqn_burn_big" -path = "src/dqn/burnrl_big/main.rs" +path = "src/burnrl/dqn_big/main.rs" [[bin]] name = "train_dqn_burn" -path = "src/dqn/burnrl/main.rs" +path = "src/burnrl/dqn/main.rs" [[bin]] name = "train_dqn_simple" -path = "src/dqn/simple/main.rs" +path = "src/dqn_simple/main.rs" [dependencies] pretty_assertions = "1.4.0" diff --git a/bot/src/dqn/burnrl_big/dqn_model.rs b/bot/src/burnrl/dqn/dqn_model.rs similarity index 92% rename from bot/src/dqn/burnrl_big/dqn_model.rs rename to bot/src/burnrl/dqn/dqn_model.rs index f50bf31..204cef0 100644 --- a/bot/src/dqn/burnrl_big/dqn_model.rs +++ b/bot/src/burnrl/dqn/dqn_model.rs @@ -1,5 +1,5 @@ -use crate::dqn::burnrl_big::environment::TrictracEnvironment; -use crate::dqn::burnrl_big::utils::soft_update_linear; +use crate::burnrl::dqn::utils::soft_update_linear; +use crate::burnrl::environment::TrictracEnvironment; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; @@ -126,7 +126,7 @@ pub fn run, B: AutodiffBackend>( ) -> DQN> { // ) -> impl Agent { let mut env = E::new(visualized); - env.as_mut().min_steps = conf.min_steps; + // env.as_mut().min_steps = conf.min_steps; env.as_mut().max_steps = conf.max_steps; let model = Net::::new( @@ -193,12 +193,17 @@ pub fn run, B: AutodiffBackend>( if snapshot.done() || episode_duration >= conf.max_steps { let envmut = env.as_mut(); + let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32) + * 100.0) + .round() as u32; println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}", envmut.goodmoves_count, + goodmoves_ratio, envmut.pointrolls_count, now.elapsed().unwrap().as_secs(), ); + if goodmoves_ratio < 5 && 10 < episode {} env.reset(); episode_done = true; now = SystemTime::now(); diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/burnrl/dqn/main.rs similarity index 85% rename from bot/src/dqn/burnrl/main.rs rename to bot/src/burnrl/dqn/main.rs index 152bf0e..fb55c60 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/burnrl/dqn/main.rs @@ -1,7 +1,8 @@ -use bot::dqn::burnrl::{ - dqn_model, environment, +use bot::burnrl::dqn::{ + dqn_model, utils::{demo_model, load_model, save_model}, }; +use bot::burnrl::environment; use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; use burn_rl::base::ElemType; @@ -15,9 +16,9 @@ fn main() { // See also MEMORY_SIZE in dqn_model.rs : 8192 let conf = dqn_model::DqnConfig { // defaults - num_episodes: 40, // 40 + num_episodes: 50, // 40 min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 2000, // 1000 max steps by episode + max_steps: 1000, // 1000 max steps by episode dense_size: 256, // 128 neural network complexity (default 128) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_end: 0.05, // 0.05 @@ -31,8 +32,8 @@ fn main() { // plus lente moins sensible aux coups de chance learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais // converger - batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100) + batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) }; println!("{conf}----------"); let agent = dqn_model::run::(&conf, false); //true); diff --git a/bot/src/dqn/burnrl_big/mod.rs b/bot/src/burnrl/dqn/mod.rs similarity index 61% rename from bot/src/dqn/burnrl_big/mod.rs rename to bot/src/burnrl/dqn/mod.rs index f4380eb..27fcc58 100644 --- a/bot/src/dqn/burnrl_big/mod.rs +++ b/bot/src/burnrl/dqn/mod.rs @@ -1,3 +1,2 @@ pub mod dqn_model; -pub mod environment; pub mod utils; diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/burnrl/dqn/utils.rs similarity index 95% rename from bot/src/dqn/burnrl/utils.rs rename to bot/src/burnrl/dqn/utils.rs index 0682f2a..77e2402 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/burnrl/dqn/utils.rs @@ -1,8 +1,6 @@ -use crate::dqn::burnrl::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common::get_valid_action_indices; +use crate::burnrl::dqn::dqn_model; +use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::training_common::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_big/dqn_model.rs similarity index 98% rename from bot/src/dqn/burnrl/dqn_model.rs rename to bot/src/burnrl/dqn_big/dqn_model.rs index 3e90904..1ccafef 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_big/dqn_model.rs @@ -1,5 +1,5 @@ -use crate::dqn::burnrl::environment::TrictracEnvironment; -use crate::dqn::burnrl::utils::soft_update_linear; +use crate::burnrl::dqn_big::utils::soft_update_linear; +use crate::burnrl::environment_big::TrictracEnvironment; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; diff --git a/bot/src/dqn/burnrl_big/main.rs b/bot/src/burnrl/dqn_big/main.rs similarity index 94% rename from bot/src/dqn/burnrl_big/main.rs rename to bot/src/burnrl/dqn_big/main.rs index c7221ec..a8c5c9f 100644 --- a/bot/src/dqn/burnrl_big/main.rs +++ b/bot/src/burnrl/dqn_big/main.rs @@ -1,13 +1,14 @@ -use bot::dqn::burnrl_big::{ - dqn_model, environment, +use bot::burnrl::dqn_big::{ + dqn_model, utils::{demo_model, load_model, save_model}, }; +use bot::burnrl::environment_big; use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; use burn_rl::base::ElemType; type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; +type Env = environment_big::TrictracEnvironment; fn main() { // println!("> Entraînement"); diff --git a/bot/src/dqn/burnrl_valid/mod.rs b/bot/src/burnrl/dqn_big/mod.rs similarity index 61% rename from bot/src/dqn/burnrl_valid/mod.rs rename to bot/src/burnrl/dqn_big/mod.rs index f4380eb..27fcc58 100644 --- a/bot/src/dqn/burnrl_valid/mod.rs +++ b/bot/src/burnrl/dqn_big/mod.rs @@ -1,3 +1,2 @@ pub mod dqn_model; -pub mod environment; pub mod utils; diff --git a/bot/src/dqn/burnrl_valid/utils.rs b/bot/src/burnrl/dqn_big/utils.rs similarity index 95% rename from bot/src/dqn/burnrl_valid/utils.rs rename to bot/src/burnrl/dqn_big/utils.rs index 6cced18..fa8de44 100644 --- a/bot/src/dqn/burnrl_valid/utils.rs +++ b/bot/src/burnrl/dqn_big/utils.rs @@ -1,8 +1,6 @@ -use crate::dqn::burnrl_valid::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common::get_valid_action_indices; +use crate::burnrl::dqn_big::dqn_model; +use crate::burnrl::environment_big::{TrictracAction, TrictracEnvironment}; +use crate::training_common_big::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; diff --git a/bot/src/dqn/burnrl_valid/dqn_model.rs b/bot/src/burnrl/dqn_valid/dqn_model.rs similarity index 98% rename from bot/src/dqn/burnrl_valid/dqn_model.rs rename to bot/src/burnrl/dqn_valid/dqn_model.rs index 4dd5180..9d53a2f 100644 --- a/bot/src/dqn/burnrl_valid/dqn_model.rs +++ b/bot/src/burnrl/dqn_valid/dqn_model.rs @@ -1,5 +1,5 @@ -use crate::dqn::burnrl_valid::environment::TrictracEnvironment; -use crate::dqn::burnrl_valid::utils::soft_update_linear; +use crate::burnrl::dqn_valid::utils::soft_update_linear; +use crate::burnrl::environment::TrictracEnvironment; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; diff --git a/bot/src/dqn/burnrl_valid/main.rs b/bot/src/burnrl/dqn_valid/main.rs similarity index 96% rename from bot/src/dqn/burnrl_valid/main.rs rename to bot/src/burnrl/dqn_valid/main.rs index ee0dd1f..b049372 100644 --- a/bot/src/dqn/burnrl_valid/main.rs +++ b/bot/src/burnrl/dqn_valid/main.rs @@ -1,7 +1,8 @@ -use bot::dqn::burnrl_valid::{ - dqn_model, environment, +use bot::burnrl::dqn_valid::{ + dqn_model, utils::{demo_model, load_model, save_model}, }; +use bot::burnrl::environment; use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; use burn_rl::base::ElemType; diff --git a/bot/src/dqn/burnrl/mod.rs b/bot/src/burnrl/dqn_valid/mod.rs similarity index 61% rename from bot/src/dqn/burnrl/mod.rs rename to bot/src/burnrl/dqn_valid/mod.rs index f4380eb..27fcc58 100644 --- a/bot/src/dqn/burnrl/mod.rs +++ b/bot/src/burnrl/dqn_valid/mod.rs @@ -1,3 +1,2 @@ pub mod dqn_model; -pub mod environment; pub mod utils; diff --git a/bot/src/dqn/burnrl_big/utils.rs b/bot/src/burnrl/dqn_valid/utils.rs similarity index 95% rename from bot/src/dqn/burnrl_big/utils.rs rename to bot/src/burnrl/dqn_valid/utils.rs index 88c8971..2e87e2a 100644 --- a/bot/src/dqn/burnrl_big/utils.rs +++ b/bot/src/burnrl/dqn_valid/utils.rs @@ -1,8 +1,6 @@ -use crate::dqn::burnrl_big::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common_big::get_valid_action_indices; +use crate::burnrl::dqn_valid::dqn_model; +use crate::burnrl::environment_valid::{TrictracAction, TrictracEnvironment}; +use crate::training_common::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/burnrl/environment.rs similarity index 91% rename from bot/src/dqn/burnrl/environment.rs rename to bot/src/burnrl/environment.rs index b0bf4b9..1d8e80d 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -1,13 +1,15 @@ -use crate::dqn::dqn_common; +use std::io::Write; + +use crate::training_common; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; -const ERROR_REWARD: f32 = -2.12121; -const REWARD_VALID_MOVE: f32 = 2.12121; +const ERROR_REWARD: f32 = -1.12121; +const REWARD_VALID_MOVE: f32 = 1.12121; const REWARD_RATIO: f32 = 0.01; -const WIN_POINTS: f32 = 0.1; +const WIN_POINTS: f32 = 1.0; /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] @@ -89,7 +91,7 @@ pub struct TrictracEnvironment { current_state: TrictracState, episode_reward: f32, pub step_count: usize, - pub min_steps: f32, + pub best_ratio: f32, pub max_steps: usize, pub pointrolls_count: usize, pub goodmoves_count: usize, @@ -122,7 +124,7 @@ impl Environment for TrictracEnvironment { current_state, episode_reward: 0.0, step_count: 0, - min_steps: 250.0, + best_ratio: 0.0, max_steps: 2000, pointrolls_count: 0, goodmoves_count: 0, @@ -151,10 +153,21 @@ impl Environment for TrictracEnvironment { } else { self.goodmoves_count as f32 / self.step_count as f32 }; + self.best_ratio = self.best_ratio.max(self.goodmoves_ratio); + let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 { + let path = "bot/models/logs/debug.log"; + if let Ok(mut out) = std::fs::File::create(path) { + write!(out, "{:?}", self.game.history); + } + "!!!!" + } else { + "" + }; println!( - "info: correct moves: {} ({}%)", + "info: correct moves: {} ({}%) {}", self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 + (100.0 * self.goodmoves_ratio).round() as u32, + warning ); self.step_count = 0; self.pointrolls_count = 0; @@ -195,9 +208,10 @@ impl Environment for TrictracEnvironment { } // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + // let max_steps = self.max_steps; + // let max_steps = self.min_steps + // + (self.max_steps as f32 - self.min_steps) + // * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); if done { @@ -210,7 +224,8 @@ impl Environment for TrictracEnvironment { } } } - let terminated = done || self.step_count >= max_steps.round() as usize; + let terminated = done || self.step_count >= self.max_steps; + // let terminated = done || self.step_count >= max_steps.round() as usize; // Mettre à jour l'état self.current_state = TrictracState::from_game_state(&self.game); @@ -229,8 +244,8 @@ impl Environment for TrictracEnvironment { impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + training_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac @@ -239,8 +254,8 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, - ) -> Option { - use dqn_common::get_valid_actions; + ) -> Option { + use training_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); @@ -257,10 +272,10 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common::TrictracAction, + // action: training_common::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) { - use dqn_common::TrictracAction; + fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) { + use training_common::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/burnrl/environment_big.rs similarity index 96% rename from bot/src/dqn/burnrl_big/environment.rs rename to bot/src/burnrl/environment_big.rs index 53572ec..b362fc1 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/burnrl/environment_big.rs @@ -1,4 +1,4 @@ -use crate::dqn::dqn_common_big; +use crate::training_common_big; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; @@ -229,8 +229,8 @@ impl Environment for TrictracEnvironment { impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac @@ -239,8 +239,8 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, - ) -> Option { - use dqn_common_big::get_valid_actions; + ) -> Option { + use training_common_big::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); @@ -257,10 +257,10 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action:dqn_common_big::TrictracAction, + // action:training_common_big::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; + fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { + use training_common_big::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/burnrl/environment_valid.rs similarity index 96% rename from bot/src/dqn/burnrl_valid/environment.rs rename to bot/src/burnrl/environment_valid.rs index 7b1291f..346044c 100644 --- a/bot/src/dqn/burnrl_valid/environment.rs +++ b/bot/src/burnrl/environment_valid.rs @@ -1,4 +1,4 @@ -use crate::dqn::dqn_common_big; +use crate::training_common_big; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; @@ -214,16 +214,16 @@ impl TrictracEnvironment { const REWARD_RATIO: f32 = 1.0; /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac fn convert_valid_action_index( &self, action: TrictracAction, - ) -> Option { - use dqn_common_big::get_valid_actions; + ) -> Option { + use training_common_big::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(&self.game); @@ -240,10 +240,10 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common_big::TrictracAction, + // action: training_common_big::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; + fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { + use training_common_big::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs new file mode 100644 index 0000000..0afacb4 --- /dev/null +++ b/bot/src/burnrl/mod.rs @@ -0,0 +1,6 @@ +pub mod dqn; +pub mod dqn_big; +pub mod dqn_valid; +pub mod environment; +pub mod environment_big; +pub mod environment_valid; diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs deleted file mode 100644 index 7b12487..0000000 --- a/bot/src/dqn/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod burnrl; -pub mod burnrl_big; -pub mod dqn_common; -pub mod dqn_common_big; -pub mod simple; - -pub mod burnrl_valid; diff --git a/bot/src/dqn/simple/dqn_model.rs b/bot/src/dqn_simple/dqn_model.rs similarity index 98% rename from bot/src/dqn/simple/dqn_model.rs rename to bot/src/dqn_simple/dqn_model.rs index ba46212..9c31f44 100644 --- a/bot/src/dqn/simple/dqn_model.rs +++ b/bot/src/dqn_simple/dqn_model.rs @@ -1,4 +1,4 @@ -use crate::dqn::dqn_common::TrictracAction; +use crate::training_common_big::TrictracAction; use serde::{Deserialize, Serialize}; /// Configuration pour l'agent DQN @@ -151,4 +151,3 @@ impl SimpleNeuralNetwork { Ok(network) } } - diff --git a/bot/src/dqn/simple/dqn_trainer.rs b/bot/src/dqn_simple/dqn_trainer.rs similarity index 99% rename from bot/src/dqn/simple/dqn_trainer.rs rename to bot/src/dqn_simple/dqn_trainer.rs index a2ca5a8..ed60f5e 100644 --- a/bot/src/dqn/simple/dqn_trainer.rs +++ b/bot/src/dqn_simple/dqn_trainer.rs @@ -6,7 +6,7 @@ use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; -use crate::dqn::dqn_common_big::{get_valid_actions, TrictracAction}; +use crate::training_common_big::{get_valid_actions, TrictracAction}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/bot/src/dqn/simple/main.rs b/bot/src/dqn_simple/main.rs similarity index 96% rename from bot/src/dqn/simple/main.rs rename to bot/src/dqn_simple/main.rs index dba015a..024f895 100644 --- a/bot/src/dqn/simple/main.rs +++ b/bot/src/dqn_simple/main.rs @@ -1,6 +1,6 @@ -use bot::dqn::dqn_common::TrictracAction; -use bot::dqn::simple::dqn_model::DqnConfig; -use bot::dqn::simple::dqn_trainer::DqnTrainer; +use bot::dqn_simple::dqn_model::DqnConfig; +use bot::dqn_simple::dqn_trainer::DqnTrainer; +use bot::training_common::TrictracAction; use std::env; fn main() -> Result<(), Box> { diff --git a/bot/src/dqn/simple/mod.rs b/bot/src/dqn_simple/mod.rs similarity index 100% rename from bot/src/dqn/simple/mod.rs rename to bot/src/dqn_simple/mod.rs diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 202bfeb..6e3b269 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,5 +1,8 @@ -pub mod dqn; +pub mod burnrl; +pub mod dqn_simple; pub mod strategy; +pub mod training_common; +pub mod training_common_big; use log::debug; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 20ce0d5..2874195 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -3,8 +3,8 @@ use log::info; use std::path::Path; use store::MoveRules; -use crate::dqn::dqn_common_big::{get_valid_actions, sample_valid_action, TrictracAction}; -use crate::dqn::simple::dqn_model::SimpleNeuralNetwork; +use crate::dqn_simple::dqn_model::SimpleNeuralNetwork; +use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index b95ce90..3d25c2b 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -6,8 +6,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use log::info; use store::MoveRules; -use crate::dqn::burnrl::{dqn_model, environment, utils}; -use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; +use crate::burnrl::dqn::{dqn_model, utils}; +use crate::burnrl::environment; +use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; type DqnBurnNetwork = dqn_model::Net>; diff --git a/bot/src/dqn/dqn_common.rs b/bot/src/training_common.rs similarity index 100% rename from bot/src/dqn/dqn_common.rs rename to bot/src/training_common.rs diff --git a/bot/src/dqn/dqn_common_big.rs b/bot/src/training_common_big.rs similarity index 100% rename from bot/src/dqn/dqn_common_big.rs rename to bot/src/training_common_big.rs