Compare commits
14 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1c4c814417 | |||
| db5c1ea4f4 | |||
| aa7f5fe42a | |||
| 145ab7dcda | |||
| f26808d798 | |||
| 43eb5bf18d | |||
| dfc485a47a | |||
| a239c02937 | |||
| 6beaa56202 | |||
| 45b9db61e3 | |||
| 44a5ba87b0 | |||
| bd4c75228b | |||
| 8732512736 | |||
| eba93f0f13 |
9 changed files with 585 additions and 1104 deletions
|
|
@ -1,992 +0,0 @@
|
|||
# Plan: C++ OpenSpiel Game via cxx.rs
|
||||
|
||||
> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings.
|
||||
>
|
||||
> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Overview
|
||||
|
||||
The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding:
|
||||
|
||||
- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box<ffi::TricTracEngine>` and delegates all logic to Rust.
|
||||
- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`.
|
||||
|
||||
The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games.
|
||||
|
||||
---
|
||||
|
||||
## 2. Files to Create / Modify
|
||||
|
||||
```
|
||||
trictrac/
|
||||
store/
|
||||
Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type
|
||||
build.rs ← CREATE: cxx-build bridge registration
|
||||
src/
|
||||
lib.rs ← MODIFY: add cxxengine module
|
||||
cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl
|
||||
|
||||
forks/open_spiel/
|
||||
CMakeLists.txt ← MODIFY: add Corrosion FetchContent
|
||||
open_spiel/
|
||||
games/
|
||||
CMakeLists.txt ← MODIFY: add trictrac/ sources + test
|
||||
trictrac/ ← CREATE directory
|
||||
trictrac.h ← CREATE
|
||||
trictrac.cc ← CREATE
|
||||
trictrac_test.cc ← CREATE
|
||||
|
||||
justfile ← MODIFY: add buildtrictrac target
|
||||
trictrac/
|
||||
justfile ← MODIFY: add cxxlib target
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Step 1 — Rust: `store/Cargo.toml`
|
||||
|
||||
Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library.
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "trictrac-store"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "trictrac_store"
|
||||
# cdylib → Python .so (used by maturin / pyengine)
|
||||
# rlib → used by other Rust crates in the workspace
|
||||
# staticlib → used by C++ consumers (cxxengine)
|
||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.21.7"
|
||||
cxx = "1.0"
|
||||
log = "0.4.20"
|
||||
merge = "0.1.0"
|
||||
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
|
||||
rand = "0.9"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
```
|
||||
|
||||
> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`).
|
||||
|
||||
---
|
||||
|
||||
## 4. Step 2 — Rust: `store/build.rs`
|
||||
|
||||
The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header.
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
cxx_build::bridge("src/cxxengine.rs")
|
||||
.std("c++17")
|
||||
.compile("trictrac-cxx");
|
||||
|
||||
// Re-run if the bridge source changes
|
||||
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
||||
}
|
||||
```
|
||||
|
||||
`cxx-build` will:
|
||||
|
||||
- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block.
|
||||
- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header.
|
||||
- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source.
|
||||
- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`).
|
||||
|
||||
---
|
||||
|
||||
## 5. Step 3 — Rust: `store/src/cxxengine.rs`
|
||||
|
||||
This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3.
|
||||
|
||||
### Design decisions vs. `pyengine.rs`
|
||||
|
||||
| pyengine | cxxengine | Reason |
|
||||
| ------------------------- | ---------------------------- | -------------------------------------------- |
|
||||
| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception |
|
||||
| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples |
|
||||
| `Vec<usize>` for actions | `Vec<u64>` | cxx does not support `usize` |
|
||||
| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays |
|
||||
| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy |
|
||||
|
||||
### File content
|
||||
|
||||
```rust
|
||||
//! # C++ bindings for the TricTrac game engine via cxx.rs
|
||||
//!
|
||||
//! Exposes an opaque `TricTracEngine` type and associated functions
|
||||
//! to C++. The C++ side (trictrac.cc) uses `rust::Box<ffi::TricTracEngine>`.
|
||||
//!
|
||||
//! The Rust engine always works from the perspective of White (player 1).
|
||||
//! For Black (player 2), the board is mirrored before computing actions
|
||||
//! and events are mirrored back before applying — exactly as in pyengine.rs.
|
||||
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||
|
||||
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
||||
|
||||
#[cxx::bridge(namespace = "trictrac_engine")]
|
||||
pub mod ffi {
|
||||
// ── Shared types (visible to both Rust and C++) ───────────────────────────
|
||||
|
||||
/// Two dice values passed from C++ to Rust for a dice-roll event.
|
||||
struct DicePair {
|
||||
die1: u8,
|
||||
die2: u8,
|
||||
}
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
struct PlayerScores {
|
||||
score_p1: i32,
|
||||
score_p2: i32,
|
||||
}
|
||||
|
||||
// ── Opaque Rust type exposed to C++ ───────────────────────────────────────
|
||||
|
||||
extern "Rust" {
|
||||
/// Opaque handle to a TricTrac game state.
|
||||
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
||||
type TricTracEngine;
|
||||
|
||||
/// Create a new engine, initialise two players, begin with player 1.
|
||||
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
||||
|
||||
/// Return a deep copy of the engine (needed for State::Clone()).
|
||||
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
||||
|
||||
// ── Queries ───────────────────────────────────────────────────────────
|
||||
|
||||
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
||||
fn needs_roll(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// True when Stage::Ended.
|
||||
fn is_game_ended(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// Active player index: 0 (player 1 / White) or 1 (player 2 / Black).
|
||||
fn current_player_idx(self: &TricTracEngine) -> u64;
|
||||
|
||||
/// Legal action indices for `player_idx`. Returns empty vec if it is
|
||||
/// not that player's turn. Indices are in [0, 513].
|
||||
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec<u64>;
|
||||
|
||||
/// Human-readable action description, e.g. "0:Move { dice_order: true … }".
|
||||
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 36-element state observation vector (i8). Mirrored for player 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
||||
/// Full debug representation of the current state.
|
||||
fn to_debug_string(self: &TricTracEngine) -> String;
|
||||
|
||||
// ── Mutations ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a dice roll result. Returns Err if not in RollWaiting stage.
|
||||
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
||||
|
||||
/// Apply a player action (move, go, roll). Returns Err if invalid.
|
||||
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Opaque type implementation ────────────────────────────────────────────────
|
||||
|
||||
pub struct TricTracEngine {
|
||||
game_state: GameState,
|
||||
}
|
||||
|
||||
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
Box::new(TricTracEngine { game_state })
|
||||
}
|
||||
|
||||
impl TricTracEngine {
|
||||
fn clone_engine(&self) -> Box<TricTracEngine> {
|
||||
Box::new(TricTracEngine {
|
||||
game_state: self.game_state.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn needs_roll(&self) -> bool {
|
||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||
}
|
||||
|
||||
fn is_game_ended(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended
|
||||
}
|
||||
|
||||
/// Returns 0 for player 1 (White) and 1 for player 2 (Black).
|
||||
fn current_player_idx(&self) -> u64 {
|
||||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<u64> {
|
||||
if player_idx == self.current_player_idx() {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
||||
TrictracAction::from_action_index(action_idx as usize)
|
||||
.map(|a| format!("{}:{}", player_idx, a))
|
||||
.unwrap_or_else(|| "unknown action".into())
|
||||
}
|
||||
|
||||
fn get_players_scores(&self) -> ffi::PlayerScores {
|
||||
ffi::PlayerScores {
|
||||
score_p1: self.score_for(1),
|
||||
score_p2: self.score_for(2),
|
||||
}
|
||||
}
|
||||
|
||||
fn score_for(&self, player_id: u64) -> i32 {
|
||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
||||
player.holes as i32 * 12 + player.points as i32
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_debug_string(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
|
||||
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> {
|
||||
let player_id = self.game_state.active_player_id;
|
||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
return Err("Not in RollWaiting stage".into());
|
||||
}
|
||||
let dice = Dice {
|
||||
values: (dice.die1, dice.die2),
|
||||
};
|
||||
self.game_state
|
||||
.consume(&GameEvent::RollResult { player_id, dice });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: u64) -> Result<(), String> {
|
||||
let action_idx = action_idx as usize;
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
|
||||
let event = TrictracAction::from_action_index(action_idx)
|
||||
.and_then(|a| {
|
||||
let game_state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(game_state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
});
|
||||
|
||||
match event {
|
||||
Some(evt) if self.game_state.validate(&evt) => {
|
||||
self.game_state.consume(&evt);
|
||||
Ok(())
|
||||
}
|
||||
Some(_) => Err("Action is invalid".into()),
|
||||
None => Err("Could not build event from action index".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options:
|
||||
>
|
||||
> - Use `anyhow::Error` (add `anyhow` dependency).
|
||||
> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`.
|
||||
>
|
||||
> The recommended approach is `anyhow`:
|
||||
>
|
||||
> ```toml
|
||||
> [dependencies]
|
||||
> anyhow = "1.0"
|
||||
> ```
|
||||
>
|
||||
> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message.
|
||||
|
||||
---
|
||||
|
||||
## 6. Step 4 — Rust: `store/src/lib.rs`
|
||||
|
||||
Add the new module:
|
||||
|
||||
```rust
|
||||
// existing modules …
|
||||
mod pyengine;
|
||||
|
||||
// NEW: C++ bindings via cxx.rs
|
||||
pub mod cxxengine;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Step 5 — C++: `trictrac/trictrac.h`
|
||||
|
||||
Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box<ffi::TricTracEngine>` and delegates everything to it.
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.h
|
||||
#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
#define OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
// Generated by cxx-build from store/src/cxxengine.rs.
|
||||
// The include path is set by CMake (see CMakeLists.txt).
|
||||
#include "trictrac_store/src/cxxengine.rs.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
|
||||
inline constexpr int kNumPlayers = 2;
|
||||
inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes
|
||||
inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust
|
||||
inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust
|
||||
inline constexpr int kDefaultMaxTurns = 1000;
|
||||
|
||||
class TrictracGame;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracState
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracState : public State {
|
||||
public:
|
||||
explicit TrictracState(std::shared_ptr<const Game> game);
|
||||
TrictracState(const TrictracState& other);
|
||||
|
||||
Player CurrentPlayer() const override;
|
||||
std::vector<Action> LegalActions() const override;
|
||||
std::string ActionToString(Player player, Action move_id) const override;
|
||||
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;
|
||||
std::string ToString() const override;
|
||||
bool IsTerminal() const override;
|
||||
std::vector<double> Returns() const override;
|
||||
std::string ObservationString(Player player) const override;
|
||||
void ObservationTensor(Player player, absl::Span<float> values) const override;
|
||||
std::unique_ptr<State> Clone() const override;
|
||||
|
||||
protected:
|
||||
void DoApplyAction(Action move_id) override;
|
||||
|
||||
private:
|
||||
// Decode a chance action index [0,35] to (die1, die2).
|
||||
// Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
static trictrac_engine::DicePair DecodeChanceAction(Action action);
|
||||
|
||||
// The Rust engine handle. Deep-copied via clone_engine() when cloning state.
|
||||
rust::Box<trictrac_engine::TricTracEngine> engine_;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracGame
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracGame : public Game {
|
||||
public:
|
||||
explicit TrictracGame(const GameParameters& params);
|
||||
|
||||
int NumDistinctActions() const override { return kNumDistinctActions; }
|
||||
std::unique_ptr<State> NewInitialState() const override;
|
||||
int MaxChanceOutcomes() const override { return kNumChanceOutcomes; }
|
||||
int NumPlayers() const override { return kNumPlayers; }
|
||||
double MinUtility() const override { return 0.0; }
|
||||
double MaxUtility() const override { return 200.0; }
|
||||
int MaxGameLength() const override { return 3 * max_turns_; }
|
||||
int MaxChanceNodesInHistory() const override { return MaxGameLength(); }
|
||||
std::vector<int> ObservationTensorShape() const override {
|
||||
return {kStateEncodingSize};
|
||||
}
|
||||
|
||||
private:
|
||||
int max_turns_;
|
||||
};
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Step 6 — C++: `trictrac/trictrac.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/abseil-cpp/absl/types/span.h"
|
||||
#include "open_spiel/game_parameters.h"
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_globals.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
// ── Game registration ────────────────────────────────────────────────────────
|
||||
|
||||
const GameType kGameType{
|
||||
/*short_name=*/"trictrac",
|
||||
/*long_name=*/"Trictrac",
|
||||
GameType::Dynamics::kSequential,
|
||||
GameType::ChanceMode::kExplicitStochastic,
|
||||
GameType::Information::kPerfectInformation,
|
||||
GameType::Utility::kGeneralSum,
|
||||
GameType::RewardModel::kRewards,
|
||||
/*min_num_players=*/kNumPlayers,
|
||||
/*max_num_players=*/kNumPlayers,
|
||||
/*provides_information_state_string=*/false,
|
||||
/*provides_information_state_tensor=*/false,
|
||||
/*provides_observation_string=*/true,
|
||||
/*provides_observation_tensor=*/true,
|
||||
/*parameter_specification=*/{
|
||||
{"max_turns", GameParameter(kDefaultMaxTurns)},
|
||||
}};
|
||||
|
||||
static std::shared_ptr<const Game> Factory(const GameParameters& params) {
|
||||
return std::make_shared<const TrictracGame>(params);
|
||||
}
|
||||
|
||||
REGISTER_SPIEL_GAME(kGameType, Factory);
|
||||
|
||||
} // namespace
|
||||
|
||||
// ── TrictracGame ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracGame::TrictracGame(const GameParameters& params)
|
||||
: Game(kGameType, params),
|
||||
max_turns_(ParameterValue<int>("max_turns", kDefaultMaxTurns)) {}
|
||||
|
||||
std::unique_ptr<State> TrictracGame::NewInitialState() const {
|
||||
return std::make_unique<TrictracState>(shared_from_this());
|
||||
}
|
||||
|
||||
// ── TrictracState ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracState::TrictracState(std::shared_ptr<const Game> game)
|
||||
: State(game),
|
||||
engine_(trictrac_engine::new_trictrac_engine()) {}
|
||||
|
||||
// Copy constructor: deep-copy the Rust engine via clone_engine().
|
||||
TrictracState::TrictracState(const TrictracState& other)
|
||||
: State(other),
|
||||
engine_(other.engine_->clone_engine()) {}
|
||||
|
||||
std::unique_ptr<State> TrictracState::Clone() const {
|
||||
return std::make_unique<TrictracState>(*this);
|
||||
}
|
||||
|
||||
// ── Current player ────────────────────────────────────────────────────────────
|
||||
|
||||
Player TrictracState::CurrentPlayer() const {
|
||||
if (engine_->is_game_ended()) return kTerminalPlayerId;
|
||||
if (engine_->needs_roll()) return kChancePlayerId;
|
||||
return static_cast<Player>(engine_->current_player_idx());
|
||||
}
|
||||
|
||||
// ── Legal actions ─────────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<Action> TrictracState::LegalActions() const {
|
||||
if (IsChanceNode()) {
|
||||
// All 36 dice outcomes are equally likely; return indices 0–35.
|
||||
std::vector<Action> actions(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i;
|
||||
return actions;
|
||||
}
|
||||
Player player = CurrentPlayer();
|
||||
rust::Vec<uint64_t> rust_actions =
|
||||
engine_->get_legal_actions(static_cast<uint64_t>(player));
|
||||
std::vector<Action> actions;
|
||||
actions.reserve(rust_actions.size());
|
||||
for (uint64_t a : rust_actions) actions.push_back(static_cast<Action>(a));
|
||||
return actions;
|
||||
}
|
||||
|
||||
// ── Chance outcomes ───────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<std::pair<Action, double>> TrictracState::ChanceOutcomes() const {
|
||||
SPIEL_CHECK_TRUE(IsChanceNode());
|
||||
const double p = 1.0 / kNumChanceOutcomes;
|
||||
std::vector<std::pair<Action, double>> outcomes;
|
||||
outcomes.reserve(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p);
|
||||
return outcomes;
|
||||
}
|
||||
|
||||
// ── Apply action ──────────────────────────────────────────────────────────────
|
||||
|
||||
/*static*/
|
||||
trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) {
|
||||
// Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
return trictrac_engine::DicePair{
|
||||
/*die1=*/static_cast<uint8_t>(action / 6 + 1),
|
||||
/*die2=*/static_cast<uint8_t>(action % 6 + 1),
|
||||
};
|
||||
}
|
||||
|
||||
void TrictracState::DoApplyAction(Action action) {
|
||||
if (IsChanceNode()) {
|
||||
engine_->apply_dice_roll(DecodeChanceAction(action));
|
||||
} else {
|
||||
engine_->apply_action(static_cast<uint64_t>(action));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Terminal & returns ────────────────────────────────────────────────────────
|
||||
|
||||
bool TrictracState::IsTerminal() const {
|
||||
return engine_->is_game_ended();
|
||||
}
|
||||
|
||||
std::vector<double> TrictracState::Returns() const {
|
||||
trictrac_engine::PlayerScores scores = engine_->get_players_scores();
|
||||
return {static_cast<double>(scores.score_p1),
|
||||
static_cast<double>(scores.score_p2)};
|
||||
}
|
||||
|
||||
// ── Observation ───────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ObservationString(Player player) const {
|
||||
return std::string(engine_->get_observation_string(
|
||||
static_cast<uint64_t>(player)));
|
||||
}
|
||||
|
||||
void TrictracState::ObservationTensor(Player player,
|
||||
absl::Span<float> values) const {
|
||||
SPIEL_CHECK_EQ(values.size(), kStateEncodingSize);
|
||||
rust::Vec<int8_t> tensor =
|
||||
engine_->get_tensor(static_cast<uint64_t>(player));
|
||||
SPIEL_CHECK_EQ(tensor.size(), static_cast<size_t>(kStateEncodingSize));
|
||||
for (int i = 0; i < kStateEncodingSize; ++i) {
|
||||
values[i] = static_cast<float>(tensor[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strings ───────────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ToString() const {
|
||||
return std::string(engine_->to_debug_string());
|
||||
}
|
||||
|
||||
std::string TrictracState::ActionToString(Player player, Action action) const {
|
||||
if (IsChanceNode()) {
|
||||
trictrac_engine::DicePair d = DecodeChanceAction(action);
|
||||
return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")";
|
||||
}
|
||||
return std::string(engine_->action_to_string(
|
||||
static_cast<uint64_t>(player), static_cast<uint64_t>(action)));
|
||||
}
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Step 7 — C++: `trictrac/trictrac_test.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac_test.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/tests/basic_tests.h"
|
||||
#include "open_spiel/utils/init.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
void BasicTrictracTests() {
|
||||
testing::LoadGameTest("trictrac");
|
||||
testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
open_spiel::Init(&argc, &argv);
|
||||
open_spiel::trictrac::BasicTrictracTests();
|
||||
std::cout << "trictrac tests passed" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt`
|
||||
|
||||
The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined:
|
||||
|
||||
```cmake
|
||||
# ── Corrosion: CMake integration for Rust ────────────────────────────────────
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
Corrosion
|
||||
GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git
|
||||
GIT_TAG v0.5.1 # pin to a stable release
|
||||
)
|
||||
FetchContent_MakeAvailable(Corrosion)
|
||||
|
||||
# Import the trictrac-store Rust crate.
|
||||
# This creates a CMake target named 'trictrac-store'.
|
||||
corrosion_import_crate(
|
||||
MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml
|
||||
CRATES trictrac-store
|
||||
)
|
||||
|
||||
# Generate the cxx bridge from cxxengine.rs.
|
||||
# corrosion_add_cxxbridge:
|
||||
# - runs cxx-build as part of the Rust build
|
||||
# - creates a CMake target 'trictrac_cxx_bridge' that:
|
||||
# * compiles the generated C++ shim
|
||||
# * exposes INTERFACE include dirs for the generated .rs.h header
|
||||
corrosion_add_cxxbridge(trictrac_cxx_bridge
|
||||
CRATE trictrac-store
|
||||
FILES src/cxxengine.rs
|
||||
)
|
||||
```
|
||||
|
||||
> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing.
|
||||
|
||||
---
|
||||
|
||||
## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt`
|
||||
|
||||
Two changes: add the new source files to `GAME_SOURCES`, and add a test target.
|
||||
|
||||
### 11.1 Add to `GAME_SOURCES`
|
||||
|
||||
Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add:
|
||||
|
||||
```cmake
|
||||
set(GAME_SOURCES
|
||||
# ... existing games ...
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
# ... remaining games ...
|
||||
)
|
||||
```
|
||||
|
||||
### 11.2 Link cxx bridge into OpenSpiel objects
|
||||
|
||||
The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed.
|
||||
|
||||
The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add:
|
||||
|
||||
```cmake
|
||||
target_link_libraries(open_spiel
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge # C++ shim generated by cxx-build
|
||||
trictrac-store # Rust static library
|
||||
)
|
||||
```
|
||||
|
||||
If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game:
|
||||
|
||||
```cmake
|
||||
add_library(trictrac_game OBJECT
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
)
|
||||
target_include_directories(trictrac_game
|
||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
target_link_libraries(trictrac_game
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
open_spiel_core # or whatever the core target is called
|
||||
)
|
||||
```
|
||||
|
||||
Then reference `$<TARGET_OBJECTS:trictrac_game>` in relevant executables.
|
||||
|
||||
### 11.3 Add the test
|
||||
|
||||
```cmake
|
||||
add_executable(trictrac_test
|
||||
trictrac/trictrac_test.cc
|
||||
${OPEN_SPIEL_OBJECTS}
|
||||
$<TARGET_OBJECTS:tests>
|
||||
)
|
||||
target_link_libraries(trictrac_test
|
||||
PRIVATE
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
)
|
||||
add_test(trictrac_test trictrac_test)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Step 10 — Justfile updates
|
||||
|
||||
### `trictrac/justfile` — add `cxxlib` target
|
||||
|
||||
Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists:
|
||||
|
||||
```just
|
||||
cxxlib:
|
||||
cargo build --release -p trictrac-store
|
||||
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
|
||||
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
|
||||
```
|
||||
|
||||
### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac`
|
||||
|
||||
```just
|
||||
buildtrictrac:
|
||||
# Rebuild the Rust static lib first, then CMake
|
||||
cd ../../trictrac && cargo build --release -p trictrac-store
|
||||
mkdir -p build && cd build && \
|
||||
CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \
|
||||
make -j$(nproc) trictrac_test
|
||||
|
||||
testtrictrac: buildtrictrac
|
||||
./build/trictrac_test
|
||||
|
||||
playtrictrac_cpp:
|
||||
./build/examples/example --game=trictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Key Design Decisions
|
||||
|
||||
### 13.1 Opaque type with `clone_engine()`
|
||||
|
||||
OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box<TricTracEngine>` which calls `.clone()` on the inner `GameState` (which derives `Clone`).
|
||||
|
||||
### 13.2 Action encoding: same 514-element space
|
||||
|
||||
The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means:
|
||||
|
||||
- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies.
|
||||
- Action 0 = Roll (used as the bridge between Move and the next chance node).
|
||||
- Actions 2–513 = Move variants (checker ordinal pair + dice order).
|
||||
- A trained C++ model and Python model share the same action space.
|
||||
|
||||
### 13.3 Chance outcome ordering
|
||||
|
||||
The dice outcome ordering is identical to the Python version:
|
||||
|
||||
```
|
||||
action → (die1, die2)
|
||||
0 → (1,1) 6 → (2,1) ... 35 → (6,6)
|
||||
```
|
||||
|
||||
(`die1 = action/6 + 1`, `die2 = action%6 + 1`)
|
||||
|
||||
This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines.
|
||||
|
||||
### 13.4 `GameType::Utility::kGeneralSum` + `kRewards`
|
||||
|
||||
Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal.
|
||||
|
||||
### 13.5 Mirror pattern preserved
|
||||
|
||||
`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest.
|
||||
|
||||
### 13.6 `rust::Box` vs `rust::UniquePtr`
|
||||
|
||||
`rust::Box<T>` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr<T>` is for C++ types held in Rust.
|
||||
|
||||
### 13.7 Separate struct from `pyengine.rs`
|
||||
|
||||
`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids:
|
||||
|
||||
- PyO3 and cxx attributes conflicting on the same type.
|
||||
- Changes to one binding breaking the other.
|
||||
- Feature-flag complexity.
|
||||
|
||||
---
|
||||
|
||||
## 14. Known Challenges
|
||||
|
||||
### 14.1 Corrosion path resolution
|
||||
|
||||
`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable).
|
||||
|
||||
### 14.2 `staticlib` + `cdylib` in one crate
|
||||
|
||||
Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions:
|
||||
|
||||
- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time.
|
||||
- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile.
|
||||
- Alternatively, accept the extra compile time during development.
|
||||
|
||||
### 14.3 Linker symbols from Rust std
|
||||
|
||||
When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`).
|
||||
|
||||
### 14.4 `anyhow` for error types
|
||||
|
||||
cxx.rs requires the `Err` type in `Result<T, E>` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper:
|
||||
|
||||
```rust
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EngineError(String);
|
||||
impl fmt::Display for EngineError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) }
|
||||
}
|
||||
impl std::error::Error for EngineError {}
|
||||
```
|
||||
|
||||
On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`.
|
||||
|
||||
### 14.5 `UndoAction` not implemented
|
||||
|
||||
OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo.
|
||||
|
||||
### 14.6 Generated header path in `#include`
|
||||
|
||||
The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory.
|
||||
|
||||
### 14.7 `rust::String` to `std::string` conversion
|
||||
|
||||
The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast<std::string>(...)`.
|
||||
|
||||
---
|
||||
|
||||
## 15. Complete File Checklist
|
||||
|
||||
```
|
||||
[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib
|
||||
[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...)
|
||||
[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;`
|
||||
[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation
|
||||
[ ] trictrac/justfile — add `cxxlib` target
|
||||
[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge
|
||||
[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file
|
||||
[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 16. Implementation Order
|
||||
|
||||
Implement in this sequence to get early feedback at each step:
|
||||
|
||||
1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced.
|
||||
2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations.
|
||||
3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist.
|
||||
4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them.
|
||||
5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination.
|
||||
6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`.
|
||||
|
||||
---
|
||||
|
||||
# Implementation complete
|
||||
|
||||
All files are in place and trictrac_test passes (168 84 final scores, all assertions pass).
|
||||
|
||||
What was done
|
||||
|
||||
Rust side (from previous session, already complete):
|
||||
|
||||
- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build
|
||||
- trictrac/store/build.rs — drives cxx-build
|
||||
- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl
|
||||
- trictrac/store/src/lib.rs — added pub mod cxxengine;
|
||||
|
||||
C++ side (this session):
|
||||
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test
|
||||
|
||||
Build system:
|
||||
|
||||
- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge
|
||||
- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable
|
||||
|
||||
Justfiles:
|
||||
|
||||
- trictrac/justfile — added cxxlib target
|
||||
- forks/open_spiel/justfile — added buildtrictrac and testtrictrac
|
||||
|
||||
Fixes discovered during build
|
||||
|
||||
| Issue | Fix |
|
||||
| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries |
|
||||
| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) |
|
||||
| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h |
|
||||
| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc |
|
||||
| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) |
|
||||
| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries |
|
||||
| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort |
|
||||
| Duplicate actions for doubles | Added std::unique after sort |
|
||||
| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states |
|
||||
|
|
@ -25,5 +25,9 @@ rand = "0.9"
|
|||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[[bin]]
|
||||
name = "random_game"
|
||||
path = "src/bin/random_game.rs"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
262
store/src/bin/random_game.rs
Normal file
262
store/src/bin/random_game.rs
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
//! Run one or many games of trictrac between two random players.
|
||||
//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`.
|
||||
//! In multi-game mode, runs silently and reports throughput at the end.
|
||||
//!
|
||||
//! Usage:
|
||||
//! cargo run --bin random_game -- [--seed <u64>] [--games <usize>] [--max-steps <usize>] [--verbose]
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::env;
|
||||
use std::time::Instant;
|
||||
|
||||
use trictrac_store::{
|
||||
training_common::sample_valid_action,
|
||||
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
|
||||
};
|
||||
|
||||
// ── CLI args ──────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
seed: Option<u64>,
|
||||
games: usize,
|
||||
max_steps: usize,
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut seed = None;
|
||||
let mut games = 1;
|
||||
let mut max_steps = 10_000;
|
||||
let mut verbose = false;
|
||||
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--seed" => {
|
||||
i += 1;
|
||||
seed = args.get(i).and_then(|s| s.parse().ok());
|
||||
}
|
||||
"--games" => {
|
||||
i += 1;
|
||||
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
|
||||
games = v;
|
||||
}
|
||||
}
|
||||
"--max-steps" => {
|
||||
i += 1;
|
||||
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
|
||||
max_steps = v;
|
||||
}
|
||||
}
|
||||
"--verbose" => verbose = true,
|
||||
_ => {}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Args {
|
||||
seed,
|
||||
games,
|
||||
max_steps,
|
||||
verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn player_label(id: u64) -> &'static str {
|
||||
if id == 1 { "White" } else { "Black" }
|
||||
}
|
||||
|
||||
/// Apply a `Roll` + `RollResult` in one logical step, returning the dice.
|
||||
/// This collapses the two-step dice phase into a single "chance node" action,
|
||||
/// matching how the OpenSpiel layer exposes it.
|
||||
fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result<Dice, String> {
|
||||
// RollDice → RollWaiting
|
||||
state
|
||||
.consume(&GameEvent::Roll { player_id: state.active_player_id })
|
||||
.map_err(|e| format!("Roll event failed: {e}"))?;
|
||||
|
||||
// RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole)
|
||||
let dice = roller.roll();
|
||||
state
|
||||
.consume(&GameEvent::RollResult { player_id: state.active_player_id, dice })
|
||||
.map_err(|e| format!("RollResult event failed: {e}"))?;
|
||||
|
||||
Ok(dice)
|
||||
}
|
||||
|
||||
/// Sample a random action and apply it to `state`, handling the Black-mirror
|
||||
/// transform exactly as `cxxengine.rs::apply_action` does:
|
||||
///
|
||||
/// 1. For Black, build a mirrored view of the state so that `sample_valid_action`
|
||||
/// and `to_event` always reason from White's perspective.
|
||||
/// 2. Mirror the resulting event back to the original coordinate frame before
|
||||
/// calling `state.consume`.
|
||||
///
|
||||
/// Returns the chosen action (in the view's coordinate frame) for display.
|
||||
fn apply_player_action(state: &mut GameState) -> Result<(), String> {
|
||||
let needs_mirror = state.active_player_id == 2;
|
||||
|
||||
// Build a White-perspective view: borrowed for White, owned mirror for Black.
|
||||
let view: Cow<GameState> = if needs_mirror {
|
||||
Cow::Owned(state.mirror())
|
||||
} else {
|
||||
Cow::Borrowed(state)
|
||||
};
|
||||
|
||||
let action = sample_valid_action(&view)
|
||||
.ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?;
|
||||
|
||||
let event = action
|
||||
.to_event(&view)
|
||||
.ok_or_else(|| format!("could not convert {action:?} to event"))?;
|
||||
|
||||
// Translate the event from the view's frame back to the game's frame.
|
||||
let event = if needs_mirror { event.get_mirror(false) } else { event };
|
||||
|
||||
state
|
||||
.consume(&event)
|
||||
.map_err(|e| format!("consume({action:?}): {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Single game ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run one full game, optionally printing play-by-play.
|
||||
/// Returns `(steps, truncated)`.
|
||||
fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) {
|
||||
let mut state = GameState::new_with_players("White", "Black");
|
||||
let mut step = 0usize;
|
||||
|
||||
if !quiet {
|
||||
println!("{state}");
|
||||
}
|
||||
|
||||
while state.stage != Stage::Ended {
|
||||
step += 1;
|
||||
if step > max_steps {
|
||||
return (step - 1, true);
|
||||
}
|
||||
|
||||
match state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
let player = state.active_player_id;
|
||||
match apply_dice_roll(&mut state, roller) {
|
||||
Ok(dice) => {
|
||||
if !quiet {
|
||||
println!(
|
||||
"[step {step:4}] {} rolls: {} & {}",
|
||||
player_label(player),
|
||||
dice.values.0,
|
||||
dice.values.1
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error during dice roll: {e}");
|
||||
eprintln!("State:\n{state}");
|
||||
return (step, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
stage => {
|
||||
let player = state.active_player_id;
|
||||
match apply_player_action(&mut state) {
|
||||
Ok(()) => {
|
||||
if !quiet {
|
||||
println!(
|
||||
"[step {step:4}] {} ({stage:?})",
|
||||
player_label(player)
|
||||
);
|
||||
if verbose {
|
||||
println!("{state}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: {e}");
|
||||
eprintln!("State:\n{state}");
|
||||
return (step, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !quiet {
|
||||
println!("\n=== Game over after {step} steps ===\n");
|
||||
println!("{state}");
|
||||
|
||||
let white = state.players.get(&1);
|
||||
let black = state.players.get(&2);
|
||||
|
||||
match (white, black) {
|
||||
(Some(w), Some(b)) => {
|
||||
println!("White — holes: {:2}, points: {:2}", w.holes, w.points);
|
||||
println!("Black — holes: {:2}, points: {:2}", b.holes, b.points);
|
||||
println!();
|
||||
|
||||
let white_score = w.holes as i32 * 12 + w.points as i32;
|
||||
let black_score = b.holes as i32 * 12 + b.points as i32;
|
||||
|
||||
if white_score > black_score {
|
||||
println!("Winner: White (+{})", white_score - black_score);
|
||||
} else if black_score > white_score {
|
||||
println!("Winner: Black (+{})", black_score - white_score);
|
||||
} else {
|
||||
println!("Draw");
|
||||
}
|
||||
}
|
||||
_ => eprintln!("Could not read final player scores."),
|
||||
}
|
||||
}
|
||||
|
||||
(step, false)
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
let mut roller = DiceRoller::new(args.seed);
|
||||
|
||||
if args.games == 1 {
|
||||
println!("=== Trictrac — random game ===");
|
||||
if let Some(s) = args.seed {
|
||||
println!("seed: {s}");
|
||||
}
|
||||
println!();
|
||||
run_game(&mut roller, args.max_steps, false, args.verbose);
|
||||
} else {
|
||||
println!("=== Trictrac — {} games ===", args.games);
|
||||
if let Some(s) = args.seed {
|
||||
println!("seed: {s}");
|
||||
}
|
||||
println!();
|
||||
|
||||
let mut total_steps = 0u64;
|
||||
let mut truncated = 0usize;
|
||||
|
||||
let t0 = Instant::now();
|
||||
for _ in 0..args.games {
|
||||
let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose);
|
||||
total_steps += steps as u64;
|
||||
if trunc {
|
||||
truncated += 1;
|
||||
}
|
||||
}
|
||||
let elapsed = t0.elapsed();
|
||||
|
||||
let secs = elapsed.as_secs_f64();
|
||||
println!("Games : {}", args.games);
|
||||
println!("Truncated : {truncated}");
|
||||
println!("Total steps: {total_steps}");
|
||||
println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64);
|
||||
println!("Elapsed : {:.3} s", secs);
|
||||
println!("Throughput : {:.1} games/s", args.games as f64 / secs);
|
||||
println!(" {:.0} steps/s", total_steps as f64 / secs);
|
||||
}
|
||||
}
|
||||
|
|
@ -598,12 +598,40 @@ impl Board {
|
|||
core::array::from_fn(|i| i + min)
|
||||
}
|
||||
|
||||
/// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i.
|
||||
/// result[0] = 0.
|
||||
pub fn white_checker_cumulative(&self) -> [u8; 25] {
|
||||
let mut cum = [0u8; 25];
|
||||
let mut total = 0u8;
|
||||
for (i, &count) in self.positions.iter().enumerate() {
|
||||
if count > 0 {
|
||||
total += count as u8;
|
||||
}
|
||||
cum[i + 1] = total;
|
||||
}
|
||||
cum
|
||||
}
|
||||
|
||||
pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> {
|
||||
self.remove_checker(color, cmove.from)?;
|
||||
self.add_checker(color, cmove.to)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reverse a previously applied `move_checker`. No validation: assumes the move was valid.
|
||||
pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) {
|
||||
let unit = match color {
|
||||
Color::White => 1,
|
||||
Color::Black => -1,
|
||||
};
|
||||
if cmove.from != 0 {
|
||||
self.positions[cmove.from - 1] += unit;
|
||||
}
|
||||
if cmove.to != 0 {
|
||||
self.positions[cmove.to - 1] -= unit;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> {
|
||||
if field == 0 {
|
||||
return Ok(());
|
||||
|
|
|
|||
|
|
@ -83,8 +83,8 @@ pub mod ffi {
|
|||
/// Both players' scores.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 36-element state vector (i8). Mirrored for player_idx == 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
/// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<f32>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
|
@ -153,8 +153,7 @@ impl TricTracEngine {
|
|||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
|
@ -180,11 +179,11 @@ impl TricTracEngine {
|
|||
.unwrap_or(-1)
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
self.game_state.to_tensor()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
self.game_state.mirror().to_tensor()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,8 +242,9 @@ impl TricTracEngine {
|
|||
self.game_state
|
||||
),
|
||||
None => anyhow::bail!(
|
||||
"apply_action: could not build event from action index {}",
|
||||
action_idx
|
||||
"apply_action: could not build event from action index {} in state {}",
|
||||
action_idx,
|
||||
self.game_state
|
||||
),
|
||||
}
|
||||
}))
|
||||
|
|
|
|||
|
|
@ -156,13 +156,6 @@ impl GameState {
|
|||
if let Some(p1) = self.players.get(&1) {
|
||||
mirrored_players.insert(2, p1.mirror());
|
||||
}
|
||||
let mirrored_history = self
|
||||
.history
|
||||
.clone()
|
||||
.iter()
|
||||
.map(|evt| evt.get_mirror(false))
|
||||
.collect();
|
||||
|
||||
let (move1, move2) = self.dice_moves;
|
||||
GameState {
|
||||
stage: self.stage,
|
||||
|
|
@ -171,7 +164,7 @@ impl GameState {
|
|||
active_player_id: mirrored_active_player,
|
||||
// active_player_id: self.active_player_id,
|
||||
players: mirrored_players,
|
||||
history: mirrored_history,
|
||||
history: Vec::new(),
|
||||
dice: self.dice,
|
||||
dice_points: self.dice_points,
|
||||
dice_moves: (move1.mirror(), move2.mirror()),
|
||||
|
|
@ -207,6 +200,106 @@ impl GameState {
|
|||
self.to_vec().iter().map(|&x| x as f32).collect()
|
||||
}
|
||||
|
||||
/// Get state as a tensor for neural network training (Option B, TD-Gammon style).
|
||||
/// Returns 217 f32 values, all normalized to [0, 1].
|
||||
///
|
||||
/// Must be called from the active player's perspective: callers should mirror
|
||||
/// the GameState for Black before calling so that "own" always means White.
|
||||
///
|
||||
/// Layout:
|
||||
/// [0..95] own (White) checkers: 4 values per field × 24 fields
|
||||
/// [96..191] opp (Black) checkers: 4 values per field × 24 fields
|
||||
/// [192..193] dice values / 6
|
||||
/// [194] active player color (0=White, 1=Black)
|
||||
/// [195] turn_stage / 5
|
||||
/// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille
|
||||
/// [200..203] Black player: same
|
||||
/// [204..207] own quarter filled (quarters 1-4)
|
||||
/// [208..211] opp quarter filled (quarters 1-4)
|
||||
/// [212] own checkers all in exit zone (fields 19-24)
|
||||
/// [213] opp checkers all in exit zone (fields 1-6)
|
||||
/// [214] own coin de repos taken (field 12 has ≥2 own checkers)
|
||||
/// [215] opp coin de repos taken (field 13 has ≥2 opp checkers)
|
||||
/// [216] own dice_roll_count / 3, clamped to 1
|
||||
pub fn to_tensor(&self) -> Vec<f32> {
|
||||
let mut t = Vec::with_capacity(217);
|
||||
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
|
||||
|
||||
// [0..95] own (White) checkers, TD-Gammon encoding
|
||||
for &c in &pos {
|
||||
let own = c.max(0) as u8;
|
||||
t.push((own == 1) as u8 as f32);
|
||||
t.push((own == 2) as u8 as f32);
|
||||
t.push((own == 3) as u8 as f32);
|
||||
t.push(own.saturating_sub(3) as f32);
|
||||
}
|
||||
|
||||
// [96..191] opp (Black) checkers, TD-Gammon encoding
|
||||
for &c in &pos {
|
||||
let opp = (-c).max(0) as u8;
|
||||
t.push((opp == 1) as u8 as f32);
|
||||
t.push((opp == 2) as u8 as f32);
|
||||
t.push((opp == 3) as u8 as f32);
|
||||
t.push(opp.saturating_sub(3) as f32);
|
||||
}
|
||||
|
||||
// [192..193] dice
|
||||
t.push(self.dice.values.0 as f32 / 6.0);
|
||||
t.push(self.dice.values.1 as f32 / 6.0);
|
||||
|
||||
// [194] active player color
|
||||
t.push(
|
||||
self.who_plays()
|
||||
.map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 })
|
||||
.unwrap_or(0.0),
|
||||
);
|
||||
|
||||
// [195] turn stage
|
||||
t.push(u8::from(self.turn_stage) as f32 / 5.0);
|
||||
|
||||
// [196..199] White player stats
|
||||
let wp = self.get_white_player();
|
||||
t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [200..203] Black player stats
|
||||
let bp = self.get_black_player();
|
||||
t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [204..207] own (White) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [208..211] opp (Black) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [212] can_exit_own: no own checker in fields 1-18
|
||||
t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32);
|
||||
|
||||
// [213] can_exit_opp: no opp checker in fields 7-24
|
||||
t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32);
|
||||
|
||||
// [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers)
|
||||
t.push((pos[11] >= 2) as u8 as f32);
|
||||
|
||||
// [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers)
|
||||
t.push((pos[12] <= -2) as u8 as f32);
|
||||
|
||||
// [216] own dice_roll_count / 3, clamped to 1
|
||||
t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0));
|
||||
|
||||
debug_assert_eq!(t.len(), 217, "to_tensor length mismatch");
|
||||
t
|
||||
}
|
||||
|
||||
/// Get state as a vector (to be used for bot training input) :
|
||||
/// length = 36
|
||||
/// i8 for board positions with negative values for blacks
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ impl MoveRules {
|
|||
// Si possible, les deux dés doivent être joués
|
||||
if moves.0.get_from() == 0 || moves.1.get_from() == 0 {
|
||||
let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]);
|
||||
possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok());
|
||||
possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok());
|
||||
// possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok());
|
||||
if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() {
|
||||
if *moves == (EMPTY_MOVE, EMPTY_MOVE) {
|
||||
|
|
@ -238,7 +238,7 @@ impl MoveRules {
|
|||
|
||||
// check exit rules
|
||||
// if !ignored_rules.contains(&TricTracRule::Exit) {
|
||||
self.check_exit_rules(moves)?;
|
||||
self.check_exit_rules(moves, None)?;
|
||||
// }
|
||||
|
||||
// --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ----
|
||||
|
|
@ -321,7 +321,11 @@ impl MoveRules {
|
|||
.is_empty()
|
||||
}
|
||||
|
||||
fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> {
|
||||
fn check_exit_rules(
|
||||
&self,
|
||||
moves: &(CheckerMove, CheckerMove),
|
||||
exit_seqs: Option<&[(CheckerMove, CheckerMove)]>,
|
||||
) -> Result<(), MoveError> {
|
||||
if !moves.0.is_exit() && !moves.1.is_exit() {
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -331,16 +335,22 @@ impl MoveRules {
|
|||
}
|
||||
|
||||
// toutes les sorties directes sont autorisées, ainsi que les nombres défaillants
|
||||
let ignored_rules = vec![TricTracRule::Exit];
|
||||
let possible_moves_sequences_without_excedent =
|
||||
self.get_possible_moves_sequences(false, ignored_rules);
|
||||
if possible_moves_sequences_without_excedent.contains(moves) {
|
||||
let owned;
|
||||
let seqs = match exit_seqs {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
owned = self
|
||||
.get_possible_moves_sequences(false, vec![TricTracRule::Exit]);
|
||||
&owned
|
||||
}
|
||||
};
|
||||
if seqs.contains(moves) {
|
||||
return Ok(());
|
||||
}
|
||||
// À ce stade au moins un des déplacements concerne un nombre en excédant
|
||||
// - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on
|
||||
// refuse cette séquence
|
||||
if !possible_moves_sequences_without_excedent.is_empty() {
|
||||
if !seqs.is_empty() {
|
||||
return Err(MoveError::ExitByEffectPossible);
|
||||
}
|
||||
|
||||
|
|
@ -361,17 +371,24 @@ impl MoveRules {
|
|||
let _ = board_to_check.move_checker(&Color::White, moves.0);
|
||||
let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check);
|
||||
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves);
|
||||
if (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2)
|
||||
{
|
||||
// dice normal order
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true);
|
||||
let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
|
||||
|
||||
// dice reversed order
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false);
|
||||
let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
|
||||
|
||||
if is_not_farthest1 && is_not_farthest2 {
|
||||
return Err(MoveError::ExitNotFarthest);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) {
|
||||
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) {
|
||||
let move1to = if moves.0.get_to() == 0 {
|
||||
25
|
||||
} else {
|
||||
|
|
@ -386,20 +403,16 @@ impl MoveRules {
|
|||
};
|
||||
let dist2 = move2to - moves.1.get_from();
|
||||
|
||||
let dist_min = cmp::min(dist1, dist2);
|
||||
let dist_max = cmp::max(dist1, dist2);
|
||||
|
||||
let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize;
|
||||
let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize;
|
||||
|
||||
let min_excedant = dist_min != 0 && dist_min < dice_min;
|
||||
let max_excedant = dist_max != 0 && dist_max < dice_max;
|
||||
|
||||
if dist_min == dist1 {
|
||||
(min_excedant, max_excedant)
|
||||
let (dice1, dice2) = if dice_order {
|
||||
self.dice.values
|
||||
} else {
|
||||
(max_excedant, min_excedant)
|
||||
}
|
||||
(self.dice.values.1, self.dice.values.0)
|
||||
};
|
||||
|
||||
(
|
||||
dist1 != 0 && dist1 < dice1 as usize,
|
||||
dist2 != 0 && dist2 < dice2 as usize,
|
||||
)
|
||||
}
|
||||
|
||||
fn get_board_exit_farthest(board: &Board) -> Field {
|
||||
|
|
@ -438,12 +451,18 @@ impl MoveRules {
|
|||
} else {
|
||||
(dice2, dice1)
|
||||
};
|
||||
let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) {
|
||||
Some(self.get_quarter_filling_moves_sequences())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut moves_seqs = self.get_possible_moves_sequences_by_dices(
|
||||
dice_max,
|
||||
dice_min,
|
||||
with_excedents,
|
||||
false,
|
||||
ignored_rules.clone(),
|
||||
&ignored_rules,
|
||||
filling_seqs.as_deref(),
|
||||
);
|
||||
// if we got valid sequences with the highest die, we don't accept sequences using only the
|
||||
// lowest die
|
||||
|
|
@ -453,7 +472,8 @@ impl MoveRules {
|
|||
dice_max,
|
||||
with_excedents,
|
||||
ignore_empty,
|
||||
ignored_rules,
|
||||
&ignored_rules,
|
||||
filling_seqs.as_deref(),
|
||||
);
|
||||
moves_seqs.append(&mut moves_seqs_order2);
|
||||
let empty_removed = moves_seqs
|
||||
|
|
@ -524,14 +544,16 @@ impl MoveRules {
|
|||
let mut moves_seqs = Vec::new();
|
||||
let color = &Color::White;
|
||||
let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter];
|
||||
let mut board = self.board.clone();
|
||||
for moves in self.get_possible_moves_sequences(true, ignored_rules) {
|
||||
let mut board = self.board.clone();
|
||||
board.move_checker(color, moves.0).unwrap();
|
||||
board.move_checker(color, moves.1).unwrap();
|
||||
// println!("get_quarter_filling_moves_sequences board : {:?}", board);
|
||||
if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) {
|
||||
moves_seqs.push(moves);
|
||||
}
|
||||
board.unmove_checker(color, moves.1);
|
||||
board.unmove_checker(color, moves.0);
|
||||
}
|
||||
moves_seqs
|
||||
}
|
||||
|
|
@ -542,18 +564,27 @@ impl MoveRules {
|
|||
dice2: u8,
|
||||
with_excedents: bool,
|
||||
ignore_empty: bool,
|
||||
ignored_rules: Vec<TricTracRule>,
|
||||
ignored_rules: &[TricTracRule],
|
||||
filling_seqs: Option<&[(CheckerMove, CheckerMove)]>,
|
||||
) -> Vec<(CheckerMove, CheckerMove)> {
|
||||
let mut moves_seqs = Vec::new();
|
||||
let color = &Color::White;
|
||||
let forbid_exits = self.has_checkers_outside_last_quarter();
|
||||
// Precompute non-excedant sequences once so check_exit_rules need not repeat
|
||||
// the full move generation for every exit-move candidate.
|
||||
// Only needed when Exit is not already ignored and exits are actually reachable.
|
||||
let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits {
|
||||
Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit]))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut board = self.board.clone();
|
||||
// println!("==== First");
|
||||
for first_move in
|
||||
self.board
|
||||
.get_possible_moves(*color, dice1, with_excedents, false, forbid_exits)
|
||||
{
|
||||
let mut board2 = self.board.clone();
|
||||
if board2.move_checker(color, first_move).is_err() {
|
||||
if board.move_checker(color, first_move).is_err() {
|
||||
println!("err move");
|
||||
continue;
|
||||
}
|
||||
|
|
@ -563,7 +594,7 @@ impl MoveRules {
|
|||
let mut has_second_dice_move = false;
|
||||
// println!(" ==== Second");
|
||||
for second_move in
|
||||
board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
|
||||
board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
|
||||
{
|
||||
if self
|
||||
.check_corner_rules(&(first_move, second_move))
|
||||
|
|
@ -587,24 +618,10 @@ impl MoveRules {
|
|||
&& self.can_take_corner_by_effect())
|
||||
&& (ignored_rules.contains(&TricTracRule::Exit)
|
||||
|| self
|
||||
.check_exit_rules(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd (exit rule): {:?} - {:?}, {:?}",
|
||||
// e, first_move, second_move
|
||||
// )
|
||||
// })
|
||||
.is_ok())
|
||||
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|
||||
|| self
|
||||
.check_must_fill_quarter_rule(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd: {:?} - {:?}, {:?} for {:?}",
|
||||
// e, first_move, second_move, self.board
|
||||
// )
|
||||
// })
|
||||
.check_exit_rules(&(first_move, second_move), exit_seqs.as_deref())
|
||||
.is_ok())
|
||||
&& filling_seqs
|
||||
.map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move)))
|
||||
{
|
||||
if second_move.get_to() == 0
|
||||
&& first_move.get_to() == 0
|
||||
|
|
@ -627,16 +644,14 @@ impl MoveRules {
|
|||
&& !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE))
|
||||
&& self.can_take_corner_by_effect())
|
||||
&& (ignored_rules.contains(&TricTracRule::Exit)
|
||||
|| self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok())
|
||||
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|
||||
|| self
|
||||
.check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE))
|
||||
.is_ok())
|
||||
|| self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok())
|
||||
&& filling_seqs
|
||||
.map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE)))
|
||||
{
|
||||
// empty move
|
||||
moves_seqs.push((first_move, EMPTY_MOVE));
|
||||
}
|
||||
//if board2.get_color_fields(*color).is_empty() {
|
||||
board.unmove_checker(color, first_move);
|
||||
}
|
||||
moves_seqs
|
||||
}
|
||||
|
|
@ -1495,6 +1510,7 @@ mod tests {
|
|||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
);
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1502,7 +1518,8 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
)
|
||||
);
|
||||
|
||||
|
|
@ -1517,6 +1534,7 @@ mod tests {
|
|||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
)];
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
moves,
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1524,7 +1542,8 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
)
|
||||
);
|
||||
let moves = vec![(
|
||||
|
|
@ -1538,7 +1557,8 @@ mod tests {
|
|||
state.dice.values.0,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
)
|
||||
);
|
||||
|
||||
|
|
@ -1554,6 +1574,7 @@ mod tests {
|
|||
CheckerMove::new(19, 21).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1561,7 +1582,8 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
@ -1580,13 +1602,26 @@ mod tests {
|
|||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
|
||||
let moves = (
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
|
||||
state.dice.values = (6, 4);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(20, 24).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -113,11 +113,11 @@ impl TricTrac {
|
|||
[self.get_score(1), self.get_score(2)]
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
self.game_state.to_tensor()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
self.game_state.mirror().to_tensor()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
use std::cmp::{max, min};
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
|
||||
use crate::board::Board;
|
||||
use crate::{CheckerMove, Dice, GameEvent, GameState};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -221,10 +220,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
// Ajoute aussi les mouvements possibles
|
||||
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// rules.board is already White-perspective (mirrored if Black): compute cum once.
|
||||
let cum = rules.board.white_checker_cumulative();
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
valid_actions.push(white_checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
&cum,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
|
@ -235,10 +238,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
// Empty move
|
||||
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
||||
}
|
||||
|
||||
// rules.board is already White-perspective (mirrored if Black): compute cum once.
|
||||
let cum = rules.board.white_checker_cumulative();
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
valid_actions.push(white_checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
&cum,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
|
@ -251,36 +258,27 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
Ok(valid_actions)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
color: &crate::Color,
|
||||
state: &GameState,
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
let dice = &state.dice;
|
||||
let board = &state.board;
|
||||
|
||||
if color == &crate::Color::Black {
|
||||
// Moves are already 'white', so we don't mirror them
|
||||
white_checker_moves_to_trictrac_action(
|
||||
move1,
|
||||
move2,
|
||||
// &move1.clone().mirror(),
|
||||
// &move2.clone().mirror(),
|
||||
dice,
|
||||
&board.clone().mirror(),
|
||||
)
|
||||
// .map(|a| a.mirror())
|
||||
// Moves are always in White's coordinate system. For Black, mirror the board first.
|
||||
let cum = if color == &crate::Color::Black {
|
||||
state.board.mirror().white_checker_cumulative()
|
||||
} else {
|
||||
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
|
||||
}
|
||||
state.board.white_checker_cumulative()
|
||||
};
|
||||
white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum)
|
||||
}
|
||||
|
||||
fn white_checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
board: &Board,
|
||||
cum: &[u8; 25],
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
|
|
@ -302,7 +300,7 @@ fn white_checker_moves_to_trictrac_action(
|
|||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 {
|
||||
if from1 < from2 || from2 == 0 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
|
|
@ -321,11 +319,21 @@ fn white_checker_moves_to_trictrac_action(
|
|||
}
|
||||
let dice_order = diff_move1 == dice.values.0 as usize;
|
||||
|
||||
let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize;
|
||||
let mut tmp_board = board.clone();
|
||||
// should not raise an error for a valid action
|
||||
tmp_board.move_checker(&crate::Color::White, *move1)?;
|
||||
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
|
||||
// cum[i] = # white checkers in fields 1..=i (precomputed by the caller).
|
||||
// checker1 is the ordinal of the last checker at from1.
|
||||
let checker1 = cum[from1] as usize;
|
||||
// checker2 is the ordinal on the board after move1 (removed from from1, added to to1).
|
||||
// Adjust the cumulative in O(1) without cloning the board.
|
||||
let checker2 = {
|
||||
let mut c = cum[from2];
|
||||
if from1 > 0 && from2 >= from1 {
|
||||
c -= 1; // one checker was removed from from1, shifting later ordinals down
|
||||
}
|
||||
if from1 > 0 && to1 > 0 && from2 >= to1 {
|
||||
c += 1; // one checker was added at to1, shifting later ordinals up
|
||||
}
|
||||
c as usize
|
||||
};
|
||||
Ok(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
|
|
@ -456,5 +464,48 @@ mod tests {
|
|||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
|
||||
// Black player
|
||||
state.active_player_id = 2;
|
||||
state.dice = Dice { values: (6, 3) };
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4,
|
||||
],
|
||||
);
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
&CheckerMove::new(21, 0).unwrap(),
|
||||
&CheckerMove::new(0, 0).unwrap(),
|
||||
&crate::Color::Black,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 2,
|
||||
checker2: 0, // blocked by white on last field
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
|
||||
// same with dice order reversed
|
||||
state.dice = Dice { values: (3, 6) };
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
&CheckerMove::new(21, 0).unwrap(),
|
||||
&CheckerMove::new(0, 0).unwrap(),
|
||||
&crate::Color::Black,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: false,
|
||||
checker1: 2,
|
||||
checker2: 0, // blocked by white on last field
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue