diff --git a/Cargo.lock b/Cargo.lock index a43261e..a6c9481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -1116,6 +1122,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cast_trait" version = "0.1.2" @@ -1200,6 +1212,33 @@ dependencies = [ "rand 0.7.3", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1453,6 +1492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -4461,6 +4536,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4597,6 +4678,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5891,6 +6000,19 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spiel_bot" +version = "0.1.0" +dependencies = [ + "anyhow", + "burn", + "criterion", + "rand 0.9.2", + "rand_distr", + "rayon", + "trictrac-store", +] + [[package]] name = "spin" version = "0.10.0" @@ -6299,6 +6421,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index b9e6d45..4c2eb15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store"] +members = ["client_cli", "bot", "store", "spiel_bot"] diff --git a/doc/plan_cxxbindings.md b/doc/plan_cxxbindings.md deleted file mode 100644 index 29bf314..0000000 --- a/doc/plan_cxxbindings.md +++ /dev/null @@ -1,992 +0,0 @@ -# Plan: C++ OpenSpiel Game via cxx.rs - -> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings. -> -> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`. - ---- - -## 1. Overview - -The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding: - -- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`. -- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`. -- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box` and delegates all logic to Rust. -- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`. - -The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games. - ---- - -## 2. Files to Create / Modify - -``` -trictrac/ - store/ - Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type - build.rs ← CREATE: cxx-build bridge registration - src/ - lib.rs ← MODIFY: add cxxengine module - cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl - -forks/open_spiel/ - CMakeLists.txt ← MODIFY: add Corrosion FetchContent - open_spiel/ - games/ - CMakeLists.txt ← MODIFY: add trictrac/ sources + test - trictrac/ ← CREATE directory - trictrac.h ← CREATE - trictrac.cc ← CREATE - trictrac_test.cc ← CREATE - - justfile ← MODIFY: add buildtrictrac target -trictrac/ - justfile ← MODIFY: add cxxlib target -``` - ---- - -## 3. Step 1 — Rust: `store/Cargo.toml` - -Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library. - -```toml -[package] -name = "trictrac-store" -version = "0.1.0" -edition = "2021" - -[lib] -name = "trictrac_store" -# cdylib → Python .so (used by maturin / pyengine) -# rlib → used by other Rust crates in the workspace -# staticlib → used by C++ consumers (cxxengine) -crate-type = ["cdylib", "rlib", "staticlib"] - -[dependencies] -base64 = "0.21.7" -cxx = "1.0" -log = "0.4.20" -merge = "0.1.0" -pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } -rand = "0.9" -serde = { version = "1.0", features = ["derive"] } -transpose = "0.2.2" - -[build-dependencies] -cxx-build = "1.0" -``` - -> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`). - ---- - -## 4. Step 2 — Rust: `store/build.rs` - -The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header. - -```rust -fn main() { - cxx_build::bridge("src/cxxengine.rs") - .std("c++17") - .compile("trictrac-cxx"); - - // Re-run if the bridge source changes - println!("cargo:rerun-if-changed=src/cxxengine.rs"); -} -``` - -`cxx-build` will: - -- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block. -- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header. -- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source. -- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`). - ---- - -## 5. Step 3 — Rust: `store/src/cxxengine.rs` - -This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3. - -### Design decisions vs. `pyengine.rs` - -| pyengine | cxxengine | Reason | -| ------------------------- | ---------------------------- | -------------------------------------------- | -| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception | -| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples | -| `Vec` for actions | `Vec` | cxx does not support `usize` | -| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays | -| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy | - -### File content - -```rust -//! # C++ bindings for the TricTrac game engine via cxx.rs -//! -//! Exposes an opaque `TricTracEngine` type and associated functions -//! to C++. The C++ side (trictrac.cc) uses `rust::Box`. -//! -//! The Rust engine always works from the perspective of White (player 1). -//! For Black (player 2), the board is mirrored before computing actions -//! and events are mirrored back before applying — exactly as in pyengine.rs. - -use crate::dice::Dice; -use crate::game::{GameEvent, GameState, Stage, TurnStage}; -use crate::training_common::{get_valid_action_indices, TrictracAction}; - -// ── cxx bridge declaration ──────────────────────────────────────────────────── - -#[cxx::bridge(namespace = "trictrac_engine")] -pub mod ffi { - // ── Shared types (visible to both Rust and C++) ─────────────────────────── - - /// Two dice values passed from C++ to Rust for a dice-roll event. - struct DicePair { - die1: u8, - die2: u8, - } - - /// Both players' scores: holes * 12 + points. - struct PlayerScores { - score_p1: i32, - score_p2: i32, - } - - // ── Opaque Rust type exposed to C++ ─────────────────────────────────────── - - extern "Rust" { - /// Opaque handle to a TricTrac game state. - /// C++ accesses this only through `rust::Box`. - type TricTracEngine; - - /// Create a new engine, initialise two players, begin with player 1. - fn new_trictrac_engine() -> Box; - - /// Return a deep copy of the engine (needed for State::Clone()). - fn clone_engine(self: &TricTracEngine) -> Box; - - // ── Queries ─────────────────────────────────────────────────────────── - - /// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node). - fn needs_roll(self: &TricTracEngine) -> bool; - - /// True when Stage::Ended. - fn is_game_ended(self: &TricTracEngine) -> bool; - - /// Active player index: 0 (player 1 / White) or 1 (player 2 / Black). - fn current_player_idx(self: &TricTracEngine) -> u64; - - /// Legal action indices for `player_idx`. Returns empty vec if it is - /// not that player's turn. Indices are in [0, 513]. - fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec; - - /// Human-readable action description, e.g. "0:Move { dice_order: true … }". - fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String; - - /// Both players' scores: holes * 12 + points. - fn get_players_scores(self: &TricTracEngine) -> PlayerScores; - - /// 36-element state observation vector (i8). Mirrored for player 1. - fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; - - /// Human-readable state description for `player_idx`. - fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; - - /// Full debug representation of the current state. - fn to_debug_string(self: &TricTracEngine) -> String; - - // ── Mutations ───────────────────────────────────────────────────────── - - /// Apply a dice roll result. Returns Err if not in RollWaiting stage. - fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>; - - /// Apply a player action (move, go, roll). Returns Err if invalid. - fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>; - } -} - -// ── Opaque type implementation ──────────────────────────────────────────────── - -pub struct TricTracEngine { - game_state: GameState, -} - -pub fn new_trictrac_engine() -> Box { - let mut game_state = GameState::new(false); // schools_enabled = false - game_state.init_player("player1"); - game_state.init_player("player2"); - game_state.consume(&GameEvent::BeginGame { goes_first: 1 }); - Box::new(TricTracEngine { game_state }) -} - -impl TricTracEngine { - fn clone_engine(&self) -> Box { - Box::new(TricTracEngine { - game_state: self.game_state.clone(), - }) - } - - fn needs_roll(&self) -> bool { - self.game_state.turn_stage == TurnStage::RollWaiting - } - - fn is_game_ended(&self) -> bool { - self.game_state.stage == Stage::Ended - } - - /// Returns 0 for player 1 (White) and 1 for player 2 (Black). - fn current_player_idx(&self) -> u64 { - self.game_state.active_player_id - 1 - } - - fn get_legal_actions(&self, player_idx: u64) -> Vec { - if player_idx == self.current_player_idx() { - if player_idx == 0 { - get_valid_action_indices(&self.game_state) - .into_iter() - .map(|i| i as u64) - .collect() - } else { - let mirror = self.game_state.mirror(); - get_valid_action_indices(&mirror) - .into_iter() - .map(|i| i as u64) - .collect() - } - } else { - vec![] - } - } - - fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String { - TrictracAction::from_action_index(action_idx as usize) - .map(|a| format!("{}:{}", player_idx, a)) - .unwrap_or_else(|| "unknown action".into()) - } - - fn get_players_scores(&self) -> ffi::PlayerScores { - ffi::PlayerScores { - score_p1: self.score_for(1), - score_p2: self.score_for(2), - } - } - - fn score_for(&self, player_id: u64) -> i32 { - if let Some(player) = self.game_state.players.get(&player_id) { - player.holes as i32 * 12 + player.points as i32 - } else { - -1 - } - } - - fn get_tensor(&self, player_idx: u64) -> Vec { - if player_idx == 0 { - self.game_state.to_vec() - } else { - self.game_state.mirror().to_vec() - } - } - - fn get_observation_string(&self, player_idx: u64) -> String { - if player_idx == 0 { - format!("{}", self.game_state) - } else { - format!("{}", self.game_state.mirror()) - } - } - - fn to_debug_string(&self) -> String { - format!("{}", self.game_state) - } - - fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> { - let player_id = self.game_state.active_player_id; - if self.game_state.turn_stage != TurnStage::RollWaiting { - return Err("Not in RollWaiting stage".into()); - } - let dice = Dice { - values: (dice.die1, dice.die2), - }; - self.game_state - .consume(&GameEvent::RollResult { player_id, dice }); - Ok(()) - } - - fn apply_action(&mut self, action_idx: u64) -> Result<(), String> { - let action_idx = action_idx as usize; - let needs_mirror = self.game_state.active_player_id == 2; - - let event = TrictracAction::from_action_index(action_idx) - .and_then(|a| { - let game_state = if needs_mirror { - &self.game_state.mirror() - } else { - &self.game_state - }; - a.to_event(game_state) - .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) - }); - - match event { - Some(evt) if self.game_state.validate(&evt) => { - self.game_state.consume(&evt); - Ok(()) - } - Some(_) => Err("Action is invalid".into()), - None => Err("Could not build event from action index".into()), - } - } -} -``` - -> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options: -> -> - Use `anyhow::Error` (add `anyhow` dependency). -> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`. -> -> The recommended approach is `anyhow`: -> -> ```toml -> [dependencies] -> anyhow = "1.0" -> ``` -> -> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message. - ---- - -## 6. Step 4 — Rust: `store/src/lib.rs` - -Add the new module: - -```rust -// existing modules … -mod pyengine; - -// NEW: C++ bindings via cxx.rs -pub mod cxxengine; -``` - ---- - -## 7. Step 5 — C++: `trictrac/trictrac.h` - -Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box` and delegates everything to it. - -```cpp -// open_spiel/games/trictrac/trictrac.h -#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_ -#define OPEN_SPIEL_GAMES_TRICTRAC_H_ - -#include -#include -#include - -#include "open_spiel/spiel.h" -#include "open_spiel/spiel_utils.h" - -// Generated by cxx-build from store/src/cxxengine.rs. -// The include path is set by CMake (see CMakeLists.txt). -#include "trictrac_store/src/cxxengine.rs.h" - -namespace open_spiel { -namespace trictrac { - -inline constexpr int kNumPlayers = 2; -inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes -inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust -inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust -inline constexpr int kDefaultMaxTurns = 1000; - -class TrictracGame; - -// --------------------------------------------------------------------------- -// TrictracState -// --------------------------------------------------------------------------- -class TrictracState : public State { - public: - explicit TrictracState(std::shared_ptr game); - TrictracState(const TrictracState& other); - - Player CurrentPlayer() const override; - std::vector LegalActions() const override; - std::string ActionToString(Player player, Action move_id) const override; - std::vector> ChanceOutcomes() const override; - std::string ToString() const override; - bool IsTerminal() const override; - std::vector Returns() const override; - std::string ObservationString(Player player) const override; - void ObservationTensor(Player player, absl::Span values) const override; - std::unique_ptr Clone() const override; - - protected: - void DoApplyAction(Action move_id) override; - - private: - // Decode a chance action index [0,35] to (die1, die2). - // Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action] - static trictrac_engine::DicePair DecodeChanceAction(Action action); - - // The Rust engine handle. Deep-copied via clone_engine() when cloning state. - rust::Box engine_; -}; - -// --------------------------------------------------------------------------- -// TrictracGame -// --------------------------------------------------------------------------- -class TrictracGame : public Game { - public: - explicit TrictracGame(const GameParameters& params); - - int NumDistinctActions() const override { return kNumDistinctActions; } - std::unique_ptr NewInitialState() const override; - int MaxChanceOutcomes() const override { return kNumChanceOutcomes; } - int NumPlayers() const override { return kNumPlayers; } - double MinUtility() const override { return 0.0; } - double MaxUtility() const override { return 200.0; } - int MaxGameLength() const override { return 3 * max_turns_; } - int MaxChanceNodesInHistory() const override { return MaxGameLength(); } - std::vector ObservationTensorShape() const override { - return {kStateEncodingSize}; - } - - private: - int max_turns_; -}; - -} // namespace trictrac -} // namespace open_spiel - -#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_ -``` - ---- - -## 8. Step 6 — C++: `trictrac/trictrac.cc` - -```cpp -// open_spiel/games/trictrac/trictrac.cc -#include "open_spiel/games/trictrac/trictrac.h" - -#include -#include -#include - -#include "open_spiel/abseil-cpp/absl/types/span.h" -#include "open_spiel/game_parameters.h" -#include "open_spiel/spiel.h" -#include "open_spiel/spiel_globals.h" -#include "open_spiel/spiel_utils.h" - -namespace open_spiel { -namespace trictrac { -namespace { - -// ── Game registration ──────────────────────────────────────────────────────── - -const GameType kGameType{ - /*short_name=*/"trictrac", - /*long_name=*/"Trictrac", - GameType::Dynamics::kSequential, - GameType::ChanceMode::kExplicitStochastic, - GameType::Information::kPerfectInformation, - GameType::Utility::kGeneralSum, - GameType::RewardModel::kRewards, - /*min_num_players=*/kNumPlayers, - /*max_num_players=*/kNumPlayers, - /*provides_information_state_string=*/false, - /*provides_information_state_tensor=*/false, - /*provides_observation_string=*/true, - /*provides_observation_tensor=*/true, - /*parameter_specification=*/{ - {"max_turns", GameParameter(kDefaultMaxTurns)}, - }}; - -static std::shared_ptr Factory(const GameParameters& params) { - return std::make_shared(params); -} - -REGISTER_SPIEL_GAME(kGameType, Factory); - -} // namespace - -// ── TrictracGame ───────────────────────────────────────────────────────────── - -TrictracGame::TrictracGame(const GameParameters& params) - : Game(kGameType, params), - max_turns_(ParameterValue("max_turns", kDefaultMaxTurns)) {} - -std::unique_ptr TrictracGame::NewInitialState() const { - return std::make_unique(shared_from_this()); -} - -// ── TrictracState ───────────────────────────────────────────────────────────── - -TrictracState::TrictracState(std::shared_ptr game) - : State(game), - engine_(trictrac_engine::new_trictrac_engine()) {} - -// Copy constructor: deep-copy the Rust engine via clone_engine(). -TrictracState::TrictracState(const TrictracState& other) - : State(other), - engine_(other.engine_->clone_engine()) {} - -std::unique_ptr TrictracState::Clone() const { - return std::make_unique(*this); -} - -// ── Current player ──────────────────────────────────────────────────────────── - -Player TrictracState::CurrentPlayer() const { - if (engine_->is_game_ended()) return kTerminalPlayerId; - if (engine_->needs_roll()) return kChancePlayerId; - return static_cast(engine_->current_player_idx()); -} - -// ── Legal actions ───────────────────────────────────────────────────────────── - -std::vector TrictracState::LegalActions() const { - if (IsChanceNode()) { - // All 36 dice outcomes are equally likely; return indices 0–35. - std::vector actions(kNumChanceOutcomes); - for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i; - return actions; - } - Player player = CurrentPlayer(); - rust::Vec rust_actions = - engine_->get_legal_actions(static_cast(player)); - std::vector actions; - actions.reserve(rust_actions.size()); - for (uint64_t a : rust_actions) actions.push_back(static_cast(a)); - return actions; -} - -// ── Chance outcomes ─────────────────────────────────────────────────────────── - -std::vector> TrictracState::ChanceOutcomes() const { - SPIEL_CHECK_TRUE(IsChanceNode()); - const double p = 1.0 / kNumChanceOutcomes; - std::vector> outcomes; - outcomes.reserve(kNumChanceOutcomes); - for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p); - return outcomes; -} - -// ── Apply action ────────────────────────────────────────────────────────────── - -/*static*/ -trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) { - // Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action] - return trictrac_engine::DicePair{ - /*die1=*/static_cast(action / 6 + 1), - /*die2=*/static_cast(action % 6 + 1), - }; -} - -void TrictracState::DoApplyAction(Action action) { - if (IsChanceNode()) { - engine_->apply_dice_roll(DecodeChanceAction(action)); - } else { - engine_->apply_action(static_cast(action)); - } -} - -// ── Terminal & returns ──────────────────────────────────────────────────────── - -bool TrictracState::IsTerminal() const { - return engine_->is_game_ended(); -} - -std::vector TrictracState::Returns() const { - trictrac_engine::PlayerScores scores = engine_->get_players_scores(); - return {static_cast(scores.score_p1), - static_cast(scores.score_p2)}; -} - -// ── Observation ─────────────────────────────────────────────────────────────── - -std::string TrictracState::ObservationString(Player player) const { - return std::string(engine_->get_observation_string( - static_cast(player))); -} - -void TrictracState::ObservationTensor(Player player, - absl::Span values) const { - SPIEL_CHECK_EQ(values.size(), kStateEncodingSize); - rust::Vec tensor = - engine_->get_tensor(static_cast(player)); - SPIEL_CHECK_EQ(tensor.size(), static_cast(kStateEncodingSize)); - for (int i = 0; i < kStateEncodingSize; ++i) { - values[i] = static_cast(tensor[i]); - } -} - -// ── Strings ─────────────────────────────────────────────────────────────────── - -std::string TrictracState::ToString() const { - return std::string(engine_->to_debug_string()); -} - -std::string TrictracState::ActionToString(Player player, Action action) const { - if (IsChanceNode()) { - trictrac_engine::DicePair d = DecodeChanceAction(action); - return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")"; - } - return std::string(engine_->action_to_string( - static_cast(player), static_cast(action))); -} - -} // namespace trictrac -} // namespace open_spiel -``` - ---- - -## 9. Step 7 — C++: `trictrac/trictrac_test.cc` - -```cpp -// open_spiel/games/trictrac/trictrac_test.cc -#include "open_spiel/games/trictrac/trictrac.h" - -#include -#include - -#include "open_spiel/spiel.h" -#include "open_spiel/tests/basic_tests.h" -#include "open_spiel/utils/init.h" - -namespace open_spiel { -namespace trictrac { -namespace { - -void BasicTrictracTests() { - testing::LoadGameTest("trictrac"); - testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5); -} - -} // namespace -} // namespace trictrac -} // namespace open_spiel - -int main(int argc, char** argv) { - open_spiel::Init(&argc, &argv); - open_spiel::trictrac::BasicTrictracTests(); - std::cout << "trictrac tests passed" << std::endl; - return 0; -} -``` - ---- - -## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt` - -The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined: - -```cmake -# ── Corrosion: CMake integration for Rust ──────────────────────────────────── -include(FetchContent) -FetchContent_Declare( - Corrosion - GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git - GIT_TAG v0.5.1 # pin to a stable release -) -FetchContent_MakeAvailable(Corrosion) - -# Import the trictrac-store Rust crate. -# This creates a CMake target named 'trictrac-store'. -corrosion_import_crate( - MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml - CRATES trictrac-store -) - -# Generate the cxx bridge from cxxengine.rs. -# corrosion_add_cxxbridge: -# - runs cxx-build as part of the Rust build -# - creates a CMake target 'trictrac_cxx_bridge' that: -# * compiles the generated C++ shim -# * exposes INTERFACE include dirs for the generated .rs.h header -corrosion_add_cxxbridge(trictrac_cxx_bridge - CRATE trictrac-store - FILES src/cxxengine.rs -) -``` - -> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing. - ---- - -## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt` - -Two changes: add the new source files to `GAME_SOURCES`, and add a test target. - -### 11.1 Add to `GAME_SOURCES` - -Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add: - -```cmake -set(GAME_SOURCES - # ... existing games ... - trictrac/trictrac.cc - trictrac/trictrac.h - # ... remaining games ... -) -``` - -### 11.2 Link cxx bridge into OpenSpiel objects - -The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed. - -The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add: - -```cmake -target_link_libraries(open_spiel - PUBLIC - trictrac_cxx_bridge # C++ shim generated by cxx-build - trictrac-store # Rust static library -) -``` - -If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game: - -```cmake -add_library(trictrac_game OBJECT - trictrac/trictrac.cc - trictrac/trictrac.h -) -target_include_directories(trictrac_game - PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/.. -) -target_link_libraries(trictrac_game - PUBLIC - trictrac_cxx_bridge - trictrac-store - open_spiel_core # or whatever the core target is called -) -``` - -Then reference `$` in relevant executables. - -### 11.3 Add the test - -```cmake -add_executable(trictrac_test - trictrac/trictrac_test.cc - ${OPEN_SPIEL_OBJECTS} - $ -) -target_link_libraries(trictrac_test - PRIVATE - trictrac_cxx_bridge - trictrac-store -) -add_test(trictrac_test trictrac_test) -``` - ---- - -## 12. Step 10 — Justfile updates - -### `trictrac/justfile` — add `cxxlib` target - -Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists: - -```just -cxxlib: - cargo build --release -p trictrac-store - @echo "Static lib: $(ls target/release/libtrictrac_store.a)" - @echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)" -``` - -### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac` - -```just -buildtrictrac: - # Rebuild the Rust static lib first, then CMake - cd ../../trictrac && cargo build --release -p trictrac-store - mkdir -p build && cd build && \ - CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \ - make -j$(nproc) trictrac_test - -testtrictrac: buildtrictrac - ./build/trictrac_test - -playtrictrac_cpp: - ./build/examples/example --game=trictrac -``` - ---- - -## 13. Key Design Decisions - -### 13.1 Opaque type with `clone_engine()` - -OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box` which calls `.clone()` on the inner `GameState` (which derives `Clone`). - -### 13.2 Action encoding: same 514-element space - -The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means: - -- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies. -- Action 0 = Roll (used as the bridge between Move and the next chance node). -- Actions 2–513 = Move variants (checker ordinal pair + dice order). -- A trained C++ model and Python model share the same action space. - -### 13.3 Chance outcome ordering - -The dice outcome ordering is identical to the Python version: - -``` -action → (die1, die2) -0 → (1,1) 6 → (2,1) ... 35 → (6,6) -``` - -(`die1 = action/6 + 1`, `die2 = action%6 + 1`) - -This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines. - -### 13.4 `GameType::Utility::kGeneralSum` + `kRewards` - -Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal. - -### 13.5 Mirror pattern preserved - -`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest. - -### 13.6 `rust::Box` vs `rust::UniquePtr` - -`rust::Box` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr` is for C++ types held in Rust. - -### 13.7 Separate struct from `pyengine.rs` - -`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids: - -- PyO3 and cxx attributes conflicting on the same type. -- Changes to one binding breaking the other. -- Feature-flag complexity. - ---- - -## 14. Known Challenges - -### 14.1 Corrosion path resolution - -`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable). - -### 14.2 `staticlib` + `cdylib` in one crate - -Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions: - -- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time. -- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile. -- Alternatively, accept the extra compile time during development. - -### 14.3 Linker symbols from Rust std - -When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`). - -### 14.4 `anyhow` for error types - -cxx.rs requires the `Err` type in `Result` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper: - -```rust -use std::fmt; - -#[derive(Debug)] -struct EngineError(String); -impl fmt::Display for EngineError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } -} -impl std::error::Error for EngineError {} -``` - -On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`. - -### 14.5 `UndoAction` not implemented - -OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo. - -### 14.6 Generated header path in `#include` - -The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory. - -### 14.7 `rust::String` to `std::string` conversion - -The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast(...)`. - ---- - -## 15. Complete File Checklist - -``` -[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib -[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...) -[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;` -[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation -[ ] trictrac/justfile — add `cxxlib` target -[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge -[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file -[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file -[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac -``` - ---- - -## 16. Implementation Order - -Implement in this sequence to get early feedback at each step: - -1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced. -2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations. -3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist. -4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them. -5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination. -6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`. - ---- - -# Implementation complete - -All files are in place and trictrac_test passes (168 84 final scores, all assertions pass). - -What was done - -Rust side (from previous session, already complete): - -- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build -- trictrac/store/build.rs — drives cxx-build -- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl -- trictrac/store/src/lib.rs — added pub mod cxxengine; - -C++ side (this session): - -- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header -- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation -- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test - -Build system: - -- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge -- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable - -Justfiles: - -- trictrac/justfile — added cxxlib target -- forks/open_spiel/justfile — added buildtrictrac and testtrictrac - -Fixes discovered during build - -| Issue | Fix | -| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- | -| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries | -| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) | -| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h | -| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc | -| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) | -| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries | -| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort | -| Duplicate actions for doubles | Added std::unique after sort | -| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states | diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md new file mode 100644 index 0000000..d9e021e --- /dev/null +++ b/doc/spiel_bot_parallel.md @@ -0,0 +1,121 @@ +Part B — Batched MCTS leaf evaluation + +Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. + +Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) + +pub trait Evaluator: Send + Sync { +fn evaluate(&self, obs: &[f32]) -> (Vec, f32); + + /// Evaluate a batch of observations at once. Default falls back to + /// sequential calls; backends override this for efficiency. + fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { + obs_batch.iter().map(|obs| self.evaluate(obs)).collect() + } + +} + +Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) + +Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. + +fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { +let b = obs_batch.len(); +let obs_size = obs_batch[0].len(); +let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); +let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); +let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); +let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); +let values: Vec = value_tensor.into_data().to_vec().unwrap(); +let action_size = policies.len() / b; +(0..b).map(|i| { +(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) +}).collect() +} + +Step B3 — Add eval_batch_size to MctsConfig + +pub struct MctsConfig { +// ... existing fields ... +/// Number of leaves to batch per network call. 1 = no batching (current behaviour). +pub eval_batch_size: usize, +} + +Default: 1 (backwards-compatible). + +Step B4 — Make the simulation iterative (mcts/search.rs) + +The current simulate is recursive. For batching we need to split it into two phases: + +descend (pure tree traversal — no network call): + +- Traverse from root following PUCT, advancing through chance nodes with apply_chance. +- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). +- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. +- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. + +ascend (backup — no network call): + +- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. +- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. + +Step B5 — Add run_mcts_batched to mcts/mod.rs + +The new entry point, called by run_mcts when config.eval_batch_size > 1: + +expand root (1 network call) +optionally add Dirichlet noise + +for round in 0..(n*simulations / batch_size): +leaves = [] +for * in 0..batch_size: +leaf = descend(root, state, env, rng) +leaves.push(leaf) + + obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves + where leaf.kind == NeedsEval] + results = evaluator.evaluate_batch(obs_batch) + + for (leaf, result) in zip(leaves, results): + expand the leaf node (insert children from result.policy) + ascend(root, leaf.path, result.value, leaf.player_idx) + // ascend also handles terminal and crossed-chance leaves + +// handle remainder: n_simulations % batch_size + +run_mcts becomes a thin dispatcher: +if config.eval_batch_size <= 1 { +// existing path (unchanged) +} else { +run_mcts_batched(...) +} + +Step B6 — CLI flag in az_train.rs + +--eval-batch N default: 8 Leaf batch size for MCTS network calls + +--- + +Summary of file changes + +┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ +│ File │ Changes │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ spiel_bot/Cargo.toml │ add rayon │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ +└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + +Key design constraint + +Parts A and B are independent and composable: + +- A only touches the outer game loop. +- B only touches the inner MCTS per game. +- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state. diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md new file mode 100644 index 0000000..a8863af --- /dev/null +++ b/doc/spiel_bot_research.md @@ -0,0 +1,782 @@ +# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac + +## 0. Context and Scope + +The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library +(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` +encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every +other stage to an inline random-opponent loop. + +`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency +for **self-play training**. Its goals: + +- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") + that works with Trictrac's multi-stage turn model and stochastic dice. +- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) + as the first algorithm. +- Remain **modular**: adding DQN or PPO later requires only a new + `impl Algorithm for Dqn` without touching the environment or network layers. +- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from + `trictrac-store`. + +--- + +## 1. Library Landscape + +### 1.1 Neural Network Frameworks + +| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | +| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | +| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | +| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | +| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | +| ndarray alone | no | no | yes | mature | array ops only; no autograd | + +**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ +runtime needed, the `ndarray` backend is sufficient for CPU training and can +switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by +changing one type alias. + +`tch-rs` would be the best choice for raw training throughput (it is the most +battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the +pure-Rust constraint. If training speed becomes the bottleneck after prototyping, +switching `spiel_bot` to `tch-rs` is a one-line backend swap. + +### 1.2 Other Key Crates + +| Crate | Role | +| -------------------- | ----------------------------------------------------------------- | +| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | +| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | +| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | +| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | +| `anyhow` | error propagation (already used everywhere) | +| `indicatif` | training progress bars | +| `tracing` | structured logging per episode/iteration | + +### 1.3 What `burn-rl` Provides (and Does Not) + +The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) +provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` +trait. It does **not** provide: + +- MCTS or any tree-search algorithm +- Two-player self-play +- Legal action masking during training +- Chance-node handling + +For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its +own (simpler, more targeted) traits and implement MCTS from scratch. + +--- + +## 2. Trictrac-Specific Design Constraints + +### 2.1 Multi-Stage Turn Model + +A Trictrac turn passes through up to six `TurnStage` values. Only two involve +genuine player choice: + +| TurnStage | Node type | Handler | +| ---------------- | ------------------------------- | ------------------------------- | +| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | +| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | +| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | +| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | +| `Move` | **Player decision** | MCTS / policy network | +| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | + +The environment wrapper advances through forced/chance stages automatically so +that from the algorithm's perspective every node it sees is a genuine player +decision. + +### 2.2 Stochastic Dice in MCTS + +AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice +introduce stochasticity. Three approaches exist: + +**A. Outcome sampling (recommended)** +During each MCTS simulation, when a chance node is reached, sample one dice +outcome at random and continue. After many simulations the expected value +converges. This is the approach used by OpenSpiel's MCTS for stochastic games +and requires no changes to the standard PUCT formula. + +**B. Chance-node averaging (expectimax)** +At each chance node, expand all 21 unique dice pairs weighted by their +probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is +exact but multiplies the branching factor by ~21 at every dice roll, making it +prohibitively expensive. + +**C. Condition on dice in the observation (current approach)** +Dice values are already encoded at indices [192–193] of `to_tensor()`. The +network naturally conditions on the rolled dice when it evaluates a position. +MCTS only runs on player-decision nodes _after_ the dice have been sampled; +chance nodes are bypassed by the environment wrapper (approach A). The policy +and value heads learn to play optimally given any dice pair. + +**Use approach A + C together**: the environment samples dice automatically +(chance node bypass), and the 217-dim tensor encodes the dice so the network +can exploit them. + +### 2.3 Perspective / Mirroring + +All move rules and tensor encoding are defined from White's perspective. +`to_tensor()` must always be called after mirroring the state for Black. +The environment wrapper handles this transparently: every observation returned +to an algorithm is already in the active player's perspective. + +### 2.4 Legal Action Masking + +A crucial difference from the existing `bot/` code: instead of penalizing +invalid actions with `ERROR_REWARD`, the policy head logits are **masked** +before softmax — illegal action logits are set to `-inf`. This prevents the +network from wasting capacity on illegal moves and eliminates the need for the +penalty-reward hack. + +--- + +## 3. Proposed Crate Architecture + +``` +spiel_bot/ +├── Cargo.toml +└── src/ + ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" + │ + ├── env/ + │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface + │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store + │ + ├── mcts/ + │ ├── mod.rs # MctsConfig, run_mcts() entry point + │ ├── node.rs # MctsNode (visit count, W, prior, children) + │ └── search.rs # simulate(), backup(), select_action() + │ + ├── network/ + │ ├── mod.rs # PolicyValueNet trait + │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads + │ + ├── alphazero/ + │ ├── mod.rs # AlphaZeroConfig + │ ├── selfplay.rs # generate_episode() -> Vec + │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) + │ └── trainer.rs # training loop: selfplay → sample → loss → update + │ + └── agent/ + ├── mod.rs # Agent trait + ├── random.rs # RandomAgent (baseline) + └── mcts_agent.rs # MctsAgent: uses trained network for inference +``` + +Future algorithms slot in without touching the above: + +``` + ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer + └── ppo/ # (future) PPO: impl Algorithm + rollout buffer +``` + +--- + +## 4. Core Traits + +### 4.1 `GameEnv` — the minimal OpenSpiel interface + +```rust +use rand::Rng; + +/// Who controls the current node. +pub enum Player { + P1, // player index 0 + P2, // player index 1 + Chance, // dice roll + Terminal, // game over +} + +pub trait GameEnv: Clone + Send + Sync + 'static { + type State: Clone + Send + Sync; + + /// Fresh game state. + fn new_game(&self) -> Self::State; + + /// Who acts at this node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices (always in [0, action_space())). + /// Empty only at Terminal nodes. + fn legal_actions(&self, s: &Self::State) -> Vec; + + /// Apply a player action (must be legal). + fn apply(&self, s: &mut Self::State, action: usize); + + /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. + fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); + + /// Observation tensor from `pov`'s perspective (0 or 1). + /// Returns 217 f32 values for Trictrac. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Flat observation size (217 for Trictrac). + fn obs_size(&self) -> usize; + + /// Total action-space size (514 for Trictrac). + fn action_space(&self) -> usize; + + /// Game outcome per player, or None if not Terminal. + /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} +``` + +### 4.2 `PolicyValueNet` — neural network interface + +```rust +use burn::prelude::*; + +pub trait PolicyValueNet: Send + Sync { + /// Forward pass. + /// `obs`: [batch, obs_size] tensor. + /// Returns: (policy_logits [batch, action_space], value [batch]). + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); + + /// Save weights to `path`. + fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; + + /// Load weights from `path`. + fn load(path: &std::path::Path) -> anyhow::Result + where + Self: Sized; +} +``` + +### 4.3 `Agent` — player policy interface + +```rust +pub trait Agent: Send { + /// Select an action index given the current game state observation. + /// `legal`: mask of valid action indices. + fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; +} +``` + +--- + +## 5. MCTS Implementation + +### 5.1 Node + +```rust +pub struct MctsNode { + n: u32, // visit count N(s, a) + w: f32, // sum of backed-up values W(s, a) + p: f32, // prior from policy head P(s, a) + children: Vec<(usize, MctsNode)>, // (action_idx, child) + is_expanded: bool, +} + +impl MctsNode { + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT score used for selection. + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} +``` + +### 5.2 Simulation Loop + +One MCTS simulation (for deterministic decision nodes): + +``` +1. SELECTION — traverse from root, always pick child with highest PUCT, + auto-advancing forced/chance nodes via env.apply_chance(). +2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get + (policy_logits, value). Mask illegal actions, softmax + the remaining logits → priors P(s,a) for each child. +3. BACKUP — propagate -value up the tree (negate at each level because + perspective alternates between P1 and P2). +``` + +After `n_simulations` iterations, action selection at the root: + +```rust +// During training: sample proportional to N^(1/temperature) +// During evaluation: argmax N +fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } +``` + +### 5.3 Configuration + +```rust +pub struct MctsConfig { + pub n_simulations: usize, // e.g. 200 + pub c_puct: f32, // exploration constant, e.g. 1.5 + pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 + pub dirichlet_eps: f32, // noise weight, e.g. 0.25 + pub temperature: f32, // action sampling temperature (anneals to 0) +} +``` + +### 5.4 Handling Chance Nodes Inside MCTS + +When simulation reaches a Chance node (dice roll), the environment automatically +samples dice and advances to the next decision node. The MCTS tree does **not** +branch on dice outcomes — it treats the sampled outcome as the state. This +corresponds to "outcome sampling" (approach A from §2.2). Because each +simulation independently samples dice, the Q-values at player nodes converge to +their expected value over many simulations. + +--- + +## 6. Network Architecture + +### 6.1 ResNet Policy-Value Network + +A single trunk with residual blocks, then two heads: + +``` +Input: [batch, 217] + ↓ +Linear(217 → 512) + ReLU + ↓ +ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) + ↓ trunk output [batch, 512] + ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) + └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) +``` + +Burn implementation sketch: + +```rust +#[derive(Module, Debug)] +pub struct TrictracNet { + input: Linear, + res_blocks: Vec>, + policy_head: Linear, + value_head: Linear, +} + +impl TrictracNet { + pub fn forward(&self, obs: Tensor) + -> (Tensor, Tensor) + { + let x = activation::relu(self.input.forward(obs)); + let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); + let policy = self.policy_head.forward(x.clone()); // raw logits + let value = activation::tanh(self.value_head.forward(x)) + .squeeze(1); + (policy, value) + } +} +``` + +A simpler MLP (no residual blocks) is sufficient for a first version and much +faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two +heads. + +### 6.2 Loss Function + +``` +L = MSE(value_pred, z) + + CrossEntropy(policy_logits_masked, π_mcts) + - c_l2 * L2_regularization +``` + +Where: + +- `z` = game outcome (±1) from the active player's perspective +- `π_mcts` = normalized MCTS visit counts at the root (the policy target) +- Legal action masking is applied before computing CrossEntropy + +--- + +## 7. AlphaZero Training Loop + +``` +INIT + network ← random weights + replay ← empty ReplayBuffer(capacity = 100_000) + +LOOP forever: + ── Self-play phase ────────────────────────────────────────────── + (parallel with rayon, n_workers games at once) + for each game: + state ← env.new_game() + samples = [] + while not terminal: + advance forced/chance nodes automatically + obs ← env.observation(state, current_player) + legal ← env.legal_actions(state) + π, root_value ← mcts.run(state, network, config) + action ← sample from π (with temperature) + samples.push((obs, π, current_player)) + env.apply(state, action) + z ← env.returns(state) // final scores + for (obs, π, player) in samples: + replay.push(TrainSample { obs, policy: π, value: z[player] }) + + ── Training phase ─────────────────────────────────────────────── + for each gradient step: + batch ← replay.sample(batch_size) + (policy_logits, value_pred) ← network.forward(batch.obs) + loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) + optimizer.step(loss.backward()) + + ── Evaluation (every N iterations) ───────────────────────────── + win_rate ← evaluate(network_new vs network_prev, n_eval_games) + if win_rate > 0.55: save checkpoint +``` + +### 7.1 Replay Buffer + +```rust +pub struct TrainSample { + pub obs: Vec, // 217 values + pub policy: Vec, // 514 values (normalized MCTS visit counts) + pub value: f32, // game outcome ∈ {-1, 0, +1} +} + +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + pub fn push(&mut self, s: TrainSample) { + if self.data.len() == self.capacity { self.data.pop_front(); } + self.data.push_back(s); + } + + pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + // sample without replacement + } +} +``` + +### 7.2 Parallelism Strategy + +Self-play is embarrassingly parallel (each game is independent): + +```rust +let samples: Vec = (0..n_games) + .into_par_iter() // rayon + .flat_map(|_| generate_episode(&env, &network, &mcts_config)) + .collect(); +``` + +Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. +Self-play uses inference-only (no gradient tape), so a `NdArray` backend +(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with +`Autodiff>`. + +For larger scale, a producer-consumer architecture (crossbeam-channel) separates +self-play workers from the training thread, allowing continuous data generation +while the GPU trains. + +--- + +## 8. `TrictracEnv` Implementation Sketch + +```rust +use trictrac_store::{ + training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +#[derive(Clone)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + fn current_player(&self, s: &GameState) -> Player { + match s.stage { + Stage::Ended => Player::Terminal, + _ => match s.turn_stage { + TurnStage::RollWaiting => Player::Chance, + _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, + }, + } + } + + fn legal_actions(&self, s: &GameState) -> Vec { + let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; + get_valid_action_indices(&view).unwrap_or_default() + } + + fn apply(&self, s: &mut GameState, action_idx: usize) { + // advance all forced/chance nodes first, then apply the player action + self.advance_forced(s); + let needs_mirror = s.active_player_id == 2; + let view = if needs_mirror { s.mirror() } else { s.clone() }; + if let Some(event) = TrictracAction::from_action_index(action_idx) + .and_then(|a| a.to_event(&view)) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + { + let _ = s.consume(&event); + } + // advance any forced stages that follow + self.advance_forced(s); + } + + fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { + // RollDice → RollWaiting + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + // RollWaiting → next stage + let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; + let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); + self.advance_forced(s); + } + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } + } + + fn obs_size(&self) -> usize { 217 } + fn action_space(&self) -> usize { ACTION_SPACE_SIZE } + + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { return None; } + // Convert hole+point scores to ±1 outcome + let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [ 1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [ 0.0, 0.0], + }) + } +} + +impl TrictracEnv { + /// Advance through all forced (non-decision, non-chance) stages. + fn advance_forced(&self, s: &mut GameState) { + use trictrac_store::PointsRules; + loop { + match s.turn_stage { + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // Scoring is deterministic; compute and apply automatically. + let color = s.player_color_by_id(&s.active_player_id) + .unwrap_or(trictrac_store::Color::White); + let drc = s.players.get(&s.active_player_id) + .map(|p| p.dice_roll_count).unwrap_or(0); + let pr = PointsRules::new(&color, &s.board, s.dice); + let pts = pr.get_points(drc); + let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; + let _ = s.consume(&GameEvent::Mark { + player_id: s.active_player_id, points, + }); + } + TurnStage::RollDice => { + // RollDice is a forced "initiate roll" action with no real choice. + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + } + _ => break, + } + } + } +} +``` + +--- + +## 9. Cargo.toml Changes + +### 9.1 Add `spiel_bot` to the workspace + +```toml +# Cargo.toml (workspace root) +[workspace] +resolver = "2" +members = ["client_cli", "bot", "store", "spiel_bot"] +``` + +### 9.2 `spiel_bot/Cargo.toml` + +```toml +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[features] +default = ["alphazero"] +alphazero = [] +# dqn = [] # future +# ppo = [] # future + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rayon = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Burn: NdArray for pure-Rust CPU training +# Replace NdArray with Wgpu or Tch for GPU. +burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +# Optional: progress display and structured logging +indicatif = "0.17" +tracing = "0.1" + +[[bin]] +name = "az_train" +path = "src/bin/az_train.rs" + +[[bin]] +name = "az_eval" +path = "src/bin/az_eval.rs" +``` + +--- + +## 10. Comparison: `bot` crate vs `spiel_bot` + +| Aspect | `bot` (existing) | `spiel_bot` (proposed) | +| ---------------- | --------------------------- | -------------------------------------------- | +| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | +| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | +| Opponent | hardcoded random | self-play | +| Invalid actions | penalise with reward | legal action mask (no penalty) | +| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | +| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | +| Burn dep | yes (0.20) | yes (0.20), same backend | +| `burn-rl` dep | yes | no | +| C++ dep | no | no | +| Python dep | no | no | +| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | + +The two crates are **complementary**: `bot` is a working DQN/PPO baseline; +`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The +`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just +replace `TrictracEnvironment` with `TrictracEnv`). + +--- + +## 11. Implementation Order + +1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game + through the trait, verify terminal state and returns). +2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + + Burn forward/backward pass test with dummy data. +3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests + (visit counts sum to `n_simulations`, legal mask respected). +4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub + (one iteration, check loss decreases). +5. **Integration test**: run 100 self-play games with a tiny network (1 res block, + 64 hidden units), verify the training loop completes without panics. +6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s + on CPU, consistent with `random_game` throughput). +7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule + hyperparameter sweep. +8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report + win rate every checkpoint. + +--- + +## 12. Key Open Questions + +1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. + AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever + scored more holes). Better option: normalize the score margin. The final + choice affects how the value head is trained. + +2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). + MCTS with 200 simulations per step means ~120k network evaluations per game. + At batch inference this is feasible on CPU; on GPU it becomes fast. Consider + limiting `n_simulations` to 50–100 for early training. + +3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). + This is a long-horizon decision that AlphaZero handles naturally via MCTS + lookahead, but needs careful value normalization (a "Go" restarts scoring + within the same game). + +4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated + to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. + This is optional but reduces code duplication. + +5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, + 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. + A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. + +## Implementation results + +All benchmarks compile and run. Here's the complete results table: + +| Group | Benchmark | Time | +| ------- | ----------------------- | --------------------- | +| env | apply_chance | 3.87 µs | +| | legal_actions | 1.91 µs | +| | observation (to_tensor) | 341 ns | +| | random_game (baseline) | 3.55 ms → 282 games/s | +| network | mlp_b1 hidden=64 | 94.9 µs | +| | mlp_b32 hidden=64 | 141 µs | +| | mlp_b1 hidden=256 | 352 µs | +| | mlp_b32 hidden=256 | 479 µs | +| mcts | zero_eval n=1 | 6.8 µs | +| | zero_eval n=5 | 23.9 µs | +| | zero_eval n=20 | 90.9 µs | +| | mlp64 n=1 | 203 µs | +| | mlp64 n=5 | 622 µs | +| | mlp64 n=20 | 2.30 ms | +| episode | trictrac n=1 | 51.8 ms → 19 games/s | +| | trictrac n=2 | 145 ms → 7 games/s | +| train | mlp64 Adam b=16 | 1.93 ms | +| | mlp64 Adam b=64 | 2.68 ms | + +Key observations: + +- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) +- observation (217-value tensor): only 341 ns — not a bottleneck +- legal_actions: 1.9 µs — well optimised +- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs +- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal +- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it +- Training: 2.7 ms/step at batch=64 → 370 steps/s + +### Summary of Step 8 + +spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: + +- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct +- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) +- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side +- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) +- Output: win/draw/loss per side + combined decisive win rate + +Typical usage after training: +cargo run -p spiel_bot --bin az_eval --release -- \ + --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 + +### az_train + +#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) + +cargo run -p spiel_bot --bin az_train --release + +#### ResNet, more games, custom output dir + +cargo run -p spiel_bot --bin az_train --release -- \ + --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ + --save-every 20 --out checkpoints/ + +#### Resume from iteration 50 + +cargo run -p spiel_bot --bin az_train --release -- \ + --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 + +What the binary does each iteration: + +1. Calls model.valid() to get a zero-overhead inference copy for self-play +2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) +3. Pushes samples into a ReplayBuffer (capacity --replay-cap) +4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min +5. Saves a .mpk checkpoint every --save-every iterations and always on the last diff --git a/doc/tensor_research.md b/doc/tensor_research.md new file mode 100644 index 0000000..b0d0ede --- /dev/null +++ b/doc/tensor_research.md @@ -0,0 +1,253 @@ +# Tensor research + +## Current tensor anatomy + +[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) +[24] active player color: 0 or 1 +[25] turn_stage: 1–5 +[26–27] dice values (raw 1–6) +[28–31] white: points, holes, can_bredouille, can_big_bredouille +[32–35] black: same +───────────────────────────────── +Total 36 floats + +The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's +AlphaZero uses a fully-connected network. + +### Fundamental problems with the current encoding + +1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network + must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with + all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. + +2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient + flow during training is uneven. + +3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it + triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers + produces a score. Including this explicitly is the single highest-value addition. + +4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely + different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. + +5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the + starting position). It's in the Player struct but not exported. + +## Key Trictrac distinctions from backgammon that shape the encoding + +| Concept | Backgammon | Trictrac | +| ------------------------- | ---------------------- | --------------------------------------------------------- | +| Hitting a blot | Removes checker to bar | Scores points, checker stays | +| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | +| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | +| 3-checker field | Safe with spare | Safe with spare | +| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | +| Both colors on a field | Impossible | Perfectly legal | +| Rest corner (field 12/13) | Does not exist | Special two-checker rules | + +The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary +indicators directly teaches the network the phase transitions the game hinges on. + +## Options + +### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) + +The minimum viable improvement. + +For each of the 24 fields, encode own and opponent separately with 4 indicators each: + +own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) +own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) +own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) +own_x[i]: max(0, count − 3) (overflow) +opp_1[i]: same for opponent +… + +Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). + +Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 +Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured +overhead is negligible. +Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from +scratch will be substantially faster and the converged policy quality better. + +### Option B — Option A + Trictrac-specific derived features (flat 1D) + +Recommended starting point. + +Add on top of Option A: + +// Quarter fill status — the single most important derived feature +quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) +quarter_filled_opp[q] (q=0..3): same for opponent +→ 8 values + +// Exit readiness +can_exit_own: 1.0 if all own checkers are in fields 19–24 +can_exit_opp: same for opponent +→ 2 values + +// Rest corner status (field 12/13) +own_corner_taken: 1.0 if field 12 has ≥2 own checkers +opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers +→ 2 values + +// Jan de 3 coups counter (normalized) +dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) +→ 1 value + +Size: 204 + 8 + 2 + 2 + 1 = 217 +Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve +the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes +expensive inference from the network. + +### Option C — Option B + richer positional features (flat 1D) + +More complete, higher sample efficiency, minor extra cost. + +Add on top of Option B: + +// Per-quarter fill fraction — how close to filling each quarter +own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 +opp_quarter_fill_fraction[q] (q=0..3): same for opponent +→ 8 values + +// Blot counts — number of own/opponent single-checker fields globally +// (tells the network at a glance how much battage risk/opportunity exists) +own_blot_count: (number of own fields with exactly 1 checker) / 15.0 +opp_blot_count: same for opponent +→ 2 values + +// Bredouille would-double multiplier (already present, but explicitly scaled) +// No change needed, already binary + +Size: 217 + 8 + 2 = 227 +Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network +from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). + +### Option D — 2D spatial tensor {K, 24} + +For CNN-based networks. Best eventual architecture but requires changing the training setup. + +Shape {14, 24} — 14 feature channels over 24 field positions: + +Channel 0: own_count_1 (blot) +Channel 1: own_count_2 +Channel 2: own_count_3 +Channel 3: own_count_overflow (float) +Channel 4: opp_count_1 +Channel 5: opp_count_2 +Channel 6: opp_count_3 +Channel 7: opp_count_overflow +Channel 8: own_corner_mask (1.0 at field 12) +Channel 9: opp_corner_mask (1.0 at field 13) +Channel 10: final_quarter_mask (1.0 at fields 19–24) +Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) +Channel 12: quarter_filled_opp (same for opponent) +Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) + +Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform +value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" +row by returning shape {K, 25} with position 0 holding global features. + +Size: 14 × 24 + few global channels ≈ 336–384 +C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated +accordingly. +Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's +alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, +W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. +Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel +size 2–3 captures adjacent-field "tout d'une" interactions. + +### On 3D tensors + +Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is +the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and +field-within-quarter patterns jointly. + +However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The +extra architecture work makes this premature unless you're already building a custom network. The information content +is the same as Option D. + +### Recommendation + +Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and +the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions +(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. + +Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when +the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. + +Costs summary: + +| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | +| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | +| Current | 36 | — | — | — | baseline | +| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | +| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | +| C | 227 | to_vec() rewrite | constant update | none | large + moderate | +| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | + +One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, +the new to_vec() must express everything from the active player's perspective (which the mirror already handles for +the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, +which they will be if computed inside to_vec(). + +## Other algorithms + +The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. + +### 1. Without MCTS, feature quality matters more + +AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred +simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — +the network must learn the full strategic structure directly from gradient updates. + +This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, +not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less +sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from +scratch. + +### 2. Normalization becomes mandatory, not optional + +AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps +Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw +values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can +be unstable. + +For DQN/PPO, every feature in the tensor should be in [0, 1]: + +dice values: / 6.0 +checker counts: overflow channel / 12.0 +points: / 12.0 +holes: / 12.0 +dice_roll_count: / 3.0 (clamped) + +Booleans and the TD-Gammon binary indicators are already in [0, 1]. + +### 3. The shape question depends on architecture, not algorithm + +| Architecture | Shape | When to use | +| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | +| MLP | {217} flat | Any algorithm, simplest baseline | +| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | +| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | +| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | + +The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want +convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet +template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all +— you just reshape the flat tensor in Python before passing it to the network. + +### One thing that genuinely changes: value function perspective + +AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This +works well. + +DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit +current-player indicator), because a single Q-network estimates action values for both players simultaneously. With +the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must +learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. +This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. + +Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml new file mode 100644 index 0000000..682505b --- /dev/null +++ b/spiel_bot/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rand_distr = "0.5" +burn = { version = "0.20", features = ["ndarray", "autodiff"] } +rayon = "1" + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "alphazero" +harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs new file mode 100644 index 0000000..00d5b02 --- /dev/null +++ b/spiel_bot/benches/alphazero.rs @@ -0,0 +1,373 @@ +//! AlphaZero pipeline benchmarks. +//! +//! Run with: +//! +//! ```sh +//! cargo bench -p spiel_bot +//! ``` +//! +//! Use `-- ` to run a specific group, e.g.: +//! +//! ```sh +//! cargo bench -p spiel_bot -- env/ +//! cargo bench -p spiel_bot -- network/ +//! cargo bench -p spiel_bot -- mcts/ +//! cargo bench -p spiel_bot -- episode/ +//! cargo bench -p spiel_bot -- train/ +//! ``` +//! +//! Target: ≥ 500 games/s for random play on CPU (consistent with +//! `random_game` throughput in `trictrac-store`). + +use std::time::Duration; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData, backend::Backend}, +}; +use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts}, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, +}; + +// ── Shared types ─────────────────────────────────────────────────────────── + +type InferB = NdArray; +type TrainB = burn::backend::Autodiff>; + +fn infer_device() -> ::Device { Default::default() } +fn train_device() -> ::Device { Default::default() } + +fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } + +/// Uniform evaluator (returns zero logits and zero value). +/// Used to isolate MCTS tree-traversal cost from network cost. +struct ZeroEval(usize); +impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } +} + +// ── 1. Environment primitives ────────────────────────────────────────────── + +/// Baseline performance of the raw Trictrac environment without MCTS. +/// Target: ≥ 500 full games / second. +fn bench_env(c: &mut Criterion) { + let env = TrictracEnv; + + let mut group = c.benchmark_group("env"); + group.measurement_time(Duration::from_secs(10)); + + // ── apply_chance ────────────────────────────────────────────────────── + group.bench_function("apply_chance", |b| { + b.iter_batched( + || { + // A fresh game is always at RollDice (Chance) — ready for apply_chance. + env.new_game() + }, + |mut s| { + env.apply_chance(&mut s, &mut seeded()); + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + // ── legal_actions ───────────────────────────────────────────────────── + group.bench_function("legal_actions", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.legal_actions(&s))) + }); + + // ── observation (to_tensor) ─────────────────────────────────────────── + group.bench_function("observation", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.observation(&s, 0))) + }); + + // ── full random game ────────────────────────────────────────────────── + group.sample_size(50); + group.bench_function("random_game", |b| { + b.iter_batched( + seeded, + |mut rng| { + let mut s = env.new_game(); + loop { + match env.current_player(&s) { + Player::Terminal => break, + Player::Chance => env.apply_chance(&mut s, &mut rng), + _ => { + let actions = env.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + env.apply(&mut s, actions[idx]); + } + } + } + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +// ── 2. Network inference ─────────────────────────────────────────────────── + +/// Forward-pass latency for MLP variants (hidden = 64 / 256). +fn bench_network(c: &mut Criterion) { + let mut group = c.benchmark_group("network"); + group.measurement_time(Duration::from_secs(5)); + + for &hidden in &[64usize, 256] { + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = MlpNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + // Batch size 1 — single-position evaluation as in MCTS. + group.bench_with_input( + BenchmarkId::new("mlp_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + // Batch size 32 — training mini-batch. + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("mlp_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + // ── ResNet (4 residual blocks) ──────────────────────────────────────── + for &hidden in &[256usize, 512] { + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = ResNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + group.bench_with_input( + BenchmarkId::new("resnet_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("resnet_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + group.finish(); +} + +// ── 3. MCTS ─────────────────────────────────────────────────────────────── + +/// MCTS cost at different simulation budgets with two evaluator types: +/// - `zero` — isolates tree-traversal overhead (no network). +/// - `mlp64` — real MLP, shows end-to-end cost per move. +fn bench_mcts(c: &mut Criterion) { + let env = TrictracEnv; + + // Build a decision-node state (after dice roll). + let state = { + let mut s = env.new_game(); + let mut rng = seeded(); + while env.current_player(&s).is_chance() { + env.apply_chance(&mut s, &mut rng); + } + s + }; + + let mut group = c.benchmark_group("mcts"); + group.measurement_time(Duration::from_secs(10)); + + let zero_eval = ZeroEval(514); + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); + let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); + + for &n_sim in &[1usize, 5, 20] { + let cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + // Zero evaluator: tree traversal only. + group.bench_with_input( + BenchmarkId::new("zero_eval", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + + // MLP evaluator: full cost per decision. + group.bench_with_input( + BenchmarkId::new("mlp64", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 4. Episode generation ───────────────────────────────────────────────── + +/// Full self-play episode latency (one complete game) at different MCTS +/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. +fn bench_episode(c: &mut Criterion) { + let env = TrictracEnv; + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let model = MlpNet::::new(&mlp_cfg, &infer_device()); + let eval = BurnEvaluator::::new(model, infer_device()); + + let mut group = c.benchmark_group("episode"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(60)); + + for &n_sim in &[1usize, 2] { + let mcts_cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + group.bench_with_input( + BenchmarkId::new("trictrac", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| { + black_box(generate_episode( + &env, + &eval, + &mcts_cfg, + &|_| 1.0, + &mut rng, + )) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 5. Training step ─────────────────────────────────────────────────────── + +/// Gradient-step latency for different batch sizes. +fn bench_train(c: &mut Criterion) { + use burn::optim::AdamConfig; + + let mut group = c.benchmark_group("train"); + group.measurement_time(Duration::from_secs(10)); + + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + + let dummy_samples = |n: usize| -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5; 217], + policy: { + let mut p = vec![0.0f32; 514]; + p[i % 514] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + }; + + for &batch_size in &[16usize, 64] { + let batch = dummy_samples(batch_size); + + group.bench_with_input( + BenchmarkId::new("mlp64_adam", batch_size), + &batch_size, + |b, _| { + b.iter_batched( + || { + ( + MlpNet::::new(&mlp_cfg, &train_device()), + AdamConfig::new().init::>(), + ) + }, + |(model, mut opt)| { + black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── Criterion entry point ────────────────────────────────────────────────── + +criterion_group!( + benches, + bench_env, + bench_network, + bench_mcts, + bench_episode, + bench_train, +); +criterion_main!(benches); diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs new file mode 100644 index 0000000..d92224e --- /dev/null +++ b/spiel_bot/src/alphazero/mod.rs @@ -0,0 +1,127 @@ +//! AlphaZero: self-play data generation, replay buffer, and training step. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | +//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | +//! | [`trainer`] | [`train_step`] | +//! +//! # Typical outer loop +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! use burn::optim::AdamConfig; +//! use spiel_bot::{ +//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, +//! env::TrictracEnv, +//! mcts::MctsConfig, +//! network::{MlpConfig, MlpNet}, +//! }; +//! +//! type Infer = NdArray; +//! type Train = Autodiff>; +//! +//! let device = Default::default(); +//! let env = TrictracEnv; +//! let config = AlphaZeroConfig::default(); +//! +//! // Build training model and optimizer. +//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); +//! let mut optimizer = AdamConfig::new().init(); +//! let mut replay = ReplayBuffer::new(config.replay_capacity); +//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); +//! +//! for _iter in 0..config.n_iterations { +//! // Convert to inference backend for self-play. +//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) +//! .load_record(train_model.clone().into_record()); +//! let eval = BurnEvaluator::new(infer_model, device.clone()); +//! +//! // Self-play: generate episodes. +//! for _ in 0..config.n_games_per_iter { +//! let samples = generate_episode(&env, &eval, &config.mcts, +//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); +//! replay.extend(samples); +//! } +//! +//! // Training: gradient steps. +//! if replay.len() >= config.batch_size { +//! for _ in 0..config.n_train_steps_per_iter { +//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) +//! .into_iter().cloned().collect(); +//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, +//! config.learning_rate); +//! train_model = m; +//! } +//! } +//! } +//! ``` + +pub mod replay; +pub mod selfplay; +pub mod trainer; + +pub use replay::{ReplayBuffer, TrainSample}; +pub use selfplay::{BurnEvaluator, generate_episode}; +pub use trainer::{cosine_lr, train_step}; + +use crate::mcts::MctsConfig; + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Top-level AlphaZero hyperparameters. +/// +/// The MCTS parameters live in [`MctsConfig`]; this struct holds the +/// outer training-loop parameters. +#[derive(Debug, Clone)] +pub struct AlphaZeroConfig { + /// MCTS parameters for self-play. + pub mcts: MctsConfig, + /// Number of self-play games per training iteration. + pub n_games_per_iter: usize, + /// Number of gradient steps per training iteration. + pub n_train_steps_per_iter: usize, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Initial (peak) Adam learning rate. + pub learning_rate: f64, + /// Minimum learning rate for cosine annealing (floor of the schedule). + /// + /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). + /// Compute the current LR with [`cosine_lr`]: + /// + /// ```rust,ignore + /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); + /// ``` + pub lr_min: f64, + /// Number of outer iterations (self-play + train) to run. + pub n_iterations: usize, + /// Move index after which the action temperature drops to 0 (greedy play). + pub temperature_drop_move: usize, +} + +impl Default for AlphaZeroConfig { + fn default() -> Self { + Self { + mcts: MctsConfig { + n_simulations: 100, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temperature: 1.0, + }, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + batch_size: 64, + replay_capacity: 50_000, + learning_rate: 1e-3, + lr_min: 1e-4, // cosine annealing floor + n_iterations: 100, + temperature_drop_move: 30, + } + } +} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs new file mode 100644 index 0000000..5e64cc4 --- /dev/null +++ b/spiel_bot/src/alphazero/replay.rs @@ -0,0 +1,144 @@ +//! Replay buffer for AlphaZero self-play data. + +use std::collections::VecDeque; +use rand::Rng; + +// ── Training sample ──────────────────────────────────────────────────────── + +/// One training example produced by self-play. +#[derive(Clone, Debug)] +pub struct TrainSample { + /// Observation tensor from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). + pub policy: Vec, + /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. + pub value: f32, +} + +// ── Replay buffer ────────────────────────────────────────────────────────── + +/// Fixed-capacity circular buffer of [`TrainSample`]s. +/// +/// When the buffer is full, the oldest sample is evicted on push. +/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + /// Create a buffer with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity.min(1024)), + capacity, + } + } + + /// Add a sample; evicts the oldest if at capacity. + pub fn push(&mut self, sample: TrainSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + /// Add all samples from an episode. + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { + self.push(s); + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Sample up to `n` distinct samples, without replacement. + /// + /// If the buffer has fewer than `n` samples, all are returned (shuffled). + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + let len = self.data.len(); + let n = n.min(len); + // Partial Fisher-Yates using index shuffling. + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(value: f32) -> TrainSample { + TrainSample { obs: vec![value], policy: vec![1.0], value } + } + + #[test] + fn push_and_len() { + let mut buf = ReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = ReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); // evicts 1.0 + assert_eq!(buf.len(), 3); + // Oldest remaining should be 2.0 + assert_eq!(buf.data[0].value, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(5, &mut rng); + assert_eq!(batch.len(), 5); + } + + #[test] + fn sample_batch_capped_at_len() { + let mut buf = ReplayBuffer::new(20); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(100, &mut rng); + assert_eq!(batch.len(), 2); + } + + #[test] + fn sample_batch_no_duplicates() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(1); + let batch = buf.sample_batch(10, &mut rng); + let mut seen: Vec = batch.iter().map(|s| s.value).collect(); + seen.sort_by(f32::total_cmp); + seen.dedup(); + assert_eq!(seen.len(), 10, "sample contained duplicates"); + } +} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs new file mode 100644 index 0000000..b38b7f4 --- /dev/null +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -0,0 +1,238 @@ +//! Self-play episode generation and Burn-backed evaluator. + +use std::marker::PhantomData; + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::GameEnv; +use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; +use crate::network::PolicyValueNet; + +use super::replay::TrainSample; + +// ── BurnEvaluator ────────────────────────────────────────────────────────── + +/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. +/// +/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so +/// that self-play generates no gradient tape overhead. +pub struct BurnEvaluator> { + model: N, + device: B::Device, + _b: PhantomData, +} + +impl> BurnEvaluator { + pub fn new(model: N, device: B::Device) -> Self { + Self { model, device, _b: PhantomData } + } + + pub fn into_model(self) -> N { + self.model + } + + pub fn model_ref(&self) -> &N { + &self.model + } +} + +// Safety: NdArray modules are Send; we never share across threads without +// external synchronisation. +unsafe impl> Send for BurnEvaluator {} +unsafe impl> Sync for BurnEvaluator {} + +impl> Evaluator for BurnEvaluator { + fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { + let obs_size = obs.len(); + let data = TensorData::new(obs.to_vec(), [1, obs_size]); + let obs_tensor = Tensor::::from_data(data, &self.device); + + let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); + + let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); + let value: Vec = value_tensor.into_data().to_vec().unwrap(); + + (policy, value[0]) + } +} + +// ── Episode generation ───────────────────────────────────────────────────── + +/// One pending observation waiting for its game-outcome value label. +struct PendingSample { + obs: Vec, + policy: Vec, + player: usize, +} + +/// Play one full game using MCTS guided by `evaluator`. +/// +/// Returns a [`TrainSample`] for every decision step in the game. +/// +/// `temperature_fn(step)` controls exploration: return `1.0` for early +/// moves and `0.0` after a fixed number of moves (e.g. move 30). +pub fn generate_episode( + env: &E, + evaluator: &dyn Evaluator, + mcts_config: &MctsConfig, + temperature_fn: &dyn Fn(usize) -> f32, + rng: &mut impl Rng, +) -> Vec { + let mut state = env.new_game(); + let mut pending: Vec = Vec::new(); + let mut step = 0usize; + + loop { + // Advance through chance nodes automatically. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + if env.current_player(&state).is_terminal() { + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + + // Run MCTS to get a policy. + let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); + let policy = mcts::mcts_policy(&root, env.action_space()); + + // Record the observation from the acting player's perspective. + let obs = env.observation(&state, player_idx); + pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); + + // Select and apply the action. + let temperature = temperature_fn(step); + let action = mcts::select_action(&root, temperature, rng); + env.apply(&mut state, action); + step += 1; + } + + // Assign game outcomes. + let returns = env.returns(&state).unwrap_or([0.0; 2]); + pending + .into_iter() + .map(|s| TrainSample { + obs: s.obs, + policy: s.policy, + value: returns[s.player], + }) + .collect() +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::env::Player; + use crate::mcts::{Evaluator, MctsConfig}; + use crate::network::{MlpConfig, MlpNet}; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(7) + } + + // Countdown game (same as in mcts tests). + #[derive(Clone, Debug)] + struct CState { remaining: u8, to_move: usize } + + #[derive(Clone)] + struct CountdownEnv; + + impl GameEnv for CountdownEnv { + type State = CState; + fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } else { Player::P2 } + } + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { s.remaining = 0; } + else { s.remaining -= sub; s.to_move = 1 - s.to_move; } + } + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 4.0, s.to_move as f32] + } + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + fn tiny_config() -> MctsConfig { + MctsConfig { n_simulations: 5, c_puct: 1.5, + dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } + } + + // ── BurnEvaluator tests ─────────────────────────────────────────────── + + #[test] + fn burn_evaluator_output_shapes() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); + assert_eq!(policy.len(), 2, "policy length should equal action_space"); + assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); + } + + // ── generate_episode tests ──────────────────────────────────────────── + + #[test] + fn episode_terminates_and_has_samples() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_sample_values_are_valid() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + for s in &samples { + assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "unexpected value {}", s.value); + let sum: f32 = s.policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); + assert_eq!(s.obs.len(), 2); + } + } + + #[test] + fn episode_with_temperature_zero() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + // temperature=0 means greedy; episode must still terminate + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); + assert!(!samples.is_empty()); + } +} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs new file mode 100644 index 0000000..9075519 --- /dev/null +++ b/spiel_bot/src/alphazero/trainer.rs @@ -0,0 +1,258 @@ +//! One gradient-descent training step for AlphaZero. +//! +//! The loss combines: +//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. +//! - **Value loss** — mean-squared error between the predicted value and the +//! actual game outcome. +//! +//! # Learning-rate scheduling +//! +//! [`cosine_lr`] implements one-cycle cosine annealing: +//! +//! ```text +//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) +//! ``` +//! +//! Typical usage in the outer loop: +//! +//! ```rust,ignore +//! for step in 0..total_train_steps { +//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); +//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); +//! model = m; +//! } +//! ``` +//! +//! # Backend +//! +//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). +//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. +//! Weights are transferred between the two via [`burn::record`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + activation::log_softmax, + backend::AutodiffBackend, + Tensor, TensorData, + }, +}; + +use crate::network::PolicyValueNet; +use super::replay::TrainSample; + +/// Run one gradient step on `model` using `batch`. +/// +/// Returns the updated model and the scalar loss value for logging. +/// +/// # Parameters +/// +/// - `lr` — learning rate (e.g. `1e-3`). +/// - `batch` — slice of [`TrainSample`]s; must be non-empty. +pub fn train_step( + model: N, + optimizer: &mut O, + batch: &[TrainSample], + device: &B::Device, + lr: f64, +) -> (N, f32) +where + B: AutodiffBackend, + N: PolicyValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "train_step called with empty batch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + let action_size = batch[0].policy.len(); + + // ── Build input tensors ──────────────────────────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); + let value_flat: Vec = batch.iter().map(|s| s.value).collect(); + + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + let policy_target = Tensor::::from_data( + TensorData::new(policy_flat, [batch_size, action_size]), + device, + ); + let value_target = Tensor::::from_data( + TensorData::new(value_flat, [batch_size, 1]), + device, + ); + + // ── Forward pass ────────────────────────────────────────────────────── + let (policy_logits, value_pred) = model.forward(obs_tensor); + + // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── + let log_probs = log_softmax(policy_logits, 1); + let policy_loss = (policy_target.clone().neg() * log_probs) + .sum_dim(1) + .mean(); + + // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── + let diff = value_pred - value_target; + let value_loss = (diff.clone() * diff).mean(); + + // ── Combined loss ───────────────────────────────────────────────────── + let loss = policy_loss + value_loss; + + // Extract scalar before backward (consumes the tensor). + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + let model = optimizer.step(lr, model, grads); + + (model, loss_scalar) +} + +// ── Learning-rate schedule ───────────────────────────────────────────────── + +/// Cosine learning-rate schedule (one half-period, no warmup). +/// +/// Returns the learning rate for training step `step` out of `total_steps`: +/// +/// ```text +/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) +/// ``` +/// +/// - At `t = 0` returns `initial`. +/// - At `t = total_steps` (or beyond) returns `lr_min`. +/// +/// # Panics +/// +/// Does not panic. When `total_steps == 0`, returns `lr_min`. +pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { + if total_steps == 0 || step >= total_steps { + return lr_min; + } + let progress = step as f64 / total_steps as f64; + lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + + use crate::network::{MlpConfig, MlpNet}; + use super::super::replay::TrainSample; + + type B = Autodiff>; + + fn device() -> ::Device { + Default::default() + } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5f32; obs_size], + policy: { + let mut p = vec![0.0f32; action_size]; + p[i % action_size] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + } + + #[test] + fn train_step_returns_finite_loss() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); + } + + #[test] + fn loss_decreases_over_steps() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + // Same batch every step — loss should decrease. + let batch = dummy_batch(16, 4, 4); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); + model = m; + assert!(loss.is_finite()); + prev_loss = loss; + } + // After 10 steps on fixed data, loss should be below a reasonable threshold. + assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn train_step_batch_size_one() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(1, 2, 2); + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite()); + } + + // ── cosine_lr ───────────────────────────────────────────────────────── + + #[test] + fn cosine_lr_at_step_zero_is_initial() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); + assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); + } + + #[test] + fn cosine_lr_at_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); + } + + #[test] + fn cosine_lr_beyond_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); + } + + #[test] + fn cosine_lr_midpoint_is_average() { + // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. + let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); + let expected = (1e-3 + 1e-5) / 2.0; + assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); + } + + #[test] + fn cosine_lr_monotone_decreasing() { + let mut prev = f64::INFINITY; + for step in 0..=100 { + let lr = super::cosine_lr(1e-3, 1e-5, step, 100); + assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); + prev = lr; + } + } + + #[test] + fn cosine_lr_zero_total_steps_returns_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); + assert!((lr - 1e-5).abs() < 1e-10); + } +} diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs new file mode 100644 index 0000000..3c82519 --- /dev/null +++ b/spiel_bot/src/bin/az_eval.rs @@ -0,0 +1,262 @@ +//! Evaluate a trained AlphaZero checkpoint against a random player. +//! +//! # Usage +//! +//! ```sh +//! # Random weights (sanity check — should be ~50 %) +//! cargo run -p spiel_bot --bin az_eval --release +//! +//! # Trained MLP checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 +//! +//! # Trained ResNet checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | +//! | `--n-games ` | `100` | Games per side (total = 2 × N) | +//! | `--n-sim ` | `50` | MCTS simulations per move | +//! | `--seed ` | `42` | RNG seed | +//! | `--c-puct ` | `1.5` | PUCT exploration constant | + +use std::path::PathBuf; + +use burn::backend::NdArray; +use rand::{SeedableRng, rngs::SmallRng, Rng}; + +use spiel_bot::{ + alphazero::BurnEvaluator, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts, select_action}, + network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, +}; + +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + checkpoint: Option, + arch: String, + hidden: Option, + n_games: usize, + n_sim: usize, + seed: u64, + c_puct: f32, +} + +impl Default for Args { + fn default() -> Self { + Self { + checkpoint: None, + arch: "mlp".into(), + hidden: None, + n_games: 100, + n_sim: 50, + seed: 42, + c_puct: 1.5, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut args = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } + "--arch" => { i += 1; args.arch = raw[i].clone(); } + "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } + "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } + "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } + "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } + "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + args +} + +// ── Game loop ───────────────────────────────────────────────────────────────── + +/// Play one complete game. +/// +/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). +/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). +fn play_game( + env: &TrictracEnv, + mcts_side: usize, + evaluator: &dyn Evaluator, + mcts_cfg: &MctsConfig, + rng: &mut SmallRng, +) -> [f32; 2] { + let mut state = env.new_game(); + loop { + match env.current_player(&state) { + Player::Terminal => { + return env.returns(&state).expect("Terminal state must have returns"); + } + Player::Chance => env.apply_chance(&mut state, rng), + player => { + let side = player.index().unwrap(); // 0 = P1, 1 = P2 + let action = if side == mcts_side { + let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); + select_action(&root, 0.0, rng) // greedy (temperature = 0) + } else { + let actions = env.legal_actions(&state); + actions[rng.random_range(0..actions.len())] + }; + env.apply(&mut state, action); + } + } + } +} + +// ── Statistics ──────────────────────────────────────────────────────────────── + +#[derive(Default)] +struct Stats { + wins: u32, + draws: u32, + losses: u32, +} + +impl Stats { + fn record(&mut self, mcts_return: f32) { + if mcts_return > 0.0 { self.wins += 1; } + else if mcts_return < 0.0 { self.losses += 1; } + else { self.draws += 1; } + } + + fn total(&self) -> u32 { self.wins + self.draws + self.losses } + + fn win_rate_decisive(&self) -> f64 { + let d = self.wins + self.losses; + if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } + } + + fn print(&self) { + let n = self.total(); + let pct = |k: u32| 100.0 * k as f64 / n as f64; + println!( + " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", + self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), + ); + } +} + +// ── Evaluation ──────────────────────────────────────────────────────────────── + +fn run_evaluation( + evaluator: &dyn Evaluator, + n_games: usize, + mcts_cfg: &MctsConfig, + seed: u64, +) -> (Stats, Stats) { + let env = TrictracEnv; + let total = n_games * 2; + let mut as_p1 = Stats::default(); + let mut as_p2 = Stats::default(); + + for i in 0..total { + // Alternate sides: even games → MctsAgent as P1, odd → as P2. + let mcts_side = i % 2; + let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); + let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); + + let mcts_return = result[mcts_side]; + if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } + + let done = i + 1; + if done % 10 == 0 || done == total { + eprint!("\r [{done}/{total}] ", ); + } + } + eprintln!(); + (as_p1, as_p2) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let device: ::Device = Default::default(); + + // ── Load model ──────────────────────────────────────────────────────── + let evaluator: Box = match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => ResNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => ResNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => MlpNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => MlpNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + }; + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: 0.0, // no exploration noise during evaluation + dirichlet_eps: 0.0, + temperature: 0.0, // greedy action selection + }; + + // ── Header ──────────────────────────────────────────────────────────── + let ckpt_label = args.checkpoint + .as_deref() + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .unwrap_or("random weights"); + + println!(); + println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", + args.arch, args.n_sim); + println!("Games per side: {} | Total: {} | Seed: {}", + args.n_games, args.n_games * 2, args.seed); + println!(); + + // ── Run ─────────────────────────────────────────────────────────────── + let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); + + // ── Results ─────────────────────────────────────────────────────────── + println!("MctsAgent as P1 (White):"); + as_p1.print(); + + println!("MctsAgent as P2 (Black):"); + as_p2.print(); + + let combined_wins = as_p1.wins + as_p2.wins; + let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; + let combined_wr = if combined_decisive == 0 { 0.5 } + else { combined_wins as f64 / combined_decisive as f64 }; + + println!(); + println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", + combined_wr * 100.0, combined_wins, combined_decisive); + println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", + as_p1.win_rate_decisive() * 100.0, + as_p2.win_rate_decisive() * 100.0); +} diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs new file mode 100644 index 0000000..824abe5 --- /dev/null +++ b/spiel_bot/src/bin/az_train.rs @@ -0,0 +1,331 @@ +//! AlphaZero self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh (MLP, default settings) +//! cargo run -p spiel_bot --bin az_train --release +//! +//! # ResNet, 200 iterations, save every 20 +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/ +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden N` | 256/512 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | `100` | Training iterations | +//! | `--n-games N` | `10` | Self-play games per iteration | +//! | `--n-train N` | `20` | Gradient steps per iteration | +//! | `--n-sim N` | `100` | MCTS simulations per move | +//! | `--batch N` | `64` | Mini-batch size | +//! | `--replay-cap N` | `50000` | Replay buffer capacity | +//! | `--lr F` | `1e-3` | Peak (initial) learning rate | +//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) | +//! | `--c-puct F` | `1.5` | PUCT exploration constant | +//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha | +//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight | +//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 | +//! | `--save-every N` | `10` | Save checkpoint every N iterations | +//! | `--seed N` | `42` | RNG seed | +//! | `--resume PATH` | (none) | Load weights from checkpoint before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; +use rayon::prelude::*; + +use spiel_bot::{ + alphazero::{ + BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step, + }, + env::TrictracEnv, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + arch: String, + hidden: Option, + out_dir: PathBuf, + n_iter: usize, + n_games: usize, + n_train: usize, + n_sim: usize, + batch_size: usize, + replay_cap: usize, + lr: f64, + lr_min: f64, + c_puct: f32, + dirichlet_alpha: f32, + dirichlet_eps: f32, + temp_drop: usize, + save_every: usize, + seed: u64, + resume: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + arch: "mlp".into(), + hidden: None, + out_dir: PathBuf::from("checkpoints"), + n_iter: 100, + n_games: 10, + n_train: 20, + n_sim: 100, + batch_size: 64, + replay_cap: 50_000, + lr: 1e-3, + lr_min: 1e-4, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temp_drop: 30, + save_every: 10, + seed: 42, + resume: None, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--arch" => { i += 1; a.arch = raw[i].clone(); } + "--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); } + "--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); } + "--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); } + "--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); } + "--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); } + "--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); } + "--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); } + "--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +/// Generic training loop, parameterised over the network type. +/// +/// `save_fn` receives the **training-backend** model and the target path; +/// it is called in the match arm where the concrete network type is known. +fn train_loop( + mut model: N, + save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>, + args: &Args, +) +where + N: PolicyValueNet + AutodiffModule + Clone, + >::InnerModule: PolicyValueNet + Send + 'static, +{ + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + // Type is inferred as OptimizerAdaptor at the call site. + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(args.replay_cap); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + // Total gradient steps (used for cosine LR denominator). + let total_train_steps = (args.n_iter * args.n_train).max(1); + let mut global_step = 0usize; + + println!( + "\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}", + "", args.arch, args.n_iter, args.n_games, args.n_sim, "" + ); + + for iter in 0..args.n_iter { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + // Convert to inference backend (zero autodiff overhead). + let infer_model: >::InnerModule = model.valid(); + let evaluator: BurnEvaluator>::InnerModule> = + BurnEvaluator::new(infer_model, infer_device.clone()); + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: args.dirichlet_alpha, + dirichlet_eps: args.dirichlet_eps, + temperature: 1.0, + }; + + let temp_drop = args.temp_drop; + let temperature_fn = |step: usize| -> f32 { + if step < temp_drop { 1.0 } else { 0.0 } + }; + + // Prepare per-game seeds and evaluators sequentially so the main RNG + // and model cloning stay deterministic regardless of thread scheduling. + // Burn modules are Send but not Sync, so each task must own its model. + let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); + let game_evals: Vec<_> = (0..args.n_games) + .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) + .collect(); + drop(evaluator); + + let all_samples: Vec> = game_seeds + .into_par_iter() + .zip(game_evals.into_par_iter()) + .map(|(seed, game_eval)| { + let mut game_rng = SmallRng::seed_from_u64(seed); + generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) + }) + .collect(); + + let mut new_samples = 0usize; + for samples in all_samples { + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.batch_size { + for _ in 0..args.n_train { + let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + let batch: Vec = replay + .sample_batch(args.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = + train_step(model, &mut optimizer, &batch, &train_device, lr); + model = m; + loss_sum += loss; + n_steps += 1; + global_step += 1; + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s", + iter + 1, + args.n_iter, + replay.len(), + new_samples, + avg_loss, + lr_now, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.n_iter; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1)); + match save_fn(&model, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nTraining complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + // Create output directory if it doesn't exist. + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + + match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + ResNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => ResNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &ResNet, path: &Path| { + // Save via inference model to avoid autodiff record overhead. + m.valid().save(path) + }, + &args, + ); + } + + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + MlpNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => MlpNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &MlpNet, path: &Path| m.valid().save(path), + &args, + ); + } + } +} diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs new file mode 100644 index 0000000..0ebe978 --- /dev/null +++ b/spiel_bot/src/bin/dqn_train.rs @@ -0,0 +1,251 @@ +//! DQN self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh with default settings +//! cargo run -p spiel_bot --bin dqn_train --release +//! +//! # Custom hyperparameters +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--hidden N` | 256 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | 100 | Training iterations | +//! | `--n-games N` | 10 | Self-play games per iteration | +//! | `--n-train N` | 20 | Gradient steps per iteration | +//! | `--batch N` | 64 | Mini-batch size | +//! | `--replay-cap N` | 50000 | Replay buffer capacity | +//! | `--lr F` | 1e-3 | Adam learning rate | +//! | `--epsilon-start F` | 1.0 | Initial exploration rate | +//! | `--epsilon-end F` | 0.05 | Final exploration rate | +//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | +//! | `--gamma F` | 0.99 | Discount factor | +//! | `--target-update N` | 500 | Hard-update target net every N steps | +//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | +//! | `--save-every N` | 10 | Save checkpoint every N iterations | +//! | `--seed N` | 42 | RNG seed | +//! | `--resume PATH` | (none) | Load weights before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + dqn::{ + DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, + generate_dqn_episode, hard_update, linear_epsilon, + }, + env::TrictracEnv, + network::{QNet, QNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + hidden: usize, + out_dir: PathBuf, + save_every: usize, + seed: u64, + resume: Option, + config: DqnConfig, +} + +impl Default for Args { + fn default() -> Self { + Self { + hidden: 256, + out_dir: PathBuf::from("checkpoints"), + save_every: 10, + seed: 42, + resume: None, + config: DqnConfig::default(), + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } + "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } + "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } + "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } + "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } + "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } + "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } + "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +fn train_loop( + mut q_net: QNet, + cfg: &QNetConfig, + save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, + args: &Args, +) { + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + let mut optimizer = AdamConfig::new().init(); + let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + let mut target_net: QNet = hard_update::(&q_net); + let mut global_step = 0usize; + let mut epsilon = args.config.epsilon_start; + + println!( + "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", + "", args.config.n_iterations, args.config.n_games_per_iter, + args.config.n_train_steps_per_iter, "" + ); + + for iter in 0..args.config.n_iterations { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + let infer_q: QNet = q_net.valid(); + let mut new_samples = 0usize; + + for _ in 0..args.config.n_games_per_iter { + let samples = generate_dqn_episode( + &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, + ); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.config.batch_size { + for _ in 0..args.config.n_train_steps_per_iter { + let batch: Vec<_> = replay + .sample_batch(args.config.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + // Target Q-values computed on the inference backend. + let target_q = compute_target_q( + &target_net, &batch, cfg.action_size, &infer_device, + ); + + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &target_q, + &train_device, args.config.learning_rate, args.config.gamma, + ); + q_net = q; + loss_sum += loss; + n_steps += 1; + global_step += 1; + + // Hard-update target net every target_update_freq steps. + if global_step % args.config.target_update_freq == 0 { + target_net = hard_update::(&q_net); + } + + // Linear epsilon decay. + epsilon = linear_epsilon( + args.config.epsilon_start, + args.config.epsilon_end, + global_step, + args.config.epsilon_decay_steps, + ); + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", + iter + 1, + args.config.n_iterations, + replay.len(), + new_samples, + avg_loss, + epsilon, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.config.n_iterations; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); + match save_fn(&q_net, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nDQN training complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; + + let q_net = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + QNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => QNet::::new(&cfg, &train_device), + }; + + train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); +} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs new file mode 100644 index 0000000..aca1343 --- /dev/null +++ b/spiel_bot/src/dqn/episode.rs @@ -0,0 +1,247 @@ +//! DQN self-play episode generation. +//! +//! Both players share the same Q-network (the [`TrictracEnv`] handles board +//! mirroring so that each player always acts from "White's perspective"). +//! Transitions for both players are stored in the returned sample list. +//! +//! # Reward +//! +//! After each full decision (action applied and the state has advanced through +//! any intervening chance nodes back to the same player's next turn), the +//! reward is: +//! +//! ```text +//! r = (my_total_score_now − my_total_score_then) +//! − (opp_total_score_now − opp_total_score_then) +//! ``` +//! +//! where `total_score = holes × 12 + points`. +//! +//! # Transition structure +//! +//! We use a "pending transition" per player. When a player acts again, we +//! *complete* the previous pending transition by filling in `next_obs`, +//! `next_legal`, and computing `reward`. Terminal transitions are completed +//! when the game ends. + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::{GameEnv, TrictracEnv}; +use crate::network::QValueNet; +use super::DqnSample; + +// ── Internals ───────────────────────────────────────────────────────────────── + +struct PendingTransition { + obs: Vec, + action: usize, + /// Score snapshot `[p1_total, p2_total]` at the moment of the action. + score_before: [i32; 2], +} + +/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. +fn epsilon_greedy>( + q_net: &Q, + obs: &[f32], + legal: &[usize], + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, +) -> usize { + debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); + if rng.random::() < epsilon { + legal[rng.random_range(0..legal.len())] + } else { + let obs_tensor = Tensor::::from_data( + TensorData::new(obs.to_vec(), [1, obs.len()]), + device, + ); + let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); + legal + .iter() + .copied() + .max_by(|&a, &b| { + q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap() + } +} + +/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. +fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { + let opp_idx = 1 - player_idx; + ((score_after[player_idx] - score_before[player_idx]) + - (score_after[opp_idx] - score_before[opp_idx])) as f32 +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Play one full game and return all transitions for both players. +/// +/// - `q_net` uses the **inference backend** (no autodiff wrapper). +/// - `epsilon` in `[0, 1]`: probability of taking a random action. +/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). +pub fn generate_dqn_episode>( + env: &TrictracEnv, + q_net: &Q, + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, + reward_scale: f32, +) -> Vec { + let obs_size = env.obs_size(); + let mut state = env.new_game(); + let mut pending: [Option; 2] = [None, None]; + let mut samples: Vec = Vec::new(); + + loop { + // ── Advance past chance nodes ────────────────────────────────────── + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + let score_now = TrictracEnv::score_snapshot(&state); + + if env.current_player(&state).is_terminal() { + // Complete all pending transitions as terminal. + for player_idx in 0..2 { + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: vec![0.0; obs_size], + next_legal: vec![], + done: true, + }); + } + } + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + let legal = env.legal_actions(&state); + let obs = env.observation(&state, player_idx); + + // ── Complete the previous transition for this player ─────────────── + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: obs.clone(), + next_legal: legal.clone(), + done: false, + }); + } + + // ── Pick and apply action ────────────────────────────────────────── + let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); + env.apply(&mut state, action); + + // ── Record new pending transition ────────────────────────────────── + pending[player_idx] = Some(PendingTransition { + obs, + action, + score_before: score_now, + }); + } + + samples +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::network::{QNet, QNetConfig}; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } + + fn tiny_q() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn episode_terminates_and_produces_samples() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_obs_size_correct() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert_eq!(s.obs.len(), 217, "obs size mismatch"); + if s.done { + assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); + assert!(s.next_legal.is_empty()); + } else { + assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); + assert!(!s.next_legal.is_empty()); + } + } + } + + #[test] + fn episode_actions_within_action_space() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert!(s.action < 514, "action {} out of bounds", s.action); + } + } + + #[test] + fn greedy_episode_also_terminates() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty()); + } + + #[test] + fn at_least_one_done_sample() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + let n_done = samples.iter().filter(|s| s.done).count(); + // Two players, so 1 or 2 terminal transitions. + assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); + } + + #[test] + fn compute_reward_correct() { + // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. + let before = [2 * 12 + 10, 0]; + let after = [3 * 12 + 2, 0]; + let r = compute_reward(0, &before, &after); + assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); + } + + #[test] + fn compute_reward_with_opponent_scoring() { + // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. + let before = [0, 0]; + let after = [2, 3]; + let r = compute_reward(0, &before, &after); + assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); + } +} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs new file mode 100644 index 0000000..8c34fc1 --- /dev/null +++ b/spiel_bot/src/dqn/mod.rs @@ -0,0 +1,232 @@ +//! DQN: self-play data generation, replay buffer, and training step. +//! +//! # Algorithm +//! +//! Deep Q-Network with: +//! - **ε-greedy** exploration (linearly decayed). +//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where +//! `score = holes × 12 + points`. +//! - **Experience replay** with a fixed-capacity circular buffer. +//! - **Target network**: hard-copied from the online Q-net every +//! `target_update_freq` gradient steps for training stability. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | +//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | + +pub mod episode; +pub mod trainer; + +pub use episode::generate_dqn_episode; +pub use trainer::{compute_target_q, dqn_train_step, hard_update}; + +use std::collections::VecDeque; +use rand::Rng; + +// ── DqnSample ───────────────────────────────────────────────────────────────── + +/// One transition `(s, a, r, s', done)` collected during self-play. +#[derive(Clone, Debug)] +pub struct DqnSample { + /// Observation from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// Action index taken. + pub action: usize, + /// Per-turn reward: `my_score_delta − opponent_score_delta`. + pub reward: f32, + /// Next observation from the same player's perspective. + /// All-zeros when `done = true` (ignored by the TD target). + pub next_obs: Vec, + /// Legal actions at `next_obs`. Empty when `done = true`. + pub next_legal: Vec, + /// `true` when `next_obs` is a terminal state. + pub done: bool, +} + +// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── + +/// Fixed-capacity circular replay buffer for [`DqnSample`]s. +/// +/// When full, the oldest sample is evicted on push. +/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. +pub struct DqnReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl DqnReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } + } + + pub fn push(&mut self, sample: DqnSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { self.push(s); } + } + + pub fn len(&self) -> usize { self.data.len() } + pub fn is_empty(&self) -> bool { self.data.is_empty() } + + /// Sample up to `n` distinct samples without replacement. + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { + let len = self.data.len(); + let n = n.min(len); + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── DqnConfig ───────────────────────────────────────────────────────────────── + +/// Top-level DQN hyperparameters for the training loop. +#[derive(Debug, Clone)] +pub struct DqnConfig { + /// Initial exploration rate (1.0 = fully random). + pub epsilon_start: f32, + /// Final exploration rate after decay. + pub epsilon_end: f32, + /// Number of gradient steps over which ε decays linearly from start to end. + /// + /// Should be calibrated to the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). A value larger than that + /// means exploration never reaches `epsilon_end` during the run. + pub epsilon_decay_steps: usize, + /// Discount factor γ for the TD target. Typical: 0.99. + pub gamma: f32, + /// Hard-copy Q → target every this many gradient steps. + /// + /// Should be much smaller than the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). + pub target_update_freq: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Number of outer iterations (self-play + train). + pub n_iterations: usize, + /// Self-play games per iteration. + pub n_games_per_iter: usize, + /// Gradient steps per iteration. + pub n_train_steps_per_iter: usize, + /// Reward normalisation divisor. + /// + /// Per-turn rewards (score delta) are divided by this constant before being + /// stored. Without normalisation, rewards can reach ±24 (jan with + /// bredouille = 12 pts × 2), driving Q-values into the hundreds and + /// causing MSE loss to grow unboundedly. + /// + /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping + /// Q-value magnitudes in a stable range. Set to `1.0` to disable. + pub reward_scale: f32, +} + +impl Default for DqnConfig { + fn default() -> Self { + // Total gradient steps with these defaults = 500 × 20 = 10_000, + // so epsilon decays fully and the target is updated 100 times. + Self { + epsilon_start: 1.0, + epsilon_end: 0.05, + epsilon_decay_steps: 10_000, + gamma: 0.99, + target_update_freq: 100, + learning_rate: 1e-3, + batch_size: 64, + replay_capacity: 50_000, + n_iterations: 500, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + reward_scale: 12.0, + } + } +} + +/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. +pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { + if decay_steps == 0 || step >= decay_steps { + return end; + } + start + (end - start) * (step as f32 / decay_steps as f32) +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(reward: f32) -> DqnSample { + DqnSample { + obs: vec![0.0], + action: 0, + reward, + next_obs: vec![0.0], + next_legal: vec![0], + done: false, + } + } + + #[test] + fn push_and_len() { + let mut buf = DqnReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = DqnReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); + assert_eq!(buf.len(), 3); + assert_eq!(buf.data[0].reward, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = DqnReplayBuffer::new(20); + for i in 0..10 { buf.push(dummy(i as f32)); } + let mut rng = SmallRng::seed_from_u64(0); + assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); + } + + #[test] + fn linear_epsilon_start() { + assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_end() { + assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_monotone() { + let mut prev = f32::INFINITY; + for step in 0..=100 { + let e = linear_epsilon(1.0, 0.05, step, 100); + assert!(e <= prev + 1e-6); + prev = e; + } + } +} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs new file mode 100644 index 0000000..b8b0a02 --- /dev/null +++ b/spiel_bot/src/dqn/trainer.rs @@ -0,0 +1,278 @@ +//! DQN gradient step and target-network management. +//! +//! # TD target +//! +//! ```text +//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done +//! y_i = r_i if done +//! ``` +//! +//! # Loss +//! +//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) +//! and `y_i` (computed from the frozen target net). +//! +//! # Target network +//! +//! [`hard_update`] copies the online Q-net weights into the target net by +//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + Int, Tensor, TensorData, + backend::{AutodiffBackend, Backend}, + }, +}; + +use crate::network::QValueNet; +use super::DqnSample; + +// ── Target Q computation ───────────────────────────────────────────────────── + +/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. +/// +/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` +/// (their bootstrap term is dropped by the TD target anyway). +/// +/// The target network runs on the **inference backend** (`InferB`) with no +/// gradient tape, so this function is backend-agnostic (`B: Backend`). +pub fn compute_target_q>( + target_net: &Q, + batch: &[DqnSample], + action_size: usize, + device: &B::Device, +) -> Vec { + let batch_size = batch.len(); + + // Collect indices of non-done samples (done samples have no next state). + let non_done: Vec = batch + .iter() + .enumerate() + .filter(|(_, s)| !s.done) + .map(|(i, _)| i) + .collect(); + + if non_done.is_empty() { + return vec![0.0; batch_size]; + } + + let obs_size = batch[0].next_obs.len(); + let nd = non_done.len(); + + // Stack next observations for non-done samples → [nd, obs_size]. + let obs_flat: Vec = non_done + .iter() + .flat_map(|&i| batch[i].next_obs.iter().copied()) + .collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [nd, obs_size]), + device, + ); + + // Forward target net → [nd, action_size], then to Vec. + let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); + + // For each non-done sample, pick max Q over legal next actions. + let mut result = vec![0.0f32; batch_size]; + for (k, &i) in non_done.iter().enumerate() { + let legal = &batch[i].next_legal; + let offset = k * action_size; + let max_q = legal + .iter() + .map(|&a| q_flat[offset + a]) + .fold(f32::NEG_INFINITY, f32::max); + // If legal is empty (shouldn't happen for non-done, but be safe): + result[i] = if max_q.is_finite() { max_q } else { 0.0 }; + } + result +} + +// ── Training step ───────────────────────────────────────────────────────────── + +/// Run one gradient step on `q_net` using `batch`. +/// +/// `target_max_q` must be pre-computed via [`compute_target_q`] using the +/// frozen target network and passed in here so that this function only +/// needs the **autodiff backend**. +/// +/// Returns the updated network and the scalar MSE loss. +pub fn dqn_train_step( + q_net: Q, + optimizer: &mut O, + batch: &[DqnSample], + target_max_q: &[f32], + device: &B::Device, + lr: f64, + gamma: f32, +) -> (Q, f32) +where + B: AutodiffBackend, + Q: QValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "dqn_train_step: empty batch"); + assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + + // ── Build observation tensor [B, obs_size] ──────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + + // ── Forward Q-net → [B, action_size] ───────────────────────────────── + let q_all = q_net.forward(obs_tensor); + + // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── + let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); + let action_tensor: Tensor = Tensor::::from_data( + TensorData::new(actions, [batch_size]), + device, + ) + .reshape([batch_size, 1]); // [B] → [B, 1] + let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] + + // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── + let targets: Vec = batch + .iter() + .zip(target_max_q.iter()) + .map(|(s, &max_q)| { + if s.done { s.reward } else { s.reward + gamma * max_q } + }) + .collect(); + let target_tensor = Tensor::::from_data( + TensorData::new(targets, [batch_size]), + device, + ); + + // ── MSE loss ────────────────────────────────────────────────────────── + let diff = q_pred - target_tensor.detach(); + let loss = (diff.clone() * diff).mean(); + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &q_net); + let q_net = optimizer.step(lr, q_net, grads); + + (q_net, loss_scalar) +} + +// ── Target network update ───────────────────────────────────────────────────── + +/// Hard-copy the online Q-net weights to a new target network. +/// +/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an +/// inference-backend module with identical weights. +pub fn hard_update>(q_net: &Q) -> Q::InnerModule { + q_net.valid() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + use crate::network::{QNet, QNetConfig}; + + type InferB = NdArray; + type TrainB = Autodiff>; + + fn infer_device() -> ::Device { Default::default() } + fn train_device() -> ::Device { Default::default() } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| DqnSample { + obs: vec![0.5f32; obs_size], + action: i % action_size, + reward: if i % 2 == 0 { 1.0 } else { -1.0 }, + next_obs: vec![0.5f32; obs_size], + next_legal: vec![0, 1], + done: i == n - 1, + }) + .collect() + } + + #[test] + fn compute_target_q_length() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 8); + } + + #[test] + fn compute_target_q_done_is_zero() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + // Single done sample. + let batch = vec![DqnSample { + obs: vec![0.0; 4], + action: 0, + reward: 5.0, + next_obs: vec![0.0; 4], + next_legal: vec![], + done: true, + }]; + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 1); + assert_eq!(tq[0], 0.0); + } + + #[test] + fn train_step_returns_finite_loss() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + + #[test] + fn train_step_loss_decreases() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(16, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, + ); + q_net = q; + assert!(loss.is_finite()); + prev_loss = loss; + } + assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn hard_update_copies_weights() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = hard_update::(&q_net); + + let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); + let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); + // After hard_update the target produces finite outputs. + assert!(q_out.iter().all(|v| v.is_finite())); + } +} diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs new file mode 100644 index 0000000..42b4ae0 --- /dev/null +++ b/spiel_bot/src/env/mod.rs @@ -0,0 +1,121 @@ +//! Game environment abstraction — the minimal "Rust OpenSpiel". +//! +//! A `GameEnv` describes the rules of a two-player, zero-sum game that may +//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, +//! and PPO interact with a game exclusively through this trait. +//! +//! # Node taxonomy +//! +//! Every game position belongs to one of four categories, returned by +//! [`GameEnv::current_player`]: +//! +//! | [`Player`] | Meaning | +//! |-----------|---------| +//! | `P1` | Player 1 (index 0) must choose an action | +//! | `P2` | Player 2 (index 1) must choose an action | +//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | +//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | +//! +//! # Perspective convention +//! +//! [`GameEnv::observation`] always returns the board from *the requested +//! player's* point of view. Callers pass `pov = 0` for Player 1 and +//! `pov = 1` for Player 2. The implementation is responsible for any +//! mirroring required (e.g. Trictrac always reasons from White's side). + +pub mod trictrac; +pub use trictrac::TrictracEnv; + +/// Who controls the current game node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Player { + /// Player 1 (index 0) is to move. + P1, + /// Player 2 (index 1) is to move. + P2, + /// A stochastic event (dice roll, etc.) must be resolved. + Chance, + /// The game is over. + Terminal, +} + +impl Player { + /// Returns the player index (0 or 1) if this is a decision node, + /// or `None` for `Chance` / `Terminal`. + pub fn index(self) -> Option { + match self { + Player::P1 => Some(0), + Player::P2 => Some(1), + _ => None, + } + } + + pub fn is_decision(self) -> bool { + matches!(self, Player::P1 | Player::P2) + } + + pub fn is_chance(self) -> bool { + self == Player::Chance + } + + pub fn is_terminal(self) -> bool { + self == Player::Terminal + } +} + +/// Trait that completely describes a two-player zero-sum game. +/// +/// Implementors must be cheaply cloneable (the type is used as a stateless +/// factory; the mutable game state lives in `Self::State`). +pub trait GameEnv: Clone + Send + Sync + 'static { + /// The mutable game state. Must be `Clone` so MCTS can copy + /// game trees without touching the environment. + type State: Clone + Send + Sync; + + // ── State creation ──────────────────────────────────────────────────── + + /// Create a fresh game state at the initial position. + fn new_game(&self) -> Self::State; + + // ── Node queries ────────────────────────────────────────────────────── + + /// Classify the current node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). + /// + /// The returned indices are in `[0, action_space())`. + /// The result is unspecified (may panic or return empty) when called at a + /// `Chance` or `Terminal` node. + fn legal_actions(&self, s: &Self::State) -> Vec; + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action. `action` must be a value returned by + /// [`legal_actions`] for the current state. + fn apply(&self, s: &mut Self::State, action: usize); + + /// Sample and apply a stochastic outcome. Must only be called when + /// `current_player(s) == Player::Chance`. + fn apply_chance(&self, s: &mut Self::State, rng: &mut R); + + // ── Observation ─────────────────────────────────────────────────────── + + /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). + /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Number of floats returned by [`observation`]. + fn obs_size(&self) -> usize; + + /// Total number of distinct action indices (the policy head output size). + fn action_space(&self) -> usize; + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Game outcome for each player, or `None` if the game is not over. + /// + /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. + /// Index 0 = Player 1, index 1 = Player 2. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs new file mode 100644 index 0000000..8dc3676 --- /dev/null +++ b/spiel_bot/src/env/trictrac.rs @@ -0,0 +1,547 @@ +//! [`GameEnv`] implementation for Trictrac. +//! +//! # Game flow (schools_enabled = false) +//! +//! With scoring schools disabled (the standard training configuration), +//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine +//! applies them automatically inside `RollResult` and `Move`. The only +//! four stages that actually occur are: +//! +//! | `TurnStage` | [`Player`] kind | Handled by | +//! |-------------|-----------------|------------| +//! | `RollDice` | `Chance` | [`apply_chance`] | +//! | `RollWaiting` | `Chance` | [`apply_chance`] | +//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | +//! | `Move` | `P1`/`P2` | [`apply`] | +//! +//! # Perspective +//! +//! The Trictrac engine always reasons from White's perspective. Player 1 is +//! White; Player 2 is Black. When Player 2 is active, the board is mirrored +//! before computing legal actions / the observation tensor, and the resulting +//! event is mirrored back before being applied to the real state. This +//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. + +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, + Dice, GameEvent, GameState, Stage, TurnStage, +}; + +use super::{GameEnv, Player}; + +/// Stateless factory that produces Trictrac [`GameState`] environments. +/// +/// Schools (`schools_enabled`) are always disabled — scoring is automatic. +#[derive(Clone, Debug, Default)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + // ── State creation ──────────────────────────────────────────────────── + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + // ── Node queries ────────────────────────────────────────────────────── + + fn current_player(&self, s: &GameState) -> Player { + if s.stage == Stage::Ended { + return Player::Terminal; + } + match s.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, + _ => { + if s.active_player_id == 1 { + Player::P1 + } else { + Player::P2 + } + } + } + } + + /// Returns the legal action indices for the active player. + /// + /// The board is automatically mirrored for Player 2 so that the engine + /// always reasons from White's perspective. The returned indices are + /// identical in meaning for both players (checker ordinals are + /// perspective-relative). + /// + /// # Panics + /// + /// Panics in debug builds if called at a `Chance` or `Terminal` node. + fn legal_actions(&self, s: &GameState) -> Vec { + debug_assert!( + self.current_player(s).is_decision(), + "legal_actions called at a non-decision node (turn_stage={:?})", + s.turn_stage + ); + let indices = if s.active_player_id == 2 { + get_valid_action_indices(&s.mirror()) + } else { + get_valid_action_indices(s) + }; + indices.unwrap_or_default() + } + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action index to the game state. + /// + /// For Player 2, the action is decoded against the mirrored board and + /// the resulting event is un-mirrored before being applied. + /// + /// # Panics + /// + /// Panics in debug builds if `action` cannot be decoded or does not + /// produce a valid event for the current state. + fn apply(&self, s: &mut GameState, action: usize) { + let needs_mirror = s.active_player_id == 2; + + let event = if needs_mirror { + let view = s.mirror(); + TrictracAction::from_action_index(action) + .and_then(|a| a.to_event(&view)) + .map(|e| e.get_mirror(false)) + } else { + TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) + }; + + match event { + Some(e) => { + s.consume(&e).expect("apply: consume failed for valid action"); + } + None => { + panic!("apply: action index {action} produced no event in state {s}"); + } + } + } + + /// Sample dice and advance through a chance node. + /// + /// Handles both `RollDice` (triggers the roll mechanism, then samples + /// dice) and `RollWaiting` (only samples dice) in a single call so that + /// callers never need to distinguish the two. + /// + /// # Panics + /// + /// Panics in debug builds if called at a non-Chance node. + fn apply_chance(&self, s: &mut GameState, rng: &mut R) { + debug_assert!( + self.current_player(s).is_chance(), + "apply_chance called at a non-Chance node (turn_stage={:?})", + s.turn_stage + ); + + // Step 1: RollDice → RollWaiting (player initiates the roll). + if s.turn_stage == TurnStage::RollDice { + s.consume(&GameEvent::Roll { + player_id: s.active_player_id, + }) + .expect("apply_chance: Roll event failed"); + } + + // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. + // With schools_enabled=false, point marking is automatic inside consume(). + let dice = Dice { + values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), + }; + s.consume(&GameEvent::RollResult { + player_id: s.active_player_id, + dice, + }) + .expect("apply_chance: RollResult event failed"); + } + + // ── Observation ─────────────────────────────────────────────────────── + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { + s.to_tensor() + } else { + s.mirror().to_tensor() + } + } + + fn obs_size(&self) -> usize { + 217 + } + + fn action_space(&self) -> usize { + ACTION_SPACE_SIZE + } + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. + /// + /// The winner (higher cumulative score) receives `+1.0`; the loser + /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score + /// is `holes × 12 + points`. + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { + return None; + } + let score = |id: u64| -> i32 { + s.players + .get(&id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + }; + let s1 = score(1); + let s2 = score(2); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [0.0, 0.0], + }) + } +} + +// ── DQN helpers ─────────────────────────────────────────────────────────────── + +impl TrictracEnv { + /// Score snapshot for DQN reward computation. + /// + /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. + /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). + pub fn score_snapshot(s: &GameState) -> [i32; 2] { + [s.total_score(1), s.total_score(2)] + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{rngs::SmallRng, Rng, SeedableRng}; + + fn env() -> TrictracEnv { + TrictracEnv + } + + fn seeded_rng(seed: u64) -> SmallRng { + SmallRng::seed_from_u64(seed) + } + + // ── Initial state ───────────────────────────────────────────────────── + + #[test] + fn new_game_is_chance_node() { + let e = env(); + let s = e.new_game(); + // A fresh game starts at RollDice — a Chance node. + assert_eq!(e.current_player(&s), Player::Chance); + assert!(e.returns(&s).is_none()); + } + + #[test] + fn new_game_is_not_terminal() { + let e = env(); + let s = e.new_game(); + assert_ne!(e.current_player(&s), Player::Terminal); + assert!(e.returns(&s).is_none()); + } + + // ── Chance nodes ────────────────────────────────────────────────────── + + #[test] + fn apply_chance_reaches_decision_node() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(1); + + // A single chance step must yield a decision node (or end the game, + // which only happens after 12 holes — impossible on the first roll). + e.apply_chance(&mut s, &mut rng); + let p = e.current_player(&s); + assert!( + p.is_decision(), + "expected decision node after first roll, got {p:?}" + ); + } + + #[test] + fn apply_chance_from_rollwaiting() { + // Check that apply_chance works when called mid-way (at RollWaiting). + let e = env(); + let mut s = e.new_game(); + assert_eq!(s.turn_stage, TurnStage::RollDice); + + // Manually advance to RollWaiting. + s.consume(&GameEvent::Roll { player_id: s.active_player_id }) + .unwrap(); + assert_eq!(s.turn_stage, TurnStage::RollWaiting); + + let mut rng = seeded_rng(2); + e.apply_chance(&mut s, &mut rng); + + let p = e.current_player(&s); + assert!(p.is_decision() || p.is_terminal()); + } + + // ── Legal actions ───────────────────────────────────────────────────── + + #[test] + fn legal_actions_nonempty_after_roll() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(3); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let actions = e.legal_actions(&s); + assert!( + !actions.is_empty(), + "legal_actions must be non-empty at a decision node" + ); + } + + #[test] + fn legal_actions_within_action_space() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(4); + + e.apply_chance(&mut s, &mut rng); + for &a in e.legal_actions(&s).iter() { + assert!( + a < e.action_space(), + "action {a} out of bounds (action_space={})", + e.action_space() + ); + } + } + + // ── Observations ────────────────────────────────────────────────────── + + #[test] + fn observation_has_correct_size() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(5); + e.apply_chance(&mut s, &mut rng); + + assert_eq!(e.observation(&s, 0).len(), e.obs_size()); + assert_eq!(e.observation(&s, 1).len(), e.obs_size()); + } + + #[test] + fn observation_values_in_unit_interval() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(6); + e.apply_chance(&mut s, &mut rng); + + for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { + for (i, &v) in obs.iter().enumerate() { + assert!( + v >= 0.0 && v <= 1.0, + "pov={pov}: obs[{i}] = {v} is outside [0,1]" + ); + } + } + } + + #[test] + fn p1_and_p2_observations_differ() { + // The board is mirrored for P2, so the two observations should differ + // whenever there are checkers in non-symmetric positions (always true + // in a real game after a few moves). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(7); + + // Advance far enough that the board is non-trivial. + for _ in 0..6 { + while e.current_player(&s).is_chance() { + e.apply_chance(&mut s, &mut rng); + } + if e.current_player(&s).is_terminal() { + break; + } + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + + if !e.current_player(&s).is_terminal() { + let obs0 = e.observation(&s, 0); + let obs1 = e.observation(&s, 1); + assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); + } + } + + // ── Applying actions ────────────────────────────────────────────────── + + #[test] + fn apply_changes_state() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(8); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let before = s.clone(); + let action = e.legal_actions(&s)[0]; + e.apply(&mut s, action); + + assert_ne!( + before.turn_stage, s.turn_stage, + "state must change after apply" + ); + } + + #[test] + fn apply_all_legal_actions_do_not_panic() { + // Verify that every action returned by legal_actions can be applied + // without panicking (on several independent copies of the same state). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(9); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + for action in e.legal_actions(&s) { + let mut copy = s.clone(); + e.apply(&mut copy, action); // must not panic + } + } + + // ── Full game ───────────────────────────────────────────────────────── + + /// Run a complete game with random actions through the `GameEnv` trait + /// and verify that: + /// - The game terminates. + /// - `returns()` is `Some` at the end. + /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's + /// score is ±1 / 0. + /// - No step panics. + #[test] + fn full_random_game_terminates() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(42); + let max_steps = 50_000; + + for step in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); + } + + let result = e.returns(&s); + assert!(result.is_some(), "returns() must be Some at Terminal"); + + let [r1, r2] = result.unwrap(); + let sum = r1 + r2; + assert!( + (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, + "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" + ); + assert!( + r1.abs() <= 1.0 && r2.abs() <= 1.0, + "returns must be in [-1,1]: r1={r1}, r2={r2}" + ); + } + + /// Run multiple games with different seeds to stress-test for panics. + #[test] + fn multiple_games_no_panic() { + let e = env(); + let max_steps = 20_000; + + for seed in 0..10u64 { + let mut s = e.new_game(); + let mut rng = seeded_rng(seed); + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + } + } + } + + // ── Returns ─────────────────────────────────────────────────────────── + + #[test] + fn returns_none_mid_game() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(11); + + // Advance a few steps but do not finish the game. + for _ in 0..4 { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + if !e.current_player(&s).is_terminal() { + assert!( + e.returns(&s).is_none(), + "returns() must be None before the game ends" + ); + } + } + + // ── Player 2 actions ────────────────────────────────────────────────── + + /// Verify that Player 2 (Black) can take actions without panicking, + /// and that the state advances correctly. + #[test] + fn player2_can_act() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(12); + + // Keep stepping until Player 2 gets a turn. + let max_steps = 5_000; + let mut p2_acted = false; + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty()); + e.apply(&mut s, actions[0]); + p2_acted = true; + break; + } + Player::P1 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs new file mode 100644 index 0000000..9dfb4de --- /dev/null +++ b/spiel_bot/src/lib.rs @@ -0,0 +1,5 @@ +pub mod alphazero; +pub mod dqn; +pub mod env; +pub mod mcts; +pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs new file mode 100644 index 0000000..eead171 --- /dev/null +++ b/spiel_bot/src/mcts/mod.rs @@ -0,0 +1,412 @@ +//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. +//! +//! # Algorithm +//! +//! The implementation follows AlphaZero's MCTS: +//! +//! 1. **Expand root** — run the network once to get priors and a value +//! estimate; optionally add Dirichlet noise for training-time exploration. +//! 2. **Simulate** `n_simulations` times: +//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. +//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; +//! chance nodes are **not** stored in the tree (outcome sampling). +//! - *Expansion* — evaluate the network at the leaf; populate children. +//! - *Backup* — propagate the value upward; negate at each player boundary. +//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). +//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). +//! +//! # Perspective convention +//! +//! Every [`MctsNode::w`] is stored **from the perspective of the player who +//! acts at that node**. The backup negates the child value whenever the +//! acting player differs between parent and child. +//! +//! # Stochastic games +//! +//! When [`GameEnv::current_player`] returns [`Player::Chance`], the +//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and +//! continues. Chance nodes are skipped transparently; Q-values converge to +//! their expectation over many simulations (outcome sampling). + +pub mod node; +mod search; + +pub use node::MctsNode; + +use rand::Rng; + +use crate::env::GameEnv; + +// ── Evaluator trait ──────────────────────────────────────────────────────── + +/// Evaluates a game position for use in MCTS. +/// +/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) +/// but the `mcts` module itself does **not** depend on Burn. +pub trait Evaluator: Send + Sync { + /// Evaluate `obs` (flat observation vector of length `obs_size`). + /// + /// Returns: + /// - `policy_logits`: one raw logit per action (`action_space` entries). + /// Illegal action entries are masked inside the search — no need to + /// zero them here. + /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. + fn evaluate(&self, obs: &[f32]) -> (Vec, f32); +} + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Hyperparameters for [`run_mcts`]. +#[derive(Debug, Clone)] +pub struct MctsConfig { + /// Number of MCTS simulations per move. Typical: 50–800. + pub n_simulations: usize, + /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. + pub c_puct: f32, + /// Dirichlet noise concentration α. Set to `0.0` to disable. + /// Typical: `0.3` for Chess, `0.1` for large action spaces. + pub dirichlet_alpha: f32, + /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. + pub dirichlet_eps: f32, + /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. + pub temperature: f32, +} + +impl Default for MctsConfig { + fn default() -> Self { + Self { + n_simulations: 200, + c_puct: 1.5, + dirichlet_alpha: 0.3, + dirichlet_eps: 0.25, + temperature: 1.0, + } + } +} + +// ── Public interface ─────────────────────────────────────────────────────── + +/// Run MCTS from `state` and return the populated root [`MctsNode`]. +/// +/// `state` must be a player-decision node (`P1` or `P2`). +/// Use [`mcts_policy`] and [`select_action`] on the returned root. +/// +/// # Panics +/// +/// Panics if `env.current_player(state)` is not `P1` or `P2`. +pub fn run_mcts( + env: &E, + state: &E::State, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, +) -> MctsNode { + let player_idx = env + .current_player(state) + .index() + .expect("run_mcts called at a non-decision node"); + + // ── Expand root (network called once here, not inside the loop) ──────── + let mut root = MctsNode::new(1.0); + search::expand::(&mut root, state, env, evaluator, player_idx); + + // ── Optional Dirichlet noise for training exploration ────────────────── + if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { + search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); + } + + // ── Simulations ──────────────────────────────────────────────────────── + for _ in 0..config.n_simulations { + search::simulate::( + &mut root, + state.clone(), + env, + evaluator, + config, + rng, + player_idx, + ); + } + + root +} + +/// Compute the MCTS policy: normalized visit counts at the root. +/// +/// Returns a vector of length `action_space` where `policy[a]` is the +/// fraction of simulations that visited action `a`. +pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { + let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); + let mut policy = vec![0.0f32; action_space]; + if total > 0.0 { + for (a, child) in &root.children { + policy[*a] = child.n as f32 / total; + } + } else if !root.children.is_empty() { + // n_simulations = 0: uniform over legal actions. + let uniform = 1.0 / root.children.len() as f32; + for (a, _) in &root.children { + policy[*a] = uniform; + } + } + policy +} + +/// Select an action index from the root after MCTS. +/// +/// * `temperature = 0` — greedy argmax of visit counts. +/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. +/// +/// # Panics +/// +/// Panics if the root has no children. +pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { + assert!(!root.children.is_empty(), "select_action called on a root with no children"); + if temperature <= 0.0 { + root.children + .iter() + .max_by_key(|(_, c)| c.n) + .map(|(a, _)| *a) + .unwrap() + } else { + let weights: Vec = root + .children + .iter() + .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) + .collect(); + let total: f32 = weights.iter().sum(); + let mut r: f32 = rng.random::() * total; + for (i, (a, _)) in root.children.iter().enumerate() { + r -= weights[i]; + if r <= 0.0 { + return *a; + } + } + root.children.last().map(|(a, _)| *a).unwrap() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + use crate::env::Player; + + // ── Minimal deterministic test game ─────────────────────────────────── + // + // "Countdown" — two players alternate subtracting 1 or 2 from a counter. + // The player who brings the counter to 0 wins. + // No chance nodes, two legal actions (0 = -1, 1 = -2). + + #[derive(Clone, Debug)] + struct CState { + remaining: u8, + to_move: usize, // at terminal: last mover (winner) + } + + #[derive(Clone)] + struct CountdownEnv; + + impl crate::env::GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: 6, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { + Player::Terminal + } else if s.to_move == 0 { + Player::P1 + } else { + Player::P2 + } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + // to_move stays as winner + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 6.0, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + // Uniform evaluator: all logits = 0, value = 0. + // `action_space` must match the environment's `action_space()`. + struct ZeroEval(usize); + impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(42) + } + + fn config_n(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, // off for reproducibility + dirichlet_eps: 0.0, + temperature: 1.0, + } + } + + // ── Visit count tests ───────────────────────────────────────────────── + + #[test] + fn visit_counts_sum_to_n_simulations() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 50, "visit counts must sum to n_simulations"); + } + + #[test] + fn all_root_children_are_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); + for (a, _) in &root.children { + assert!(legal.contains(a), "child action {a} is not legal"); + } + } + + // ── Policy tests ───────────────────────────────────────────────────── + + #[test] + fn policy_sums_to_one() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); + } + + #[test] + fn policy_zero_for_illegal_actions() { + let env = CountdownEnv; + // remaining = 1 → only action 0 is legal + let state = CState { remaining: 1, to_move: 0 }; + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); + } + + // ── Action selection tests ──────────────────────────────────────────── + + #[test] + fn greedy_selects_most_visited() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); + let greedy = select_action(&root, 0.0, &mut rng()); + let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); + assert_eq!(greedy, most_visited); + } + + #[test] + fn temperature_sampling_stays_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let mut r = rng(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); + for _ in 0..20 { + let a = select_action(&root, 1.0, &mut r); + assert!(legal.contains(&a), "sampled action {a} is not legal"); + } + } + + // ── Zero-simulation edge case ───────────────────────────────────────── + + #[test] + fn zero_simulations_uniform_policy() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + // With 0 simulations, fallback is uniform over the 2 legal actions. + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + // ── Root value ──────────────────────────────────────────────────────── + + #[test] + fn root_q_in_valid_range() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); + let q = root.q(); + assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); + } + + // ── Integration: run on a real Trictrac game ────────────────────────── + + #[test] + fn no_panic_on_trictrac_state() { + use crate::env::TrictracEnv; + + let env = TrictracEnv; + let mut state = env.new_game(); + let mut r = rng(); + + // Advance past the initial chance node to reach a decision node. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, &mut r); + } + + if env.current_player(&state).is_terminal() { + return; // unlikely but safe + } + + let config = MctsConfig { + n_simulations: 5, // tiny for speed + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + ..MctsConfig::default() + }; + + let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); + // root.n = 1 (expansion) + n_simulations (one backup per simulation). + assert_eq!(root.n, 1 + config.n_simulations as u32); + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, config.n_simulations as u32); + } +} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs new file mode 100644 index 0000000..aff7735 --- /dev/null +++ b/spiel_bot/src/mcts/node.rs @@ -0,0 +1,91 @@ +//! MCTS tree node. +//! +//! [`MctsNode`] holds the visit statistics for one player-decision position in +//! the search tree. A node is *expanded* the first time the policy-value +//! network is evaluated there; before that it is a leaf. + +/// One node in the MCTS tree, representing a player-decision position. +/// +/// `w` stores the sum of values backed up into this node, always from the +/// perspective of **the player who acts here**. `q()` therefore also returns +/// a value in `(-1, 1)` from that same perspective. +#[derive(Debug)] +pub struct MctsNode { + /// Visit count `N(s, a)`. + pub n: u32, + /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. + pub w: f32, + /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). + pub p: f32, + /// Children: `(action_index, child_node)`, populated on first expansion. + pub children: Vec<(usize, MctsNode)>, + /// `true` after the network has been evaluated and children have been set up. + pub expanded: bool, +} + +impl MctsNode { + /// Create a fresh, unexpanded leaf with the given prior probability. + pub fn new(prior: f32) -> Self { + Self { + n: 0, + w: 0.0, + p: prior, + children: Vec::new(), + expanded: false, + } + } + + /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. + #[inline] + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT selection score: + /// + /// ```text + /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) + /// ``` + #[inline] + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_zero_when_unvisited() { + let node = MctsNode::new(0.5); + assert_eq!(node.q(), 0.0); + } + + #[test] + fn q_reflects_w_over_n() { + let mut node = MctsNode::new(0.5); + node.n = 4; + node.w = 2.0; + assert!((node.q() - 0.5).abs() < 1e-6); + } + + #[test] + fn puct_exploration_dominates_unvisited() { + // Unvisited child should outscore a visited child with negative Q. + let mut visited = MctsNode::new(0.5); + visited.n = 10; + visited.w = -5.0; // Q = -0.5 + + let unvisited = MctsNode::new(0.5); + + let parent_n = 10; + let c = 1.5; + assert!( + unvisited.puct(parent_n, c) > visited.puct(parent_n, c), + "unvisited child should have higher PUCT than a negatively-valued visited child" + ); + } +} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs new file mode 100644 index 0000000..1d9750d --- /dev/null +++ b/spiel_bot/src/mcts/search.rs @@ -0,0 +1,190 @@ +//! Simulation, expansion, backup, and noise helpers. +//! +//! These are internal to the `mcts` module; the public entry points are +//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. + +use rand::Rng; +use rand_distr::{Gamma, Distribution}; + +use crate::env::GameEnv; +use super::{Evaluator, MctsConfig}; +use super::node::MctsNode; + +// ── Masked softmax ───────────────────────────────────────────────────────── + +/// Numerically stable softmax over `legal` actions only. +/// +/// Illegal logits are treated as `-∞` and receive probability `0.0`. +/// Returns a probability vector of length `action_space`. +pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { + let mut probs = vec![0.0f32; action_space]; + if legal.is_empty() { + return probs; + } + let max_logit = legal + .iter() + .map(|&a| logits[a]) + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for &a in legal { + let e = (logits[a] - max_logit).exp(); + probs[a] = e; + sum += e; + } + if sum > 0.0 { + for &a in legal { + probs[a] /= sum; + } + } else { + let uniform = 1.0 / legal.len() as f32; + for &a in legal { + probs[a] = uniform; + } + } + probs +} + +// ── Dirichlet noise ──────────────────────────────────────────────────────── + +/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. +/// +/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. +/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. +pub(super) fn add_dirichlet_noise( + node: &mut MctsNode, + alpha: f32, + eps: f32, + rng: &mut impl Rng, +) { + let n = node.children.len(); + if n == 0 { + return; + } + let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { + return; + }; + let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); + let sum: f32 = samples.iter().sum(); + if sum <= 0.0 { + return; + } + for (i, (_, child)) in node.children.iter_mut().enumerate() { + let noise = samples[i] / sum; + child.p = (1.0 - eps) * child.p + eps * noise; + } +} + +// ── Expansion ────────────────────────────────────────────────────────────── + +/// Evaluate the network at `state` and populate `node` with children. +/// +/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. +/// Returns the network value estimate from `player_idx`'s perspective. +pub(super) fn expand( + node: &mut MctsNode, + state: &E::State, + env: &E, + evaluator: &dyn Evaluator, + player_idx: usize, +) -> f32 { + let obs = env.observation(state, player_idx); + let legal = env.legal_actions(state); + let (logits, value) = evaluator.evaluate(&obs); + let priors = masked_softmax(&logits, &legal, env.action_space()); + node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); + node.expanded = true; + node.n = 1; + node.w = value; + value +} + +// ── Simulation ───────────────────────────────────────────────────────────── + +/// One MCTS simulation from an **already-expanded** decision node. +/// +/// Traverses the tree with PUCT selection, expands the first unvisited leaf, +/// and backs up the result. +/// +/// * `player_idx` — the player (0 or 1) who acts at `state`. +/// * Returns the backed-up value **from `player_idx`'s perspective**. +pub(super) fn simulate( + node: &mut MctsNode, + state: E::State, + env: &E, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, + player_idx: usize, +) -> f32 { + debug_assert!(node.expanded, "simulate called on unexpanded node"); + + // ── Selection: child with highest PUCT ──────────────────────────────── + let parent_n = node.n; + let best = node + .children + .iter() + .enumerate() + .max_by(|(_, (_, a)), (_, (_, b))| { + a.puct(parent_n, config.c_puct) + .partial_cmp(&b.puct(parent_n, config.c_puct)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) + .expect("expanded node must have at least one child"); + + let (action, child) = &mut node.children[best]; + let action = *action; + + // ── Apply action + advance through any chance nodes ─────────────────── + let mut next_state = state; + env.apply(&mut next_state, action); + + // Track whether we crossed a chance node (dice roll) on the way down. + // If we did, the child's cached legal actions are for a *different* dice + // outcome and must not be reused — evaluate with the network directly. + let mut crossed_chance = false; + while env.current_player(&next_state).is_chance() { + env.apply_chance(&mut next_state, rng); + crossed_chance = true; + } + + let next_cp = env.current_player(&next_state); + + // ── Evaluate leaf or terminal ────────────────────────────────────────── + // All values are converted to `player_idx`'s perspective before backup. + let child_value = if next_cp.is_terminal() { + let returns = env + .returns(&next_state) + .expect("terminal node must have returns"); + returns[player_idx] + } else { + let child_player = next_cp.index().unwrap(); + let v = if crossed_chance { + // Outcome sampling: after dice, evaluate the resulting position + // directly with the network. Do NOT build the tree across chance + // boundaries — the dice change which actions are legal, so any + // previously cached children would be for a different outcome. + let obs = env.observation(&next_state, child_player); + let (_, value) = evaluator.evaluate(&obs); + // Record the visit so that PUCT and mcts_policy use real counts. + // Without this, child.n stays 0 for every simulation in games where + // every player action is immediately followed by a chance node (e.g. + // Trictrac), causing mcts_policy to always return a uniform policy. + child.n += 1; + child.w += value; + value + } else if child.expanded { + simulate(child, next_state, env, evaluator, config, rng, child_player) + } else { + expand::(child, &next_state, env, evaluator, child_player) + }; + // Negate when the child belongs to the opponent. + if child_player == player_idx { v } else { -v } + }; + + // ── Backup ──────────────────────────────────────────────────────────── + node.n += 1; + node.w += child_value; + + child_value +} diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs new file mode 100644 index 0000000..eb6184e --- /dev/null +++ b/spiel_bot/src/network/mlp.rs @@ -0,0 +1,223 @@ +//! Two-hidden-layer MLP policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`MlpNet`]. +#[derive(Debug, Clone)] +pub struct MlpConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for MlpConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 256, + } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Simple two-hidden-layer MLP with shared trunk and two heads. +/// +/// Prefer this over [`ResNet`](super::ResNet) when training time is a +/// priority, or as a fast baseline. +#[derive(Module, Debug)] +pub struct MlpNet { + fc1: Linear, + fc2: Linear, + policy_head: Linear, + value_head: Linear, +} + +impl MlpNet { + /// Construct a fresh network with random weights. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + value_head: LinearConfig::new(config.hidden_size, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + /// + /// The file is written exactly at `path`; callers should append `.mpk` if + /// they want the conventional extension. + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for MlpNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn default_net() -> MlpNet { + MlpNet::new(&MlpConfig::default(), &device()) + } + + fn zeros_obs(batch: usize) -> Tensor { + Tensor::zeros([batch, 217], &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let net = default_net(); + let obs = zeros_obs(4); + let (policy, value) = net.forward(obs); + + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let (policy, value) = net.forward(zeros_obs(1)); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let net = default_net(); + // Use a non-zero input so the output is not trivially at 0. + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net.forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Policy logits ───────────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + // With random weights the 514 logits should not all be identical. + let net = default_net(); + let (policy, _) = net.forward(zeros_obs(1)); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical — network may be degenerate"); + } + + // ── Config propagation ──────────────────────────────────────────────── + + #[test] + fn custom_config_shapes() { + let config = MlpConfig { + obs_size: 10, + action_size: 20, + hidden_size: 32, + }; + let net = MlpNet::::new(&config, &device()); + let obs = Tensor::zeros([3, 10], &device()); + let (policy, value) = net.forward(obs); + assert_eq!(policy.dims(), [3, 20]); + assert_eq!(value.dims(), [3, 1]); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = MlpConfig::default(); + let net = default_net(); + + // Forward pass before saving. + let obs = Tensor::::ones([2, 217], &device()); + let (policy_before, value_before) = net.forward(obs.clone()); + + // Save to a temp file. + let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); + net.save(&path).expect("save failed"); + + // Load into a fresh model. + let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + // Outputs must be bitwise identical. + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs new file mode 100644 index 0000000..64f93ec --- /dev/null +++ b/spiel_bot/src/network/mod.rs @@ -0,0 +1,78 @@ +//! Neural network abstractions for policy-value learning. +//! +//! # Trait +//! +//! [`PolicyValueNet`] is the single trait that all network architectures +//! implement. It takes an observation tensor and returns raw policy logits +//! plus a tanh-squashed scalar value estimate. +//! +//! # Architectures +//! +//! | Module | Description | Default hidden | +//! |--------|-------------|----------------| +//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | +//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | +//! +//! # Backend convention +//! +//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). +//! * **Training** — use `Autodiff>` so Burn can differentiate +//! through the forward pass. +//! +//! Both modes use the exact same struct; only the type-level backend changes: +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! type InferBackend = NdArray; +//! type TrainBackend = Autodiff>; +//! +//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! ``` +//! +//! # Output shapes +//! +//! Given a batch of `B` observations of size `obs_size`: +//! +//! | Output | Shape | Range | +//! |--------|-------|-------| +//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | +//! | `value` | `[B, 1]` | (-1, 1) via tanh | +//! +//! Callers are responsible for masking illegal actions in `policy_logits` +//! before passing to softmax. + +pub mod mlp; +pub mod qnet; +pub mod resnet; + +pub use mlp::{MlpConfig, MlpNet}; +pub use qnet::{QNet, QNetConfig}; +pub use resnet::{ResNet, ResNetConfig}; + +use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; + +/// A neural network that produces a policy and a value from an observation. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) +/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// +/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses +/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. +/// Use an `Arc>` wrapper if cross-thread sharing is needed. +pub trait PolicyValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); +} + +/// A neural network that outputs one Q-value per action. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - output: `[batch, action_size]` — raw Q-values (no activation) +/// +/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. +pub trait QValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> Tensor; +} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs new file mode 100644 index 0000000..1737f72 --- /dev/null +++ b/spiel_bot/src/network/qnet.rs @@ -0,0 +1,147 @@ +//! Single-headed Q-value network for DQN. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! → Linear(hidden → action_size) ← raw Q-values, no activation +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{activation::relu, backend::Backend, Tensor}, +}; +use std::path::Path; + +use super::QValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`QNet`]. +#[derive(Debug, Clone)] +pub struct QNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for QNetConfig { + fn default() -> Self { + Self { obs_size: 217, action_size: 514, hidden_size: 256 } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Two-hidden-layer MLP that outputs one Q-value per action. +#[derive(Module, Debug)] +pub struct QNet { + fc1: Linear, + fc2: Linear, + q_head: Linear, +} + +impl QNet { + /// Construct a fresh network with random weights. + pub fn new(config: &QNetConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl QValueNet for QNet { + fn forward(&self, obs: Tensor) -> Tensor { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + self.q_head.forward(x) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + + fn default_net() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn forward_output_shape() { + let net = default_net(); + let obs = Tensor::zeros([4, 217], &device()); + let q = net.forward(obs); + assert_eq!(q.dims(), [4, 514]); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let q = net.forward(Tensor::zeros([1, 217], &device())); + assert_eq!(q.dims(), [1, 514]); + } + + #[test] + fn q_values_not_all_equal() { + let net = default_net(); + let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) + .into_data().to_vec().unwrap(); + let first = q[0]; + assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); + } + + #[test] + fn custom_config_shapes() { + let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; + let net = QNet::::new(&cfg, &device()); + let q = net.forward(Tensor::zeros([3, 10], &device())); + assert_eq!(q.dims(), [3, 20]); + } + + #[test] + fn save_load_preserves_weights() { + let net = default_net(); + let obs = Tensor::::ones([2, 217], &device()); + let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); + + let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); + net.save(&path).expect("save failed"); + + let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); + let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); + + for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); + } + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs new file mode 100644 index 0000000..d20d5ad --- /dev/null +++ b/spiel_bot/src/network/resnet.rs @@ -0,0 +1,253 @@ +//! Residual-block policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU (input projection) +//! → ResBlock × 4 (residual trunk) +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! +//! ResBlock: +//! x → Linear → ReLU → Linear → (+x) → ReLU +//! ``` +//! +//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better +//! suited for long training runs where board-pattern recognition matters. + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`ResNet`]. +#[derive(Debug, Clone)] +pub struct ResNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of all hidden layers (input projection + residual blocks). + pub hidden_size: usize, +} + +impl Default for ResNetConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 512, + } + } +} + +// ── Residual block ──────────────────────────────────────────────────────────── + +/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. +/// +/// Both linear layers preserve the hidden dimension so the skip connection +/// can be added without projection. +#[derive(Module, Debug)] +struct ResBlock { + fc1: Linear, + fc2: Linear, +} + +impl ResBlock { + fn new(hidden: usize, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(hidden, hidden).init(device), + fc2: LinearConfig::new(hidden, hidden).init(device), + } + } + + fn forward(&self, x: Tensor) -> Tensor { + let residual = x.clone(); + let out = relu(self.fc1.forward(x)); + relu(self.fc2.forward(out) + residual) + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Four-residual-block policy-value network. +/// +/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and +/// when representing complex positional patterns is important. +#[derive(Module, Debug)] +pub struct ResNet { + input: Linear, + block0: ResBlock, + block1: ResBlock, + block2: ResBlock, + block3: ResBlock, + policy_head: Linear, + value_head: Linear, +} + +impl ResNet { + /// Construct a fresh network with random weights. + pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { + let h = config.hidden_size; + Self { + input: LinearConfig::new(config.obs_size, h).init(device), + block0: ResBlock::new(h, device), + block1: ResBlock::new(h, device), + block2: ResBlock::new(h, device), + block3: ResBlock::new(h, device), + policy_head: LinearConfig::new(h, config.action_size).init(device), + value_head: LinearConfig::new(h, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for ResNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.input.forward(obs)); + let x = self.block0.forward(x); + let x = self.block1.forward(x); + let x = self.block2.forward(x); + let x = self.block3.forward(x); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn small_config() -> ResNetConfig { + // Use a small hidden size so tests are fast. + ResNetConfig { + obs_size: 217, + action_size: 514, + hidden_size: 64, + } + } + + fn net() -> ResNet { + ResNet::new(&small_config(), &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let obs = Tensor::zeros([4, 217], &device()); + let (policy, value) = net().forward(obs); + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net().forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Residual connections ────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical"); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = small_config(); + let model = net(); + let obs = Tensor::::ones([2, 217], &device()); + + let (policy_before, value_before) = model.forward(obs.clone()); + + let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); + model.save(&path).expect("save failed"); + + let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } + + // ── Integration: both architectures satisfy PolicyValueNet ──────────── + + #[test] + fn resnet_satisfies_trait() { + fn requires_net>(net: &N, obs: Tensor) { + let (p, v) = net.forward(obs); + assert_eq!(p.dims()[1], 514); + assert_eq!(v.dims()[1], 1); + } + requires_net(&net(), Tensor::zeros([2, 217], &device())); + } +} diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs new file mode 100644 index 0000000..d73fda0 --- /dev/null +++ b/spiel_bot/tests/integration.rs @@ -0,0 +1,391 @@ +//! End-to-end integration tests for the AlphaZero training pipeline. +//! +//! Each test exercises the full chain: +//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] +//! +//! Two environments are used: +//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. +//! Used when we need many iterations without worrying about runtime. +//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that +//! the full pipeline compiles and runs correctly with 217-dim observations +//! and 514-dim action spaces. +//! +//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep +//! runtime minimal; correctness, not training quality, is what matters here. + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Backend aliases ──────────────────────────────────────────────────────── + +type Train = Autodiff>; +type Infer = NdArray; + +// ── Helpers ──────────────────────────────────────────────────────────────── + +fn train_device() -> ::Device { + Default::default() +} + +fn infer_device() -> ::Device { + Default::default() +} + +/// Tiny 64-unit MLP, compatible with an obs/action space of any size. +fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { + let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; + MlpNet::new(&cfg, &train_device()) +} + +fn tiny_mcts(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + } +} + +fn seeded() -> SmallRng { + SmallRng::seed_from_u64(0) +} + +// ── Countdown environment (fast, local, no external deps) ───────────────── +// +// Two players alternate subtracting 1 or 2 from a counter that starts at N. +// The player who brings the counter to 0 wins. + +#[derive(Clone, Debug)] +struct CState { + remaining: u8, + to_move: usize, +} + +#[derive(Clone)] +struct CountdownEnv(u8); // starting value + +impl GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: self.0, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } + else { Player::P2 } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } +} + +// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── + +/// The canonical AlphaZero loop: self-play → replay → train, iterated. +/// Uses CountdownEnv so each game terminates in < 10 moves. +#[test] +fn countdown_full_loop_no_panic() { + let env = CountdownEnv(8); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(1_000); + + for _iter in 0..5 { + // Self-play: 3 games per iteration. + for _ in 0..3 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + replay.extend(samples); + } + + // Training: 4 gradient steps per iteration. + if replay.len() >= 4 { + for _ in 0..4 { + let batch: Vec = replay + .sample_batch(4, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + } + } + + assert!(replay.len() > 0); +} + +// ── 2. Replay buffer invariants ─────────────────────────────────────────── + +/// After several Countdown games, replay capacity is respected and batch +/// shapes are consistent. +#[test] +fn replay_buffer_capacity_and_shapes() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let capacity = 50; + let mut replay = ReplayBuffer::new(capacity); + + for _ in 0..20 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + replay.extend(samples); + } + + assert!(replay.len() <= capacity, "buffer exceeded capacity"); + assert!(replay.len() > 0); + + let batch = replay.sample_batch(8, &mut rng); + assert_eq!(batch.len(), 8.min(replay.len())); + for s in &batch { + assert_eq!(s.obs.len(), env.obs_size()); + assert_eq!(s.policy.len(), env.action_space()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); + assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); + } +} + +// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── + +/// Verify that one TrictracEnv episode produces samples with the correct +/// tensor dimensions: obs = 217, policy = 514. +#[test] +fn trictrac_sample_shapes() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + + assert!(!samples.is_empty(), "Trictrac episode produced no samples"); + + for (i, s) in samples.iter().enumerate() { + assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); + assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!( + (policy_sum - 1.0).abs() < 1e-4, + "sample {i}: policy sums to {policy_sum}" + ); + assert!( + s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "sample {i}: unexpected value {}", + s.value + ); + } +} + +// ── 4. TrictracEnv: training step after real self-play ──────────────────── + +/// Collect one Trictrac episode, then verify that a gradient step runs +/// without panic and produces a finite loss. +#[test] +fn trictrac_train_step_finite_loss() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(10_000); + + // Generate one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + let n_samples = samples.len(); + replay.extend(samples); + + // Train on a batch from this episode. + let batch_size = 8.min(n_samples); + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); +} + +// ── 5. Backend transfer: train → infer → same outputs ───────────────────── + +/// Weights transferred from the training backend to the inference backend +/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. +#[test] +fn valid_model_matches_train_model_outputs() { + use burn::tensor::{Tensor, TensorData}; + + let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let train_model = MlpNet::::new(&cfg, &train_device()); + let infer_model: MlpNet = train_model.valid(); + + // Build the same input on both backends. + let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; + + let obs_train = Tensor::::from_data( + TensorData::new(obs_data.clone(), [1, 4]), + &train_device(), + ); + let obs_infer = Tensor::::from_data( + TensorData::new(obs_data, [1, 4]), + &infer_device(), + ); + + let (p_train, v_train) = train_model.forward(obs_train); + let (p_infer, v_infer) = infer_model.forward(obs_infer); + + let p_train: Vec = p_train.into_data().to_vec().unwrap(); + let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); + let v_train: Vec = v_train.into_data().to_vec().unwrap(); + let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); + + for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { + assert!( + (a - b).abs() < 1e-5, + "policy[{i}] differs after valid(): train={a}, infer={b}" + ); + } + assert!( + (v_train[0] - v_infer[0]).abs() < 1e-5, + "value differs after valid(): train={}, infer={}", + v_train[0], v_infer[0] + ); +} + +// ── 6. Loss converges on a fixed batch ──────────────────────────────────── + +/// With repeated gradient steps on the same Countdown batch, the loss must +/// decrease monotonically (or at least end lower than it started). +#[test] +fn loss_decreases_on_fixed_batch() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + + // Collect a fixed batch from one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); + assert!(!samples.is_empty()); + + let batch: Vec = { + let mut replay = ReplayBuffer::new(1000); + replay.extend(samples); + replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() + }; + + // Overfit on the same fixed batch for 20 steps. + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut first_loss = f32::NAN; + let mut last_loss = f32::NAN; + + for step in 0..20 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); + model = m; + assert!(loss.is_finite(), "loss is not finite at step {step}"); + if step == 0 { first_loss = loss; } + last_loss = loss; + } + + assert!( + last_loss < first_loss, + "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" + ); +} + +// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── + +/// Two full self-play + train iterations on TrictracEnv. +/// Verifies the entire pipeline runs without panic end-to-end. +#[test] +fn trictrac_two_iteration_loop() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mut model = MlpNet::::new(&cfg, &train_device()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(20_000); + + for iter in 0..2 { + // Self-play: 1 game per iteration. + let infer: MlpNet = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); + assert!(!samples.is_empty(), "iter {iter}: episode was empty"); + replay.extend(samples); + + // Training: 3 gradient steps. + let batch_size = 16.min(replay.len()); + for _ in 0..3 { + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "iter {iter}: loss={loss}"); + } + } +} diff --git a/store/Cargo.toml b/store/Cargo.toml index a9234ff..935a2a0 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -25,5 +25,9 @@ rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" +[[bin]] +name = "random_game" +path = "src/bin/random_game.rs" + [build-dependencies] cxx-build = "1.0" diff --git a/store/src/bin/random_game.rs b/store/src/bin/random_game.rs new file mode 100644 index 0000000..6da3b9c --- /dev/null +++ b/store/src/bin/random_game.rs @@ -0,0 +1,262 @@ +//! Run one or many games of trictrac between two random players. +//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`. +//! In multi-game mode, runs silently and reports throughput at the end. +//! +//! Usage: +//! cargo run --bin random_game -- [--seed ] [--games ] [--max-steps ] [--verbose] + +use std::borrow::Cow; +use std::env; +use std::time::Instant; + +use trictrac_store::{ + training_common::sample_valid_action, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +// ── CLI args ────────────────────────────────────────────────────────────────── + +struct Args { + seed: Option, + games: usize, + max_steps: usize, + verbose: bool, +} + +fn parse_args() -> Args { + let args: Vec = env::args().collect(); + let mut seed = None; + let mut games = 1; + let mut max_steps = 10_000; + let mut verbose = false; + + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--seed" => { + i += 1; + seed = args.get(i).and_then(|s| s.parse().ok()); + } + "--games" => { + i += 1; + if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { + games = v; + } + } + "--max-steps" => { + i += 1; + if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { + max_steps = v; + } + } + "--verbose" => verbose = true, + _ => {} + } + i += 1; + } + + Args { + seed, + games, + max_steps, + verbose, + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn player_label(id: u64) -> &'static str { + if id == 1 { "White" } else { "Black" } +} + +/// Apply a `Roll` + `RollResult` in one logical step, returning the dice. +/// This collapses the two-step dice phase into a single "chance node" action, +/// matching how the OpenSpiel layer exposes it. +fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result { + // RollDice → RollWaiting + state + .consume(&GameEvent::Roll { player_id: state.active_player_id }) + .map_err(|e| format!("Roll event failed: {e}"))?; + + // RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole) + let dice = roller.roll(); + state + .consume(&GameEvent::RollResult { player_id: state.active_player_id, dice }) + .map_err(|e| format!("RollResult event failed: {e}"))?; + + Ok(dice) +} + +/// Sample a random action and apply it to `state`, handling the Black-mirror +/// transform exactly as `cxxengine.rs::apply_action` does: +/// +/// 1. For Black, build a mirrored view of the state so that `sample_valid_action` +/// and `to_event` always reason from White's perspective. +/// 2. Mirror the resulting event back to the original coordinate frame before +/// calling `state.consume`. +/// +/// Returns the chosen action (in the view's coordinate frame) for display. +fn apply_player_action(state: &mut GameState) -> Result<(), String> { + let needs_mirror = state.active_player_id == 2; + + // Build a White-perspective view: borrowed for White, owned mirror for Black. + let view: Cow = if needs_mirror { + Cow::Owned(state.mirror()) + } else { + Cow::Borrowed(state) + }; + + let action = sample_valid_action(&view) + .ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?; + + let event = action + .to_event(&view) + .ok_or_else(|| format!("could not convert {action:?} to event"))?; + + // Translate the event from the view's frame back to the game's frame. + let event = if needs_mirror { event.get_mirror(false) } else { event }; + + state + .consume(&event) + .map_err(|e| format!("consume({action:?}): {e}"))?; + + Ok(()) +} + +// ── Single game ──────────────────────────────────────────────────────────────── + +/// Run one full game, optionally printing play-by-play. +/// Returns `(steps, truncated)`. +fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) { + let mut state = GameState::new_with_players("White", "Black"); + let mut step = 0usize; + + if !quiet { + println!("{state}"); + } + + while state.stage != Stage::Ended { + step += 1; + if step > max_steps { + return (step - 1, true); + } + + match state.turn_stage { + TurnStage::RollDice => { + let player = state.active_player_id; + match apply_dice_roll(&mut state, roller) { + Ok(dice) => { + if !quiet { + println!( + "[step {step:4}] {} rolls: {} & {}", + player_label(player), + dice.values.0, + dice.values.1 + ); + } + } + Err(e) => { + eprintln!("Error during dice roll: {e}"); + eprintln!("State:\n{state}"); + return (step, true); + } + } + } + stage => { + let player = state.active_player_id; + match apply_player_action(&mut state) { + Ok(()) => { + if !quiet { + println!( + "[step {step:4}] {} ({stage:?})", + player_label(player) + ); + if verbose { + println!("{state}"); + } + } + } + Err(e) => { + eprintln!("Error: {e}"); + eprintln!("State:\n{state}"); + return (step, true); + } + } + } + } + } + + if !quiet { + println!("\n=== Game over after {step} steps ===\n"); + println!("{state}"); + + let white = state.players.get(&1); + let black = state.players.get(&2); + + match (white, black) { + (Some(w), Some(b)) => { + println!("White — holes: {:2}, points: {:2}", w.holes, w.points); + println!("Black — holes: {:2}, points: {:2}", b.holes, b.points); + println!(); + + let white_score = w.holes as i32 * 12 + w.points as i32; + let black_score = b.holes as i32 * 12 + b.points as i32; + + if white_score > black_score { + println!("Winner: White (+{})", white_score - black_score); + } else if black_score > white_score { + println!("Winner: Black (+{})", black_score - white_score); + } else { + println!("Draw"); + } + } + _ => eprintln!("Could not read final player scores."), + } + } + + (step, false) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let mut roller = DiceRoller::new(args.seed); + + if args.games == 1 { + println!("=== Trictrac — random game ==="); + if let Some(s) = args.seed { + println!("seed: {s}"); + } + println!(); + run_game(&mut roller, args.max_steps, false, args.verbose); + } else { + println!("=== Trictrac — {} games ===", args.games); + if let Some(s) = args.seed { + println!("seed: {s}"); + } + println!(); + + let mut total_steps = 0u64; + let mut truncated = 0usize; + + let t0 = Instant::now(); + for _ in 0..args.games { + let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose); + total_steps += steps as u64; + if trunc { + truncated += 1; + } + } + let elapsed = t0.elapsed(); + + let secs = elapsed.as_secs_f64(); + println!("Games : {}", args.games); + println!("Truncated : {truncated}"); + println!("Total steps: {total_steps}"); + println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64); + println!("Elapsed : {:.3} s", secs); + println!("Throughput : {:.1} games/s", args.games as f64 / secs); + println!(" {:.0} steps/s", total_steps as f64 / secs); + } +} diff --git a/store/src/board.rs b/store/src/board.rs index de0e450..0fba2d6 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -598,12 +598,40 @@ impl Board { core::array::from_fn(|i| i + min) } + /// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i. + /// result[0] = 0. + pub fn white_checker_cumulative(&self) -> [u8; 25] { + let mut cum = [0u8; 25]; + let mut total = 0u8; + for (i, &count) in self.positions.iter().enumerate() { + if count > 0 { + total += count as u8; + } + cum[i + 1] = total; + } + cum + } + pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> { self.remove_checker(color, cmove.from)?; self.add_checker(color, cmove.to)?; Ok(()) } + /// Reverse a previously applied `move_checker`. No validation: assumes the move was valid. + pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) { + let unit = match color { + Color::White => 1, + Color::Black => -1, + }; + if cmove.from != 0 { + self.positions[cmove.from - 1] += unit; + } + if cmove.to != 0 { + self.positions[cmove.to - 1] -= unit; + } + } + pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> { if field == 0 { return Ok(()); diff --git a/store/src/cxxengine.rs b/store/src/cxxengine.rs index 29bc7fe..55d348c 100644 --- a/store/src/cxxengine.rs +++ b/store/src/cxxengine.rs @@ -83,8 +83,8 @@ pub mod ffi { /// Both players' scores. fn get_players_scores(self: &TricTracEngine) -> PlayerScores; - /// 36-element state vector (i8). Mirrored for player_idx == 1. - fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; + /// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1. + fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; /// Human-readable state description for `player_idx`. fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; @@ -153,8 +153,7 @@ impl TricTracEngine { .map(|v| v.into_iter().map(|i| i as u64).collect()) } else { let mirror = self.game_state.mirror(); - get_valid_action_indices(&mirror) - .map(|v| v.into_iter().map(|i| i as u64).collect()) + get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect()) } })) } @@ -180,11 +179,11 @@ impl TricTracEngine { .unwrap_or(-1) } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_vec() + self.game_state.to_tensor() } else { - self.game_state.mirror().to_vec() + self.game_state.mirror().to_tensor() } } @@ -243,8 +242,9 @@ impl TricTracEngine { self.game_state ), None => anyhow::bail!( - "apply_action: could not build event from action index {}", - action_idx + "apply_action: could not build event from action index {} in state {}", + action_idx, + self.game_state ), } })) diff --git a/store/src/game.rs b/store/src/game.rs index d32734d..e4e938c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -156,13 +156,6 @@ impl GameState { if let Some(p1) = self.players.get(&1) { mirrored_players.insert(2, p1.mirror()); } - let mirrored_history = self - .history - .clone() - .iter() - .map(|evt| evt.get_mirror(false)) - .collect(); - let (move1, move2) = self.dice_moves; GameState { stage: self.stage, @@ -171,7 +164,7 @@ impl GameState { active_player_id: mirrored_active_player, // active_player_id: self.active_player_id, players: mirrored_players, - history: mirrored_history, + history: Vec::new(), dice: self.dice, dice_points: self.dice_points, dice_moves: (move1.mirror(), move2.mirror()), @@ -207,6 +200,110 @@ impl GameState { self.to_vec().iter().map(|&x| x as f32).collect() } + /// Get state as a tensor for neural network training (Option B, TD-Gammon style). + /// Returns 217 f32 values, all normalized to [0, 1]. + /// + /// Must be called from the active player's perspective: callers should mirror + /// the GameState for Black before calling so that "own" always means White. + /// + /// Layout: + /// [0..95] own (White) checkers: 4 values per field × 24 fields + /// [96..191] opp (Black) checkers: 4 values per field × 24 fields + /// [192..193] dice values / 6 + /// [194] active player color (0=White, 1=Black) + /// [195] turn_stage / 5 + /// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille + /// [200..203] Black player: same + /// [204..207] own quarter filled (quarters 1-4) + /// [208..211] opp quarter filled (quarters 1-4) + /// [212] own checkers all in exit zone (fields 19-24) + /// [213] opp checkers all in exit zone (fields 1-6) + /// [214] own coin de repos taken (field 12 has ≥2 own checkers) + /// [215] opp coin de repos taken (field 13 has ≥2 opp checkers) + /// [216] own dice_roll_count / 3, clamped to 1 + pub fn to_tensor(&self) -> Vec { + let mut t = Vec::with_capacity(217); + let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black + + // [0..95] own (White) checkers, TD-Gammon encoding. + // Each field contributes 4 values: + // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] + // The overflow term is divided by 12 because the maximum excess is + // 15 (all checkers) − 3 = 12. + for &c in &pos { + let own = c.max(0) as u8; + t.push((own == 1) as u8 as f32); + t.push((own == 2) as u8 as f32); + t.push((own == 3) as u8 as f32); + t.push(own.saturating_sub(3) as f32 / 12.0); + } + + // [96..191] opp (Black) checkers, same encoding. + for &c in &pos { + let opp = (-c).max(0) as u8; + t.push((opp == 1) as u8 as f32); + t.push((opp == 2) as u8 as f32); + t.push((opp == 3) as u8 as f32); + t.push(opp.saturating_sub(3) as f32 / 12.0); + } + + // [192..193] dice + t.push(self.dice.values.0 as f32 / 6.0); + t.push(self.dice.values.1 as f32 / 6.0); + + // [194] active player color + t.push( + self.who_plays() + .map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 }) + .unwrap_or(0.0), + ); + + // [195] turn stage + t.push(u8::from(self.turn_stage) as f32 / 5.0); + + // [196..199] White player stats + let wp = self.get_white_player(); + t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0)); + t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0)); + t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); + t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); + + // [200..203] Black player stats + let bp = self.get_black_player(); + t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0)); + t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0)); + t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); + t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); + + // [204..207] own (White) quarter fill status + for &start in &[1usize, 7, 13, 19] { + t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32); + } + + // [208..211] opp (Black) quarter fill status + for &start in &[1usize, 7, 13, 19] { + t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32); + } + + // [212] can_exit_own: no own checker in fields 1-18 + t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32); + + // [213] can_exit_opp: no opp checker in fields 7-24 + t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32); + + // [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers) + t.push((pos[11] >= 2) as u8 as f32); + + // [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers) + t.push((pos[12] <= -2) as u8 as f32); + + // [216] own dice_roll_count / 3, clamped to 1 + t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0)); + + debug_assert_eq!(t.len(), 217, "to_tensor length mismatch"); + t + } + /// Get state as a vector (to be used for bot training input) : /// length = 36 /// i8 for board positions with negative values for blacks @@ -914,6 +1011,16 @@ impl GameState { self.mark_points(player_id, points) } + /// Total accumulated score for a player: `holes × 12 + points`. + /// + /// Returns `0` if `player_id` is not found (e.g. before `init_player`). + pub fn total_score(&self, player_id: PlayerId) -> i32 { + self.players + .get(&player_id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + } + fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false; diff --git a/store/src/game_rules_moves.rs b/store/src/game_rules_moves.rs index 41221f2..396bcaf 100644 --- a/store/src/game_rules_moves.rs +++ b/store/src/game_rules_moves.rs @@ -220,7 +220,7 @@ impl MoveRules { // Si possible, les deux dés doivent être joués if moves.0.get_from() == 0 || moves.1.get_from() == 0 { let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]); - possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok()); + possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok()); // possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok()); if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() { if *moves == (EMPTY_MOVE, EMPTY_MOVE) { @@ -238,7 +238,7 @@ impl MoveRules { // check exit rules // if !ignored_rules.contains(&TricTracRule::Exit) { - self.check_exit_rules(moves)?; + self.check_exit_rules(moves, None)?; // } // --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ---- @@ -321,7 +321,11 @@ impl MoveRules { .is_empty() } - fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> { + fn check_exit_rules( + &self, + moves: &(CheckerMove, CheckerMove), + exit_seqs: Option<&[(CheckerMove, CheckerMove)]>, + ) -> Result<(), MoveError> { if !moves.0.is_exit() && !moves.1.is_exit() { return Ok(()); } @@ -331,16 +335,22 @@ impl MoveRules { } // toutes les sorties directes sont autorisées, ainsi que les nombres défaillants - let ignored_rules = vec![TricTracRule::Exit]; - let possible_moves_sequences_without_excedent = - self.get_possible_moves_sequences(false, ignored_rules); - if possible_moves_sequences_without_excedent.contains(moves) { + let owned; + let seqs = match exit_seqs { + Some(s) => s, + None => { + owned = self + .get_possible_moves_sequences(false, vec![TricTracRule::Exit]); + &owned + } + }; + if seqs.contains(moves) { return Ok(()); } // À ce stade au moins un des déplacements concerne un nombre en excédant // - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on // refuse cette séquence - if !possible_moves_sequences_without_excedent.is_empty() { + if !seqs.is_empty() { return Err(MoveError::ExitByEffectPossible); } @@ -361,17 +371,24 @@ impl MoveRules { let _ = board_to_check.move_checker(&Color::White, moves.0); let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check); - let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves); - if (is_move1_exedant && moves.0.get_from() != farthest_on_move1) - || (is_move2_exedant && moves.1.get_from() != farthest_on_move2) - { + // dice normal order + let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true); + let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) + || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); + + // dice reversed order + let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false); + let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) + || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); + + if is_not_farthest1 && is_not_farthest2 { return Err(MoveError::ExitNotFarthest); } Ok(()) } - fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) { + fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) { let move1to = if moves.0.get_to() == 0 { 25 } else { @@ -386,20 +403,16 @@ impl MoveRules { }; let dist2 = move2to - moves.1.get_from(); - let dist_min = cmp::min(dist1, dist2); - let dist_max = cmp::max(dist1, dist2); - - let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize; - let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize; - - let min_excedant = dist_min != 0 && dist_min < dice_min; - let max_excedant = dist_max != 0 && dist_max < dice_max; - - if dist_min == dist1 { - (min_excedant, max_excedant) + let (dice1, dice2) = if dice_order { + self.dice.values } else { - (max_excedant, min_excedant) - } + (self.dice.values.1, self.dice.values.0) + }; + + ( + dist1 != 0 && dist1 < dice1 as usize, + dist2 != 0 && dist2 < dice2 as usize, + ) } fn get_board_exit_farthest(board: &Board) -> Field { @@ -438,12 +451,18 @@ impl MoveRules { } else { (dice2, dice1) }; + let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) { + Some(self.get_quarter_filling_moves_sequences()) + } else { + None + }; let mut moves_seqs = self.get_possible_moves_sequences_by_dices( dice_max, dice_min, with_excedents, false, - ignored_rules.clone(), + &ignored_rules, + filling_seqs.as_deref(), ); // if we got valid sequences with the highest die, we don't accept sequences using only the // lowest die @@ -453,7 +472,8 @@ impl MoveRules { dice_max, with_excedents, ignore_empty, - ignored_rules, + &ignored_rules, + filling_seqs.as_deref(), ); moves_seqs.append(&mut moves_seqs_order2); let empty_removed = moves_seqs @@ -524,14 +544,16 @@ impl MoveRules { let mut moves_seqs = Vec::new(); let color = &Color::White; let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter]; + let mut board = self.board.clone(); for moves in self.get_possible_moves_sequences(true, ignored_rules) { - let mut board = self.board.clone(); board.move_checker(color, moves.0).unwrap(); board.move_checker(color, moves.1).unwrap(); // println!("get_quarter_filling_moves_sequences board : {:?}", board); if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) { moves_seqs.push(moves); } + board.unmove_checker(color, moves.1); + board.unmove_checker(color, moves.0); } moves_seqs } @@ -542,18 +564,27 @@ impl MoveRules { dice2: u8, with_excedents: bool, ignore_empty: bool, - ignored_rules: Vec, + ignored_rules: &[TricTracRule], + filling_seqs: Option<&[(CheckerMove, CheckerMove)]>, ) -> Vec<(CheckerMove, CheckerMove)> { let mut moves_seqs = Vec::new(); let color = &Color::White; let forbid_exits = self.has_checkers_outside_last_quarter(); + // Precompute non-excedant sequences once so check_exit_rules need not repeat + // the full move generation for every exit-move candidate. + // Only needed when Exit is not already ignored and exits are actually reachable. + let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits { + Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit])) + } else { + None + }; + let mut board = self.board.clone(); // println!("==== First"); for first_move in self.board .get_possible_moves(*color, dice1, with_excedents, false, forbid_exits) { - let mut board2 = self.board.clone(); - if board2.move_checker(color, first_move).is_err() { + if board.move_checker(color, first_move).is_err() { println!("err move"); continue; } @@ -563,7 +594,7 @@ impl MoveRules { let mut has_second_dice_move = false; // println!(" ==== Second"); for second_move in - board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) + board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) { if self .check_corner_rules(&(first_move, second_move)) @@ -587,24 +618,10 @@ impl MoveRules { && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) || self - .check_exit_rules(&(first_move, second_move)) - // .inspect_err(|e| { - // println!( - // " 2nd (exit rule): {:?} - {:?}, {:?}", - // e, first_move, second_move - // ) - // }) - .is_ok()) - && (ignored_rules.contains(&TricTracRule::MustFillQuarter) - || self - .check_must_fill_quarter_rule(&(first_move, second_move)) - // .inspect_err(|e| { - // println!( - // " 2nd: {:?} - {:?}, {:?} for {:?}", - // e, first_move, second_move, self.board - // ) - // }) + .check_exit_rules(&(first_move, second_move), exit_seqs.as_deref()) .is_ok()) + && filling_seqs + .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move))) { if second_move.get_to() == 0 && first_move.get_to() == 0 @@ -627,16 +644,14 @@ impl MoveRules { && !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE)) && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) - || self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok()) - && (ignored_rules.contains(&TricTracRule::MustFillQuarter) - || self - .check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE)) - .is_ok()) + || self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok()) + && filling_seqs + .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE))) { // empty move moves_seqs.push((first_move, EMPTY_MOVE)); } - //if board2.get_color_fields(*color).is_empty() { + board.unmove_checker(color, first_move); } moves_seqs } @@ -1495,6 +1510,7 @@ mod tests { CheckerMove::new(23, 0).unwrap(), CheckerMove::new(24, 0).unwrap(), ); + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1502,7 +1518,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); @@ -1517,6 +1534,7 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), )]; + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( moves, state.get_possible_moves_sequences_by_dices( @@ -1524,7 +1542,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); let moves = vec![( @@ -1538,7 +1557,8 @@ mod tests { state.dice.values.0, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); @@ -1554,6 +1574,7 @@ mod tests { CheckerMove::new(19, 21).unwrap(), CheckerMove::new(23, 0).unwrap(), ); + let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1561,7 +1582,8 @@ mod tests { state.dice.values.1, true, false, - vec![] + &[], + filling_seqs.as_deref(), ) ); } @@ -1580,13 +1602,26 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves).is_ok()); + assert!(state.check_exit_rules(&moves, None).is_ok()); let moves = ( CheckerMove::new(19, 24).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves).is_ok()); + assert!(state.check_exit_rules(&moves, None).is_ok()); + + state.dice.values = (6, 4); + state.board.set_positions( + &crate::Color::White, + [ + -4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1, + ], + ); + let moves = ( + CheckerMove::new(20, 24).unwrap(), + CheckerMove::new(23, 0).unwrap(), + ); + assert!(state.check_exit_rules(&moves, None).is_ok()); } #[test] diff --git a/store/src/pyengine.rs b/store/src/pyengine.rs index b193987..43b5713 100644 --- a/store/src/pyengine.rs +++ b/store/src/pyengine.rs @@ -113,11 +113,11 @@ impl TricTrac { [self.get_score(1), self.get_score(2)] } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_vec() + self.game_state.to_tensor() } else { - self.game_state.mirror().to_vec() + self.game_state.mirror().to_tensor() } } diff --git a/store/src/training_common.rs b/store/src/training_common.rs index 57094a9..69765fc 100644 --- a/store/src/training_common.rs +++ b/store/src/training_common.rs @@ -3,7 +3,6 @@ use std::cmp::{max, min}; use std::fmt::{Debug, Display, Formatter}; -use crate::board::Board; use crate::{CheckerMove, Dice, GameEvent, GameState}; use serde::{Deserialize, Serialize}; @@ -221,10 +220,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result anyhow::Result anyhow::Result anyhow::Result { - let dice = &state.dice; - let board = &state.board; - - if color == &crate::Color::Black { - // Moves are already 'white', so we don't mirror them - white_checker_moves_to_trictrac_action( - move1, - move2, - // &move1.clone().mirror(), - // &move2.clone().mirror(), - dice, - &board.clone().mirror(), - ) - // .map(|a| a.mirror()) + // Moves are always in White's coordinate system. For Black, mirror the board first. + let cum = if color == &crate::Color::Black { + state.board.mirror().white_checker_cumulative() } else { - white_checker_moves_to_trictrac_action(move1, move2, dice, board) - } + state.board.white_checker_cumulative() + }; + white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum) } fn white_checker_moves_to_trictrac_action( move1: &CheckerMove, move2: &CheckerMove, dice: &Dice, - board: &Board, + cum: &[u8; 25], ) -> anyhow::Result { let to1 = move1.get_to(); let to2 = move2.get_to(); @@ -302,7 +300,7 @@ fn white_checker_moves_to_trictrac_action( } } else { // double sortie - if from1 < from2 { + if from1 < from2 || from2 == 0 { max(dice.values.0, dice.values.1) as usize } else { min(dice.values.0, dice.values.1) as usize @@ -321,11 +319,21 @@ fn white_checker_moves_to_trictrac_action( } let dice_order = diff_move1 == dice.values.0 as usize; - let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize; - let mut tmp_board = board.clone(); - // should not raise an error for a valid action - tmp_board.move_checker(&crate::Color::White, *move1)?; - let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; + // cum[i] = # white checkers in fields 1..=i (precomputed by the caller). + // checker1 is the ordinal of the last checker at from1. + let checker1 = cum[from1] as usize; + // checker2 is the ordinal on the board after move1 (removed from from1, added to to1). + // Adjust the cumulative in O(1) without cloning the board. + let checker2 = { + let mut c = cum[from2]; + if from1 > 0 && from2 >= from1 { + c -= 1; // one checker was removed from from1, shifting later ordinals down + } + if from1 > 0 && to1 > 0 && from2 >= to1 { + c += 1; // one checker was added at to1, shifting later ordinals up + } + c as usize + }; Ok(TrictracAction::Move { dice_order, checker1, @@ -456,5 +464,48 @@ mod tests { }), ttaction.ok() ); + + // Black player + state.active_player_id = 2; + state.dice = Dice { values: (6, 3) }; + state.board.set_positions( + &crate::Color::White, + [ + 2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4, + ], + ); + let ttaction = super::checker_moves_to_trictrac_action( + &CheckerMove::new(21, 0).unwrap(), + &CheckerMove::new(0, 0).unwrap(), + &crate::Color::Black, + &state, + ); + + assert_eq!( + Some(TrictracAction::Move { + dice_order: true, + checker1: 2, + checker2: 0, // blocked by white on last field + }), + ttaction.ok() + ); + + // same with dice order reversed + state.dice = Dice { values: (3, 6) }; + let ttaction = super::checker_moves_to_trictrac_action( + &CheckerMove::new(21, 0).unwrap(), + &CheckerMove::new(0, 0).unwrap(), + &crate::Color::Black, + &state, + ); + + assert_eq!( + Some(TrictracAction::Move { + dice_order: false, + checker1: 2, + checker2: 0, // blocked by white on last field + }), + ttaction.ok() + ); } }