refact : save model

This commit is contained in:
Henri Bourcereau 2025-07-23 21:16:28 +02:00
parent c6d33555ec
commit 6fa8a31cc7
2 changed files with 26 additions and 14 deletions

View file

@ -70,7 +70,8 @@ type MyAgent<E, B> = DQN<E, B, Net<B>>;
pub fn run<E: Environment, B: AutodiffBackend>(
num_episodes: usize,
visualized: bool,
) -> impl Agent<E> {
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
let mut env = E::new(visualized);
let model = Net::<B>::new(
@ -138,16 +139,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
}
}
}
// Save
let path = "models/burn_dqn".to_string();
let inference_network = agent.model().clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
println!("Modèle sauvegardé : {}", model_path);
recorder
.record(inference_network, model_path.into())
.unwrap();
agent.valid()
agent
}

View file

@ -1,13 +1,35 @@
use bot::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{Autodiff, NdArray};
use burn::module::Module;
use burn::record::{CompactRecorder, Recorder};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
println!("> Sauvegarde");
save(&agent);
demo_model::<Env>(agent);
// cette ligne sert à extraire le "cerveau" de l'agent entraîné,
// sans les données nécessaires à l'entraînement
let valid_agent = agent.valid();
println!("> Test");
demo_model::<Env>(valid_agent);
}
fn save(agent: &DQN<Env, Backend, dqn_model::Net<Backend>>) {
let path = "models/burn_dqn".to_string();
let inference_network = agent.model().clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
println!("Modèle sauvegardé : {}", model_path);
recorder
.record(inference_network, model_path.into())
.unwrap();
}