diff --git a/doc/plan_cxxbindings.md b/doc/plan_cxxbindings.md deleted file mode 100644 index 29bf314..0000000 --- a/doc/plan_cxxbindings.md +++ /dev/null @@ -1,992 +0,0 @@ -# Plan: C++ OpenSpiel Game via cxx.rs - -> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings. -> -> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`. - ---- - -## 1. Overview - -The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding: - -- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`. -- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`. -- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box` and delegates all logic to Rust. -- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`. - -The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games. - ---- - -## 2. Files to Create / Modify - -``` -trictrac/ - store/ - Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type - build.rs ← CREATE: cxx-build bridge registration - src/ - lib.rs ← MODIFY: add cxxengine module - cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl - -forks/open_spiel/ - CMakeLists.txt ← MODIFY: add Corrosion FetchContent - open_spiel/ - games/ - CMakeLists.txt ← MODIFY: add trictrac/ sources + test - trictrac/ ← CREATE directory - trictrac.h ← CREATE - trictrac.cc ← CREATE - trictrac_test.cc ← CREATE - - justfile ← MODIFY: add buildtrictrac target -trictrac/ - justfile ← MODIFY: add cxxlib target -``` - ---- - -## 3. Step 1 — Rust: `store/Cargo.toml` - -Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library. - -```toml -[package] -name = "trictrac-store" -version = "0.1.0" -edition = "2021" - -[lib] -name = "trictrac_store" -# cdylib → Python .so (used by maturin / pyengine) -# rlib → used by other Rust crates in the workspace -# staticlib → used by C++ consumers (cxxengine) -crate-type = ["cdylib", "rlib", "staticlib"] - -[dependencies] -base64 = "0.21.7" -cxx = "1.0" -log = "0.4.20" -merge = "0.1.0" -pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } -rand = "0.9" -serde = { version = "1.0", features = ["derive"] } -transpose = "0.2.2" - -[build-dependencies] -cxx-build = "1.0" -``` - -> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`). - ---- - -## 4. Step 2 — Rust: `store/build.rs` - -The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header. - -```rust -fn main() { - cxx_build::bridge("src/cxxengine.rs") - .std("c++17") - .compile("trictrac-cxx"); - - // Re-run if the bridge source changes - println!("cargo:rerun-if-changed=src/cxxengine.rs"); -} -``` - -`cxx-build` will: - -- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block. -- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header. -- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source. -- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`). - ---- - -## 5. Step 3 — Rust: `store/src/cxxengine.rs` - -This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3. - -### Design decisions vs. `pyengine.rs` - -| pyengine | cxxengine | Reason | -| ------------------------- | ---------------------------- | -------------------------------------------- | -| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception | -| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples | -| `Vec` for actions | `Vec` | cxx does not support `usize` | -| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays | -| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy | - -### File content - -```rust -//! # C++ bindings for the TricTrac game engine via cxx.rs -//! -//! Exposes an opaque `TricTracEngine` type and associated functions -//! to C++. The C++ side (trictrac.cc) uses `rust::Box`. -//! -//! The Rust engine always works from the perspective of White (player 1). -//! For Black (player 2), the board is mirrored before computing actions -//! and events are mirrored back before applying — exactly as in pyengine.rs. - -use crate::dice::Dice; -use crate::game::{GameEvent, GameState, Stage, TurnStage}; -use crate::training_common::{get_valid_action_indices, TrictracAction}; - -// ── cxx bridge declaration ──────────────────────────────────────────────────── - -#[cxx::bridge(namespace = "trictrac_engine")] -pub mod ffi { - // ── Shared types (visible to both Rust and C++) ─────────────────────────── - - /// Two dice values passed from C++ to Rust for a dice-roll event. - struct DicePair { - die1: u8, - die2: u8, - } - - /// Both players' scores: holes * 12 + points. - struct PlayerScores { - score_p1: i32, - score_p2: i32, - } - - // ── Opaque Rust type exposed to C++ ─────────────────────────────────────── - - extern "Rust" { - /// Opaque handle to a TricTrac game state. - /// C++ accesses this only through `rust::Box`. - type TricTracEngine; - - /// Create a new engine, initialise two players, begin with player 1. - fn new_trictrac_engine() -> Box; - - /// Return a deep copy of the engine (needed for State::Clone()). - fn clone_engine(self: &TricTracEngine) -> Box; - - // ── Queries ─────────────────────────────────────────────────────────── - - /// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node). - fn needs_roll(self: &TricTracEngine) -> bool; - - /// True when Stage::Ended. - fn is_game_ended(self: &TricTracEngine) -> bool; - - /// Active player index: 0 (player 1 / White) or 1 (player 2 / Black). - fn current_player_idx(self: &TricTracEngine) -> u64; - - /// Legal action indices for `player_idx`. Returns empty vec if it is - /// not that player's turn. Indices are in [0, 513]. - fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec; - - /// Human-readable action description, e.g. "0:Move { dice_order: true … }". - fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String; - - /// Both players' scores: holes * 12 + points. - fn get_players_scores(self: &TricTracEngine) -> PlayerScores; - - /// 36-element state observation vector (i8). Mirrored for player 1. - fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; - - /// Human-readable state description for `player_idx`. - fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; - - /// Full debug representation of the current state. - fn to_debug_string(self: &TricTracEngine) -> String; - - // ── Mutations ───────────────────────────────────────────────────────── - - /// Apply a dice roll result. Returns Err if not in RollWaiting stage. - fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>; - - /// Apply a player action (move, go, roll). Returns Err if invalid. - fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>; - } -} - -// ── Opaque type implementation ──────────────────────────────────────────────── - -pub struct TricTracEngine { - game_state: GameState, -} - -pub fn new_trictrac_engine() -> Box { - let mut game_state = GameState::new(false); // schools_enabled = false - game_state.init_player("player1"); - game_state.init_player("player2"); - game_state.consume(&GameEvent::BeginGame { goes_first: 1 }); - Box::new(TricTracEngine { game_state }) -} - -impl TricTracEngine { - fn clone_engine(&self) -> Box { - Box::new(TricTracEngine { - game_state: self.game_state.clone(), - }) - } - - fn needs_roll(&self) -> bool { - self.game_state.turn_stage == TurnStage::RollWaiting - } - - fn is_game_ended(&self) -> bool { - self.game_state.stage == Stage::Ended - } - - /// Returns 0 for player 1 (White) and 1 for player 2 (Black). - fn current_player_idx(&self) -> u64 { - self.game_state.active_player_id - 1 - } - - fn get_legal_actions(&self, player_idx: u64) -> Vec { - if player_idx == self.current_player_idx() { - if player_idx == 0 { - get_valid_action_indices(&self.game_state) - .into_iter() - .map(|i| i as u64) - .collect() - } else { - let mirror = self.game_state.mirror(); - get_valid_action_indices(&mirror) - .into_iter() - .map(|i| i as u64) - .collect() - } - } else { - vec![] - } - } - - fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String { - TrictracAction::from_action_index(action_idx as usize) - .map(|a| format!("{}:{}", player_idx, a)) - .unwrap_or_else(|| "unknown action".into()) - } - - fn get_players_scores(&self) -> ffi::PlayerScores { - ffi::PlayerScores { - score_p1: self.score_for(1), - score_p2: self.score_for(2), - } - } - - fn score_for(&self, player_id: u64) -> i32 { - if let Some(player) = self.game_state.players.get(&player_id) { - player.holes as i32 * 12 + player.points as i32 - } else { - -1 - } - } - - fn get_tensor(&self, player_idx: u64) -> Vec { - if player_idx == 0 { - self.game_state.to_vec() - } else { - self.game_state.mirror().to_vec() - } - } - - fn get_observation_string(&self, player_idx: u64) -> String { - if player_idx == 0 { - format!("{}", self.game_state) - } else { - format!("{}", self.game_state.mirror()) - } - } - - fn to_debug_string(&self) -> String { - format!("{}", self.game_state) - } - - fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> { - let player_id = self.game_state.active_player_id; - if self.game_state.turn_stage != TurnStage::RollWaiting { - return Err("Not in RollWaiting stage".into()); - } - let dice = Dice { - values: (dice.die1, dice.die2), - }; - self.game_state - .consume(&GameEvent::RollResult { player_id, dice }); - Ok(()) - } - - fn apply_action(&mut self, action_idx: u64) -> Result<(), String> { - let action_idx = action_idx as usize; - let needs_mirror = self.game_state.active_player_id == 2; - - let event = TrictracAction::from_action_index(action_idx) - .and_then(|a| { - let game_state = if needs_mirror { - &self.game_state.mirror() - } else { - &self.game_state - }; - a.to_event(game_state) - .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) - }); - - match event { - Some(evt) if self.game_state.validate(&evt) => { - self.game_state.consume(&evt); - Ok(()) - } - Some(_) => Err("Action is invalid".into()), - None => Err("Could not build event from action index".into()), - } - } -} -``` - -> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options: -> -> - Use `anyhow::Error` (add `anyhow` dependency). -> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`. -> -> The recommended approach is `anyhow`: -> -> ```toml -> [dependencies] -> anyhow = "1.0" -> ``` -> -> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message. - ---- - -## 6. Step 4 — Rust: `store/src/lib.rs` - -Add the new module: - -```rust -// existing modules … -mod pyengine; - -// NEW: C++ bindings via cxx.rs -pub mod cxxengine; -``` - ---- - -## 7. Step 5 — C++: `trictrac/trictrac.h` - -Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box` and delegates everything to it. - -```cpp -// open_spiel/games/trictrac/trictrac.h -#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_ -#define OPEN_SPIEL_GAMES_TRICTRAC_H_ - -#include -#include -#include - -#include "open_spiel/spiel.h" -#include "open_spiel/spiel_utils.h" - -// Generated by cxx-build from store/src/cxxengine.rs. -// The include path is set by CMake (see CMakeLists.txt). -#include "trictrac_store/src/cxxengine.rs.h" - -namespace open_spiel { -namespace trictrac { - -inline constexpr int kNumPlayers = 2; -inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes -inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust -inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust -inline constexpr int kDefaultMaxTurns = 1000; - -class TrictracGame; - -// --------------------------------------------------------------------------- -// TrictracState -// --------------------------------------------------------------------------- -class TrictracState : public State { - public: - explicit TrictracState(std::shared_ptr game); - TrictracState(const TrictracState& other); - - Player CurrentPlayer() const override; - std::vector LegalActions() const override; - std::string ActionToString(Player player, Action move_id) const override; - std::vector> ChanceOutcomes() const override; - std::string ToString() const override; - bool IsTerminal() const override; - std::vector Returns() const override; - std::string ObservationString(Player player) const override; - void ObservationTensor(Player player, absl::Span values) const override; - std::unique_ptr Clone() const override; - - protected: - void DoApplyAction(Action move_id) override; - - private: - // Decode a chance action index [0,35] to (die1, die2). - // Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action] - static trictrac_engine::DicePair DecodeChanceAction(Action action); - - // The Rust engine handle. Deep-copied via clone_engine() when cloning state. - rust::Box engine_; -}; - -// --------------------------------------------------------------------------- -// TrictracGame -// --------------------------------------------------------------------------- -class TrictracGame : public Game { - public: - explicit TrictracGame(const GameParameters& params); - - int NumDistinctActions() const override { return kNumDistinctActions; } - std::unique_ptr NewInitialState() const override; - int MaxChanceOutcomes() const override { return kNumChanceOutcomes; } - int NumPlayers() const override { return kNumPlayers; } - double MinUtility() const override { return 0.0; } - double MaxUtility() const override { return 200.0; } - int MaxGameLength() const override { return 3 * max_turns_; } - int MaxChanceNodesInHistory() const override { return MaxGameLength(); } - std::vector ObservationTensorShape() const override { - return {kStateEncodingSize}; - } - - private: - int max_turns_; -}; - -} // namespace trictrac -} // namespace open_spiel - -#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_ -``` - ---- - -## 8. Step 6 — C++: `trictrac/trictrac.cc` - -```cpp -// open_spiel/games/trictrac/trictrac.cc -#include "open_spiel/games/trictrac/trictrac.h" - -#include -#include -#include - -#include "open_spiel/abseil-cpp/absl/types/span.h" -#include "open_spiel/game_parameters.h" -#include "open_spiel/spiel.h" -#include "open_spiel/spiel_globals.h" -#include "open_spiel/spiel_utils.h" - -namespace open_spiel { -namespace trictrac { -namespace { - -// ── Game registration ──────────────────────────────────────────────────────── - -const GameType kGameType{ - /*short_name=*/"trictrac", - /*long_name=*/"Trictrac", - GameType::Dynamics::kSequential, - GameType::ChanceMode::kExplicitStochastic, - GameType::Information::kPerfectInformation, - GameType::Utility::kGeneralSum, - GameType::RewardModel::kRewards, - /*min_num_players=*/kNumPlayers, - /*max_num_players=*/kNumPlayers, - /*provides_information_state_string=*/false, - /*provides_information_state_tensor=*/false, - /*provides_observation_string=*/true, - /*provides_observation_tensor=*/true, - /*parameter_specification=*/{ - {"max_turns", GameParameter(kDefaultMaxTurns)}, - }}; - -static std::shared_ptr Factory(const GameParameters& params) { - return std::make_shared(params); -} - -REGISTER_SPIEL_GAME(kGameType, Factory); - -} // namespace - -// ── TrictracGame ───────────────────────────────────────────────────────────── - -TrictracGame::TrictracGame(const GameParameters& params) - : Game(kGameType, params), - max_turns_(ParameterValue("max_turns", kDefaultMaxTurns)) {} - -std::unique_ptr TrictracGame::NewInitialState() const { - return std::make_unique(shared_from_this()); -} - -// ── TrictracState ───────────────────────────────────────────────────────────── - -TrictracState::TrictracState(std::shared_ptr game) - : State(game), - engine_(trictrac_engine::new_trictrac_engine()) {} - -// Copy constructor: deep-copy the Rust engine via clone_engine(). -TrictracState::TrictracState(const TrictracState& other) - : State(other), - engine_(other.engine_->clone_engine()) {} - -std::unique_ptr TrictracState::Clone() const { - return std::make_unique(*this); -} - -// ── Current player ──────────────────────────────────────────────────────────── - -Player TrictracState::CurrentPlayer() const { - if (engine_->is_game_ended()) return kTerminalPlayerId; - if (engine_->needs_roll()) return kChancePlayerId; - return static_cast(engine_->current_player_idx()); -} - -// ── Legal actions ───────────────────────────────────────────────────────────── - -std::vector TrictracState::LegalActions() const { - if (IsChanceNode()) { - // All 36 dice outcomes are equally likely; return indices 0–35. - std::vector actions(kNumChanceOutcomes); - for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i; - return actions; - } - Player player = CurrentPlayer(); - rust::Vec rust_actions = - engine_->get_legal_actions(static_cast(player)); - std::vector actions; - actions.reserve(rust_actions.size()); - for (uint64_t a : rust_actions) actions.push_back(static_cast(a)); - return actions; -} - -// ── Chance outcomes ─────────────────────────────────────────────────────────── - -std::vector> TrictracState::ChanceOutcomes() const { - SPIEL_CHECK_TRUE(IsChanceNode()); - const double p = 1.0 / kNumChanceOutcomes; - std::vector> outcomes; - outcomes.reserve(kNumChanceOutcomes); - for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p); - return outcomes; -} - -// ── Apply action ────────────────────────────────────────────────────────────── - -/*static*/ -trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) { - // Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action] - return trictrac_engine::DicePair{ - /*die1=*/static_cast(action / 6 + 1), - /*die2=*/static_cast(action % 6 + 1), - }; -} - -void TrictracState::DoApplyAction(Action action) { - if (IsChanceNode()) { - engine_->apply_dice_roll(DecodeChanceAction(action)); - } else { - engine_->apply_action(static_cast(action)); - } -} - -// ── Terminal & returns ──────────────────────────────────────────────────────── - -bool TrictracState::IsTerminal() const { - return engine_->is_game_ended(); -} - -std::vector TrictracState::Returns() const { - trictrac_engine::PlayerScores scores = engine_->get_players_scores(); - return {static_cast(scores.score_p1), - static_cast(scores.score_p2)}; -} - -// ── Observation ─────────────────────────────────────────────────────────────── - -std::string TrictracState::ObservationString(Player player) const { - return std::string(engine_->get_observation_string( - static_cast(player))); -} - -void TrictracState::ObservationTensor(Player player, - absl::Span values) const { - SPIEL_CHECK_EQ(values.size(), kStateEncodingSize); - rust::Vec tensor = - engine_->get_tensor(static_cast(player)); - SPIEL_CHECK_EQ(tensor.size(), static_cast(kStateEncodingSize)); - for (int i = 0; i < kStateEncodingSize; ++i) { - values[i] = static_cast(tensor[i]); - } -} - -// ── Strings ─────────────────────────────────────────────────────────────────── - -std::string TrictracState::ToString() const { - return std::string(engine_->to_debug_string()); -} - -std::string TrictracState::ActionToString(Player player, Action action) const { - if (IsChanceNode()) { - trictrac_engine::DicePair d = DecodeChanceAction(action); - return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")"; - } - return std::string(engine_->action_to_string( - static_cast(player), static_cast(action))); -} - -} // namespace trictrac -} // namespace open_spiel -``` - ---- - -## 9. Step 7 — C++: `trictrac/trictrac_test.cc` - -```cpp -// open_spiel/games/trictrac/trictrac_test.cc -#include "open_spiel/games/trictrac/trictrac.h" - -#include -#include - -#include "open_spiel/spiel.h" -#include "open_spiel/tests/basic_tests.h" -#include "open_spiel/utils/init.h" - -namespace open_spiel { -namespace trictrac { -namespace { - -void BasicTrictracTests() { - testing::LoadGameTest("trictrac"); - testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5); -} - -} // namespace -} // namespace trictrac -} // namespace open_spiel - -int main(int argc, char** argv) { - open_spiel::Init(&argc, &argv); - open_spiel::trictrac::BasicTrictracTests(); - std::cout << "trictrac tests passed" << std::endl; - return 0; -} -``` - ---- - -## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt` - -The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined: - -```cmake -# ── Corrosion: CMake integration for Rust ──────────────────────────────────── -include(FetchContent) -FetchContent_Declare( - Corrosion - GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git - GIT_TAG v0.5.1 # pin to a stable release -) -FetchContent_MakeAvailable(Corrosion) - -# Import the trictrac-store Rust crate. -# This creates a CMake target named 'trictrac-store'. -corrosion_import_crate( - MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml - CRATES trictrac-store -) - -# Generate the cxx bridge from cxxengine.rs. -# corrosion_add_cxxbridge: -# - runs cxx-build as part of the Rust build -# - creates a CMake target 'trictrac_cxx_bridge' that: -# * compiles the generated C++ shim -# * exposes INTERFACE include dirs for the generated .rs.h header -corrosion_add_cxxbridge(trictrac_cxx_bridge - CRATE trictrac-store - FILES src/cxxengine.rs -) -``` - -> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing. - ---- - -## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt` - -Two changes: add the new source files to `GAME_SOURCES`, and add a test target. - -### 11.1 Add to `GAME_SOURCES` - -Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add: - -```cmake -set(GAME_SOURCES - # ... existing games ... - trictrac/trictrac.cc - trictrac/trictrac.h - # ... remaining games ... -) -``` - -### 11.2 Link cxx bridge into OpenSpiel objects - -The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed. - -The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add: - -```cmake -target_link_libraries(open_spiel - PUBLIC - trictrac_cxx_bridge # C++ shim generated by cxx-build - trictrac-store # Rust static library -) -``` - -If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game: - -```cmake -add_library(trictrac_game OBJECT - trictrac/trictrac.cc - trictrac/trictrac.h -) -target_include_directories(trictrac_game - PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/.. -) -target_link_libraries(trictrac_game - PUBLIC - trictrac_cxx_bridge - trictrac-store - open_spiel_core # or whatever the core target is called -) -``` - -Then reference `$` in relevant executables. - -### 11.3 Add the test - -```cmake -add_executable(trictrac_test - trictrac/trictrac_test.cc - ${OPEN_SPIEL_OBJECTS} - $ -) -target_link_libraries(trictrac_test - PRIVATE - trictrac_cxx_bridge - trictrac-store -) -add_test(trictrac_test trictrac_test) -``` - ---- - -## 12. Step 10 — Justfile updates - -### `trictrac/justfile` — add `cxxlib` target - -Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists: - -```just -cxxlib: - cargo build --release -p trictrac-store - @echo "Static lib: $(ls target/release/libtrictrac_store.a)" - @echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)" -``` - -### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac` - -```just -buildtrictrac: - # Rebuild the Rust static lib first, then CMake - cd ../../trictrac && cargo build --release -p trictrac-store - mkdir -p build && cd build && \ - CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \ - make -j$(nproc) trictrac_test - -testtrictrac: buildtrictrac - ./build/trictrac_test - -playtrictrac_cpp: - ./build/examples/example --game=trictrac -``` - ---- - -## 13. Key Design Decisions - -### 13.1 Opaque type with `clone_engine()` - -OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box` which calls `.clone()` on the inner `GameState` (which derives `Clone`). - -### 13.2 Action encoding: same 514-element space - -The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means: - -- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies. -- Action 0 = Roll (used as the bridge between Move and the next chance node). -- Actions 2–513 = Move variants (checker ordinal pair + dice order). -- A trained C++ model and Python model share the same action space. - -### 13.3 Chance outcome ordering - -The dice outcome ordering is identical to the Python version: - -``` -action → (die1, die2) -0 → (1,1) 6 → (2,1) ... 35 → (6,6) -``` - -(`die1 = action/6 + 1`, `die2 = action%6 + 1`) - -This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines. - -### 13.4 `GameType::Utility::kGeneralSum` + `kRewards` - -Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal. - -### 13.5 Mirror pattern preserved - -`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest. - -### 13.6 `rust::Box` vs `rust::UniquePtr` - -`rust::Box` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr` is for C++ types held in Rust. - -### 13.7 Separate struct from `pyengine.rs` - -`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids: - -- PyO3 and cxx attributes conflicting on the same type. -- Changes to one binding breaking the other. -- Feature-flag complexity. - ---- - -## 14. Known Challenges - -### 14.1 Corrosion path resolution - -`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable). - -### 14.2 `staticlib` + `cdylib` in one crate - -Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions: - -- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time. -- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile. -- Alternatively, accept the extra compile time during development. - -### 14.3 Linker symbols from Rust std - -When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`). - -### 14.4 `anyhow` for error types - -cxx.rs requires the `Err` type in `Result` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper: - -```rust -use std::fmt; - -#[derive(Debug)] -struct EngineError(String); -impl fmt::Display for EngineError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } -} -impl std::error::Error for EngineError {} -``` - -On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`. - -### 14.5 `UndoAction` not implemented - -OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo. - -### 14.6 Generated header path in `#include` - -The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory. - -### 14.7 `rust::String` to `std::string` conversion - -The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast(...)`. - ---- - -## 15. Complete File Checklist - -``` -[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib -[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...) -[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;` -[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation -[ ] trictrac/justfile — add `cxxlib` target -[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge -[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file -[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac -``` - ---- - -## 16. Implementation Order - -Implement in this sequence to get early feedback at each step: - -1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced. -2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations. -3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist. -4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them. -5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination. -6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`. - ---- - -# Implementation complete - -All files are in place and trictrac_test passes (168 84 final scores, all assertions pass). - -What was done - -Rust side (from previous session, already complete): - -- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build -- trictrac/store/build.rs — drives cxx-build -- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl -- trictrac/store/src/lib.rs — added pub mod cxxengine; - -C++ side (this session): - -- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header -- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation -- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test - -Build system: - -- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge -- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable - -Justfiles: - -- trictrac/justfile — added cxxlib target -- forks/open_spiel/justfile — added buildtrictrac and testtrictrac - -Fixes discovered during build - -| Issue | Fix | -| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- | -| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries | -| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) | -| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h | -| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc | -| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) | -| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries | -| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort | -| Duplicate actions for doubles | Added std::unique after sort | -| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states | diff --git a/store/Cargo.toml b/store/Cargo.toml index a9234ff..935a2a0 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -25,5 +25,9 @@ rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" +[[bin]] +name = "random_game" +path = "src/bin/random_game.rs" + [build-dependencies] cxx-build = "1.0" diff --git a/store/src/bin/random_game.rs b/store/src/bin/random_game.rs new file mode 100644 index 0000000..6da3b9c --- /dev/null +++ b/store/src/bin/random_game.rs @@ -0,0 +1,262 @@ +//! Run one or many games of trictrac between two random players. +//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`. +//! In multi-game mode, runs silently and reports throughput at the end. +//! +//! Usage: +//! cargo run --bin random_game -- [--seed ] [--games ] [--max-steps ] [--verbose] + +use std::borrow::Cow; +use std::env; +use std::time::Instant; + +use trictrac_store::{ + training_common::sample_valid_action, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +// ── CLI args ────────────────────────────────────────────────────────────────── + +struct Args { + seed: Option, + games: usize, + max_steps: usize, + verbose: bool, +} + +fn parse_args() -> Args { + let args: Vec = env::args().collect(); + let mut seed = None; + let mut games = 1; + let mut max_steps = 10_000; + let mut verbose = false; + + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--seed" => { + i += 1; + seed = args.get(i).and_then(|s| s.parse().ok()); + } + "--games" => { + i += 1; + if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { + games = v; + } + } + "--max-steps" => { + i += 1; + if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { + max_steps = v; + } + } + "--verbose" => verbose = true, + _ => {} + } + i += 1; + } + + Args { + seed, + games, + max_steps, + verbose, + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn player_label(id: u64) -> &'static str { + if id == 1 { "White" } else { "Black" } +} + +/// Apply a `Roll` + `RollResult` in one logical step, returning the dice. +/// This collapses the two-step dice phase into a single "chance node" action, +/// matching how the OpenSpiel layer exposes it. +fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result { + // RollDice → RollWaiting + state + .consume(&GameEvent::Roll { player_id: state.active_player_id }) + .map_err(|e| format!("Roll event failed: {e}"))?; + + // RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole) + let dice = roller.roll(); + state + .consume(&GameEvent::RollResult { player_id: state.active_player_id, dice }) + .map_err(|e| format!("RollResult event failed: {e}"))?; + + Ok(dice) +} + +/// Sample a random action and apply it to `state`, handling the Black-mirror +/// transform exactly as `cxxengine.rs::apply_action` does: +/// +/// 1. For Black, build a mirrored view of the state so that `sample_valid_action` +/// and `to_event` always reason from White's perspective. +/// 2. Mirror the resulting event back to the original coordinate frame before +/// calling `state.consume`. +/// +/// Returns the chosen action (in the view's coordinate frame) for display. +fn apply_player_action(state: &mut GameState) -> Result<(), String> { + let needs_mirror = state.active_player_id == 2; + + // Build a White-perspective view: borrowed for White, owned mirror for Black. + let view: Cow = if needs_mirror { + Cow::Owned(state.mirror()) + } else { + Cow::Borrowed(state) + }; + + let action = sample_valid_action(&view) + .ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?; + + let event = action + .to_event(&view) + .ok_or_else(|| format!("could not convert {action:?} to event"))?; + + // Translate the event from the view's frame back to the game's frame. + let event = if needs_mirror { event.get_mirror(false) } else { event }; + + state + .consume(&event) + .map_err(|e| format!("consume({action:?}): {e}"))?; + + Ok(()) +} + +// ── Single game ──────────────────────────────────────────────────────────────── + +/// Run one full game, optionally printing play-by-play. +/// Returns `(steps, truncated)`. +fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) { + let mut state = GameState::new_with_players("White", "Black"); + let mut step = 0usize; + + if !quiet { + println!("{state}"); + } + + while state.stage != Stage::Ended { + step += 1; + if step > max_steps { + return (step - 1, true); + } + + match state.turn_stage { + TurnStage::RollDice => { + let player = state.active_player_id; + match apply_dice_roll(&mut state, roller) { + Ok(dice) => { + if !quiet { + println!( + "[step {step:4}] {} rolls: {} & {}", + player_label(player), + dice.values.0, + dice.values.1 + ); + } + } + Err(e) => { + eprintln!("Error during dice roll: {e}"); + eprintln!("State:\n{state}"); + return (step, true); + } + } + } + stage => { + let player = state.active_player_id; + match apply_player_action(&mut state) { + Ok(()) => { + if !quiet { + println!( + "[step {step:4}] {} ({stage:?})", + player_label(player) + ); + if verbose { + println!("{state}"); + } + } + } + Err(e) => { + eprintln!("Error: {e}"); + eprintln!("State:\n{state}"); + return (step, true); + } + } + } + } + } + + if !quiet { + println!("\n=== Game over after {step} steps ===\n"); + println!("{state}"); + + let white = state.players.get(&1); + let black = state.players.get(&2); + + match (white, black) { + (Some(w), Some(b)) => { + println!("White — holes: {:2}, points: {:2}", w.holes, w.points); + println!("Black — holes: {:2}, points: {:2}", b.holes, b.points); + println!(); + + let white_score = w.holes as i32 * 12 + w.points as i32; + let black_score = b.holes as i32 * 12 + b.points as i32; + + if white_score > black_score { + println!("Winner: White (+{})", white_score - black_score); + } else if black_score > white_score { + println!("Winner: Black (+{})", black_score - white_score); + } else { + println!("Draw"); + } + } + _ => eprintln!("Could not read final player scores."), + } + } + + (step, false) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let mut roller = DiceRoller::new(args.seed); + + if args.games == 1 { + println!("=== Trictrac — random game ==="); + if let Some(s) = args.seed { + println!("seed: {s}"); + } + println!(); + run_game(&mut roller, args.max_steps, false, args.verbose); + } else { + println!("=== Trictrac — {} games ===", args.games); + if let Some(s) = args.seed { + println!("seed: {s}"); + } + println!(); + + let mut total_steps = 0u64; + let mut truncated = 0usize; + + let t0 = Instant::now(); + for _ in 0..args.games { + let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose); + total_steps += steps as u64; + if trunc { + truncated += 1; + } + } + let elapsed = t0.elapsed(); + + let secs = elapsed.as_secs_f64(); + println!("Games : {}", args.games); + println!("Truncated : {truncated}"); + println!("Total steps: {total_steps}"); + println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64); + println!("Elapsed : {:.3} s", secs); + println!("Throughput : {:.1} games/s", args.games as f64 / secs); + println!(" {:.0} steps/s", total_steps as f64 / secs); + } +} diff --git a/store/src/board.rs b/store/src/board.rs index de0e450..0fba2d6 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -598,12 +598,40 @@ impl Board { core::array::from_fn(|i| i + min) } + /// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i. + /// result[0] = 0. + pub fn white_checker_cumulative(&self) -> [u8; 25] { + let mut cum = [0u8; 25]; + let mut total = 0u8; + for (i, &count) in self.positions.iter().enumerate() { + if count > 0 { + total += count as u8; + } + cum[i + 1] = total; + } + cum + } + pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> { self.remove_checker(color, cmove.from)?; self.add_checker(color, cmove.to)?; Ok(()) } + /// Reverse a previously applied `move_checker`. No validation: assumes the move was valid. + pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) { + let unit = match color { + Color::White => 1, + Color::Black => -1, + }; + if cmove.from != 0 { + self.positions[cmove.from - 1] += unit; + } + if cmove.to != 0 { + self.positions[cmove.to - 1] -= unit; + } + } + pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> { if field == 0 { return Ok(()); diff --git a/store/src/cxxengine.rs b/store/src/cxxengine.rs index 29bc7fe..55d348c 100644 --- a/store/src/cxxengine.rs +++ b/store/src/cxxengine.rs @@ -83,8 +83,8 @@ pub mod ffi { /// Both players' scores. fn get_players_scores(self: &TricTracEngine) -> PlayerScores; - /// 36-element state vector (i8). Mirrored for player_idx == 1. - fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; + /// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1. + fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; /// Human-readable state description for `player_idx`. fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; @@ -153,8 +153,7 @@ impl TricTracEngine { .map(|v| v.into_iter().map(|i| i as u64).collect()) } else { let mirror = self.game_state.mirror(); - get_valid_action_indices(&mirror) - .map(|v| v.into_iter().map(|i| i as u64).collect()) + get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect()) } })) } @@ -180,11 +179,11 @@ impl TricTracEngine { .unwrap_or(-1) } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_vec() + self.game_state.to_tensor() } else { - self.game_state.mirror().to_vec() + self.game_state.mirror().to_tensor() } } @@ -243,8 +242,9 @@ impl TricTracEngine { self.game_state ), None => anyhow::bail!( - "apply_action: could not build event from action index {}", - action_idx + "apply_action: could not build event from action index {} in state {}", + action_idx, + self.game_state ), } })) diff --git a/store/src/game.rs b/store/src/game.rs index d32734d..f553bdb 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -156,13 +156,6 @@ impl GameState { if let Some(p1) = self.players.get(&1) { mirrored_players.insert(2, p1.mirror()); } - let mirrored_history = self - .history - .clone() - .iter() - .map(|evt| evt.get_mirror(false)) - .collect(); - let (move1, move2) = self.dice_moves; GameState { stage: self.stage, @@ -171,7 +164,7 @@ impl GameState { active_player_id: mirrored_active_player, // active_player_id: self.active_player_id, players: mirrored_players, - history: mirrored_history, + history: Vec::new(), dice: self.dice, dice_points: self.dice_points, dice_moves: (move1.mirror(), move2.mirror()), @@ -207,6 +200,106 @@ impl GameState { self.to_vec().iter().map(|&x| x as f32).collect() } + /// Get state as a tensor for neural network training (Option B, TD-Gammon style). + /// Returns 217 f32 values, all normalized to [0, 1]. + /// + /// Must be called from the active player's perspective: callers should mirror + /// the GameState for Black before calling so that "own" always means White. + /// + /// Layout: + /// [0..95] own (White) checkers: 4 values per field × 24 fields + /// [96..191] opp (Black) checkers: 4 values per field × 24 fields + /// [192..193] dice values / 6 + /// [194] active player color (0=White, 1=Black) + /// [195] turn_stage / 5 + /// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille + /// [200..203] Black player: same + /// [204..207] own quarter filled (quarters 1-4) + /// [208..211] opp quarter filled (quarters 1-4) + /// [212] own checkers all in exit zone (fields 19-24) + /// [213] opp checkers all in exit zone (fields 1-6) + /// [214] own coin de repos taken (field 12 has ≥2 own checkers) + /// [215] opp coin de repos taken (field 13 has ≥2 opp checkers) + /// [216] own dice_roll_count / 3, clamped to 1 + pub fn to_tensor(&self) -> Vec { + let mut t = Vec::with_capacity(217); + let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black + + // [0..95] own (White) checkers, TD-Gammon encoding + for &c in &pos { + let own = c.max(0) as u8; + t.push((own == 1) as u8 as f32); + t.push((own == 2) as u8 as f32); + t.push((own == 3) as u8 as f32); + t.push(own.saturating_sub(3) as f32); + } + + // [96..191] opp (Black) checkers, TD-Gammon encoding + for &c in &pos { + let opp = (-c).max(0) as u8; + t.push((opp == 1) as u8 as f32); + t.push((opp == 2) as u8 as f32); + t.push((opp == 3) as u8 as f32); + t.push(opp.saturating_sub(3) as f32); + } + + // [192..193] dice + t.push(self.dice.values.0 as f32 / 6.0); + t.push(self.dice.values.1 as f32 / 6.0); + + // [194] active player color + t.push( + self.who_plays() + .map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 }) + .unwrap_or(0.0), + ); + + // [195] turn stage + t.push(u8::from(self.turn_stage) as f32 / 5.0); + + // [196..199] White player stats + let wp = self.get_white_player(); + t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0)); + t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0)); + t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); + t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); + + // [200..203] Black player stats + let bp = self.get_black_player(); + t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0)); + t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0)); + t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); + t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); + + // [204..207] own (White) quarter fill status + for &start in &[1usize, 7, 13, 19] { + t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32); + } + + // [208..211] opp (Black) quarter fill status + for &start in &[1usize, 7, 13, 19] { + t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32); + } + + // [212] can_exit_own: no own checker in fields 1-18 + t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32); + + // [213] can_exit_opp: no opp checker in fields 7-24 + t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32); + + // [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers) + t.push((pos[11] >= 2) as u8 as f32); + + // [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers) + t.push((pos[12] <= -2) as u8 as f32); + + // [216] own dice_roll_count / 3, clamped to 1 + t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0)); + + debug_assert_eq!(t.len(), 217, "to_tensor length mismatch"); + t + } + /// Get state as a vector (to be used for bot training input) : /// length = 36 /// i8 for board positions with negative values for blacks diff --git a/store/src/game_rules_moves.rs b/store/src/game_rules_moves.rs index 41221f2..396bcaf 100644 --- a/store/src/game_rules_moves.rs +++ b/store/src/game_rules_moves.rs @@ -220,7 +220,7 @@ impl MoveRules { // Si possible, les deux dés doivent être joués if moves.0.get_from() == 0 || moves.1.get_from() == 0 { let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]); - possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok()); + possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok()); // possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok()); if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() { if *moves == (EMPTY_MOVE, EMPTY_MOVE) { @@ -238,7 +238,7 @@ impl MoveRules { // check exit rules // if !ignored_rules.contains(&TricTracRule::Exit) { - self.check_exit_rules(moves)?; + self.check_exit_rules(moves, None)?; // } // --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ---- @@ -321,7 +321,11 @@ impl MoveRules { .is_empty() } - fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> { + fn check_exit_rules( + &self, + moves: &(CheckerMove, CheckerMove), + exit_seqs: Option<&[(CheckerMove, CheckerMove)]>, + ) -> Result<(), MoveError> { if !moves.0.is_exit() && !moves.1.is_exit() { return Ok(()); } @@ -331,16 +335,22 @@ impl MoveRules { } // toutes les sorties directes sont autorisées, ainsi que les nombres défaillants - let ignored_rules = vec![TricTracRule::Exit]; - let possible_moves_sequences_without_excedent = - self.get_possible_moves_sequences(false, ignored_rules); - if possible_moves_sequences_without_excedent.contains(moves) { + let owned; + let seqs = match exit_seqs { + Some(s) => s, + None => { + owned = self + .get_possible_moves_sequences(false, vec![TricTracRule::Exit]); + &owned + } + }; + if seqs.contains(moves) { return Ok(()); } // À ce stade au moins un des déplacements concerne un nombre en excédant // - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on // refuse cette séquence - if !possible_moves_sequences_without_excedent.is_empty() { + if !seqs.is_empty() { return Err(MoveError::ExitByEffectPossible); } @@ -361,17 +371,24 @@ impl MoveRules { let _ = board_to_check.move_checker(&Color::White, moves.0); let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check); - let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves); - if (is_move1_exedant && moves.0.get_from() != farthest_on_move1) - || (is_move2_exedant && moves.1.get_from() != farthest_on_move2) - { + // dice normal order + let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true); + let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) + || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); + + // dice reversed order + let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false); + let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) + || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); + + if is_not_farthest1 && is_not_farthest2 { return Err(MoveError::ExitNotFarthest); } Ok(()) } - fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) { + fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) { let move1to = if moves.0.get_to() == 0 { 25 } else { @@ -386,20 +403,16 @@ impl MoveRules { }; let dist2 = move2to - moves.1.get_from(); - let dist_min = cmp::min(dist1, dist2); - let dist_max = cmp::max(dist1, dist2); - - let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize; - let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize; - - let min_excedant = dist_min != 0 && dist_min < dice_min; - let max_excedant = dist_max != 0 && dist_max < dice_max; - - if dist_min == dist1 { - (min_excedant, max_excedant) + let (dice1, dice2) = if dice_order { + self.dice.values } else { - (max_excedant, min_excedant) - } + (self.dice.values.1, self.dice.values.0) + }; + + ( + dist1 != 0 && dist1 < dice1 as usize, + dist2 != 0 && dist2 < dice2 as usize, + ) } fn get_board_exit_farthest(board: &Board) -> Field { @@ -438,12 +451,18 @@ impl MoveRules { } else { (dice2, dice1) }; + let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) { + Some(self.get_quarter_filling_moves_sequences()) + } else { + None + }; let mut moves_seqs = self.get_possible_moves_sequences_by_dices( dice_max, dice_min, with_excedents, false, - ignored_rules.clone(), + &ignored_rules, + filling_seqs.as_deref(), ); // if we got valid sequences with the highest die, we don't accept sequences using only the // lowest die @@ -453,7 +472,8 @@ impl MoveRules { dice_max, with_excedents, ignore_empty, - ignored_rules, + &ignored_rules, + filling_seqs.as_deref(), ); moves_seqs.append(&mut moves_seqs_order2); let empty_removed = moves_seqs @@ -524,14 +544,16 @@ impl MoveRules { let mut moves_seqs = Vec::new(); let color = &Color::White; let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter]; + let mut board = self.board.clone(); for moves in self.get_possible_moves_sequences(true, ignored_rules) { - let mut board = self.board.clone(); board.move_checker(color, moves.0).unwrap(); board.move_checker(color, moves.1).unwrap(); // println!("get_quarter_filling_moves_sequences board : {:?}", board); if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) { moves_seqs.push(moves); } + board.unmove_checker(color, moves.1); + board.unmove_checker(color, moves.0); } moves_seqs } @@ -542,18 +564,27 @@ impl MoveRules { dice2: u8, with_excedents: bool, ignore_empty: bool, - ignored_rules: Vec, + ignored_rules: &[TricTracRule], + filling_seqs: Option<&[(CheckerMove, CheckerMove)]>, ) -> Vec<(CheckerMove, CheckerMove)> { let mut moves_seqs = Vec::new(); let color = &Color::White; let forbid_exits = self.has_checkers_outside_last_quarter(); + // Precompute non-excedant sequences once so check_exit_rules need not repeat + // the full move generation for every exit-move candidate. + // Only needed when Exit is not already ignored and exits are actually reachable. + let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits { + Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit])) + } else { + None + }; + let mut board = self.board.clone(); // println!("==== First"); for first_move in self.board .get_possible_moves(*color, dice1, with_excedents, false, forbid_exits) { - let mut board2 = self.board.clone(); - if board2.move_checker(color, first_move).is_err() { + if board.move_checker(color, first_move).is_err() { println!("err move"); continue; } @@ -563,7 +594,7 @@ impl MoveRules { let mut has_second_dice_move = false; // println!(" ==== Second"); for second_move in - board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) + board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) { if self .check_corner_rules(&(first_move, second_move)) @@ -587,24 +618,10 @@ impl MoveRules { && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) || self - .check_exit_rules(&(first_move, second_move)) - // .inspect_err(|e| { - // println!( - // " 2nd (exit rule): {:?} - {:?}, {:?}", - // e, first_move, second_move - // ) - // }) - .is_ok()) - && (ignored_rules.contains(&TricTracRule::MustFillQuarter) - || self - .check_must_fill_quarter_rule(&(first_move, second_move)) - // .inspect_err(|e| { - // println!( - // " 2nd: {:?} - {:?}, {:?} for {:?}", - // e, first_move, second_move, self.board - // ) - // }) + .check_exit_rules(&(first_move, second_move), exit_seqs.as_deref()) .is_ok()) + && filling_seqs + .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move))) { if second_move.get_to() == 0 && first_move.get_to() == 0 @@ -627,16 +644,14 @@ impl MoveRules { && !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE)) && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) - || self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok()) - && (ignored_rules.contains(&TricTracRule::MustFillQuarter) - || self - .check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE)) - .is_ok()) + || self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok()) + && filling_seqs + .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE))) { // empty move moves_seqs.push((first_move, EMPTY_MOVE)); } - //if board2.get_color_fields(*color).is_empty() { + board.unmove_checker(color, first_move); } moves_seqs } @@ -1495,6 +1510,7 @@ mod tests { CheckerMove::new(23, 0).unwrap(), CheckerMove::new(24, 0).unwrap(), ); + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1502,7 +1518,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); @@ -1517,6 +1534,7 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), )]; + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( moves, state.get_possible_moves_sequences_by_dices( @@ -1524,7 +1542,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); let moves = vec![( @@ -1538,7 +1557,8 @@ mod tests { state.dice.values.0, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); @@ -1554,6 +1574,7 @@ mod tests { CheckerMove::new(19, 21).unwrap(), CheckerMove::new(23, 0).unwrap(), ); + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1561,7 +1582,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); } @@ -1580,13 +1602,26 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves).is_ok()); + assert!(state.check_exit_rules(&moves, None).is_ok()); let moves = ( CheckerMove::new(19, 24).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves).is_ok()); + assert!(state.check_exit_rules(&moves, None).is_ok()); + + state.dice.values = (6, 4); + state.board.set_positions( + &crate::Color::White, + [ + -4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1, + ], + ); + let moves = ( + CheckerMove::new(20, 24).unwrap(), + CheckerMove::new(23, 0).unwrap(), + ); + assert!(state.check_exit_rules(&moves, None).is_ok()); } #[test] diff --git a/store/src/pyengine.rs b/store/src/pyengine.rs index b193987..43b5713 100644 --- a/store/src/pyengine.rs +++ b/store/src/pyengine.rs @@ -113,11 +113,11 @@ impl TricTrac { [self.get_score(1), self.get_score(2)] } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_vec() + self.game_state.to_tensor() } else { - self.game_state.mirror().to_vec() + self.game_state.mirror().to_tensor() } } diff --git a/store/src/training_common.rs b/store/src/training_common.rs index 57094a9..69765fc 100644 --- a/store/src/training_common.rs +++ b/store/src/training_common.rs @@ -3,7 +3,6 @@ use std::cmp::{max, min}; use std::fmt::{Debug, Display, Formatter}; -use crate::board::Board; use crate::{CheckerMove, Dice, GameEvent, GameState}; use serde::{Deserialize, Serialize}; @@ -221,10 +220,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result anyhow::Result anyhow::Result anyhow::Result { - let dice = &state.dice; - let board = &state.board; - - if color == &crate::Color::Black { - // Moves are already 'white', so we don't mirror them - white_checker_moves_to_trictrac_action( - move1, - move2, - // &move1.clone().mirror(), - // &move2.clone().mirror(), - dice, - &board.clone().mirror(), - ) - // .map(|a| a.mirror()) + // Moves are always in White's coordinate system. For Black, mirror the board first. + let cum = if color == &crate::Color::Black { + state.board.mirror().white_checker_cumulative() } else { - white_checker_moves_to_trictrac_action(move1, move2, dice, board) - } + state.board.white_checker_cumulative() + }; + white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum) } fn white_checker_moves_to_trictrac_action( move1: &CheckerMove, move2: &CheckerMove, dice: &Dice, - board: &Board, + cum: &[u8; 25], ) -> anyhow::Result { let to1 = move1.get_to(); let to2 = move2.get_to(); @@ -302,7 +300,7 @@ fn white_checker_moves_to_trictrac_action( } } else { // double sortie - if from1 < from2 { + if from1 < from2 || from2 == 0 { max(dice.values.0, dice.values.1) as usize } else { min(dice.values.0, dice.values.1) as usize @@ -321,11 +319,21 @@ fn white_checker_moves_to_trictrac_action( } let dice_order = diff_move1 == dice.values.0 as usize; - let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize; - let mut tmp_board = board.clone(); - // should not raise an error for a valid action - tmp_board.move_checker(&crate::Color::White, *move1)?; - let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; + // cum[i] = # white checkers in fields 1..=i (precomputed by the caller). + // checker1 is the ordinal of the last checker at from1. + let checker1 = cum[from1] as usize; + // checker2 is the ordinal on the board after move1 (removed from from1, added to to1). + // Adjust the cumulative in O(1) without cloning the board. + let checker2 = { + let mut c = cum[from2]; + if from1 > 0 && from2 >= from1 { + c -= 1; // one checker was removed from from1, shifting later ordinals down + } + if from1 > 0 && to1 > 0 && from2 >= to1 { + c += 1; // one checker was added at to1, shifting later ordinals up + } + c as usize + }; Ok(TrictracAction::Move { dice_order, checker1, @@ -456,5 +464,48 @@ mod tests { }), ttaction.ok() ); + + // Black player + state.active_player_id = 2; + state.dice = Dice { values: (6, 3) }; + state.board.set_positions( + &crate::Color::White, + [ + 2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4, + ], + ); + let ttaction = super::checker_moves_to_trictrac_action( + &CheckerMove::new(21, 0).unwrap(), + &CheckerMove::new(0, 0).unwrap(), + &crate::Color::Black, + &state, + ); + + assert_eq!( + Some(TrictracAction::Move { + dice_order: true, + checker1: 2, + checker2: 0, // blocked by white on last field + }), + ttaction.ok() + ); + + // same with dice order reversed + state.dice = Dice { values: (3, 6) }; + let ttaction = super::checker_moves_to_trictrac_action( + &CheckerMove::new(21, 0).unwrap(), + &CheckerMove::new(0, 0).unwrap(), + &crate::Color::Black, + &state, + ); + + assert_eq!( + Some(TrictracAction::Move { + dice_order: false, + checker1: 2, + checker2: 0, // blocked by white on last field + }), + ttaction.ok() + ); } }