From bf820ecc4e081bbecc7f29aa910562eeeba97c5e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 16:24:12 +0200 Subject: [PATCH 1/2] feat: bot random strategy --- bot/src/lib.rs | 17 +++++---- bot/src/strategy/dqn.rs | 6 ++- bot/src/strategy/mod.rs | 1 + bot/src/strategy/random.rs | 67 ++++++++++++++++++++++++++++++++++ client_cli/src/app.rs | 11 ++++-- client_cli/src/game_runner.rs | 8 ++-- justfile | 5 ++- store/src/board.rs | 4 +- store/src/game.rs | 10 ++--- store/src/game_rules_points.rs | 4 +- 10 files changed, 106 insertions(+), 27 deletions(-) create mode 100644 bot/src/strategy/random.rs diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 6326253..ca338e1 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,11 +1,12 @@ pub mod dqn; pub mod strategy; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; +pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; pub trait BotStrategy: std::fmt::Debug { @@ -64,7 +65,7 @@ impl Bot { } pub fn handle_event(&mut self, event: &GameEvent) -> Option { - info!(">>>> {:?} BOT handle", self.color); + debug!(">>>> {:?} BOT handle", self.color); let game = self.strategy.get_mut_game(); let internal_event = if self.color == Color::Black { &event.get_mirror() @@ -76,7 +77,7 @@ impl Bot { let turn_stage = game.turn_stage; game.consume(internal_event); if game.stage == Stage::Ended { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); return None; } let active_player_id = if self.color == Color::Black { @@ -91,7 +92,7 @@ impl Bot { if active_player_id == self.player_id { let player_points = game.who_plays().map(|p| (p.points, p.holes)); if self.color == Color::Black { - info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); + debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); } let internal_event = match game.turn_stage { TurnStage::MarkAdvPoints => Some(GameEvent::Mark { @@ -120,15 +121,15 @@ impl Bot { _ => None, }; return if self.color == Color::Black { - info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); - info!("<<<< end {:?} BOT handle", self.color); + debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); + debug!("<<<< end {:?} BOT handle", self.color); internal_event.map(|evt| evt.get_mirror()) } else { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); internal_event }; } - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); None } diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 0248cc5..109a9cf 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,5 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; use std::path::Path; use store::MoveRules; @@ -31,9 +32,10 @@ impl DqnStrategy { Self::default() } - pub fn new_with_model>(model_path: P) -> Self { + pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { let mut strategy = Self::new(); - if let Ok(model) = SimpleNeuralNetwork::load(model_path) { + if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { + info!("Loading model {model_path:?}"); strategy.model = Some(model); } strategy diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 3812188..731d1b1 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -2,4 +2,5 @@ pub mod client; pub mod default; pub mod dqn; pub mod erroneous_moves; +pub mod random; pub mod stable_baselines3; diff --git a/bot/src/strategy/random.rs b/bot/src/strategy/random.rs new file mode 100644 index 0000000..0bfd1c6 --- /dev/null +++ b/bot/src/strategy/random.rs @@ -0,0 +1,67 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use store::MoveRules; + +#[derive(Debug)] +pub struct RandomStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, +} + +impl Default for RandomStrategy { + fn default() -> Self { + let game = GameState::default(); + Self { + game, + player_id: 1, + color: Color::White, + } + } +} + +impl BotStrategy for RandomStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + true + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + use rand::{seq::SliceRandom, thread_rng}; + let mut rng = thread_rng(); + let choosen_move = possible_moves + .choose(&mut rng) + .cloned() + .unwrap_or((CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + } + } +} diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 9b6ab3a..8fb1c9e 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,4 +1,7 @@ -use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; +use bot::{ + BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, + StableBaselines3Strategy, +}; use itertools::Itertools; use crate::game_runner::GameRunner; @@ -32,13 +35,15 @@ impl App { "dummy" => { Some(Box::new(DefaultStrategy::default()) as Box) } + "random" => { + Some(Box::new(RandomStrategy::default()) as Box) + } "erroneous" => { Some(Box::new(ErroneousStrategy::default()) as Box) } "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), - "dqn" => Some(Box::new(DqnStrategy::default()) - as Box), + "dqn" => Some(Box::new(DqnStrategy::default()) as Box), s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) diff --git a/client_cli/src/game_runner.rs b/client_cli/src/game_runner.rs index 296c907..797dbc9 100644 --- a/client_cli/src/game_runner.rs +++ b/client_cli/src/game_runner.rs @@ -1,5 +1,5 @@ use bot::{Bot, BotStrategy}; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; // Application Game @@ -63,19 +63,19 @@ impl GameRunner { return None; } let valid_event = if self.state.validate(event) { - info!( + debug!( "--------------- new valid event {event:?} (stage {:?}) -----------", self.state.turn_stage ); self.state.consume(event); - info!( + debug!( " --> stage {:?} ; active player points {:?}", self.state.turn_stage, self.state.who_plays().map(|p| p.points) ); event } else { - info!("{}", self.state); + debug!("{}", self.state); error!("event not valid : {event:?}"); panic!("crash and burn"); &GameEvent::PlayError diff --git a/justfile b/justfile index 16f56ce..0501ded 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,7 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy + cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli @@ -21,6 +21,9 @@ profile: pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +trainsimple: + cargo build --release --bin=train_dqn + LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok diff --git a/store/src/board.rs b/store/src/board.rs index 3e563d0..a838f10 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -37,7 +37,7 @@ impl Default for CheckerMove { impl CheckerMove { pub fn to_display_string(self) -> String { - format!("{:?} ", self) + format!("{self:?} ") } pub fn new(from: Field, to: Field) -> Result { @@ -569,7 +569,7 @@ impl Board { } let checker_color = self.get_checkers_color(field)?; if Some(color) != checker_color { - println!("field invalid : {:?}, {:?}, {:?}", color, field, self); + println!("field invalid : {color:?}, {field:?}, {self:?}"); return Err(Error::FieldInvalid); } let unit = match color { diff --git a/store/src/game.rs b/store/src/game.rs index c9995b8..200c321 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -4,7 +4,7 @@ use crate::dice::Dice; use crate::game_rules_moves::MoveRules; use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::player::{Color, Player, PlayerId}; -use log::{error, info}; +use log::{debug, error, info}; // use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -521,14 +521,14 @@ impl GameState { self.inc_roll_count(self.active_player_id); self.turn_stage = TurnStage::MarkPoints; (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); - info!("points from result : {:?}", self.dice_points); + debug!("points from result : {:?}", self.dice_points); if !self.schools_enabled { // Schools are not enabled. We mark points automatically // the points earned by the opponent will be marked on its turn let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); if new_hole { let holes_count = self.get_active_player().unwrap().holes; - info!("new hole -> {holes_count:?}"); + debug!("new hole -> {holes_count:?}"); if holes_count > 12 { self.stage = Stage::Ended; } else { @@ -606,7 +606,7 @@ impl GameState { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { let player = &self.players.get(&self.active_player_id).unwrap(); - info!( + debug!( "get rollresult for {:?} {:?} {:?} (roll count {:?})", player.color, self.board, dice, player.dice_roll_count ); @@ -654,7 +654,7 @@ impl GameState { // if points > 0 && p.holes > 15 { if points > 0 { - info!( + debug!( "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", p.holes, p.points ) diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index ab67236..c8ea334 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -5,7 +5,7 @@ use crate::player::Color; use crate::CheckerMove; use crate::Error; -use log::info; +use log::debug; use serde::{Deserialize, Serialize}; use std::cmp; use std::collections::HashMap; @@ -384,7 +384,7 @@ impl PointsRules { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { let jans = self.get_jans(&self.board, dice_rolls_count); - info!("jans : {jans:?}"); + debug!("jans : {jans:?}"); let points_jans = jans.clone(); (jans, self.get_jans_points(points_jans)) } From 1b58ca4ccc3220a98e5d6f9e753186116f2ed8aa Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 8 Aug 2025 17:07:34 +0200 Subject: [PATCH 2/2] refact dqn burn demo --- bot/src/dqn/burnrl/main.rs | 44 ++++---------------- bot/src/dqn/burnrl/utils.rs | 39 ++++++++++++++++-- bot/src/strategy/dqn.rs | 82 ++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 82 deletions(-) diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 7b4584c..8408e6a 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,9 +1,10 @@ -use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::Module; -use burn::record::{CompactRecorder, Recorder}; +use bot::dqn::burnrl::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; use burn_rl::agent::DQN; -use burn_rl::base::{Action, Agent, ElemType, Environment, State}; +use burn_rl::base::ElemType; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; @@ -25,12 +26,9 @@ fn main() { println!("> Sauvegarde du modèle de validation"); - let path = "models/burn_dqn_50".to_string(); + let path = "models/burn_dqn_40".to_string(); save_model(valid_agent.model().as_ref().unwrap(), &path); - // println!("> Test avec le modèle entraîné"); - // demo_model::(valid_agent); - println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); let loaded_agent = DQN::new(loaded_model); @@ -38,31 +36,3 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - -fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{}_model.mpk", path); - println!("Modèle de validation sauvegardé : {}", model_path); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { - let model_path = format!("{}_model.mpk", path); - println!("Chargement du modèle depuis : {}", model_path); - - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) -} diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index ba04cb6..66fa850 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -1,12 +1,45 @@ -use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::dqn::burnrl::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; use crate::dqn::dqn_common::get_valid_action_indices; -use burn::module::{Param, ParamId}; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; use burn::tensor::backend::Backend; use burn::tensor::cast::ToElement; use burn::tensor::Tensor; use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{ElemType, Environment, State}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { + let model_path = format!("{path}_model.mpk"); + println!("Chargement du modèle depuis : {model_path}"); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) +} pub fn demo_model>(agent: DQN) { let mut env = TrictracEnvironment::new(true); diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 109a9cf..34fb853 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -114,50 +114,48 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { - dice_order, - from1, - from2, - } = action - { - let dicevals = self.game.dice.values; - let (mut dice1, mut dice2) = if dice_order { - (dicevals.0, dicevals.1) - } else { - (dicevals.1, dicevals.0) - }; + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; - if from1 == 0 { - // empty move - dice1 = 0; - } - let mut to1 = from1 + dice1 as usize; - if 24 < to1 { - // sortie - to1 = 0; - } - if from2 == 0 { - // empty move - dice2 = 0; - } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; - } - - let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); - - let chosen_move = if self.color == Color::White { - (checker_move1, checker_move2) - } else { - (checker_move1.mirror(), checker_move2.mirror()) - }; - - return chosen_move; + if from1 == 0 { + // empty move + dice1 = 0; } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; } // Fallback : utiliser la stratégie par défaut