trictrac/store/src/pyengine.rs

156 lines
4.6 KiB
Rust
Raw Normal View History

2026-01-18 18:41:08 +01:00
//! # Expose trictrac game state and rules in a python module
use pyo3::prelude::*;
use pyo3::types::PyDict;
use crate::board::CheckerMove;
use crate::dice::{Dice, DiceRoller};
use crate::game::{GameEvent, GameState, Stage, TurnStage};
use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::PointsRules;
use crate::player::{Color, PlayerId};
2026-02-08 18:57:20 +01:00
use crate::training_common::{get_valid_action_indices, TrictracAction};
2026-01-18 18:41:08 +01:00
#[pyclass]
struct TricTrac {
game_state: GameState,
dice_roll_sequence: Vec<(u8, u8)>,
current_dice_index: usize,
}
#[pymethods]
impl TricTrac {
#[new]
fn new() -> Self {
let mut game_state = GameState::new(false); // schools_enabled = false
// Initialiser 2 joueurs
game_state.init_player("player1");
game_state.init_player("player2");
// Commencer la partie avec le joueur 1
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
TricTrac {
game_state,
dice_roll_sequence: Vec::new(),
current_dice_index: 0,
}
}
fn needs_roll(&self) -> bool {
self.game_state.turn_stage == TurnStage::RollWaiting
}
2026-01-18 18:41:08 +01:00
fn is_game_ended(&self) -> bool {
self.game_state.stage == Stage::Ended
2026-01-18 18:41:08 +01:00
}
// 0 or 1
fn current_player_idx(&self) -> u64 {
self.game_state.active_player_id - 1
}
2026-02-15 12:08:24 +01:00
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
if player_idx == self.current_player_idx() {
if player_idx == 0 {
get_valid_action_indices(&self.game_state)
} else {
let mirror = self.game_state.mirror();
get_valid_action_indices(&mirror)
}
2026-02-08 18:57:20 +01:00
} else {
vec![]
}
}
fn action_to_string(&self, player_idx: u64, action_idx: usize) -> String {
TrictracAction::from_action_index(action_idx)
.map(|a| format!("{}:{}", player_idx, a))
.unwrap_or("unknown action".into())
}
2026-02-08 18:57:20 +01:00
fn apply_dice_roll(&mut self, dices: (u8, u8)) -> PyResult<()> {
2026-01-18 18:41:08 +01:00
let player_id = self.game_state.active_player_id;
2026-02-08 18:57:20 +01:00
if self.game_state.turn_stage != TurnStage::RollWaiting {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
2026-02-08 18:57:20 +01:00
"Not in RollWaiting stage",
));
2026-01-18 18:41:08 +01:00
}
2026-02-08 18:57:20 +01:00
let dice = Dice { values: dices };
self.game_state
.consume(&GameEvent::RollResult { player_id, dice });
2026-01-18 18:41:08 +01:00
Ok(())
}
2026-02-08 18:57:20 +01:00
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
2026-02-15 12:08:24 +01:00
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
let needs_mirror = self.game_state.active_player_id == 2;
let game_state = if needs_mirror {
&self.game_state.mirror()
} else {
&self.game_state
};
a.to_event(game_state)
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
}) {
2026-02-08 18:57:20 +01:00
if self.game_state.validate(&event) {
self.game_state.consume(&event);
return Ok(());
} else {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"Action is invalid",
));
}
2026-01-18 18:41:08 +01:00
}
2026-02-08 18:57:20 +01:00
Err(pyo3::exceptions::PyRuntimeError::new_err(
"Could not apply action",
))
2026-01-18 18:41:08 +01:00
}
2026-02-08 18:57:20 +01:00
/// Get a player total score (holes & points)
2026-01-18 18:41:08 +01:00
fn get_score(&self, player_id: PlayerId) -> i32 {
if let Some(player) = self.game_state.players.get(&player_id) {
2026-02-08 18:57:20 +01:00
player.holes as i32 * 12 + player.points as i32
2026-01-18 18:41:08 +01:00
} else {
-1
}
}
2026-02-08 18:57:20 +01:00
fn get_players_scores(&self) -> [i32; 2] {
[self.get_score(1), self.get_score(2)]
2026-01-18 18:41:08 +01:00
}
2026-02-15 12:08:24 +01:00
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
if player_idx == 0 {
self.game_state.to_vec()
} else {
self.game_state.mirror().to_vec()
}
}
fn get_observation_string(&self, player_idx: u64) -> String {
if player_idx == 0 {
format!("{}", self.game_state)
} else {
format!("{}", self.game_state.mirror())
}
2026-01-18 18:41:08 +01:00
}
/// Afficher l'état du jeu (pour le débogage)
fn __str__(&self) -> String {
format!("{}", self.game_state)
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn trictrac_store(m: &Bound<'_, PyModule>) -> PyResult<()> {
2026-01-18 18:41:08 +01:00
m.add_class::<TricTrac>()?;
Ok(())
}