diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index 2a6db43..f337289 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -2,6 +2,7 @@ use crate::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; @@ -138,5 +139,15 @@ pub fn run( } } + // Save + let path = "models/burn_dqn".to_string(); + let inference_network = agent.model().clone().into_record(); + let recorder = CompactRecorder::new(); + let model_path = format!("{}_model.burn", path); + println!("Modèle sauvegardé : {}", model_path); + recorder + .record(inference_network, model_path.into()) + .unwrap(); + agent.valid() } diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index ef5da61..6e55928 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,16 +1,13 @@ +use bot::burnrl::{dqn_model, environment, utils::demo_model}; use burn::backend::{Autodiff, NdArray}; use burn_rl::base::ElemType; -use bot::burnrl::{ - dqn_model, - environment, - utils::demo_model, -}; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; fn main() { - let agent = dqn_model::run::(512, false); //true); + let num_episodes = 3; + let agent = dqn_model::run::(num_episodes, false); //true); demo_model::(agent); } diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index 7cfb165..bc8d836 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -12,6 +12,7 @@ pub fn demo_model(agent: impl Agent) { if let Some(action) = agent.react(&state) { let snapshot = env.step(action); state = *snapshot.state(); + // println!("{:?}", state); done = snapshot.done(); } } diff --git a/doc/backlog.md b/doc/backlog.md index b92c6d1..f41b9b7 100644 --- a/doc/backlog.md +++ b/doc/backlog.md @@ -1,11 +1,17 @@ # Backlog -position dans tutoriel : - ## DONE ## TODO +- bot burn + - train = `just trainbot` + - durée d'entrainement selon params ? + - save + - load and run against default bot + - many configs, save models selon config + - retrain against himself ? + ### Doc Cheatsheet : arbre des situations et priorité des règles diff --git a/justfile b/justfile index bb1d86e..305abed 100644 --- a/justfile +++ b/justfile @@ -21,4 +21,5 @@ trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok # cargo run --bin=train_burn_rl # doesn't save model - cargo run --bin=train_dqn_full + # cargo run --bin=train_dqn_full + cargo run --bin=train_dqn_burn