Compare commits

..

No commits in common. "042999967290eb8add6fd85ced44f347f1c7491a" and "4920ab96f843615340d39af7f3944339a39d5eb6" have entirely different histories.

30 changed files with 1410 additions and 2609 deletions

3379
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
[package] [package]
name = "trictrac-bot" name = "bot"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -13,10 +13,10 @@ path = "src/burnrl/main.rs"
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
trictrac-store = { path = "../store" } store = { path = "../store" }
rand = "0.9" rand = "0.8"
env_logger = "0.10" env_logger = "0.10"
burn = { version = "0.20", features = ["ndarray", "autodiff"] } burn = { version = "0.18", features = ["ndarray", "autodiff"] }
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" } burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
log = "0.4.20" log = "0.4.20"
confy = "1.0.0" confy = "1.0.0"

View file

@ -1,5 +0,0 @@
import trictrac_store
game = trictrac_store.TricTrac()
print(game.current_player_idx())
print(game.get_legal_actions(game.current_player_idx()))

View file

@ -1,10 +1,10 @@
use std::io::Write; use std::io::Write;
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{rng, Rng}; use rand::{thread_rng, Rng};
use trictrac_store::training_common; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
use trictrac_store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -1.0012121; const ERROR_REWARD: f32 = -1.0012121;
const REWARD_VALID_MOVE: f32 = 1.0012121; const REWARD_VALID_MOVE: f32 = 1.0012121;
@ -52,10 +52,10 @@ pub struct TrictracAction {
impl Action for TrictracAction { impl Action for TrictracAction {
fn random() -> Self { fn random() -> Self {
use rand::{rng, Rng}; use rand::{thread_rng, Rng};
let mut rng = rng(); let mut rng = thread_rng();
TrictracAction { TrictracAction {
index: rng.random_range(0..Self::size() as u32), index: rng.gen_range(0..Self::size() as u32),
} }
} }
@ -288,11 +288,11 @@ impl TrictracEnvironment {
// reward += REWARD_VALID_MOVE; // reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll // Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) { if matches!(action, TrictracAction::Roll) {
let mut rng = rng(); let mut rng = thread_rng();
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6)); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult { let dice_event = GameEvent::RollResult {
player_id: self.active_player_id, player_id: self.active_player_id,
dice: trictrac_store::Dice { dice: store::Dice {
values: dice_values, values: dice_values,
}, },
}; };
@ -340,18 +340,18 @@ impl TrictracEnvironment {
// Exécuter l'action selon le turn_stage // Exécuter l'action selon le turn_stage
let mut calculate_points = false; let mut calculate_points = false;
let opponent_color = trictrac_store::Color::Black; let opponent_color = store::Color::Black;
let event = match self.game.turn_stage { let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll { TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id, player_id: self.opponent_id,
}, },
TurnStage::RollWaiting => { TurnStage::RollWaiting => {
let mut rng = rng(); let mut rng = thread_rng();
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6)); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true; calculate_points = true;
GameEvent::RollResult { GameEvent::RollResult {
player_id: self.opponent_id, player_id: self.opponent_id,
dice: trictrac_store::Dice { dice: store::Dice {
values: dice_values, values: dice_values,
}, },
} }
@ -371,7 +371,7 @@ impl TrictracEnvironment {
} }
} }
TurnStage::MarkAdvPoints => { TurnStage::MarkAdvPoints => {
let opponent_color = trictrac_store::Color::Black; let opponent_color = store::Color::Black;
let dice_roll_count = self let dice_roll_count = self
.game .game
.players .players

View file

@ -1,8 +1,8 @@
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{rng, Rng}; use rand::{thread_rng, Rng};
use trictrac_store::training_common; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
use trictrac_store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -1.0012121; const ERROR_REWARD: f32 = -1.0012121;
const REWARD_RATIO: f32 = 0.1; const REWARD_RATIO: f32 = 0.1;
@ -48,10 +48,10 @@ pub struct TrictracAction {
impl Action for TrictracAction { impl Action for TrictracAction {
fn random() -> Self { fn random() -> Self {
use rand::{rng, Rng}; use rand::{thread_rng, Rng};
let mut rng = rng(); let mut rng = thread_rng();
TrictracAction { TrictracAction {
index: rng.random_range(0..Self::size() as u32), index: rng.gen_range(0..Self::size() as u32),
} }
} }
@ -258,11 +258,11 @@ impl TrictracEnvironment {
// reward += REWARD_VALID_MOVE; // reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll // Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) { if matches!(action, TrictracAction::Roll) {
let mut rng = rng(); let mut rng = thread_rng();
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6)); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult { let dice_event = GameEvent::RollResult {
player_id: self.active_player_id, player_id: self.active_player_id,
dice: trictrac_store::Dice { dice: store::Dice {
values: dice_values, values: dice_values,
}, },
}; };
@ -310,18 +310,18 @@ impl TrictracEnvironment {
// Exécuter l'action selon le turn_stage // Exécuter l'action selon le turn_stage
let mut calculate_points = false; let mut calculate_points = false;
let opponent_color = trictrac_store::Color::Black; let opponent_color = store::Color::Black;
let event = match self.game.turn_stage { let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll { TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id, player_id: self.opponent_id,
}, },
TurnStage::RollWaiting => { TurnStage::RollWaiting => {
let mut rng = rng(); let mut rng = thread_rng();
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6)); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true; calculate_points = true;
GameEvent::RollResult { GameEvent::RollResult {
player_id: self.opponent_id, player_id: self.opponent_id,
dice: trictrac_store::Dice { dice: store::Dice {
values: dice_values, values: dice_values,
}, },
} }

View file

@ -1,7 +1,7 @@
use trictrac_bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid}; use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
use trictrac_bot::burnrl::environment::TrictracEnvironment; use bot::burnrl::environment::TrictracEnvironment;
use trictrac_bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid; use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
use trictrac_bot::burnrl::utils::{demo_model, Config}; use bot::burnrl::utils::{demo_model, Config};
use burn::backend::{Autodiff, NdArray}; use burn::backend::{Autodiff, NdArray};
use burn_rl::base::ElemType; use burn_rl::base::ElemType;
use std::env; use std::env;

View file

@ -1,16 +1,15 @@
pub mod burnrl; pub mod burnrl;
pub mod strategy; pub mod strategy;
pub mod training_common;
pub mod trictrac_board; pub mod trictrac_board;
use log::debug; use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy; pub use strategy::default::DefaultStrategy;
pub use strategy::dqnburn::DqnBurnStrategy; pub use strategy::dqnburn::DqnBurnStrategy;
pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::erroneous_moves::ErroneousStrategy;
pub use strategy::random::RandomStrategy; pub use strategy::random::RandomStrategy;
pub use strategy::stable_baselines3::StableBaselines3Strategy; pub use strategy::stable_baselines3::StableBaselines3Strategy;
use trictrac_store::{
CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage,
};
pub trait BotStrategy: std::fmt::Debug { pub trait BotStrategy: std::fmt::Debug {
fn get_game(&self) -> &GameState; fn get_game(&self) -> &GameState;
@ -71,7 +70,7 @@ impl Bot {
debug!(">>>> {:?} BOT handle", self.color); debug!(">>>> {:?} BOT handle", self.color);
let game = self.strategy.get_mut_game(); let game = self.strategy.get_mut_game();
let internal_event = if self.color == Color::Black { let internal_event = if self.color == Color::Black {
&event.get_mirror(false) &event.get_mirror()
} else { } else {
event event
}; };
@ -126,7 +125,7 @@ impl Bot {
return if self.color == Color::Black { return if self.color == Color::Black {
debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}"); debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}");
debug!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
internal_event.map(|evt| evt.get_mirror(false)) internal_event.map(|evt| evt.get_mirror())
} else { } else {
debug!("<<<< end {:?} BOT handle", self.color); debug!("<<<< end {:?} BOT handle", self.color);
internal_event internal_event
@ -145,7 +144,7 @@ impl Bot {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use trictrac_store::{Dice, Stage}; use store::{Dice, Stage};
#[test] #[test]
fn test_new() { fn test_new() {

View file

@ -1,5 +1,5 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use trictrac_store::MoveRules; use store::MoveRules;
#[derive(Debug)] #[derive(Debug)]
pub struct DefaultStrategy { pub struct DefaultStrategy {

View file

@ -4,13 +4,11 @@ use burn_rl::base::{ElemType, Model, State};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info; use log::info;
use trictrac_store::MoveRules; use store::MoveRules;
use crate::burnrl::algos::dqn; use crate::burnrl::algos::dqn;
use crate::burnrl::environment; use crate::burnrl::environment;
use trictrac_store::training_common::{ use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
get_valid_action_indices, sample_valid_action, TrictracAction,
};
type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>; type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>;
@ -154,7 +152,7 @@ impl BotStrategy for DqnBurnStrategy {
to1 = if fto1 < 0 { 0 } else { fto1 as usize }; to1 = if fto1 < 0 { 0 } else { fto1 as usize };
} }
let checker_move1 = trictrac_store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone(); let mut tmp_board = self.game.board.clone();
let move_res = tmp_board.move_checker(&self.color, checker_move1); let move_res = tmp_board.move_checker(&self.color, checker_move1);

View file

@ -1,6 +1,5 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use rand::{prelude::IndexedRandom, rng}; use store::MoveRules;
use trictrac_store::MoveRules;
#[derive(Debug)] #[derive(Debug)]
pub struct RandomStrategy { pub struct RandomStrategy {
@ -52,7 +51,8 @@ impl BotStrategy for RandomStrategy {
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let mut rng = rng(); use rand::{seq::SliceRandom, thread_rng};
let mut rng = thread_rng();
let choosen_move = possible_moves let choosen_move = possible_moves
.choose(&mut rng) .choose(&mut rng)
.cloned() .cloned()

View file

@ -5,7 +5,7 @@ use std::io::Read;
use std::io::Write; use std::io::Write;
use std::path::Path; use std::path::Path;
use std::process::Command; use std::process::Command;
use trictrac_store::MoveRules; use store::MoveRules;
#[derive(Debug)] #[derive(Debug)]
pub struct StableBaselines3Strategy { pub struct StableBaselines3Strategy {
@ -79,12 +79,12 @@ impl StableBaselines3Strategy {
// Convertir l'étape du tour en entier // Convertir l'étape du tour en entier
let turn_stage = match self.game.turn_stage { let turn_stage = match self.game.turn_stage {
trictrac_store::TurnStage::RollDice => 0, store::TurnStage::RollDice => 0,
trictrac_store::TurnStage::RollWaiting => 1, store::TurnStage::RollWaiting => 1,
trictrac_store::TurnStage::MarkPoints => 2, store::TurnStage::MarkPoints => 2,
trictrac_store::TurnStage::HoldOrGoChoice => 3, store::TurnStage::HoldOrGoChoice => 3,
trictrac_store::TurnStage::Move => 4, store::TurnStage::Move => 4,
trictrac_store::TurnStage::MarkAdvPoints => 5, store::TurnStage::MarkAdvPoints => 5,
}; };
// Récupérer les points et trous des joueurs // Récupérer les points et trous des joueurs

View file

@ -3,9 +3,8 @@
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use crate::board::Board;
use crate::{CheckerMove, Dice, GameEvent, GameState};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use store::{CheckerMove, GameEvent, GameState};
// 1 (Roll) + 1 (Go) + 512 (mouvements possibles) // 1 (Roll) + 1 (Go) + 512 (mouvements possibles)
// avec 512 = 2 (choix du dé) * 16 * 16 (choix de la dame 0-15 pour chaque from) // avec 512 = 2 (choix du dé) * 16 * 16 (choix de la dame 0-15 pour chaque from)
@ -16,8 +15,7 @@ pub const ACTION_SPACE_SIZE: usize = 514;
pub enum TrictracAction { pub enum TrictracAction {
/// Lancer les dés /// Lancer les dés
Roll, Roll,
/// Faire un nouveau 'relevé' (repositionnement des dames à l'état de départ) après avoir gagné un trou, /// Continuer après avoir gagné un trou
/// au lieu de continuer dans la position courante
Go, Go,
/// Effectuer un mouvement de pions /// Effectuer un mouvement de pions
Move { Move {
@ -61,22 +59,6 @@ impl TrictracAction {
} }
} }
pub fn mirror(&self) -> TrictracAction {
match self {
TrictracAction::Roll => TrictracAction::Roll,
TrictracAction::Go => TrictracAction::Go,
TrictracAction::Move {
dice_order,
checker1,
checker2,
} => TrictracAction::Move {
dice_order: *dice_order,
checker1: if *checker1 == 0 { 0 } else { 25 - checker1 },
checker2: if *checker2 == 0 { 0 } else { 25 - checker2 },
},
}
}
pub fn to_event(&self, state: &GameState) -> Option<GameEvent> { pub fn to_event(&self, state: &GameState) -> Option<GameEvent> {
match self { match self {
TrictracAction::Roll => { TrictracAction::Roll => {
@ -111,17 +93,13 @@ impl TrictracAction {
(state.dice.values.1, state.dice.values.0) (state.dice.values.1, state.dice.values.0)
}; };
let color = &crate::Color::White; let color = &store::Color::White;
let from1 = state let from1 = state
.board .board
.get_checker_field(color, *checker1 as u8) .get_checker_field(color, *checker1 as u8)
.unwrap_or(0); .unwrap_or(0);
let mut to1 = from1 + dice1 as usize; let mut to1 = from1 + dice1 as usize;
if 24 < to1 { let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
// exit board
to1 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = state.board.clone(); let mut tmp_board = state.board.clone();
let move_result = tmp_board.move_checker(color, checker_move1); let move_result = tmp_board.move_checker(color, checker_move1);
@ -133,10 +111,6 @@ impl TrictracAction {
.get_checker_field(color, *checker2 as u8) .get_checker_field(color, *checker2 as u8)
.unwrap_or(0); .unwrap_or(0);
let mut to2 = from2 + dice2 as usize; let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// exit board
to2 = 0;
}
// Gestion prise de coin par puissance // Gestion prise de coin par puissance
let opp_rest_field = 13; let opp_rest_field = 13;
@ -145,8 +119,8 @@ impl TrictracAction {
to2 -= 1; to2 -= 1;
} }
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
Some(GameEvent::Move { Some(GameEvent::Move {
player_id: state.active_player_id, player_id: state.active_player_id,
@ -192,11 +166,33 @@ impl TrictracAction {
pub fn action_space_size() -> usize { pub fn action_space_size() -> usize {
ACTION_SPACE_SIZE ACTION_SPACE_SIZE
} }
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
// match action {
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
// TrictracAction::Move {
// dice_order,
// from1,
// from2,
// } => {
// // Effectuer un mouvement
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
//
// Some(GameEvent::Move {
// player_id: self.agent_player_id,
// moves: (checker_move1, checker_move2),
// })
// }
// };
// }
} }
/// Obtient les actions valides pour l'état de jeu actuel /// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &GameState) -> Vec<TrictracAction> { pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use crate::TurnStage; use store::TurnStage;
let mut valid_actions = Vec::new(); let mut valid_actions = Vec::new();
@ -219,9 +215,11 @@ pub fn get_valid_actions(game_state: &GameState) -> Vec<TrictracAction> {
valid_actions.push(TrictracAction::Go); valid_actions.push(TrictracAction::Go);
// Ajoute aussi les mouvements possibles // Ajoute aussi les mouvements possibles
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice); let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1, &move2, &color, game_state,
@ -229,13 +227,15 @@ pub fn get_valid_actions(game_state: &GameState) -> Vec<TrictracAction> {
} }
} }
TurnStage::Move => { TurnStage::Move => {
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice); let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]); let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if possible_moves.is_empty() { if possible_moves.is_empty() {
// Empty move // Empty move
possible_moves.push((CheckerMove::default(), CheckerMove::default())); possible_moves.push((CheckerMove::default(), CheckerMove::default()));
} }
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1, &move2, &color, game_state,
@ -251,40 +251,18 @@ pub fn get_valid_actions(game_state: &GameState) -> Vec<TrictracAction> {
valid_actions valid_actions
} }
// Valid only for White player
fn checker_moves_to_trictrac_action( fn checker_moves_to_trictrac_action(
move1: &CheckerMove, move1: &CheckerMove,
move2: &CheckerMove, move2: &CheckerMove,
color: &crate::Color, color: &store::Color,
state: &GameState, state: &crate::GameState,
) -> TrictracAction {
let dice = &state.dice;
let board = &state.board;
if color == &crate::Color::Black {
white_checker_moves_to_trictrac_action(
move1,
move2,
// &move1.clone().mirror(),
// &move2.clone().mirror(),
dice,
&board.clone().mirror(),
)
.mirror()
} else {
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
}
}
fn white_checker_moves_to_trictrac_action(
move1: &CheckerMove,
move2: &CheckerMove,
dice: &Dice,
board: &Board,
) -> TrictracAction { ) -> TrictracAction {
let to1 = move1.get_to(); let to1 = move1.get_to();
let to2 = move2.get_to(); let to2 = move2.get_to();
let from1 = move1.get_from(); let from1 = move1.get_from();
let from2 = move2.get_from(); let from2 = move2.get_from();
let dice = state.dice;
let mut diff_move1 = if to1 > 0 { let mut diff_move1 = if to1 > 0 {
// Mouvement sans sortie // Mouvement sans sortie
@ -320,14 +298,14 @@ fn white_checker_moves_to_trictrac_action(
} }
let dice_order = diff_move1 == dice.values.0 as usize; let dice_order = diff_move1 == dice.values.0 as usize;
let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize; let checker1 = state.board.get_field_checker(color, from1) as usize;
let mut tmp_board = board.clone(); let mut tmp_board = state.board.clone();
// should not raise an error for a valid action // should not raise an error for a valid action
let move_res = tmp_board.move_checker(&crate::Color::White, *move1); let move_res = tmp_board.move_checker(color, *move1);
if move_res.is_err() { if move_res.is_err() {
panic!("error while moving checker {move_res:?}"); panic!("error while moving checker {move_res:?}");
} }
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; let checker2 = tmp_board.get_field_checker(color, from2) as usize;
TrictracAction::Move { TrictracAction::Move {
dice_order, dice_order,
checker1, checker1,
@ -336,7 +314,7 @@ fn white_checker_moves_to_trictrac_action(
} }
/// Retourne les indices des actions valides /// Retourne les indices des actions valides
pub fn get_valid_action_indices(game_state: &GameState) -> Vec<usize> { pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
get_valid_actions(game_state) get_valid_actions(game_state)
.into_iter() .into_iter()
.map(|action| action.to_action_index()) .map(|action| action.to_action_index())
@ -344,11 +322,11 @@ pub fn get_valid_action_indices(game_state: &GameState) -> Vec<usize> {
} }
/// Sélectionne une action valide aléatoire /// Sélectionne une action valide aléatoire
pub fn sample_valid_action(game_state: &GameState) -> Option<TrictracAction> { pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
use rand::{prelude::IndexedRandom, rng}; use rand::{seq::SliceRandom, thread_rng};
let valid_actions = get_valid_actions(game_state); let valid_actions = get_valid_actions(game_state);
let mut rng = rng(); let mut rng = thread_rng();
valid_actions.choose(&mut rng).cloned() valid_actions.choose(&mut rng).cloned()
} }

View file

@ -1,4 +1,5 @@
// https://docs.rs/board-game/ implementation // https://docs.rs/board-game/ implementation
use crate::training_common::{get_valid_actions, TrictracAction};
use board_game::board::{ use board_game::board::{
Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer, Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer,
}; };
@ -7,8 +8,7 @@ use internal_iterator::InternalIterator;
use std::fmt; use std::fmt;
use std::hash::Hash; use std::hash::Hash;
use std::ops::ControlFlow; use std::ops::ControlFlow;
use trictrac_store::training_common::{get_valid_actions, TrictracAction}; use store::Color;
use trictrac_store::Color;
#[derive(Clone, Debug, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct TrictracBoard(crate::GameState); pub struct TrictracBoard(crate::GameState);

View file

@ -1,5 +1,5 @@
[package] [package]
name = "trictrac-client_cli" name = "client_cli"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -11,8 +11,8 @@ bincode = "1.3.3"
pico-args = "0.5.0" pico-args = "0.5.0"
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
renet = "0.0.13" renet = "0.0.13"
trictrac-store = { path = "../store" } store = { path = "../store" }
trictrac-bot = { path = "../bot" } bot = { path = "../bot" }
itertools = "0.13.0" itertools = "0.13.0"
env_logger = "0.11.6" env_logger = "0.11.6"
log = "0.4.20" log = "0.4.20"

View file

@ -1,11 +1,11 @@
use trictrac_bot::{ use bot::{
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy, BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy, StableBaselines3Strategy,
}; };
use itertools::Itertools; use itertools::Itertools;
use crate::game_runner::GameRunner; use crate::game_runner::GameRunner;
use trictrac_store::{CheckerMove, GameEvent, GameState, Stage, TurnStage}; use store::{CheckerMove, GameEvent, GameState, Stage, TurnStage};
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct AppArgs { pub struct AppArgs {

View file

@ -1,6 +1,6 @@
use trictrac_bot::{Bot, BotStrategy}; use bot::{Bot, BotStrategy};
use log::{debug, error}; use log::{debug, error};
use trictrac_store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage};
// Application Game // Application Game
#[derive(Debug, Default)] #[derive(Debug, Default)]
@ -117,8 +117,8 @@ impl GameRunner {
} }
if let Some(winner) = self.state.determine_winner() { if let Some(winner) = self.state.determine_winner() {
next_event = Some(trictrac_store::GameEvent::EndGame { next_event = Some(store::GameEvent::EndGame {
reason: trictrac_store::EndGameReason::PlayerWon { winner }, reason: store::EndGameReason::PlayerWon { winner },
}); });
} }

View file

@ -3,10 +3,10 @@
"devenv": { "devenv": {
"locked": { "locked": {
"dir": "src/modules", "dir": "src/modules",
"lastModified": 1770390537, "lastModified": 1753667201,
"owner": "cachix", "owner": "cachix",
"repo": "devenv", "repo": "devenv",
"rev": "d6f45cc00829254a9a6f8807c8fbfaf3efa7e629", "rev": "4d584d7686a50387f975879788043e55af9f0ad4",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -19,14 +19,14 @@
"flake-compat": { "flake-compat": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1767039857, "lastModified": 1747046372,
"owner": "NixOS", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"type": "github" "type": "github"
} }
@ -40,10 +40,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1769939035, "lastModified": 1750779888,
"owner": "cachix", "owner": "cachix",
"repo": "git-hooks.nix", "repo": "git-hooks.nix",
"rev": "a8ca480175326551d6c4121498316261cbb5b260", "rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -60,10 +60,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1762808025, "lastModified": 1709087332,
"owner": "hercules-ci", "owner": "hercules-ci",
"repo": "gitignore.nix", "repo": "gitignore.nix",
"rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c", "rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -74,40 +74,24 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1770136044, "lastModified": 1753432016,
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "e576e3c9cf9bad747afcddd9e34f51d18c855b4e", "rev": "6027c30c8e9810896b92429f0092f624f7b1aace",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "NixOS",
"ref": "nixos-25.11", "ref": "nixpkgs-unstable",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
}, },
"nixpkgs-cmake3": {
"locked": {
"lastModified": 1758213207,
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "f4b140d5b253f5e2a1ff4e5506edbf8267724bde",
"type": "github"
},
"original": {
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "f4b140d5b253f5e2a1ff4e5506edbf8267724bde",
"type": "github"
}
},
"root": { "root": {
"inputs": { "inputs": {
"devenv": "devenv", "devenv": "devenv",
"git-hooks": "git-hooks", "git-hooks": "git-hooks",
"nixpkgs": "nixpkgs", "nixpkgs": "nixpkgs",
"nixpkgs-cmake3": "nixpkgs-cmake3",
"pre-commit-hooks": [ "pre-commit-hooks": [
"git-hooks" "git-hooks"
] ]

View file

@ -1,16 +1,13 @@
{ inputs, pkgs, ... }: { pkgs, ... }:
let
pkgs-cmake3 = import inputs.nixpkgs-cmake3 { system = pkgs.stdenv.system; };
in
{ {
packages = [ packages = [
# pour burn-rs # pour burn-rs
pkgs.SDL2_gfx pkgs.SDL2_gfx
# (compilation sdl2-sys) # (compilation sdl2-sys)
pkgs-cmake3.cmake pkgs.cmake
pkgs.libxcb
pkgs.libffi pkgs.libffi
pkgs.wayland-scanner pkgs.wayland-scanner
@ -18,12 +15,6 @@ in
pkgs.samply # code profiler pkgs.samply # code profiler
pkgs.feedgnuplot # to visualize bots training results pkgs.feedgnuplot # to visualize bots training results
# --- AI training with python ---
# generate python classes from rust code
pkgs.maturin
# required by python numpy
pkgs.libz
# for bevy # for bevy
pkgs.alsa-lib pkgs.alsa-lib
pkgs.udev pkgs.udev
@ -56,25 +47,6 @@ in
# https://devenv.sh/languages/ # https://devenv.sh/languages/
languages.rust.enable = true; languages.rust.enable = true;
# AI training with python
enterShell = ''
PYTHONPATH=$PYTHONPATH:$PWD/.devenv/state/venv/lib/python3/site-packages
'';
languages.python = {
enable = true;
uv.enable = true;
venv.enable = true;
venv.requirements = "
pip
gymnasium
numpy
stable-baselines3
shimmy
";
};
# https://devenv.sh/scripts/ # https://devenv.sh/scripts/
# scripts.hello.exec = "echo hello from $GREET"; # scripts.hello.exec = "echo hello from $GREET";

View file

@ -1,5 +1,3 @@
inputs: inputs:
nixpkgs: nixpkgs:
url: github:NixOS/nixpkgs/nixos-25.11 url: github:NixOS/nixpkgs/nixpkgs-unstable
nixpkgs-cmake3:
url: github:NixOS/nixpkgs/f4b140d5b253f5e2a1ff4e5506edbf8267724bde

View file

@ -1,31 +0,0 @@
# Python bindings
## Génération bindings
```sh
# Generate trictrac python lib as a wheel
maturin build -m store/Cargo.toml --release
# Install wheel in local python env
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
```
## Usage
Pour vérifier l'accès à la lib : lancer le shell interactif `python`
```python
Python 3.13.11 (main, Dec 5 2025, 16:06:33) [GCC 15.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import trictrac_store
>>> game = trictrac_store.TricTrac()
>>> game.get_active_player_id()
1
```
### Appels depuis python
`python bot/python/test.py`
## Interfaces
## Entraînement

View file

@ -20,7 +20,6 @@ profile:
cargo build --profile profiling cargo build --profile profiling
samply record ./target/profiling/client_cli --bot dummy,dummy samply record ./target/profiling/client_cli --bot dummy,dummy
pythonlib: pythonlib:
rm -rf target/wheels
maturin build -m store/Cargo.toml --release maturin build -m store/Cargo.toml --release
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot algo: trainbot algo:

View file

@ -1,23 +1,20 @@
[package] [package]
name = "trictrac-store" name = "store"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib] [lib]
name = "trictrac_store" name = "store"
# "cdylib" is necessary to produce a shared library for Python to import from.
# Only "rlib" is needed for other Rust crates to use this library # Only "rlib" is needed for other Rust crates to use this library
crate-type = ["cdylib", "rlib"] crate-type = ["rlib"]
[dependencies] [dependencies]
base64 = "0.21.7" base64 = "0.21.7"
# provides macros for creating log messages to be used by a logger (for example env_logger) # provides macros for creating log messages to be used by a logger (for example env_logger)
log = "0.4.20" log = "0.4.20"
merge = "0.1.0" merge = "0.1.0"
# generate python lib (with maturin) to be used in AI training rand = "0.8.5"
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
rand = "0.9"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
transpose = "0.2.2" transpose = "0.2.2"

View file

@ -1,8 +0,0 @@
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.maturin]
# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so)
features = ["pyo3/extension-module"]
# python-source = "python"

View file

@ -1,4 +1,4 @@
use rand::distr::{Distribution, Uniform}; use rand::distributions::{Distribution, Uniform};
use rand::{rngs::StdRng, SeedableRng}; use rand::{rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -17,7 +17,7 @@ impl DiceRoller {
pub fn new(opt_seed: Option<u64>) -> Self { pub fn new(opt_seed: Option<u64>) -> Self {
Self { Self {
rng: match opt_seed { rng: match opt_seed {
None => StdRng::from_rng(&mut rand::rng()), None => StdRng::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => SeedableRng::seed_from_u64(seed), Some(seed) => SeedableRng::seed_from_u64(seed),
}, },
} }
@ -26,7 +26,7 @@ impl DiceRoller {
/// Roll the dices which generates two random numbers between 1 and 6, replicating a perfect /// Roll the dices which generates two random numbers between 1 and 6, replicating a perfect
/// dice. We use the operating system's random number generator. /// dice. We use the operating system's random number generator.
pub fn roll(&mut self) -> Dice { pub fn roll(&mut self) -> Dice {
let between = Uniform::new_inclusive(1, 6).expect("1 > 6 !?"); let between = Uniform::new_inclusive(1, 6);
let v = (between.sample(&mut self.rng), between.sample(&mut self.rng)); let v = (between.sample(&mut self.rng), between.sample(&mut self.rng));

View file

@ -2,13 +2,13 @@
use crate::board::{Board, CheckerMove}; use crate::board::{Board, CheckerMove};
use crate::dice::Dice; use crate::dice::Dice;
use crate::game_rules_moves::MoveRules; use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::{PointsRules, PossibleJans, PossibleJansMethods}; use crate::game_rules_points::{PointsRules, PossibleJans};
use crate::player::{Color, Player, PlayerId}; use crate::player::{Color, Player, PlayerId};
use log::{debug, error}; use log::{debug, error};
// use itertools::Itertools; // use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::{fmt, str}; use std::{fmt, str};
@ -90,12 +90,7 @@ impl fmt::Display for GameState {
self.stage, self.turn_stage self.stage, self.turn_stage
)); ));
s.push_str(&format!("Dice: {:?}\n", self.dice)); s.push_str(&format!("Dice: {:?}\n", self.dice));
s.push_str(&format!( // s.push_str(&format!("Who plays: {}\n", self.who_plays().map(|player| &player.name ).unwrap_or("")));
"Who plays: {}\n",
self.who_plays()
.map(|player| &player.name)
.unwrap_or(&String::from(""))
));
s.push_str(&format!("Board: {:?}\n", self.board)); s.push_str(&format!("Board: {:?}\n", self.board));
// s.push_str(&format!("History: {:?}\n", self.history)); // s.push_str(&format!("History: {:?}\n", self.history));
write!(f, "{s}") write!(f, "{s}")
@ -143,40 +138,6 @@ impl GameState {
game game
} }
pub fn mirror(&self) -> GameState {
let mirrored_active_player = if self.active_player_id == 1 { 2 } else { 1 };
let mut mirrored_players = HashMap::new();
if let Some(p2) = self.players.get(&2) {
mirrored_players.insert(1, p2.mirror());
}
if let Some(p1) = self.players.get(&1) {
mirrored_players.insert(2, p1.mirror());
}
let mirrored_history = self
.history
.clone()
.iter()
.map(|evt| evt.get_mirror(false))
.collect();
let (move1, move2) = self.dice_moves;
GameState {
stage: self.stage,
turn_stage: self.turn_stage,
board: self.board.mirror(),
active_player_id: mirrored_active_player,
// active_player_id: self.active_player_id,
players: mirrored_players,
history: mirrored_history,
dice: self.dice,
dice_points: self.dice_points,
dice_moves: (move1.mirror(), move2.mirror()),
dice_jans: self.dice_jans.mirror(),
roll_first: self.roll_first,
schools_enabled: self.schools_enabled,
}
}
fn set_schools_enabled(&mut self, schools_enabled: bool) { fn set_schools_enabled(&mut self, schools_enabled: bool) {
self.schools_enabled = schools_enabled; self.schools_enabled = schools_enabled;
} }
@ -470,12 +431,10 @@ impl GameState {
Roll { player_id } => { Roll { player_id } => {
// Check player exists // Check player exists
if !self.players.contains_key(player_id) { if !self.players.contains_key(player_id) {
error!("unknown player_id");
return false; return false;
} }
// Check player is currently the one making their move // Check player is currently the one making their move
if self.active_player_id != *player_id { if self.active_player_id != *player_id {
error!("not active player_id");
return false; return false;
} }
// Check the turn stage // Check the turn stage
@ -572,7 +531,6 @@ impl GameState {
*moves *moves
}; };
if !rules.moves_follow_rules(&moves) { if !rules.moves_follow_rules(&moves) {
// println!(">>> rules not followed ");
error!("rules not followed "); error!("rules not followed ");
return false; return false;
} }
@ -592,7 +550,7 @@ impl GameState {
pub fn init_player(&mut self, player_name: &str) -> Option<PlayerId> { pub fn init_player(&mut self, player_name: &str) -> Option<PlayerId> {
if self.players.len() > 2 { if self.players.len() > 2 {
// println!("more than two players"); println!("more than two players");
return None; return None;
} }
@ -906,12 +864,10 @@ impl GameEvent {
} }
} }
pub fn get_mirror(&self, preserve_player: bool) -> Self { pub fn get_mirror(&self) -> Self {
// let mut mirror = self.clone(); // let mut mirror = self.clone();
let mirror_player_id = if let Some(player_id) = self.player_id() { let mirror_player_id = if let Some(player_id) = self.player_id() {
if preserve_player { if player_id == 1 {
player_id
} else if player_id == 1 {
2 2
} else { } else {
1 1

View file

@ -81,8 +81,7 @@ impl MoveRules {
let is_allowed = self.moves_allowed(moves); let is_allowed = self.moves_allowed(moves);
// let is_allowed = self.moves_allowed(moves, ignored_rules); // let is_allowed = self.moves_allowed(moves, ignored_rules);
if is_allowed.is_err() { if is_allowed.is_err() {
println!("Move not allowed : {:?}", is_allowed.unwrap_err()); info!("Move not allowed : {:?}", is_allowed.unwrap_err());
// info!("Move not allowed : {:?}", is_allowed.unwrap_err());
false false
} else { } else {
true true
@ -100,10 +99,6 @@ impl MoveRules {
if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from) if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from)
{ {
if color != field_color || field_count < 2 { if color != field_color || field_count < 2 {
println!(
"Move not physically possible 1. field_color {:?}, count {}",
field_color, field_count
);
info!("Move not physically possible"); info!("Move not physically possible");
return false; return false;
} }
@ -115,7 +110,6 @@ impl MoveRules {
if !self.board.passage_possible(color, &moves.0) if !self.board.passage_possible(color, &moves.0)
|| !self.board.move_possible(color, &chained_move) || !self.board.move_possible(color, &chained_move)
{ {
println!("Tout d'une : Move not physically possible");
info!("Tout d'une : Move not physically possible"); info!("Tout d'une : Move not physically possible");
return false; return false;
} }
@ -123,11 +117,6 @@ impl MoveRules {
|| !self.board.move_possible(color, &moves.1) || !self.board.move_possible(color, &moves.1)
{ {
// Move is not physically possible // Move is not physically possible
println!("Move not physically possible 2");
println!(
"board: {}, color: {:?} move: {:?}",
self.board, color, moves
);
info!("Move not physically possible"); info!("Move not physically possible");
return false; return false;
} }

View file

@ -69,26 +69,10 @@ pub type PossibleJans = HashMap<Jan, Vec<(CheckerMove, CheckerMove)>>;
pub trait PossibleJansMethods { pub trait PossibleJansMethods {
fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove)); fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove));
fn merge(&mut self, other: Self); fn merge(&mut self, other: Self);
fn mirror(&self) -> Self;
// fn get_points(&self) -> u8; // fn get_points(&self) -> u8;
} }
impl PossibleJansMethods for PossibleJans { impl PossibleJansMethods for PossibleJans {
fn mirror(&self) -> Self {
self.clone()
.into_iter()
.map(|(jan, moves)| {
(
jan,
moves
.into_iter()
.map(|(m1, m2)| (m1.mirror(), m2.mirror()))
.collect(),
)
})
.collect()
}
fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove)) { fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove)) {
if let Some(ways) = self.get_mut(&jan) { if let Some(ways) = self.get_mut(&jan) {
if !ways.contains(&cmoves) { if !ways.contains(&cmoves) {

View file

@ -16,8 +16,3 @@ pub use board::CheckerMove;
mod dice; mod dice;
pub use dice::{Dice, DiceRoller}; pub use dice::{Dice, DiceRoller};
pub mod training_common;
// python interface "trictrac_engine" (for AI training..)
mod pyengine;

View file

@ -1,11 +1,9 @@
use pyo3::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
// This just makes it easier to dissern between a player id and any ol' u64 // This just makes it easier to dissern between a player id and any ol' u64
pub type PlayerId = u64; pub type PlayerId = u64;
#[pyclass(eq, eq_int)]
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Color { pub enum Color {
White, White,
@ -48,16 +46,6 @@ impl Player {
} }
} }
pub fn mirror(&self) -> Self {
let mut player = self.clone();
player.color = if self.color == Color::White {
Color::Black
} else {
Color::White
};
player
}
pub fn to_bits_string(&self) -> String { pub fn to_bits_string(&self) -> String {
format!( format!(
"{:0>4b}{:0>4b}{:b}{:b}", "{:0>4b}{:0>4b}{:b}{:b}",

View file

@ -1,145 +0,0 @@
//! # Expose trictrac game state and rules in a python module
use pyo3::prelude::*;
use crate::dice::Dice;
use crate::game::{GameEvent, GameState, Stage, TurnStage};
use crate::player::PlayerId;
use crate::training_common::{get_valid_action_indices, TrictracAction};
#[pyclass]
struct TricTrac {
game_state: GameState,
}
#[pymethods]
impl TricTrac {
#[new]
fn new() -> Self {
let mut game_state = GameState::new(false); // schools_enabled = false
// Initialiser 2 joueurs
game_state.init_player("player1");
game_state.init_player("player2");
// Commencer la partie avec le joueur 1
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
TricTrac { game_state }
}
fn needs_roll(&self) -> bool {
self.game_state.turn_stage == TurnStage::RollWaiting
}
fn is_game_ended(&self) -> bool {
self.game_state.stage == Stage::Ended
}
// 0 or 1
fn current_player_idx(&self) -> u64 {
self.game_state.active_player_id - 1
}
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
if player_idx == self.current_player_idx() {
if player_idx == 0 {
get_valid_action_indices(&self.game_state)
} else {
let mirror = self.game_state.mirror();
get_valid_action_indices(&mirror)
}
} else {
vec![]
}
}
fn action_to_string(&self, player_idx: u64, action_idx: usize) -> String {
TrictracAction::from_action_index(action_idx)
.map(|a| format!("{}:{}", player_idx, a))
.unwrap_or("unknown action".into())
}
fn apply_dice_roll(&mut self, dices: (u8, u8)) -> PyResult<()> {
let player_id = self.game_state.active_player_id;
if self.game_state.turn_stage != TurnStage::RollWaiting {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"Not in RollWaiting stage",
));
}
let dice = Dice { values: dices };
self.game_state
.consume(&GameEvent::RollResult { player_id, dice });
Ok(())
}
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
let needs_mirror = self.game_state.active_player_id == 2;
let game_state = if needs_mirror {
&self.game_state.mirror()
} else {
&self.game_state
};
a.to_event(game_state)
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
}) {
if self.game_state.validate(&event) {
self.game_state.consume(&event);
return Ok(());
} else {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"Action is invalid",
));
}
}
Err(pyo3::exceptions::PyRuntimeError::new_err(
"Could not apply action",
))
}
/// Get a player total score (holes & points)
fn get_score(&self, player_id: PlayerId) -> i32 {
if let Some(player) = self.game_state.players.get(&player_id) {
player.holes as i32 * 12 + player.points as i32
} else {
-1
}
}
fn get_players_scores(&self) -> [i32; 2] {
[self.get_score(1), self.get_score(2)]
}
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
if player_idx == 0 {
self.game_state.to_vec()
} else {
self.game_state.mirror().to_vec()
}
}
fn get_observation_string(&self, player_idx: u64) -> String {
if player_idx == 0 {
format!("{}", self.game_state)
} else {
format!("{}", self.game_state.mirror())
}
}
/// Afficher l'état du jeu (pour le débogage)
fn __str__(&self) -> String {
format!("{}", self.game_state)
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn trictrac_store(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<TricTrac>()?;
Ok(())
}