action mask

This commit is contained in:
Henri Bourcereau 2025-07-26 09:37:54 +02:00
parent cb30fd3229
commit 3e1775428d
7 changed files with 111 additions and 554 deletions

View file

@ -10,27 +10,28 @@ type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 10;
let num_episodes = 50;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent);
let path = "models/burn_dqn_50".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
let loaded_model = load_model(&path);
let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé");
demo_model::<Env>(loaded_agent);
demo_model(loaded_agent);
}
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string();
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
@ -39,11 +40,10 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
.unwrap();
}
fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);