Compare commits
2 commits
5b02293221
...
778ac1817b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
778ac1817b | ||
|
|
e4b3092018 |
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -11,6 +11,4 @@ devenv.local.nix
|
||||||
|
|
||||||
# generated by samply rust profiler
|
# generated by samply rust profiler
|
||||||
profile.json
|
profile.json
|
||||||
|
bot/models
|
||||||
# IA modles used by bots
|
|
||||||
/models
|
|
||||||
|
|
|
||||||
26
CLAUDE.md
26
CLAUDE.md
|
|
@ -1,26 +0,0 @@
|
||||||
# Trictrac Project Guidelines
|
|
||||||
|
|
||||||
## Build & Run Commands
|
|
||||||
- Build: `cargo build`
|
|
||||||
- Test: `cargo test`
|
|
||||||
- Test specific: `cargo test -- test_name`
|
|
||||||
- Lint: `cargo clippy`
|
|
||||||
- Format: `cargo fmt`
|
|
||||||
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
|
|
||||||
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
|
|
||||||
- Build Python lib: `maturin build -m store/Cargo.toml --release`
|
|
||||||
|
|
||||||
## Code Style
|
|
||||||
- Use Rust 2021 edition idioms
|
|
||||||
- Error handling: Use Result<T, Error> pattern with custom Error types
|
|
||||||
- Naming: snake_case for functions/variables, CamelCase for types
|
|
||||||
- Imports: Group standard lib, external crates, then internal modules
|
|
||||||
- Module structure: Prefer small, focused modules with clear responsibilities
|
|
||||||
- Documentation: Document public APIs with doc comments
|
|
||||||
- Testing: Write unit tests in same file as implementation
|
|
||||||
- Python bindings: Use pyo3 for creating Python modules
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
- Core game logic in `store` crate
|
|
||||||
- Multiple clients: CLI, TUI, Bevy (graphical)
|
|
||||||
- Bot interfaces in `bot` crate
|
|
||||||
38
bot/scripts/train.sh
Executable file
38
bot/scripts/train.sh
Executable file
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||||
|
LOGS_DIR="$ROOT/bot/models/logs"
|
||||||
|
|
||||||
|
CFG_SIZE=12
|
||||||
|
OPPONENT="random"
|
||||||
|
|
||||||
|
PLOT_EXT="png"
|
||||||
|
|
||||||
|
train() {
|
||||||
|
cargo build --release --bin=train_dqn_burn
|
||||||
|
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
|
||||||
|
LOGS="$LOGS_DIR/$NAME.out"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
|
||||||
|
}
|
||||||
|
|
||||||
|
plot() {
|
||||||
|
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||||
|
LOGS="$LOGS_DIR/$NAME"
|
||||||
|
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||||
|
for cfg in $cfgs; do
|
||||||
|
eval "$cfg"
|
||||||
|
done
|
||||||
|
|
||||||
|
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
|
||||||
|
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||||
|
grep -v "info:" |
|
||||||
|
awk -F '[ ,]' '{print $5}' |
|
||||||
|
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ "$1" = "plot" ]; then
|
||||||
|
plot
|
||||||
|
else
|
||||||
|
train
|
||||||
|
fi
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
|
|
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::fmt;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
|
|
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
const MEMORY_SIZE: usize = 8192;
|
const MEMORY_SIZE: usize = 8192;
|
||||||
|
|
||||||
pub struct DqnConfig {
|
pub struct DqnConfig {
|
||||||
|
pub min_steps: f32,
|
||||||
|
pub max_steps: usize,
|
||||||
pub num_episodes: usize,
|
pub num_episodes: usize,
|
||||||
// pub memory_size: usize,
|
|
||||||
pub dense_size: usize,
|
pub dense_size: usize,
|
||||||
pub eps_start: f64,
|
pub eps_start: f64,
|
||||||
pub eps_end: f64,
|
pub eps_end: f64,
|
||||||
pub eps_decay: f64,
|
pub eps_decay: f64,
|
||||||
|
|
||||||
|
pub gamma: f32,
|
||||||
|
pub tau: f32,
|
||||||
|
pub learning_rate: f32,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub clip_grad: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for DqnConfig {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
let mut s = String::new();
|
||||||
|
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||||
|
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||||
|
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||||
|
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||||
|
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||||
|
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||||
|
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||||
|
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||||
|
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||||
|
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||||
|
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||||
|
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||||
|
write!(f, "{s}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for DqnConfig {
|
impl Default for DqnConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
min_steps: 250.0,
|
||||||
|
max_steps: 2000,
|
||||||
num_episodes: 1000,
|
num_episodes: 1000,
|
||||||
// memory_size: 8192,
|
|
||||||
dense_size: 256,
|
dense_size: 256,
|
||||||
eps_start: 0.9,
|
eps_start: 0.9,
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
eps_decay: 1000.0,
|
eps_decay: 1000.0,
|
||||||
|
|
||||||
|
gamma: 0.999,
|
||||||
|
tau: 0.005,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 32,
|
||||||
|
clip_grad: 100.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -85,12 +120,14 @@ impl Default for DqnConfig {
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||||
conf: &DqnConfig,
|
conf: &DqnConfig,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
// ) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().min_steps = conf.min_steps;
|
||||||
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
<<E as Environment>::StateType as State>::size(),
|
<<E as Environment>::StateType as State>::size(),
|
||||||
|
|
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let mut agent = MyAgent::new(model);
|
let mut agent = MyAgent::new(model);
|
||||||
|
|
||||||
let config = DQNTrainingConfig::default();
|
// let config = DQNTrainingConfig::default();
|
||||||
|
let config = DQNTrainingConfig {
|
||||||
|
gamma: conf.gamma,
|
||||||
|
tau: conf.tau,
|
||||||
|
learning_rate: conf.learning_rate,
|
||||||
|
batch_size: conf.batch_size,
|
||||||
|
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||||
|
conf.clip_grad,
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||||
|
|
||||||
|
|
@ -145,16 +191,13 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
step += 1;
|
step += 1;
|
||||||
episode_duration += 1;
|
episode_duration += 1;
|
||||||
|
|
||||||
if snapshot.done() || episode_duration >= E::MAX_STEPS {
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
env.reset();
|
env.reset();
|
||||||
episode_done = true;
|
episode_done = true;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
|
||||||
episode,
|
now.elapsed().unwrap().as_secs(),
|
||||||
episode_reward,
|
|
||||||
episode_duration,
|
|
||||||
now.elapsed().unwrap().as_secs()
|
|
||||||
);
|
);
|
||||||
now = SystemTime::now();
|
now = SystemTime::now();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct TrictracState {
|
pub struct TrictracState {
|
||||||
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
|
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State for TrictracState {
|
impl State for TrictracState {
|
||||||
type Data = [f32; 36];
|
type Data = [i8; 36];
|
||||||
|
|
||||||
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
||||||
Tensor::from_floats(self.data, &B::Device::default())
|
Tensor::from_floats(self.data, &B::Device::default())
|
||||||
|
|
@ -25,8 +25,8 @@ impl State for TrictracState {
|
||||||
impl TrictracState {
|
impl TrictracState {
|
||||||
/// Convertit un GameState en TrictracState
|
/// Convertit un GameState en TrictracState
|
||||||
pub fn from_game_state(game_state: &GameState) -> Self {
|
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||||
let state_vec = game_state.to_vec_float();
|
let state_vec = game_state.to_vec();
|
||||||
let mut data = [0.0; 36];
|
let mut data = [0; 36];
|
||||||
|
|
||||||
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||||
let copy_len = state_vec.len().min(36);
|
let copy_len = state_vec.len().min(36);
|
||||||
|
|
@ -39,6 +39,7 @@ impl TrictracState {
|
||||||
/// Actions possibles dans Trictrac pour burn-rl
|
/// Actions possibles dans Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct TrictracAction {
|
pub struct TrictracAction {
|
||||||
|
// u32 as required by burn_rl::base::Action type
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,7 +83,11 @@ pub struct TrictracEnvironment {
|
||||||
opponent_id: PlayerId,
|
opponent_id: PlayerId,
|
||||||
current_state: TrictracState,
|
current_state: TrictracState,
|
||||||
episode_reward: f32,
|
episode_reward: f32,
|
||||||
step_count: usize,
|
pub step_count: usize,
|
||||||
|
pub min_steps: f32,
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub goodmoves_count: usize,
|
||||||
|
pub goodmoves_ratio: f32,
|
||||||
pub visualized: bool,
|
pub visualized: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -91,8 +96,6 @@ impl Environment for TrictracEnvironment {
|
||||||
type ActionType = TrictracAction;
|
type ActionType = TrictracAction;
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
||||||
|
|
@ -113,6 +116,10 @@ impl Environment for TrictracEnvironment {
|
||||||
current_state,
|
current_state,
|
||||||
episode_reward: 0.0,
|
episode_reward: 0.0,
|
||||||
step_count: 0,
|
step_count: 0,
|
||||||
|
min_steps: 250.0,
|
||||||
|
max_steps: 2000,
|
||||||
|
goodmoves_count: 0,
|
||||||
|
goodmoves_ratio: 0.0,
|
||||||
visualized,
|
visualized,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +139,18 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
self.episode_reward = 0.0;
|
self.episode_reward = 0.0;
|
||||||
|
self.goodmoves_ratio = if self.step_count == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
self.goodmoves_count as f32 / self.step_count as f32
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"info: correct moves: {} ({}%)",
|
||||||
|
self.goodmoves_count,
|
||||||
|
(100.0 * self.goodmoves_ratio).round() as u32
|
||||||
|
);
|
||||||
self.step_count = 0;
|
self.step_count = 0;
|
||||||
|
self.goodmoves_count = 0;
|
||||||
|
|
||||||
Snapshot::new(self.current_state, 0.0, false)
|
Snapshot::new(self.current_state, 0.0, false)
|
||||||
}
|
}
|
||||||
|
|
@ -149,14 +167,9 @@ impl Environment for TrictracEnvironment {
|
||||||
// Exécuter l'action si c'est le tour de l'agent DQN
|
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||||
if self.game.active_player_id == self.active_player_id {
|
if self.game.active_player_id == self.active_player_id {
|
||||||
if let Some(action) = trictrac_action {
|
if let Some(action) = trictrac_action {
|
||||||
match self.execute_action(action) {
|
reward = self.execute_action(action);
|
||||||
Ok(action_reward) => {
|
if reward != Self::ERROR_REWARD {
|
||||||
reward = action_reward;
|
self.goodmoves_count += 1;
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
// Action invalide, pénalité
|
|
||||||
reward = -1.0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Action non convertible, pénalité
|
// Action non convertible, pénalité
|
||||||
|
|
@ -170,12 +183,12 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vérifier si la partie est terminée
|
// Vérifier si la partie est terminée
|
||||||
let done = self.game.stage == Stage::Ended
|
let max_steps = self.min_steps
|
||||||
|| self.game.determine_winner().is_some()
|
+ (self.max_steps as f32 - self.min_steps)
|
||||||
|| self.step_count >= Self::MAX_STEPS;
|
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||||
|
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||||
|
|
||||||
if done {
|
if done {
|
||||||
terminated = true;
|
|
||||||
// Récompense finale basée sur le résultat
|
// Récompense finale basée sur le résultat
|
||||||
if let Some(winner_id) = self.game.determine_winner() {
|
if let Some(winner_id) = self.game.determine_winner() {
|
||||||
if winner_id == self.active_player_id {
|
if winner_id == self.active_player_id {
|
||||||
|
|
@ -185,6 +198,7 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||||
|
|
||||||
// Mettre à jour l'état
|
// Mettre à jour l'état
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
|
|
@ -202,6 +216,9 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrictracEnvironment {
|
impl TrictracEnvironment {
|
||||||
|
const ERROR_REWARD: f32 = -1.12121;
|
||||||
|
const REWARD_RATIO: f32 = 1.0;
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
|
|
@ -228,10 +245,11 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
fn execute_action(
|
// fn execute_action(
|
||||||
&mut self,
|
// &mut self,
|
||||||
action: dqn_common::TrictracAction,
|
// action: dqn_common::TrictracAction,
|
||||||
) -> Result<f32, Box<dyn std::error::Error>> {
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
|
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
|
||||||
use dqn_common::TrictracAction;
|
use dqn_common::TrictracAction;
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
|
|
@ -310,16 +328,22 @@ impl TrictracEnvironment {
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
|
if points > 0 {
|
||||||
|
println!("info: rolled for {reward}");
|
||||||
|
}
|
||||||
|
// Récompense proportionnelle aux points
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Pénalité pour action invalide
|
// Pénalité pour action invalide
|
||||||
reward -= 2.0;
|
// on annule les précédents reward
|
||||||
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
|
reward = Self::ERROR_REWARD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(reward)
|
reward
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fait jouer l'adversaire avec une stratégie simple
|
/// Fait jouer l'adversaire avec une stratégie simple
|
||||||
|
|
@ -329,15 +353,14 @@ impl TrictracEnvironment {
|
||||||
// Si c'est le tour de l'adversaire, jouer automatiquement
|
// Si c'est le tour de l'adversaire, jouer automatiquement
|
||||||
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
// Utiliser la stratégie default pour l'adversaire
|
// Utiliser la stratégie default pour l'adversaire
|
||||||
use crate::strategy::default::DefaultStrategy;
|
|
||||||
use crate::BotStrategy;
|
use crate::BotStrategy;
|
||||||
|
|
||||||
let mut default_strategy = DefaultStrategy::default();
|
let mut strategy = crate::strategy::random::RandomStrategy::default();
|
||||||
default_strategy.set_player_id(self.opponent_id);
|
strategy.set_player_id(self.opponent_id);
|
||||||
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
||||||
default_strategy.set_color(color);
|
strategy.set_color(color);
|
||||||
}
|
}
|
||||||
*default_strategy.get_mut_game() = self.game.clone();
|
*strategy.get_mut_game() = self.game.clone();
|
||||||
|
|
||||||
// Exécuter l'action selon le turn_stage
|
// Exécuter l'action selon le turn_stage
|
||||||
let event = match self.game.turn_stage {
|
let event = match self.game.turn_stage {
|
||||||
|
|
@ -365,7 +388,7 @@ impl TrictracEnvironment {
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
|
|
@ -397,7 +420,7 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
TurnStage::Move => GameEvent::Move {
|
TurnStage::Move => GameEvent::Move {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
moves: default_strategy.choose_move(),
|
moves: strategy.choose_move(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -408,3 +431,9 @@ impl TrictracEnvironment {
|
||||||
reward
|
reward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||||
|
fn as_mut(&mut self) -> &mut Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// println!("> Entraînement");
|
// println!("> Entraînement");
|
||||||
|
|
||||||
|
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||||
let conf = dqn_model::DqnConfig {
|
let conf = dqn_model::DqnConfig {
|
||||||
num_episodes: 40,
|
num_episodes: 40,
|
||||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
|
||||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
max_steps: 2000, // max steps by episode
|
||||||
dense_size: 256, // neural network complexity
|
dense_size: 256, // neural network complexity
|
||||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
|
// eps_decay higher = epsilon decrease slower
|
||||||
|
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||||
|
// epsilon is updated at the start of each episode
|
||||||
eps_decay: 3000.0,
|
eps_decay: 3000.0,
|
||||||
|
|
||||||
|
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
|
||||||
|
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||||
|
// plus lente moins sensible aux coups de chance
|
||||||
|
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||||
|
// converger
|
||||||
|
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||||
|
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
|
||||||
};
|
};
|
||||||
|
println!("{conf}----------");
|
||||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
let valid_agent = agent.valid();
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ MEMORY_SIZE
|
||||||
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
||||||
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
||||||
- Pourquoi c'est important :
|
- Pourquoi c'est important :
|
||||||
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
||||||
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
||||||
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
||||||
|
|
||||||
DENSE_SIZE
|
DENSE_SIZE
|
||||||
|
|
@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
|
||||||
|
|
||||||
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
||||||
nouvelles (EPS*\*).
|
nouvelles (EPS*\*).
|
||||||
|
|
||||||
|
## Paramètres DQNTrainingConfig
|
||||||
|
|
||||||
|
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
|
||||||
|
|
||||||
|
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
|
||||||
|
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
|
||||||
|
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
|
||||||
|
gain plus grand plus tard.
|
||||||
|
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
|
||||||
|
0.99 l'encourage à élaborer des stratégies à long terme.
|
||||||
|
|
||||||
|
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
|
||||||
|
|
||||||
|
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
|
||||||
|
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
|
||||||
|
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
|
||||||
|
réseau cible.
|
||||||
|
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
|
||||||
|
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
|
||||||
|
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
|
||||||
|
coup de chance (ou de malchance).
|
||||||
|
|
||||||
|
3. `learning_rate` (Taux d'apprentissage)
|
||||||
|
|
||||||
|
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
|
||||||
|
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
|
||||||
|
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
|
||||||
|
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
|
||||||
|
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
|
||||||
|
lent.
|
||||||
|
|
||||||
|
4. `batch_size` (Taille du lot)
|
||||||
|
|
||||||
|
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
|
||||||
|
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
|
||||||
|
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
|
||||||
|
ses poids.
|
||||||
|
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
|
||||||
|
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
|
||||||
|
particulier qui pourrait être une anomalie.
|
||||||
|
|
||||||
|
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
|
||||||
|
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
|
||||||
|
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
|
||||||
|
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
|
||||||
|
réseau a appris.
|
||||||
|
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
|
||||||
|
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
|
||||||
|
mais on va corriger calmement, sans tout casser".
|
||||||
|
|
|
||||||
11
justfile
11
justfile
|
|
@ -9,8 +9,8 @@ shell:
|
||||||
runcli:
|
runcli:
|
||||||
RUST_LOG=info cargo run --bin=client_cli
|
RUST_LOG=info cargo run --bin=client_cli
|
||||||
runclibots:
|
runclibots:
|
||||||
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
|
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||||
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
|
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||||
match:
|
match:
|
||||||
cargo build --release --bin=client_cli
|
cargo build --release --bin=client_cli
|
||||||
|
|
@ -28,12 +28,9 @@ trainsimple:
|
||||||
trainbot:
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
|
./bot/scripts/train.sh
|
||||||
cargo build --release --bin=train_dqn_burn
|
|
||||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
|
||||||
plottrainbot:
|
plottrainbot:
|
||||||
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
./bot/scripts/train.sh plot
|
||||||
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
|
||||||
debugtrainbot:
|
debugtrainbot:
|
||||||
cargo build --bin=train_dqn_burn
|
cargo build --bin=train_dqn_burn
|
||||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,7 @@ impl GameState {
|
||||||
|
|
||||||
/// Get state as a vector (to be used for bot training input) :
|
/// Get state as a vector (to be used for bot training input) :
|
||||||
/// length = 36
|
/// length = 36
|
||||||
|
/// i8 for board positions with negative values for blacks
|
||||||
pub fn to_vec(&self) -> Vec<i8> {
|
pub fn to_vec(&self) -> Vec<i8> {
|
||||||
let state_len = 36;
|
let state_len = 36;
|
||||||
let mut state = Vec::with_capacity(state_len);
|
let mut state = Vec::with_capacity(state_len);
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue