diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 4da2866..3fd08c4 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [[bin]] name = "train_dqn_burn" -path = "src/burnrl/main.rs" +path = "src/dqn/burnrl/main.rs" [[bin]] name = "train_dqn" diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs index af0e2dd..0c333b0 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -58,17 +58,35 @@ impl DQNModel for Net { } #[allow(unused)] -const MEMORY_SIZE: usize = 4096; -const DENSE_SIZE: usize = 128; -const EPS_DECAY: f64 = 1000.0; -const EPS_START: f64 = 0.9; -const EPS_END: f64 = 0.05; +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub num_episodes: usize, + // pub memory_size: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + num_episodes: 1000, + // memory_size: 8192, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + } + } +} type MyAgent = DQN>; #[allow(unused)] pub fn run( - num_episodes: usize, + conf: &DqnConfig, visualized: bool, ) -> DQN> { // ) -> impl Agent { @@ -76,7 +94,7 @@ pub fn run( let model = Net::::new( <::StateType as State>::size(), - DENSE_SIZE, + conf.dense_size, <::ActionType as Action>::size(), ); @@ -94,7 +112,7 @@ pub fn run( let mut step = 0_usize; - for episode in 0..num_episodes { + for episode in 0..conf.num_episodes { let mut episode_done = false; let mut episode_reward: ElemType = 0.0; let mut episode_duration = 0_usize; @@ -102,8 +120,8 @@ pub fn run( let mut now = SystemTime::now(); while !episode_done { - let eps_threshold = - EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY); + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); let action = DQN::>::react_with_exploration(&policy_net, state, eps_threshold); let snapshot = env.step(action); diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index 40bcc29..f8e5f21 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -91,8 +91,7 @@ impl Environment for TrictracEnvironment { type ActionType = TrictracAction; type RewardType = f32; - const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies - // const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies + const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { let mut game = GameState::new(false); @@ -260,7 +259,7 @@ impl TrictracEnvironment { // } TrictracAction::Go => { // Continuer après avoir gagné un trou - reward += 0.2; + reward += 0.4; Some(GameEvent::Go { player_id: self.active_player_id, }) @@ -289,7 +288,7 @@ impl TrictracEnvironment { let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - reward += 0.2; + reward += 0.4; Some(GameEvent::Move { player_id: self.active_player_id, moves: (checker_move1, checker_move2), diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 0919d5f..fdaafc6 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,4 +1,4 @@ -use bot::burnrl::{dqn_model, environment, utils::demo_model}; +use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::module::Module; use burn::record::{CompactRecorder, Recorder}; @@ -10,8 +10,16 @@ type Env = environment::TrictracEnvironment; fn main() { println!("> Entraînement"); - let num_episodes = 50; - let agent = dqn_model::run::(num_episodes, false); //true); + let conf = dqn_model::DqnConfig { + num_episodes: 50, + // memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant + // max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant + dense_size: 256, // neural network complexity + eps_start: 0.9, // epsilon initial value (0.9 => more exploration) + eps_end: 0.05, + eps_decay: 1000.0, + }; + let agent = dqn_model::run::(&conf, false); //true); let valid_agent = agent.valid(); @@ -24,7 +32,7 @@ fn main() { // demo_model::(valid_agent); println!("> Chargement du modèle pour test"); - let loaded_model = load_model(&path); + let loaded_model = load_model(conf.dense_size, &path); let loaded_agent = DQN::new(loaded_model); println!("> Test avec le modèle chargé"); @@ -40,10 +48,7 @@ fn save_model(model: &dqn_model::Net>, path: &String) { .unwrap(); } -fn load_model(path: &String) -> dqn_model::Net> { - // TODO : reprendre le DENSE_SIZE de dqn_model.rs - const DENSE_SIZE: usize = 128; - +fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { let model_path = format!("{}_model.mpk", path); println!("Chargement du modèle depuis : {}", model_path); @@ -56,7 +61,7 @@ fn load_model(path: &String) -> dqn_model::Net> { dqn_model::Net::new( ::StateType::size(), - DENSE_SIZE, + dense_size, ::ActionType::size(), ) .load_record(record)