diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 6326253..ca338e1 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,11 +1,12 @@ pub mod dqn; pub mod strategy; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; +pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; pub trait BotStrategy: std::fmt::Debug { @@ -64,7 +65,7 @@ impl Bot { } pub fn handle_event(&mut self, event: &GameEvent) -> Option { - info!(">>>> {:?} BOT handle", self.color); + debug!(">>>> {:?} BOT handle", self.color); let game = self.strategy.get_mut_game(); let internal_event = if self.color == Color::Black { &event.get_mirror() @@ -76,7 +77,7 @@ impl Bot { let turn_stage = game.turn_stage; game.consume(internal_event); if game.stage == Stage::Ended { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); return None; } let active_player_id = if self.color == Color::Black { @@ -91,7 +92,7 @@ impl Bot { if active_player_id == self.player_id { let player_points = game.who_plays().map(|p| (p.points, p.holes)); if self.color == Color::Black { - info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); + debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); } let internal_event = match game.turn_stage { TurnStage::MarkAdvPoints => Some(GameEvent::Mark { @@ -120,15 +121,15 @@ impl Bot { _ => None, }; return if self.color == Color::Black { - info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); - info!("<<<< end {:?} BOT handle", self.color); + debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); + debug!("<<<< end {:?} BOT handle", self.color); internal_event.map(|evt| evt.get_mirror()) } else { - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); internal_event }; } - info!("<<<< end {:?} BOT handle", self.color); + debug!("<<<< end {:?} BOT handle", self.color); None } diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 0248cc5..109a9cf 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,5 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; use std::path::Path; use store::MoveRules; @@ -31,9 +32,10 @@ impl DqnStrategy { Self::default() } - pub fn new_with_model>(model_path: P) -> Self { + pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { let mut strategy = Self::new(); - if let Ok(model) = SimpleNeuralNetwork::load(model_path) { + if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { + info!("Loading model {model_path:?}"); strategy.model = Some(model); } strategy diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 3812188..731d1b1 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -2,4 +2,5 @@ pub mod client; pub mod default; pub mod dqn; pub mod erroneous_moves; +pub mod random; pub mod stable_baselines3; diff --git a/bot/src/strategy/random.rs b/bot/src/strategy/random.rs new file mode 100644 index 0000000..0bfd1c6 --- /dev/null +++ b/bot/src/strategy/random.rs @@ -0,0 +1,67 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use store::MoveRules; + +#[derive(Debug)] +pub struct RandomStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, +} + +impl Default for RandomStrategy { + fn default() -> Self { + let game = GameState::default(); + Self { + game, + player_id: 1, + color: Color::White, + } + } +} + +impl BotStrategy for RandomStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + true + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + use rand::{seq::SliceRandom, thread_rng}; + let mut rng = thread_rng(); + let choosen_move = possible_moves + .choose(&mut rng) + .cloned() + .unwrap_or((CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + } + } +} diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 9b6ab3a..8fb1c9e 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,4 +1,7 @@ -use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; +use bot::{ + BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, + StableBaselines3Strategy, +}; use itertools::Itertools; use crate::game_runner::GameRunner; @@ -32,13 +35,15 @@ impl App { "dummy" => { Some(Box::new(DefaultStrategy::default()) as Box) } + "random" => { + Some(Box::new(RandomStrategy::default()) as Box) + } "erroneous" => { Some(Box::new(ErroneousStrategy::default()) as Box) } "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), - "dqn" => Some(Box::new(DqnStrategy::default()) - as Box), + "dqn" => Some(Box::new(DqnStrategy::default()) as Box), s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) diff --git a/client_cli/src/game_runner.rs b/client_cli/src/game_runner.rs index 296c907..797dbc9 100644 --- a/client_cli/src/game_runner.rs +++ b/client_cli/src/game_runner.rs @@ -1,5 +1,5 @@ use bot::{Bot, BotStrategy}; -use log::{error, info}; +use log::{debug, error}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; // Application Game @@ -63,19 +63,19 @@ impl GameRunner { return None; } let valid_event = if self.state.validate(event) { - info!( + debug!( "--------------- new valid event {event:?} (stage {:?}) -----------", self.state.turn_stage ); self.state.consume(event); - info!( + debug!( " --> stage {:?} ; active player points {:?}", self.state.turn_stage, self.state.who_plays().map(|p| p.points) ); event } else { - info!("{}", self.state); + debug!("{}", self.state); error!("event not valid : {event:?}"); panic!("crash and burn"); &GameEvent::PlayError diff --git a/justfile b/justfile index 16f56ce..0501ded 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,7 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy + cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli @@ -21,6 +21,9 @@ profile: pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +trainsimple: + cargo build --release --bin=train_dqn + LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok diff --git a/store/src/board.rs b/store/src/board.rs index 3e563d0..a838f10 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -37,7 +37,7 @@ impl Default for CheckerMove { impl CheckerMove { pub fn to_display_string(self) -> String { - format!("{:?} ", self) + format!("{self:?} ") } pub fn new(from: Field, to: Field) -> Result { @@ -569,7 +569,7 @@ impl Board { } let checker_color = self.get_checkers_color(field)?; if Some(color) != checker_color { - println!("field invalid : {:?}, {:?}, {:?}", color, field, self); + println!("field invalid : {color:?}, {field:?}, {self:?}"); return Err(Error::FieldInvalid); } let unit = match color { diff --git a/store/src/game.rs b/store/src/game.rs index c9995b8..200c321 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -4,7 +4,7 @@ use crate::dice::Dice; use crate::game_rules_moves::MoveRules; use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::player::{Color, Player, PlayerId}; -use log::{error, info}; +use log::{debug, error, info}; // use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -521,14 +521,14 @@ impl GameState { self.inc_roll_count(self.active_player_id); self.turn_stage = TurnStage::MarkPoints; (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); - info!("points from result : {:?}", self.dice_points); + debug!("points from result : {:?}", self.dice_points); if !self.schools_enabled { // Schools are not enabled. We mark points automatically // the points earned by the opponent will be marked on its turn let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); if new_hole { let holes_count = self.get_active_player().unwrap().holes; - info!("new hole -> {holes_count:?}"); + debug!("new hole -> {holes_count:?}"); if holes_count > 12 { self.stage = Stage::Ended; } else { @@ -606,7 +606,7 @@ impl GameState { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { let player = &self.players.get(&self.active_player_id).unwrap(); - info!( + debug!( "get rollresult for {:?} {:?} {:?} (roll count {:?})", player.color, self.board, dice, player.dice_roll_count ); @@ -654,7 +654,7 @@ impl GameState { // if points > 0 && p.holes > 15 { if points > 0 { - info!( + debug!( "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", p.holes, p.points ) diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index ab67236..c8ea334 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -5,7 +5,7 @@ use crate::player::Color; use crate::CheckerMove; use crate::Error; -use log::info; +use log::debug; use serde::{Deserialize, Serialize}; use std::cmp; use std::collections::HashMap; @@ -384,7 +384,7 @@ impl PointsRules { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { let jans = self.get_jans(&self.board, dice_rolls_count); - info!("jans : {jans:?}"); + debug!("jans : {jans:?}"); let points_jans = jans.clone(); (jans, self.get_jans_points(points_jans)) }