diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 127e69c..41a29e2 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,9 +1,9 @@ use bot::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{Autodiff, NdArray}; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::module::Module; use burn::record::{CompactRecorder, Recorder}; use burn_rl::agent::DQN; -use burn_rl::base::ElemType; +use burn_rl::base::{Action, Agent, ElemType, Environment, State}; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; @@ -18,8 +18,12 @@ fn main() { println!("> Sauvegarde du modèle de validation"); save_model(valid_agent.model().as_ref().unwrap()); - println!("> Test"); - demo_model::(valid_agent); + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(); + let loaded_agent = DQN::new(loaded_model); + + println!("> Test avec le modèle chargé"); + demo_model::(loaded_agent); } fn save_model(model: &dqn_model::Net>) { @@ -31,3 +35,26 @@ fn save_model(model: &dqn_model::Net>) { .record(model.clone().into_record(), model_path.into()) .unwrap(); } + +fn load_model() -> dqn_model::Net> { + // TODO : reprendre le DENSE_SIZE de dqn_model.rs + const DENSE_SIZE: usize = 128; + + let path = "models/burn_dqn".to_string(); + let model_path = format!("{}_model.burn", path); + println!("Chargement du modèle depuis : {}", model_path); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + DENSE_SIZE, + ::ActionType::size(), + ) + .load_record(record) +}