diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index a7ce014..b0bf4b9 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -184,7 +184,8 @@ impl Environment for TrictracEnvironment { } } else { // Action non convertible, pénalité - reward = -0.5; + panic!("action non convertible"); + //reward = -0.5; } } diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 7a99f46..152bf0e 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -26,13 +26,13 @@ fn main() { // epsilon is updated at the start of each episode eps_decay: 2000.0, // 1000 ? - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.0005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation // plus lente moins sensible aux coups de chance learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100) }; println!("{conf}----------"); let agent = dqn_model::run::(&conf, false); //true); @@ -41,7 +41,7 @@ fn main() { println!("> Sauvegarde du modèle de validation"); - let path = "models/burn_dqn_40".to_string(); + let path = "bot/models/burnrl_dqn".to_string(); save_model(valid_agent.model().as_ref().unwrap(), &path); println!("> Chargement du modèle pour test"); diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index 4ce4799..0682f2a 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -15,7 +15,7 @@ use burn_rl::base::{Action, ElemType, Environment, State}; pub fn save_model(model: &dqn_model::Net>, path: &String) { let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); + let model_path = format!("{path}.mpk"); println!("Modèle de validation sauvegardé : {model_path}"); recorder .record(model.clone().into_record(), model_path.into()) @@ -23,7 +23,7 @@ pub fn save_model(model: &dqn_model::Net>, path: &String) { } pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); + let model_path = format!("{path}.mpk"); // println!("Chargement du modèle depuis : {model_path}"); CompactRecorder::new() diff --git a/bot/src/dqn/dqn_common.rs b/bot/src/dqn/dqn_common.rs index 9dae81f..b2f2bad 100644 --- a/bot/src/dqn/dqn_common.rs +++ b/bot/src/dqn/dqn_common.rs @@ -157,6 +157,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { } } + if valid_actions.is_empty() { + panic!("empty valid_actions for state {game_state}"); + } valid_actions } diff --git a/bot/src/dqn/dqn_common_big.rs b/bot/src/dqn/dqn_common_big.rs index ee0dff3..db9ee2b 100644 --- a/bot/src/dqn/dqn_common_big.rs +++ b/bot/src/dqn/dqn_common_big.rs @@ -161,6 +161,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { } } + if valid_actions.is_empty() { + panic!("empty valid_actions for state {game_state}"); + } valid_actions } diff --git a/justfile b/justfile index ffa3229..f554b15 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,7 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk + cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk #cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: