2025-07-23 17:25:05 +02:00
|
|
|
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
2025-07-08 21:58:15 +02:00
|
|
|
use burn::backend::{Autodiff, NdArray};
|
2025-07-23 21:16:28 +02:00
|
|
|
use burn::module::Module;
|
|
|
|
|
use burn::record::{CompactRecorder, Recorder};
|
|
|
|
|
use burn_rl::agent::DQN;
|
2025-07-08 21:58:15 +02:00
|
|
|
use burn_rl::base::ElemType;
|
|
|
|
|
|
|
|
|
|
type Backend = Autodiff<NdArray<ElemType>>;
|
|
|
|
|
type Env = environment::TrictracEnvironment;
|
|
|
|
|
|
|
|
|
|
fn main() {
|
2025-07-23 21:16:28 +02:00
|
|
|
println!("> Entraînement");
|
2025-07-23 17:25:05 +02:00
|
|
|
let num_episodes = 3;
|
|
|
|
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
2025-07-08 21:58:15 +02:00
|
|
|
|
2025-07-23 21:16:28 +02:00
|
|
|
let valid_agent = agent.valid();
|
|
|
|
|
|
2025-07-23 21:28:29 +02:00
|
|
|
println!("> Sauvegarde du modèle de validation");
|
|
|
|
|
save_model(valid_agent.model().as_ref().unwrap());
|
|
|
|
|
|
2025-07-23 21:16:28 +02:00
|
|
|
println!("> Test");
|
|
|
|
|
demo_model::<Env>(valid_agent);
|
|
|
|
|
}
|
|
|
|
|
|
2025-07-23 21:28:29 +02:00
|
|
|
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
2025-07-23 21:16:28 +02:00
|
|
|
let path = "models/burn_dqn".to_string();
|
|
|
|
|
let recorder = CompactRecorder::new();
|
|
|
|
|
let model_path = format!("{}_model.burn", path);
|
2025-07-23 21:28:29 +02:00
|
|
|
println!("Modèle de validation sauvegardé : {}", model_path);
|
2025-07-23 21:16:28 +02:00
|
|
|
recorder
|
2025-07-23 21:28:29 +02:00
|
|
|
.record(model.clone().into_record(), model_path.into())
|
2025-07-23 21:16:28 +02:00
|
|
|
.unwrap();
|
2025-07-08 21:58:15 +02:00
|
|
|
}
|