fix: train bot dqn burnrl : extract config

This commit is contained in:
Henri Bourcereau 2025-08-02 12:42:32 +02:00
parent ad5ae17168
commit 28c2aa836f
4 changed files with 46 additions and 24 deletions

View file

@ -7,7 +7,7 @@ edition = "2021"
[[bin]] [[bin]]
name = "train_dqn_burn" name = "train_dqn_burn"
path = "src/burnrl/main.rs" path = "src/dqn/burnrl/main.rs"
[[bin]] [[bin]]
name = "train_dqn" name = "train_dqn"

View file

@ -58,17 +58,35 @@ impl<B: Backend> DQNModel<B> for Net<B> {
} }
#[allow(unused)] #[allow(unused)]
const MEMORY_SIZE: usize = 4096; const MEMORY_SIZE: usize = 8192;
const DENSE_SIZE: usize = 128;
const EPS_DECAY: f64 = 1000.0; pub struct DqnConfig {
const EPS_START: f64 = 0.9; pub num_episodes: usize,
const EPS_END: f64 = 0.05; // pub memory_size: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
num_episodes: 1000,
// memory_size: 8192,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
}
}
}
type MyAgent<E, B> = DQN<E, B, Net<B>>; type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)] #[allow(unused)]
pub fn run<E: Environment, B: AutodiffBackend>( pub fn run<E: Environment, B: AutodiffBackend>(
num_episodes: usize, conf: &DqnConfig,
visualized: bool, visualized: bool,
) -> DQN<E, B, Net<B>> { ) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> { // ) -> impl Agent<E> {
@ -76,7 +94,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let model = Net::<B>::new( let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(), <<E as Environment>::StateType as State>::size(),
DENSE_SIZE, conf.dense_size,
<<E as Environment>::ActionType as Action>::size(), <<E as Environment>::ActionType as Action>::size(),
); );
@ -94,7 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut step = 0_usize; let mut step = 0_usize;
for episode in 0..num_episodes { for episode in 0..conf.num_episodes {
let mut episode_done = false; let mut episode_done = false;
let mut episode_reward: ElemType = 0.0; let mut episode_reward: ElemType = 0.0;
let mut episode_duration = 0_usize; let mut episode_duration = 0_usize;
@ -102,8 +120,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut now = SystemTime::now(); let mut now = SystemTime::now();
while !episode_done { while !episode_done {
let eps_threshold = let eps_threshold = conf.eps_end
EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY); + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
let action = let action =
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold); DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action); let snapshot = env.step(action);

View file

@ -91,8 +91,7 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction; type ActionType = TrictracAction;
type RewardType = f32; type RewardType = f32;
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self { fn new(visualized: bool) -> Self {
let mut game = GameState::new(false); let mut game = GameState::new(false);
@ -260,7 +259,7 @@ impl TrictracEnvironment {
// } // }
TrictracAction::Go => { TrictracAction::Go => {
// Continuer après avoir gagné un trou // Continuer après avoir gagné un trou
reward += 0.2; reward += 0.4;
Some(GameEvent::Go { Some(GameEvent::Go {
player_id: self.active_player_id, player_id: self.active_player_id,
}) })
@ -289,7 +288,7 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2; reward += 0.4;
Some(GameEvent::Move { Some(GameEvent::Move {
player_id: self.active_player_id, player_id: self.active_player_id,
moves: (checker_move1, checker_move2), moves: (checker_move1, checker_move2),

View file

@ -1,4 +1,4 @@
use bot::burnrl::{dqn_model, environment, utils::demo_model}; use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::Module; use burn::module::Module;
use burn::record::{CompactRecorder, Recorder}; use burn::record::{CompactRecorder, Recorder};
@ -10,8 +10,16 @@ type Env = environment::TrictracEnvironment;
fn main() { fn main() {
println!("> Entraînement"); println!("> Entraînement");
let num_episodes = 50; let conf = dqn_model::DqnConfig {
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true); num_episodes: 50,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 1000.0,
};
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid(); let valid_agent = agent.valid();
@ -24,7 +32,7 @@ fn main() {
// demo_model::<Env>(valid_agent); // demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(&path); let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model); let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé"); println!("> Test avec le modèle chargé");
@ -40,10 +48,7 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
.unwrap(); .unwrap();
} }
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> { fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
const DENSE_SIZE: usize = 128;
let model_path = format!("{}_model.mpk", path); let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path); println!("Chargement du modèle depuis : {}", model_path);
@ -56,7 +61,7 @@ fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
dqn_model::Net::new( dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(), <environment::TrictracEnvironment as Environment>::StateType::size(),
DENSE_SIZE, dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(), <environment::TrictracEnvironment as Environment>::ActionType::size(),
) )
.load_record(record) .load_record(record)