From 012ccf8b425023fcffe9739cd2886077e640e92f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 18 Jan 2026 18:41:08 +0100 Subject: [PATCH] feat: python bindings --- Cargo.lock | 79 +++++++++++++ bot/python/test.py | 4 + bot/src/training_common.rs | 3 +- devenv.lock | 24 ++-- devenv.nix | 25 ++++ doc/python.md | 31 +++++ justfile | 1 + store/Cargo.toml | 5 +- store/pyproject.toml | 9 ++ store/src/lib.rs | 3 + store/src/player.rs | 2 + store/src/pyengine.rs | 230 +++++++++++++++++++++++++++++++++++++ 12 files changed, 402 insertions(+), 14 deletions(-) create mode 100644 bot/python/test.py create mode 100644 doc/python.md create mode 100644 store/pyproject.toml create mode 100644 store/src/pyengine.rs diff --git a/Cargo.lock b/Cargo.lock index a71f75a..de74a7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3460,6 +3460,15 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "merge" version = "0.1.0" @@ -4210,6 +4219,69 @@ dependencies = [ "num-traits", ] +[[package]] +name = "pyo3" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.106", +] + [[package]] name = "qoi" version = "0.4.1" @@ -5154,6 +5226,7 @@ dependencies = [ "base64 0.21.7", "log", "merge", + "pyo3", "rand 0.8.5", "serde", "transpose", @@ -5892,6 +5965,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "universal-hash" version = "0.5.1" diff --git a/bot/python/test.py b/bot/python/test.py new file mode 100644 index 0000000..7b13b7d --- /dev/null +++ b/bot/python/test.py @@ -0,0 +1,4 @@ +import store + +game = store.TricTrac() +print(game.get_state_dict()) diff --git a/bot/src/training_common.rs b/bot/src/training_common.rs index 8c85021..3754086 100644 --- a/bot/src/training_common.rs +++ b/bot/src/training_common.rs @@ -15,7 +15,8 @@ pub const ACTION_SPACE_SIZE: usize = 514; pub enum TrictracAction { /// Lancer les dés Roll, - /// Continuer après avoir gagné un trou + /// Faire un nouveau 'relevé' (repositionnement des dames à l'état de départ) après avoir gagné un trou, + /// au lieu de continuer dans la position courante Go, /// Effectuer un mouvement de pions Move { diff --git a/devenv.lock b/devenv.lock index c3d5629..f30fbdc 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1753667201, + "lastModified": 1768056019, "owner": "cachix", "repo": "devenv", - "rev": "4d584d7686a50387f975879788043e55af9f0ad4", + "rev": "9bfc4a64c3a798ed8fa6cee3a519a9eac5e73cb5", "type": "github" }, "original": { @@ -19,14 +19,14 @@ "flake-compat": { "flake": false, "locked": { - "lastModified": 1747046372, - "owner": "edolstra", + "lastModified": 1767039857, + "owner": "NixOS", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", "type": "github" }, "original": { - "owner": "edolstra", + "owner": "NixOS", "repo": "flake-compat", "type": "github" } @@ -40,10 +40,10 @@ ] }, "locked": { - "lastModified": 1750779888, + "lastModified": 1767281941, "owner": "cachix", "repo": "git-hooks.nix", - "rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d", + "rev": "f0927703b7b1c8d97511c4116eb9b4ec6645a0fa", "type": "github" }, "original": { @@ -60,10 +60,10 @@ ] }, "locked": { - "lastModified": 1709087332, + "lastModified": 1762808025, "owner": "hercules-ci", "repo": "gitignore.nix", - "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c", "type": "github" }, "original": { @@ -74,10 +74,10 @@ }, "nixpkgs": { "locked": { - "lastModified": 1753432016, + "lastModified": 1767995494, "owner": "NixOS", "repo": "nixpkgs", - "rev": "6027c30c8e9810896b92429f0092f624f7b1aace", + "rev": "45a1530683263666f42d1de4cdda328109d5a676", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index 1b51c9d..af6f116 100644 --- a/devenv.nix +++ b/devenv.nix @@ -15,6 +15,12 @@ pkgs.samply # code profiler pkgs.feedgnuplot # to visualize bots training results + # --- AI training with python --- + # generate python classes from rust code + pkgs.maturin + # required by python numpy + pkgs.libz + # for bevy pkgs.alsa-lib pkgs.udev @@ -47,6 +53,25 @@ # https://devenv.sh/languages/ languages.rust.enable = true; + + # AI training with python + enterShell = '' + PYTHONPATH=$PYTHONPATH:$PWD/.devenv/state/venv/lib/python3/site-packages + ''; + + languages.python = { + enable = true; + uv.enable = true; + venv.enable = true; + venv.requirements = " + pip + gymnasium + numpy + stable-baselines3 + shimmy + "; + }; + # https://devenv.sh/scripts/ # scripts.hello.exec = "echo hello from $GREET"; diff --git a/doc/python.md b/doc/python.md new file mode 100644 index 0000000..65b0239 --- /dev/null +++ b/doc/python.md @@ -0,0 +1,31 @@ +# Python bindings + +## Génération bindings + +```sh +# Generate trictrac python lib as a wheel +maturin build -m store/Cargo.toml --release +# Install wheel in local python env +pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +``` + +## Usage + +Pour vérifier l'accès à la lib : lancer le shell interactif `python` + +```python +Python 3.13.11 (main, Dec 5 2025, 16:06:33) [GCC 15.2.0] on linux +Type "help", "copyright", "credits" or "license" for more information. +>>> import store +>>> game = store.TricTrac() +>>> game.get_active_player_id() +1 +``` + +### Appels depuis python + +`python bot/python/test.py` + +## Interfaces + +## Entraînement diff --git a/justfile b/justfile index 9c8bf58..33c0654 100644 --- a/justfile +++ b/justfile @@ -20,6 +20,7 @@ profile: cargo build --profile profiling samply record ./target/profiling/client_cli --bot dummy,dummy pythonlib: + rm -rf target/wheels maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainbot algo: diff --git a/store/Cargo.toml b/store/Cargo.toml index a071dd1..0517553 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -7,14 +7,17 @@ edition = "2021" [lib] name = "store" +# "cdylib" is necessary to produce a shared library for Python to import from. # Only "rlib" is needed for other Rust crates to use this library -crate-type = ["rlib"] +crate-type = ["cdylib", "rlib"] [dependencies] base64 = "0.21.7" # provides macros for creating log messages to be used by a logger (for example env_logger) log = "0.4.20" merge = "0.1.0" +# generate python lib (with maturin) to be used in AI training +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } rand = "0.8.5" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" diff --git a/store/pyproject.toml b/store/pyproject.toml new file mode 100644 index 0000000..8fe5762 --- /dev/null +++ b/store/pyproject.toml @@ -0,0 +1,9 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) +features = ["pyo3/extension-module"] +# python-source = "python" +# module-name = "trictrac.game" diff --git a/store/src/lib.rs b/store/src/lib.rs index 58a5727..60639e5 100644 --- a/store/src/lib.rs +++ b/store/src/lib.rs @@ -16,3 +16,6 @@ pub use board::CheckerMove; mod dice; pub use dice::{Dice, DiceRoller}; + +// python interface "trictrac_engine" (for AI training..) +mod pyengine; diff --git a/store/src/player.rs b/store/src/player.rs index d990a1f..eeb5829 100644 --- a/store/src/player.rs +++ b/store/src/player.rs @@ -1,9 +1,11 @@ +use pyo3::prelude::*; use serde::{Deserialize, Serialize}; use std::fmt; // This just makes it easier to dissern between a player id and any ol' u64 pub type PlayerId = u64; +#[pyclass(eq, eq_int)] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Color { White, diff --git a/store/src/pyengine.rs b/store/src/pyengine.rs new file mode 100644 index 0000000..b436baa --- /dev/null +++ b/store/src/pyengine.rs @@ -0,0 +1,230 @@ +//! # Expose trictrac game state and rules in a python module +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::board::CheckerMove; +use crate::dice::{Dice, DiceRoller}; +use crate::game::{GameEvent, GameState, Stage, TurnStage}; +use crate::game_rules_moves::MoveRules; +use crate::game_rules_points::PointsRules; +use crate::player::{Color, PlayerId}; + +#[pyclass] +struct TricTrac { + game_state: GameState, + dice_roll_sequence: Vec<(u8, u8)>, + current_dice_index: usize, +} + +#[pymethods] +impl TricTrac { + #[new] + fn new() -> Self { + let mut game_state = GameState::new(false); // schools_enabled = false + + // Initialiser 2 joueurs + game_state.init_player("player1"); + game_state.init_player("player2"); + + // Commencer la partie avec le joueur 1 + game_state.consume(&GameEvent::BeginGame { goes_first: 1 }); + + TricTrac { + game_state, + dice_roll_sequence: Vec::new(), + current_dice_index: 0, + } + } + + /// Obtenir l'état du jeu sous forme de dictionnaire + fn get_state_dict<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = PyDict::new(py); + dict.set_item("stage", format!("{:?}", self.game_state.stage))?; + dict.set_item("turn_stage", format!("{:?}", self.game_state.turn_stage))?; + dict.set_item("active_player_id", self.game_state.active_player_id)?; + + // Board + let board_list = self.game_state.board.to_vec(); // returns Vec + dict.set_item("board", board_list)?; + + // Dice + dict.set_item("dice", (self.game_state.dice.values.0, self.game_state.dice.values.1))?; + + // Players + let players_dict = PyDict::new(py); + for (id, player) in &self.game_state.players { + let p_dict = PyDict::new(py); + p_dict.set_item("color", format!("{:?}", player.color))?; + p_dict.set_item("holes", player.holes)?; + p_dict.set_item("points", player.points)?; + p_dict.set_item("can_bredouille", player.can_bredouille)?; + p_dict.set_item("dice_roll_count", player.dice_roll_count)?; + players_dict.set_item(id, p_dict)?; + } + dict.set_item("players", players_dict)?; + + Ok(dict) + } + + /// Lance les dés ou utilise la séquence prédéfinie + fn roll_dice(&mut self) -> PyResult<(u8, u8)> { + let player_id = self.game_state.active_player_id; + + if self.game_state.turn_stage != TurnStage::RollDice { + return Err(pyo3::exceptions::PyRuntimeError::new_err("Not in RollDice stage")); + } + + self.game_state.consume(&GameEvent::Roll { player_id }); + + let dice = if self.current_dice_index < self.dice_roll_sequence.len() { + let vals = self.dice_roll_sequence[self.current_dice_index]; + self.current_dice_index += 1; + Dice { values: vals } + } else { + DiceRoller::default().roll() + }; + + self.game_state.consume(&GameEvent::RollResult { player_id, dice }); + + Ok(dice.values) + } + + /// Applique un mouvement (deux déplacements de dames) + fn apply_move(&mut self, from1: usize, to1: usize, from2: usize, to2: usize) -> PyResult<()> { + let player_id = self.game_state.active_player_id; + + let m1 = CheckerMove::new(from1, to1).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let m2 = CheckerMove::new(from2, to2).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + + let moves = (m1, m2); + + if !self.game_state.validate(&GameEvent::Move { player_id, moves }) { + return Err(pyo3::exceptions::PyValueError::new_err("Invalid move")); + } + + self.game_state.consume(&GameEvent::Move { player_id, moves }); + Ok(()) + } + + /// Obtenir l'état du jeu sous forme de chaîne de caractères compacte + fn get_state_id(&self) -> String { + self.game_state.to_string_id() + } + + /// Renvoie les positions des pièces pour un joueur spécifique + fn get_checker_positions(&self, color: Color) -> Vec<(usize, i8)> { + self.game_state.board.get_color_fields(color) + } + + /// Obtenir la liste des mouvements légaux sous forme de paires (from, to) + fn get_available_moves(&self) -> Vec<((usize, usize), (usize, usize))> { + // L'agent joue toujours le joueur actif + let color = self + .game_state + .player_color_by_id(&self.game_state.active_player_id) + .unwrap_or(Color::White); + + // Si ce n'est pas le moment de déplacer les pièces, retourner une liste vide + if self.game_state.turn_stage != TurnStage::Move + && self.game_state.turn_stage != TurnStage::HoldOrGoChoice + { + return vec![]; + } + + let rules = MoveRules::new(&color, &self.game_state.board, self.game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + // Convertir les mouvements CheckerMove en tuples (from, to) pour Python + possible_moves + .into_iter() + .map(|(move1, move2)| { + ( + (move1.get_from(), move1.get_to()), + (move2.get_from(), move2.get_to()), + ) + }) + .collect() + } + + /// Calcule les points maximaux que le joueur actif peut obtenir avec les dés actuels + fn calculate_points(&self) -> u8 { + let active_player = self + .game_state + .players + .get(&self.game_state.active_player_id); + + if let Some(player) = active_player { + let dice_roll_count = player.dice_roll_count; + let color = player.color; + + let points_rules = + PointsRules::new(&color, &self.game_state.board, self.game_state.dice); + let (points, _) = points_rules.get_points(dice_roll_count); + + points + } else { + 0 + } + } + + /// Réinitialise la partie + fn reset(&mut self) { + self.game_state = GameState::new(false); + + // Initialiser 2 joueurs + self.game_state.init_player("player1"); + self.game_state.init_player("player2"); + + // Commencer la partie avec le joueur 1 + self.game_state + .consume(&GameEvent::BeginGame { goes_first: 1 }); + + // Réinitialiser l'index de la séquence de dés + self.current_dice_index = 0; + } + + /// Vérifie si la partie est terminée + fn is_done(&self) -> bool { + self.game_state.stage == Stage::Ended || self.game_state.determine_winner().is_some() + } + + /// Obtenir le gagnant de la partie + fn get_winner(&self) -> Option { + self.game_state.determine_winner() + } + + /// Obtenir le score du joueur actif (nombre de trous) + fn get_score(&self, player_id: PlayerId) -> i32 { + if let Some(player) = self.game_state.players.get(&player_id) { + player.holes as i32 + } else { + -1 + } + } + + /// Obtenir l'ID du joueur actif + fn get_active_player_id(&self) -> PlayerId { + self.game_state.active_player_id + } + + /// Définir une séquence de dés à utiliser (pour la reproductibilité) + fn set_dice_sequence(&mut self, sequence: Vec<(u8, u8)>) { + self.dice_roll_sequence = sequence; + self.current_dice_index = 0; + } + + /// Afficher l'état du jeu (pour le débogage) + fn __str__(&self) -> String { + format!("{}", self.game_state) + } +} + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn store(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +}