From f3fc053dbd721c0920e24ce2160ffd80df53ee80 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 23 Jul 2025 21:28:29 +0200 Subject: [PATCH] save inference model --- bot/src/burnrl/main.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index a78b586..127e69c 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -12,24 +12,22 @@ fn main() { println!("> Entraînement"); let num_episodes = 3; let agent = dqn_model::run::(num_episodes, false); //true); - println!("> Sauvegarde"); - save(&agent); - // cette ligne sert à extraire le "cerveau" de l'agent entraîné, - // sans les données nécessaires à l'entraînement let valid_agent = agent.valid(); + println!("> Sauvegarde du modèle de validation"); + save_model(valid_agent.model().as_ref().unwrap()); + println!("> Test"); demo_model::(valid_agent); } -fn save(agent: &DQN>) { +fn save_model(model: &dqn_model::Net>) { let path = "models/burn_dqn".to_string(); - let inference_network = agent.model().clone().into_record(); let recorder = CompactRecorder::new(); let model_path = format!("{}_model.burn", path); - println!("Modèle sauvegardé : {}", model_path); + println!("Modèle de validation sauvegardé : {}", model_path); recorder - .record(inference_network, model_path.into()) + .record(model.clone().into_record(), model_path.into()) .unwrap(); }