trictrac/bot/src/burnrl/main.rs

61 lines
2 KiB
Rust
Raw Normal View History

2025-07-23 17:25:05 +02:00
use bot::burnrl::{dqn_model, environment, utils::demo_model};
2025-07-23 21:52:32 +02:00
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
2025-07-23 21:16:28 +02:00
use burn::module::Module;
use burn::record::{CompactRecorder, Recorder};
use burn_rl::agent::DQN;
2025-07-23 21:52:32 +02:00
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
2025-07-08 21:58:15 +02:00
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
2025-07-23 21:16:28 +02:00
println!("> Entraînement");
2025-07-23 17:25:05 +02:00
let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
2025-07-08 21:58:15 +02:00
2025-07-23 21:16:28 +02:00
let valid_agent = agent.valid();
2025-07-23 21:28:29 +02:00
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
2025-07-23 21:52:32 +02:00
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé");
demo_model::<Env>(loaded_agent);
2025-07-23 21:16:28 +02:00
}
2025-07-23 21:28:29 +02:00
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
2025-07-23 21:16:28 +02:00
let path = "models/burn_dqn".to_string();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
2025-07-23 21:28:29 +02:00
println!("Modèle de validation sauvegardé : {}", model_path);
2025-07-23 21:16:28 +02:00
recorder
2025-07-23 21:28:29 +02:00
.record(model.clone().into_record(), model_path.into())
2025-07-23 21:16:28 +02:00
.unwrap();
2025-07-08 21:58:15 +02:00
}
2025-07-23 21:52:32 +02:00
fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.burn", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
DENSE_SIZE,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}