refact: remove python & c++ bindings
This commit is contained in:
parent
7f63df2946
commit
f556ae10b8
12 changed files with 5 additions and 668 deletions
172
Cargo.lock
generated
172
Cargo.lock
generated
|
|
@ -1381,7 +1381,6 @@ checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"clap_lex",
|
"clap_lex",
|
||||||
"strsim",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1447,17 +1446,6 @@ dependencies = [
|
||||||
"unicode-width 0.2.0",
|
"unicode-width 0.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "codespan-reporting"
|
|
||||||
version = "0.13.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
|
|
||||||
dependencies = [
|
|
||||||
"serde",
|
|
||||||
"termcolor",
|
|
||||||
"unicode-width 0.2.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "collection_literals"
|
name = "collection_literals"
|
||||||
version = "1.0.3"
|
version = "1.0.3"
|
||||||
|
|
@ -2278,68 +2266,6 @@ dependencies = [
|
||||||
"libloading",
|
"libloading",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cxx"
|
|
||||||
version = "1.0.194"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "747d8437319e3a2f43d93b341c137927ca70c0f5dabeea7a005a73665e247c7e"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
"cxx-build",
|
|
||||||
"cxxbridge-cmd",
|
|
||||||
"cxxbridge-flags",
|
|
||||||
"cxxbridge-macro",
|
|
||||||
"foldhash 0.2.0",
|
|
||||||
"link-cplusplus",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cxx-build"
|
|
||||||
version = "1.0.194"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b0f4697d190a142477b16aef7da8a99bfdc41e7e8b1687583c0d23a79c7afc1e"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
"codespan-reporting 0.13.1",
|
|
||||||
"indexmap",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"scratch",
|
|
||||||
"syn 2.0.114",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cxxbridge-cmd"
|
|
||||||
version = "1.0.194"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d0956799fa8678d4c50eed028f2de1c0552ae183c76e976cf7ca8c4e36a7c328"
|
|
||||||
dependencies = [
|
|
||||||
"clap",
|
|
||||||
"codespan-reporting 0.13.1",
|
|
||||||
"indexmap",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.114",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cxxbridge-flags"
|
|
||||||
version = "1.0.194"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "23384a836ab4f0ad98ace7e3955ad2de39de42378ab487dc28d3990392cb283a"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cxxbridge-macro"
|
|
||||||
version = "1.0.194"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e6acc6b5822b9526adfb4fc377b67128fdd60aac757cc4a741a6278603f763cf"
|
|
||||||
dependencies = [
|
|
||||||
"indexmap",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.114",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.20.11"
|
version = "0.20.11"
|
||||||
|
|
@ -4875,15 +4801,6 @@ version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bfae20f6b19ad527b550c223fddc3077a547fc70cda94b9b566575423fd303ee"
|
checksum = "bfae20f6b19ad527b550c223fddc3077a547fc70cda94b9b566575423fd303ee"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "link-cplusplus"
|
|
||||||
version = "1.0.12"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.4.15"
|
version = "0.4.15"
|
||||||
|
|
@ -5073,15 +4990,6 @@ dependencies = [
|
||||||
"stable_deref_trait",
|
"stable_deref_trait",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "memoffset"
|
|
||||||
version = "0.9.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
|
||||||
dependencies = [
|
|
||||||
"autocfg",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "merge"
|
name = "merge"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
@ -5180,7 +5088,7 @@ dependencies = [
|
||||||
"bitflags 2.10.0",
|
"bitflags 2.10.0",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"cfg_aliases",
|
"cfg_aliases",
|
||||||
"codespan-reporting 0.12.0",
|
"codespan-reporting",
|
||||||
"half",
|
"half",
|
||||||
"hashbrown 0.15.5",
|
"hashbrown 0.15.5",
|
||||||
"hexf-parse",
|
"hexf-parse",
|
||||||
|
|
@ -6045,69 +5953,6 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pyo3"
|
|
||||||
version = "0.23.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"indoc",
|
|
||||||
"libc",
|
|
||||||
"memoffset",
|
|
||||||
"once_cell",
|
|
||||||
"portable-atomic",
|
|
||||||
"pyo3-build-config",
|
|
||||||
"pyo3-ffi",
|
|
||||||
"pyo3-macros",
|
|
||||||
"unindent",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pyo3-build-config"
|
|
||||||
version = "0.23.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
|
|
||||||
dependencies = [
|
|
||||||
"once_cell",
|
|
||||||
"target-lexicon",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pyo3-ffi"
|
|
||||||
version = "0.23.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
|
|
||||||
dependencies = [
|
|
||||||
"libc",
|
|
||||||
"pyo3-build-config",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pyo3-macros"
|
|
||||||
version = "0.23.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"pyo3-macros-backend",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.114",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pyo3-macros-backend"
|
|
||||||
version = "0.23.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
|
|
||||||
dependencies = [
|
|
||||||
"heck",
|
|
||||||
"proc-macro2",
|
|
||||||
"pyo3-build-config",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.114",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "qoi"
|
name = "qoi"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
|
@ -6955,12 +6800,6 @@ version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "scratch"
|
|
||||||
version = "1.0.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sdl2"
|
name = "sdl2"
|
||||||
version = "0.37.0"
|
version = "0.37.0"
|
||||||
|
|
@ -7626,12 +7465,6 @@ dependencies = [
|
||||||
"xattr",
|
"xattr",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "target-lexicon"
|
|
||||||
version = "0.12.16"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tch"
|
name = "tch"
|
||||||
version = "0.22.0"
|
version = "0.22.0"
|
||||||
|
|
@ -8276,11 +8109,8 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"cxx",
|
|
||||||
"cxx-build",
|
|
||||||
"log",
|
"log",
|
||||||
"merge",
|
"merge",
|
||||||
"pyo3",
|
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"serde",
|
"serde",
|
||||||
"transpose",
|
"transpose",
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ path = "src/burnrl/main.rs"
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
trictrac-store = { path = "../store", features = ["python"] }
|
trictrac-store = { path = "../store" }
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ bincode = "1.3.3"
|
||||||
pico-args = "0.5.0"
|
pico-args = "0.5.0"
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
renet = "0.0.13"
|
renet = "0.0.13"
|
||||||
trictrac-store = { path = "../store", features = ["python"] }
|
trictrac-store = { path = "../store" }
|
||||||
trictrac-bot = { path = "../bot" }
|
trictrac-bot = { path = "../bot" }
|
||||||
spiel_bot = { path = "../spiel_bot" }
|
spiel_bot = { path = "../spiel_bot" }
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
|
|
|
||||||
52
devenv.nix
52
devenv.nix
|
|
@ -21,63 +21,11 @@ in
|
||||||
pkgs.samply # code profiler
|
pkgs.samply # code profiler
|
||||||
pkgs.feedgnuplot # to visualize bots training results
|
pkgs.feedgnuplot # to visualize bots training results
|
||||||
|
|
||||||
# --- AI training with python ---
|
|
||||||
# generate python classes from rust code
|
|
||||||
pkgs.maturin
|
|
||||||
# required by python numpy
|
|
||||||
pkgs.libz
|
|
||||||
|
|
||||||
# for bevy
|
|
||||||
pkgs.alsa-lib
|
|
||||||
pkgs.udev
|
|
||||||
|
|
||||||
# bevy fast compile
|
|
||||||
pkgs.clang
|
|
||||||
pkgs.lld
|
|
||||||
|
|
||||||
# copié de https://github.com/mmai/Hyperspeedcube/blob/develop/devenv.nix
|
|
||||||
# TODO : retirer ce qui est inutile
|
|
||||||
# pour erreur à l'exécution, selon https://github.com/emilk/egui/discussions/1587
|
|
||||||
pkgs.libxkbcommon
|
|
||||||
pkgs.libGL
|
|
||||||
|
|
||||||
# WINIT_UNIX_BACKEND=wayland
|
|
||||||
pkgs.wayland
|
|
||||||
|
|
||||||
# WINIT_UNIX_BACKEND=x11
|
|
||||||
pkgs.xorg.libXcursor
|
|
||||||
pkgs.xorg.libXrandr
|
|
||||||
pkgs.xorg.libXi
|
|
||||||
pkgs.xorg.libX11
|
|
||||||
|
|
||||||
pkgs.vulkan-headers
|
|
||||||
pkgs.vulkan-loader
|
|
||||||
# ------------ fin copie
|
|
||||||
|
|
||||||
];
|
];
|
||||||
|
|
||||||
# https://devenv.sh/languages/
|
# https://devenv.sh/languages/
|
||||||
languages.rust.enable = true;
|
languages.rust.enable = true;
|
||||||
|
|
||||||
|
|
||||||
# AI training with python
|
|
||||||
enterShell = ''
|
|
||||||
PYTHONPATH=$PYTHONPATH:$PWD/.devenv/state/venv/lib/python3/site-packages
|
|
||||||
'';
|
|
||||||
|
|
||||||
languages.python = {
|
|
||||||
enable = true;
|
|
||||||
uv.enable = true;
|
|
||||||
venv.enable = true;
|
|
||||||
venv.requirements = "
|
|
||||||
pip
|
|
||||||
gymnasium
|
|
||||||
numpy
|
|
||||||
stable-baselines3
|
|
||||||
shimmy
|
|
||||||
";
|
|
||||||
};
|
|
||||||
|
|
||||||
# https://devenv.sh/scripts/
|
# https://devenv.sh/scripts/
|
||||||
# scripts.hello.exec = "echo hello from $GREET";
|
# scripts.hello.exec = "echo hello from $GREET";
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
trictrac-store = { path = "../store", features = ["python"] }
|
trictrac-store = { path = "../store" }
|
||||||
trictrac-bot = { path = "../bot" }
|
trictrac-bot = { path = "../bot" }
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
|
|
|
||||||
|
|
@ -7,26 +7,14 @@ edition = "2021"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "trictrac_store"
|
name = "trictrac_store"
|
||||||
# "cdylib" → Python .so built by maturin (pyengine)
|
crate-type = ["rlib"]
|
||||||
# "rlib" → used by other workspace crates (bot, client_cli)
|
|
||||||
# "staticlib" → used by the C++ OpenSpiel game (cxxengine)
|
|
||||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
|
||||||
|
|
||||||
[features]
|
|
||||||
# Enable Python bindings (required for maturin / AI training). Not available on wasm32.
|
|
||||||
python = ["pyo3"]
|
|
||||||
# Enable C++ bridge for OpenSpiel integration. Not available on wasm32.
|
|
||||||
cpp = ["dep:cxx"]
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
base64 = "0.21.7"
|
base64 = "0.21.7"
|
||||||
cxx = { version = "1.0", optional = true }
|
|
||||||
# provides macros for creating log messages to be used by a logger (for example env_logger)
|
# provides macros for creating log messages to be used by a logger (for example env_logger)
|
||||||
log = "0.4.20"
|
log = "0.4.20"
|
||||||
merge = "0.1.0"
|
merge = "0.1.0"
|
||||||
# generate python lib (with maturin) to be used in AI training
|
|
||||||
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"], optional = true }
|
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
transpose = "0.2.2"
|
transpose = "0.2.2"
|
||||||
|
|
@ -34,6 +22,3 @@ transpose = "0.2.2"
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "random_game"
|
name = "random_game"
|
||||||
path = "src/bin/random_game.rs"
|
path = "src/bin/random_game.rs"
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
cxx-build = "1.0"
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
fn main() {
|
|
||||||
if std::env::var("CARGO_FEATURE_CPP").is_ok() {
|
|
||||||
cxx_build::bridge("src/cxxengine.rs")
|
|
||||||
.std("c++17")
|
|
||||||
.compile("trictrac-cxx");
|
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
[build-system]
|
|
||||||
requires = ["maturin>=1.0,<2.0"]
|
|
||||||
build-backend = "maturin"
|
|
||||||
|
|
||||||
[tool.maturin]
|
|
||||||
# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so)
|
|
||||||
features = ["pyo3/extension-module"]
|
|
||||||
# python-source = "python"
|
|
||||||
|
|
@ -1,252 +0,0 @@
|
||||||
//! C++ bindings for the TricTrac game engine via cxx.rs.
|
|
||||||
//!
|
|
||||||
//! Exposes an opaque `TricTracEngine` type to C++. The C++ side
|
|
||||||
//! (open_spiel/games/trictrac/trictrac.cc) holds it via
|
|
||||||
//! `rust::Box<trictrac_engine::TricTracEngine>`.
|
|
||||||
//!
|
|
||||||
//! The Rust engine always reasons from White's (player 1's) perspective.
|
|
||||||
//! For Black (player 2), the board is mirrored before computing actions
|
|
||||||
//! and events are mirrored back before being applied — exactly as in
|
|
||||||
//! pyengine.rs.
|
|
||||||
|
|
||||||
use std::panic::{self, AssertUnwindSafe};
|
|
||||||
|
|
||||||
use crate::dice::Dice;
|
|
||||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
|
||||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
|
||||||
|
|
||||||
/// Catch any Rust panic and convert it to anyhow::Error so it never
|
|
||||||
/// crosses the C FFI boundary as undefined behaviour.
|
|
||||||
fn catch_panics<F, T>(f: F) -> anyhow::Result<T>
|
|
||||||
where
|
|
||||||
F: FnOnce() -> anyhow::Result<T> + panic::UnwindSafe,
|
|
||||||
{
|
|
||||||
panic::catch_unwind(f).unwrap_or_else(|e| {
|
|
||||||
let msg = e
|
|
||||||
.downcast_ref::<String>()
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.or_else(|| e.downcast_ref::<&str>().copied())
|
|
||||||
.unwrap_or("unknown panic payload");
|
|
||||||
Err(anyhow::anyhow!("Rust panic in FFI: {}", msg))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "trictrac_engine")]
|
|
||||||
pub mod ffi {
|
|
||||||
// ── Shared types (transparent to both Rust and C++) ───────────────────────
|
|
||||||
|
|
||||||
/// Two dice values passed from C++ when applying a chance outcome.
|
|
||||||
struct DicePair {
|
|
||||||
die1: u8,
|
|
||||||
die2: u8,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Both players' cumulative scores: holes * 12 + points.
|
|
||||||
struct PlayerScores {
|
|
||||||
score_p1: i32,
|
|
||||||
score_p2: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Opaque Rust type and its free-function constructor ────────────────────
|
|
||||||
|
|
||||||
extern "Rust" {
|
|
||||||
/// Opaque handle to a running TricTrac game.
|
|
||||||
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
|
||||||
type TricTracEngine;
|
|
||||||
|
|
||||||
/// Construct a fresh engine with two players; player 1 (White) goes first.
|
|
||||||
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
|
||||||
|
|
||||||
/// Deep-copy the engine — required by OpenSpiel's State::Clone().
|
|
||||||
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
|
||||||
|
|
||||||
// ── Queries ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
|
||||||
fn needs_roll(self: &TricTracEngine) -> bool;
|
|
||||||
|
|
||||||
/// True when Stage::Ended.
|
|
||||||
fn is_game_ended(self: &TricTracEngine) -> bool;
|
|
||||||
|
|
||||||
/// Active player index: 0 = player 1 (White), 1 = player 2 (Black).
|
|
||||||
fn current_player_idx(self: &TricTracEngine) -> u64;
|
|
||||||
|
|
||||||
/// Legal action indices for `player_idx` in [0, 513].
|
|
||||||
/// Returns an empty vector when it is not that player's turn.
|
|
||||||
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Result<Vec<u64>>;
|
|
||||||
|
|
||||||
/// Human-readable description of an action index.
|
|
||||||
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
|
||||||
|
|
||||||
/// Both players' scores.
|
|
||||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
|
||||||
|
|
||||||
/// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1.
|
|
||||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<f32>;
|
|
||||||
|
|
||||||
/// Human-readable state description for `player_idx`.
|
|
||||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
|
||||||
|
|
||||||
/// Full debug representation of the current state.
|
|
||||||
fn to_debug_string(self: &TricTracEngine) -> String;
|
|
||||||
|
|
||||||
// ── Mutations ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Apply a dice-roll result. Returns Err (C++ exception) if not in
|
|
||||||
/// the RollWaiting stage.
|
|
||||||
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
|
||||||
|
|
||||||
/// Apply a player action. Returns Err (C++ exception) if the action
|
|
||||||
/// is not legal in the current state.
|
|
||||||
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Opaque type ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
pub struct TricTracEngine {
|
|
||||||
game_state: GameState,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Free-function constructor (declared in the bridge as a plain function) ────
|
|
||||||
|
|
||||||
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
|
||||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
|
||||||
game_state.init_player("player1");
|
|
||||||
game_state.init_player("player2");
|
|
||||||
game_state
|
|
||||||
.consume(&GameEvent::BeginGame { goes_first: 1 })
|
|
||||||
.expect("BeginGame failed during engine initialization");
|
|
||||||
Box::new(TricTracEngine { game_state })
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Method implementations ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
impl TricTracEngine {
|
|
||||||
fn clone_engine(&self) -> Box<TricTracEngine> {
|
|
||||||
Box::new(TricTracEngine {
|
|
||||||
game_state: self.game_state.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn needs_roll(&self) -> bool {
|
|
||||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_game_ended(&self) -> bool {
|
|
||||||
self.game_state.stage == Stage::Ended
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_player_idx(&self) -> u64 {
|
|
||||||
self.game_state.active_player_id - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_legal_actions(&self, player_idx: u64) -> anyhow::Result<Vec<u64>> {
|
|
||||||
if player_idx != self.current_player_idx() {
|
|
||||||
return Ok(vec![]);
|
|
||||||
}
|
|
||||||
catch_panics(AssertUnwindSafe(|| {
|
|
||||||
if player_idx == 0 {
|
|
||||||
get_valid_action_indices(&self.game_state)
|
|
||||||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
|
||||||
} else {
|
|
||||||
let mirror = self.game_state.mirror();
|
|
||||||
get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect())
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
|
||||||
TrictracAction::from_action_index(action_idx as usize)
|
|
||||||
.map(|a| format!("{}:{}", player_idx, a))
|
|
||||||
.unwrap_or_else(|| "unknown action".into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_players_scores(&self) -> ffi::PlayerScores {
|
|
||||||
ffi::PlayerScores {
|
|
||||||
score_p1: self.score_for(1),
|
|
||||||
score_p2: self.score_for(2),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn score_for(&self, player_id: u64) -> i32 {
|
|
||||||
self.game_state
|
|
||||||
.players
|
|
||||||
.get(&player_id)
|
|
||||||
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
|
||||||
.unwrap_or(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
|
||||||
if player_idx == 0 {
|
|
||||||
self.game_state.to_tensor()
|
|
||||||
} else {
|
|
||||||
self.game_state.mirror().to_tensor()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
|
||||||
if player_idx == 0 {
|
|
||||||
format!("{}", self.game_state)
|
|
||||||
} else {
|
|
||||||
format!("{}", self.game_state.mirror())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_debug_string(&self) -> String {
|
|
||||||
format!("{}", self.game_state)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> anyhow::Result<()> {
|
|
||||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
|
||||||
anyhow::bail!(
|
|
||||||
"apply_dice_roll: not in RollWaiting stage (currently {:?})",
|
|
||||||
self.game_state.turn_stage
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let player_id = self.game_state.active_player_id;
|
|
||||||
let dice = Dice {
|
|
||||||
values: (dice.die1, dice.die2),
|
|
||||||
};
|
|
||||||
catch_panics(AssertUnwindSafe(|| {
|
|
||||||
self.game_state
|
|
||||||
.consume(&GameEvent::RollResult { player_id, dice })
|
|
||||||
.map_err(|e| anyhow::anyhow!(e))
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_action(&mut self, action_idx: u64) -> anyhow::Result<()> {
|
|
||||||
catch_panics(AssertUnwindSafe(|| {
|
|
||||||
let needs_mirror = self.game_state.active_player_id == 2;
|
|
||||||
|
|
||||||
let event = TrictracAction::from_action_index(action_idx as usize).and_then(|a| {
|
|
||||||
let state = if needs_mirror {
|
|
||||||
&self.game_state.mirror()
|
|
||||||
} else {
|
|
||||||
&self.game_state
|
|
||||||
};
|
|
||||||
a.to_event(state)
|
|
||||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
|
||||||
});
|
|
||||||
|
|
||||||
match event {
|
|
||||||
Some(evt) if self.game_state.validate(&evt) => self
|
|
||||||
.game_state
|
|
||||||
.consume(&evt)
|
|
||||||
.map_err(|e| anyhow::anyhow!(e)),
|
|
||||||
Some(evt) => anyhow::bail!(
|
|
||||||
"apply_action: event {:?} is not valid in current state {}",
|
|
||||||
evt,
|
|
||||||
self.game_state
|
|
||||||
),
|
|
||||||
None => anyhow::bail!(
|
|
||||||
"apply_action: could not build event from action index {} in state {}",
|
|
||||||
action_idx,
|
|
||||||
self.game_state
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -18,11 +18,3 @@ mod dice;
|
||||||
pub use dice::{Dice, DiceRoller};
|
pub use dice::{Dice, DiceRoller};
|
||||||
|
|
||||||
pub mod training_common;
|
pub mod training_common;
|
||||||
|
|
||||||
// python interface "trictrac_engine" (for AI training..)
|
|
||||||
#[cfg(feature = "python")]
|
|
||||||
mod pyengine;
|
|
||||||
|
|
||||||
// C++ interface via cxx.rs (for OpenSpiel C++ integration)
|
|
||||||
#[cfg(feature = "cpp")]
|
|
||||||
pub mod cxxengine;
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
#[cfg(feature = "python")]
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
// This just makes it easier to dissern between a player id and any ol' u64
|
// This just makes it easier to dissern between a player id and any ol' u64
|
||||||
pub type PlayerId = u64;
|
pub type PlayerId = u64;
|
||||||
|
|
||||||
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
|
|
||||||
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub enum Color {
|
pub enum Color {
|
||||||
White,
|
White,
|
||||||
|
|
|
||||||
|
|
@ -1,146 +0,0 @@
|
||||||
//! # Expose trictrac game state and rules in a python module
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
|
|
||||||
use crate::dice::Dice;
|
|
||||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
|
||||||
use crate::player::PlayerId;
|
|
||||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
|
||||||
|
|
||||||
#[pyclass]
|
|
||||||
struct TricTrac {
|
|
||||||
game_state: GameState,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl TricTrac {
|
|
||||||
#[new]
|
|
||||||
fn new() -> Self {
|
|
||||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
|
||||||
|
|
||||||
// Initialiser 2 joueurs
|
|
||||||
game_state.init_player("player1");
|
|
||||||
game_state.init_player("player2");
|
|
||||||
|
|
||||||
// Commencer la partie avec le joueur 1
|
|
||||||
let _ = game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
|
||||||
|
|
||||||
TricTrac { game_state }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn needs_roll(&self) -> bool {
|
|
||||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_game_ended(&self) -> bool {
|
|
||||||
self.game_state.stage == Stage::Ended
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0 or 1
|
|
||||||
fn current_player_idx(&self) -> u64 {
|
|
||||||
self.game_state.active_player_id - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
|
|
||||||
if player_idx == self.current_player_idx() {
|
|
||||||
if player_idx == 0 {
|
|
||||||
get_valid_action_indices(&self.game_state).unwrap()
|
|
||||||
} else {
|
|
||||||
let mirror = self.game_state.mirror();
|
|
||||||
get_valid_action_indices(&mirror).unwrap()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn action_to_string(&self, player_idx: u64, action_idx: usize) -> String {
|
|
||||||
TrictracAction::from_action_index(action_idx)
|
|
||||||
.map(|a| format!("{}:{}", player_idx, a))
|
|
||||||
.unwrap_or("unknown action".into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_dice_roll(&mut self, dices: (u8, u8)) -> PyResult<()> {
|
|
||||||
let player_id = self.game_state.active_player_id;
|
|
||||||
|
|
||||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
|
||||||
return Err(pyo3::exceptions::PyRuntimeError::new_err(
|
|
||||||
"Not in RollWaiting stage",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let dice = Dice { values: dices };
|
|
||||||
let _ = self
|
|
||||||
.game_state
|
|
||||||
.consume(&GameEvent::RollResult { player_id, dice });
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
|
|
||||||
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
|
|
||||||
let needs_mirror = self.game_state.active_player_id == 2;
|
|
||||||
let game_state = if needs_mirror {
|
|
||||||
&self.game_state.mirror()
|
|
||||||
} else {
|
|
||||||
&self.game_state
|
|
||||||
};
|
|
||||||
a.to_event(game_state)
|
|
||||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
|
||||||
}) {
|
|
||||||
if self.game_state.validate(&event) {
|
|
||||||
let _ = self.game_state.consume(&event);
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
return Err(pyo3::exceptions::PyRuntimeError::new_err(
|
|
||||||
"Action is invalid",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(pyo3::exceptions::PyRuntimeError::new_err(
|
|
||||||
"Could not apply action",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a player total score (holes & points)
|
|
||||||
fn get_score(&self, player_id: PlayerId) -> i32 {
|
|
||||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
|
||||||
player.holes as i32 * 12 + player.points as i32
|
|
||||||
} else {
|
|
||||||
-1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_players_scores(&self) -> [i32; 2] {
|
|
||||||
[self.get_score(1), self.get_score(2)]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
|
||||||
if player_idx == 0 {
|
|
||||||
self.game_state.to_tensor()
|
|
||||||
} else {
|
|
||||||
self.game_state.mirror().to_tensor()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
|
||||||
if player_idx == 0 {
|
|
||||||
format!("{}", self.game_state)
|
|
||||||
} else {
|
|
||||||
format!("{}", self.game_state.mirror())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Afficher l'état du jeu (pour le débogage)
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
format!("{}", self.game_state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Python module implemented in Rust. The name of this function must match
|
|
||||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
|
||||||
/// import the module.
|
|
||||||
#[pymodule]
|
|
||||||
fn trictrac_store(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<TricTrac>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue