From 6fa8a31cc75ebc3cf030c169ca5808d84c051b86 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 23 Jul 2025 21:16:28 +0200 Subject: [PATCH] refact : save model --- bot/src/burnrl/dqn_model.rs | 16 +++------------- bot/src/burnrl/main.rs | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index f337289..221b391 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -70,7 +70,8 @@ type MyAgent = DQN>; pub fn run( num_episodes: usize, visualized: bool, -) -> impl Agent { +) -> DQN> { + // ) -> impl Agent { let mut env = E::new(visualized); let model = Net::::new( @@ -138,16 +139,5 @@ pub fn run( } } } - - // Save - let path = "models/burn_dqn".to_string(); - let inference_network = agent.model().clone().into_record(); - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.burn", path); - println!("Modèle sauvegardé : {}", model_path); - recorder - .record(inference_network, model_path.into()) - .unwrap(); - - agent.valid() + agent } diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 6e55928..a78b586 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,13 +1,35 @@ use bot::burnrl::{dqn_model, environment, utils::demo_model}; use burn::backend::{Autodiff, NdArray}; +use burn::module::Module; +use burn::record::{CompactRecorder, Recorder}; +use burn_rl::agent::DQN; use burn_rl::base::ElemType; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; fn main() { + println!("> Entraînement"); let num_episodes = 3; let agent = dqn_model::run::(num_episodes, false); //true); + println!("> Sauvegarde"); + save(&agent); - demo_model::(agent); + // cette ligne sert à extraire le "cerveau" de l'agent entraîné, + // sans les données nécessaires à l'entraînement + let valid_agent = agent.valid(); + + println!("> Test"); + demo_model::(valid_agent); +} + +fn save(agent: &DQN>) { + let path = "models/burn_dqn".to_string(); + let inference_network = agent.model().clone().into_record(); + let recorder = CompactRecorder::new(); + let model_path = format!("{}_model.burn", path); + println!("Modèle sauvegardé : {}", model_path); + recorder + .record(inference_network, model_path.into()) + .unwrap(); }