feat: c++ bindings
This commit is contained in:
parent
0429999672
commit
3490a184b3
6 changed files with 364 additions and 4 deletions
119
Cargo.lock
generated
119
Cargo.lock
generated
|
|
@ -1222,6 +1222,32 @@ dependencies = [
|
||||||
"libloading",
|
"libloading",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap"
|
||||||
|
version = "4.5.60"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
|
||||||
|
dependencies = [
|
||||||
|
"clap_builder",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_builder"
|
||||||
|
version = "4.5.60"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"clap_lex",
|
||||||
|
"strsim",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_lex"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cmake"
|
name = "cmake"
|
||||||
version = "0.1.57"
|
version = "0.1.57"
|
||||||
|
|
@ -1242,6 +1268,17 @@ dependencies = [
|
||||||
"unicode-width 0.2.0",
|
"unicode-width 0.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "codespan-reporting"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"termcolor",
|
||||||
|
"unicode-width 0.2.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "color_quant"
|
name = "color_quant"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
|
@ -1943,6 +1980,68 @@ dependencies = [
|
||||||
"libloading",
|
"libloading",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cxx"
|
||||||
|
version = "1.0.194"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "747d8437319e3a2f43d93b341c137927ca70c0f5dabeea7a005a73665e247c7e"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"cxx-build",
|
||||||
|
"cxxbridge-cmd",
|
||||||
|
"cxxbridge-flags",
|
||||||
|
"cxxbridge-macro",
|
||||||
|
"foldhash 0.2.0",
|
||||||
|
"link-cplusplus",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cxx-build"
|
||||||
|
version = "1.0.194"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b0f4697d190a142477b16aef7da8a99bfdc41e7e8b1687583c0d23a79c7afc1e"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"codespan-reporting 0.13.1",
|
||||||
|
"indexmap",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"scratch",
|
||||||
|
"syn 2.0.114",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cxxbridge-cmd"
|
||||||
|
version = "1.0.194"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d0956799fa8678d4c50eed028f2de1c0552ae183c76e976cf7ca8c4e36a7c328"
|
||||||
|
dependencies = [
|
||||||
|
"clap",
|
||||||
|
"codespan-reporting 0.13.1",
|
||||||
|
"indexmap",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.114",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cxxbridge-flags"
|
||||||
|
version = "1.0.194"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "23384a836ab4f0ad98ace7e3955ad2de39de42378ab487dc28d3990392cb283a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cxxbridge-macro"
|
||||||
|
version = "1.0.194"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6acc6b5822b9526adfb4fc377b67128fdd60aac757cc4a741a6278603f763cf"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.114",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.20.11"
|
version = "0.20.11"
|
||||||
|
|
@ -3733,6 +3832,15 @@ dependencies = [
|
||||||
"vcpkg",
|
"vcpkg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "link-cplusplus"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.4.15"
|
version = "0.4.15"
|
||||||
|
|
@ -4000,7 +4108,7 @@ dependencies = [
|
||||||
"bitflags 2.10.0",
|
"bitflags 2.10.0",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"cfg_aliases",
|
"cfg_aliases",
|
||||||
"codespan-reporting",
|
"codespan-reporting 0.12.0",
|
||||||
"half",
|
"half",
|
||||||
"hashbrown 0.15.5",
|
"hashbrown 0.15.5",
|
||||||
"hexf-parse",
|
"hexf-parse",
|
||||||
|
|
@ -5489,6 +5597,12 @@ version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scratch"
|
||||||
|
version = "1.0.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sdl2"
|
name = "sdl2"
|
||||||
version = "0.37.0"
|
version = "0.37.0"
|
||||||
|
|
@ -6617,7 +6731,10 @@ dependencies = [
|
||||||
name = "trictrac-store"
|
name = "trictrac-store"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
|
"cxx",
|
||||||
|
"cxx-build",
|
||||||
"log",
|
"log",
|
||||||
"merge",
|
"merge",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
|
|
|
||||||
4
justfile
4
justfile
|
|
@ -23,6 +23,10 @@ pythonlib:
|
||||||
rm -rf target/wheels
|
rm -rf target/wheels
|
||||||
maturin build -m store/Cargo.toml --release
|
maturin build -m store/Cargo.toml --release
|
||||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||||
|
cxxlib:
|
||||||
|
cargo build --release -p trictrac-store
|
||||||
|
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
|
||||||
|
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
|
||||||
trainbot algo:
|
trainbot algo:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,15 @@ edition = "2021"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "trictrac_store"
|
name = "trictrac_store"
|
||||||
# "cdylib" is necessary to produce a shared library for Python to import from.
|
# "cdylib" → Python .so built by maturin (pyengine)
|
||||||
# Only "rlib" is needed for other Rust crates to use this library
|
# "rlib" → used by other workspace crates (bot, client_cli)
|
||||||
crate-type = ["cdylib", "rlib"]
|
# "staticlib" → used by the C++ OpenSpiel game (cxxengine)
|
||||||
|
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
base64 = "0.21.7"
|
base64 = "0.21.7"
|
||||||
|
cxx = "1.0"
|
||||||
# provides macros for creating log messages to be used by a logger (for example env_logger)
|
# provides macros for creating log messages to be used by a logger (for example env_logger)
|
||||||
log = "0.4.20"
|
log = "0.4.20"
|
||||||
merge = "0.1.0"
|
merge = "0.1.0"
|
||||||
|
|
@ -21,3 +24,6 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
transpose = "0.2.2"
|
transpose = "0.2.2"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
cxx-build = "1.0"
|
||||||
|
|
|
||||||
7
store/build.rs
Normal file
7
store/build.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
fn main() {
|
||||||
|
cxx_build::bridge("src/cxxengine.rs")
|
||||||
|
.std("c++17")
|
||||||
|
.compile("trictrac-cxx");
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
||||||
|
}
|
||||||
223
store/src/cxxengine.rs
Normal file
223
store/src/cxxengine.rs
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
//! C++ bindings for the TricTrac game engine via cxx.rs.
|
||||||
|
//!
|
||||||
|
//! Exposes an opaque `TricTracEngine` type to C++. The C++ side
|
||||||
|
//! (open_spiel/games/trictrac/trictrac.cc) holds it via
|
||||||
|
//! `rust::Box<trictrac_engine::TricTracEngine>`.
|
||||||
|
//!
|
||||||
|
//! The Rust engine always reasons from White's (player 1's) perspective.
|
||||||
|
//! For Black (player 2), the board is mirrored before computing actions
|
||||||
|
//! and events are mirrored back before being applied — exactly as in
|
||||||
|
//! pyengine.rs.
|
||||||
|
|
||||||
|
use crate::dice::Dice;
|
||||||
|
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||||
|
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||||
|
|
||||||
|
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cxx::bridge(namespace = "trictrac_engine")]
|
||||||
|
pub mod ffi {
|
||||||
|
// ── Shared types (transparent to both Rust and C++) ───────────────────────
|
||||||
|
|
||||||
|
/// Two dice values passed from C++ when applying a chance outcome.
|
||||||
|
struct DicePair {
|
||||||
|
die1: u8,
|
||||||
|
die2: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Both players' cumulative scores: holes * 12 + points.
|
||||||
|
struct PlayerScores {
|
||||||
|
score_p1: i32,
|
||||||
|
score_p2: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Opaque Rust type and its free-function constructor ────────────────────
|
||||||
|
|
||||||
|
extern "Rust" {
|
||||||
|
/// Opaque handle to a running TricTrac game.
|
||||||
|
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
||||||
|
type TricTracEngine;
|
||||||
|
|
||||||
|
/// Construct a fresh engine with two players; player 1 (White) goes first.
|
||||||
|
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
||||||
|
|
||||||
|
/// Deep-copy the engine — required by OpenSpiel's State::Clone().
|
||||||
|
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
||||||
|
|
||||||
|
// ── Queries ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
||||||
|
fn needs_roll(self: &TricTracEngine) -> bool;
|
||||||
|
|
||||||
|
/// True when Stage::Ended.
|
||||||
|
fn is_game_ended(self: &TricTracEngine) -> bool;
|
||||||
|
|
||||||
|
/// Active player index: 0 = player 1 (White), 1 = player 2 (Black).
|
||||||
|
fn current_player_idx(self: &TricTracEngine) -> u64;
|
||||||
|
|
||||||
|
/// Legal action indices for `player_idx` in [0, 513].
|
||||||
|
/// Returns an empty vector when it is not that player's turn.
|
||||||
|
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec<u64>;
|
||||||
|
|
||||||
|
/// Human-readable description of an action index.
|
||||||
|
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
||||||
|
|
||||||
|
/// Both players' scores.
|
||||||
|
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||||
|
|
||||||
|
/// 36-element state vector (i8). Mirrored for player_idx == 1.
|
||||||
|
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||||
|
|
||||||
|
/// Human-readable state description for `player_idx`.
|
||||||
|
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||||
|
|
||||||
|
/// Full debug representation of the current state.
|
||||||
|
fn to_debug_string(self: &TricTracEngine) -> String;
|
||||||
|
|
||||||
|
// ── Mutations ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Apply a dice-roll result. Returns Err (C++ exception) if not in
|
||||||
|
/// the RollWaiting stage.
|
||||||
|
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
||||||
|
|
||||||
|
/// Apply a player action. Returns Err (C++ exception) if the action
|
||||||
|
/// is not legal in the current state.
|
||||||
|
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Opaque type ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct TricTracEngine {
|
||||||
|
game_state: GameState,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Free-function constructor (declared in the bridge as a plain function) ────
|
||||||
|
|
||||||
|
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
||||||
|
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||||
|
game_state.init_player("player1");
|
||||||
|
game_state.init_player("player2");
|
||||||
|
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||||
|
Box::new(TricTracEngine { game_state })
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Method implementations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
impl TricTracEngine {
|
||||||
|
fn clone_engine(&self) -> Box<TricTracEngine> {
|
||||||
|
Box::new(TricTracEngine {
|
||||||
|
game_state: self.game_state.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn needs_roll(&self) -> bool {
|
||||||
|
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_game_ended(&self) -> bool {
|
||||||
|
self.game_state.stage == Stage::Ended
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_player_idx(&self) -> u64 {
|
||||||
|
self.game_state.active_player_id - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_legal_actions(&self, player_idx: u64) -> Vec<u64> {
|
||||||
|
if player_idx != self.current_player_idx() {
|
||||||
|
return vec![];
|
||||||
|
}
|
||||||
|
if player_idx == 0 {
|
||||||
|
get_valid_action_indices(&self.game_state)
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| i as u64)
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
let mirror = self.game_state.mirror();
|
||||||
|
get_valid_action_indices(&mirror)
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| i as u64)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
||||||
|
TrictracAction::from_action_index(action_idx as usize)
|
||||||
|
.map(|a| format!("{}:{}", player_idx, a))
|
||||||
|
.unwrap_or_else(|| "unknown action".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_players_scores(&self) -> ffi::PlayerScores {
|
||||||
|
ffi::PlayerScores {
|
||||||
|
score_p1: self.score_for(1),
|
||||||
|
score_p2: self.score_for(2),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn score_for(&self, player_id: u64) -> i32 {
|
||||||
|
self.game_state
|
||||||
|
.players
|
||||||
|
.get(&player_id)
|
||||||
|
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||||
|
.unwrap_or(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||||
|
if player_idx == 0 {
|
||||||
|
self.game_state.to_vec()
|
||||||
|
} else {
|
||||||
|
self.game_state.mirror().to_vec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||||
|
if player_idx == 0 {
|
||||||
|
format!("{}", self.game_state)
|
||||||
|
} else {
|
||||||
|
format!("{}", self.game_state.mirror())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_debug_string(&self) -> String {
|
||||||
|
format!("{}", self.game_state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> anyhow::Result<()> {
|
||||||
|
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||||
|
anyhow::bail!("apply_dice_roll: not in RollWaiting stage (currently {:?})",
|
||||||
|
self.game_state.turn_stage);
|
||||||
|
}
|
||||||
|
let player_id = self.game_state.active_player_id;
|
||||||
|
let dice = Dice {
|
||||||
|
values: (dice.die1, dice.die2),
|
||||||
|
};
|
||||||
|
self.game_state
|
||||||
|
.consume(&GameEvent::RollResult { player_id, dice });
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_action(&mut self, action_idx: u64) -> anyhow::Result<()> {
|
||||||
|
let needs_mirror = self.game_state.active_player_id == 2;
|
||||||
|
|
||||||
|
let event = TrictracAction::from_action_index(action_idx as usize).and_then(|a| {
|
||||||
|
let state = if needs_mirror {
|
||||||
|
&self.game_state.mirror()
|
||||||
|
} else {
|
||||||
|
&self.game_state
|
||||||
|
};
|
||||||
|
a.to_event(state)
|
||||||
|
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||||
|
});
|
||||||
|
|
||||||
|
match event {
|
||||||
|
Some(evt) if self.game_state.validate(&evt) => {
|
||||||
|
self.game_state.consume(&evt);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Some(_) => anyhow::bail!("apply_action: action {} is not valid in current state",
|
||||||
|
action_idx),
|
||||||
|
None => anyhow::bail!("apply_action: could not build event from action index {}",
|
||||||
|
action_idx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -21,3 +21,6 @@ pub mod training_common;
|
||||||
|
|
||||||
// python interface "trictrac_engine" (for AI training..)
|
// python interface "trictrac_engine" (for AI training..)
|
||||||
mod pyengine;
|
mod pyengine;
|
||||||
|
|
||||||
|
// C++ interface via cxx.rs (for OpenSpiel C++ integration)
|
||||||
|
pub mod cxxengine;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue