script train bots

This commit is contained in:
Henri Bourcereau 2025-08-10 15:32:41 +02:00
parent e4b3092018
commit 778ac1817b
7 changed files with 191 additions and 30 deletions

4
.gitignore vendored
View file

@ -11,6 +11,4 @@ devenv.local.nix
# generated by samply rust profiler
profile.json
# IA modles used by bots
/models
bot/models

38
bot/scripts/train.sh Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env sh
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=12
OPPONENT="random"
PLOT_EXT="png"
train() {
cargo build --release --bin=train_dqn_burn
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$NAME.out"
mkdir -p "$LOGS_DIR"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
}
plot() {
NAME=$(ls "$LOGS_DIR" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
for cfg in $cfgs; do
eval "$cfg"
done
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
}
if [ "$1" = "plot" ]; then
plot
else
train
fi

View file

@ -1,3 +1,4 @@
use crate::dqn::burnrl::environment::TrictracEnvironment;
use crate::dqn::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::fmt;
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub min_steps: f32,
pub max_steps: usize,
pub num_episodes: usize,
// pub memory_size: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
pub gamma: f32,
pub tau: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub clip_grad: f32,
}
impl fmt::Display for DqnConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
write!(f, "{s}")
}
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
min_steps: 250.0,
max_steps: 2000,
num_episodes: 1000,
// memory_size: 8192,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
}
}
}
@ -85,12 +120,14 @@ impl Default for DqnConfig {
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment, B: AutodiffBackend>(
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig,
visualized: bool,
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut agent = MyAgent::new(model);
let config = DQNTrainingConfig::default();
// let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
@ -145,12 +191,12 @@ pub fn run<E: Environment, B: AutodiffBackend>(
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= E::MAX_STEPS {
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();

View file

@ -84,7 +84,10 @@ pub struct TrictracEnvironment {
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub min_steps: f32,
pub max_steps: usize,
pub goodmoves_count: usize,
pub goodmoves_ratio: f32,
pub visualized: bool,
}
@ -93,8 +96,6 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction;
type RewardType = f32;
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
@ -115,7 +116,10 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
min_steps: 250.0,
max_steps: 2000,
goodmoves_count: 0,
goodmoves_ratio: 0.0,
visualized,
}
}
@ -135,10 +139,15 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.goodmoves_ratio = if self.step_count == 0 {
0.0
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
println!(
"correct moves: {} ({}%)",
"info: correct moves: {} ({}%)",
self.goodmoves_count,
100 * self.goodmoves_count / self.step_count
(100.0 * self.goodmoves_ratio).round() as u32
);
self.step_count = 0;
self.goodmoves_count = 0;
@ -174,12 +183,12 @@ impl Environment for TrictracEnvironment {
}
// Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended
|| self.game.determine_winner().is_some()
|| self.step_count >= Self::MAX_STEPS;
let max_steps = self.min_steps
+ (self.max_steps as f32 - self.min_steps)
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
terminated = true;
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
@ -189,6 +198,7 @@ impl Environment for TrictracEnvironment {
}
}
}
let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
@ -320,7 +330,7 @@ impl TrictracEnvironment {
let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
println!("rolled for {reward}");
println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
}
@ -421,3 +431,9 @@ impl TrictracEnvironment {
reward
}
}
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

View file

@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
max_steps: 2000, // max steps by episode
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 1500.0,
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 3000.0,
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();

View file

@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
nouvelles (EPS*\*).
## Paramètres DQNTrainingConfig
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
gain plus grand plus tard.
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
0.99 l'encourage à élaborer des stratégies à long terme.
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
réseau cible.
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
coup de chance (ou de malchance).
3. `learning_rate` (Taux d'apprentissage)
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
lent.
4. `batch_size` (Taille du lot)
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
ses poids.
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
particulier qui pourrait être une anomalie.
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
réseau a appris.
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
mais on va corriger calmement, sans tout casser".

View file

@ -9,8 +9,8 @@ shell:
runcli:
RUST_LOG=info cargo run --bin=client_cli
runclibots:
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match:
cargo build --release --bin=client_cli
@ -28,10 +28,9 @@ trainsimple:
trainbot:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok
cargo build --release --bin=train_dqn_burn
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
./bot/scripts/train.sh
plottrainbot:
cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png
./bot/scripts/train.sh plot
debugtrainbot:
cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn