diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 8408e6a..7b4584c 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -1,10 +1,9 @@ -use bot::dqn::burnrl::{ - dqn_model, environment, - utils::{demo_model, load_model, save_model}, -}; -use burn::backend::{Autodiff, NdArray}; +use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::Module; +use burn::record::{CompactRecorder, Recorder}; use burn_rl::agent::DQN; -use burn_rl::base::ElemType; +use burn_rl::base::{Action, Agent, ElemType, Environment, State}; type Backend = Autodiff>; type Env = environment::TrictracEnvironment; @@ -26,9 +25,12 @@ fn main() { println!("> Sauvegarde du modèle de validation"); - let path = "models/burn_dqn_40".to_string(); + let path = "models/burn_dqn_50".to_string(); save_model(valid_agent.model().as_ref().unwrap(), &path); + // println!("> Test avec le modèle entraîné"); + // demo_model::(valid_agent); + println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); let loaded_agent = DQN::new(loaded_model); @@ -36,3 +38,31 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } + +fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{}_model.mpk", path); + println!("Modèle de validation sauvegardé : {}", model_path); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { + let model_path = format!("{}_model.mpk", path); + println!("Chargement du modèle depuis : {}", model_path); + + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + + let record = recorder + .load(model_path.into(), &device) + .expect("Impossible de charger le modèle"); + + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) +} diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index 66fa850..ba04cb6 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -1,45 +1,12 @@ -use crate::dqn::burnrl::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; +use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; use crate::dqn::dqn_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::{Module, Param, ParamId}; +use burn::module::{Param, ParamId}; use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; use burn::tensor::backend::Backend; use burn::tensor::cast::ToElement; use burn::tensor::Tensor; use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { - let model_path = format!("{path}_model.mpk"); - println!("Chargement du modèle depuis : {model_path}"); - - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) -} +use burn_rl::base::{ElemType, Environment, State}; pub fn demo_model>(agent: DQN) { let mut env = TrictracEnvironment::new(true); diff --git a/bot/src/lib.rs b/bot/src/lib.rs index ca338e1..6326253 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,12 +1,11 @@ pub mod dqn; pub mod strategy; -use log::{debug, error}; +use log::{error, info}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; -pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; pub trait BotStrategy: std::fmt::Debug { @@ -65,7 +64,7 @@ impl Bot { } pub fn handle_event(&mut self, event: &GameEvent) -> Option { - debug!(">>>> {:?} BOT handle", self.color); + info!(">>>> {:?} BOT handle", self.color); let game = self.strategy.get_mut_game(); let internal_event = if self.color == Color::Black { &event.get_mirror() @@ -77,7 +76,7 @@ impl Bot { let turn_stage = game.turn_stage; game.consume(internal_event); if game.stage == Stage::Ended { - debug!("<<<< end {:?} BOT handle", self.color); + info!("<<<< end {:?} BOT handle", self.color); return None; } let active_player_id = if self.color == Color::Black { @@ -92,7 +91,7 @@ impl Bot { if active_player_id == self.player_id { let player_points = game.who_plays().map(|p| (p.points, p.holes)); if self.color == Color::Black { - debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); + info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); } let internal_event = match game.turn_stage { TurnStage::MarkAdvPoints => Some(GameEvent::Mark { @@ -121,15 +120,15 @@ impl Bot { _ => None, }; return if self.color == Color::Black { - debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); - debug!("<<<< end {:?} BOT handle", self.color); + info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); + info!("<<<< end {:?} BOT handle", self.color); internal_event.map(|evt| evt.get_mirror()) } else { - debug!("<<<< end {:?} BOT handle", self.color); + info!("<<<< end {:?} BOT handle", self.color); internal_event }; } - debug!("<<<< end {:?} BOT handle", self.color); + info!("<<<< end {:?} BOT handle", self.color); None } diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 34fb853..0248cc5 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,5 +1,4 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; -use log::info; use std::path::Path; use store::MoveRules; @@ -32,10 +31,9 @@ impl DqnStrategy { Self::default() } - pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { + pub fn new_with_model>(model_path: P) -> Self { let mut strategy = Self::new(); - if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { - info!("Loading model {model_path:?}"); + if let Ok(model) = SimpleNeuralNetwork::load(model_path) { strategy.model = Some(model); } strategy @@ -114,48 +112,50 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement - if let Some(TrictracAction::Move { - dice_order, - from1, - from2, - }) = self.get_dqn_action() - { - let dicevals = self.game.dice.values; - let (mut dice1, mut dice2) = if dice_order { - (dicevals.0, dicevals.1) - } else { - (dicevals.1, dicevals.0) - }; + if let Some(action) = self.get_dqn_action() { + if let TrictracAction::Move { + dice_order, + from1, + from2, + } = action + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; - if from1 == 0 { - // empty move - dice1 = 0; - } - let mut to1 = from1 + dice1 as usize; - if 24 < to1 { - // sortie - to1 = 0; - } - if from2 == 0 { - // empty move - dice2 = 0; - } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; - } + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } - let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); - let chosen_move = if self.color == Color::White { - (checker_move1, checker_move2) - } else { - (checker_move1.mirror(), checker_move2.mirror()) - }; + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; - return chosen_move; + return chosen_move; + } } // Fallback : utiliser la stratégie par défaut diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 731d1b1..3812188 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -2,5 +2,4 @@ pub mod client; pub mod default; pub mod dqn; pub mod erroneous_moves; -pub mod random; pub mod stable_baselines3; diff --git a/bot/src/strategy/random.rs b/bot/src/strategy/random.rs deleted file mode 100644 index 0bfd1c6..0000000 --- a/bot/src/strategy/random.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; -use store::MoveRules; - -#[derive(Debug)] -pub struct RandomStrategy { - pub game: GameState, - pub player_id: PlayerId, - pub color: Color, -} - -impl Default for RandomStrategy { - fn default() -> Self { - let game = GameState::default(); - Self { - game, - player_id: 1, - color: Color::White, - } - } -} - -impl BotStrategy for RandomStrategy { - fn get_game(&self) -> &GameState { - &self.game - } - fn get_mut_game(&mut self) -> &mut GameState { - &mut self.game - } - - fn set_color(&mut self, color: Color) { - self.color = color; - } - - fn set_player_id(&mut self, player_id: PlayerId) { - self.player_id = player_id; - } - - fn calculate_points(&self) -> u8 { - self.game.dice_points.0 - } - - fn calculate_adv_points(&self) -> u8 { - self.game.dice_points.1 - } - - fn choose_go(&self) -> bool { - true - } - - fn choose_move(&self) -> (CheckerMove, CheckerMove) { - let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - use rand::{seq::SliceRandom, thread_rng}; - let mut rng = thread_rng(); - let choosen_move = possible_moves - .choose(&mut rng) - .cloned() - .unwrap_or((CheckerMove::default(), CheckerMove::default())); - - if self.color == Color::White { - choosen_move - } else { - (choosen_move.0.mirror(), choosen_move.1.mirror()) - } - } -} diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 8fb1c9e..9b6ab3a 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,7 +1,4 @@ -use bot::{ - BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, - StableBaselines3Strategy, -}; +use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; use itertools::Itertools; use crate::game_runner::GameRunner; @@ -35,15 +32,13 @@ impl App { "dummy" => { Some(Box::new(DefaultStrategy::default()) as Box) } - "random" => { - Some(Box::new(RandomStrategy::default()) as Box) - } "erroneous" => { Some(Box::new(ErroneousStrategy::default()) as Box) } "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), - "dqn" => Some(Box::new(DqnStrategy::default()) as Box), + "dqn" => Some(Box::new(DqnStrategy::default()) + as Box), s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) diff --git a/client_cli/src/game_runner.rs b/client_cli/src/game_runner.rs index 797dbc9..296c907 100644 --- a/client_cli/src/game_runner.rs +++ b/client_cli/src/game_runner.rs @@ -1,5 +1,5 @@ use bot::{Bot, BotStrategy}; -use log::{debug, error}; +use log::{error, info}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; // Application Game @@ -63,19 +63,19 @@ impl GameRunner { return None; } let valid_event = if self.state.validate(event) { - debug!( + info!( "--------------- new valid event {event:?} (stage {:?}) -----------", self.state.turn_stage ); self.state.consume(event); - debug!( + info!( " --> stage {:?} ; active player points {:?}", self.state.turn_stage, self.state.who_plays().map(|p| p.points) ); event } else { - debug!("{}", self.state); + info!("{}", self.state); error!("event not valid : {event:?}"); panic!("crash and burn"); &GameEvent::PlayError diff --git a/justfile b/justfile index 0501ded..16f56ce 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,7 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy + RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli @@ -21,9 +21,6 @@ profile: pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl -trainsimple: - cargo build --release --bin=train_dqn - LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok diff --git a/store/src/board.rs b/store/src/board.rs index a838f10..3e563d0 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -37,7 +37,7 @@ impl Default for CheckerMove { impl CheckerMove { pub fn to_display_string(self) -> String { - format!("{self:?} ") + format!("{:?} ", self) } pub fn new(from: Field, to: Field) -> Result { @@ -569,7 +569,7 @@ impl Board { } let checker_color = self.get_checkers_color(field)?; if Some(color) != checker_color { - println!("field invalid : {color:?}, {field:?}, {self:?}"); + println!("field invalid : {:?}, {:?}, {:?}", color, field, self); return Err(Error::FieldInvalid); } let unit = match color { diff --git a/store/src/game.rs b/store/src/game.rs index 200c321..c9995b8 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -4,7 +4,7 @@ use crate::dice::Dice; use crate::game_rules_moves::MoveRules; use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::player::{Color, Player, PlayerId}; -use log::{debug, error, info}; +use log::{error, info}; // use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -521,14 +521,14 @@ impl GameState { self.inc_roll_count(self.active_player_id); self.turn_stage = TurnStage::MarkPoints; (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); - debug!("points from result : {:?}", self.dice_points); + info!("points from result : {:?}", self.dice_points); if !self.schools_enabled { // Schools are not enabled. We mark points automatically // the points earned by the opponent will be marked on its turn let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); if new_hole { let holes_count = self.get_active_player().unwrap().holes; - debug!("new hole -> {holes_count:?}"); + info!("new hole -> {holes_count:?}"); if holes_count > 12 { self.stage = Stage::Ended; } else { @@ -606,7 +606,7 @@ impl GameState { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { let player = &self.players.get(&self.active_player_id).unwrap(); - debug!( + info!( "get rollresult for {:?} {:?} {:?} (roll count {:?})", player.color, self.board, dice, player.dice_roll_count ); @@ -654,7 +654,7 @@ impl GameState { // if points > 0 && p.holes > 15 { if points > 0 { - debug!( + info!( "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", p.holes, p.points ) diff --git a/store/src/game_rules_points.rs b/store/src/game_rules_points.rs index c8ea334..ab67236 100644 --- a/store/src/game_rules_points.rs +++ b/store/src/game_rules_points.rs @@ -5,7 +5,7 @@ use crate::player::Color; use crate::CheckerMove; use crate::Error; -use log::debug; +use log::info; use serde::{Deserialize, Serialize}; use std::cmp; use std::collections::HashMap; @@ -384,7 +384,7 @@ impl PointsRules { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { let jans = self.get_jans(&self.board, dice_rolls_count); - debug!("jans : {jans:?}"); + info!("jans : {jans:?}"); let points_jans = jans.clone(); (jans, self.get_jans_points(points_jans)) }