This commit is contained in:
Henri Bourcereau 2025-07-23 17:25:05 +02:00
parent 354dcfd341
commit c6d33555ec
5 changed files with 25 additions and 9 deletions

View file

@ -2,6 +2,7 @@ use crate::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
@ -138,5 +139,15 @@ pub fn run<E: Environment, B: AutodiffBackend>(
}
}
// Save
let path = "models/burn_dqn".to_string();
let inference_network = agent.model().clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
println!("Modèle sauvegardé : {}", model_path);
recorder
.record(inference_network, model_path.into())
.unwrap();
agent.valid()
}

View file

@ -1,16 +1,13 @@
use bot::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{Autodiff, NdArray};
use burn_rl::base::ElemType;
use bot::burnrl::{
dqn_model,
environment,
utils::demo_model,
};
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
let agent = dqn_model::run::<Env, Backend>(512, false); //true);
let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
demo_model::<Env>(agent);
}

View file

@ -12,6 +12,7 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
if let Some(action) = agent.react(&state) {
let snapshot = env.step(action);
state = *snapshot.state();
// println!("{:?}", state);
done = snapshot.done();
}
}

View file

@ -1,11 +1,17 @@
# Backlog
position dans tutoriel :
## DONE
## TODO
- bot burn
- train = `just trainbot`
- durée d'entrainement selon params ?
- save
- load and run against default bot
- many configs, save models selon config
- retrain against himself ?
### Doc
Cheatsheet : arbre des situations et priorité des règles

View file

@ -21,4 +21,5 @@ trainbot:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok
# cargo run --bin=train_burn_rl # doesn't save model
cargo run --bin=train_dqn_full
# cargo run --bin=train_dqn_full
cargo run --bin=train_dqn_burn