fix: convert_action from_action_index

This commit is contained in:
Henri Bourcereau 2025-07-25 17:26:02 +02:00
parent 1e18b784d1
commit b92c9eb7ff
4 changed files with 139 additions and 4 deletions

View file

@ -10,7 +10,7 @@ type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 3;
let num_episodes = 10;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
@ -18,6 +18,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
let loaded_agent = DQN::new(loaded_model);
@ -29,7 +32,7 @@ fn main() {
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
@ -41,7 +44,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.burn", path);
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();