2025-08-08 17:07:34 +02:00
|
|
|
use bot::dqn::burnrl::{
|
|
|
|
|
dqn_model, environment,
|
|
|
|
|
utils::{demo_model, load_model, save_model},
|
|
|
|
|
};
|
|
|
|
|
use burn::backend::{Autodiff, NdArray};
|
2025-07-23 21:16:28 +02:00
|
|
|
use burn_rl::agent::DQN;
|
2025-08-08 17:07:34 +02:00
|
|
|
use burn_rl::base::ElemType;
|
2025-07-08 21:58:15 +02:00
|
|
|
|
|
|
|
|
type Backend = Autodiff<NdArray<ElemType>>;
|
|
|
|
|
type Env = environment::TrictracEnvironment;
|
|
|
|
|
|
|
|
|
|
fn main() {
|
2025-08-03 20:32:06 +02:00
|
|
|
// println!("> Entraînement");
|
2025-08-02 12:42:32 +02:00
|
|
|
let conf = dqn_model::DqnConfig {
|
2025-08-03 16:11:45 +02:00
|
|
|
num_episodes: 40,
|
2025-08-02 12:42:32 +02:00
|
|
|
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
2025-08-03 22:16:28 +02:00
|
|
|
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
2025-08-02 12:42:32 +02:00
|
|
|
dense_size: 256, // neural network complexity
|
|
|
|
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
|
|
|
|
eps_end: 0.05,
|
2025-08-03 16:11:45 +02:00
|
|
|
eps_decay: 3000.0,
|
2025-08-02 12:42:32 +02:00
|
|
|
};
|
|
|
|
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
2025-07-08 21:58:15 +02:00
|
|
|
|
2025-07-23 21:16:28 +02:00
|
|
|
let valid_agent = agent.valid();
|
|
|
|
|
|
2025-07-23 21:28:29 +02:00
|
|
|
println!("> Sauvegarde du modèle de validation");
|
|
|
|
|
|
2025-08-08 17:07:34 +02:00
|
|
|
let path = "models/burn_dqn_40".to_string();
|
2025-07-26 09:37:54 +02:00
|
|
|
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
|
|
|
|
|
2025-07-23 21:52:32 +02:00
|
|
|
println!("> Chargement du modèle pour test");
|
2025-08-02 12:42:32 +02:00
|
|
|
let loaded_model = load_model(conf.dense_size, &path);
|
2025-08-08 21:31:38 +02:00
|
|
|
let loaded_agent = DQN::new(loaded_model.unwrap());
|
2025-07-23 21:52:32 +02:00
|
|
|
|
|
|
|
|
println!("> Test avec le modèle chargé");
|
2025-07-26 09:37:54 +02:00
|
|
|
demo_model(loaded_agent);
|
2025-07-23 21:16:28 +02:00
|
|
|
}
|