save inference model

This commit is contained in:
Henri Bourcereau 2025-07-23 21:28:29 +02:00
parent 6fa8a31cc7
commit f3fc053dbd

View file

@ -12,24 +12,22 @@ fn main() {
println!("> Entraînement"); println!("> Entraînement");
let num_episodes = 3; let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true); let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
println!("> Sauvegarde");
save(&agent);
// cette ligne sert à extraire le "cerveau" de l'agent entraîné,
// sans les données nécessaires à l'entraînement
let valid_agent = agent.valid(); let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Test"); println!("> Test");
demo_model::<Env>(valid_agent); demo_model::<Env>(valid_agent);
} }
fn save(agent: &DQN<Env, Backend, dqn_model::Net<Backend>>) { fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string(); let path = "models/burn_dqn".to_string();
let inference_network = agent.model().clone().into_record();
let recorder = CompactRecorder::new(); let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path); let model_path = format!("{}_model.burn", path);
println!("Modèle sauvegardé : {}", model_path); println!("Modèle de validation sauvegardé : {}", model_path);
recorder recorder
.record(inference_network, model_path.into()) .record(model.clone().into_record(), model_path.into())
.unwrap(); .unwrap();
} }