refact models paths
This commit is contained in:
parent
2499c3377f
commit
e66921fcce
|
|
@ -184,7 +184,8 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
} else {
|
||||
// Action non convertible, pénalité
|
||||
reward = -0.5;
|
||||
panic!("action non convertible");
|
||||
//reward = -0.5;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,13 +26,13 @@ fn main() {
|
|||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 2000.0, // 1000 ?
|
||||
|
||||
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.0005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
|
@ -41,7 +41,7 @@ fn main() {
|
|||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
|
||||
let path = "models/burn_dqn_40".to_string();
|
||||
let path = "bot/models/burnrl_dqn".to_string();
|
||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ use burn_rl::base::{Action, ElemType, Environment, State};
|
|||
|
||||
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
|
|
@ -23,7 +23,7 @@ pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
|||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
|
|
|
|||
|
|
@ -157,6 +157,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
|||
}
|
||||
}
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
panic!("empty valid_actions for state {game_state}");
|
||||
}
|
||||
valid_actions
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -161,6 +161,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
|||
}
|
||||
}
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
panic!("empty valid_actions for state {game_state}");
|
||||
}
|
||||
valid_actions
|
||||
}
|
||||
|
||||
|
|
|
|||
2
justfile
2
justfile
|
|
@ -9,7 +9,7 @@ shell:
|
|||
runcli:
|
||||
RUST_LOG=info cargo run --bin=client_cli
|
||||
runclibots:
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
|
||||
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||
match:
|
||||
|
|
|
|||
Loading…
Reference in a new issue