diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 3fd08c4..4da2866 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [[bin]] name = "train_dqn_burn" -path = "src/dqn/burnrl/main.rs" +path = "src/burnrl/main.rs" [[bin]] name = "train_dqn" diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs index 0c333b0..af0e2dd 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -58,35 +58,17 @@ impl DQNModel for Net { } #[allow(unused)] -const MEMORY_SIZE: usize = 8192; - -pub struct DqnConfig { - pub num_episodes: usize, - // pub memory_size: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - num_episodes: 1000, - // memory_size: 8192, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - } - } -} +const MEMORY_SIZE: usize = 4096; +const DENSE_SIZE: usize = 128; +const EPS_DECAY: f64 = 1000.0; +const EPS_START: f64 = 0.9; +const EPS_END: f64 = 0.05; type MyAgent = DQN>; #[allow(unused)] pub fn run( - conf: &DqnConfig, + num_episodes: usize, visualized: bool, ) -> DQN> { // ) -> impl Agent { @@ -94,7 +76,7 @@ pub fn run( let model = Net::::new( <::StateType as State>::size(), - conf.dense_size, + DENSE_SIZE, <::ActionType as Action>::size(), ); @@ -112,7 +94,7 @@ pub fn run( let mut step = 0_usize; - for episode in 0..conf.num_episodes { + for episode in 0..num_episodes { let mut episode_done = false; let mut episode_reward: ElemType = 0.0; let mut episode_duration = 0_usize; @@ -120,8 +102,8 @@ pub fn run( let mut now = SystemTime::now(); while !episode_done { - let eps_threshold = conf.eps_end - + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let eps_threshold = + EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY); let action = DQN::>::react_with_exploration(&policy_net, state, eps_threshold); let snapshot = env.step(action); diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index f8e5f21..40bcc29 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -91,7 +91,8 @@ impl Environment for TrictracEnvironment { type ActionType = TrictracAction; type RewardType = f32; - const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies + const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies + // const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { let mut game = GameState::new(false); @@ -259,7 +260,7 @@ impl TrictracEnvironment { // } TrictracAction::Go => { // Continuer après avoir gagné un trou - reward += 0.4; + reward += 0.2; Some(GameEvent::Go { player_id: self.active_player_id, }) @@ -288,7 +289,7 @@ impl TrictracEnvironment { let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - reward += 0.4; + reward += 0.2; Some(GameEvent::Move { player_id: self.active_player_id, moves: (checker_move1, checker_move2), diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index fdaafc6..0919d5f 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,4 +1,4 @@ -use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; +use bot::burnrl::{dqn_model, environment, utils::demo_model}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::module::Module; use burn::record::{CompactRecorder, Recorder}; @@ -10,16 +10,8 @@ type Env = environment::TrictracEnvironment; fn main() { println!("> Entraînement"); - let conf = dqn_model::DqnConfig { - num_episodes: 50, - // memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant - // max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant - dense_size: 256, // neural network complexity - eps_start: 0.9, // epsilon initial value (0.9 => more exploration) - eps_end: 0.05, - eps_decay: 1000.0, - }; - let agent = dqn_model::run::(&conf, false); //true); + let num_episodes = 50; + let agent = dqn_model::run::(num_episodes, false); //true); let valid_agent = agent.valid(); @@ -32,7 +24,7 @@ fn main() { // demo_model::(valid_agent); println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); + let loaded_model = load_model(&path); let loaded_agent = DQN::new(loaded_model); println!("> Test avec le modèle chargé"); @@ -48,7 +40,10 @@ fn save_model(model: &dqn_model::Net>, path: &String) { .unwrap(); } -fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { +fn load_model(path: &String) -> dqn_model::Net> { + // TODO : reprendre le DENSE_SIZE de dqn_model.rs + const DENSE_SIZE: usize = 128; + let model_path = format!("{}_model.mpk", path); println!("Chargement du modèle depuis : {}", model_path); @@ -61,7 +56,7 @@ fn load_model(dense_size: usize, path: &String) -> dqn_model::Net::StateType::size(), - dense_size, + DENSE_SIZE, ::ActionType::size(), ) .load_record(record) diff --git a/store/src/game_rules_moves.rs b/store/src/game_rules_moves.rs index 17e572e..1a67340 100644 --- a/store/src/game_rules_moves.rs +++ b/store/src/game_rules_moves.rs @@ -93,18 +93,6 @@ impl MoveRules { /// ---- moves_possibles : First of three checks for moves fn moves_possible(&self, moves: &(CheckerMove, CheckerMove)) -> bool { let color = &Color::White; - - let move0_from = moves.0.get_from(); - if 0 < move0_from && move0_from == moves.1.get_from() { - if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from) - { - if color != field_color || field_count < 2 { - info!("Move not physically possible"); - return false; - } - } - } - if let Ok(chained_move) = moves.0.chain(moves.1) { // Check intermediary move and chained_move : "Tout d'une" if !self.board.passage_possible(color, &moves.0) @@ -1017,7 +1005,7 @@ mod tests { #[test] fn moves_possible() { - let mut state = MoveRules::default(); + let state = MoveRules::default(); // Chained moves let moves = ( @@ -1033,17 +1021,6 @@ mod tests { ); assert!(!state.moves_possible(&moves)); - // Can't move the same checker twice - state.board.set_positions([ - 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); - state.dice.values = (2, 1); - let moves = ( - CheckerMove::new(3, 5).unwrap(), - CheckerMove::new(3, 4).unwrap(), - ); - assert!(!state.moves_possible(&moves)); - // black moves let state = MoveRules::new(&Color::Black, &Board::default(), Dice::default()); let moves = (