script train bots

This commit is contained in:
Henri Bourcereau 2025-08-10 15:32:41 +02:00
parent e4b3092018
commit 778ac1817b
7 changed files with 191 additions and 30 deletions

38
bot/scripts/train.sh Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env sh
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=12
OPPONENT="random"
PLOT_EXT="png"
train() {
cargo build --release --bin=train_dqn_burn
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$NAME.out"
mkdir -p "$LOGS_DIR"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
}
plot() {
NAME=$(ls "$LOGS_DIR" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
for cfg in $cfgs; do
eval "$cfg"
done
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
}
if [ "$1" = "plot" ]; then
plot
else
train
fi

View file

@ -1,3 +1,4 @@
use crate::dqn::burnrl::environment::TrictracEnvironment;
use crate::dqn::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::fmt;
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub min_steps: f32,
pub max_steps: usize,
pub num_episodes: usize,
// pub memory_size: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
pub gamma: f32,
pub tau: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub clip_grad: f32,
}
impl fmt::Display for DqnConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
write!(f, "{s}")
}
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
min_steps: 250.0,
max_steps: 2000,
num_episodes: 1000,
// memory_size: 8192,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
}
}
}
@ -85,12 +120,14 @@ impl Default for DqnConfig {
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment, B: AutodiffBackend>(
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig,
visualized: bool,
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut agent = MyAgent::new(model);
let config = DQNTrainingConfig::default();
// let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
@ -145,12 +191,12 @@ pub fn run<E: Environment, B: AutodiffBackend>(
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= E::MAX_STEPS {
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();

View file

@ -84,7 +84,10 @@ pub struct TrictracEnvironment {
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub min_steps: f32,
pub max_steps: usize,
pub goodmoves_count: usize,
pub goodmoves_ratio: f32,
pub visualized: bool,
}
@ -93,8 +96,6 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction;
type RewardType = f32;
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
@ -115,7 +116,10 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
min_steps: 250.0,
max_steps: 2000,
goodmoves_count: 0,
goodmoves_ratio: 0.0,
visualized,
}
}
@ -135,10 +139,15 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.goodmoves_ratio = if self.step_count == 0 {
0.0
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
println!(
"correct moves: {} ({}%)",
"info: correct moves: {} ({}%)",
self.goodmoves_count,
100 * self.goodmoves_count / self.step_count
(100.0 * self.goodmoves_ratio).round() as u32
);
self.step_count = 0;
self.goodmoves_count = 0;
@ -174,12 +183,12 @@ impl Environment for TrictracEnvironment {
}
// Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended
|| self.game.determine_winner().is_some()
|| self.step_count >= Self::MAX_STEPS;
let max_steps = self.min_steps
+ (self.max_steps as f32 - self.min_steps)
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
terminated = true;
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
@ -189,6 +198,7 @@ impl Environment for TrictracEnvironment {
}
}
}
let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
@ -320,7 +330,7 @@ impl TrictracEnvironment {
let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
println!("rolled for {reward}");
println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
}
@ -421,3 +431,9 @@ impl TrictracEnvironment {
reward
}
}
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

View file

@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
max_steps: 2000, // max steps by episode
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 1500.0,
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 3000.0,
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();