trictrac/bot/src/dqn/burnrl/main.rs

39 lines
1.3 KiB
Rust
Raw Normal View History

2025-08-08 17:07:34 +02:00
use bot::dqn::burnrl::{
dqn_model, environment,
utils::{demo_model, load_model, save_model},
};
use burn::backend::{Autodiff, NdArray};
2025-07-23 21:16:28 +02:00
use burn_rl::agent::DQN;
2025-08-08 17:07:34 +02:00
use burn_rl::base::ElemType;
2025-07-08 21:58:15 +02:00
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
2025-08-03 20:32:06 +02:00
// println!("> Entraînement");
let conf = dqn_model::DqnConfig {
2025-08-03 16:11:45 +02:00
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
2025-08-03 22:16:28 +02:00
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
2025-08-03 16:11:45 +02:00
eps_decay: 3000.0,
};
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
2025-07-08 21:58:15 +02:00
2025-07-23 21:16:28 +02:00
let valid_agent = agent.valid();
2025-07-23 21:28:29 +02:00
println!("> Sauvegarde du modèle de validation");
2025-08-08 17:07:34 +02:00
let path = "models/burn_dqn_40".to_string();
2025-07-26 09:37:54 +02:00
save_model(valid_agent.model().as_ref().unwrap(), &path);
2025-07-23 21:52:32 +02:00
println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path);
2025-08-08 21:31:38 +02:00
let loaded_agent = DQN::new(loaded_model.unwrap());
2025-07-23 21:52:32 +02:00
println!("> Test avec le modèle chargé");
2025-07-26 09:37:54 +02:00
demo_model(loaded_agent);
2025-07-23 21:16:28 +02:00
}