Compare commits

...

2 commits

Author SHA1 Message Date
Henri Bourcereau 778ac1817b script train bots 2025-08-10 15:35:12 +02:00
Henri Bourcereau e4b3092018 train burn-rl with integers 2025-08-10 08:39:31 +02:00
9 changed files with 229 additions and 85 deletions

4
.gitignore vendored
View file

@ -11,6 +11,4 @@ devenv.local.nix
# generated by samply rust profiler
profile.json
# IA modles used by bots
/models
bot/models

View file

@ -1,26 +0,0 @@
# Trictrac Project Guidelines
## Build & Run Commands
- Build: `cargo build`
- Test: `cargo test`
- Test specific: `cargo test -- test_name`
- Lint: `cargo clippy`
- Format: `cargo fmt`
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
- Build Python lib: `maturin build -m store/Cargo.toml --release`
## Code Style
- Use Rust 2021 edition idioms
- Error handling: Use Result<T, Error> pattern with custom Error types
- Naming: snake_case for functions/variables, CamelCase for types
- Imports: Group standard lib, external crates, then internal modules
- Module structure: Prefer small, focused modules with clear responsibilities
- Documentation: Document public APIs with doc comments
- Testing: Write unit tests in same file as implementation
- Python bindings: Use pyo3 for creating Python modules
## Architecture
- Core game logic in `store` crate
- Multiple clients: CLI, TUI, Bevy (graphical)
- Bot interfaces in `bot` crate

38
bot/scripts/train.sh Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env sh
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=12
OPPONENT="random"
PLOT_EXT="png"
train() {
cargo build --release --bin=train_dqn_burn
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$NAME.out"
mkdir -p "$LOGS_DIR"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
}
plot() {
NAME=$(ls "$LOGS_DIR" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
for cfg in $cfgs; do
eval "$cfg"
done
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
}
if [ "$1" = "plot" ]; then
plot
else
train
fi

View file

@ -1,3 +1,4 @@
use crate::dqn::burnrl::environment::TrictracEnvironment;
use crate::dqn::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::fmt;
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub min_steps: f32,
pub max_steps: usize,
pub num_episodes: usize,
// pub memory_size: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
pub gamma: f32,
pub tau: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub clip_grad: f32,
}
impl fmt::Display for DqnConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
write!(f, "{s}")
}
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
min_steps: 250.0,
max_steps: 2000,
num_episodes: 1000,
// memory_size: 8192,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
}
}
}
@ -85,12 +120,14 @@ impl Default for DqnConfig {
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment, B: AutodiffBackend>(
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig,
visualized: bool,
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut agent = MyAgent::new(model);
let config = DQNTrainingConfig::default();
// let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
@ -145,16 +191,13 @@ pub fn run<E: Environment, B: AutodiffBackend>(
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= E::MAX_STEPS {
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
episode,
episode_reward,
episode_duration,
now.elapsed().unwrap().as_secs()
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
} else {

View file

@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
}
impl State for TrictracState {
type Data = [f32; 36];
type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default())
@ -25,8 +25,8 @@ impl State for TrictracState {
impl TrictracState {
/// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec_float();
let mut data = [0.0; 36];
let state_vec = game_state.to_vec();
let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36);
@ -39,6 +39,7 @@ impl TrictracState {
/// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32,
}
@ -82,7 +83,11 @@ pub struct TrictracEnvironment {
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
step_count: usize,
pub step_count: usize,
pub min_steps: f32,
pub max_steps: usize,
pub goodmoves_count: usize,
pub goodmoves_ratio: f32,
pub visualized: bool,
}
@ -91,8 +96,6 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction;
type RewardType = f32;
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
@ -113,6 +116,10 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
min_steps: 250.0,
max_steps: 2000,
goodmoves_count: 0,
goodmoves_ratio: 0.0,
visualized,
}
}
@ -132,7 +139,18 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.goodmoves_ratio = if self.step_count == 0 {
0.0
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
println!(
"info: correct moves: {} ({}%)",
self.goodmoves_count,
(100.0 * self.goodmoves_ratio).round() as u32
);
self.step_count = 0;
self.goodmoves_count = 0;
Snapshot::new(self.current_state, 0.0, false)
}
@ -149,14 +167,9 @@ impl Environment for TrictracEnvironment {
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
match self.execute_action(action) {
Ok(action_reward) => {
reward = action_reward;
}
Err(_) => {
// Action invalide, pénalité
reward = -1.0;
}
reward = self.execute_action(action);
if reward != Self::ERROR_REWARD {
self.goodmoves_count += 1;
}
} else {
// Action non convertible, pénalité
@ -170,12 +183,12 @@ impl Environment for TrictracEnvironment {
}
// Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended
|| self.game.determine_winner().is_some()
|| self.step_count >= Self::MAX_STEPS;
let max_steps = self.min_steps
+ (self.max_steps as f32 - self.min_steps)
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
terminated = true;
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
@ -185,6 +198,7 @@ impl Environment for TrictracEnvironment {
}
}
}
let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
@ -202,6 +216,9 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
@ -228,10 +245,11 @@ impl TrictracEnvironment {
}
/// Exécute une action Trictrac dans le jeu
fn execute_action(
&mut self,
action: dqn_common::TrictracAction,
) -> Result<f32, Box<dyn std::error::Error>> {
// fn execute_action(
// &mut self,
// action: dqn_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
use dqn_common::TrictracAction;
let mut reward = 0.0;
@ -310,16 +328,22 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
}
}
} else {
// Pénalité pour action invalide
reward -= 2.0;
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
}
}
Ok(reward)
reward
}
/// Fait jouer l'adversaire avec une stratégie simple
@ -329,15 +353,14 @@ impl TrictracEnvironment {
// Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire
use crate::strategy::default::DefaultStrategy;
use crate::BotStrategy;
let mut default_strategy = DefaultStrategy::default();
default_strategy.set_player_id(self.opponent_id);
let mut strategy = crate::strategy::random::RandomStrategy::default();
strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
default_strategy.set_color(color);
strategy.set_color(color);
}
*default_strategy.get_mut_game() = self.game.clone();
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let event = match self.game.turn_stage {
@ -365,7 +388,7 @@ impl TrictracEnvironment {
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_id,
@ -397,7 +420,7 @@ impl TrictracEnvironment {
}
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: default_strategy.choose_move(),
moves: strategy.choose_move(),
},
};
@ -408,3 +431,9 @@ impl TrictracEnvironment {
reward
}
}
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

View file

@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
max_steps: 2000, // max steps by episode
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 3000.0,
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();

View file

@ -10,8 +10,8 @@ MEMORY_SIZE
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
- Pourquoi c'est important :
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
DENSE_SIZE
@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
nouvelles (EPS*\*).
## Paramètres DQNTrainingConfig
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
gain plus grand plus tard.
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
0.99 l'encourage à élaborer des stratégies à long terme.
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
réseau cible.
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
coup de chance (ou de malchance).
3. `learning_rate` (Taux d'apprentissage)
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
lent.
4. `batch_size` (Taille du lot)
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
ses poids.
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
particulier qui pourrait être une anomalie.
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
réseau a appris.
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
mais on va corriger calmement, sans tout casser".

View file

@ -9,8 +9,8 @@ shell:
runcli:
RUST_LOG=info cargo run --bin=client_cli
runclibots:
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match:
cargo build --release --bin=client_cli
@ -28,12 +28,9 @@ trainsimple:
trainbot:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
cargo build --release --bin=train_dqn_burn
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
./bot/scripts/train.sh
plottrainbot:
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
./bot/scripts/train.sh plot
debugtrainbot:
cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn

View file

@ -151,6 +151,7 @@ impl GameState {
/// Get state as a vector (to be used for bot training input) :
/// length = 36
/// i8 for board positions with negative values for blacks
pub fn to_vec(&self) -> Vec<i8> {
let state_len = 36;
let mut state = Vec::with_capacity(state_len);