Compare commits
2 commits
5b02293221
...
778ac1817b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
778ac1817b | ||
|
|
e4b3092018 |
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -11,6 +11,4 @@ devenv.local.nix
|
|||
|
||||
# generated by samply rust profiler
|
||||
profile.json
|
||||
|
||||
# IA modles used by bots
|
||||
/models
|
||||
bot/models
|
||||
|
|
|
|||
26
CLAUDE.md
26
CLAUDE.md
|
|
@ -1,26 +0,0 @@
|
|||
# Trictrac Project Guidelines
|
||||
|
||||
## Build & Run Commands
|
||||
- Build: `cargo build`
|
||||
- Test: `cargo test`
|
||||
- Test specific: `cargo test -- test_name`
|
||||
- Lint: `cargo clippy`
|
||||
- Format: `cargo fmt`
|
||||
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
|
||||
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
|
||||
- Build Python lib: `maturin build -m store/Cargo.toml --release`
|
||||
|
||||
## Code Style
|
||||
- Use Rust 2021 edition idioms
|
||||
- Error handling: Use Result<T, Error> pattern with custom Error types
|
||||
- Naming: snake_case for functions/variables, CamelCase for types
|
||||
- Imports: Group standard lib, external crates, then internal modules
|
||||
- Module structure: Prefer small, focused modules with clear responsibilities
|
||||
- Documentation: Document public APIs with doc comments
|
||||
- Testing: Write unit tests in same file as implementation
|
||||
- Python bindings: Use pyo3 for creating Python modules
|
||||
|
||||
## Architecture
|
||||
- Core game logic in `store` crate
|
||||
- Multiple clients: CLI, TUI, Bevy (graphical)
|
||||
- Bot interfaces in `bot` crate
|
||||
38
bot/scripts/train.sh
Executable file
38
bot/scripts/train.sh
Executable file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||
LOGS_DIR="$ROOT/bot/models/logs"
|
||||
|
||||
CFG_SIZE=12
|
||||
OPPONENT="random"
|
||||
|
||||
PLOT_EXT="png"
|
||||
|
||||
train() {
|
||||
cargo build --release --bin=train_dqn_burn
|
||||
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
|
||||
LOGS="$LOGS_DIR/$NAME.out"
|
||||
mkdir -p "$LOGS_DIR"
|
||||
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
|
||||
}
|
||||
|
||||
plot() {
|
||||
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||
LOGS="$LOGS_DIR/$NAME"
|
||||
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||
for cfg in $cfgs; do
|
||||
eval "$cfg"
|
||||
done
|
||||
|
||||
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
|
||||
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||
grep -v "info:" |
|
||||
awk -F '[ ,]' '{print $5}' |
|
||||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||
}
|
||||
|
||||
if [ "$1" = "plot" ]; then
|
||||
plot
|
||||
else
|
||||
train
|
||||
fi
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
|
|
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
|
|||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||
use std::fmt;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
|||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub struct DqnConfig {
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub num_episodes: usize,
|
||||
// pub memory_size: usize,
|
||||
pub dense_size: usize,
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
pub eps_decay: f64,
|
||||
|
||||
pub gamma: f32,
|
||||
pub tau: f32,
|
||||
pub learning_rate: f32,
|
||||
pub batch_size: usize,
|
||||
pub clip_grad: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for DqnConfig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut s = String::new();
|
||||
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
num_episodes: 1000,
|
||||
// memory_size: 8192,
|
||||
dense_size: 256,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
gamma: 0.999,
|
||||
tau: 0.005,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
clip_grad: 100.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -85,12 +120,14 @@ impl Default for DqnConfig {
|
|||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
conf: &DqnConfig,
|
||||
visualized: bool,
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().min_steps = conf.min_steps;
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
|
|
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
|
||||
let mut agent = MyAgent::new(model);
|
||||
|
||||
let config = DQNTrainingConfig::default();
|
||||
// let config = DQNTrainingConfig::default();
|
||||
let config = DQNTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
|
|
@ -145,16 +191,13 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= E::MAX_STEPS {
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
|
||||
println!(
|
||||
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
|
||||
episode,
|
||||
episode_reward,
|
||||
episode_duration,
|
||||
now.elapsed().unwrap().as_secs()
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
now = SystemTime::now();
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
|||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TrictracState {
|
||||
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
|
||||
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||
}
|
||||
|
||||
impl State for TrictracState {
|
||||
type Data = [f32; 36];
|
||||
type Data = [i8; 36];
|
||||
|
||||
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
||||
Tensor::from_floats(self.data, &B::Device::default())
|
||||
|
|
@ -25,8 +25,8 @@ impl State for TrictracState {
|
|||
impl TrictracState {
|
||||
/// Convertit un GameState en TrictracState
|
||||
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||
let state_vec = game_state.to_vec_float();
|
||||
let mut data = [0.0; 36];
|
||||
let state_vec = game_state.to_vec();
|
||||
let mut data = [0; 36];
|
||||
|
||||
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||
let copy_len = state_vec.len().min(36);
|
||||
|
|
@ -39,6 +39,7 @@ impl TrictracState {
|
|||
/// Actions possibles dans Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct TrictracAction {
|
||||
// u32 as required by burn_rl::base::Action type
|
||||
pub index: u32,
|
||||
}
|
||||
|
||||
|
|
@ -82,7 +83,11 @@ pub struct TrictracEnvironment {
|
|||
opponent_id: PlayerId,
|
||||
current_state: TrictracState,
|
||||
episode_reward: f32,
|
||||
step_count: usize,
|
||||
pub step_count: usize,
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub goodmoves_count: usize,
|
||||
pub goodmoves_ratio: f32,
|
||||
pub visualized: bool,
|
||||
}
|
||||
|
||||
|
|
@ -91,8 +96,6 @@ impl Environment for TrictracEnvironment {
|
|||
type ActionType = TrictracAction;
|
||||
type RewardType = f32;
|
||||
|
||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
||||
|
||||
fn new(visualized: bool) -> Self {
|
||||
let mut game = GameState::new(false);
|
||||
|
||||
|
|
@ -113,6 +116,10 @@ impl Environment for TrictracEnvironment {
|
|||
current_state,
|
||||
episode_reward: 0.0,
|
||||
step_count: 0,
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
goodmoves_count: 0,
|
||||
goodmoves_ratio: 0.0,
|
||||
visualized,
|
||||
}
|
||||
}
|
||||
|
|
@ -132,7 +139,18 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward = 0.0;
|
||||
self.goodmoves_ratio = if self.step_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
println!(
|
||||
"info: correct moves: {} ({}%)",
|
||||
self.goodmoves_count,
|
||||
(100.0 * self.goodmoves_ratio).round() as u32
|
||||
);
|
||||
self.step_count = 0;
|
||||
self.goodmoves_count = 0;
|
||||
|
||||
Snapshot::new(self.current_state, 0.0, false)
|
||||
}
|
||||
|
|
@ -149,14 +167,9 @@ impl Environment for TrictracEnvironment {
|
|||
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||
if self.game.active_player_id == self.active_player_id {
|
||||
if let Some(action) = trictrac_action {
|
||||
match self.execute_action(action) {
|
||||
Ok(action_reward) => {
|
||||
reward = action_reward;
|
||||
}
|
||||
Err(_) => {
|
||||
// Action invalide, pénalité
|
||||
reward = -1.0;
|
||||
}
|
||||
reward = self.execute_action(action);
|
||||
if reward != Self::ERROR_REWARD {
|
||||
self.goodmoves_count += 1;
|
||||
}
|
||||
} else {
|
||||
// Action non convertible, pénalité
|
||||
|
|
@ -170,12 +183,12 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game.stage == Stage::Ended
|
||||
|| self.game.determine_winner().is_some()
|
||||
|| self.step_count >= Self::MAX_STEPS;
|
||||
let max_steps = self.min_steps
|
||||
+ (self.max_steps as f32 - self.min_steps)
|
||||
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||
|
||||
if done {
|
||||
terminated = true;
|
||||
// Récompense finale basée sur le résultat
|
||||
if let Some(winner_id) = self.game.determine_winner() {
|
||||
if winner_id == self.active_player_id {
|
||||
|
|
@ -185,6 +198,7 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
}
|
||||
}
|
||||
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
|
||||
// Mettre à jour l'état
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
|
|
@ -202,6 +216,9 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
impl TrictracEnvironment {
|
||||
const ERROR_REWARD: f32 = -1.12121;
|
||||
const REWARD_RATIO: f32 = 1.0;
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
|
|
@ -228,10 +245,11 @@ impl TrictracEnvironment {
|
|||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
fn execute_action(
|
||||
&mut self,
|
||||
action: dqn_common::TrictracAction,
|
||||
) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: dqn_common::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
|
||||
use dqn_common::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
|
|
@ -310,16 +328,22 @@ impl TrictracEnvironment {
|
|||
if self.game.validate(&dice_event) {
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
println!("info: rolled for {reward}");
|
||||
}
|
||||
// Récompense proportionnelle aux points
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pénalité pour action invalide
|
||||
reward -= 2.0;
|
||||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = Self::ERROR_REWARD;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(reward)
|
||||
reward
|
||||
}
|
||||
|
||||
/// Fait jouer l'adversaire avec une stratégie simple
|
||||
|
|
@ -329,15 +353,14 @@ impl TrictracEnvironment {
|
|||
// Si c'est le tour de l'adversaire, jouer automatiquement
|
||||
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||
// Utiliser la stratégie default pour l'adversaire
|
||||
use crate::strategy::default::DefaultStrategy;
|
||||
use crate::BotStrategy;
|
||||
|
||||
let mut default_strategy = DefaultStrategy::default();
|
||||
default_strategy.set_player_id(self.opponent_id);
|
||||
let mut strategy = crate::strategy::random::RandomStrategy::default();
|
||||
strategy.set_player_id(self.opponent_id);
|
||||
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
||||
default_strategy.set_color(color);
|
||||
strategy.set_color(color);
|
||||
}
|
||||
*default_strategy.get_mut_game() = self.game.clone();
|
||||
*strategy.get_mut_game() = self.game.clone();
|
||||
|
||||
// Exécuter l'action selon le turn_stage
|
||||
let event = match self.game.turn_stage {
|
||||
|
|
@ -365,7 +388,7 @@ impl TrictracEnvironment {
|
|||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
|
|
@ -397,7 +420,7 @@ impl TrictracEnvironment {
|
|||
}
|
||||
TurnStage::Move => GameEvent::Move {
|
||||
player_id: self.opponent_id,
|
||||
moves: default_strategy.choose_move(),
|
||||
moves: strategy.choose_move(),
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -408,3 +431,9 @@ impl TrictracEnvironment {
|
|||
reward
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||
fn as_mut(&mut self) -> &mut Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
|
|||
|
||||
fn main() {
|
||||
// println!("> Entraînement");
|
||||
|
||||
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||
let conf = dqn_model::DqnConfig {
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
|
||||
max_steps: 2000, // max steps by episode
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 3000.0,
|
||||
|
||||
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ MEMORY_SIZE
|
|||
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
||||
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
||||
- Pourquoi c'est important :
|
||||
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
||||
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
||||
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
||||
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
||||
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
||||
|
||||
DENSE_SIZE
|
||||
|
|
@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
|
|||
|
||||
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
||||
nouvelles (EPS*\*).
|
||||
|
||||
## Paramètres DQNTrainingConfig
|
||||
|
||||
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
|
||||
|
||||
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
|
||||
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
|
||||
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
|
||||
gain plus grand plus tard.
|
||||
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
|
||||
0.99 l'encourage à élaborer des stratégies à long terme.
|
||||
|
||||
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
|
||||
|
||||
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
|
||||
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
|
||||
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
|
||||
réseau cible.
|
||||
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
|
||||
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
|
||||
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
|
||||
coup de chance (ou de malchance).
|
||||
|
||||
3. `learning_rate` (Taux d'apprentissage)
|
||||
|
||||
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
|
||||
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
|
||||
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
|
||||
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
|
||||
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
|
||||
lent.
|
||||
|
||||
4. `batch_size` (Taille du lot)
|
||||
|
||||
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
|
||||
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
|
||||
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
|
||||
ses poids.
|
||||
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
|
||||
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
|
||||
particulier qui pourrait être une anomalie.
|
||||
|
||||
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
|
||||
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
|
||||
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
|
||||
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
|
||||
réseau a appris.
|
||||
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
|
||||
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
|
||||
mais on va corriger calmement, sans tout casser".
|
||||
|
|
|
|||
11
justfile
11
justfile
|
|
@ -9,8 +9,8 @@ shell:
|
|||
runcli:
|
||||
RUST_LOG=info cargo run --bin=client_cli
|
||||
runclibots:
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
|
||||
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||
match:
|
||||
cargo build --release --bin=client_cli
|
||||
|
|
@ -28,12 +28,9 @@ trainsimple:
|
|||
trainbot:
|
||||
#python ./store/python/trainModel.py
|
||||
# cargo run --bin=train_dqn # ok
|
||||
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
|
||||
cargo build --release --bin=train_dqn_burn
|
||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
||||
./bot/scripts/train.sh
|
||||
plottrainbot:
|
||||
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
||||
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
||||
./bot/scripts/train.sh plot
|
||||
debugtrainbot:
|
||||
cargo build --bin=train_dqn_burn
|
||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ impl GameState {
|
|||
|
||||
/// Get state as a vector (to be used for bot training input) :
|
||||
/// length = 36
|
||||
/// i8 for board positions with negative values for blacks
|
||||
pub fn to_vec(&self) -> Vec<i8> {
|
||||
let state_len = 36;
|
||||
let mut state = Vec::with_capacity(state_len);
|
||||
|
|
|
|||
Loading…
Reference in a new issue