diff --git a/Cargo.lock b/Cargo.lock index 320dcb8..a43261e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1222,6 +1222,32 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + [[package]] name = "cmake" version = "0.1.57" @@ -1242,6 +1268,17 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "codespan-reporting" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" +dependencies = [ + "serde", + "termcolor", + "unicode-width 0.2.0", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -1943,6 +1980,68 @@ dependencies = [ "libloading", ] +[[package]] +name = "cxx" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "747d8437319e3a2f43d93b341c137927ca70c0f5dabeea7a005a73665e247c7e" +dependencies = [ + "cc", + "cxx-build", + "cxxbridge-cmd", + "cxxbridge-flags", + "cxxbridge-macro", + "foldhash 0.2.0", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f4697d190a142477b16aef7da8a99bfdc41e7e8b1687583c0d23a79c7afc1e" +dependencies = [ + "cc", + "codespan-reporting 0.13.1", + "indexmap", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.114", +] + +[[package]] +name = "cxxbridge-cmd" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0956799fa8678d4c50eed028f2de1c0552ae183c76e976cf7ca8c4e36a7c328" +dependencies = [ + "clap", + "codespan-reporting 0.13.1", + "indexmap", + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23384a836ab4f0ad98ace7e3955ad2de39de42378ab487dc28d3990392cb283a" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6acc6b5822b9526adfb4fc377b67128fdd60aac757cc4a741a6278603f763cf" +dependencies = [ + "indexmap", + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "darling" version = "0.20.11" @@ -3733,6 +3832,15 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "link-cplusplus" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -4000,7 +4108,7 @@ dependencies = [ "bitflags 2.10.0", "cfg-if", "cfg_aliases", - "codespan-reporting", + "codespan-reporting 0.12.0", "half", "hashbrown 0.15.5", "hexf-parse", @@ -5489,6 +5597,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" + [[package]] name = "sdl2" version = "0.37.0" @@ -6617,7 +6731,10 @@ dependencies = [ name = "trictrac-store" version = "0.1.0" dependencies = [ + "anyhow", "base64 0.21.7", + "cxx", + "cxx-build", "log", "merge", "pyo3", diff --git a/justfile b/justfile index 33c0654..2bfc052 100644 --- a/justfile +++ b/justfile @@ -23,6 +23,10 @@ pythonlib: rm -rf target/wheels maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +cxxlib: + cargo build --release -p trictrac-store + @echo "Static lib: $(ls target/release/libtrictrac_store.a)" + @echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)" trainbot algo: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok diff --git a/store/Cargo.toml b/store/Cargo.toml index 846e5fb..a9234ff 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -7,12 +7,15 @@ edition = "2021" [lib] name = "trictrac_store" -# "cdylib" is necessary to produce a shared library for Python to import from. -# Only "rlib" is needed for other Rust crates to use this library -crate-type = ["cdylib", "rlib"] +# "cdylib" → Python .so built by maturin (pyengine) +# "rlib" → used by other workspace crates (bot, client_cli) +# "staticlib" → used by the C++ OpenSpiel game (cxxengine) +crate-type = ["cdylib", "rlib", "staticlib"] [dependencies] +anyhow = "1.0" base64 = "0.21.7" +cxx = "1.0" # provides macros for creating log messages to be used by a logger (for example env_logger) log = "0.4.20" merge = "0.1.0" @@ -21,3 +24,6 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" + +[build-dependencies] +cxx-build = "1.0" diff --git a/store/build.rs b/store/build.rs new file mode 100644 index 0000000..852bcf6 --- /dev/null +++ b/store/build.rs @@ -0,0 +1,7 @@ +fn main() { + cxx_build::bridge("src/cxxengine.rs") + .std("c++17") + .compile("trictrac-cxx"); + + println!("cargo:rerun-if-changed=src/cxxengine.rs"); +} diff --git a/store/src/cxxengine.rs b/store/src/cxxengine.rs new file mode 100644 index 0000000..86aa382 --- /dev/null +++ b/store/src/cxxengine.rs @@ -0,0 +1,223 @@ +//! C++ bindings for the TricTrac game engine via cxx.rs. +//! +//! Exposes an opaque `TricTracEngine` type to C++. The C++ side +//! (open_spiel/games/trictrac/trictrac.cc) holds it via +//! `rust::Box`. +//! +//! The Rust engine always reasons from White's (player 1's) perspective. +//! For Black (player 2), the board is mirrored before computing actions +//! and events are mirrored back before being applied — exactly as in +//! pyengine.rs. + +use crate::dice::Dice; +use crate::game::{GameEvent, GameState, Stage, TurnStage}; +use crate::training_common::{get_valid_action_indices, TrictracAction}; + +// ── cxx bridge declaration ──────────────────────────────────────────────────── + +#[cxx::bridge(namespace = "trictrac_engine")] +pub mod ffi { + // ── Shared types (transparent to both Rust and C++) ─────────────────────── + + /// Two dice values passed from C++ when applying a chance outcome. + struct DicePair { + die1: u8, + die2: u8, + } + + /// Both players' cumulative scores: holes * 12 + points. + struct PlayerScores { + score_p1: i32, + score_p2: i32, + } + + // ── Opaque Rust type and its free-function constructor ──────────────────── + + extern "Rust" { + /// Opaque handle to a running TricTrac game. + /// C++ accesses this only through `rust::Box`. + type TricTracEngine; + + /// Construct a fresh engine with two players; player 1 (White) goes first. + fn new_trictrac_engine() -> Box; + + /// Deep-copy the engine — required by OpenSpiel's State::Clone(). + fn clone_engine(self: &TricTracEngine) -> Box; + + // ── Queries ─────────────────────────────────────────────────────────── + + /// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node). + fn needs_roll(self: &TricTracEngine) -> bool; + + /// True when Stage::Ended. + fn is_game_ended(self: &TricTracEngine) -> bool; + + /// Active player index: 0 = player 1 (White), 1 = player 2 (Black). + fn current_player_idx(self: &TricTracEngine) -> u64; + + /// Legal action indices for `player_idx` in [0, 513]. + /// Returns an empty vector when it is not that player's turn. + fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec; + + /// Human-readable description of an action index. + fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String; + + /// Both players' scores. + fn get_players_scores(self: &TricTracEngine) -> PlayerScores; + + /// 36-element state vector (i8). Mirrored for player_idx == 1. + fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; + + /// Human-readable state description for `player_idx`. + fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; + + /// Full debug representation of the current state. + fn to_debug_string(self: &TricTracEngine) -> String; + + // ── Mutations ───────────────────────────────────────────────────────── + + /// Apply a dice-roll result. Returns Err (C++ exception) if not in + /// the RollWaiting stage. + fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>; + + /// Apply a player action. Returns Err (C++ exception) if the action + /// is not legal in the current state. + fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>; + } +} + +// ── Opaque type ─────────────────────────────────────────────────────────────── + +pub struct TricTracEngine { + game_state: GameState, +} + +// ── Free-function constructor (declared in the bridge as a plain function) ──── + +pub fn new_trictrac_engine() -> Box { + let mut game_state = GameState::new(false); // schools_enabled = false + game_state.init_player("player1"); + game_state.init_player("player2"); + game_state.consume(&GameEvent::BeginGame { goes_first: 1 }); + Box::new(TricTracEngine { game_state }) +} + +// ── Method implementations ──────────────────────────────────────────────────── + +impl TricTracEngine { + fn clone_engine(&self) -> Box { + Box::new(TricTracEngine { + game_state: self.game_state.clone(), + }) + } + + fn needs_roll(&self) -> bool { + self.game_state.turn_stage == TurnStage::RollWaiting + } + + fn is_game_ended(&self) -> bool { + self.game_state.stage == Stage::Ended + } + + fn current_player_idx(&self) -> u64 { + self.game_state.active_player_id - 1 + } + + fn get_legal_actions(&self, player_idx: u64) -> Vec { + if player_idx != self.current_player_idx() { + return vec![]; + } + if player_idx == 0 { + get_valid_action_indices(&self.game_state) + .into_iter() + .map(|i| i as u64) + .collect() + } else { + let mirror = self.game_state.mirror(); + get_valid_action_indices(&mirror) + .into_iter() + .map(|i| i as u64) + .collect() + } + } + + fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String { + TrictracAction::from_action_index(action_idx as usize) + .map(|a| format!("{}:{}", player_idx, a)) + .unwrap_or_else(|| "unknown action".into()) + } + + fn get_players_scores(&self) -> ffi::PlayerScores { + ffi::PlayerScores { + score_p1: self.score_for(1), + score_p2: self.score_for(2), + } + } + + fn score_for(&self, player_id: u64) -> i32 { + self.game_state + .players + .get(&player_id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(-1) + } + + fn get_tensor(&self, player_idx: u64) -> Vec { + if player_idx == 0 { + self.game_state.to_vec() + } else { + self.game_state.mirror().to_vec() + } + } + + fn get_observation_string(&self, player_idx: u64) -> String { + if player_idx == 0 { + format!("{}", self.game_state) + } else { + format!("{}", self.game_state.mirror()) + } + } + + fn to_debug_string(&self) -> String { + format!("{}", self.game_state) + } + + fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> anyhow::Result<()> { + if self.game_state.turn_stage != TurnStage::RollWaiting { + anyhow::bail!("apply_dice_roll: not in RollWaiting stage (currently {:?})", + self.game_state.turn_stage); + } + let player_id = self.game_state.active_player_id; + let dice = Dice { + values: (dice.die1, dice.die2), + }; + self.game_state + .consume(&GameEvent::RollResult { player_id, dice }); + Ok(()) + } + + fn apply_action(&mut self, action_idx: u64) -> anyhow::Result<()> { + let needs_mirror = self.game_state.active_player_id == 2; + + let event = TrictracAction::from_action_index(action_idx as usize).and_then(|a| { + let state = if needs_mirror { + &self.game_state.mirror() + } else { + &self.game_state + }; + a.to_event(state) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + }); + + match event { + Some(evt) if self.game_state.validate(&evt) => { + self.game_state.consume(&evt); + Ok(()) + } + Some(_) => anyhow::bail!("apply_action: action {} is not valid in current state", + action_idx), + None => anyhow::bail!("apply_action: could not build event from action index {}", + action_idx), + } + } +} diff --git a/store/src/lib.rs b/store/src/lib.rs index 1bb8d1d..4fc8dff 100644 --- a/store/src/lib.rs +++ b/store/src/lib.rs @@ -21,3 +21,6 @@ pub mod training_common; // python interface "trictrac_engine" (for AI training..) mod pyengine; + +// C++ interface via cxx.rs (for OpenSpiel C++ integration) +pub mod cxxengine;