refact models paths

This commit is contained in:
Henri Bourcereau 2025-08-18 17:44:01 +02:00
parent 2499c3377f
commit e66921fcce
6 changed files with 16 additions and 9 deletions

View file

@ -184,7 +184,8 @@ impl Environment for TrictracEnvironment {
}
} else {
// Action non convertible, pénalité
reward = -0.5;
panic!("action non convertible");
//reward = -0.5;
}
}

View file

@ -26,13 +26,13 @@ fn main() {
// epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ?
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.0005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100)
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
@ -41,7 +41,7 @@ fn main() {
println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_40".to_string();
let path = "bot/models/burnrl_dqn".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
println!("> Chargement du modèle pour test");

View file

@ -15,7 +15,7 @@ use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
let model_path = format!("{path}.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
@ -23,7 +23,7 @@ pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
}
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
let model_path = format!("{path}_model.mpk");
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()

View file

@ -157,6 +157,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
}
}
if valid_actions.is_empty() {
panic!("empty valid_actions for state {game_state}");
}
valid_actions
}

View file

@ -161,6 +161,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
}
}
if valid_actions.is_empty() {
panic!("empty valid_actions for state {game_state}");
}
valid_actions
}

View file

@ -9,7 +9,7 @@ shell:
runcli:
RUST_LOG=info cargo run --bin=client_cli
runclibots:
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match: