Compare commits

...

2 commits

Author SHA1 Message Date
Henri Bourcereau 1b58ca4ccc refact dqn burn demo 2025-08-08 17:07:34 +02:00
Henri Bourcereau bf820ecc4e feat: bot random strategy 2025-08-08 16:24:40 +02:00
12 changed files with 189 additions and 109 deletions

View file

@ -1,9 +1,10 @@
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; use bot::dqn::burnrl::{
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; dqn_model, environment,
use burn::module::Module; utils::{demo_model, load_model, save_model},
use burn::record::{CompactRecorder, Recorder}; };
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State}; use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>; type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment; type Env = environment::TrictracEnvironment;
@ -25,12 +26,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation"); println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_50".to_string(); let path = "models/burn_dqn_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path); save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path); let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model); let loaded_agent = DQN::new(loaded_model);
@ -38,31 +36,3 @@ fn main() {
println!("> Test avec le modèle chargé"); println!("> Test avec le modèle chargé");
demo_model(loaded_agent); demo_model(loaded_agent);
} }
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}

View file

@ -1,12 +1,45 @@
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices; use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement; use burn::tensor::cast::ToElement;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN}; use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State}; use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{path}_model.mpk");
println!("Chargement du modèle depuis : {model_path}");
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) { pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true); let mut env = TrictracEnvironment::new(true);

View file

@ -1,11 +1,12 @@
pub mod dqn; pub mod dqn;
pub mod strategy; pub mod strategy;
use log::{error, info}; use log::{debug, error};
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy; pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy; pub use strategy::dqn::DqnStrategy;
pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::erroneous_moves::ErroneousStrategy;
pub use strategy::random::RandomStrategy;
pub use strategy::stable_baselines3::StableBaselines3Strategy; pub use strategy::stable_baselines3::StableBaselines3Strategy;
pub trait BotStrategy: std::fmt::Debug { pub trait BotStrategy: std::fmt::Debug {
@ -64,7 +65,7 @@ impl Bot {
} }
pub fn handle_event(&mut self, event: &GameEvent) -> Option<GameEvent> { pub fn handle_event(&mut self, event: &GameEvent) -> Option<GameEvent> {
info!(">>>> {:?} BOT handle", self.color); debug!(">>>> {:?} BOT handle", self.color);
let game = self.strategy.get_mut_game(); let game = self.strategy.get_mut_game();
let internal_event = if self.color == Color::Black { let internal_event = if self.color == Color::Black {
&event.get_mirror() &event.get_mirror()
@ -76,7 +77,7 @@ impl Bot {
let turn_stage = game.turn_stage; let turn_stage = game.turn_stage;
game.consume(internal_event); game.consume(internal_event);
if game.stage == Stage::Ended { if game.stage == Stage::Ended {
info!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
return None; return None;
} }
let active_player_id = if self.color == Color::Black { let active_player_id = if self.color == Color::Black {
@ -91,7 +92,7 @@ impl Bot {
if active_player_id == self.player_id { if active_player_id == self.player_id {
let player_points = game.who_plays().map(|p| (p.points, p.holes)); let player_points = game.who_plays().map(|p| (p.points, p.holes));
if self.color == Color::Black { if self.color == Color::Black {
info!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}"); debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}");
} }
let internal_event = match game.turn_stage { let internal_event = match game.turn_stage {
TurnStage::MarkAdvPoints => Some(GameEvent::Mark { TurnStage::MarkAdvPoints => Some(GameEvent::Mark {
@ -120,15 +121,15 @@ impl Bot {
_ => None, _ => None,
}; };
return if self.color == Color::Black { return if self.color == Color::Black {
info!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}");
info!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
internal_event.map(|evt| evt.get_mirror()) internal_event.map(|evt| evt.get_mirror())
} else { } else {
info!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
internal_event internal_event
}; };
} }
info!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
None None
} }

View file

@ -1,4 +1,5 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info;
use std::path::Path; use std::path::Path;
use store::MoveRules; use store::MoveRules;
@ -31,9 +32,10 @@ impl DqnStrategy {
Self::default() Self::default()
} }
pub fn new_with_model<P: AsRef<Path>>(model_path: P) -> Self { pub fn new_with_model<P: AsRef<Path> + std::fmt::Debug>(model_path: P) -> Self {
let mut strategy = Self::new(); let mut strategy = Self::new();
if let Ok(model) = SimpleNeuralNetwork::load(model_path) { if let Ok(model) = SimpleNeuralNetwork::load(&model_path) {
info!("Loading model {model_path:?}");
strategy.model = Some(model); strategy.model = Some(model);
} }
strategy strategy
@ -112,50 +114,48 @@ impl BotStrategy for DqnStrategy {
fn choose_move(&self) -> (CheckerMove, CheckerMove) { fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement // Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() { if let Some(TrictracAction::Move {
if let TrictracAction::Move { dice_order,
dice_order, from1,
from1, from2,
from2, }) = self.get_dqn_action()
} = action {
{ let dicevals = self.game.dice.values;
let dicevals = self.game.dice.values; let (mut dice1, mut dice2) = if dice_order {
let (mut dice1, mut dice2) = if dice_order { (dicevals.0, dicevals.1)
(dicevals.0, dicevals.1) } else {
} else { (dicevals.1, dicevals.0)
(dicevals.1, dicevals.0) };
};
if from1 == 0 { if from1 == 0 {
// empty move // empty move
dice1 = 0; dice1 = 0;
}
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
} }
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
} }
// Fallback : utiliser la stratégie par défaut // Fallback : utiliser la stratégie par défaut

View file

@ -2,4 +2,5 @@ pub mod client;
pub mod default; pub mod default;
pub mod dqn; pub mod dqn;
pub mod erroneous_moves; pub mod erroneous_moves;
pub mod random;
pub mod stable_baselines3; pub mod stable_baselines3;

View file

@ -0,0 +1,67 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use store::MoveRules;
#[derive(Debug)]
pub struct RandomStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
}
impl Default for RandomStrategy {
fn default() -> Self {
let game = GameState::default();
Self {
game,
player_id: 1,
color: Color::White,
}
}
}
impl BotStrategy for RandomStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
self.game.dice_points.0
}
fn calculate_adv_points(&self) -> u8 {
self.game.dice_points.1
}
fn choose_go(&self) -> bool {
true
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
use rand::{seq::SliceRandom, thread_rng};
let mut rng = thread_rng();
let choosen_move = possible_moves
.choose(&mut rng)
.cloned()
.unwrap_or((CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
}
}
}

View file

@ -1,4 +1,7 @@
use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; use bot::{
BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy,
};
use itertools::Itertools; use itertools::Itertools;
use crate::game_runner::GameRunner; use crate::game_runner::GameRunner;
@ -32,13 +35,15 @@ impl App {
"dummy" => { "dummy" => {
Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>)
} }
"random" => {
Some(Box::new(RandomStrategy::default()) as Box<dyn BotStrategy>)
}
"erroneous" => { "erroneous" => {
Some(Box::new(ErroneousStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(ErroneousStrategy::default()) as Box<dyn BotStrategy>)
} }
"ai" => Some(Box::new(StableBaselines3Strategy::default()) "ai" => Some(Box::new(StableBaselines3Strategy::default())
as Box<dyn BotStrategy>), as Box<dyn BotStrategy>),
"dqn" => Some(Box::new(DqnStrategy::default()) "dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
as Box<dyn BotStrategy>),
s if s.starts_with("ai:") => { s if s.starts_with("ai:") => {
let path = s.trim_start_matches("ai:"); let path = s.trim_start_matches("ai:");
Some(Box::new(StableBaselines3Strategy::new(path)) Some(Box::new(StableBaselines3Strategy::new(path))

View file

@ -1,5 +1,5 @@
use bot::{Bot, BotStrategy}; use bot::{Bot, BotStrategy};
use log::{error, info}; use log::{debug, error};
use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage};
// Application Game // Application Game
@ -63,19 +63,19 @@ impl GameRunner {
return None; return None;
} }
let valid_event = if self.state.validate(event) { let valid_event = if self.state.validate(event) {
info!( debug!(
"--------------- new valid event {event:?} (stage {:?}) -----------", "--------------- new valid event {event:?} (stage {:?}) -----------",
self.state.turn_stage self.state.turn_stage
); );
self.state.consume(event); self.state.consume(event);
info!( debug!(
" --> stage {:?} ; active player points {:?}", " --> stage {:?} ; active player points {:?}",
self.state.turn_stage, self.state.turn_stage,
self.state.who_plays().map(|p| p.points) self.state.who_plays().map(|p| p.points)
); );
event event
} else { } else {
info!("{}", self.state); debug!("{}", self.state);
error!("event not valid : {event:?}"); error!("event not valid : {event:?}");
panic!("crash and burn"); panic!("crash and burn");
&GameEvent::PlayError &GameEvent::PlayError

View file

@ -9,7 +9,7 @@ shell:
runcli: runcli:
RUST_LOG=info cargo run --bin=client_cli RUST_LOG=info cargo run --bin=client_cli
runclibots: runclibots:
RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match: match:
cargo build --release --bin=client_cli cargo build --release --bin=client_cli
@ -21,6 +21,9 @@ profile:
pythonlib: pythonlib:
maturin build -m store/Cargo.toml --release maturin build -m store/Cargo.toml --release
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainsimple:
cargo build --release --bin=train_dqn
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out
trainbot: trainbot:
#python ./store/python/trainModel.py #python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok # cargo run --bin=train_dqn # ok

View file

@ -37,7 +37,7 @@ impl Default for CheckerMove {
impl CheckerMove { impl CheckerMove {
pub fn to_display_string(self) -> String { pub fn to_display_string(self) -> String {
format!("{:?} ", self) format!("{self:?} ")
} }
pub fn new(from: Field, to: Field) -> Result<Self, Error> { pub fn new(from: Field, to: Field) -> Result<Self, Error> {
@ -569,7 +569,7 @@ impl Board {
} }
let checker_color = self.get_checkers_color(field)?; let checker_color = self.get_checkers_color(field)?;
if Some(color) != checker_color { if Some(color) != checker_color {
println!("field invalid : {:?}, {:?}, {:?}", color, field, self); println!("field invalid : {color:?}, {field:?}, {self:?}");
return Err(Error::FieldInvalid); return Err(Error::FieldInvalid);
} }
let unit = match color { let unit = match color {

View file

@ -4,7 +4,7 @@ use crate::dice::Dice;
use crate::game_rules_moves::MoveRules; use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::game_rules_points::{PointsRules, PossibleJans};
use crate::player::{Color, Player, PlayerId}; use crate::player::{Color, Player, PlayerId};
use log::{error, info}; use log::{debug, error, info};
// use itertools::Itertools; // use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -521,14 +521,14 @@ impl GameState {
self.inc_roll_count(self.active_player_id); self.inc_roll_count(self.active_player_id);
self.turn_stage = TurnStage::MarkPoints; self.turn_stage = TurnStage::MarkPoints;
(self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice);
info!("points from result : {:?}", self.dice_points); debug!("points from result : {:?}", self.dice_points);
if !self.schools_enabled { if !self.schools_enabled {
// Schools are not enabled. We mark points automatically // Schools are not enabled. We mark points automatically
// the points earned by the opponent will be marked on its turn // the points earned by the opponent will be marked on its turn
let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); let new_hole = self.mark_points(self.active_player_id, self.dice_points.0);
if new_hole { if new_hole {
let holes_count = self.get_active_player().unwrap().holes; let holes_count = self.get_active_player().unwrap().holes;
info!("new hole -> {holes_count:?}"); debug!("new hole -> {holes_count:?}");
if holes_count > 12 { if holes_count > 12 {
self.stage = Stage::Ended; self.stage = Stage::Ended;
} else { } else {
@ -606,7 +606,7 @@ impl GameState {
fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) {
let player = &self.players.get(&self.active_player_id).unwrap(); let player = &self.players.get(&self.active_player_id).unwrap();
info!( debug!(
"get rollresult for {:?} {:?} {:?} (roll count {:?})", "get rollresult for {:?} {:?} {:?} (roll count {:?})",
player.color, self.board, dice, player.dice_roll_count player.color, self.board, dice, player.dice_roll_count
); );
@ -654,7 +654,7 @@ impl GameState {
// if points > 0 && p.holes > 15 { // if points > 0 && p.holes > 15 {
if points > 0 { if points > 0 {
info!( debug!(
"player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})", "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})",
p.holes, p.points p.holes, p.points
) )

View file

@ -5,7 +5,7 @@ use crate::player::Color;
use crate::CheckerMove; use crate::CheckerMove;
use crate::Error; use crate::Error;
use log::info; use log::debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp; use std::cmp;
use std::collections::HashMap; use std::collections::HashMap;
@ -384,7 +384,7 @@ impl PointsRules {
pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) {
let jans = self.get_jans(&self.board, dice_rolls_count); let jans = self.get_jans(&self.board, dice_rolls_count);
info!("jans : {jans:?}"); debug!("jans : {jans:?}");
let points_jans = jans.clone(); let points_jans = jans.clone();
(jans, self.get_jans_points(points_jans)) (jans, self.get_jans_points(points_jans))
} }