From e4b3092018e2b81486bceb0cea8301e89d45064a Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 10 Aug 2025 08:39:31 +0200 Subject: [PATCH] train burn-rl with integers --- CLAUDE.md | 26 ------------ bot/src/dqn/burnrl/dqn_model.rs | 7 +--- bot/src/dqn/burnrl/environment.rs | 69 ++++++++++++++++++------------- bot/src/dqn/burnrl/main.rs | 4 +- justfile | 4 +- store/src/game.rs | 1 + 6 files changed, 47 insertions(+), 64 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index bdbc72d..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,26 +0,0 @@ -# Trictrac Project Guidelines - -## Build & Run Commands -- Build: `cargo build` -- Test: `cargo test` -- Test specific: `cargo test -- test_name` -- Lint: `cargo clippy` -- Format: `cargo fmt` -- Run CLI: `RUST_LOG=info cargo run --bin=client_cli` -- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy` -- Build Python lib: `maturin build -m store/Cargo.toml --release` - -## Code Style -- Use Rust 2021 edition idioms -- Error handling: Use Result pattern with custom Error types -- Naming: snake_case for functions/variables, CamelCase for types -- Imports: Group standard lib, external crates, then internal modules -- Module structure: Prefer small, focused modules with clear responsibilities -- Documentation: Document public APIs with doc comments -- Testing: Write unit tests in same file as implementation -- Python bindings: Use pyo3 for creating Python modules - -## Architecture -- Core game logic in `store` crate -- Multiple clients: CLI, TUI, Bevy (graphical) -- Bot interfaces in `bot` crate \ No newline at end of file diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs index 0c333b0..2dd696f 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -150,11 +150,8 @@ pub fn run( episode_done = true; println!( - "{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}", - episode, - episode_reward, - episode_duration, - now.elapsed().unwrap().as_secs() + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), ); now = SystemTime::now(); } else { diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index d5e0028..dd8b09f 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] pub struct TrictracState { - pub data: [f32; 36], // Représentation vectorielle de l'état du jeu + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu } impl State for TrictracState { - type Data = [f32; 36]; + type Data = [i8; 36]; fn to_tensor(&self) -> Tensor { Tensor::from_floats(self.data, &B::Device::default()) @@ -25,8 +25,8 @@ impl State for TrictracState { impl TrictracState { /// Convertit un GameState en TrictracState pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec_float(); - let mut data = [0.0; 36]; + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; // Copier les données en s'assurant qu'on ne dépasse pas la taille let copy_len = state_vec.len().min(36); @@ -39,6 +39,7 @@ impl TrictracState { /// Actions possibles dans Trictrac pour burn-rl #[derive(Debug, Clone, Copy, PartialEq)] pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type pub index: u32, } @@ -82,7 +83,8 @@ pub struct TrictracEnvironment { opponent_id: PlayerId, current_state: TrictracState, episode_reward: f32, - step_count: usize, + pub step_count: usize, + pub goodmoves_count: usize, pub visualized: bool, } @@ -91,7 +93,7 @@ impl Environment for TrictracEnvironment { type ActionType = TrictracAction; type RewardType = f32; - const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies + const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { let mut game = GameState::new(false); @@ -113,6 +115,7 @@ impl Environment for TrictracEnvironment { current_state, episode_reward: 0.0, step_count: 0, + goodmoves_count: 0, visualized, } } @@ -132,7 +135,13 @@ impl Environment for TrictracEnvironment { self.current_state = TrictracState::from_game_state(&self.game); self.episode_reward = 0.0; + println!( + "correct moves: {} ({}%)", + self.goodmoves_count, + 100 * self.goodmoves_count / self.step_count + ); self.step_count = 0; + self.goodmoves_count = 0; Snapshot::new(self.current_state, 0.0, false) } @@ -149,14 +158,9 @@ impl Environment for TrictracEnvironment { // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { - match self.execute_action(action) { - Ok(action_reward) => { - reward = action_reward; - } - Err(_) => { - // Action invalide, pénalité - reward = -1.0; - } + reward = self.execute_action(action); + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; } } else { // Action non convertible, pénalité @@ -202,6 +206,9 @@ impl Environment for TrictracEnvironment { } impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + /// Convertit une action burn-rl vers une action Trictrac pub fn convert_action(action: TrictracAction) -> Option { dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) @@ -228,10 +235,11 @@ impl TrictracEnvironment { } /// Exécute une action Trictrac dans le jeu - fn execute_action( - &mut self, - action: dqn_common::TrictracAction, - ) -> Result> { + // fn execute_action( + // &mut self, + // action: dqn_common::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 { use dqn_common::TrictracAction; let mut reward = 0.0; @@ -310,16 +318,22 @@ impl TrictracEnvironment { if self.game.validate(&dice_event) { self.game.consume(&dice_event); let (points, adv_points) = self.game.dice_points; - reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + println!("rolled for {reward}"); + } + // Récompense proportionnelle aux points } } } else { // Pénalité pour action invalide - reward -= 2.0; + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; } } - Ok(reward) + reward } /// Fait jouer l'adversaire avec une stratégie simple @@ -329,15 +343,14 @@ impl TrictracEnvironment { // Si c'est le tour de l'adversaire, jouer automatiquement if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { // Utiliser la stratégie default pour l'adversaire - use crate::strategy::default::DefaultStrategy; use crate::BotStrategy; - let mut default_strategy = DefaultStrategy::default(); - default_strategy.set_player_id(self.opponent_id); + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - default_strategy.set_color(color); + strategy.set_color(color); } - *default_strategy.get_mut_game() = self.game.clone(); + *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage let event = match self.game.turn_stage { @@ -365,7 +378,7 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points GameEvent::Mark { player_id: self.opponent_id, @@ -397,7 +410,7 @@ impl TrictracEnvironment { } TurnStage::Move => GameEvent::Move { player_id: self.opponent_id, - moves: default_strategy.choose_move(), + moves: strategy.choose_move(), }, }; diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 4b3a789..fcc513a 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -14,11 +14,11 @@ fn main() { let conf = dqn_model::DqnConfig { num_episodes: 40, // memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant - // max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant + // max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant dense_size: 256, // neural network complexity eps_start: 0.9, // epsilon initial value (0.9 => more exploration) eps_end: 0.05, - eps_decay: 3000.0, + eps_decay: 1500.0, }; let agent = dqn_model::run::(&conf, false); //true); diff --git a/justfile b/justfile index dcb5117..6570cb1 100644 --- a/justfile +++ b/justfile @@ -28,12 +28,10 @@ trainsimple: trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok - # cargo run --bin=train_dqn_burn # utilise debug (why ?) cargo build --release --bin=train_dqn_burn LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out plottrainbot: - cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid - #tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid + cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png debugtrainbot: cargo build --bin=train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn diff --git a/store/src/game.rs b/store/src/game.rs index 200c321..2b7fa46 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -151,6 +151,7 @@ impl GameState { /// Get state as a vector (to be used for bot training input) : /// length = 36 + /// i8 for board positions with negative values for blacks pub fn to_vec(&self) -> Vec { let state_len = 36; let mut state = Vec::with_capacity(state_len);