refacto: bot directories

This commit is contained in:
Henri Bourcereau 2025-08-19 16:27:37 +02:00
parent e66921fcce
commit fcd50bc0f2
27 changed files with 110 additions and 94 deletions

View file

@ -7,19 +7,19 @@ edition = "2021"
[[bin]]
name = "train_dqn_burn_valid"
path = "src/dqn/burnrl_valid/main.rs"
path = "src/burnrl/dqn_valid/main.rs"
[[bin]]
name = "train_dqn_burn_big"
path = "src/dqn/burnrl_big/main.rs"
path = "src/burnrl/dqn_big/main.rs"
[[bin]]
name = "train_dqn_burn"
path = "src/dqn/burnrl/main.rs"
path = "src/burnrl/dqn/main.rs"
[[bin]]
name = "train_dqn_simple"
path = "src/dqn/simple/main.rs"
path = "src/dqn_simple/main.rs"
[dependencies]
pretty_assertions = "1.4.0"

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl_big::environment::TrictracEnvironment;
use crate::dqn::burnrl_big::utils::soft_update_linear;
use crate::burnrl::dqn::utils::soft_update_linear;
use crate::burnrl::environment::TrictracEnvironment;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
@ -126,7 +126,7 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps;
// env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
@ -193,12 +193,17 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
* 100.0)
.round() as u32;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}",
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
envmut.goodmoves_count,
goodmoves_ratio,
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
if goodmoves_ratio < 5 && 10 < episode {}
env.reset();
episode_done = true;
now = SystemTime::now();

View file

@ -1,7 +1,8 @@
use bot::dqn::burnrl::{
dqn_model, environment,
use bot::burnrl::dqn::{
dqn_model,
utils::{demo_model, load_model, save_model},
};
use bot::burnrl::environment;
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;
@ -15,9 +16,9 @@ fn main() {
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
// defaults
num_episodes: 40, // 40
num_episodes: 50, // 40
min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 2000, // 1000 max steps by episode
max_steps: 1000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
@ -31,8 +32,8 @@ fn main() {
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100)
batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);

View file

@ -1,3 +1,2 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use crate::burnrl::dqn::dqn_model;
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::training_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl::environment::TrictracEnvironment;
use crate::dqn::burnrl::utils::soft_update_linear;
use crate::burnrl::dqn_big::utils::soft_update_linear;
use crate::burnrl::environment_big::TrictracEnvironment;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;

View file

@ -1,13 +1,14 @@
use bot::dqn::burnrl_big::{
dqn_model, environment,
use bot::burnrl::dqn_big::{
dqn_model,
utils::{demo_model, load_model, save_model},
};
use bot::burnrl::environment_big;
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
type Env = environment_big::TrictracEnvironment;
fn main() {
// println!("> Entraînement");

View file

@ -1,3 +1,2 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl_valid::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use crate::burnrl::dqn_big::dqn_model;
use crate::burnrl::environment_big::{TrictracAction, TrictracEnvironment};
use crate::training_common_big::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl_valid::environment::TrictracEnvironment;
use crate::dqn::burnrl_valid::utils::soft_update_linear;
use crate::burnrl::dqn_valid::utils::soft_update_linear;
use crate::burnrl::environment::TrictracEnvironment;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;

View file

@ -1,7 +1,8 @@
use bot::dqn::burnrl_valid::{
dqn_model, environment,
use bot::burnrl::dqn_valid::{
dqn_model,
utils::{demo_model, load_model, save_model},
};
use bot::burnrl::environment;
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;

View file

@ -1,3 +1,2 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl_big::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common_big::get_valid_action_indices;
use crate::burnrl::dqn_valid::dqn_model;
use crate::burnrl::environment_valid::{TrictracAction, TrictracEnvironment};
use crate::training_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;

View file

@ -1,13 +1,15 @@
use crate::dqn::dqn_common;
use std::io::Write;
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const ERROR_REWARD: f32 = -1.12121;
const REWARD_VALID_MOVE: f32 = 1.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
const WIN_POINTS: f32 = 1.0;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
@ -89,7 +91,7 @@ pub struct TrictracEnvironment {
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub min_steps: f32,
pub best_ratio: f32,
pub max_steps: usize,
pub pointrolls_count: usize,
pub goodmoves_count: usize,
@ -122,7 +124,7 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
min_steps: 250.0,
best_ratio: 0.0,
max_steps: 2000,
pointrolls_count: 0,
goodmoves_count: 0,
@ -151,10 +153,21 @@ impl Environment for TrictracEnvironment {
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{:?}", self.game.history);
}
"!!!!"
} else {
""
};
println!(
"info: correct moves: {} ({}%)",
"info: correct moves: {} ({}%) {}",
self.goodmoves_count,
(100.0 * self.goodmoves_ratio).round() as u32
(100.0 * self.goodmoves_ratio).round() as u32,
warning
);
self.step_count = 0;
self.pointrolls_count = 0;
@ -195,9 +208,10 @@ impl Environment for TrictracEnvironment {
}
// Vérifier si la partie est terminée
let max_steps = self.min_steps
+ (self.max_steps as f32 - self.min_steps)
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
// let max_steps = self.max_steps;
// let max_steps = self.min_steps
// + (self.max_steps as f32 - self.min_steps)
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
@ -210,7 +224,8 @@ impl Environment for TrictracEnvironment {
}
}
}
let terminated = done || self.step_count >= max_steps.round() as usize;
let terminated = done || self.step_count >= self.max_steps;
// let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
@ -229,8 +244,8 @@ impl Environment for TrictracEnvironment {
impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
@ -239,8 +254,8 @@ impl TrictracEnvironment {
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<dqn_common::TrictracAction> {
use dqn_common::get_valid_actions;
) -> Option<training_common::TrictracAction> {
use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state);
@ -257,10 +272,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: dqn_common::TrictracAction,
// action: training_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
use dqn_common::TrictracAction;
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
use training_common::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common_big;
use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
@ -229,8 +229,8 @@ impl Environment for TrictracEnvironment {
impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
@ -239,8 +239,8 @@ impl TrictracEnvironment {
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<dqn_common_big::TrictracAction> {
use dqn_common_big::get_valid_actions;
) -> Option<training_common_big::TrictracAction> {
use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state);
@ -257,10 +257,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action:dqn_common_big::TrictracAction,
// action:training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) {
use dqn_common_big::TrictracAction;
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use training_common_big::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common_big;
use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
@ -214,16 +214,16 @@ impl TrictracEnvironment {
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
fn convert_valid_action_index(
&self,
action: TrictracAction,
) -> Option<dqn_common_big::TrictracAction> {
use dqn_common_big::get_valid_actions;
) -> Option<training_common_big::TrictracAction> {
use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(&self.game);
@ -240,10 +240,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: dqn_common_big::TrictracAction,
// action: training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) {
use dqn_common_big::TrictracAction;
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use training_common_big::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;

6
bot/src/burnrl/mod.rs Normal file
View file

@ -0,0 +1,6 @@
pub mod dqn;
pub mod dqn_big;
pub mod dqn_valid;
pub mod environment;
pub mod environment_big;
pub mod environment_valid;

View file

@ -1,7 +0,0 @@
pub mod burnrl;
pub mod burnrl_big;
pub mod dqn_common;
pub mod dqn_common_big;
pub mod simple;
pub mod burnrl_valid;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common::TrictracAction;
use crate::training_common_big::TrictracAction;
use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN
@ -151,4 +151,3 @@ impl SimpleNeuralNetwork {
Ok(network)
}
}

View file

@ -6,7 +6,7 @@ use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
use crate::dqn::dqn_common_big::{get_valid_actions, TrictracAction};
use crate::training_common_big::{get_valid_actions, TrictracAction};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,6 +1,6 @@
use bot::dqn::dqn_common::TrictracAction;
use bot::dqn::simple::dqn_model::DqnConfig;
use bot::dqn::simple::dqn_trainer::DqnTrainer;
use bot::dqn_simple::dqn_model::DqnConfig;
use bot::dqn_simple::dqn_trainer::DqnTrainer;
use bot::training_common::TrictracAction;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {

View file

@ -1,5 +1,8 @@
pub mod dqn;
pub mod burnrl;
pub mod dqn_simple;
pub mod strategy;
pub mod training_common;
pub mod training_common_big;
use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};

View file

@ -3,8 +3,8 @@ use log::info;
use std::path::Path;
use store::MoveRules;
use crate::dqn::dqn_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork;
use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]

View file

@ -6,8 +6,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info;
use store::MoveRules;
use crate::dqn::burnrl::{dqn_model, environment, utils};
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
use crate::burnrl::dqn::{dqn_model, utils};
use crate::burnrl::environment;
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;