From 12004ec4f38e5ddfc1d98d427ce2f53eef94e2aa Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 4 Aug 2025 18:04:40 +0200 Subject: [PATCH 1/7] wip bot mirror --- Cargo.lock | 2 + bot/Cargo.toml | 1 + bot/src/lib.rs | 88 ++++++--- bot/src/strategy/default.rs | 4 +- client_cli/Cargo.toml | 1 + client_cli/src/game_runner.rs | 15 +- justfile | 4 +- store/src/board.rs | 47 +++-- store/src/game.rs | 83 +++++++- store/src/game_rules_moves.rs | 336 ++++++++++++++++++++++----------- store/src/game_rules_points.rs | 329 ++++++++++++++++++++++---------- 11 files changed, 656 insertions(+), 254 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2ba864f..d504e2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,7 @@ dependencies = [ "burn", "burn-rl", "env_logger 0.10.0", + "log", "pretty_assertions", "rand 0.8.5", "serde", @@ -881,6 +882,7 @@ dependencies = [ "bot", "env_logger 0.11.6", "itertools 0.13.0", + "log", "pico-args", "pretty_assertions", "renet", diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 3fd08c4..a5667fa 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -22,3 +22,4 @@ rand = "0.8" env_logger = "0.10" burn = { version = "0.17", features = ["ndarray", "autodiff"] } burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" } +log = "0.4.20" diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 65424fc..6326253 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,6 +1,7 @@ pub mod dqn; pub mod strategy; +use log::{error, info}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; @@ -26,7 +27,7 @@ pub trait BotStrategy: std::fmt::Debug { pub struct Bot { pub player_id: PlayerId, strategy: Box, - // color: Color, + color: Color, // schools_enabled: bool, } @@ -34,9 +35,9 @@ impl Default for Bot { fn default() -> Self { let strategy = DefaultStrategy::default(); Self { - player_id: 2, + player_id: 1, strategy: Box::new(strategy), - // color: Color::Black, + color: Color::White, // schools_enabled: false, } } @@ -52,57 +53,86 @@ impl Bot { Color::White => 1, Color::Black => 2, }; - strategy.set_player_id(player_id); - strategy.set_color(color); + // strategy.set_player_id(player_id); + // strategy.set_color(color); Self { player_id, strategy, - // color, + color, // schools_enabled: false, } } pub fn handle_event(&mut self, event: &GameEvent) -> Option { + info!(">>>> {:?} BOT handle", self.color); let game = self.strategy.get_mut_game(); - game.consume(event); + let internal_event = if self.color == Color::Black { + &event.get_mirror() + } else { + event + }; + + let init_player_points = game.who_plays().map(|p| (p.points, p.holes)); + let turn_stage = game.turn_stage; + game.consume(internal_event); if game.stage == Stage::Ended { + info!("<<<< end {:?} BOT handle", self.color); return None; } - if game.active_player_id == self.player_id { - return match game.turn_stage { + let active_player_id = if self.color == Color::Black { + if game.active_player_id == 1 { + 2 + } else { + 1 + } + } else { + game.active_player_id + }; + if active_player_id == self.player_id { + let player_points = game.who_plays().map(|p| (p.points, p.holes)); + if self.color == Color::Black { + info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); + } + let internal_event = match game.turn_stage { TurnStage::MarkAdvPoints => Some(GameEvent::Mark { - player_id: self.player_id, + player_id: 1, points: self.strategy.calculate_adv_points(), }), - TurnStage::RollDice => Some(GameEvent::Roll { - player_id: self.player_id, - }), + TurnStage::RollDice => Some(GameEvent::Roll { player_id: 1 }), TurnStage::MarkPoints => Some(GameEvent::Mark { - player_id: self.player_id, + player_id: 1, points: self.strategy.calculate_points(), }), TurnStage::Move => Some(GameEvent::Move { - player_id: self.player_id, + player_id: 1, moves: self.strategy.choose_move(), }), TurnStage::HoldOrGoChoice => { if self.strategy.choose_go() { - Some(GameEvent::Go { - player_id: self.player_id, - }) + Some(GameEvent::Go { player_id: 1 }) } else { Some(GameEvent::Move { - player_id: self.player_id, + player_id: 1, moves: self.strategy.choose_move(), }) } } _ => None, }; + return if self.color == Color::Black { + info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); + info!("<<<< end {:?} BOT handle", self.color); + internal_event.map(|evt| evt.get_mirror()) + } else { + info!("<<<< end {:?} BOT handle", self.color); + internal_event + }; } + info!("<<<< end {:?} BOT handle", self.color); None } + // Only used in tests below pub fn get_state(&self) -> &GameState { self.strategy.get_game() } @@ -121,17 +151,31 @@ mod tests { } #[test] - fn test_consume() { + fn test_handle_event() { let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black); // let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black, false); let mut event = bot.handle_event(&GameEvent::BeginGame { goes_first: 2 }); assert_eq!(event, Some(GameEvent::Roll { player_id: 2 })); - assert_eq!(bot.get_state().active_player_id, 2); + assert_eq!(bot.get_state().active_player_id, 1); // bot internal active_player_id for black + event = bot.handle_event(&GameEvent::RollResult { + player_id: 2, + dice: Dice { values: (2, 3) }, + }); + assert_eq!( + event, + Some(GameEvent::Move { + player_id: 2, + moves: ( + CheckerMove::new(24, 21).unwrap(), + CheckerMove::new(24, 22).unwrap() + ) + }) + ); event = bot.handle_event(&GameEvent::BeginGame { goes_first: 1 }); assert_eq!(event, None); - assert_eq!(bot.get_state().active_player_id, 1); + assert_eq!(bot.get_state().active_player_id, 2); //internal active_player_id bot.handle_event(&GameEvent::RollResult { player_id: 1, dice: Dice { values: (2, 3) }, diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index e01f406..628ce83 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -13,8 +13,8 @@ impl Default for DefaultStrategy { let game = GameState::default(); Self { game, - player_id: 2, - color: Color::Black, + player_id: 1, + color: Color::White, } } } diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index 4dcd86f..6c1d4e1 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -15,3 +15,4 @@ store = { path = "../store" } bot = { path = "../bot" } itertools = "0.13.0" env_logger = "0.11.6" +log = "0.4.20" diff --git a/client_cli/src/game_runner.rs b/client_cli/src/game_runner.rs index 9944918..296c907 100644 --- a/client_cli/src/game_runner.rs +++ b/client_cli/src/game_runner.rs @@ -1,4 +1,5 @@ use bot::{Bot, BotStrategy}; +use log::{error, info}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; // Application Game @@ -62,11 +63,21 @@ impl GameRunner { return None; } let valid_event = if self.state.validate(event) { + info!( + "--------------- new valid event {event:?} (stage {:?}) -----------", + self.state.turn_stage + ); self.state.consume(event); + info!( + " --> stage {:?} ; active player points {:?}", + self.state.turn_stage, + self.state.who_plays().map(|p| p.points) + ); event } else { - println!("{}", self.state); - println!("event not valid : {:?}", event); + info!("{}", self.state); + error!("event not valid : {event:?}"); + panic!("crash and burn"); &GameEvent::PlayError }; diff --git a/justfile b/justfile index d4f14c4..16f56ce 100644 --- a/justfile +++ b/justfile @@ -9,8 +9,8 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - #RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy - RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn + RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy + # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn diff --git a/store/src/board.rs b/store/src/board.rs index 646e929..3e563d0 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -114,7 +114,7 @@ impl fmt::Display for Board { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut s = String::new(); s.push_str(&format!("{:?}", self.positions)); - write!(f, "{}", s) + write!(f, "{s}") } } @@ -132,8 +132,13 @@ impl Board { } /// Globally set pieces on board ( for tests ) - pub fn set_positions(&mut self, positions: [i8; 24]) { - self.positions = positions; + pub fn set_positions(&mut self, color: &Color, positions: [i8; 24]) { + let mut new_positions = positions; + if color == &Color::Black { + new_positions = new_positions.map(|c| 0 - c); + new_positions.reverse(); + } + self.positions = new_positions; } pub fn count_checkers(&self, color: Color, from: Field, to: Field) -> u8 { @@ -672,9 +677,12 @@ mod tests { #[test] fn is_quarter_fillable() { let mut board = Board::new(); - board.set_positions([ - 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + board.set_positions( + &Color::White, + [ + 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); assert!(board.is_quarter_fillable(Color::Black, 1)); assert!(!board.is_quarter_fillable(Color::Black, 12)); assert!(board.is_quarter_fillable(Color::Black, 13)); @@ -683,25 +691,34 @@ mod tests { assert!(board.is_quarter_fillable(Color::White, 12)); assert!(!board.is_quarter_fillable(Color::White, 13)); assert!(board.is_quarter_fillable(Color::White, 24)); - board.set_positions([ - 5, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -8, 0, 0, 0, 0, 0, -5, - ]); + board.set_positions( + &Color::White, + [ + 5, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -8, 0, 0, 0, 0, 0, -5, + ], + ); assert!(board.is_quarter_fillable(Color::Black, 13)); assert!(!board.is_quarter_fillable(Color::Black, 24)); assert!(!board.is_quarter_fillable(Color::White, 1)); assert!(board.is_quarter_fillable(Color::White, 12)); - board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, - ]); + board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, + ], + ); assert!(board.is_quarter_fillable(Color::Black, 16)); } #[test] fn get_quarter_filling_candidate() { let mut board = Board::new(); - board.set_positions([ - 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + board.set_positions( + &Color::White, + [ + 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White)); } } diff --git a/store/src/game.rs b/store/src/game.rs index d500342..c9995b8 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -91,7 +91,8 @@ impl fmt::Display for GameState { s.push_str(&format!("Dice: {:?}\n", self.dice)); // s.push_str(&format!("Who plays: {}\n", self.who_plays().map(|player| &player.name ).unwrap_or(""))); s.push_str(&format!("Board: {:?}\n", self.board)); - write!(f, "{}", s) + // s.push_str(&format!("History: {:?}\n", self.history)); + write!(f, "{s}") } } @@ -372,22 +373,30 @@ impl GameState { } Go { player_id } => { if !self.players.contains_key(player_id) { - error!("Player {} unknown", player_id); + error!("Player {player_id} unknown"); return false; } // Check player is currently the one making their move if self.active_player_id != *player_id { + error!("Player not active : {}", self.active_player_id); return false; } // Check the player can leave (ie the game is in the KeepOrLeaveChoice stage) if self.turn_stage != TurnStage::HoldOrGoChoice { + error!("bad stage {:?}", self.turn_stage); + error!( + "black player points : {:?}", + self.get_black_player() + .map(|player| (player.points, player.holes)) + ); + // error!("history {:?}", self.history); return false; } } Move { player_id, moves } => { // Check player exists if !self.players.contains_key(player_id) { - error!("Player {} unknown", player_id); + error!("Player {player_id} unknown"); return false; } // Check player is currently the one making their move @@ -512,12 +521,15 @@ impl GameState { self.inc_roll_count(self.active_player_id); self.turn_stage = TurnStage::MarkPoints; (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); + info!("points from result : {:?}", self.dice_points); if !self.schools_enabled { // Schools are not enabled. We mark points automatically // the points earned by the opponent will be marked on its turn let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); if new_hole { - if self.get_active_player().unwrap().holes > 12 { + let holes_count = self.get_active_player().unwrap().holes; + info!("new hole -> {holes_count:?}"); + if holes_count > 12 { self.stage = Stage::Ended; } else { self.turn_stage = TurnStage::HoldOrGoChoice; @@ -594,6 +606,10 @@ impl GameState { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { let player = &self.players.get(&self.active_player_id).unwrap(); + info!( + "get rollresult for {:?} {:?} {:?} (roll count {:?})", + player.color, self.board, dice, player.dice_roll_count + ); let points_rules = PointsRules::new(&player.color, &self.board, *dice); points_rules.get_result_jans(player.dice_roll_count) } @@ -636,10 +652,11 @@ impl GameState { p.points = sum_points % 12; p.holes += holes; - if points > 0 && p.holes > 15 { + // if points > 0 && p.holes > 15 { + if points > 0 { info!( - "player {:?} holes : {:?} added points : {:?}", - player_id, p.holes, points + "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", + p.holes, p.points ) } p @@ -733,6 +750,58 @@ impl GameEvent { _ => None, } } + + pub fn get_mirror(&self) -> Self { + // let mut mirror = self.clone(); + let mirror_player_id = if let Some(player_id) = self.player_id() { + if player_id == 1 { + 2 + } else { + 1 + } + } else { + 0 + }; + + match self { + Self::PlayerJoined { player_id: _, name } => Self::PlayerJoined { + player_id: mirror_player_id, + name: name.clone(), + }, + Self::PlayerDisconnected { player_id: _ } => GameEvent::PlayerDisconnected { + player_id: mirror_player_id, + }, + Self::Roll { player_id: _ } => GameEvent::Roll { + player_id: mirror_player_id, + }, + Self::RollResult { player_id: _, dice } => GameEvent::RollResult { + player_id: mirror_player_id, + dice: *dice, + }, + Self::Mark { + player_id: _, + points, + } => GameEvent::Mark { + player_id: mirror_player_id, + points: *points, + }, + Self::Go { player_id: _ } => GameEvent::Go { + player_id: mirror_player_id, + }, + Self::Move { + player_id: _, + moves: (move1, move2), + } => Self::Move { + player_id: mirror_player_id, + moves: (move1.mirror(), move2.mirror()), + }, + Self::BeginGame { goes_first } => GameEvent::BeginGame { + goes_first: (if *goes_first == 1 { 2 } else { 1 }), + }, + Self::EndGame { reason } => GameEvent::EndGame { reason: *reason }, + Self::PlayError => GameEvent::PlayError, + } + } } #[cfg(test)] diff --git a/store/src/game_rules_moves.rs b/store/src/game_rules_moves.rs index 17e572e..31c43fa 100644 --- a/store/src/game_rules_moves.rs +++ b/store/src/game_rules_moves.rs @@ -625,18 +625,24 @@ mod tests { #[test] fn can_take_corner_by_effect() { let mut rules = MoveRules::default(); - rules.board.set_positions([ - 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + rules.board.set_positions( + &Color::White, + [ + 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); rules.dice.values = (4, 4); assert!(rules.can_take_corner_by_effect()); rules.dice.values = (5, 5); assert!(!rules.can_take_corner_by_effect()); - rules.board.set_positions([ - 10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + rules.board.set_positions( + &Color::White, + [ + 10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); rules.dice.values = (4, 4); assert!(!rules.can_take_corner_by_effect()); } @@ -645,9 +651,12 @@ mod tests { fn prise_en_puissance() { let mut state = MoveRules::default(); // prise par puissance ok - state.board.set_positions([ - 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + state.board.set_positions( + &Color::White, + [ + 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(8, 12).unwrap(), @@ -658,25 +667,34 @@ mod tests { assert!(state.moves_allowed(&moves).is_ok()); // opponent corner must be empty - state.board.set_positions([ - 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13, - ]); + state.board.set_positions( + &Color::White, + [ + 10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13, + ], + ); assert!(!state.is_move_by_puissance(&moves)); assert!(!state.moves_follows_dices(&moves)); // Si on a la possibilité de prendre son coin à la fois par effet, c'est à dire naturellement, et aussi par puissance, on doit le prendre par effet - state.board.set_positions([ - 5, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + state.board.set_positions( + &Color::White, + [ + 5, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); assert_eq!( Err(MoveError::CornerByEffectPossible), state.moves_allowed(&moves) ); // on a déjà pris son coin : on ne peux plus y deplacer des dames par puissance - state.board.set_positions([ - 8, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, - ]); + state.board.set_positions( + &Color::White, + [ + 8, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, + ], + ); assert!(!state.is_move_by_puissance(&moves)); assert!(!state.moves_follows_dices(&moves)); } @@ -685,9 +703,12 @@ mod tests { fn exit() { let mut state = MoveRules::default(); // exit ok - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(20, 0).unwrap(), @@ -697,9 +718,12 @@ mod tests { assert!(state.moves_allowed(&moves).is_ok()); // toutes les dames doivent être dans le jan de retour - state.board.set_positions([ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(20, 0).unwrap(), @@ -711,9 +735,12 @@ mod tests { ); // on ne peut pas sortir une dame avec un nombre excédant si on peut en jouer une avec un nombre défaillant - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(20, 0).unwrap(), @@ -725,9 +752,12 @@ mod tests { ); // on doit jouer le nombre excédant le plus éloigné - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(20, 0).unwrap(), @@ -741,9 +771,12 @@ mod tests { assert!(state.moves_allowed(&moves).is_ok()); // Cas de la dernière dame - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(23, 0).unwrap(), @@ -756,9 +789,12 @@ mod tests { #[test] fn move_check_opponent_fillable_quarter() { let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(11, 16).unwrap(), @@ -766,9 +802,12 @@ mod tests { ); assert!(state.moves_allowed(&moves).is_ok()); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(11, 16).unwrap(), @@ -779,9 +818,12 @@ mod tests { state.moves_allowed(&moves) ); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(11, 16).unwrap(), @@ -789,9 +831,12 @@ mod tests { ); assert!(state.moves_allowed(&moves).is_ok()); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, -12, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, -12, + ], + ); state.dice.values = (5, 5); let moves = ( CheckerMove::new(11, 16).unwrap(), @@ -806,9 +851,12 @@ mod tests { #[test] fn move_check_fillable_quarter() { let mut state = MoveRules::default(); - state.board.set_positions([ - 3, 3, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 3, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (5, 4); let moves = ( CheckerMove::new(1, 6).unwrap(), @@ -821,9 +869,12 @@ mod tests { ); assert_eq!(Err(MoveError::MustFillQuarter), state.moves_allowed(&moves)); - state.board.set_positions([ - 2, 3, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 3, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 3); let moves = ( CheckerMove::new(6, 8).unwrap(), @@ -840,9 +891,12 @@ mod tests { #[test] fn move_play_all_dice() { let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + ], + ); state.dice.values = (1, 3); let moves = ( CheckerMove::new(22, 0).unwrap(), @@ -861,9 +915,12 @@ mod tests { fn move_opponent_rest_corner_rules() { // fill with 2 checkers : forbidden let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (1, 1); let moves = ( CheckerMove::new(12, 13).unwrap(), @@ -891,9 +948,12 @@ mod tests { fn move_rest_corner_enter() { // direct let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let moves = ( CheckerMove::new(10, 12).unwrap(), @@ -915,9 +975,12 @@ mod tests { #[test] fn move_rest_corner_blocked() { let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let moves = ( CheckerMove::new(0, 0).unwrap(), @@ -926,9 +989,12 @@ mod tests { assert!(state.moves_follows_dices(&moves)); assert!(state.moves_allowed(&moves).is_ok()); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + ], + ); state.dice.values = (2, 1); let moves = ( CheckerMove::new(23, 24).unwrap(), @@ -949,9 +1015,12 @@ mod tests { #[test] fn move_rest_corner_exit() { let mut state = MoveRules::default(); - state.board.set_positions([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 3); let moves = ( CheckerMove::new(12, 14).unwrap(), @@ -967,9 +1036,12 @@ mod tests { fn move_rest_corner_toutdune() { let mut state = MoveRules::default(); // We can't go to the occupied rest corner as an intermediary step - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let moves = ( CheckerMove::new(11, 13).unwrap(), @@ -978,9 +1050,12 @@ mod tests { assert!(!state.moves_possible(&moves)); // We can use the empty rest corner as an intermediary step - state.board.set_positions([ - 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, + ], + ); state.dice.values = (6, 5); let moves = ( CheckerMove::new(8, 13).unwrap(), @@ -994,9 +1069,12 @@ mod tests { #[test] fn move_play_stronger_dice() { let mut state = MoveRules::default(); - state.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, -1, -1, -1, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, -1, -1, -1, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 3); let moves = ( CheckerMove::new(12, 14).unwrap(), @@ -1034,9 +1112,12 @@ mod tests { assert!(!state.moves_possible(&moves)); // Can't move the same checker twice - state.board.set_positions([ - 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let moves = ( CheckerMove::new(3, 5).unwrap(), @@ -1056,9 +1137,12 @@ mod tests { #[test] fn filling_moves_sequences() { let mut state = MoveRules::default(); - state.board.set_positions([ - 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); // println!( @@ -1067,17 +1151,23 @@ mod tests { // ); assert_eq!(2, filling_moves_sequences.len()); - state.board.set_positions([ - 3, 2, 3, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 2, 3, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 2); let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); // println!("{:?}", filling_moves_sequences); assert_eq!(2, filling_moves_sequences.len()); - state.board.set_positions([ - 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); // println!( @@ -1087,9 +1177,12 @@ mod tests { assert_eq!(2, filling_moves_sequences.len()); // positions - state.board.set_positions([ - 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, + ], + ); state.dice.values = (6, 5); let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); assert_eq!(1, filling_moves_sequences.len()); @@ -1099,19 +1192,46 @@ mod tests { fn scoring_filling_moves_sequences() { let mut state = MoveRules::default(); - state.board.set_positions([ - 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len()); - state.board.set_positions([ - 2, 3, 3, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 3, 3, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 1); let filling_moves_sequences = state.get_scoring_quarter_filling_moves_sequences(); // println!("{:?}", filling_moves_sequences); assert_eq!(3, filling_moves_sequences.len()); + + // preserve filling + state.board.set_positions( + &Color::White, + [ + 2, 2, 2, 2, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -1, -2, -3, -5, 0, -1, + ], + ); + state.dice.values = (3, 1); + assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len()); + + // preserve filling (black) + let mut state = MoveRules::new(&Color::Black, &Board::default(), Dice::default()); + state.board.set_positions( + &Color::Black, + [ + 1, 0, 5, 3, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -4, -2, -2, -2, -2, -2, + ], + ); + state.dice.values = (3, 1); + assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len()); } // prise de coin par puissance et conservation de jan #18 @@ -1120,9 +1240,12 @@ mod tests { fn corner_by_effect_and_filled_corner() { let mut state = MoveRules::default(); - state.board.set_positions([ - 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, + ], + ); state.dice.values = (6, 5); let moves = ( @@ -1155,9 +1278,12 @@ mod tests { fn get_possible_moves_sequences() { let mut state = MoveRules::default(); - state.board.set_positions([ - 2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + state.board.set_positions( + &Color::White, + [ + 2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); state.dice.values = (2, 3); let moves = ( CheckerMove::new(9, 11).unwrap(), diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index 8656b54..24991eb 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -5,6 +5,7 @@ use crate::player::Color; use crate::CheckerMove; use crate::Error; +use log::info; use serde::{Deserialize, Serialize}; use std::cmp; use std::collections::HashMap; @@ -158,9 +159,9 @@ impl PointsRules { self.move_rules.dice = dice; } - pub fn update_positions(&mut self, positions: [i8; 24]) { - self.board.set_positions(positions); - self.move_rules.board.set_positions(positions); + pub fn update_positions(&mut self, color: &Color, positions: [i8; 24]) { + self.board.set_positions(color, positions); + self.move_rules.board.set_positions(color, positions); } fn get_jans(&self, board_ini: &Board, dice_rolls_count: u8) -> PossibleJans { @@ -381,6 +382,7 @@ impl PointsRules { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { let jans = self.get_jans(&self.board, dice_rolls_count); + info!("jans : {jans:?}"); let points_jans = jans.clone(); (jans, self.get_jans_points(points_jans)) } @@ -481,9 +483,12 @@ mod tests { #[test] fn get_jans_by_dice_order() { let mut rules = PointsRules::default(); - rules.board.set_positions([ - 2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); assert_eq!(1, jans.len()); @@ -495,9 +500,12 @@ mod tests { // On peut passer par une dame battue pour battre une autre dame // mais pas par une case remplie par l'adversaire - rules.board.set_positions([ - 2, 0, -1, -2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 2, 0, -1, -2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let mut jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[3, 2], None, false); @@ -506,25 +514,34 @@ mod tests { jans.merge(jans_revert_dices); assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); - rules.board.set_positions([ - 2, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 2, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); assert_eq!(1, jans.len()); assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); - rules.board.set_positions([ - 2, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 2, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); assert_eq!(1, jans.len()); assert_eq!(1, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); - rules.board.set_positions([ - 2, 0, 1, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 2, 0, 1, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); assert_eq!(1, jans.len()); @@ -533,25 +550,34 @@ mod tests { // corners handling // deux dés bloqués (coin de repos et coin de l'adversaire) - rules.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); // le premier dé traité est le dernier du vecteur : 1 let jans = get_jans_by_ordered_dice(&rules.board, &[2, 1], None, false); // println!("jans (dés bloqués) : {:?}", jans.get(&Jan::TrueHit)); assert_eq!(0, jans.len()); // dé dans son coin de repos : peut tout de même battre à vrai - rules.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let jans = get_jans_by_ordered_dice(&rules.board, &[3, 3], None, false); assert_eq!(1, jans.len()); // premier dé bloqué, mais tout d'une possible en commençant par le second - rules.board.set_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); let mut jans = get_jans_by_ordered_dice(&rules.board, &[3, 1], None, false); let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[1, 3], None, false); assert_eq!(1, jans_revert_dices.len()); @@ -569,169 +595,274 @@ mod tests { // ----- Jan de récompense // Battre à vrai une dame située dans la table des petits jans : 4 + 4 + 4 = 12 let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!(12, rules.get_points(5).0); + // Calcul des points pour noir + let mut board = Board::new(); + board.set_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2, + ], + ); + let mut rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) }); + assert_eq!(12, rules.get_points(5).0); + // Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4 let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, 0, -1, 2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, 0, -1, 2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 4) }); assert_eq!(4, rules.get_points(5).0); // Battre à vrai une dame située dans la table des grands jans : 2 let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, -2, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, -2, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 4) }); assert_eq!((2, 2), rules.get_points(5)); // Battre à vrai le coin adverse par doublet : 6 - rules.update_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 2) }); assert_eq!(6, rules.get_points(5).0); // Cas de battage du coin de repos adverse impossible - rules.update_positions([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!(0, rules.get_points(5).0); // ---- Jan de remplissage // Faire un petit jan : 4 - rules.update_positions([ - 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 1) }); assert_eq!(1, rules.get_jans(&rules.board, 5).len()); assert_eq!(4, rules.get_points(5).0); // Faire un petit jan avec un doublet : 6 - rules.update_positions([ - 2, 3, 1, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 3, 1, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!(6, rules.get_points(5).0); // Faire un petit jan avec 2 moyens : 6 + 6 = 12 - rules.update_positions([ - 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!(12, rules.get_points(5).0); // Conserver un jan avec un doublet : 6 - rules.update_positions([ - 3, 3, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 3, 3, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!(6, rules.get_points(5).0); + // Conserver un jan + rules.update_positions( + &Color::White, + [ + 2, 2, 2, 2, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -1, -2, -3, -5, 0, -1, + ], + ); + rules.set_dice(Dice { values: (3, 1) }); + assert_eq!((4, 0), rules.get_points(8)); + + // Conserver un jan (black) + let mut board = Board::new(); + board.set_positions( + &Color::Black, + [ + 1, 0, 5, 3, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -4, -2, -2, -2, -2, -2, + ], + ); + let rules = PointsRules::new(&Color::Black, &board, Dice { values: (3, 1) }); + assert_eq!((4, 0), rules.get_points(8)); + // ---- Sorties // Sortir toutes ses dames avant l'adversaire (simple) - rules.update_positions([ - 0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - ]); + let mut rules = PointsRules::default(); + rules.update_positions( + &Color::White, + [ + 0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + ], + ); rules.set_dice(Dice { values: (3, 1) }); assert_eq!(4, rules.get_points(5).0); // Sortir toutes ses dames avant l'adversaire (doublet) - rules.update_positions([ - 0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + ], + ); rules.set_dice(Dice { values: (2, 2) }); assert_eq!(6, rules.get_points(5).0); // ---- JANS RARES // Jan de six tables - rules.update_positions([ - 10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!(0, rules.get_points(5).0); - rules.update_positions([ - 10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!(4, rules.get_points(3).0); - rules.update_positions([ - 10, 1, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 10, 1, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!(0, rules.get_points(3).0); - rules.update_positions([ - 10, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 10, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!(0, rules.get_points(3).0); // Jan de deux tables - rules.update_positions([ - 13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 2) }); assert_eq!(6, rules.get_points(5).0); - rules.update_positions([ - 12, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 12, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 2) }); assert_eq!(0, rules.get_points(5).0); // Contre jan de deux tables - rules.update_positions([ - 13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (2, 2) }); assert_eq!((0, 6), rules.get_points(5)); // Jan de mézéas - rules.update_positions([ - 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!(6, rules.get_points(5).0); - rules.update_positions([ - 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (1, 2) }); assert_eq!(4, rules.get_points(5).0); // Contre jan de mézéas - rules.update_positions([ - 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, - ]); + rules.update_positions( + &Color::White, + [ + 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, + ], + ); rules.set_dice(Dice { values: (1, 1) }); assert_eq!((0, 6), rules.get_points(5)); // ---- JANS QUI NE PEUT // Battre à faux une dame située dans la table des petits jans let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 3) }); assert_eq!((0, 4), rules.get_points(5)); // Battre à faux une dame située dans la table des grands jans let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, -2, -1, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, -2, -1, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 4) }); assert_eq!((0, 2), rules.get_points(5)); // Pour chaque dé non jouable (dame impuissante) let mut rules = PointsRules::default(); - rules.update_positions([ - 2, 0, -2, -2, -2, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]); + rules.update_positions( + &Color::White, + [ + 2, 0, -2, -2, -2, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ); rules.set_dice(Dice { values: (2, 4) }); assert_eq!((0, 4), rules.get_points(5)); } From dc80243a1add05ccff52cf727f8700b17a67b376 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 7 Aug 2025 20:42:59 +0200 Subject: [PATCH 2/7] fix black moves --- store/src/game_rules_points.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index 24991eb..ab67236 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -144,7 +144,9 @@ impl PointsRules { } else { board.clone() }; - let move_rules = MoveRules::new(color, &board, dice); + // the board is already reverted for black, so we pretend color is white + let move_rules = MoveRules::new(&Color::White, &board, dice); + // let move_rules = MoveRules::new(color, &board, dice); // let move_rules = MoveRules::new(color, &self.board, dice, moves); Self { @@ -590,6 +592,20 @@ mod tests { // à vrai } + #[test] + fn get_result_jans() { + let mut board = Board::new(); + board.set_positions( + &Color::White, + [ + 0, 0, 5, 2, 4, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -2, -2, -2, -2, -2, -2, + ], + ); + let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) }); + let jans = points_rules.get_result_jans(8); + assert!(jans.0.len() > 0); + } + #[test] fn get_points() { // ----- Jan de récompense @@ -711,7 +727,7 @@ mod tests { // Conserver un jan (black) let mut board = Board::new(); board.set_positions( - &Color::Black, + &Color::White, [ 1, 0, 5, 3, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -4, -2, -2, -2, -2, -2, ], From b02ce8d185b57cdcd6cf948213aed5b5efca4557 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 7 Aug 2025 21:01:40 +0200 Subject: [PATCH 3/7] fix dqn strategy color --- bot/src/strategy/dqn.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index af08341..0248cc5 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -19,8 +19,8 @@ impl Default for DqnStrategy { fn default() -> Self { Self { game: GameState::default(), - player_id: 2, - color: Color::Black, + player_id: 1, + color: Color::White, model: None, } } From bf820ecc4e081bbecc7f29aa910562eeeba97c5e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 16:24:12 +0200 Subject: [PATCH 4/7] feat: bot random strategy --- bot/src/lib.rs | 17 +++++---- bot/src/strategy/dqn.rs | 6 ++- bot/src/strategy/mod.rs | 1 + bot/src/strategy/random.rs | 67 ++++++++++++++++++++++++++++++++++ client_cli/src/app.rs | 11 ++++-- client_cli/src/game_runner.rs | 8 ++-- justfile | 5 ++- store/src/board.rs | 4 +- store/src/game.rs | 10 ++--- store/src/game_rules_points.rs | 4 +- 10 files changed, 106 insertions(+), 27 deletions(-) create mode 100644 bot/src/strategy/random.rs diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 6326253..ca338e1 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,11 +1,12 @@ pub mod dqn; pub mod strategy; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; +pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; pub trait BotStrategy: std::fmt::Debug { @@ -64,7 +65,7 @@ impl Bot { } pub fn handle_event(&mut self, event: &GameEvent) -> Option { - info!(">>>> {:?} BOT handle", self.color); + debug!(">>>> {:?} BOT handle", self.color); let game = self.strategy.get_mut_game(); let internal_event = if self.color == Color::Black { &event.get_mirror() @@ -76,7 +77,7 @@ impl Bot { let turn_stage = game.turn_stage; game.consume(internal_event); if game.stage == Stage::Ended { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); return None; } let active_player_id = if self.color == Color::Black { @@ -91,7 +92,7 @@ impl Bot { if active_player_id == self.player_id { let player_points = game.who_plays().map(|p| (p.points, p.holes)); if self.color == Color::Black { - info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); + debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); } let internal_event = match game.turn_stage { TurnStage::MarkAdvPoints => Some(GameEvent::Mark { @@ -120,15 +121,15 @@ impl Bot { _ => None, }; return if self.color == Color::Black { - info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); - info!("<<<< end {:?} BOT handle", self.color); + debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); + debug!("<<<< end {:?} BOT handle", self.color); internal_event.map(|evt| evt.get_mirror()) } else { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); internal_event }; } - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); None } diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 0248cc5..109a9cf 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,5 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; use std::path::Path; use store::MoveRules; @@ -31,9 +32,10 @@ impl DqnStrategy { Self::default() } - pub fn new_with_model>(model_path: P) -> Self { + pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { let mut strategy = Self::new(); - if let Ok(model) = SimpleNeuralNetwork::load(model_path) { + if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { + info!("Loading model {model_path:?}"); strategy.model = Some(model); } strategy diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 3812188..731d1b1 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -2,4 +2,5 @@ pub mod client; pub mod default; pub mod dqn; pub mod erroneous_moves; +pub mod random; pub mod stable_baselines3; diff --git a/bot/src/strategy/random.rs b/bot/src/strategy/random.rs new file mode 100644 index 0000000..0bfd1c6 --- /dev/null +++ b/bot/src/strategy/random.rs @@ -0,0 +1,67 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use store::MoveRules; + +#[derive(Debug)] +pub struct RandomStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, +} + +impl Default for RandomStrategy { + fn default() -> Self { + let game = GameState::default(); + Self { + game, + player_id: 1, + color: Color::White, + } + } +} + +impl BotStrategy for RandomStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + true + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + use rand::{seq::SliceRandom, thread_rng}; + let mut rng = thread_rng(); + let choosen_move = possible_moves + .choose(&mut rng) + .cloned() + .unwrap_or((CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + } + } +} diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 9b6ab3a..8fb1c9e 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,4 +1,7 @@ -use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; +use bot::{ + BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, + StableBaselines3Strategy, +}; use itertools::Itertools; use crate::game_runner::GameRunner; @@ -32,13 +35,15 @@ impl App { "dummy" => { Some(Box::new(DefaultStrategy::default()) as Box) } + "random" => { + Some(Box::new(RandomStrategy::default()) as Box) + } "erroneous" => { Some(Box::new(ErroneousStrategy::default()) as Box) } "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), - "dqn" => Some(Box::new(DqnStrategy::default()) - as Box), + "dqn" => Some(Box::new(DqnStrategy::default()) as Box), s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) diff --git a/client_cli/src/game_runner.rs b/client_cli/src/game_runner.rs index 296c907..797dbc9 100644 --- a/client_cli/src/game_runner.rs +++ b/client_cli/src/game_runner.rs @@ -1,5 +1,5 @@ use bot::{Bot, BotStrategy}; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; // Application Game @@ -63,19 +63,19 @@ impl GameRunner { return None; } let valid_event = if self.state.validate(event) { - info!( + debug!( "--------------- new valid event {event:?} (stage {:?}) -----------", self.state.turn_stage ); self.state.consume(event); - info!( + debug!( " --> stage {:?} ; active player points {:?}", self.state.turn_stage, self.state.who_plays().map(|p| p.points) ); event } else { - info!("{}", self.state); + debug!("{}", self.state); error!("event not valid : {event:?}"); panic!("crash and burn"); &GameEvent::PlayError diff --git a/justfile b/justfile index 16f56ce..0501ded 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,7 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy + cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli @@ -21,6 +21,9 @@ profile: pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +trainsimple: + cargo build --release --bin=train_dqn + LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok diff --git a/store/src/board.rs b/store/src/board.rs index 3e563d0..a838f10 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -37,7 +37,7 @@ impl Default for CheckerMove { impl CheckerMove { pub fn to_display_string(self) -> String { - format!("{:?} ", self) + format!("{self:?} ") } pub fn new(from: Field, to: Field) -> Result { @@ -569,7 +569,7 @@ impl Board { } let checker_color = self.get_checkers_color(field)?; if Some(color) != checker_color { - println!("field invalid : {:?}, {:?}, {:?}", color, field, self); + println!("field invalid : {color:?}, {field:?}, {self:?}"); return Err(Error::FieldInvalid); } let unit = match color { diff --git a/store/src/game.rs b/store/src/game.rs index c9995b8..200c321 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -4,7 +4,7 @@ use crate::dice::Dice; use crate::game_rules_moves::MoveRules; use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::player::{Color, Player, PlayerId}; -use log::{error, info}; +use log::{debug, error, info}; // use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -521,14 +521,14 @@ impl GameState { self.inc_roll_count(self.active_player_id); self.turn_stage = TurnStage::MarkPoints; (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); - info!("points from result : {:?}", self.dice_points); + debug!("points from result : {:?}", self.dice_points); if !self.schools_enabled { // Schools are not enabled. We mark points automatically // the points earned by the opponent will be marked on its turn let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); if new_hole { let holes_count = self.get_active_player().unwrap().holes; - info!("new hole -> {holes_count:?}"); + debug!("new hole -> {holes_count:?}"); if holes_count > 12 { self.stage = Stage::Ended; } else { @@ -606,7 +606,7 @@ impl GameState { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { let player = &self.players.get(&self.active_player_id).unwrap(); - info!( + debug!( "get rollresult for {:?} {:?} {:?} (roll count {:?})", player.color, self.board, dice, player.dice_roll_count ); @@ -654,7 +654,7 @@ impl GameState { // if points > 0 && p.holes > 15 { if points > 0 { - info!( + debug!( "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", p.holes, p.points ) diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index ab67236..c8ea334 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -5,7 +5,7 @@ use crate::player::Color; use crate::CheckerMove; use crate::Error; -use log::info; +use log::debug; use serde::{Deserialize, Serialize}; use std::cmp; use std::collections::HashMap; @@ -384,7 +384,7 @@ impl PointsRules { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { let jans = self.get_jans(&self.board, dice_rolls_count); - info!("jans : {jans:?}"); + debug!("jans : {jans:?}"); let points_jans = jans.clone(); (jans, self.get_jans_points(points_jans)) } From 1b58ca4ccc3220a98e5d6f9e753186116f2ed8aa Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 17:07:34 +0200 Subject: [PATCH 5/7] refact dqn burn demo --- bot/src/dqn/burnrl/main.rs | 44 ++++---------------- bot/src/dqn/burnrl/utils.rs | 39 ++++++++++++++++-- bot/src/strategy/dqn.rs | 82 ++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 82 deletions(-) diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 7b4584c..8408e6a 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,9 +1,10 @@ -use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::Module; -use burn::record::{CompactRecorder, Recorder}; +use bot::dqn::burnrl::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; -use burn_rl::base::{Action, Agent, ElemType, Environment, State}; +use burn_rl::base::ElemType; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; @@ -25,12 +26,9 @@ fn main() { println!("> Sauvegarde du modèle de validation"); - let path = "models/burn_dqn_50".to_string(); + let path = "models/burn_dqn_40".to_string(); save_model(valid_agent.model().as_ref().unwrap(), &path); - // println!("> Test avec le modèle entraîné"); - // demo_model::(valid_agent); - println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); let loaded_agent = DQN::new(loaded_model); @@ -38,31 +36,3 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - -fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.mpk", path); - println!("Modèle de validation sauvegardé : {}", model_path); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { - let model_path = format!("{}_model.mpk", path); - println!("Chargement du modèle depuis : {}", model_path); - - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) -} diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index ba04cb6..66fa850 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -1,12 +1,45 @@ -use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::dqn::burnrl::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; use crate::dqn::dqn_common::get_valid_action_indices; -use burn::module::{Param, ParamId}; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::backend::Backend; use burn::tensor::cast::ToElement; use burn::tensor::Tensor; use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{ElemType, Environment, State}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { + let model_path = format!("{path}_model.mpk"); + println!("Chargement du modèle depuis : {model_path}"); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) +} pub fn demo_model>(agent: DQN) { let mut env = TrictracEnvironment::new(true); diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 109a9cf..34fb853 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -114,50 +114,48 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { - dice_order, - from1, - from2, - } = action - { - let dicevals = self.game.dice.values; - let (mut dice1, mut dice2) = if dice_order { - (dicevals.0, dicevals.1) - } else { - (dicevals.1, dicevals.0) - }; + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; - if from1 == 0 { - // empty move - dice1 = 0; - } - let mut to1 = from1 + dice1 as usize; - if 24 < to1 { - // sortie - to1 = 0; - } - if from2 == 0 { - // empty move - dice2 = 0; - } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; - } - - let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); - - let chosen_move = if self.color == Color::White { - (checker_move1, checker_move2) - } else { - (checker_move1.mirror(), checker_move2.mirror()) - }; - - return chosen_move; + if from1 == 0 { + // empty move + dice1 = 0; } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; } // Fallback : utiliser la stratégie par défaut From a19c5d8596ed9372df107b07bd6d9f53bb537bb7 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 18:58:21 +0200 Subject: [PATCH 6/7] refact dqn simple --- bot/Cargo.toml | 4 +- bot/src/dqn/dqn_common.rs | 151 ----------------- bot/src/dqn/simple/dqn_model.rs | 154 ++++++++++++++++++ bot/src/dqn/simple/dqn_trainer.rs | 3 +- .../{bin/train_dqn.rs => dqn/simple/main.rs} | 3 +- bot/src/dqn/simple/mod.rs | 1 + bot/src/strategy/dqn.rs | 5 +- justfile | 4 +- 8 files changed, 165 insertions(+), 160 deletions(-) create mode 100644 bot/src/dqn/simple/dqn_model.rs rename bot/src/{bin/train_dqn.rs => dqn/simple/main.rs} (97%) diff --git a/bot/Cargo.toml b/bot/Cargo.toml index a5667fa..68ff52d 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -10,8 +10,8 @@ name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" [[bin]] -name = "train_dqn" -path = "src/bin/train_dqn.rs" +name = "train_dqn_simple" +path = "src/dqn/simple/main.rs" [dependencies] pretty_assertions = "1.4.0" diff --git a/bot/src/dqn/dqn_common.rs b/bot/src/dqn/dqn_common.rs index 3ea0738..2da4aa5 100644 --- a/bot/src/dqn/dqn_common.rs +++ b/bot/src/dqn/dqn_common.rs @@ -106,157 +106,6 @@ impl TrictracAction { // } } -/// Configuration pour l'agent DQN -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DqnConfig { - pub state_size: usize, - pub hidden_size: usize, - pub num_actions: usize, - pub learning_rate: f64, - pub gamma: f64, - pub epsilon: f64, - pub epsilon_decay: f64, - pub epsilon_min: f64, - pub replay_buffer_size: usize, - pub batch_size: usize, -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - state_size: 36, - hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi - num_actions: TrictracAction::action_space_size(), - learning_rate: 0.001, - gamma: 0.99, - epsilon: 0.1, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - } - } -} - -/// Réseau de neurones DQN simplifié (matrice de poids basique) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleNeuralNetwork { - pub weights1: Vec>, - pub biases1: Vec, - pub weights2: Vec>, - pub biases2: Vec, - pub weights3: Vec>, - pub biases3: Vec, -} - -impl SimpleNeuralNetwork { - pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - - // Initialisation aléatoire des poids avec Xavier/Glorot - let scale1 = (2.0 / input_size as f32).sqrt(); - let weights1 = (0..hidden_size) - .map(|_| { - (0..input_size) - .map(|_| rng.gen_range(-scale1..scale1)) - .collect() - }) - .collect(); - let biases1 = vec![0.0; hidden_size]; - - let scale2 = (2.0 / hidden_size as f32).sqrt(); - let weights2 = (0..hidden_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale2..scale2)) - .collect() - }) - .collect(); - let biases2 = vec![0.0; hidden_size]; - - let scale3 = (2.0 / hidden_size as f32).sqrt(); - let weights3 = (0..output_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale3..scale3)) - .collect() - }) - .collect(); - let biases3 = vec![0.0; output_size]; - - Self { - weights1, - biases1, - weights2, - biases2, - weights3, - biases3, - } - } - - pub fn forward(&self, input: &[f32]) -> Vec { - // Première couche - let mut layer1: Vec = self.biases1.clone(); - for (i, neuron_weights) in self.weights1.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < input.len() { - layer1[i] += input[j] * weight; - } - } - layer1[i] = layer1[i].max(0.0); // ReLU - } - - // Deuxième couche - let mut layer2: Vec = self.biases2.clone(); - for (i, neuron_weights) in self.weights2.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer1.len() { - layer2[i] += layer1[j] * weight; - } - } - layer2[i] = layer2[i].max(0.0); // ReLU - } - - // Couche de sortie - let mut output: Vec = self.biases3.clone(); - for (i, neuron_weights) in self.weights3.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer2.len() { - output[i] += layer2[j] * weight; - } - } - } - - output - } - - pub fn get_best_action(&self, input: &[f32]) -> usize { - let q_values = self.forward(input); - q_values - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .map(|(index, _)| index) - .unwrap_or(0) - } - - pub fn save>( - &self, - path: P, - ) -> Result<(), Box> { - let data = serde_json::to_string_pretty(self)?; - std::fs::write(path, data)?; - Ok(()) - } - - pub fn load>(path: P) -> Result> { - let data = std::fs::read_to_string(path)?; - let network = serde_json::from_str(&data)?; - Ok(network) - } -} - /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { use store::TurnStage; diff --git a/bot/src/dqn/simple/dqn_model.rs b/bot/src/dqn/simple/dqn_model.rs new file mode 100644 index 0000000..ba46212 --- /dev/null +++ b/bot/src/dqn/simple/dqn_model.rs @@ -0,0 +1,154 @@ +use crate::dqn::dqn_common::TrictracAction; +use serde::{Deserialize, Serialize}; + +/// Configuration pour l'agent DQN +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DqnConfig { + pub state_size: usize, + pub hidden_size: usize, + pub num_actions: usize, + pub learning_rate: f64, + pub gamma: f64, + pub epsilon: f64, + pub epsilon_decay: f64, + pub epsilon_min: f64, + pub replay_buffer_size: usize, + pub batch_size: usize, +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + state_size: 36, + hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi + num_actions: TrictracAction::action_space_size(), + learning_rate: 0.001, + gamma: 0.99, + epsilon: 0.1, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + } + } +} + +/// Réseau de neurones DQN simplifié (matrice de poids basique) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleNeuralNetwork { + pub weights1: Vec>, + pub biases1: Vec, + pub weights2: Vec>, + pub biases2: Vec, + pub weights3: Vec>, + pub biases3: Vec, +} + +impl SimpleNeuralNetwork { + pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + + // Initialisation aléatoire des poids avec Xavier/Glorot + let scale1 = (2.0 / input_size as f32).sqrt(); + let weights1 = (0..hidden_size) + .map(|_| { + (0..input_size) + .map(|_| rng.gen_range(-scale1..scale1)) + .collect() + }) + .collect(); + let biases1 = vec![0.0; hidden_size]; + + let scale2 = (2.0 / hidden_size as f32).sqrt(); + let weights2 = (0..hidden_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale2..scale2)) + .collect() + }) + .collect(); + let biases2 = vec![0.0; hidden_size]; + + let scale3 = (2.0 / hidden_size as f32).sqrt(); + let weights3 = (0..output_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale3..scale3)) + .collect() + }) + .collect(); + let biases3 = vec![0.0; output_size]; + + Self { + weights1, + biases1, + weights2, + biases2, + weights3, + biases3, + } + } + + pub fn forward(&self, input: &[f32]) -> Vec { + // Première couche + let mut layer1: Vec = self.biases1.clone(); + for (i, neuron_weights) in self.weights1.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < input.len() { + layer1[i] += input[j] * weight; + } + } + layer1[i] = layer1[i].max(0.0); // ReLU + } + + // Deuxième couche + let mut layer2: Vec = self.biases2.clone(); + for (i, neuron_weights) in self.weights2.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer1.len() { + layer2[i] += layer1[j] * weight; + } + } + layer2[i] = layer2[i].max(0.0); // ReLU + } + + // Couche de sortie + let mut output: Vec = self.biases3.clone(); + for (i, neuron_weights) in self.weights3.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer2.len() { + output[i] += layer2[j] * weight; + } + } + } + + output + } + + pub fn get_best_action(&self, input: &[f32]) -> usize { + let q_values = self.forward(input); + q_values + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(index, _)| index) + .unwrap_or(0) + } + + pub fn save>( + &self, + path: P, + ) -> Result<(), Box> { + let data = serde_json::to_string_pretty(self)?; + std::fs::write(path, data)?; + Ok(()) + } + + pub fn load>(path: P) -> Result> { + let data = std::fs::read_to_string(path)?; + let network = serde_json::from_str(&data)?; + Ok(network) + } +} + diff --git a/bot/src/dqn/simple/dqn_trainer.rs b/bot/src/dqn/simple/dqn_trainer.rs index dedf382..78e6dc7 100644 --- a/bot/src/dqn/simple/dqn_trainer.rs +++ b/bot/src/dqn/simple/dqn_trainer.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize}; use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use crate::dqn::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction}; +use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; +use crate::dqn::dqn_common::{get_valid_actions, TrictracAction}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/bot/src/bin/train_dqn.rs b/bot/src/dqn/simple/main.rs similarity index 97% rename from bot/src/bin/train_dqn.rs rename to bot/src/dqn/simple/main.rs index e0929fb..30fd933 100644 --- a/bot/src/bin/train_dqn.rs +++ b/bot/src/dqn/simple/main.rs @@ -1,4 +1,5 @@ -use bot::dqn::dqn_common::{DqnConfig, TrictracAction}; +use bot::dqn::dqn_common::TrictracAction; +use bot::dqn::simple::dqn_model::DqnConfig; use bot::dqn::simple::dqn_trainer::DqnTrainer; use std::env; diff --git a/bot/src/dqn/simple/mod.rs b/bot/src/dqn/simple/mod.rs index 114bd10..8090a29 100644 --- a/bot/src/dqn/simple/mod.rs +++ b/bot/src/dqn/simple/mod.rs @@ -1 +1,2 @@ +pub mod dqn_model; pub mod dqn_trainer; diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 34fb853..cf24684 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -3,9 +3,8 @@ use log::info; use std::path::Path; use store::MoveRules; -use crate::dqn::dqn_common::{ - get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, -}; +use crate::dqn::dqn_common::{get_valid_actions, sample_valid_action, TrictracAction}; +use crate::dqn::simple::dqn_model::SimpleNeuralNetwork; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] diff --git a/justfile b/justfile index 0501ded..32193af 100644 --- a/justfile +++ b/justfile @@ -22,8 +22,8 @@ pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainsimple: - cargo build --release --bin=train_dqn - LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out + cargo build --release --bin=train_dqn_simple + LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok From 17d29b86330429ca652b9c8f8a91a764e0ccc7dd Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 21:31:38 +0200 Subject: [PATCH 7/7] runcli with bot dqn burn-rl --- bot/src/dqn/burnrl/environment.rs | 8 +- bot/src/dqn/burnrl/main.rs | 2 +- bot/src/dqn/burnrl/utils.rs | 28 +++-- bot/src/lib.rs | 1 + bot/src/strategy/dqnburn.rs | 176 ++++++++++++++++++++++++++++++ bot/src/strategy/mod.rs | 1 + client_cli/src/app.rs | 24 ++-- justfile | 3 +- 8 files changed, 212 insertions(+), 31 deletions(-) create mode 100644 bot/src/strategy/dqnburn.rs diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index 5716fa1..d5e0028 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -141,7 +141,7 @@ impl Environment for TrictracEnvironment { self.step_count += 1; // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = self.convert_action(action, &self.game); + let trictrac_action = Self::convert_action(action); let mut reward = 0.0; let mut terminated = false; @@ -203,11 +203,7 @@ impl Environment for TrictracEnvironment { impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac - fn convert_action( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { + pub fn convert_action(action: TrictracAction) -> Option { dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) } diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 8408e6a..4b3a789 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -31,7 +31,7 @@ fn main() { println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model); + let loaded_agent = DQN::new(loaded_model.unwrap()); println!("> Test avec le modèle chargé"); demo_model(loaded_agent); diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index 66fa850..a1d5480 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -22,23 +22,21 @@ pub fn save_model(model: &dqn_model::Net>, path: &String) { .unwrap(); } -pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { +pub fn load_model(dense_size: usize, path: &String) -> Option>> { let model_path = format!("{path}_model.mpk"); - println!("Chargement du modèle depuis : {model_path}"); + // println!("Chargement du modèle depuis : {model_path}"); - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() } pub fn demo_model>(agent: DQN) { diff --git a/bot/src/lib.rs b/bot/src/lib.rs index ca338e1..f9a4617 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -5,6 +5,7 @@ use log::{debug, error}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; +pub use strategy::dqnburn::DqnBurnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs new file mode 100644 index 0000000..4fc0c06 --- /dev/null +++ b/bot/src/strategy/dqnburn.rs @@ -0,0 +1,176 @@ +use burn::backend::NdArray; +use burn::tensor::cast::ToElement; +use burn_rl::base::{ElemType, Model, State}; + +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; +use store::MoveRules; + +use crate::dqn::burnrl::{dqn_model, environment, utils}; +use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; + +type DqnBurnNetwork = dqn_model::Net>; + +/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné +#[derive(Debug)] +pub struct DqnBurnStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, + pub model: Option, +} + +impl Default for DqnBurnStrategy { + fn default() -> Self { + Self { + game: GameState::default(), + player_id: 1, + color: Color::White, + model: None, + } + } +} + +impl DqnBurnStrategy { + pub fn new() -> Self { + Self::default() + } + + pub fn new_with_model(model_path: &String) -> Self { + info!("Loading model {model_path:?}"); + let mut strategy = Self::new(); + strategy.model = utils::load_model(256, model_path); + strategy + } + + /// Utilise le modèle DQN pour choisir une action valide + fn get_dqn_action(&self) -> Option { + if let Some(ref model) = self.model { + let state = environment::TrictracState::from_game_state(&self.game); + let valid_actions_indices = get_valid_action_indices(&self.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + + // Obtenir les Q-values pour toutes les actions + let q_values = model.infer(state.to_tensor().unsqueeze()); + + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + environment::TrictracEnvironment::convert_action(environment::TrictracAction::from( + action_index, + )) + } else { + // Fallback : action aléatoire valide + sample_valid_action(&self.game) + } + } +} + +impl BotStrategy for DqnBurnStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + // Utiliser le DQN pour décider si on continue + if let Some(action) = self.get_dqn_action() { + matches!(action, TrictracAction::Go) + } else { + // Fallback : toujours continuer + true + } + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + // Utiliser le DQN pour choisir le mouvement + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; + } + + // Fallback : utiliser la stratégie par défaut + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + let chosen_move = *possible_moves + .first() + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + chosen_move + } else { + (chosen_move.0.mirror(), chosen_move.1.mirror()) + } + } +} diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 731d1b1..b9fa3b2 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -1,6 +1,7 @@ pub mod client; pub mod default; pub mod dqn; +pub mod dqnburn; pub mod erroneous_moves; pub mod random; pub mod stable_baselines3; diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 8fb1c9e..519adf1 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,5 +1,5 @@ use bot::{ - BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, + BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, }; use itertools::Itertools; @@ -25,11 +25,11 @@ pub struct App { impl App { // Constructs a new instance of [`App`]. pub fn new(args: AppArgs) -> Self { - let bot_strategies: Vec> = - args.bot - .as_deref() - .map(|str_bots| { - str_bots + let bot_strategies: Vec> = args + .bot + .as_deref() + .map(|str_bots| { + str_bots .split(",") .filter_map(|s| match s.trim() { "dummy" => { @@ -44,6 +44,9 @@ impl App { "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), "dqn" => Some(Box::new(DqnStrategy::default()) as Box), + "dqnburn" => { + Some(Box::new(DqnBurnStrategy::default()) as Box) + } s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) @@ -54,11 +57,16 @@ impl App { Some(Box::new(DqnStrategy::new_with_model(path)) as Box) } + s if s.starts_with("dqnburn:") => { + let path = s.trim_start_matches("dqnburn:"); + Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}"))) + as Box) + } _ => None, }) .collect() - }) - .unwrap_or_default(); + }) + .unwrap_or_default(); let schools_enabled = false; let should_quit = bot_strategies.len() > 1; Self { diff --git a/justfile b/justfile index 32193af..dcb5117 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,8 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy + cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk + #cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli