diff --git a/Cargo.lock b/Cargo.lock index a6c9481..a43261e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,12 +92,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - [[package]] name = "anstream" version = "0.6.21" @@ -1122,12 +1116,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - [[package]] name = "cast_trait" version = "0.1.2" @@ -1212,33 +1200,6 @@ dependencies = [ "rand 0.7.3", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1492,42 +1453,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot", - "is-terminal", - "itertools 0.10.5", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools 0.10.5", -] - [[package]] name = "critical-section" version = "1.2.0" @@ -4536,12 +4461,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" -[[package]] -name = "oorandom" -version = "11.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" - [[package]] name = "opaque-debug" version = "0.3.1" @@ -4678,34 +4597,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" -[[package]] -name = "plotters" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" - -[[package]] -name = "plotters-svg" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" -dependencies = [ - "plotters-backend", -] - [[package]] name = "png" version = "0.18.0" @@ -6000,19 +5891,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "spiel_bot" -version = "0.1.0" -dependencies = [ - "anyhow", - "burn", - "criterion", - "rand 0.9.2", - "rand_distr", - "rayon", - "trictrac-store", -] - [[package]] name = "spin" version = "0.10.0" @@ -6421,16 +6299,6 @@ dependencies = [ "zerovec", ] -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "tinyvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 4c2eb15..b9e6d45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] +members = ["client_cli", "bot", "store"] diff --git a/doc/plan_cxxbindings.md b/doc/plan_cxxbindings.md new file mode 100644 index 0000000..29bf314 --- /dev/null +++ b/doc/plan_cxxbindings.md @@ -0,0 +1,992 @@ +# Plan: C++ OpenSpiel Game via cxx.rs + +> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings. +> +> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`. + +--- + +## 1. Overview + +The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding: + +- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`. +- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`. +- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box` and delegates all logic to Rust. +- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`. + +The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games. + +--- + +## 2. Files to Create / Modify + +``` +trictrac/ + store/ + Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type + build.rs ← CREATE: cxx-build bridge registration + src/ + lib.rs ← MODIFY: add cxxengine module + cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl + +forks/open_spiel/ + CMakeLists.txt ← MODIFY: add Corrosion FetchContent + open_spiel/ + games/ + CMakeLists.txt ← MODIFY: add trictrac/ sources + test + trictrac/ ← CREATE directory + trictrac.h ← CREATE + trictrac.cc ← CREATE + trictrac_test.cc ← CREATE + + justfile ← MODIFY: add buildtrictrac target +trictrac/ + justfile ← MODIFY: add cxxlib target +``` + +--- + +## 3. Step 1 — Rust: `store/Cargo.toml` + +Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library. + +```toml +[package] +name = "trictrac-store" +version = "0.1.0" +edition = "2021" + +[lib] +name = "trictrac_store" +# cdylib → Python .so (used by maturin / pyengine) +# rlib → used by other Rust crates in the workspace +# staticlib → used by C++ consumers (cxxengine) +crate-type = ["cdylib", "rlib", "staticlib"] + +[dependencies] +base64 = "0.21.7" +cxx = "1.0" +log = "0.4.20" +merge = "0.1.0" +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } +rand = "0.9" +serde = { version = "1.0", features = ["derive"] } +transpose = "0.2.2" + +[build-dependencies] +cxx-build = "1.0" +``` + +> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`). + +--- + +## 4. Step 2 — Rust: `store/build.rs` + +The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header. + +```rust +fn main() { + cxx_build::bridge("src/cxxengine.rs") + .std("c++17") + .compile("trictrac-cxx"); + + // Re-run if the bridge source changes + println!("cargo:rerun-if-changed=src/cxxengine.rs"); +} +``` + +`cxx-build` will: + +- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block. +- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header. +- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source. +- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`). + +--- + +## 5. Step 3 — Rust: `store/src/cxxengine.rs` + +This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3. + +### Design decisions vs. `pyengine.rs` + +| pyengine | cxxengine | Reason | +| ------------------------- | ---------------------------- | -------------------------------------------- | +| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception | +| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples | +| `Vec` for actions | `Vec` | cxx does not support `usize` | +| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays | +| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy | + +### File content + +```rust +//! # C++ bindings for the TricTrac game engine via cxx.rs +//! +//! Exposes an opaque `TricTracEngine` type and associated functions +//! to C++. The C++ side (trictrac.cc) uses `rust::Box`. +//! +//! The Rust engine always works from the perspective of White (player 1). +//! For Black (player 2), the board is mirrored before computing actions +//! and events are mirrored back before applying — exactly as in pyengine.rs. + +use crate::dice::Dice; +use crate::game::{GameEvent, GameState, Stage, TurnStage}; +use crate::training_common::{get_valid_action_indices, TrictracAction}; + +// ── cxx bridge declaration ──────────────────────────────────────────────────── + +#[cxx::bridge(namespace = "trictrac_engine")] +pub mod ffi { + // ── Shared types (visible to both Rust and C++) ─────────────────────────── + + /// Two dice values passed from C++ to Rust for a dice-roll event. + struct DicePair { + die1: u8, + die2: u8, + } + + /// Both players' scores: holes * 12 + points. + struct PlayerScores { + score_p1: i32, + score_p2: i32, + } + + // ── Opaque Rust type exposed to C++ ─────────────────────────────────────── + + extern "Rust" { + /// Opaque handle to a TricTrac game state. + /// C++ accesses this only through `rust::Box`. + type TricTracEngine; + + /// Create a new engine, initialise two players, begin with player 1. + fn new_trictrac_engine() -> Box; + + /// Return a deep copy of the engine (needed for State::Clone()). + fn clone_engine(self: &TricTracEngine) -> Box; + + // ── Queries ─────────────────────────────────────────────────────────── + + /// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node). + fn needs_roll(self: &TricTracEngine) -> bool; + + /// True when Stage::Ended. + fn is_game_ended(self: &TricTracEngine) -> bool; + + /// Active player index: 0 (player 1 / White) or 1 (player 2 / Black). + fn current_player_idx(self: &TricTracEngine) -> u64; + + /// Legal action indices for `player_idx`. Returns empty vec if it is + /// not that player's turn. Indices are in [0, 513]. + fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec; + + /// Human-readable action description, e.g. "0:Move { dice_order: true … }". + fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String; + + /// Both players' scores: holes * 12 + points. + fn get_players_scores(self: &TricTracEngine) -> PlayerScores; + + /// 36-element state observation vector (i8). Mirrored for player 1. + fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; + + /// Human-readable state description for `player_idx`. + fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; + + /// Full debug representation of the current state. + fn to_debug_string(self: &TricTracEngine) -> String; + + // ── Mutations ───────────────────────────────────────────────────────── + + /// Apply a dice roll result. Returns Err if not in RollWaiting stage. + fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>; + + /// Apply a player action (move, go, roll). Returns Err if invalid. + fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>; + } +} + +// ── Opaque type implementation ──────────────────────────────────────────────── + +pub struct TricTracEngine { + game_state: GameState, +} + +pub fn new_trictrac_engine() -> Box { + let mut game_state = GameState::new(false); // schools_enabled = false + game_state.init_player("player1"); + game_state.init_player("player2"); + game_state.consume(&GameEvent::BeginGame { goes_first: 1 }); + Box::new(TricTracEngine { game_state }) +} + +impl TricTracEngine { + fn clone_engine(&self) -> Box { + Box::new(TricTracEngine { + game_state: self.game_state.clone(), + }) + } + + fn needs_roll(&self) -> bool { + self.game_state.turn_stage == TurnStage::RollWaiting + } + + fn is_game_ended(&self) -> bool { + self.game_state.stage == Stage::Ended + } + + /// Returns 0 for player 1 (White) and 1 for player 2 (Black). + fn current_player_idx(&self) -> u64 { + self.game_state.active_player_id - 1 + } + + fn get_legal_actions(&self, player_idx: u64) -> Vec { + if player_idx == self.current_player_idx() { + if player_idx == 0 { + get_valid_action_indices(&self.game_state) + .into_iter() + .map(|i| i as u64) + .collect() + } else { + let mirror = self.game_state.mirror(); + get_valid_action_indices(&mirror) + .into_iter() + .map(|i| i as u64) + .collect() + } + } else { + vec![] + } + } + + fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String { + TrictracAction::from_action_index(action_idx as usize) + .map(|a| format!("{}:{}", player_idx, a)) + .unwrap_or_else(|| "unknown action".into()) + } + + fn get_players_scores(&self) -> ffi::PlayerScores { + ffi::PlayerScores { + score_p1: self.score_for(1), + score_p2: self.score_for(2), + } + } + + fn score_for(&self, player_id: u64) -> i32 { + if let Some(player) = self.game_state.players.get(&player_id) { + player.holes as i32 * 12 + player.points as i32 + } else { + -1 + } + } + + fn get_tensor(&self, player_idx: u64) -> Vec { + if player_idx == 0 { + self.game_state.to_vec() + } else { + self.game_state.mirror().to_vec() + } + } + + fn get_observation_string(&self, player_idx: u64) -> String { + if player_idx == 0 { + format!("{}", self.game_state) + } else { + format!("{}", self.game_state.mirror()) + } + } + + fn to_debug_string(&self) -> String { + format!("{}", self.game_state) + } + + fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> { + let player_id = self.game_state.active_player_id; + if self.game_state.turn_stage != TurnStage::RollWaiting { + return Err("Not in RollWaiting stage".into()); + } + let dice = Dice { + values: (dice.die1, dice.die2), + }; + self.game_state + .consume(&GameEvent::RollResult { player_id, dice }); + Ok(()) + } + + fn apply_action(&mut self, action_idx: u64) -> Result<(), String> { + let action_idx = action_idx as usize; + let needs_mirror = self.game_state.active_player_id == 2; + + let event = TrictracAction::from_action_index(action_idx) + .and_then(|a| { + let game_state = if needs_mirror { + &self.game_state.mirror() + } else { + &self.game_state + }; + a.to_event(game_state) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + }); + + match event { + Some(evt) if self.game_state.validate(&evt) => { + self.game_state.consume(&evt); + Ok(()) + } + Some(_) => Err("Action is invalid".into()), + None => Err("Could not build event from action index".into()), + } + } +} +``` + +> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options: +> +> - Use `anyhow::Error` (add `anyhow` dependency). +> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`. +> +> The recommended approach is `anyhow`: +> +> ```toml +> [dependencies] +> anyhow = "1.0" +> ``` +> +> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message. + +--- + +## 6. Step 4 — Rust: `store/src/lib.rs` + +Add the new module: + +```rust +// existing modules … +mod pyengine; + +// NEW: C++ bindings via cxx.rs +pub mod cxxengine; +``` + +--- + +## 7. Step 5 — C++: `trictrac/trictrac.h` + +Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box` and delegates everything to it. + +```cpp +// open_spiel/games/trictrac/trictrac.h +#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_ +#define OPEN_SPIEL_GAMES_TRICTRAC_H_ + +#include +#include +#include + +#include "open_spiel/spiel.h" +#include "open_spiel/spiel_utils.h" + +// Generated by cxx-build from store/src/cxxengine.rs. +// The include path is set by CMake (see CMakeLists.txt). +#include "trictrac_store/src/cxxengine.rs.h" + +namespace open_spiel { +namespace trictrac { + +inline constexpr int kNumPlayers = 2; +inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes +inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust +inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust +inline constexpr int kDefaultMaxTurns = 1000; + +class TrictracGame; + +// --------------------------------------------------------------------------- +// TrictracState +// --------------------------------------------------------------------------- +class TrictracState : public State { + public: + explicit TrictracState(std::shared_ptr game); + TrictracState(const TrictracState& other); + + Player CurrentPlayer() const override; + std::vector LegalActions() const override; + std::string ActionToString(Player player, Action move_id) const override; + std::vector> ChanceOutcomes() const override; + std::string ToString() const override; + bool IsTerminal() const override; + std::vector Returns() const override; + std::string ObservationString(Player player) const override; + void ObservationTensor(Player player, absl::Span values) const override; + std::unique_ptr Clone() const override; + + protected: + void DoApplyAction(Action move_id) override; + + private: + // Decode a chance action index [0,35] to (die1, die2). + // Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action] + static trictrac_engine::DicePair DecodeChanceAction(Action action); + + // The Rust engine handle. Deep-copied via clone_engine() when cloning state. + rust::Box engine_; +}; + +// --------------------------------------------------------------------------- +// TrictracGame +// --------------------------------------------------------------------------- +class TrictracGame : public Game { + public: + explicit TrictracGame(const GameParameters& params); + + int NumDistinctActions() const override { return kNumDistinctActions; } + std::unique_ptr NewInitialState() const override; + int MaxChanceOutcomes() const override { return kNumChanceOutcomes; } + int NumPlayers() const override { return kNumPlayers; } + double MinUtility() const override { return 0.0; } + double MaxUtility() const override { return 200.0; } + int MaxGameLength() const override { return 3 * max_turns_; } + int MaxChanceNodesInHistory() const override { return MaxGameLength(); } + std::vector ObservationTensorShape() const override { + return {kStateEncodingSize}; + } + + private: + int max_turns_; +}; + +} // namespace trictrac +} // namespace open_spiel + +#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_ +``` + +--- + +## 8. Step 6 — C++: `trictrac/trictrac.cc` + +```cpp +// open_spiel/games/trictrac/trictrac.cc +#include "open_spiel/games/trictrac/trictrac.h" + +#include +#include +#include + +#include "open_spiel/abseil-cpp/absl/types/span.h" +#include "open_spiel/game_parameters.h" +#include "open_spiel/spiel.h" +#include "open_spiel/spiel_globals.h" +#include "open_spiel/spiel_utils.h" + +namespace open_spiel { +namespace trictrac { +namespace { + +// ── Game registration ──────────────────────────────────────────────────────── + +const GameType kGameType{ + /*short_name=*/"trictrac", + /*long_name=*/"Trictrac", + GameType::Dynamics::kSequential, + GameType::ChanceMode::kExplicitStochastic, + GameType::Information::kPerfectInformation, + GameType::Utility::kGeneralSum, + GameType::RewardModel::kRewards, + /*min_num_players=*/kNumPlayers, + /*max_num_players=*/kNumPlayers, + /*provides_information_state_string=*/false, + /*provides_information_state_tensor=*/false, + /*provides_observation_string=*/true, + /*provides_observation_tensor=*/true, + /*parameter_specification=*/{ + {"max_turns", GameParameter(kDefaultMaxTurns)}, + }}; + +static std::shared_ptr Factory(const GameParameters& params) { + return std::make_shared(params); +} + +REGISTER_SPIEL_GAME(kGameType, Factory); + +} // namespace + +// ── TrictracGame ───────────────────────────────────────────────────────────── + +TrictracGame::TrictracGame(const GameParameters& params) + : Game(kGameType, params), + max_turns_(ParameterValue("max_turns", kDefaultMaxTurns)) {} + +std::unique_ptr TrictracGame::NewInitialState() const { + return std::make_unique(shared_from_this()); +} + +// ── TrictracState ───────────────────────────────────────────────────────────── + +TrictracState::TrictracState(std::shared_ptr game) + : State(game), + engine_(trictrac_engine::new_trictrac_engine()) {} + +// Copy constructor: deep-copy the Rust engine via clone_engine(). +TrictracState::TrictracState(const TrictracState& other) + : State(other), + engine_(other.engine_->clone_engine()) {} + +std::unique_ptr TrictracState::Clone() const { + return std::make_unique(*this); +} + +// ── Current player ──────────────────────────────────────────────────────────── + +Player TrictracState::CurrentPlayer() const { + if (engine_->is_game_ended()) return kTerminalPlayerId; + if (engine_->needs_roll()) return kChancePlayerId; + return static_cast(engine_->current_player_idx()); +} + +// ── Legal actions ───────────────────────────────────────────────────────────── + +std::vector TrictracState::LegalActions() const { + if (IsChanceNode()) { + // All 36 dice outcomes are equally likely; return indices 0–35. + std::vector actions(kNumChanceOutcomes); + for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i; + return actions; + } + Player player = CurrentPlayer(); + rust::Vec rust_actions = + engine_->get_legal_actions(static_cast(player)); + std::vector actions; + actions.reserve(rust_actions.size()); + for (uint64_t a : rust_actions) actions.push_back(static_cast(a)); + return actions; +} + +// ── Chance outcomes ─────────────────────────────────────────────────────────── + +std::vector> TrictracState::ChanceOutcomes() const { + SPIEL_CHECK_TRUE(IsChanceNode()); + const double p = 1.0 / kNumChanceOutcomes; + std::vector> outcomes; + outcomes.reserve(kNumChanceOutcomes); + for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p); + return outcomes; +} + +// ── Apply action ────────────────────────────────────────────────────────────── + +/*static*/ +trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) { + // Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action] + return trictrac_engine::DicePair{ + /*die1=*/static_cast(action / 6 + 1), + /*die2=*/static_cast(action % 6 + 1), + }; +} + +void TrictracState::DoApplyAction(Action action) { + if (IsChanceNode()) { + engine_->apply_dice_roll(DecodeChanceAction(action)); + } else { + engine_->apply_action(static_cast(action)); + } +} + +// ── Terminal & returns ──────────────────────────────────────────────────────── + +bool TrictracState::IsTerminal() const { + return engine_->is_game_ended(); +} + +std::vector TrictracState::Returns() const { + trictrac_engine::PlayerScores scores = engine_->get_players_scores(); + return {static_cast(scores.score_p1), + static_cast(scores.score_p2)}; +} + +// ── Observation ─────────────────────────────────────────────────────────────── + +std::string TrictracState::ObservationString(Player player) const { + return std::string(engine_->get_observation_string( + static_cast(player))); +} + +void TrictracState::ObservationTensor(Player player, + absl::Span values) const { + SPIEL_CHECK_EQ(values.size(), kStateEncodingSize); + rust::Vec tensor = + engine_->get_tensor(static_cast(player)); + SPIEL_CHECK_EQ(tensor.size(), static_cast(kStateEncodingSize)); + for (int i = 0; i < kStateEncodingSize; ++i) { + values[i] = static_cast(tensor[i]); + } +} + +// ── Strings ─────────────────────────────────────────────────────────────────── + +std::string TrictracState::ToString() const { + return std::string(engine_->to_debug_string()); +} + +std::string TrictracState::ActionToString(Player player, Action action) const { + if (IsChanceNode()) { + trictrac_engine::DicePair d = DecodeChanceAction(action); + return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")"; + } + return std::string(engine_->action_to_string( + static_cast(player), static_cast(action))); +} + +} // namespace trictrac +} // namespace open_spiel +``` + +--- + +## 9. Step 7 — C++: `trictrac/trictrac_test.cc` + +```cpp +// open_spiel/games/trictrac/trictrac_test.cc +#include "open_spiel/games/trictrac/trictrac.h" + +#include +#include + +#include "open_spiel/spiel.h" +#include "open_spiel/tests/basic_tests.h" +#include "open_spiel/utils/init.h" + +namespace open_spiel { +namespace trictrac { +namespace { + +void BasicTrictracTests() { + testing::LoadGameTest("trictrac"); + testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5); +} + +} // namespace +} // namespace trictrac +} // namespace open_spiel + +int main(int argc, char** argv) { + open_spiel::Init(&argc, &argv); + open_spiel::trictrac::BasicTrictracTests(); + std::cout << "trictrac tests passed" << std::endl; + return 0; +} +``` + +--- + +## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt` + +The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined: + +```cmake +# ── Corrosion: CMake integration for Rust ──────────────────────────────────── +include(FetchContent) +FetchContent_Declare( + Corrosion + GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git + GIT_TAG v0.5.1 # pin to a stable release +) +FetchContent_MakeAvailable(Corrosion) + +# Import the trictrac-store Rust crate. +# This creates a CMake target named 'trictrac-store'. +corrosion_import_crate( + MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml + CRATES trictrac-store +) + +# Generate the cxx bridge from cxxengine.rs. +# corrosion_add_cxxbridge: +# - runs cxx-build as part of the Rust build +# - creates a CMake target 'trictrac_cxx_bridge' that: +# * compiles the generated C++ shim +# * exposes INTERFACE include dirs for the generated .rs.h header +corrosion_add_cxxbridge(trictrac_cxx_bridge + CRATE trictrac-store + FILES src/cxxengine.rs +) +``` + +> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing. + +--- + +## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt` + +Two changes: add the new source files to `GAME_SOURCES`, and add a test target. + +### 11.1 Add to `GAME_SOURCES` + +Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add: + +```cmake +set(GAME_SOURCES + # ... existing games ... + trictrac/trictrac.cc + trictrac/trictrac.h + # ... remaining games ... +) +``` + +### 11.2 Link cxx bridge into OpenSpiel objects + +The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed. + +The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add: + +```cmake +target_link_libraries(open_spiel + PUBLIC + trictrac_cxx_bridge # C++ shim generated by cxx-build + trictrac-store # Rust static library +) +``` + +If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game: + +```cmake +add_library(trictrac_game OBJECT + trictrac/trictrac.cc + trictrac/trictrac.h +) +target_include_directories(trictrac_game + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/.. +) +target_link_libraries(trictrac_game + PUBLIC + trictrac_cxx_bridge + trictrac-store + open_spiel_core # or whatever the core target is called +) +``` + +Then reference `$` in relevant executables. + +### 11.3 Add the test + +```cmake +add_executable(trictrac_test + trictrac/trictrac_test.cc + ${OPEN_SPIEL_OBJECTS} + $ +) +target_link_libraries(trictrac_test + PRIVATE + trictrac_cxx_bridge + trictrac-store +) +add_test(trictrac_test trictrac_test) +``` + +--- + +## 12. Step 10 — Justfile updates + +### `trictrac/justfile` — add `cxxlib` target + +Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists: + +```just +cxxlib: + cargo build --release -p trictrac-store + @echo "Static lib: $(ls target/release/libtrictrac_store.a)" + @echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)" +``` + +### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac` + +```just +buildtrictrac: + # Rebuild the Rust static lib first, then CMake + cd ../../trictrac && cargo build --release -p trictrac-store + mkdir -p build && cd build && \ + CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \ + make -j$(nproc) trictrac_test + +testtrictrac: buildtrictrac + ./build/trictrac_test + +playtrictrac_cpp: + ./build/examples/example --game=trictrac +``` + +--- + +## 13. Key Design Decisions + +### 13.1 Opaque type with `clone_engine()` + +OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box` which calls `.clone()` on the inner `GameState` (which derives `Clone`). + +### 13.2 Action encoding: same 514-element space + +The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means: + +- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies. +- Action 0 = Roll (used as the bridge between Move and the next chance node). +- Actions 2–513 = Move variants (checker ordinal pair + dice order). +- A trained C++ model and Python model share the same action space. + +### 13.3 Chance outcome ordering + +The dice outcome ordering is identical to the Python version: + +``` +action → (die1, die2) +0 → (1,1) 6 → (2,1) ... 35 → (6,6) +``` + +(`die1 = action/6 + 1`, `die2 = action%6 + 1`) + +This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines. + +### 13.4 `GameType::Utility::kGeneralSum` + `kRewards` + +Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal. + +### 13.5 Mirror pattern preserved + +`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest. + +### 13.6 `rust::Box` vs `rust::UniquePtr` + +`rust::Box` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr` is for C++ types held in Rust. + +### 13.7 Separate struct from `pyengine.rs` + +`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids: + +- PyO3 and cxx attributes conflicting on the same type. +- Changes to one binding breaking the other. +- Feature-flag complexity. + +--- + +## 14. Known Challenges + +### 14.1 Corrosion path resolution + +`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable). + +### 14.2 `staticlib` + `cdylib` in one crate + +Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions: + +- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time. +- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile. +- Alternatively, accept the extra compile time during development. + +### 14.3 Linker symbols from Rust std + +When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`). + +### 14.4 `anyhow` for error types + +cxx.rs requires the `Err` type in `Result` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper: + +```rust +use std::fmt; + +#[derive(Debug)] +struct EngineError(String); +impl fmt::Display for EngineError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } +} +impl std::error::Error for EngineError {} +``` + +On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`. + +### 14.5 `UndoAction` not implemented + +OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo. + +### 14.6 Generated header path in `#include` + +The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory. + +### 14.7 `rust::String` to `std::string` conversion + +The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast(...)`. + +--- + +## 15. Complete File Checklist + +``` +[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib +[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...) +[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;` +[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation +[ ] trictrac/justfile — add `cxxlib` target +[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge +[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test +[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file +[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file +[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file +[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac +``` + +--- + +## 16. Implementation Order + +Implement in this sequence to get early feedback at each step: + +1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced. +2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations. +3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist. +4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them. +5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination. +6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`. + +--- + +# Implementation complete + +All files are in place and trictrac_test passes (168 84 final scores, all assertions pass). + +What was done + +Rust side (from previous session, already complete): + +- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build +- trictrac/store/build.rs — drives cxx-build +- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl +- trictrac/store/src/lib.rs — added pub mod cxxengine; + +C++ side (this session): + +- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header +- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation +- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test + +Build system: + +- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge +- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable + +Justfiles: + +- trictrac/justfile — added cxxlib target +- forks/open_spiel/justfile — added buildtrictrac and testtrictrac + +Fixes discovered during build + +| Issue | Fix | +| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- | +| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries | +| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) | +| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h | +| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc | +| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) | +| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries | +| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort | +| Duplicate actions for doubles | Added std::unique after sort | +| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states | diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md deleted file mode 100644 index d9e021e..0000000 --- a/doc/spiel_bot_parallel.md +++ /dev/null @@ -1,121 +0,0 @@ -Part B — Batched MCTS leaf evaluation - -Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. - -Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) - -pub trait Evaluator: Send + Sync { -fn evaluate(&self, obs: &[f32]) -> (Vec, f32); - - /// Evaluate a batch of observations at once. Default falls back to - /// sequential calls; backends override this for efficiency. - fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { - obs_batch.iter().map(|obs| self.evaluate(obs)).collect() - } - -} - -Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) - -Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. - -fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { -let b = obs_batch.len(); -let obs_size = obs_batch[0].len(); -let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); -let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); -let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); -let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); -let values: Vec = value_tensor.into_data().to_vec().unwrap(); -let action_size = policies.len() / b; -(0..b).map(|i| { -(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) -}).collect() -} - -Step B3 — Add eval_batch_size to MctsConfig - -pub struct MctsConfig { -// ... existing fields ... -/// Number of leaves to batch per network call. 1 = no batching (current behaviour). -pub eval_batch_size: usize, -} - -Default: 1 (backwards-compatible). - -Step B4 — Make the simulation iterative (mcts/search.rs) - -The current simulate is recursive. For batching we need to split it into two phases: - -descend (pure tree traversal — no network call): - -- Traverse from root following PUCT, advancing through chance nodes with apply_chance. -- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). -- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. -- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. - -ascend (backup — no network call): - -- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. -- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. - -Step B5 — Add run_mcts_batched to mcts/mod.rs - -The new entry point, called by run_mcts when config.eval_batch_size > 1: - -expand root (1 network call) -optionally add Dirichlet noise - -for round in 0..(n*simulations / batch_size): -leaves = [] -for * in 0..batch_size: -leaf = descend(root, state, env, rng) -leaves.push(leaf) - - obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves - where leaf.kind == NeedsEval] - results = evaluator.evaluate_batch(obs_batch) - - for (leaf, result) in zip(leaves, results): - expand the leaf node (insert children from result.policy) - ascend(root, leaf.path, result.value, leaf.player_idx) - // ascend also handles terminal and crossed-chance leaves - -// handle remainder: n_simulations % batch_size - -run_mcts becomes a thin dispatcher: -if config.eval_batch_size <= 1 { -// existing path (unchanged) -} else { -run_mcts_batched(...) -} - -Step B6 — CLI flag in az_train.rs - ---eval-batch N default: 8 Leaf batch size for MCTS network calls - ---- - -Summary of file changes - -┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ -│ File │ Changes │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ spiel_bot/Cargo.toml │ add rayon │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ -└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ - -Key design constraint - -Parts A and B are independent and composable: - -- A only touches the outer game loop. -- B only touches the inner MCTS per game. -- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state. diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md deleted file mode 100644 index a8863af..0000000 --- a/doc/spiel_bot_research.md +++ /dev/null @@ -1,782 +0,0 @@ -# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac - -## 0. Context and Scope - -The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library -(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` -encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every -other stage to an inline random-opponent loop. - -`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency -for **self-play training**. Its goals: - -- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") - that works with Trictrac's multi-stage turn model and stochastic dice. -- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) - as the first algorithm. -- Remain **modular**: adding DQN or PPO later requires only a new - `impl Algorithm for Dqn` without touching the environment or network layers. -- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from - `trictrac-store`. - ---- - -## 1. Library Landscape - -### 1.1 Neural Network Frameworks - -| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | -| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | -| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | -| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | -| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | -| ndarray alone | no | no | yes | mature | array ops only; no autograd | - -**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ -runtime needed, the `ndarray` backend is sufficient for CPU training and can -switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by -changing one type alias. - -`tch-rs` would be the best choice for raw training throughput (it is the most -battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the -pure-Rust constraint. If training speed becomes the bottleneck after prototyping, -switching `spiel_bot` to `tch-rs` is a one-line backend swap. - -### 1.2 Other Key Crates - -| Crate | Role | -| -------------------- | ----------------------------------------------------------------- | -| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | -| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | -| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | -| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | -| `anyhow` | error propagation (already used everywhere) | -| `indicatif` | training progress bars | -| `tracing` | structured logging per episode/iteration | - -### 1.3 What `burn-rl` Provides (and Does Not) - -The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) -provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` -trait. It does **not** provide: - -- MCTS or any tree-search algorithm -- Two-player self-play -- Legal action masking during training -- Chance-node handling - -For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its -own (simpler, more targeted) traits and implement MCTS from scratch. - ---- - -## 2. Trictrac-Specific Design Constraints - -### 2.1 Multi-Stage Turn Model - -A Trictrac turn passes through up to six `TurnStage` values. Only two involve -genuine player choice: - -| TurnStage | Node type | Handler | -| ---------------- | ------------------------------- | ------------------------------- | -| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | -| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | -| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | -| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | -| `Move` | **Player decision** | MCTS / policy network | -| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | - -The environment wrapper advances through forced/chance stages automatically so -that from the algorithm's perspective every node it sees is a genuine player -decision. - -### 2.2 Stochastic Dice in MCTS - -AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice -introduce stochasticity. Three approaches exist: - -**A. Outcome sampling (recommended)** -During each MCTS simulation, when a chance node is reached, sample one dice -outcome at random and continue. After many simulations the expected value -converges. This is the approach used by OpenSpiel's MCTS for stochastic games -and requires no changes to the standard PUCT formula. - -**B. Chance-node averaging (expectimax)** -At each chance node, expand all 21 unique dice pairs weighted by their -probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is -exact but multiplies the branching factor by ~21 at every dice roll, making it -prohibitively expensive. - -**C. Condition on dice in the observation (current approach)** -Dice values are already encoded at indices [192–193] of `to_tensor()`. The -network naturally conditions on the rolled dice when it evaluates a position. -MCTS only runs on player-decision nodes _after_ the dice have been sampled; -chance nodes are bypassed by the environment wrapper (approach A). The policy -and value heads learn to play optimally given any dice pair. - -**Use approach A + C together**: the environment samples dice automatically -(chance node bypass), and the 217-dim tensor encodes the dice so the network -can exploit them. - -### 2.3 Perspective / Mirroring - -All move rules and tensor encoding are defined from White's perspective. -`to_tensor()` must always be called after mirroring the state for Black. -The environment wrapper handles this transparently: every observation returned -to an algorithm is already in the active player's perspective. - -### 2.4 Legal Action Masking - -A crucial difference from the existing `bot/` code: instead of penalizing -invalid actions with `ERROR_REWARD`, the policy head logits are **masked** -before softmax — illegal action logits are set to `-inf`. This prevents the -network from wasting capacity on illegal moves and eliminates the need for the -penalty-reward hack. - ---- - -## 3. Proposed Crate Architecture - -``` -spiel_bot/ -├── Cargo.toml -└── src/ - ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" - │ - ├── env/ - │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface - │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store - │ - ├── mcts/ - │ ├── mod.rs # MctsConfig, run_mcts() entry point - │ ├── node.rs # MctsNode (visit count, W, prior, children) - │ └── search.rs # simulate(), backup(), select_action() - │ - ├── network/ - │ ├── mod.rs # PolicyValueNet trait - │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads - │ - ├── alphazero/ - │ ├── mod.rs # AlphaZeroConfig - │ ├── selfplay.rs # generate_episode() -> Vec - │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) - │ └── trainer.rs # training loop: selfplay → sample → loss → update - │ - └── agent/ - ├── mod.rs # Agent trait - ├── random.rs # RandomAgent (baseline) - └── mcts_agent.rs # MctsAgent: uses trained network for inference -``` - -Future algorithms slot in without touching the above: - -``` - ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer - └── ppo/ # (future) PPO: impl Algorithm + rollout buffer -``` - ---- - -## 4. Core Traits - -### 4.1 `GameEnv` — the minimal OpenSpiel interface - -```rust -use rand::Rng; - -/// Who controls the current node. -pub enum Player { - P1, // player index 0 - P2, // player index 1 - Chance, // dice roll - Terminal, // game over -} - -pub trait GameEnv: Clone + Send + Sync + 'static { - type State: Clone + Send + Sync; - - /// Fresh game state. - fn new_game(&self) -> Self::State; - - /// Who acts at this node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices (always in [0, action_space())). - /// Empty only at Terminal nodes. - fn legal_actions(&self, s: &Self::State) -> Vec; - - /// Apply a player action (must be legal). - fn apply(&self, s: &mut Self::State, action: usize); - - /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. - fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); - - /// Observation tensor from `pov`'s perspective (0 or 1). - /// Returns 217 f32 values for Trictrac. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Flat observation size (217 for Trictrac). - fn obs_size(&self) -> usize; - - /// Total action-space size (514 for Trictrac). - fn action_space(&self) -> usize; - - /// Game outcome per player, or None if not Terminal. - /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} -``` - -### 4.2 `PolicyValueNet` — neural network interface - -```rust -use burn::prelude::*; - -pub trait PolicyValueNet: Send + Sync { - /// Forward pass. - /// `obs`: [batch, obs_size] tensor. - /// Returns: (policy_logits [batch, action_space], value [batch]). - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); - - /// Save weights to `path`. - fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; - - /// Load weights from `path`. - fn load(path: &std::path::Path) -> anyhow::Result - where - Self: Sized; -} -``` - -### 4.3 `Agent` — player policy interface - -```rust -pub trait Agent: Send { - /// Select an action index given the current game state observation. - /// `legal`: mask of valid action indices. - fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; -} -``` - ---- - -## 5. MCTS Implementation - -### 5.1 Node - -```rust -pub struct MctsNode { - n: u32, // visit count N(s, a) - w: f32, // sum of backed-up values W(s, a) - p: f32, // prior from policy head P(s, a) - children: Vec<(usize, MctsNode)>, // (action_idx, child) - is_expanded: bool, -} - -impl MctsNode { - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT score used for selection. - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} -``` - -### 5.2 Simulation Loop - -One MCTS simulation (for deterministic decision nodes): - -``` -1. SELECTION — traverse from root, always pick child with highest PUCT, - auto-advancing forced/chance nodes via env.apply_chance(). -2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get - (policy_logits, value). Mask illegal actions, softmax - the remaining logits → priors P(s,a) for each child. -3. BACKUP — propagate -value up the tree (negate at each level because - perspective alternates between P1 and P2). -``` - -After `n_simulations` iterations, action selection at the root: - -```rust -// During training: sample proportional to N^(1/temperature) -// During evaluation: argmax N -fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } -``` - -### 5.3 Configuration - -```rust -pub struct MctsConfig { - pub n_simulations: usize, // e.g. 200 - pub c_puct: f32, // exploration constant, e.g. 1.5 - pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 - pub dirichlet_eps: f32, // noise weight, e.g. 0.25 - pub temperature: f32, // action sampling temperature (anneals to 0) -} -``` - -### 5.4 Handling Chance Nodes Inside MCTS - -When simulation reaches a Chance node (dice roll), the environment automatically -samples dice and advances to the next decision node. The MCTS tree does **not** -branch on dice outcomes — it treats the sampled outcome as the state. This -corresponds to "outcome sampling" (approach A from §2.2). Because each -simulation independently samples dice, the Q-values at player nodes converge to -their expected value over many simulations. - ---- - -## 6. Network Architecture - -### 6.1 ResNet Policy-Value Network - -A single trunk with residual blocks, then two heads: - -``` -Input: [batch, 217] - ↓ -Linear(217 → 512) + ReLU - ↓ -ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) - ↓ trunk output [batch, 512] - ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) - └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) -``` - -Burn implementation sketch: - -```rust -#[derive(Module, Debug)] -pub struct TrictracNet { - input: Linear, - res_blocks: Vec>, - policy_head: Linear, - value_head: Linear, -} - -impl TrictracNet { - pub fn forward(&self, obs: Tensor) - -> (Tensor, Tensor) - { - let x = activation::relu(self.input.forward(obs)); - let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); - let policy = self.policy_head.forward(x.clone()); // raw logits - let value = activation::tanh(self.value_head.forward(x)) - .squeeze(1); - (policy, value) - } -} -``` - -A simpler MLP (no residual blocks) is sufficient for a first version and much -faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two -heads. - -### 6.2 Loss Function - -``` -L = MSE(value_pred, z) - + CrossEntropy(policy_logits_masked, π_mcts) - - c_l2 * L2_regularization -``` - -Where: - -- `z` = game outcome (±1) from the active player's perspective -- `π_mcts` = normalized MCTS visit counts at the root (the policy target) -- Legal action masking is applied before computing CrossEntropy - ---- - -## 7. AlphaZero Training Loop - -``` -INIT - network ← random weights - replay ← empty ReplayBuffer(capacity = 100_000) - -LOOP forever: - ── Self-play phase ────────────────────────────────────────────── - (parallel with rayon, n_workers games at once) - for each game: - state ← env.new_game() - samples = [] - while not terminal: - advance forced/chance nodes automatically - obs ← env.observation(state, current_player) - legal ← env.legal_actions(state) - π, root_value ← mcts.run(state, network, config) - action ← sample from π (with temperature) - samples.push((obs, π, current_player)) - env.apply(state, action) - z ← env.returns(state) // final scores - for (obs, π, player) in samples: - replay.push(TrainSample { obs, policy: π, value: z[player] }) - - ── Training phase ─────────────────────────────────────────────── - for each gradient step: - batch ← replay.sample(batch_size) - (policy_logits, value_pred) ← network.forward(batch.obs) - loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) - optimizer.step(loss.backward()) - - ── Evaluation (every N iterations) ───────────────────────────── - win_rate ← evaluate(network_new vs network_prev, n_eval_games) - if win_rate > 0.55: save checkpoint -``` - -### 7.1 Replay Buffer - -```rust -pub struct TrainSample { - pub obs: Vec, // 217 values - pub policy: Vec, // 514 values (normalized MCTS visit counts) - pub value: f32, // game outcome ∈ {-1, 0, +1} -} - -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - pub fn push(&mut self, s: TrainSample) { - if self.data.len() == self.capacity { self.data.pop_front(); } - self.data.push_back(s); - } - - pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - // sample without replacement - } -} -``` - -### 7.2 Parallelism Strategy - -Self-play is embarrassingly parallel (each game is independent): - -```rust -let samples: Vec = (0..n_games) - .into_par_iter() // rayon - .flat_map(|_| generate_episode(&env, &network, &mcts_config)) - .collect(); -``` - -Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. -Self-play uses inference-only (no gradient tape), so a `NdArray` backend -(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with -`Autodiff>`. - -For larger scale, a producer-consumer architecture (crossbeam-channel) separates -self-play workers from the training thread, allowing continuous data generation -while the GPU trains. - ---- - -## 8. `TrictracEnv` Implementation Sketch - -```rust -use trictrac_store::{ - training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, - Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, -}; - -#[derive(Clone)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - fn current_player(&self, s: &GameState) -> Player { - match s.stage { - Stage::Ended => Player::Terminal, - _ => match s.turn_stage { - TurnStage::RollWaiting => Player::Chance, - _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, - }, - } - } - - fn legal_actions(&self, s: &GameState) -> Vec { - let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; - get_valid_action_indices(&view).unwrap_or_default() - } - - fn apply(&self, s: &mut GameState, action_idx: usize) { - // advance all forced/chance nodes first, then apply the player action - self.advance_forced(s); - let needs_mirror = s.active_player_id == 2; - let view = if needs_mirror { s.mirror() } else { s.clone() }; - if let Some(event) = TrictracAction::from_action_index(action_idx) - .and_then(|a| a.to_event(&view)) - .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) - { - let _ = s.consume(&event); - } - // advance any forced stages that follow - self.advance_forced(s); - } - - fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { - // RollDice → RollWaiting - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - // RollWaiting → next stage - let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; - let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); - self.advance_forced(s); - } - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } - } - - fn obs_size(&self) -> usize { 217 } - fn action_space(&self) -> usize { ACTION_SPACE_SIZE } - - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { return None; } - // Convert hole+point scores to ±1 outcome - let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [ 1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [ 0.0, 0.0], - }) - } -} - -impl TrictracEnv { - /// Advance through all forced (non-decision, non-chance) stages. - fn advance_forced(&self, s: &mut GameState) { - use trictrac_store::PointsRules; - loop { - match s.turn_stage { - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Scoring is deterministic; compute and apply automatically. - let color = s.player_color_by_id(&s.active_player_id) - .unwrap_or(trictrac_store::Color::White); - let drc = s.players.get(&s.active_player_id) - .map(|p| p.dice_roll_count).unwrap_or(0); - let pr = PointsRules::new(&color, &s.board, s.dice); - let pts = pr.get_points(drc); - let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; - let _ = s.consume(&GameEvent::Mark { - player_id: s.active_player_id, points, - }); - } - TurnStage::RollDice => { - // RollDice is a forced "initiate roll" action with no real choice. - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - } - _ => break, - } - } - } -} -``` - ---- - -## 9. Cargo.toml Changes - -### 9.1 Add `spiel_bot` to the workspace - -```toml -# Cargo.toml (workspace root) -[workspace] -resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] -``` - -### 9.2 `spiel_bot/Cargo.toml` - -```toml -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[features] -default = ["alphazero"] -alphazero = [] -# dqn = [] # future -# ppo = [] # future - -[dependencies] -trictrac-store = { path = "../store" } -anyhow = "1" -rand = "0.9" -rayon = "1" -serde = { version = "1", features = ["derive"] } -serde_json = "1" - -# Burn: NdArray for pure-Rust CPU training -# Replace NdArray with Wgpu or Tch for GPU. -burn = { version = "0.20", features = ["ndarray", "autodiff"] } - -# Optional: progress display and structured logging -indicatif = "0.17" -tracing = "0.1" - -[[bin]] -name = "az_train" -path = "src/bin/az_train.rs" - -[[bin]] -name = "az_eval" -path = "src/bin/az_eval.rs" -``` - ---- - -## 10. Comparison: `bot` crate vs `spiel_bot` - -| Aspect | `bot` (existing) | `spiel_bot` (proposed) | -| ---------------- | --------------------------- | -------------------------------------------- | -| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | -| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | -| Opponent | hardcoded random | self-play | -| Invalid actions | penalise with reward | legal action mask (no penalty) | -| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | -| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | -| Burn dep | yes (0.20) | yes (0.20), same backend | -| `burn-rl` dep | yes | no | -| C++ dep | no | no | -| Python dep | no | no | -| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | - -The two crates are **complementary**: `bot` is a working DQN/PPO baseline; -`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The -`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just -replace `TrictracEnvironment` with `TrictracEnv`). - ---- - -## 11. Implementation Order - -1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game - through the trait, verify terminal state and returns). -2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + - Burn forward/backward pass test with dummy data. -3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests - (visit counts sum to `n_simulations`, legal mask respected). -4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub - (one iteration, check loss decreases). -5. **Integration test**: run 100 self-play games with a tiny network (1 res block, - 64 hidden units), verify the training loop completes without panics. -6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s - on CPU, consistent with `random_game` throughput). -7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule - hyperparameter sweep. -8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report - win rate every checkpoint. - ---- - -## 12. Key Open Questions - -1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. - AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever - scored more holes). Better option: normalize the score margin. The final - choice affects how the value head is trained. - -2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). - MCTS with 200 simulations per step means ~120k network evaluations per game. - At batch inference this is feasible on CPU; on GPU it becomes fast. Consider - limiting `n_simulations` to 50–100 for early training. - -3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). - This is a long-horizon decision that AlphaZero handles naturally via MCTS - lookahead, but needs careful value normalization (a "Go" restarts scoring - within the same game). - -4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated - to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. - This is optional but reduces code duplication. - -5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, - 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. - A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. - -## Implementation results - -All benchmarks compile and run. Here's the complete results table: - -| Group | Benchmark | Time | -| ------- | ----------------------- | --------------------- | -| env | apply_chance | 3.87 µs | -| | legal_actions | 1.91 µs | -| | observation (to_tensor) | 341 ns | -| | random_game (baseline) | 3.55 ms → 282 games/s | -| network | mlp_b1 hidden=64 | 94.9 µs | -| | mlp_b32 hidden=64 | 141 µs | -| | mlp_b1 hidden=256 | 352 µs | -| | mlp_b32 hidden=256 | 479 µs | -| mcts | zero_eval n=1 | 6.8 µs | -| | zero_eval n=5 | 23.9 µs | -| | zero_eval n=20 | 90.9 µs | -| | mlp64 n=1 | 203 µs | -| | mlp64 n=5 | 622 µs | -| | mlp64 n=20 | 2.30 ms | -| episode | trictrac n=1 | 51.8 ms → 19 games/s | -| | trictrac n=2 | 145 ms → 7 games/s | -| train | mlp64 Adam b=16 | 1.93 ms | -| | mlp64 Adam b=64 | 2.68 ms | - -Key observations: - -- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) -- observation (217-value tensor): only 341 ns — not a bottleneck -- legal_actions: 1.9 µs — well optimised -- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs -- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal -- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it -- Training: 2.7 ms/step at batch=64 → 370 steps/s - -### Summary of Step 8 - -spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: - -- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct -- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) -- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side -- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) -- Output: win/draw/loss per side + combined decisive win rate - -Typical usage after training: -cargo run -p spiel_bot --bin az_eval --release -- \ - --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 - -### az_train - -#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) - -cargo run -p spiel_bot --bin az_train --release - -#### ResNet, more games, custom output dir - -cargo run -p spiel_bot --bin az_train --release -- \ - --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ - --save-every 20 --out checkpoints/ - -#### Resume from iteration 50 - -cargo run -p spiel_bot --bin az_train --release -- \ - --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 - -What the binary does each iteration: - -1. Calls model.valid() to get a zero-overhead inference copy for self-play -2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) -3. Pushes samples into a ReplayBuffer (capacity --replay-cap) -4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min -5. Saves a .mpk checkpoint every --save-every iterations and always on the last diff --git a/doc/tensor_research.md b/doc/tensor_research.md deleted file mode 100644 index b0d0ede..0000000 --- a/doc/tensor_research.md +++ /dev/null @@ -1,253 +0,0 @@ -# Tensor research - -## Current tensor anatomy - -[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) -[24] active player color: 0 or 1 -[25] turn_stage: 1–5 -[26–27] dice values (raw 1–6) -[28–31] white: points, holes, can_bredouille, can_big_bredouille -[32–35] black: same -───────────────────────────────── -Total 36 floats - -The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's -AlphaZero uses a fully-connected network. - -### Fundamental problems with the current encoding - -1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network - must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with - all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. - -2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient - flow during training is uneven. - -3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it - triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers - produces a score. Including this explicitly is the single highest-value addition. - -4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely - different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. - -5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the - starting position). It's in the Player struct but not exported. - -## Key Trictrac distinctions from backgammon that shape the encoding - -| Concept | Backgammon | Trictrac | -| ------------------------- | ---------------------- | --------------------------------------------------------- | -| Hitting a blot | Removes checker to bar | Scores points, checker stays | -| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | -| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | -| 3-checker field | Safe with spare | Safe with spare | -| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | -| Both colors on a field | Impossible | Perfectly legal | -| Rest corner (field 12/13) | Does not exist | Special two-checker rules | - -The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary -indicators directly teaches the network the phase transitions the game hinges on. - -## Options - -### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) - -The minimum viable improvement. - -For each of the 24 fields, encode own and opponent separately with 4 indicators each: - -own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) -own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) -own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) -own_x[i]: max(0, count − 3) (overflow) -opp_1[i]: same for opponent -… - -Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). - -Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 -Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured -overhead is negligible. -Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from -scratch will be substantially faster and the converged policy quality better. - -### Option B — Option A + Trictrac-specific derived features (flat 1D) - -Recommended starting point. - -Add on top of Option A: - -// Quarter fill status — the single most important derived feature -quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) -quarter_filled_opp[q] (q=0..3): same for opponent -→ 8 values - -// Exit readiness -can_exit_own: 1.0 if all own checkers are in fields 19–24 -can_exit_opp: same for opponent -→ 2 values - -// Rest corner status (field 12/13) -own_corner_taken: 1.0 if field 12 has ≥2 own checkers -opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers -→ 2 values - -// Jan de 3 coups counter (normalized) -dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) -→ 1 value - -Size: 204 + 8 + 2 + 2 + 1 = 217 -Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve -the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes -expensive inference from the network. - -### Option C — Option B + richer positional features (flat 1D) - -More complete, higher sample efficiency, minor extra cost. - -Add on top of Option B: - -// Per-quarter fill fraction — how close to filling each quarter -own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 -opp_quarter_fill_fraction[q] (q=0..3): same for opponent -→ 8 values - -// Blot counts — number of own/opponent single-checker fields globally -// (tells the network at a glance how much battage risk/opportunity exists) -own_blot_count: (number of own fields with exactly 1 checker) / 15.0 -opp_blot_count: same for opponent -→ 2 values - -// Bredouille would-double multiplier (already present, but explicitly scaled) -// No change needed, already binary - -Size: 217 + 8 + 2 = 227 -Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network -from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). - -### Option D — 2D spatial tensor {K, 24} - -For CNN-based networks. Best eventual architecture but requires changing the training setup. - -Shape {14, 24} — 14 feature channels over 24 field positions: - -Channel 0: own_count_1 (blot) -Channel 1: own_count_2 -Channel 2: own_count_3 -Channel 3: own_count_overflow (float) -Channel 4: opp_count_1 -Channel 5: opp_count_2 -Channel 6: opp_count_3 -Channel 7: opp_count_overflow -Channel 8: own_corner_mask (1.0 at field 12) -Channel 9: opp_corner_mask (1.0 at field 13) -Channel 10: final_quarter_mask (1.0 at fields 19–24) -Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) -Channel 12: quarter_filled_opp (same for opponent) -Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) - -Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform -value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" -row by returning shape {K, 25} with position 0 holding global features. - -Size: 14 × 24 + few global channels ≈ 336–384 -C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated -accordingly. -Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's -alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, -W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. -Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel -size 2–3 captures adjacent-field "tout d'une" interactions. - -### On 3D tensors - -Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is -the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and -field-within-quarter patterns jointly. - -However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The -extra architecture work makes this premature unless you're already building a custom network. The information content -is the same as Option D. - -### Recommendation - -Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and -the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions -(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. - -Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when -the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. - -Costs summary: - -| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | -| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | -| Current | 36 | — | — | — | baseline | -| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | -| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | -| C | 227 | to_vec() rewrite | constant update | none | large + moderate | -| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | - -One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, -the new to_vec() must express everything from the active player's perspective (which the mirror already handles for -the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, -which they will be if computed inside to_vec(). - -## Other algorithms - -The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. - -### 1. Without MCTS, feature quality matters more - -AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred -simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — -the network must learn the full strategic structure directly from gradient updates. - -This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, -not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less -sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from -scratch. - -### 2. Normalization becomes mandatory, not optional - -AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps -Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw -values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can -be unstable. - -For DQN/PPO, every feature in the tensor should be in [0, 1]: - -dice values: / 6.0 -checker counts: overflow channel / 12.0 -points: / 12.0 -holes: / 12.0 -dice_roll_count: / 3.0 (clamped) - -Booleans and the TD-Gammon binary indicators are already in [0, 1]. - -### 3. The shape question depends on architecture, not algorithm - -| Architecture | Shape | When to use | -| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | -| MLP | {217} flat | Any algorithm, simplest baseline | -| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | -| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | -| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | - -The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want -convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet -template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all -— you just reshape the flat tensor in Python before passing it to the network. - -### One thing that genuinely changes: value function perspective - -AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This -works well. - -DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit -current-player indicator), because a single Q-network estimates action values for both players simultaneously. With -the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must -learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. -This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. - -Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml deleted file mode 100644 index 682505b..0000000 --- a/spiel_bot/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[dependencies] -trictrac-store = { path = "../store" } -anyhow = "1" -rand = "0.9" -rand_distr = "0.5" -burn = { version = "0.20", features = ["ndarray", "autodiff"] } -rayon = "1" - -[dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } - -[[bench]] -name = "alphazero" -harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs deleted file mode 100644 index 00d5b02..0000000 --- a/spiel_bot/benches/alphazero.rs +++ /dev/null @@ -1,373 +0,0 @@ -//! AlphaZero pipeline benchmarks. -//! -//! Run with: -//! -//! ```sh -//! cargo bench -p spiel_bot -//! ``` -//! -//! Use `-- ` to run a specific group, e.g.: -//! -//! ```sh -//! cargo bench -p spiel_bot -- env/ -//! cargo bench -p spiel_bot -- network/ -//! cargo bench -p spiel_bot -- mcts/ -//! cargo bench -p spiel_bot -- episode/ -//! cargo bench -p spiel_bot -- train/ -//! ``` -//! -//! Target: ≥ 500 games/s for random play on CPU (consistent with -//! `random_game` throughput in `trictrac-store`). - -use std::time::Duration; - -use burn::{ - backend::NdArray, - tensor::{Tensor, TensorData, backend::Backend}, -}; -use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; -use rand::{Rng, SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, -}; - -// ── Shared types ─────────────────────────────────────────────────────────── - -type InferB = NdArray; -type TrainB = burn::backend::Autodiff>; - -fn infer_device() -> ::Device { Default::default() } -fn train_device() -> ::Device { Default::default() } - -fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } - -/// Uniform evaluator (returns zero logits and zero value). -/// Used to isolate MCTS tree-traversal cost from network cost. -struct ZeroEval(usize); -impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } -} - -// ── 1. Environment primitives ────────────────────────────────────────────── - -/// Baseline performance of the raw Trictrac environment without MCTS. -/// Target: ≥ 500 full games / second. -fn bench_env(c: &mut Criterion) { - let env = TrictracEnv; - - let mut group = c.benchmark_group("env"); - group.measurement_time(Duration::from_secs(10)); - - // ── apply_chance ────────────────────────────────────────────────────── - group.bench_function("apply_chance", |b| { - b.iter_batched( - || { - // A fresh game is always at RollDice (Chance) — ready for apply_chance. - env.new_game() - }, - |mut s| { - env.apply_chance(&mut s, &mut seeded()); - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - // ── legal_actions ───────────────────────────────────────────────────── - group.bench_function("legal_actions", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.legal_actions(&s))) - }); - - // ── observation (to_tensor) ─────────────────────────────────────────── - group.bench_function("observation", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.observation(&s, 0))) - }); - - // ── full random game ────────────────────────────────────────────────── - group.sample_size(50); - group.bench_function("random_game", |b| { - b.iter_batched( - seeded, - |mut rng| { - let mut s = env.new_game(); - loop { - match env.current_player(&s) { - Player::Terminal => break, - Player::Chance => env.apply_chance(&mut s, &mut rng), - _ => { - let actions = env.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - env.apply(&mut s, actions[idx]); - } - } - } - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - group.finish(); -} - -// ── 2. Network inference ─────────────────────────────────────────────────── - -/// Forward-pass latency for MLP variants (hidden = 64 / 256). -fn bench_network(c: &mut Criterion) { - let mut group = c.benchmark_group("network"); - group.measurement_time(Duration::from_secs(5)); - - for &hidden in &[64usize, 256] { - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = MlpNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - // Batch size 1 — single-position evaluation as in MCTS. - group.bench_with_input( - BenchmarkId::new("mlp_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - // Batch size 32 — training mini-batch. - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("mlp_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - // ── ResNet (4 residual blocks) ──────────────────────────────────────── - for &hidden in &[256usize, 512] { - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = ResNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - group.bench_with_input( - BenchmarkId::new("resnet_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("resnet_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - group.finish(); -} - -// ── 3. MCTS ─────────────────────────────────────────────────────────────── - -/// MCTS cost at different simulation budgets with two evaluator types: -/// - `zero` — isolates tree-traversal overhead (no network). -/// - `mlp64` — real MLP, shows end-to-end cost per move. -fn bench_mcts(c: &mut Criterion) { - let env = TrictracEnv; - - // Build a decision-node state (after dice roll). - let state = { - let mut s = env.new_game(); - let mut rng = seeded(); - while env.current_player(&s).is_chance() { - env.apply_chance(&mut s, &mut rng); - } - s - }; - - let mut group = c.benchmark_group("mcts"); - group.measurement_time(Duration::from_secs(10)); - - let zero_eval = ZeroEval(514); - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); - let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); - - for &n_sim in &[1usize, 5, 20] { - let cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - // Zero evaluator: tree traversal only. - group.bench_with_input( - BenchmarkId::new("zero_eval", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - - // MLP evaluator: full cost per decision. - group.bench_with_input( - BenchmarkId::new("mlp64", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 4. Episode generation ───────────────────────────────────────────────── - -/// Full self-play episode latency (one complete game) at different MCTS -/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. -fn bench_episode(c: &mut Criterion) { - let env = TrictracEnv; - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let model = MlpNet::::new(&mlp_cfg, &infer_device()); - let eval = BurnEvaluator::::new(model, infer_device()); - - let mut group = c.benchmark_group("episode"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(60)); - - for &n_sim in &[1usize, 2] { - let mcts_cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - group.bench_with_input( - BenchmarkId::new("trictrac", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| { - black_box(generate_episode( - &env, - &eval, - &mcts_cfg, - &|_| 1.0, - &mut rng, - )) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 5. Training step ─────────────────────────────────────────────────────── - -/// Gradient-step latency for different batch sizes. -fn bench_train(c: &mut Criterion) { - use burn::optim::AdamConfig; - - let mut group = c.benchmark_group("train"); - group.measurement_time(Duration::from_secs(10)); - - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - - let dummy_samples = |n: usize| -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5; 217], - policy: { - let mut p = vec![0.0f32; 514]; - p[i % 514] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - }; - - for &batch_size in &[16usize, 64] { - let batch = dummy_samples(batch_size); - - group.bench_with_input( - BenchmarkId::new("mlp64_adam", batch_size), - &batch_size, - |b, _| { - b.iter_batched( - || { - ( - MlpNet::::new(&mlp_cfg, &train_device()), - AdamConfig::new().init::>(), - ) - }, - |(model, mut opt)| { - black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── Criterion entry point ────────────────────────────────────────────────── - -criterion_group!( - benches, - bench_env, - bench_network, - bench_mcts, - bench_episode, - bench_train, -); -criterion_main!(benches); diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs deleted file mode 100644 index d92224e..0000000 --- a/spiel_bot/src/alphazero/mod.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! AlphaZero: self-play data generation, replay buffer, and training step. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | -//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | -//! | [`trainer`] | [`train_step`] | -//! -//! # Typical outer loop -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! use burn::optim::AdamConfig; -//! use spiel_bot::{ -//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, -//! env::TrictracEnv, -//! mcts::MctsConfig, -//! network::{MlpConfig, MlpNet}, -//! }; -//! -//! type Infer = NdArray; -//! type Train = Autodiff>; -//! -//! let device = Default::default(); -//! let env = TrictracEnv; -//! let config = AlphaZeroConfig::default(); -//! -//! // Build training model and optimizer. -//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); -//! let mut optimizer = AdamConfig::new().init(); -//! let mut replay = ReplayBuffer::new(config.replay_capacity); -//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); -//! -//! for _iter in 0..config.n_iterations { -//! // Convert to inference backend for self-play. -//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) -//! .load_record(train_model.clone().into_record()); -//! let eval = BurnEvaluator::new(infer_model, device.clone()); -//! -//! // Self-play: generate episodes. -//! for _ in 0..config.n_games_per_iter { -//! let samples = generate_episode(&env, &eval, &config.mcts, -//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); -//! replay.extend(samples); -//! } -//! -//! // Training: gradient steps. -//! if replay.len() >= config.batch_size { -//! for _ in 0..config.n_train_steps_per_iter { -//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) -//! .into_iter().cloned().collect(); -//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, -//! config.learning_rate); -//! train_model = m; -//! } -//! } -//! } -//! ``` - -pub mod replay; -pub mod selfplay; -pub mod trainer; - -pub use replay::{ReplayBuffer, TrainSample}; -pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::{cosine_lr, train_step}; - -use crate::mcts::MctsConfig; - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Top-level AlphaZero hyperparameters. -/// -/// The MCTS parameters live in [`MctsConfig`]; this struct holds the -/// outer training-loop parameters. -#[derive(Debug, Clone)] -pub struct AlphaZeroConfig { - /// MCTS parameters for self-play. - pub mcts: MctsConfig, - /// Number of self-play games per training iteration. - pub n_games_per_iter: usize, - /// Number of gradient steps per training iteration. - pub n_train_steps_per_iter: usize, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Initial (peak) Adam learning rate. - pub learning_rate: f64, - /// Minimum learning rate for cosine annealing (floor of the schedule). - /// - /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). - /// Compute the current LR with [`cosine_lr`]: - /// - /// ```rust,ignore - /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); - /// ``` - pub lr_min: f64, - /// Number of outer iterations (self-play + train) to run. - pub n_iterations: usize, - /// Move index after which the action temperature drops to 0 (greedy play). - pub temperature_drop_move: usize, -} - -impl Default for AlphaZeroConfig { - fn default() -> Self { - Self { - mcts: MctsConfig { - n_simulations: 100, - c_puct: 1.5, - dirichlet_alpha: 0.1, - dirichlet_eps: 0.25, - temperature: 1.0, - }, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - batch_size: 64, - replay_capacity: 50_000, - learning_rate: 1e-3, - lr_min: 1e-4, // cosine annealing floor - n_iterations: 100, - temperature_drop_move: 30, - } - } -} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs deleted file mode 100644 index 5e64cc4..0000000 --- a/spiel_bot/src/alphazero/replay.rs +++ /dev/null @@ -1,144 +0,0 @@ -//! Replay buffer for AlphaZero self-play data. - -use std::collections::VecDeque; -use rand::Rng; - -// ── Training sample ──────────────────────────────────────────────────────── - -/// One training example produced by self-play. -#[derive(Clone, Debug)] -pub struct TrainSample { - /// Observation tensor from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). - pub policy: Vec, - /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. - pub value: f32, -} - -// ── Replay buffer ────────────────────────────────────────────────────────── - -/// Fixed-capacity circular buffer of [`TrainSample`]s. -/// -/// When the buffer is full, the oldest sample is evicted on push. -/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - /// Create a buffer with the given maximum capacity. - pub fn new(capacity: usize) -> Self { - Self { - data: VecDeque::with_capacity(capacity.min(1024)), - capacity, - } - } - - /// Add a sample; evicts the oldest if at capacity. - pub fn push(&mut self, sample: TrainSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - /// Add all samples from an episode. - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { - self.push(s); - } - } - - pub fn len(&self) -> usize { - self.data.len() - } - - pub fn is_empty(&self) -> bool { - self.data.is_empty() - } - - /// Sample up to `n` distinct samples, without replacement. - /// - /// If the buffer has fewer than `n` samples, all are returned (shuffled). - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - let len = self.data.len(); - let n = n.min(len); - // Partial Fisher-Yates using index shuffling. - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(value: f32) -> TrainSample { - TrainSample { obs: vec![value], policy: vec![1.0], value } - } - - #[test] - fn push_and_len() { - let mut buf = ReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = ReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); // evicts 1.0 - assert_eq!(buf.len(), 3); - // Oldest remaining should be 2.0 - assert_eq!(buf.data[0].value, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(5, &mut rng); - assert_eq!(batch.len(), 5); - } - - #[test] - fn sample_batch_capped_at_len() { - let mut buf = ReplayBuffer::new(20); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(100, &mut rng); - assert_eq!(batch.len(), 2); - } - - #[test] - fn sample_batch_no_duplicates() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(1); - let batch = buf.sample_batch(10, &mut rng); - let mut seen: Vec = batch.iter().map(|s| s.value).collect(); - seen.sort_by(f32::total_cmp); - seen.dedup(); - assert_eq!(seen.len(), 10, "sample contained duplicates"); - } -} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs deleted file mode 100644 index b38b7f4..0000000 --- a/spiel_bot/src/alphazero/selfplay.rs +++ /dev/null @@ -1,238 +0,0 @@ -//! Self-play episode generation and Burn-backed evaluator. - -use std::marker::PhantomData; - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::GameEnv; -use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; -use crate::network::PolicyValueNet; - -use super::replay::TrainSample; - -// ── BurnEvaluator ────────────────────────────────────────────────────────── - -/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. -/// -/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so -/// that self-play generates no gradient tape overhead. -pub struct BurnEvaluator> { - model: N, - device: B::Device, - _b: PhantomData, -} - -impl> BurnEvaluator { - pub fn new(model: N, device: B::Device) -> Self { - Self { model, device, _b: PhantomData } - } - - pub fn into_model(self) -> N { - self.model - } - - pub fn model_ref(&self) -> &N { - &self.model - } -} - -// Safety: NdArray modules are Send; we never share across threads without -// external synchronisation. -unsafe impl> Send for BurnEvaluator {} -unsafe impl> Sync for BurnEvaluator {} - -impl> Evaluator for BurnEvaluator { - fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { - let obs_size = obs.len(); - let data = TensorData::new(obs.to_vec(), [1, obs_size]); - let obs_tensor = Tensor::::from_data(data, &self.device); - - let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); - - let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); - let value: Vec = value_tensor.into_data().to_vec().unwrap(); - - (policy, value[0]) - } -} - -// ── Episode generation ───────────────────────────────────────────────────── - -/// One pending observation waiting for its game-outcome value label. -struct PendingSample { - obs: Vec, - policy: Vec, - player: usize, -} - -/// Play one full game using MCTS guided by `evaluator`. -/// -/// Returns a [`TrainSample`] for every decision step in the game. -/// -/// `temperature_fn(step)` controls exploration: return `1.0` for early -/// moves and `0.0` after a fixed number of moves (e.g. move 30). -pub fn generate_episode( - env: &E, - evaluator: &dyn Evaluator, - mcts_config: &MctsConfig, - temperature_fn: &dyn Fn(usize) -> f32, - rng: &mut impl Rng, -) -> Vec { - let mut state = env.new_game(); - let mut pending: Vec = Vec::new(); - let mut step = 0usize; - - loop { - // Advance through chance nodes automatically. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - if env.current_player(&state).is_terminal() { - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - - // Run MCTS to get a policy. - let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); - let policy = mcts::mcts_policy(&root, env.action_space()); - - // Record the observation from the acting player's perspective. - let obs = env.observation(&state, player_idx); - pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); - - // Select and apply the action. - let temperature = temperature_fn(step); - let action = mcts::select_action(&root, temperature, rng); - env.apply(&mut state, action); - step += 1; - } - - // Assign game outcomes. - let returns = env.returns(&state).unwrap_or([0.0; 2]); - pending - .into_iter() - .map(|s| TrainSample { - obs: s.obs, - policy: s.policy, - value: returns[s.player], - }) - .collect() -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::env::Player; - use crate::mcts::{Evaluator, MctsConfig}; - use crate::network::{MlpConfig, MlpNet}; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(7) - } - - // Countdown game (same as in mcts tests). - #[derive(Clone, Debug)] - struct CState { remaining: u8, to_move: usize } - - #[derive(Clone)] - struct CountdownEnv; - - impl GameEnv for CountdownEnv { - type State = CState; - fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } else { Player::P2 } - } - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { s.remaining = 0; } - else { s.remaining -= sub; s.to_move = 1 - s.to_move; } - } - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 4.0, s.to_move as f32] - } - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - fn tiny_config() -> MctsConfig { - MctsConfig { n_simulations: 5, c_puct: 1.5, - dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } - } - - // ── BurnEvaluator tests ─────────────────────────────────────────────── - - #[test] - fn burn_evaluator_output_shapes() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); - assert_eq!(policy.len(), 2, "policy length should equal action_space"); - assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); - } - - // ── generate_episode tests ──────────────────────────────────────────── - - #[test] - fn episode_terminates_and_has_samples() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_sample_values_are_valid() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - for s in &samples { - assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "unexpected value {}", s.value); - let sum: f32 = s.policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); - assert_eq!(s.obs.len(), 2); - } - } - - #[test] - fn episode_with_temperature_zero() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - // temperature=0 means greedy; episode must still terminate - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); - assert!(!samples.is_empty()); - } -} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs deleted file mode 100644 index 9075519..0000000 --- a/spiel_bot/src/alphazero/trainer.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! One gradient-descent training step for AlphaZero. -//! -//! The loss combines: -//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. -//! - **Value loss** — mean-squared error between the predicted value and the -//! actual game outcome. -//! -//! # Learning-rate scheduling -//! -//! [`cosine_lr`] implements one-cycle cosine annealing: -//! -//! ```text -//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) -//! ``` -//! -//! Typical usage in the outer loop: -//! -//! ```rust,ignore -//! for step in 0..total_train_steps { -//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); -//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); -//! model = m; -//! } -//! ``` -//! -//! # Backend -//! -//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). -//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. -//! Weights are transferred between the two via [`burn::record`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - activation::log_softmax, - backend::AutodiffBackend, - Tensor, TensorData, - }, -}; - -use crate::network::PolicyValueNet; -use super::replay::TrainSample; - -/// Run one gradient step on `model` using `batch`. -/// -/// Returns the updated model and the scalar loss value for logging. -/// -/// # Parameters -/// -/// - `lr` — learning rate (e.g. `1e-3`). -/// - `batch` — slice of [`TrainSample`]s; must be non-empty. -pub fn train_step( - model: N, - optimizer: &mut O, - batch: &[TrainSample], - device: &B::Device, - lr: f64, -) -> (N, f32) -where - B: AutodiffBackend, - N: PolicyValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "train_step called with empty batch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - let action_size = batch[0].policy.len(); - - // ── Build input tensors ──────────────────────────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); - let value_flat: Vec = batch.iter().map(|s| s.value).collect(); - - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - let policy_target = Tensor::::from_data( - TensorData::new(policy_flat, [batch_size, action_size]), - device, - ); - let value_target = Tensor::::from_data( - TensorData::new(value_flat, [batch_size, 1]), - device, - ); - - // ── Forward pass ────────────────────────────────────────────────────── - let (policy_logits, value_pred) = model.forward(obs_tensor); - - // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── - let log_probs = log_softmax(policy_logits, 1); - let policy_loss = (policy_target.clone().neg() * log_probs) - .sum_dim(1) - .mean(); - - // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── - let diff = value_pred - value_target; - let value_loss = (diff.clone() * diff).mean(); - - // ── Combined loss ───────────────────────────────────────────────────── - let loss = policy_loss + value_loss; - - // Extract scalar before backward (consumes the tensor). - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - let model = optimizer.step(lr, model, grads); - - (model, loss_scalar) -} - -// ── Learning-rate schedule ───────────────────────────────────────────────── - -/// Cosine learning-rate schedule (one half-period, no warmup). -/// -/// Returns the learning rate for training step `step` out of `total_steps`: -/// -/// ```text -/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) -/// ``` -/// -/// - At `t = 0` returns `initial`. -/// - At `t = total_steps` (or beyond) returns `lr_min`. -/// -/// # Panics -/// -/// Does not panic. When `total_steps == 0`, returns `lr_min`. -pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { - if total_steps == 0 || step >= total_steps { - return lr_min; - } - let progress = step as f64 / total_steps as f64; - lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - - use crate::network::{MlpConfig, MlpNet}; - use super::super::replay::TrainSample; - - type B = Autodiff>; - - fn device() -> ::Device { - Default::default() - } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5f32; obs_size], - policy: { - let mut p = vec![0.0f32; action_size]; - p[i % action_size] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - } - - #[test] - fn train_step_returns_finite_loss() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); - } - - #[test] - fn loss_decreases_over_steps() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - // Same batch every step — loss should decrease. - let batch = dummy_batch(16, 4, 4); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); - model = m; - assert!(loss.is_finite()); - prev_loss = loss; - } - // After 10 steps on fixed data, loss should be below a reasonable threshold. - assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn train_step_batch_size_one() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(1, 2, 2); - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite()); - } - - // ── cosine_lr ───────────────────────────────────────────────────────── - - #[test] - fn cosine_lr_at_step_zero_is_initial() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); - assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); - } - - #[test] - fn cosine_lr_at_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); - } - - #[test] - fn cosine_lr_beyond_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); - } - - #[test] - fn cosine_lr_midpoint_is_average() { - // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. - let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); - let expected = (1e-3 + 1e-5) / 2.0; - assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); - } - - #[test] - fn cosine_lr_monotone_decreasing() { - let mut prev = f64::INFINITY; - for step in 0..=100 { - let lr = super::cosine_lr(1e-3, 1e-5, step, 100); - assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); - prev = lr; - } - } - - #[test] - fn cosine_lr_zero_total_steps_returns_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); - assert!((lr - 1e-5).abs() < 1e-10); - } -} diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs deleted file mode 100644 index 3c82519..0000000 --- a/spiel_bot/src/bin/az_eval.rs +++ /dev/null @@ -1,262 +0,0 @@ -//! Evaluate a trained AlphaZero checkpoint against a random player. -//! -//! # Usage -//! -//! ```sh -//! # Random weights (sanity check — should be ~50 %) -//! cargo run -p spiel_bot --bin az_eval --release -//! -//! # Trained MLP checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 -//! -//! # Trained ResNet checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | -//! | `--arch mlp\|resnet` | `mlp` | Network architecture | -//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | -//! | `--n-games ` | `100` | Games per side (total = 2 × N) | -//! | `--n-sim ` | `50` | MCTS simulations per move | -//! | `--seed ` | `42` | RNG seed | -//! | `--c-puct ` | `1.5` | PUCT exploration constant | - -use std::path::PathBuf; - -use burn::backend::NdArray; -use rand::{SeedableRng, rngs::SmallRng, Rng}; - -use spiel_bot::{ - alphazero::BurnEvaluator, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts, select_action}, - network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, -}; - -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - checkpoint: Option, - arch: String, - hidden: Option, - n_games: usize, - n_sim: usize, - seed: u64, - c_puct: f32, -} - -impl Default for Args { - fn default() -> Self { - Self { - checkpoint: None, - arch: "mlp".into(), - hidden: None, - n_games: 100, - n_sim: 50, - seed: 42, - c_puct: 1.5, - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut args = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } - "--arch" => { i += 1; args.arch = raw[i].clone(); } - "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } - "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } - "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } - "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } - "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - args -} - -// ── Game loop ───────────────────────────────────────────────────────────────── - -/// Play one complete game. -/// -/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). -/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). -fn play_game( - env: &TrictracEnv, - mcts_side: usize, - evaluator: &dyn Evaluator, - mcts_cfg: &MctsConfig, - rng: &mut SmallRng, -) -> [f32; 2] { - let mut state = env.new_game(); - loop { - match env.current_player(&state) { - Player::Terminal => { - return env.returns(&state).expect("Terminal state must have returns"); - } - Player::Chance => env.apply_chance(&mut state, rng), - player => { - let side = player.index().unwrap(); // 0 = P1, 1 = P2 - let action = if side == mcts_side { - let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); - select_action(&root, 0.0, rng) // greedy (temperature = 0) - } else { - let actions = env.legal_actions(&state); - actions[rng.random_range(0..actions.len())] - }; - env.apply(&mut state, action); - } - } - } -} - -// ── Statistics ──────────────────────────────────────────────────────────────── - -#[derive(Default)] -struct Stats { - wins: u32, - draws: u32, - losses: u32, -} - -impl Stats { - fn record(&mut self, mcts_return: f32) { - if mcts_return > 0.0 { self.wins += 1; } - else if mcts_return < 0.0 { self.losses += 1; } - else { self.draws += 1; } - } - - fn total(&self) -> u32 { self.wins + self.draws + self.losses } - - fn win_rate_decisive(&self) -> f64 { - let d = self.wins + self.losses; - if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } - } - - fn print(&self) { - let n = self.total(); - let pct = |k: u32| 100.0 * k as f64 / n as f64; - println!( - " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", - self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), - ); - } -} - -// ── Evaluation ──────────────────────────────────────────────────────────────── - -fn run_evaluation( - evaluator: &dyn Evaluator, - n_games: usize, - mcts_cfg: &MctsConfig, - seed: u64, -) -> (Stats, Stats) { - let env = TrictracEnv; - let total = n_games * 2; - let mut as_p1 = Stats::default(); - let mut as_p2 = Stats::default(); - - for i in 0..total { - // Alternate sides: even games → MctsAgent as P1, odd → as P2. - let mcts_side = i % 2; - let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); - let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); - - let mcts_return = result[mcts_side]; - if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } - - let done = i + 1; - if done % 10 == 0 || done == total { - eprint!("\r [{done}/{total}] ", ); - } - } - eprintln!(); - (as_p1, as_p2) -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - let device: ::Device = Default::default(); - - // ── Load model ──────────────────────────────────────────────────────── - let evaluator: Box = match args.arch.as_str() { - "resnet" => { - let hidden = args.hidden.unwrap_or(512); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => ResNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => ResNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - "mlp" | _ => { - let hidden = args.hidden.unwrap_or(256); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => MlpNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => MlpNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - }; - - let mcts_cfg = MctsConfig { - n_simulations: args.n_sim, - c_puct: args.c_puct, - dirichlet_alpha: 0.0, // no exploration noise during evaluation - dirichlet_eps: 0.0, - temperature: 0.0, // greedy action selection - }; - - // ── Header ──────────────────────────────────────────────────────────── - let ckpt_label = args.checkpoint - .as_deref() - .and_then(|p| p.file_name()) - .and_then(|n| n.to_str()) - .unwrap_or("random weights"); - - println!(); - println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", - args.arch, args.n_sim); - println!("Games per side: {} | Total: {} | Seed: {}", - args.n_games, args.n_games * 2, args.seed); - println!(); - - // ── Run ─────────────────────────────────────────────────────────────── - let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); - - // ── Results ─────────────────────────────────────────────────────────── - println!("MctsAgent as P1 (White):"); - as_p1.print(); - - println!("MctsAgent as P2 (Black):"); - as_p2.print(); - - let combined_wins = as_p1.wins + as_p2.wins; - let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; - let combined_wr = if combined_decisive == 0 { 0.5 } - else { combined_wins as f64 / combined_decisive as f64 }; - - println!(); - println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", - combined_wr * 100.0, combined_wins, combined_decisive); - println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", - as_p1.win_rate_decisive() * 100.0, - as_p2.win_rate_decisive() * 100.0); -} diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs deleted file mode 100644 index 824abe5..0000000 --- a/spiel_bot/src/bin/az_train.rs +++ /dev/null @@ -1,331 +0,0 @@ -//! AlphaZero self-play training loop. -//! -//! # Usage -//! -//! ```sh -//! # Start fresh (MLP, default settings) -//! cargo run -p spiel_bot --bin az_train --release -//! -//! # ResNet, 200 iterations, save every 20 -//! cargo run -p spiel_bot --bin az_train --release -- \ -//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/ -//! -//! # Resume from a checkpoint -//! cargo run -p spiel_bot --bin az_train --release -- \ -//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--arch mlp\|resnet` | `mlp` | Network architecture | -//! | `--hidden N` | 256/512 | Hidden layer width | -//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | -//! | `--n-iter N` | `100` | Training iterations | -//! | `--n-games N` | `10` | Self-play games per iteration | -//! | `--n-train N` | `20` | Gradient steps per iteration | -//! | `--n-sim N` | `100` | MCTS simulations per move | -//! | `--batch N` | `64` | Mini-batch size | -//! | `--replay-cap N` | `50000` | Replay buffer capacity | -//! | `--lr F` | `1e-3` | Peak (initial) learning rate | -//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) | -//! | `--c-puct F` | `1.5` | PUCT exploration constant | -//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha | -//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight | -//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 | -//! | `--save-every N` | `10` | Save checkpoint every N iterations | -//! | `--seed N` | `42` | RNG seed | -//! | `--resume PATH` | (none) | Load weights from checkpoint before training | - -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, - tensor::backend::Backend, -}; -use rand::{Rng, SeedableRng, rngs::SmallRng}; -use rayon::prelude::*; - -use spiel_bot::{ - alphazero::{ - BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step, - }, - env::TrictracEnv, - mcts::MctsConfig, - network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, -}; - -type TrainB = Autodiff>; -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - arch: String, - hidden: Option, - out_dir: PathBuf, - n_iter: usize, - n_games: usize, - n_train: usize, - n_sim: usize, - batch_size: usize, - replay_cap: usize, - lr: f64, - lr_min: f64, - c_puct: f32, - dirichlet_alpha: f32, - dirichlet_eps: f32, - temp_drop: usize, - save_every: usize, - seed: u64, - resume: Option, -} - -impl Default for Args { - fn default() -> Self { - Self { - arch: "mlp".into(), - hidden: None, - out_dir: PathBuf::from("checkpoints"), - n_iter: 100, - n_games: 10, - n_train: 20, - n_sim: 100, - batch_size: 64, - replay_cap: 50_000, - lr: 1e-3, - lr_min: 1e-4, - c_puct: 1.5, - dirichlet_alpha: 0.1, - dirichlet_eps: 0.25, - temp_drop: 30, - save_every: 10, - seed: 42, - resume: None, - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut a = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--arch" => { i += 1; a.arch = raw[i].clone(); } - "--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); } - "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } - "--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); } - "--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); } - "--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); } - "--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); } - "--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); } - "--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); } - "--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); } - "--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); } - "--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); } - "--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); } - "--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); } - "--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); } - "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } - "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } - "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - a -} - -// ── Training loop ───────────────────────────────────────────────────────────── - -/// Generic training loop, parameterised over the network type. -/// -/// `save_fn` receives the **training-backend** model and the target path; -/// it is called in the match arm where the concrete network type is known. -fn train_loop( - mut model: N, - save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>, - args: &Args, -) -where - N: PolicyValueNet + AutodiffModule + Clone, - >::InnerModule: PolicyValueNet + Send + 'static, -{ - let train_device: ::Device = Default::default(); - let infer_device: ::Device = Default::default(); - - // Type is inferred as OptimizerAdaptor at the call site. - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(args.replay_cap); - let mut rng = SmallRng::seed_from_u64(args.seed); - let env = TrictracEnv; - - // Total gradient steps (used for cosine LR denominator). - let total_train_steps = (args.n_iter * args.n_train).max(1); - let mut global_step = 0usize; - - println!( - "\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}", - "", args.arch, args.n_iter, args.n_games, args.n_sim, "" - ); - - for iter in 0..args.n_iter { - let t0 = Instant::now(); - - // ── Self-play ──────────────────────────────────────────────────── - // Convert to inference backend (zero autodiff overhead). - let infer_model: >::InnerModule = model.valid(); - let evaluator: BurnEvaluator>::InnerModule> = - BurnEvaluator::new(infer_model, infer_device.clone()); - - let mcts_cfg = MctsConfig { - n_simulations: args.n_sim, - c_puct: args.c_puct, - dirichlet_alpha: args.dirichlet_alpha, - dirichlet_eps: args.dirichlet_eps, - temperature: 1.0, - }; - - let temp_drop = args.temp_drop; - let temperature_fn = |step: usize| -> f32 { - if step < temp_drop { 1.0 } else { 0.0 } - }; - - // Prepare per-game seeds and evaluators sequentially so the main RNG - // and model cloning stay deterministic regardless of thread scheduling. - // Burn modules are Send but not Sync, so each task must own its model. - let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); - let game_evals: Vec<_> = (0..args.n_games) - .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) - .collect(); - drop(evaluator); - - let all_samples: Vec> = game_seeds - .into_par_iter() - .zip(game_evals.into_par_iter()) - .map(|(seed, game_eval)| { - let mut game_rng = SmallRng::seed_from_u64(seed); - generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) - }) - .collect(); - - let mut new_samples = 0usize; - for samples in all_samples { - new_samples += samples.len(); - replay.extend(samples); - } - - // ── Training ───────────────────────────────────────────────────── - let mut loss_sum = 0.0f32; - let mut n_steps = 0usize; - - if replay.len() >= args.batch_size { - for _ in 0..args.n_train { - let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); - let batch: Vec = replay - .sample_batch(args.batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = - train_step(model, &mut optimizer, &batch, &train_device, lr); - model = m; - loss_sum += loss; - n_steps += 1; - global_step += 1; - } - } - - // ── Logging ────────────────────────────────────────────────────── - let elapsed = t0.elapsed(); - let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; - let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); - - println!( - "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s", - iter + 1, - args.n_iter, - replay.len(), - new_samples, - avg_loss, - lr_now, - elapsed.as_secs_f32(), - ); - - // ── Checkpoint ─────────────────────────────────────────────────── - let is_last = iter + 1 == args.n_iter; - if (iter + 1) % args.save_every == 0 || is_last { - let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1)); - match save_fn(&model, &path) { - Ok(()) => println!(" -> saved {}", path.display()), - Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), - } - } - } - - println!("\nTraining complete."); -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - - // Create output directory if it doesn't exist. - if let Err(e) = std::fs::create_dir_all(&args.out_dir) { - eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); - std::process::exit(1); - } - - let train_device: ::Device = Default::default(); - - match args.arch.as_str() { - "resnet" => { - let hidden = args.hidden.unwrap_or(512); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - - let model = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - ResNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => ResNet::::new(&cfg, &train_device), - }; - - train_loop( - model, - &|m: &ResNet, path: &Path| { - // Save via inference model to avoid autodiff record overhead. - m.valid().save(path) - }, - &args, - ); - } - - "mlp" | _ => { - let hidden = args.hidden.unwrap_or(256); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - - let model = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - MlpNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => MlpNet::::new(&cfg, &train_device), - }; - - train_loop( - model, - &|m: &MlpNet, path: &Path| m.valid().save(path), - &args, - ); - } - } -} diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs deleted file mode 100644 index 0ebe978..0000000 --- a/spiel_bot/src/bin/dqn_train.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! DQN self-play training loop. -//! -//! # Usage -//! -//! ```sh -//! # Start fresh with default settings -//! cargo run -p spiel_bot --bin dqn_train --release -//! -//! # Custom hyperparameters -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 -//! -//! # Resume from a checkpoint -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--hidden N` | 256 | Hidden layer width | -//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | -//! | `--n-iter N` | 100 | Training iterations | -//! | `--n-games N` | 10 | Self-play games per iteration | -//! | `--n-train N` | 20 | Gradient steps per iteration | -//! | `--batch N` | 64 | Mini-batch size | -//! | `--replay-cap N` | 50000 | Replay buffer capacity | -//! | `--lr F` | 1e-3 | Adam learning rate | -//! | `--epsilon-start F` | 1.0 | Initial exploration rate | -//! | `--epsilon-end F` | 0.05 | Final exploration rate | -//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | -//! | `--gamma F` | 0.99 | Discount factor | -//! | `--target-update N` | 500 | Hard-update target net every N steps | -//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | -//! | `--save-every N` | 10 | Save checkpoint every N iterations | -//! | `--seed N` | 42 | RNG seed | -//! | `--resume PATH` | (none) | Load weights before training | - -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, - tensor::backend::Backend, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - dqn::{ - DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, - generate_dqn_episode, hard_update, linear_epsilon, - }, - env::TrictracEnv, - network::{QNet, QNetConfig}, -}; - -type TrainB = Autodiff>; -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - hidden: usize, - out_dir: PathBuf, - save_every: usize, - seed: u64, - resume: Option, - config: DqnConfig, -} - -impl Default for Args { - fn default() -> Self { - Self { - hidden: 256, - out_dir: PathBuf::from("checkpoints"), - save_every: 10, - seed: 42, - resume: None, - config: DqnConfig::default(), - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut a = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } - "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } - "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } - "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } - "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } - "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } - "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } - "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } - "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } - "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } - "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } - "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } - "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } - "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } - "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } - "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } - "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - a -} - -// ── Training loop ───────────────────────────────────────────────────────────── - -fn train_loop( - mut q_net: QNet, - cfg: &QNetConfig, - save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, - args: &Args, -) { - let train_device: ::Device = Default::default(); - let infer_device: ::Device = Default::default(); - - let mut optimizer = AdamConfig::new().init(); - let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); - let mut rng = SmallRng::seed_from_u64(args.seed); - let env = TrictracEnv; - - let mut target_net: QNet = hard_update::(&q_net); - let mut global_step = 0usize; - let mut epsilon = args.config.epsilon_start; - - println!( - "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", - "", args.config.n_iterations, args.config.n_games_per_iter, - args.config.n_train_steps_per_iter, "" - ); - - for iter in 0..args.config.n_iterations { - let t0 = Instant::now(); - - // ── Self-play ──────────────────────────────────────────────────── - let infer_q: QNet = q_net.valid(); - let mut new_samples = 0usize; - - for _ in 0..args.config.n_games_per_iter { - let samples = generate_dqn_episode( - &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, - ); - new_samples += samples.len(); - replay.extend(samples); - } - - // ── Training ───────────────────────────────────────────────────── - let mut loss_sum = 0.0f32; - let mut n_steps = 0usize; - - if replay.len() >= args.config.batch_size { - for _ in 0..args.config.n_train_steps_per_iter { - let batch: Vec<_> = replay - .sample_batch(args.config.batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - // Target Q-values computed on the inference backend. - let target_q = compute_target_q( - &target_net, &batch, cfg.action_size, &infer_device, - ); - - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &target_q, - &train_device, args.config.learning_rate, args.config.gamma, - ); - q_net = q; - loss_sum += loss; - n_steps += 1; - global_step += 1; - - // Hard-update target net every target_update_freq steps. - if global_step % args.config.target_update_freq == 0 { - target_net = hard_update::(&q_net); - } - - // Linear epsilon decay. - epsilon = linear_epsilon( - args.config.epsilon_start, - args.config.epsilon_end, - global_step, - args.config.epsilon_decay_steps, - ); - } - } - - // ── Logging ────────────────────────────────────────────────────── - let elapsed = t0.elapsed(); - let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; - - println!( - "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", - iter + 1, - args.config.n_iterations, - replay.len(), - new_samples, - avg_loss, - epsilon, - elapsed.as_secs_f32(), - ); - - // ── Checkpoint ─────────────────────────────────────────────────── - let is_last = iter + 1 == args.config.n_iterations; - if (iter + 1) % args.save_every == 0 || is_last { - let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); - match save_fn(&q_net, &path) { - Ok(()) => println!(" -> saved {}", path.display()), - Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), - } - } - } - - println!("\nDQN training complete."); -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - - if let Err(e) = std::fs::create_dir_all(&args.out_dir) { - eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); - std::process::exit(1); - } - - let train_device: ::Device = Default::default(); - let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; - - let q_net = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - QNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => QNet::::new(&cfg, &train_device), - }; - - train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); -} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs deleted file mode 100644 index aca1343..0000000 --- a/spiel_bot/src/dqn/episode.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! DQN self-play episode generation. -//! -//! Both players share the same Q-network (the [`TrictracEnv`] handles board -//! mirroring so that each player always acts from "White's perspective"). -//! Transitions for both players are stored in the returned sample list. -//! -//! # Reward -//! -//! After each full decision (action applied and the state has advanced through -//! any intervening chance nodes back to the same player's next turn), the -//! reward is: -//! -//! ```text -//! r = (my_total_score_now − my_total_score_then) -//! − (opp_total_score_now − opp_total_score_then) -//! ``` -//! -//! where `total_score = holes × 12 + points`. -//! -//! # Transition structure -//! -//! We use a "pending transition" per player. When a player acts again, we -//! *complete* the previous pending transition by filling in `next_obs`, -//! `next_legal`, and computing `reward`. Terminal transitions are completed -//! when the game ends. - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::{GameEnv, TrictracEnv}; -use crate::network::QValueNet; -use super::DqnSample; - -// ── Internals ───────────────────────────────────────────────────────────────── - -struct PendingTransition { - obs: Vec, - action: usize, - /// Score snapshot `[p1_total, p2_total]` at the moment of the action. - score_before: [i32; 2], -} - -/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. -fn epsilon_greedy>( - q_net: &Q, - obs: &[f32], - legal: &[usize], - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, -) -> usize { - debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); - if rng.random::() < epsilon { - legal[rng.random_range(0..legal.len())] - } else { - let obs_tensor = Tensor::::from_data( - TensorData::new(obs.to_vec(), [1, obs.len()]), - device, - ); - let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); - legal - .iter() - .copied() - .max_by(|&a, &b| { - q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap() - } -} - -/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. -fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { - let opp_idx = 1 - player_idx; - ((score_after[player_idx] - score_before[player_idx]) - - (score_after[opp_idx] - score_before[opp_idx])) as f32 -} - -// ── Public API ──────────────────────────────────────────────────────────────── - -/// Play one full game and return all transitions for both players. -/// -/// - `q_net` uses the **inference backend** (no autodiff wrapper). -/// - `epsilon` in `[0, 1]`: probability of taking a random action. -/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). -pub fn generate_dqn_episode>( - env: &TrictracEnv, - q_net: &Q, - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, - reward_scale: f32, -) -> Vec { - let obs_size = env.obs_size(); - let mut state = env.new_game(); - let mut pending: [Option; 2] = [None, None]; - let mut samples: Vec = Vec::new(); - - loop { - // ── Advance past chance nodes ────────────────────────────────────── - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - let score_now = TrictracEnv::score_snapshot(&state); - - if env.current_player(&state).is_terminal() { - // Complete all pending transitions as terminal. - for player_idx in 0..2 { - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: vec![0.0; obs_size], - next_legal: vec![], - done: true, - }); - } - } - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - let legal = env.legal_actions(&state); - let obs = env.observation(&state, player_idx); - - // ── Complete the previous transition for this player ─────────────── - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: obs.clone(), - next_legal: legal.clone(), - done: false, - }); - } - - // ── Pick and apply action ────────────────────────────────────────── - let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); - env.apply(&mut state, action); - - // ── Record new pending transition ────────────────────────────────── - pending[player_idx] = Some(PendingTransition { - obs, - action, - score_before: score_now, - }); - } - - samples -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::network::{QNet, QNetConfig}; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } - - fn tiny_q() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn episode_terminates_and_produces_samples() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_obs_size_correct() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert_eq!(s.obs.len(), 217, "obs size mismatch"); - if s.done { - assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); - assert!(s.next_legal.is_empty()); - } else { - assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); - assert!(!s.next_legal.is_empty()); - } - } - } - - #[test] - fn episode_actions_within_action_space() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert!(s.action < 514, "action {} out of bounds", s.action); - } - } - - #[test] - fn greedy_episode_also_terminates() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty()); - } - - #[test] - fn at_least_one_done_sample() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - let n_done = samples.iter().filter(|s| s.done).count(); - // Two players, so 1 or 2 terminal transitions. - assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); - } - - #[test] - fn compute_reward_correct() { - // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. - let before = [2 * 12 + 10, 0]; - let after = [3 * 12 + 2, 0]; - let r = compute_reward(0, &before, &after); - assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); - } - - #[test] - fn compute_reward_with_opponent_scoring() { - // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. - let before = [0, 0]; - let after = [2, 3]; - let r = compute_reward(0, &before, &after); - assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); - } -} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs deleted file mode 100644 index 8c34fc1..0000000 --- a/spiel_bot/src/dqn/mod.rs +++ /dev/null @@ -1,232 +0,0 @@ -//! DQN: self-play data generation, replay buffer, and training step. -//! -//! # Algorithm -//! -//! Deep Q-Network with: -//! - **ε-greedy** exploration (linearly decayed). -//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where -//! `score = holes × 12 + points`. -//! - **Experience replay** with a fixed-capacity circular buffer. -//! - **Target network**: hard-copied from the online Q-net every -//! `target_update_freq` gradient steps for training stability. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | -//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | - -pub mod episode; -pub mod trainer; - -pub use episode::generate_dqn_episode; -pub use trainer::{compute_target_q, dqn_train_step, hard_update}; - -use std::collections::VecDeque; -use rand::Rng; - -// ── DqnSample ───────────────────────────────────────────────────────────────── - -/// One transition `(s, a, r, s', done)` collected during self-play. -#[derive(Clone, Debug)] -pub struct DqnSample { - /// Observation from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// Action index taken. - pub action: usize, - /// Per-turn reward: `my_score_delta − opponent_score_delta`. - pub reward: f32, - /// Next observation from the same player's perspective. - /// All-zeros when `done = true` (ignored by the TD target). - pub next_obs: Vec, - /// Legal actions at `next_obs`. Empty when `done = true`. - pub next_legal: Vec, - /// `true` when `next_obs` is a terminal state. - pub done: bool, -} - -// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── - -/// Fixed-capacity circular replay buffer for [`DqnSample`]s. -/// -/// When full, the oldest sample is evicted on push. -/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. -pub struct DqnReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl DqnReplayBuffer { - pub fn new(capacity: usize) -> Self { - Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } - } - - pub fn push(&mut self, sample: DqnSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { self.push(s); } - } - - pub fn len(&self) -> usize { self.data.len() } - pub fn is_empty(&self) -> bool { self.data.is_empty() } - - /// Sample up to `n` distinct samples without replacement. - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { - let len = self.data.len(); - let n = n.min(len); - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── DqnConfig ───────────────────────────────────────────────────────────────── - -/// Top-level DQN hyperparameters for the training loop. -#[derive(Debug, Clone)] -pub struct DqnConfig { - /// Initial exploration rate (1.0 = fully random). - pub epsilon_start: f32, - /// Final exploration rate after decay. - pub epsilon_end: f32, - /// Number of gradient steps over which ε decays linearly from start to end. - /// - /// Should be calibrated to the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). A value larger than that - /// means exploration never reaches `epsilon_end` during the run. - pub epsilon_decay_steps: usize, - /// Discount factor γ for the TD target. Typical: 0.99. - pub gamma: f32, - /// Hard-copy Q → target every this many gradient steps. - /// - /// Should be much smaller than the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). - pub target_update_freq: usize, - /// Adam learning rate. - pub learning_rate: f64, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Number of outer iterations (self-play + train). - pub n_iterations: usize, - /// Self-play games per iteration. - pub n_games_per_iter: usize, - /// Gradient steps per iteration. - pub n_train_steps_per_iter: usize, - /// Reward normalisation divisor. - /// - /// Per-turn rewards (score delta) are divided by this constant before being - /// stored. Without normalisation, rewards can reach ±24 (jan with - /// bredouille = 12 pts × 2), driving Q-values into the hundreds and - /// causing MSE loss to grow unboundedly. - /// - /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping - /// Q-value magnitudes in a stable range. Set to `1.0` to disable. - pub reward_scale: f32, -} - -impl Default for DqnConfig { - fn default() -> Self { - // Total gradient steps with these defaults = 500 × 20 = 10_000, - // so epsilon decays fully and the target is updated 100 times. - Self { - epsilon_start: 1.0, - epsilon_end: 0.05, - epsilon_decay_steps: 10_000, - gamma: 0.99, - target_update_freq: 100, - learning_rate: 1e-3, - batch_size: 64, - replay_capacity: 50_000, - n_iterations: 500, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - reward_scale: 12.0, - } - } -} - -/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. -pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { - if decay_steps == 0 || step >= decay_steps { - return end; - } - start + (end - start) * (step as f32 / decay_steps as f32) -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(reward: f32) -> DqnSample { - DqnSample { - obs: vec![0.0], - action: 0, - reward, - next_obs: vec![0.0], - next_legal: vec![0], - done: false, - } - } - - #[test] - fn push_and_len() { - let mut buf = DqnReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = DqnReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); - assert_eq!(buf.len(), 3); - assert_eq!(buf.data[0].reward, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = DqnReplayBuffer::new(20); - for i in 0..10 { buf.push(dummy(i as f32)); } - let mut rng = SmallRng::seed_from_u64(0); - assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); - } - - #[test] - fn linear_epsilon_start() { - assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_end() { - assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_monotone() { - let mut prev = f32::INFINITY; - for step in 0..=100 { - let e = linear_epsilon(1.0, 0.05, step, 100); - assert!(e <= prev + 1e-6); - prev = e; - } - } -} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs deleted file mode 100644 index b8b0a02..0000000 --- a/spiel_bot/src/dqn/trainer.rs +++ /dev/null @@ -1,278 +0,0 @@ -//! DQN gradient step and target-network management. -//! -//! # TD target -//! -//! ```text -//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done -//! y_i = r_i if done -//! ``` -//! -//! # Loss -//! -//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) -//! and `y_i` (computed from the frozen target net). -//! -//! # Target network -//! -//! [`hard_update`] copies the online Q-net weights into the target net by -//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - Int, Tensor, TensorData, - backend::{AutodiffBackend, Backend}, - }, -}; - -use crate::network::QValueNet; -use super::DqnSample; - -// ── Target Q computation ───────────────────────────────────────────────────── - -/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. -/// -/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` -/// (their bootstrap term is dropped by the TD target anyway). -/// -/// The target network runs on the **inference backend** (`InferB`) with no -/// gradient tape, so this function is backend-agnostic (`B: Backend`). -pub fn compute_target_q>( - target_net: &Q, - batch: &[DqnSample], - action_size: usize, - device: &B::Device, -) -> Vec { - let batch_size = batch.len(); - - // Collect indices of non-done samples (done samples have no next state). - let non_done: Vec = batch - .iter() - .enumerate() - .filter(|(_, s)| !s.done) - .map(|(i, _)| i) - .collect(); - - if non_done.is_empty() { - return vec![0.0; batch_size]; - } - - let obs_size = batch[0].next_obs.len(); - let nd = non_done.len(); - - // Stack next observations for non-done samples → [nd, obs_size]. - let obs_flat: Vec = non_done - .iter() - .flat_map(|&i| batch[i].next_obs.iter().copied()) - .collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [nd, obs_size]), - device, - ); - - // Forward target net → [nd, action_size], then to Vec. - let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); - - // For each non-done sample, pick max Q over legal next actions. - let mut result = vec![0.0f32; batch_size]; - for (k, &i) in non_done.iter().enumerate() { - let legal = &batch[i].next_legal; - let offset = k * action_size; - let max_q = legal - .iter() - .map(|&a| q_flat[offset + a]) - .fold(f32::NEG_INFINITY, f32::max); - // If legal is empty (shouldn't happen for non-done, but be safe): - result[i] = if max_q.is_finite() { max_q } else { 0.0 }; - } - result -} - -// ── Training step ───────────────────────────────────────────────────────────── - -/// Run one gradient step on `q_net` using `batch`. -/// -/// `target_max_q` must be pre-computed via [`compute_target_q`] using the -/// frozen target network and passed in here so that this function only -/// needs the **autodiff backend**. -/// -/// Returns the updated network and the scalar MSE loss. -pub fn dqn_train_step( - q_net: Q, - optimizer: &mut O, - batch: &[DqnSample], - target_max_q: &[f32], - device: &B::Device, - lr: f64, - gamma: f32, -) -> (Q, f32) -where - B: AutodiffBackend, - Q: QValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "dqn_train_step: empty batch"); - assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - - // ── Build observation tensor [B, obs_size] ──────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - - // ── Forward Q-net → [B, action_size] ───────────────────────────────── - let q_all = q_net.forward(obs_tensor); - - // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── - let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); - let action_tensor: Tensor = Tensor::::from_data( - TensorData::new(actions, [batch_size]), - device, - ) - .reshape([batch_size, 1]); // [B] → [B, 1] - let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] - - // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── - let targets: Vec = batch - .iter() - .zip(target_max_q.iter()) - .map(|(s, &max_q)| { - if s.done { s.reward } else { s.reward + gamma * max_q } - }) - .collect(); - let target_tensor = Tensor::::from_data( - TensorData::new(targets, [batch_size]), - device, - ); - - // ── MSE loss ────────────────────────────────────────────────────────── - let diff = q_pred - target_tensor.detach(); - let loss = (diff.clone() * diff).mean(); - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &q_net); - let q_net = optimizer.step(lr, q_net, grads); - - (q_net, loss_scalar) -} - -// ── Target network update ───────────────────────────────────────────────────── - -/// Hard-copy the online Q-net weights to a new target network. -/// -/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an -/// inference-backend module with identical weights. -pub fn hard_update>(q_net: &Q) -> Q::InnerModule { - q_net.valid() -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - use crate::network::{QNet, QNetConfig}; - - type InferB = NdArray; - type TrainB = Autodiff>; - - fn infer_device() -> ::Device { Default::default() } - fn train_device() -> ::Device { Default::default() } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| DqnSample { - obs: vec![0.5f32; obs_size], - action: i % action_size, - reward: if i % 2 == 0 { 1.0 } else { -1.0 }, - next_obs: vec![0.5f32; obs_size], - next_legal: vec![0, 1], - done: i == n - 1, - }) - .collect() - } - - #[test] - fn compute_target_q_length() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 8); - } - - #[test] - fn compute_target_q_done_is_zero() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - // Single done sample. - let batch = vec![DqnSample { - obs: vec![0.0; 4], - action: 0, - reward: 5.0, - next_obs: vec![0.0; 4], - next_legal: vec![], - done: true, - }]; - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 1); - assert_eq!(tq[0], 0.0); - } - - #[test] - fn train_step_returns_finite_loss() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - - #[test] - fn train_step_loss_decreases() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(16, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, - ); - q_net = q; - assert!(loss.is_finite()); - prev_loss = loss; - } - assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn hard_update_copies_weights() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = hard_update::(&q_net); - - let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); - let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); - // After hard_update the target produces finite outputs. - assert!(q_out.iter().all(|v| v.is_finite())); - } -} diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs deleted file mode 100644 index 42b4ae0..0000000 --- a/spiel_bot/src/env/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -//! Game environment abstraction — the minimal "Rust OpenSpiel". -//! -//! A `GameEnv` describes the rules of a two-player, zero-sum game that may -//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, -//! and PPO interact with a game exclusively through this trait. -//! -//! # Node taxonomy -//! -//! Every game position belongs to one of four categories, returned by -//! [`GameEnv::current_player`]: -//! -//! | [`Player`] | Meaning | -//! |-----------|---------| -//! | `P1` | Player 1 (index 0) must choose an action | -//! | `P2` | Player 2 (index 1) must choose an action | -//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | -//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | -//! -//! # Perspective convention -//! -//! [`GameEnv::observation`] always returns the board from *the requested -//! player's* point of view. Callers pass `pov = 0` for Player 1 and -//! `pov = 1` for Player 2. The implementation is responsible for any -//! mirroring required (e.g. Trictrac always reasons from White's side). - -pub mod trictrac; -pub use trictrac::TrictracEnv; - -/// Who controls the current game node. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Player { - /// Player 1 (index 0) is to move. - P1, - /// Player 2 (index 1) is to move. - P2, - /// A stochastic event (dice roll, etc.) must be resolved. - Chance, - /// The game is over. - Terminal, -} - -impl Player { - /// Returns the player index (0 or 1) if this is a decision node, - /// or `None` for `Chance` / `Terminal`. - pub fn index(self) -> Option { - match self { - Player::P1 => Some(0), - Player::P2 => Some(1), - _ => None, - } - } - - pub fn is_decision(self) -> bool { - matches!(self, Player::P1 | Player::P2) - } - - pub fn is_chance(self) -> bool { - self == Player::Chance - } - - pub fn is_terminal(self) -> bool { - self == Player::Terminal - } -} - -/// Trait that completely describes a two-player zero-sum game. -/// -/// Implementors must be cheaply cloneable (the type is used as a stateless -/// factory; the mutable game state lives in `Self::State`). -pub trait GameEnv: Clone + Send + Sync + 'static { - /// The mutable game state. Must be `Clone` so MCTS can copy - /// game trees without touching the environment. - type State: Clone + Send + Sync; - - // ── State creation ──────────────────────────────────────────────────── - - /// Create a fresh game state at the initial position. - fn new_game(&self) -> Self::State; - - // ── Node queries ────────────────────────────────────────────────────── - - /// Classify the current node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). - /// - /// The returned indices are in `[0, action_space())`. - /// The result is unspecified (may panic or return empty) when called at a - /// `Chance` or `Terminal` node. - fn legal_actions(&self, s: &Self::State) -> Vec; - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action. `action` must be a value returned by - /// [`legal_actions`] for the current state. - fn apply(&self, s: &mut Self::State, action: usize); - - /// Sample and apply a stochastic outcome. Must only be called when - /// `current_player(s) == Player::Chance`. - fn apply_chance(&self, s: &mut Self::State, rng: &mut R); - - // ── Observation ─────────────────────────────────────────────────────── - - /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). - /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Number of floats returned by [`observation`]. - fn obs_size(&self) -> usize; - - /// Total number of distinct action indices (the policy head output size). - fn action_space(&self) -> usize; - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Game outcome for each player, or `None` if the game is not over. - /// - /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. - /// Index 0 = Player 1, index 1 = Player 2. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs deleted file mode 100644 index 8dc3676..0000000 --- a/spiel_bot/src/env/trictrac.rs +++ /dev/null @@ -1,547 +0,0 @@ -//! [`GameEnv`] implementation for Trictrac. -//! -//! # Game flow (schools_enabled = false) -//! -//! With scoring schools disabled (the standard training configuration), -//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine -//! applies them automatically inside `RollResult` and `Move`. The only -//! four stages that actually occur are: -//! -//! | `TurnStage` | [`Player`] kind | Handled by | -//! |-------------|-----------------|------------| -//! | `RollDice` | `Chance` | [`apply_chance`] | -//! | `RollWaiting` | `Chance` | [`apply_chance`] | -//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | -//! | `Move` | `P1`/`P2` | [`apply`] | -//! -//! # Perspective -//! -//! The Trictrac engine always reasons from White's perspective. Player 1 is -//! White; Player 2 is Black. When Player 2 is active, the board is mirrored -//! before computing legal actions / the observation tensor, and the resulting -//! event is mirrored back before being applied to the real state. This -//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. - -use trictrac_store::{ - training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, - Dice, GameEvent, GameState, Stage, TurnStage, -}; - -use super::{GameEnv, Player}; - -/// Stateless factory that produces Trictrac [`GameState`] environments. -/// -/// Schools (`schools_enabled`) are always disabled — scoring is automatic. -#[derive(Clone, Debug, Default)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - // ── State creation ──────────────────────────────────────────────────── - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - // ── Node queries ────────────────────────────────────────────────────── - - fn current_player(&self, s: &GameState) -> Player { - if s.stage == Stage::Ended { - return Player::Terminal; - } - match s.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, - _ => { - if s.active_player_id == 1 { - Player::P1 - } else { - Player::P2 - } - } - } - } - - /// Returns the legal action indices for the active player. - /// - /// The board is automatically mirrored for Player 2 so that the engine - /// always reasons from White's perspective. The returned indices are - /// identical in meaning for both players (checker ordinals are - /// perspective-relative). - /// - /// # Panics - /// - /// Panics in debug builds if called at a `Chance` or `Terminal` node. - fn legal_actions(&self, s: &GameState) -> Vec { - debug_assert!( - self.current_player(s).is_decision(), - "legal_actions called at a non-decision node (turn_stage={:?})", - s.turn_stage - ); - let indices = if s.active_player_id == 2 { - get_valid_action_indices(&s.mirror()) - } else { - get_valid_action_indices(s) - }; - indices.unwrap_or_default() - } - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action index to the game state. - /// - /// For Player 2, the action is decoded against the mirrored board and - /// the resulting event is un-mirrored before being applied. - /// - /// # Panics - /// - /// Panics in debug builds if `action` cannot be decoded or does not - /// produce a valid event for the current state. - fn apply(&self, s: &mut GameState, action: usize) { - let needs_mirror = s.active_player_id == 2; - - let event = if needs_mirror { - let view = s.mirror(); - TrictracAction::from_action_index(action) - .and_then(|a| a.to_event(&view)) - .map(|e| e.get_mirror(false)) - } else { - TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) - }; - - match event { - Some(e) => { - s.consume(&e).expect("apply: consume failed for valid action"); - } - None => { - panic!("apply: action index {action} produced no event in state {s}"); - } - } - } - - /// Sample dice and advance through a chance node. - /// - /// Handles both `RollDice` (triggers the roll mechanism, then samples - /// dice) and `RollWaiting` (only samples dice) in a single call so that - /// callers never need to distinguish the two. - /// - /// # Panics - /// - /// Panics in debug builds if called at a non-Chance node. - fn apply_chance(&self, s: &mut GameState, rng: &mut R) { - debug_assert!( - self.current_player(s).is_chance(), - "apply_chance called at a non-Chance node (turn_stage={:?})", - s.turn_stage - ); - - // Step 1: RollDice → RollWaiting (player initiates the roll). - if s.turn_stage == TurnStage::RollDice { - s.consume(&GameEvent::Roll { - player_id: s.active_player_id, - }) - .expect("apply_chance: Roll event failed"); - } - - // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. - // With schools_enabled=false, point marking is automatic inside consume(). - let dice = Dice { - values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), - }; - s.consume(&GameEvent::RollResult { - player_id: s.active_player_id, - dice, - }) - .expect("apply_chance: RollResult event failed"); - } - - // ── Observation ─────────────────────────────────────────────────────── - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { - s.to_tensor() - } else { - s.mirror().to_tensor() - } - } - - fn obs_size(&self) -> usize { - 217 - } - - fn action_space(&self) -> usize { - ACTION_SPACE_SIZE - } - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. - /// - /// The winner (higher cumulative score) receives `+1.0`; the loser - /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score - /// is `holes × 12 + points`. - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { - return None; - } - let score = |id: u64| -> i32 { - s.players - .get(&id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - }; - let s1 = score(1); - let s2 = score(2); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [0.0, 0.0], - }) - } -} - -// ── DQN helpers ─────────────────────────────────────────────────────────────── - -impl TrictracEnv { - /// Score snapshot for DQN reward computation. - /// - /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. - /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). - pub fn score_snapshot(s: &GameState) -> [i32; 2] { - [s.total_score(1), s.total_score(2)] - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{rngs::SmallRng, Rng, SeedableRng}; - - fn env() -> TrictracEnv { - TrictracEnv - } - - fn seeded_rng(seed: u64) -> SmallRng { - SmallRng::seed_from_u64(seed) - } - - // ── Initial state ───────────────────────────────────────────────────── - - #[test] - fn new_game_is_chance_node() { - let e = env(); - let s = e.new_game(); - // A fresh game starts at RollDice — a Chance node. - assert_eq!(e.current_player(&s), Player::Chance); - assert!(e.returns(&s).is_none()); - } - - #[test] - fn new_game_is_not_terminal() { - let e = env(); - let s = e.new_game(); - assert_ne!(e.current_player(&s), Player::Terminal); - assert!(e.returns(&s).is_none()); - } - - // ── Chance nodes ────────────────────────────────────────────────────── - - #[test] - fn apply_chance_reaches_decision_node() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(1); - - // A single chance step must yield a decision node (or end the game, - // which only happens after 12 holes — impossible on the first roll). - e.apply_chance(&mut s, &mut rng); - let p = e.current_player(&s); - assert!( - p.is_decision(), - "expected decision node after first roll, got {p:?}" - ); - } - - #[test] - fn apply_chance_from_rollwaiting() { - // Check that apply_chance works when called mid-way (at RollWaiting). - let e = env(); - let mut s = e.new_game(); - assert_eq!(s.turn_stage, TurnStage::RollDice); - - // Manually advance to RollWaiting. - s.consume(&GameEvent::Roll { player_id: s.active_player_id }) - .unwrap(); - assert_eq!(s.turn_stage, TurnStage::RollWaiting); - - let mut rng = seeded_rng(2); - e.apply_chance(&mut s, &mut rng); - - let p = e.current_player(&s); - assert!(p.is_decision() || p.is_terminal()); - } - - // ── Legal actions ───────────────────────────────────────────────────── - - #[test] - fn legal_actions_nonempty_after_roll() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(3); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let actions = e.legal_actions(&s); - assert!( - !actions.is_empty(), - "legal_actions must be non-empty at a decision node" - ); - } - - #[test] - fn legal_actions_within_action_space() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(4); - - e.apply_chance(&mut s, &mut rng); - for &a in e.legal_actions(&s).iter() { - assert!( - a < e.action_space(), - "action {a} out of bounds (action_space={})", - e.action_space() - ); - } - } - - // ── Observations ────────────────────────────────────────────────────── - - #[test] - fn observation_has_correct_size() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(5); - e.apply_chance(&mut s, &mut rng); - - assert_eq!(e.observation(&s, 0).len(), e.obs_size()); - assert_eq!(e.observation(&s, 1).len(), e.obs_size()); - } - - #[test] - fn observation_values_in_unit_interval() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(6); - e.apply_chance(&mut s, &mut rng); - - for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { - for (i, &v) in obs.iter().enumerate() { - assert!( - v >= 0.0 && v <= 1.0, - "pov={pov}: obs[{i}] = {v} is outside [0,1]" - ); - } - } - } - - #[test] - fn p1_and_p2_observations_differ() { - // The board is mirrored for P2, so the two observations should differ - // whenever there are checkers in non-symmetric positions (always true - // in a real game after a few moves). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(7); - - // Advance far enough that the board is non-trivial. - for _ in 0..6 { - while e.current_player(&s).is_chance() { - e.apply_chance(&mut s, &mut rng); - } - if e.current_player(&s).is_terminal() { - break; - } - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - - if !e.current_player(&s).is_terminal() { - let obs0 = e.observation(&s, 0); - let obs1 = e.observation(&s, 1); - assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); - } - } - - // ── Applying actions ────────────────────────────────────────────────── - - #[test] - fn apply_changes_state() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(8); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let before = s.clone(); - let action = e.legal_actions(&s)[0]; - e.apply(&mut s, action); - - assert_ne!( - before.turn_stage, s.turn_stage, - "state must change after apply" - ); - } - - #[test] - fn apply_all_legal_actions_do_not_panic() { - // Verify that every action returned by legal_actions can be applied - // without panicking (on several independent copies of the same state). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(9); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - for action in e.legal_actions(&s) { - let mut copy = s.clone(); - e.apply(&mut copy, action); // must not panic - } - } - - // ── Full game ───────────────────────────────────────────────────────── - - /// Run a complete game with random actions through the `GameEnv` trait - /// and verify that: - /// - The game terminates. - /// - `returns()` is `Some` at the end. - /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's - /// score is ±1 / 0. - /// - No step panics. - #[test] - fn full_random_game_terminates() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(42); - let max_steps = 50_000; - - for step in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); - } - - let result = e.returns(&s); - assert!(result.is_some(), "returns() must be Some at Terminal"); - - let [r1, r2] = result.unwrap(); - let sum = r1 + r2; - assert!( - (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, - "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" - ); - assert!( - r1.abs() <= 1.0 && r2.abs() <= 1.0, - "returns must be in [-1,1]: r1={r1}, r2={r2}" - ); - } - - /// Run multiple games with different seeds to stress-test for panics. - #[test] - fn multiple_games_no_panic() { - let e = env(); - let max_steps = 20_000; - - for seed in 0..10u64 { - let mut s = e.new_game(); - let mut rng = seeded_rng(seed); - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - } - } - } - - // ── Returns ─────────────────────────────────────────────────────────── - - #[test] - fn returns_none_mid_game() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(11); - - // Advance a few steps but do not finish the game. - for _ in 0..4 { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - if !e.current_player(&s).is_terminal() { - assert!( - e.returns(&s).is_none(), - "returns() must be None before the game ends" - ); - } - } - - // ── Player 2 actions ────────────────────────────────────────────────── - - /// Verify that Player 2 (Black) can take actions without panicking, - /// and that the state advances correctly. - #[test] - fn player2_can_act() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(12); - - // Keep stepping until Player 2 gets a turn. - let max_steps = 5_000; - let mut p2_acted = false; - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty()); - e.apply(&mut s, actions[0]); - p2_acted = true; - break; - } - Player::P1 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); - } -} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs deleted file mode 100644 index 9dfb4de..0000000 --- a/spiel_bot/src/lib.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod alphazero; -pub mod dqn; -pub mod env; -pub mod mcts; -pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs deleted file mode 100644 index eead171..0000000 --- a/spiel_bot/src/mcts/mod.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. -//! -//! # Algorithm -//! -//! The implementation follows AlphaZero's MCTS: -//! -//! 1. **Expand root** — run the network once to get priors and a value -//! estimate; optionally add Dirichlet noise for training-time exploration. -//! 2. **Simulate** `n_simulations` times: -//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. -//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; -//! chance nodes are **not** stored in the tree (outcome sampling). -//! - *Expansion* — evaluate the network at the leaf; populate children. -//! - *Backup* — propagate the value upward; negate at each player boundary. -//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). -//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). -//! -//! # Perspective convention -//! -//! Every [`MctsNode::w`] is stored **from the perspective of the player who -//! acts at that node**. The backup negates the child value whenever the -//! acting player differs between parent and child. -//! -//! # Stochastic games -//! -//! When [`GameEnv::current_player`] returns [`Player::Chance`], the -//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and -//! continues. Chance nodes are skipped transparently; Q-values converge to -//! their expectation over many simulations (outcome sampling). - -pub mod node; -mod search; - -pub use node::MctsNode; - -use rand::Rng; - -use crate::env::GameEnv; - -// ── Evaluator trait ──────────────────────────────────────────────────────── - -/// Evaluates a game position for use in MCTS. -/// -/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) -/// but the `mcts` module itself does **not** depend on Burn. -pub trait Evaluator: Send + Sync { - /// Evaluate `obs` (flat observation vector of length `obs_size`). - /// - /// Returns: - /// - `policy_logits`: one raw logit per action (`action_space` entries). - /// Illegal action entries are masked inside the search — no need to - /// zero them here. - /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. - fn evaluate(&self, obs: &[f32]) -> (Vec, f32); -} - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Hyperparameters for [`run_mcts`]. -#[derive(Debug, Clone)] -pub struct MctsConfig { - /// Number of MCTS simulations per move. Typical: 50–800. - pub n_simulations: usize, - /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. - pub c_puct: f32, - /// Dirichlet noise concentration α. Set to `0.0` to disable. - /// Typical: `0.3` for Chess, `0.1` for large action spaces. - pub dirichlet_alpha: f32, - /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. - pub dirichlet_eps: f32, - /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. - pub temperature: f32, -} - -impl Default for MctsConfig { - fn default() -> Self { - Self { - n_simulations: 200, - c_puct: 1.5, - dirichlet_alpha: 0.3, - dirichlet_eps: 0.25, - temperature: 1.0, - } - } -} - -// ── Public interface ─────────────────────────────────────────────────────── - -/// Run MCTS from `state` and return the populated root [`MctsNode`]. -/// -/// `state` must be a player-decision node (`P1` or `P2`). -/// Use [`mcts_policy`] and [`select_action`] on the returned root. -/// -/// # Panics -/// -/// Panics if `env.current_player(state)` is not `P1` or `P2`. -pub fn run_mcts( - env: &E, - state: &E::State, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, -) -> MctsNode { - let player_idx = env - .current_player(state) - .index() - .expect("run_mcts called at a non-decision node"); - - // ── Expand root (network called once here, not inside the loop) ──────── - let mut root = MctsNode::new(1.0); - search::expand::(&mut root, state, env, evaluator, player_idx); - - // ── Optional Dirichlet noise for training exploration ────────────────── - if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { - search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); - } - - // ── Simulations ──────────────────────────────────────────────────────── - for _ in 0..config.n_simulations { - search::simulate::( - &mut root, - state.clone(), - env, - evaluator, - config, - rng, - player_idx, - ); - } - - root -} - -/// Compute the MCTS policy: normalized visit counts at the root. -/// -/// Returns a vector of length `action_space` where `policy[a]` is the -/// fraction of simulations that visited action `a`. -pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { - let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); - let mut policy = vec![0.0f32; action_space]; - if total > 0.0 { - for (a, child) in &root.children { - policy[*a] = child.n as f32 / total; - } - } else if !root.children.is_empty() { - // n_simulations = 0: uniform over legal actions. - let uniform = 1.0 / root.children.len() as f32; - for (a, _) in &root.children { - policy[*a] = uniform; - } - } - policy -} - -/// Select an action index from the root after MCTS. -/// -/// * `temperature = 0` — greedy argmax of visit counts. -/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. -/// -/// # Panics -/// -/// Panics if the root has no children. -pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { - assert!(!root.children.is_empty(), "select_action called on a root with no children"); - if temperature <= 0.0 { - root.children - .iter() - .max_by_key(|(_, c)| c.n) - .map(|(a, _)| *a) - .unwrap() - } else { - let weights: Vec = root - .children - .iter() - .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) - .collect(); - let total: f32 = weights.iter().sum(); - let mut r: f32 = rng.random::() * total; - for (i, (a, _)) in root.children.iter().enumerate() { - r -= weights[i]; - if r <= 0.0 { - return *a; - } - } - root.children.last().map(|(a, _)| *a).unwrap() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - use crate::env::Player; - - // ── Minimal deterministic test game ─────────────────────────────────── - // - // "Countdown" — two players alternate subtracting 1 or 2 from a counter. - // The player who brings the counter to 0 wins. - // No chance nodes, two legal actions (0 = -1, 1 = -2). - - #[derive(Clone, Debug)] - struct CState { - remaining: u8, - to_move: usize, // at terminal: last mover (winner) - } - - #[derive(Clone)] - struct CountdownEnv; - - impl crate::env::GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: 6, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { - Player::Terminal - } else if s.to_move == 0 { - Player::P1 - } else { - Player::P2 - } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - // to_move stays as winner - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 6.0, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - // Uniform evaluator: all logits = 0, value = 0. - // `action_space` must match the environment's `action_space()`. - struct ZeroEval(usize); - impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(42) - } - - fn config_n(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, // off for reproducibility - dirichlet_eps: 0.0, - temperature: 1.0, - } - } - - // ── Visit count tests ───────────────────────────────────────────────── - - #[test] - fn visit_counts_sum_to_n_simulations() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 50, "visit counts must sum to n_simulations"); - } - - #[test] - fn all_root_children_are_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); - for (a, _) in &root.children { - assert!(legal.contains(a), "child action {a} is not legal"); - } - } - - // ── Policy tests ───────────────────────────────────────────────────── - - #[test] - fn policy_sums_to_one() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); - } - - #[test] - fn policy_zero_for_illegal_actions() { - let env = CountdownEnv; - // remaining = 1 → only action 0 is legal - let state = CState { remaining: 1, to_move: 0 }; - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); - } - - // ── Action selection tests ──────────────────────────────────────────── - - #[test] - fn greedy_selects_most_visited() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); - let greedy = select_action(&root, 0.0, &mut rng()); - let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); - assert_eq!(greedy, most_visited); - } - - #[test] - fn temperature_sampling_stays_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let mut r = rng(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); - for _ in 0..20 { - let a = select_action(&root, 1.0, &mut r); - assert!(legal.contains(&a), "sampled action {a} is not legal"); - } - } - - // ── Zero-simulation edge case ───────────────────────────────────────── - - #[test] - fn zero_simulations_uniform_policy() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - // With 0 simulations, fallback is uniform over the 2 legal actions. - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5); - } - - // ── Root value ──────────────────────────────────────────────────────── - - #[test] - fn root_q_in_valid_range() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); - let q = root.q(); - assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); - } - - // ── Integration: run on a real Trictrac game ────────────────────────── - - #[test] - fn no_panic_on_trictrac_state() { - use crate::env::TrictracEnv; - - let env = TrictracEnv; - let mut state = env.new_game(); - let mut r = rng(); - - // Advance past the initial chance node to reach a decision node. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, &mut r); - } - - if env.current_player(&state).is_terminal() { - return; // unlikely but safe - } - - let config = MctsConfig { - n_simulations: 5, // tiny for speed - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - ..MctsConfig::default() - }; - - let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - // root.n = 1 (expansion) + n_simulations (one backup per simulation). - assert_eq!(root.n, 1 + config.n_simulations as u32); - // Every simulation crosses a chance node at depth 1 (dice roll after - // the player's move). Since the fix now updates child.n in that case, - // children visit counts must sum to exactly n_simulations. - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, config.n_simulations as u32); - } -} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs deleted file mode 100644 index aff7735..0000000 --- a/spiel_bot/src/mcts/node.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! MCTS tree node. -//! -//! [`MctsNode`] holds the visit statistics for one player-decision position in -//! the search tree. A node is *expanded* the first time the policy-value -//! network is evaluated there; before that it is a leaf. - -/// One node in the MCTS tree, representing a player-decision position. -/// -/// `w` stores the sum of values backed up into this node, always from the -/// perspective of **the player who acts here**. `q()` therefore also returns -/// a value in `(-1, 1)` from that same perspective. -#[derive(Debug)] -pub struct MctsNode { - /// Visit count `N(s, a)`. - pub n: u32, - /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. - pub w: f32, - /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). - pub p: f32, - /// Children: `(action_index, child_node)`, populated on first expansion. - pub children: Vec<(usize, MctsNode)>, - /// `true` after the network has been evaluated and children have been set up. - pub expanded: bool, -} - -impl MctsNode { - /// Create a fresh, unexpanded leaf with the given prior probability. - pub fn new(prior: f32) -> Self { - Self { - n: 0, - w: 0.0, - p: prior, - children: Vec::new(), - expanded: false, - } - } - - /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. - #[inline] - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT selection score: - /// - /// ```text - /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) - /// ``` - #[inline] - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn q_zero_when_unvisited() { - let node = MctsNode::new(0.5); - assert_eq!(node.q(), 0.0); - } - - #[test] - fn q_reflects_w_over_n() { - let mut node = MctsNode::new(0.5); - node.n = 4; - node.w = 2.0; - assert!((node.q() - 0.5).abs() < 1e-6); - } - - #[test] - fn puct_exploration_dominates_unvisited() { - // Unvisited child should outscore a visited child with negative Q. - let mut visited = MctsNode::new(0.5); - visited.n = 10; - visited.w = -5.0; // Q = -0.5 - - let unvisited = MctsNode::new(0.5); - - let parent_n = 10; - let c = 1.5; - assert!( - unvisited.puct(parent_n, c) > visited.puct(parent_n, c), - "unvisited child should have higher PUCT than a negatively-valued visited child" - ); - } -} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs deleted file mode 100644 index 1d9750d..0000000 --- a/spiel_bot/src/mcts/search.rs +++ /dev/null @@ -1,190 +0,0 @@ -//! Simulation, expansion, backup, and noise helpers. -//! -//! These are internal to the `mcts` module; the public entry points are -//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. - -use rand::Rng; -use rand_distr::{Gamma, Distribution}; - -use crate::env::GameEnv; -use super::{Evaluator, MctsConfig}; -use super::node::MctsNode; - -// ── Masked softmax ───────────────────────────────────────────────────────── - -/// Numerically stable softmax over `legal` actions only. -/// -/// Illegal logits are treated as `-∞` and receive probability `0.0`. -/// Returns a probability vector of length `action_space`. -pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { - let mut probs = vec![0.0f32; action_space]; - if legal.is_empty() { - return probs; - } - let max_logit = legal - .iter() - .map(|&a| logits[a]) - .fold(f32::NEG_INFINITY, f32::max); - let mut sum = 0.0f32; - for &a in legal { - let e = (logits[a] - max_logit).exp(); - probs[a] = e; - sum += e; - } - if sum > 0.0 { - for &a in legal { - probs[a] /= sum; - } - } else { - let uniform = 1.0 / legal.len() as f32; - for &a in legal { - probs[a] = uniform; - } - } - probs -} - -// ── Dirichlet noise ──────────────────────────────────────────────────────── - -/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. -/// -/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. -/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. -pub(super) fn add_dirichlet_noise( - node: &mut MctsNode, - alpha: f32, - eps: f32, - rng: &mut impl Rng, -) { - let n = node.children.len(); - if n == 0 { - return; - } - let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { - return; - }; - let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); - let sum: f32 = samples.iter().sum(); - if sum <= 0.0 { - return; - } - for (i, (_, child)) in node.children.iter_mut().enumerate() { - let noise = samples[i] / sum; - child.p = (1.0 - eps) * child.p + eps * noise; - } -} - -// ── Expansion ────────────────────────────────────────────────────────────── - -/// Evaluate the network at `state` and populate `node` with children. -/// -/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. -/// Returns the network value estimate from `player_idx`'s perspective. -pub(super) fn expand( - node: &mut MctsNode, - state: &E::State, - env: &E, - evaluator: &dyn Evaluator, - player_idx: usize, -) -> f32 { - let obs = env.observation(state, player_idx); - let legal = env.legal_actions(state); - let (logits, value) = evaluator.evaluate(&obs); - let priors = masked_softmax(&logits, &legal, env.action_space()); - node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); - node.expanded = true; - node.n = 1; - node.w = value; - value -} - -// ── Simulation ───────────────────────────────────────────────────────────── - -/// One MCTS simulation from an **already-expanded** decision node. -/// -/// Traverses the tree with PUCT selection, expands the first unvisited leaf, -/// and backs up the result. -/// -/// * `player_idx` — the player (0 or 1) who acts at `state`. -/// * Returns the backed-up value **from `player_idx`'s perspective**. -pub(super) fn simulate( - node: &mut MctsNode, - state: E::State, - env: &E, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, - player_idx: usize, -) -> f32 { - debug_assert!(node.expanded, "simulate called on unexpanded node"); - - // ── Selection: child with highest PUCT ──────────────────────────────── - let parent_n = node.n; - let best = node - .children - .iter() - .enumerate() - .max_by(|(_, (_, a)), (_, (_, b))| { - a.puct(parent_n, config.c_puct) - .partial_cmp(&b.puct(parent_n, config.c_puct)) - .unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(i, _)| i) - .expect("expanded node must have at least one child"); - - let (action, child) = &mut node.children[best]; - let action = *action; - - // ── Apply action + advance through any chance nodes ─────────────────── - let mut next_state = state; - env.apply(&mut next_state, action); - - // Track whether we crossed a chance node (dice roll) on the way down. - // If we did, the child's cached legal actions are for a *different* dice - // outcome and must not be reused — evaluate with the network directly. - let mut crossed_chance = false; - while env.current_player(&next_state).is_chance() { - env.apply_chance(&mut next_state, rng); - crossed_chance = true; - } - - let next_cp = env.current_player(&next_state); - - // ── Evaluate leaf or terminal ────────────────────────────────────────── - // All values are converted to `player_idx`'s perspective before backup. - let child_value = if next_cp.is_terminal() { - let returns = env - .returns(&next_state) - .expect("terminal node must have returns"); - returns[player_idx] - } else { - let child_player = next_cp.index().unwrap(); - let v = if crossed_chance { - // Outcome sampling: after dice, evaluate the resulting position - // directly with the network. Do NOT build the tree across chance - // boundaries — the dice change which actions are legal, so any - // previously cached children would be for a different outcome. - let obs = env.observation(&next_state, child_player); - let (_, value) = evaluator.evaluate(&obs); - // Record the visit so that PUCT and mcts_policy use real counts. - // Without this, child.n stays 0 for every simulation in games where - // every player action is immediately followed by a chance node (e.g. - // Trictrac), causing mcts_policy to always return a uniform policy. - child.n += 1; - child.w += value; - value - } else if child.expanded { - simulate(child, next_state, env, evaluator, config, rng, child_player) - } else { - expand::(child, &next_state, env, evaluator, child_player) - }; - // Negate when the child belongs to the opponent. - if child_player == player_idx { v } else { -v } - }; - - // ── Backup ──────────────────────────────────────────────────────────── - node.n += 1; - node.w += child_value; - - child_value -} diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs deleted file mode 100644 index eb6184e..0000000 --- a/spiel_bot/src/network/mlp.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! Two-hidden-layer MLP policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`MlpNet`]. -#[derive(Debug, Clone)] -pub struct MlpConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for MlpConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 256, - } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Simple two-hidden-layer MLP with shared trunk and two heads. -/// -/// Prefer this over [`ResNet`](super::ResNet) when training time is a -/// priority, or as a fast baseline. -#[derive(Module, Debug)] -pub struct MlpNet { - fc1: Linear, - fc2: Linear, - policy_head: Linear, - value_head: Linear, -} - -impl MlpNet { - /// Construct a fresh network with random weights. - pub fn new(config: &MlpConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - value_head: LinearConfig::new(config.hidden_size, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - /// - /// The file is written exactly at `path`; callers should append `.mpk` if - /// they want the conventional extension. - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for MlpNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn default_net() -> MlpNet { - MlpNet::new(&MlpConfig::default(), &device()) - } - - fn zeros_obs(batch: usize) -> Tensor { - Tensor::zeros([batch, 217], &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let net = default_net(); - let obs = zeros_obs(4); - let (policy, value) = net.forward(obs); - - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let (policy, value) = net.forward(zeros_obs(1)); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let net = default_net(); - // Use a non-zero input so the output is not trivially at 0. - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net.forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Policy logits ───────────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - // With random weights the 514 logits should not all be identical. - let net = default_net(); - let (policy, _) = net.forward(zeros_obs(1)); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical — network may be degenerate"); - } - - // ── Config propagation ──────────────────────────────────────────────── - - #[test] - fn custom_config_shapes() { - let config = MlpConfig { - obs_size: 10, - action_size: 20, - hidden_size: 32, - }; - let net = MlpNet::::new(&config, &device()); - let obs = Tensor::zeros([3, 10], &device()); - let (policy, value) = net.forward(obs); - assert_eq!(policy.dims(), [3, 20]); - assert_eq!(value.dims(), [3, 1]); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = MlpConfig::default(); - let net = default_net(); - - // Forward pass before saving. - let obs = Tensor::::ones([2, 217], &device()); - let (policy_before, value_before) = net.forward(obs.clone()); - - // Save to a temp file. - let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); - net.save(&path).expect("save failed"); - - // Load into a fresh model. - let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - // Outputs must be bitwise identical. - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } -} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs deleted file mode 100644 index 64f93ec..0000000 --- a/spiel_bot/src/network/mod.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Neural network abstractions for policy-value learning. -//! -//! # Trait -//! -//! [`PolicyValueNet`] is the single trait that all network architectures -//! implement. It takes an observation tensor and returns raw policy logits -//! plus a tanh-squashed scalar value estimate. -//! -//! # Architectures -//! -//! | Module | Description | Default hidden | -//! |--------|-------------|----------------| -//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | -//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | -//! -//! # Backend convention -//! -//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). -//! * **Training** — use `Autodiff>` so Burn can differentiate -//! through the forward pass. -//! -//! Both modes use the exact same struct; only the type-level backend changes: -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! type InferBackend = NdArray; -//! type TrainBackend = Autodiff>; -//! -//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! ``` -//! -//! # Output shapes -//! -//! Given a batch of `B` observations of size `obs_size`: -//! -//! | Output | Shape | Range | -//! |--------|-------|-------| -//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | -//! | `value` | `[B, 1]` | (-1, 1) via tanh | -//! -//! Callers are responsible for masking illegal actions in `policy_logits` -//! before passing to softmax. - -pub mod mlp; -pub mod qnet; -pub mod resnet; - -pub use mlp::{MlpConfig, MlpNet}; -pub use qnet::{QNet, QNetConfig}; -pub use resnet::{ResNet, ResNetConfig}; - -use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; - -/// A neural network that produces a policy and a value from an observation. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) -/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) -/// -/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses -/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. -/// Use an `Arc>` wrapper if cross-thread sharing is needed. -pub trait PolicyValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); -} - -/// A neural network that outputs one Q-value per action. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - output: `[batch, action_size]` — raw Q-values (no activation) -/// -/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. -pub trait QValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> Tensor; -} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs deleted file mode 100644 index 1737f72..0000000 --- a/spiel_bot/src/network/qnet.rs +++ /dev/null @@ -1,147 +0,0 @@ -//! Single-headed Q-value network for DQN. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! → Linear(hidden → action_size) ← raw Q-values, no activation -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{activation::relu, backend::Backend, Tensor}, -}; -use std::path::Path; - -use super::QValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`QNet`]. -#[derive(Debug, Clone)] -pub struct QNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for QNetConfig { - fn default() -> Self { - Self { obs_size: 217, action_size: 514, hidden_size: 256 } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Two-hidden-layer MLP that outputs one Q-value per action. -#[derive(Module, Debug)] -pub struct QNet { - fc1: Linear, - fc2: Linear, - q_head: Linear, -} - -impl QNet { - /// Construct a fresh network with random weights. - pub fn new(config: &QNetConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl QValueNet for QNet { - fn forward(&self, obs: Tensor) -> Tensor { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - self.q_head.forward(x) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - - fn default_net() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn forward_output_shape() { - let net = default_net(); - let obs = Tensor::zeros([4, 217], &device()); - let q = net.forward(obs); - assert_eq!(q.dims(), [4, 514]); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let q = net.forward(Tensor::zeros([1, 217], &device())); - assert_eq!(q.dims(), [1, 514]); - } - - #[test] - fn q_values_not_all_equal() { - let net = default_net(); - let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) - .into_data().to_vec().unwrap(); - let first = q[0]; - assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); - } - - #[test] - fn custom_config_shapes() { - let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; - let net = QNet::::new(&cfg, &device()); - let q = net.forward(Tensor::zeros([3, 10], &device())); - assert_eq!(q.dims(), [3, 20]); - } - - #[test] - fn save_load_preserves_weights() { - let net = default_net(); - let obs = Tensor::::ones([2, 217], &device()); - let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); - - let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); - net.save(&path).expect("save failed"); - - let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); - let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); - - for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); - } - let _ = std::fs::remove_file(path); - } -} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs deleted file mode 100644 index d20d5ad..0000000 --- a/spiel_bot/src/network/resnet.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Residual-block policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU (input projection) -//! → ResBlock × 4 (residual trunk) -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! -//! ResBlock: -//! x → Linear → ReLU → Linear → (+x) → ReLU -//! ``` -//! -//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better -//! suited for long training runs where board-pattern recognition matters. - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`ResNet`]. -#[derive(Debug, Clone)] -pub struct ResNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of all hidden layers (input projection + residual blocks). - pub hidden_size: usize, -} - -impl Default for ResNetConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 512, - } - } -} - -// ── Residual block ──────────────────────────────────────────────────────────── - -/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. -/// -/// Both linear layers preserve the hidden dimension so the skip connection -/// can be added without projection. -#[derive(Module, Debug)] -struct ResBlock { - fc1: Linear, - fc2: Linear, -} - -impl ResBlock { - fn new(hidden: usize, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(hidden, hidden).init(device), - fc2: LinearConfig::new(hidden, hidden).init(device), - } - } - - fn forward(&self, x: Tensor) -> Tensor { - let residual = x.clone(); - let out = relu(self.fc1.forward(x)); - relu(self.fc2.forward(out) + residual) - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Four-residual-block policy-value network. -/// -/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and -/// when representing complex positional patterns is important. -#[derive(Module, Debug)] -pub struct ResNet { - input: Linear, - block0: ResBlock, - block1: ResBlock, - block2: ResBlock, - block3: ResBlock, - policy_head: Linear, - value_head: Linear, -} - -impl ResNet { - /// Construct a fresh network with random weights. - pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { - let h = config.hidden_size; - Self { - input: LinearConfig::new(config.obs_size, h).init(device), - block0: ResBlock::new(h, device), - block1: ResBlock::new(h, device), - block2: ResBlock::new(h, device), - block3: ResBlock::new(h, device), - policy_head: LinearConfig::new(h, config.action_size).init(device), - value_head: LinearConfig::new(h, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for ResNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.input.forward(obs)); - let x = self.block0.forward(x); - let x = self.block1.forward(x); - let x = self.block2.forward(x); - let x = self.block3.forward(x); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn small_config() -> ResNetConfig { - // Use a small hidden size so tests are fast. - ResNetConfig { - obs_size: 217, - action_size: 514, - hidden_size: 64, - } - } - - fn net() -> ResNet { - ResNet::new(&small_config(), &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let obs = Tensor::zeros([4, 217], &device()); - let (policy, value) = net().forward(obs); - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net().forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Residual connections ────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical"); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = small_config(); - let model = net(); - let obs = Tensor::::ones([2, 217], &device()); - - let (policy_before, value_before) = model.forward(obs.clone()); - - let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); - model.save(&path).expect("save failed"); - - let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } - - // ── Integration: both architectures satisfy PolicyValueNet ──────────── - - #[test] - fn resnet_satisfies_trait() { - fn requires_net>(net: &N, obs: Tensor) { - let (p, v) = net.forward(obs); - assert_eq!(p.dims()[1], 514); - assert_eq!(v.dims()[1], 1); - } - requires_net(&net(), Tensor::zeros([2, 217], &device())); - } -} diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs deleted file mode 100644 index d73fda0..0000000 --- a/spiel_bot/tests/integration.rs +++ /dev/null @@ -1,391 +0,0 @@ -//! End-to-end integration tests for the AlphaZero training pipeline. -//! -//! Each test exercises the full chain: -//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] -//! -//! Two environments are used: -//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. -//! Used when we need many iterations without worrying about runtime. -//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that -//! the full pipeline compiles and runs correctly with 217-dim observations -//! and 514-dim action spaces. -//! -//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep -//! runtime minimal; correctness, not training quality, is what matters here. - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::MctsConfig, - network::{MlpConfig, MlpNet, PolicyValueNet}, -}; - -// ── Backend aliases ──────────────────────────────────────────────────────── - -type Train = Autodiff>; -type Infer = NdArray; - -// ── Helpers ──────────────────────────────────────────────────────────────── - -fn train_device() -> ::Device { - Default::default() -} - -fn infer_device() -> ::Device { - Default::default() -} - -/// Tiny 64-unit MLP, compatible with an obs/action space of any size. -fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { - let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; - MlpNet::new(&cfg, &train_device()) -} - -fn tiny_mcts(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - } -} - -fn seeded() -> SmallRng { - SmallRng::seed_from_u64(0) -} - -// ── Countdown environment (fast, local, no external deps) ───────────────── -// -// Two players alternate subtracting 1 or 2 from a counter that starts at N. -// The player who brings the counter to 0 wins. - -#[derive(Clone, Debug)] -struct CState { - remaining: u8, - to_move: usize, -} - -#[derive(Clone)] -struct CountdownEnv(u8); // starting value - -impl GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: self.0, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } - else { Player::P2 } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } -} - -// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── - -/// The canonical AlphaZero loop: self-play → replay → train, iterated. -/// Uses CountdownEnv so each game terminates in < 10 moves. -#[test] -fn countdown_full_loop_no_panic() { - let env = CountdownEnv(8); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(1_000); - - for _iter in 0..5 { - // Self-play: 3 games per iteration. - for _ in 0..3 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - replay.extend(samples); - } - - // Training: 4 gradient steps per iteration. - if replay.len() >= 4 { - for _ in 0..4 { - let batch: Vec = replay - .sample_batch(4, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - } - } - - assert!(replay.len() > 0); -} - -// ── 2. Replay buffer invariants ─────────────────────────────────────────── - -/// After several Countdown games, replay capacity is respected and batch -/// shapes are consistent. -#[test] -fn replay_buffer_capacity_and_shapes() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let capacity = 50; - let mut replay = ReplayBuffer::new(capacity); - - for _ in 0..20 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - replay.extend(samples); - } - - assert!(replay.len() <= capacity, "buffer exceeded capacity"); - assert!(replay.len() > 0); - - let batch = replay.sample_batch(8, &mut rng); - assert_eq!(batch.len(), 8.min(replay.len())); - for s in &batch { - assert_eq!(s.obs.len(), env.obs_size()); - assert_eq!(s.policy.len(), env.action_space()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); - assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); - } -} - -// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── - -/// Verify that one TrictracEnv episode produces samples with the correct -/// tensor dimensions: obs = 217, policy = 514. -#[test] -fn trictrac_sample_shapes() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - - assert!(!samples.is_empty(), "Trictrac episode produced no samples"); - - for (i, s) in samples.iter().enumerate() { - assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); - assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!( - (policy_sum - 1.0).abs() < 1e-4, - "sample {i}: policy sums to {policy_sum}" - ); - assert!( - s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "sample {i}: unexpected value {}", - s.value - ); - } -} - -// ── 4. TrictracEnv: training step after real self-play ──────────────────── - -/// Collect one Trictrac episode, then verify that a gradient step runs -/// without panic and produces a finite loss. -#[test] -fn trictrac_train_step_finite_loss() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(10_000); - - // Generate one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - let n_samples = samples.len(); - replay.extend(samples); - - // Train on a batch from this episode. - let batch_size = 8.min(n_samples); - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); -} - -// ── 5. Backend transfer: train → infer → same outputs ───────────────────── - -/// Weights transferred from the training backend to the inference backend -/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. -#[test] -fn valid_model_matches_train_model_outputs() { - use burn::tensor::{Tensor, TensorData}; - - let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let train_model = MlpNet::::new(&cfg, &train_device()); - let infer_model: MlpNet = train_model.valid(); - - // Build the same input on both backends. - let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; - - let obs_train = Tensor::::from_data( - TensorData::new(obs_data.clone(), [1, 4]), - &train_device(), - ); - let obs_infer = Tensor::::from_data( - TensorData::new(obs_data, [1, 4]), - &infer_device(), - ); - - let (p_train, v_train) = train_model.forward(obs_train); - let (p_infer, v_infer) = infer_model.forward(obs_infer); - - let p_train: Vec = p_train.into_data().to_vec().unwrap(); - let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); - let v_train: Vec = v_train.into_data().to_vec().unwrap(); - let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); - - for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { - assert!( - (a - b).abs() < 1e-5, - "policy[{i}] differs after valid(): train={a}, infer={b}" - ); - } - assert!( - (v_train[0] - v_infer[0]).abs() < 1e-5, - "value differs after valid(): train={}, infer={}", - v_train[0], v_infer[0] - ); -} - -// ── 6. Loss converges on a fixed batch ──────────────────────────────────── - -/// With repeated gradient steps on the same Countdown batch, the loss must -/// decrease monotonically (or at least end lower than it started). -#[test] -fn loss_decreases_on_fixed_batch() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - - // Collect a fixed batch from one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); - assert!(!samples.is_empty()); - - let batch: Vec = { - let mut replay = ReplayBuffer::new(1000); - replay.extend(samples); - replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() - }; - - // Overfit on the same fixed batch for 20 steps. - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut first_loss = f32::NAN; - let mut last_loss = f32::NAN; - - for step in 0..20 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); - model = m; - assert!(loss.is_finite(), "loss is not finite at step {step}"); - if step == 0 { first_loss = loss; } - last_loss = loss; - } - - assert!( - last_loss < first_loss, - "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" - ); -} - -// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── - -/// Two full self-play + train iterations on TrictracEnv. -/// Verifies the entire pipeline runs without panic end-to-end. -#[test] -fn trictrac_two_iteration_loop() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mut model = MlpNet::::new(&cfg, &train_device()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(20_000); - - for iter in 0..2 { - // Self-play: 1 game per iteration. - let infer: MlpNet = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); - assert!(!samples.is_empty(), "iter {iter}: episode was empty"); - replay.extend(samples); - - // Training: 3 gradient steps. - let batch_size = 16.min(replay.len()); - for _ in 0..3 { - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "iter {iter}: loss={loss}"); - } - } -} diff --git a/store/Cargo.toml b/store/Cargo.toml index 935a2a0..a9234ff 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -25,9 +25,5 @@ rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" -[[bin]] -name = "random_game" -path = "src/bin/random_game.rs" - [build-dependencies] cxx-build = "1.0" diff --git a/store/src/bin/random_game.rs b/store/src/bin/random_game.rs deleted file mode 100644 index 6da3b9c..0000000 --- a/store/src/bin/random_game.rs +++ /dev/null @@ -1,262 +0,0 @@ -//! Run one or many games of trictrac between two random players. -//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`. -//! In multi-game mode, runs silently and reports throughput at the end. -//! -//! Usage: -//! cargo run --bin random_game -- [--seed ] [--games ] [--max-steps ] [--verbose] - -use std::borrow::Cow; -use std::env; -use std::time::Instant; - -use trictrac_store::{ - training_common::sample_valid_action, - Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, -}; - -// ── CLI args ────────────────────────────────────────────────────────────────── - -struct Args { - seed: Option, - games: usize, - max_steps: usize, - verbose: bool, -} - -fn parse_args() -> Args { - let args: Vec = env::args().collect(); - let mut seed = None; - let mut games = 1; - let mut max_steps = 10_000; - let mut verbose = false; - - let mut i = 1; - while i < args.len() { - match args[i].as_str() { - "--seed" => { - i += 1; - seed = args.get(i).and_then(|s| s.parse().ok()); - } - "--games" => { - i += 1; - if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { - games = v; - } - } - "--max-steps" => { - i += 1; - if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) { - max_steps = v; - } - } - "--verbose" => verbose = true, - _ => {} - } - i += 1; - } - - Args { - seed, - games, - max_steps, - verbose, - } -} - -// ── Helpers ─────────────────────────────────────────────────────────────────── - -fn player_label(id: u64) -> &'static str { - if id == 1 { "White" } else { "Black" } -} - -/// Apply a `Roll` + `RollResult` in one logical step, returning the dice. -/// This collapses the two-step dice phase into a single "chance node" action, -/// matching how the OpenSpiel layer exposes it. -fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result { - // RollDice → RollWaiting - state - .consume(&GameEvent::Roll { player_id: state.active_player_id }) - .map_err(|e| format!("Roll event failed: {e}"))?; - - // RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole) - let dice = roller.roll(); - state - .consume(&GameEvent::RollResult { player_id: state.active_player_id, dice }) - .map_err(|e| format!("RollResult event failed: {e}"))?; - - Ok(dice) -} - -/// Sample a random action and apply it to `state`, handling the Black-mirror -/// transform exactly as `cxxengine.rs::apply_action` does: -/// -/// 1. For Black, build a mirrored view of the state so that `sample_valid_action` -/// and `to_event` always reason from White's perspective. -/// 2. Mirror the resulting event back to the original coordinate frame before -/// calling `state.consume`. -/// -/// Returns the chosen action (in the view's coordinate frame) for display. -fn apply_player_action(state: &mut GameState) -> Result<(), String> { - let needs_mirror = state.active_player_id == 2; - - // Build a White-perspective view: borrowed for White, owned mirror for Black. - let view: Cow = if needs_mirror { - Cow::Owned(state.mirror()) - } else { - Cow::Borrowed(state) - }; - - let action = sample_valid_action(&view) - .ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?; - - let event = action - .to_event(&view) - .ok_or_else(|| format!("could not convert {action:?} to event"))?; - - // Translate the event from the view's frame back to the game's frame. - let event = if needs_mirror { event.get_mirror(false) } else { event }; - - state - .consume(&event) - .map_err(|e| format!("consume({action:?}): {e}"))?; - - Ok(()) -} - -// ── Single game ──────────────────────────────────────────────────────────────── - -/// Run one full game, optionally printing play-by-play. -/// Returns `(steps, truncated)`. -fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) { - let mut state = GameState::new_with_players("White", "Black"); - let mut step = 0usize; - - if !quiet { - println!("{state}"); - } - - while state.stage != Stage::Ended { - step += 1; - if step > max_steps { - return (step - 1, true); - } - - match state.turn_stage { - TurnStage::RollDice => { - let player = state.active_player_id; - match apply_dice_roll(&mut state, roller) { - Ok(dice) => { - if !quiet { - println!( - "[step {step:4}] {} rolls: {} & {}", - player_label(player), - dice.values.0, - dice.values.1 - ); - } - } - Err(e) => { - eprintln!("Error during dice roll: {e}"); - eprintln!("State:\n{state}"); - return (step, true); - } - } - } - stage => { - let player = state.active_player_id; - match apply_player_action(&mut state) { - Ok(()) => { - if !quiet { - println!( - "[step {step:4}] {} ({stage:?})", - player_label(player) - ); - if verbose { - println!("{state}"); - } - } - } - Err(e) => { - eprintln!("Error: {e}"); - eprintln!("State:\n{state}"); - return (step, true); - } - } - } - } - } - - if !quiet { - println!("\n=== Game over after {step} steps ===\n"); - println!("{state}"); - - let white = state.players.get(&1); - let black = state.players.get(&2); - - match (white, black) { - (Some(w), Some(b)) => { - println!("White — holes: {:2}, points: {:2}", w.holes, w.points); - println!("Black — holes: {:2}, points: {:2}", b.holes, b.points); - println!(); - - let white_score = w.holes as i32 * 12 + w.points as i32; - let black_score = b.holes as i32 * 12 + b.points as i32; - - if white_score > black_score { - println!("Winner: White (+{})", white_score - black_score); - } else if black_score > white_score { - println!("Winner: Black (+{})", black_score - white_score); - } else { - println!("Draw"); - } - } - _ => eprintln!("Could not read final player scores."), - } - } - - (step, false) -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - let mut roller = DiceRoller::new(args.seed); - - if args.games == 1 { - println!("=== Trictrac — random game ==="); - if let Some(s) = args.seed { - println!("seed: {s}"); - } - println!(); - run_game(&mut roller, args.max_steps, false, args.verbose); - } else { - println!("=== Trictrac — {} games ===", args.games); - if let Some(s) = args.seed { - println!("seed: {s}"); - } - println!(); - - let mut total_steps = 0u64; - let mut truncated = 0usize; - - let t0 = Instant::now(); - for _ in 0..args.games { - let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose); - total_steps += steps as u64; - if trunc { - truncated += 1; - } - } - let elapsed = t0.elapsed(); - - let secs = elapsed.as_secs_f64(); - println!("Games : {}", args.games); - println!("Truncated : {truncated}"); - println!("Total steps: {total_steps}"); - println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64); - println!("Elapsed : {:.3} s", secs); - println!("Throughput : {:.1} games/s", args.games as f64 / secs); - println!(" {:.0} steps/s", total_steps as f64 / secs); - } -} diff --git a/store/src/board.rs b/store/src/board.rs index 0fba2d6..de0e450 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -598,40 +598,12 @@ impl Board { core::array::from_fn(|i| i + min) } - /// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i. - /// result[0] = 0. - pub fn white_checker_cumulative(&self) -> [u8; 25] { - let mut cum = [0u8; 25]; - let mut total = 0u8; - for (i, &count) in self.positions.iter().enumerate() { - if count > 0 { - total += count as u8; - } - cum[i + 1] = total; - } - cum - } - pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> { self.remove_checker(color, cmove.from)?; self.add_checker(color, cmove.to)?; Ok(()) } - /// Reverse a previously applied `move_checker`. No validation: assumes the move was valid. - pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) { - let unit = match color { - Color::White => 1, - Color::Black => -1, - }; - if cmove.from != 0 { - self.positions[cmove.from - 1] += unit; - } - if cmove.to != 0 { - self.positions[cmove.to - 1] -= unit; - } - } - pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> { if field == 0 { return Ok(()); diff --git a/store/src/cxxengine.rs b/store/src/cxxengine.rs index 55d348c..29bc7fe 100644 --- a/store/src/cxxengine.rs +++ b/store/src/cxxengine.rs @@ -83,8 +83,8 @@ pub mod ffi { /// Both players' scores. fn get_players_scores(self: &TricTracEngine) -> PlayerScores; - /// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1. - fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; + /// 36-element state vector (i8). Mirrored for player_idx == 1. + fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec; /// Human-readable state description for `player_idx`. fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; @@ -153,7 +153,8 @@ impl TricTracEngine { .map(|v| v.into_iter().map(|i| i as u64).collect()) } else { let mirror = self.game_state.mirror(); - get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect()) + get_valid_action_indices(&mirror) + .map(|v| v.into_iter().map(|i| i as u64).collect()) } })) } @@ -179,11 +180,11 @@ impl TricTracEngine { .unwrap_or(-1) } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_tensor() + self.game_state.to_vec() } else { - self.game_state.mirror().to_tensor() + self.game_state.mirror().to_vec() } } @@ -242,9 +243,8 @@ impl TricTracEngine { self.game_state ), None => anyhow::bail!( - "apply_action: could not build event from action index {} in state {}", - action_idx, - self.game_state + "apply_action: could not build event from action index {}", + action_idx ), } })) diff --git a/store/src/game.rs b/store/src/game.rs index e4e938c..d32734d 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -156,6 +156,13 @@ impl GameState { if let Some(p1) = self.players.get(&1) { mirrored_players.insert(2, p1.mirror()); } + let mirrored_history = self + .history + .clone() + .iter() + .map(|evt| evt.get_mirror(false)) + .collect(); + let (move1, move2) = self.dice_moves; GameState { stage: self.stage, @@ -164,7 +171,7 @@ impl GameState { active_player_id: mirrored_active_player, // active_player_id: self.active_player_id, players: mirrored_players, - history: Vec::new(), + history: mirrored_history, dice: self.dice, dice_points: self.dice_points, dice_moves: (move1.mirror(), move2.mirror()), @@ -200,110 +207,6 @@ impl GameState { self.to_vec().iter().map(|&x| x as f32).collect() } - /// Get state as a tensor for neural network training (Option B, TD-Gammon style). - /// Returns 217 f32 values, all normalized to [0, 1]. - /// - /// Must be called from the active player's perspective: callers should mirror - /// the GameState for Black before calling so that "own" always means White. - /// - /// Layout: - /// [0..95] own (White) checkers: 4 values per field × 24 fields - /// [96..191] opp (Black) checkers: 4 values per field × 24 fields - /// [192..193] dice values / 6 - /// [194] active player color (0=White, 1=Black) - /// [195] turn_stage / 5 - /// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille - /// [200..203] Black player: same - /// [204..207] own quarter filled (quarters 1-4) - /// [208..211] opp quarter filled (quarters 1-4) - /// [212] own checkers all in exit zone (fields 19-24) - /// [213] opp checkers all in exit zone (fields 1-6) - /// [214] own coin de repos taken (field 12 has ≥2 own checkers) - /// [215] opp coin de repos taken (field 13 has ≥2 opp checkers) - /// [216] own dice_roll_count / 3, clamped to 1 - pub fn to_tensor(&self) -> Vec { - let mut t = Vec::with_capacity(217); - let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - - // [0..95] own (White) checkers, TD-Gammon encoding. - // Each field contributes 4 values: - // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] - // The overflow term is divided by 12 because the maximum excess is - // 15 (all checkers) − 3 = 12. - for &c in &pos { - let own = c.max(0) as u8; - t.push((own == 1) as u8 as f32); - t.push((own == 2) as u8 as f32); - t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32 / 12.0); - } - - // [96..191] opp (Black) checkers, same encoding. - for &c in &pos { - let opp = (-c).max(0) as u8; - t.push((opp == 1) as u8 as f32); - t.push((opp == 2) as u8 as f32); - t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32 / 12.0); - } - - // [192..193] dice - t.push(self.dice.values.0 as f32 / 6.0); - t.push(self.dice.values.1 as f32 / 6.0); - - // [194] active player color - t.push( - self.who_plays() - .map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 }) - .unwrap_or(0.0), - ); - - // [195] turn stage - t.push(u8::from(self.turn_stage) as f32 / 5.0); - - // [196..199] White player stats - let wp = self.get_white_player(); - t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0)); - t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0)); - t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); - t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); - - // [200..203] Black player stats - let bp = self.get_black_player(); - t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0)); - t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0)); - t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32)); - t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32)); - - // [204..207] own (White) quarter fill status - for &start in &[1usize, 7, 13, 19] { - t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32); - } - - // [208..211] opp (Black) quarter fill status - for &start in &[1usize, 7, 13, 19] { - t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32); - } - - // [212] can_exit_own: no own checker in fields 1-18 - t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32); - - // [213] can_exit_opp: no opp checker in fields 7-24 - t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32); - - // [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers) - t.push((pos[11] >= 2) as u8 as f32); - - // [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers) - t.push((pos[12] <= -2) as u8 as f32); - - // [216] own dice_roll_count / 3, clamped to 1 - t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0)); - - debug_assert_eq!(t.len(), 217, "to_tensor length mismatch"); - t - } - /// Get state as a vector (to be used for bot training input) : /// length = 36 /// i8 for board positions with negative values for blacks @@ -1011,16 +914,6 @@ impl GameState { self.mark_points(player_id, points) } - /// Total accumulated score for a player: `holes × 12 + points`. - /// - /// Returns `0` if `player_id` is not found (e.g. before `init_player`). - pub fn total_score(&self, player_id: PlayerId) -> i32 { - self.players - .get(&player_id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - } - fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false; diff --git a/store/src/game_rules_moves.rs b/store/src/game_rules_moves.rs index 396bcaf..41221f2 100644 --- a/store/src/game_rules_moves.rs +++ b/store/src/game_rules_moves.rs @@ -220,7 +220,7 @@ impl MoveRules { // Si possible, les deux dés doivent être joués if moves.0.get_from() == 0 || moves.1.get_from() == 0 { let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]); - possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok()); + possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok()); // possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok()); if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() { if *moves == (EMPTY_MOVE, EMPTY_MOVE) { @@ -238,7 +238,7 @@ impl MoveRules { // check exit rules // if !ignored_rules.contains(&TricTracRule::Exit) { - self.check_exit_rules(moves, None)?; + self.check_exit_rules(moves)?; // } // --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ---- @@ -321,11 +321,7 @@ impl MoveRules { .is_empty() } - fn check_exit_rules( - &self, - moves: &(CheckerMove, CheckerMove), - exit_seqs: Option<&[(CheckerMove, CheckerMove)]>, - ) -> Result<(), MoveError> { + fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> { if !moves.0.is_exit() && !moves.1.is_exit() { return Ok(()); } @@ -335,22 +331,16 @@ impl MoveRules { } // toutes les sorties directes sont autorisées, ainsi que les nombres défaillants - let owned; - let seqs = match exit_seqs { - Some(s) => s, - None => { - owned = self - .get_possible_moves_sequences(false, vec![TricTracRule::Exit]); - &owned - } - }; - if seqs.contains(moves) { + let ignored_rules = vec![TricTracRule::Exit]; + let possible_moves_sequences_without_excedent = + self.get_possible_moves_sequences(false, ignored_rules); + if possible_moves_sequences_without_excedent.contains(moves) { return Ok(()); } // À ce stade au moins un des déplacements concerne un nombre en excédant // - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on // refuse cette séquence - if !seqs.is_empty() { + if !possible_moves_sequences_without_excedent.is_empty() { return Err(MoveError::ExitByEffectPossible); } @@ -371,24 +361,17 @@ impl MoveRules { let _ = board_to_check.move_checker(&Color::White, moves.0); let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check); - // dice normal order - let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true); - let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) - || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); - - // dice reversed order - let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false); - let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1) - || (is_move2_exedant && moves.1.get_from() != farthest_on_move2); - - if is_not_farthest1 && is_not_farthest2 { + let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves); + if (is_move1_exedant && moves.0.get_from() != farthest_on_move1) + || (is_move2_exedant && moves.1.get_from() != farthest_on_move2) + { return Err(MoveError::ExitNotFarthest); } Ok(()) } - fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) { + fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) { let move1to = if moves.0.get_to() == 0 { 25 } else { @@ -403,16 +386,20 @@ impl MoveRules { }; let dist2 = move2to - moves.1.get_from(); - let (dice1, dice2) = if dice_order { - self.dice.values - } else { - (self.dice.values.1, self.dice.values.0) - }; + let dist_min = cmp::min(dist1, dist2); + let dist_max = cmp::max(dist1, dist2); - ( - dist1 != 0 && dist1 < dice1 as usize, - dist2 != 0 && dist2 < dice2 as usize, - ) + let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize; + let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize; + + let min_excedant = dist_min != 0 && dist_min < dice_min; + let max_excedant = dist_max != 0 && dist_max < dice_max; + + if dist_min == dist1 { + (min_excedant, max_excedant) + } else { + (max_excedant, min_excedant) + } } fn get_board_exit_farthest(board: &Board) -> Field { @@ -451,18 +438,12 @@ impl MoveRules { } else { (dice2, dice1) }; - let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) { - Some(self.get_quarter_filling_moves_sequences()) - } else { - None - }; let mut moves_seqs = self.get_possible_moves_sequences_by_dices( dice_max, dice_min, with_excedents, false, - &ignored_rules, - filling_seqs.as_deref(), + ignored_rules.clone(), ); // if we got valid sequences with the highest die, we don't accept sequences using only the // lowest die @@ -472,8 +453,7 @@ impl MoveRules { dice_max, with_excedents, ignore_empty, - &ignored_rules, - filling_seqs.as_deref(), + ignored_rules, ); moves_seqs.append(&mut moves_seqs_order2); let empty_removed = moves_seqs @@ -544,16 +524,14 @@ impl MoveRules { let mut moves_seqs = Vec::new(); let color = &Color::White; let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter]; - let mut board = self.board.clone(); for moves in self.get_possible_moves_sequences(true, ignored_rules) { + let mut board = self.board.clone(); board.move_checker(color, moves.0).unwrap(); board.move_checker(color, moves.1).unwrap(); // println!("get_quarter_filling_moves_sequences board : {:?}", board); if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) { moves_seqs.push(moves); } - board.unmove_checker(color, moves.1); - board.unmove_checker(color, moves.0); } moves_seqs } @@ -564,27 +542,18 @@ impl MoveRules { dice2: u8, with_excedents: bool, ignore_empty: bool, - ignored_rules: &[TricTracRule], - filling_seqs: Option<&[(CheckerMove, CheckerMove)]>, + ignored_rules: Vec, ) -> Vec<(CheckerMove, CheckerMove)> { let mut moves_seqs = Vec::new(); let color = &Color::White; let forbid_exits = self.has_checkers_outside_last_quarter(); - // Precompute non-excedant sequences once so check_exit_rules need not repeat - // the full move generation for every exit-move candidate. - // Only needed when Exit is not already ignored and exits are actually reachable. - let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits { - Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit])) - } else { - None - }; - let mut board = self.board.clone(); // println!("==== First"); for first_move in self.board .get_possible_moves(*color, dice1, with_excedents, false, forbid_exits) { - if board.move_checker(color, first_move).is_err() { + let mut board2 = self.board.clone(); + if board2.move_checker(color, first_move).is_err() { println!("err move"); continue; } @@ -594,7 +563,7 @@ impl MoveRules { let mut has_second_dice_move = false; // println!(" ==== Second"); for second_move in - board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) + board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) { if self .check_corner_rules(&(first_move, second_move)) @@ -618,10 +587,24 @@ impl MoveRules { && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) || self - .check_exit_rules(&(first_move, second_move), exit_seqs.as_deref()) + .check_exit_rules(&(first_move, second_move)) + // .inspect_err(|e| { + // println!( + // " 2nd (exit rule): {:?} - {:?}, {:?}", + // e, first_move, second_move + // ) + // }) + .is_ok()) + && (ignored_rules.contains(&TricTracRule::MustFillQuarter) + || self + .check_must_fill_quarter_rule(&(first_move, second_move)) + // .inspect_err(|e| { + // println!( + // " 2nd: {:?} - {:?}, {:?} for {:?}", + // e, first_move, second_move, self.board + // ) + // }) .is_ok()) - && filling_seqs - .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move))) { if second_move.get_to() == 0 && first_move.get_to() == 0 @@ -644,14 +627,16 @@ impl MoveRules { && !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE)) && self.can_take_corner_by_effect()) && (ignored_rules.contains(&TricTracRule::Exit) - || self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok()) - && filling_seqs - .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE))) + || self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok()) + && (ignored_rules.contains(&TricTracRule::MustFillQuarter) + || self + .check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE)) + .is_ok()) { // empty move moves_seqs.push((first_move, EMPTY_MOVE)); } - board.unmove_checker(color, first_move); + //if board2.get_color_fields(*color).is_empty() { } moves_seqs } @@ -1510,7 +1495,6 @@ mod tests { CheckerMove::new(23, 0).unwrap(), CheckerMove::new(24, 0).unwrap(), ); - let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1518,8 +1502,7 @@ mod tests { state.dice.values.1, true, false, - &[], - filling_seqs.as_deref(), + vec![] ) ); @@ -1534,7 +1517,6 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), )]; - let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( moves, state.get_possible_moves_sequences_by_dices( @@ -1542,8 +1524,7 @@ mod tests { state.dice.values.1, true, false, - &[], - filling_seqs.as_deref(), + vec![] ) ); let moves = vec![( @@ -1557,8 +1538,7 @@ mod tests { state.dice.values.0, true, false, - &[], - filling_seqs.as_deref(), + vec![] ) ); @@ -1574,7 +1554,6 @@ mod tests { CheckerMove::new(19, 21).unwrap(), CheckerMove::new(23, 0).unwrap(), ); - let filling_seqs = Some(state.get_quarter_filling_moves_sequences()); assert_eq!( vec![moves], state.get_possible_moves_sequences_by_dices( @@ -1582,8 +1561,7 @@ mod tests { state.dice.values.1, true, false, - &[], - filling_seqs.as_deref(), + vec![] ) ); } @@ -1602,26 +1580,13 @@ mod tests { CheckerMove::new(19, 23).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves, None).is_ok()); + assert!(state.check_exit_rules(&moves).is_ok()); let moves = ( CheckerMove::new(19, 24).unwrap(), CheckerMove::new(22, 0).unwrap(), ); - assert!(state.check_exit_rules(&moves, None).is_ok()); - - state.dice.values = (6, 4); - state.board.set_positions( - &crate::Color::White, - [ - -4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1, - ], - ); - let moves = ( - CheckerMove::new(20, 24).unwrap(), - CheckerMove::new(23, 0).unwrap(), - ); - assert!(state.check_exit_rules(&moves, None).is_ok()); + assert!(state.check_exit_rules(&moves).is_ok()); } #[test] diff --git a/store/src/pyengine.rs b/store/src/pyengine.rs index 43b5713..b193987 100644 --- a/store/src/pyengine.rs +++ b/store/src/pyengine.rs @@ -113,11 +113,11 @@ impl TricTrac { [self.get_score(1), self.get_score(2)] } - fn get_tensor(&self, player_idx: u64) -> Vec { + fn get_tensor(&self, player_idx: u64) -> Vec { if player_idx == 0 { - self.game_state.to_tensor() + self.game_state.to_vec() } else { - self.game_state.mirror().to_tensor() + self.game_state.mirror().to_vec() } } diff --git a/store/src/training_common.rs b/store/src/training_common.rs index 69765fc..57094a9 100644 --- a/store/src/training_common.rs +++ b/store/src/training_common.rs @@ -3,6 +3,7 @@ use std::cmp::{max, min}; use std::fmt::{Debug, Display, Formatter}; +use crate::board::Board; use crate::{CheckerMove, Dice, GameEvent, GameState}; use serde::{Deserialize, Serialize}; @@ -220,14 +221,10 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result anyhow::Result anyhow::Result anyhow::Result { - // Moves are always in White's coordinate system. For Black, mirror the board first. - let cum = if color == &crate::Color::Black { - state.board.mirror().white_checker_cumulative() + let dice = &state.dice; + let board = &state.board; + + if color == &crate::Color::Black { + // Moves are already 'white', so we don't mirror them + white_checker_moves_to_trictrac_action( + move1, + move2, + // &move1.clone().mirror(), + // &move2.clone().mirror(), + dice, + &board.clone().mirror(), + ) + // .map(|a| a.mirror()) } else { - state.board.white_checker_cumulative() - }; - white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum) + white_checker_moves_to_trictrac_action(move1, move2, dice, board) + } } fn white_checker_moves_to_trictrac_action( move1: &CheckerMove, move2: &CheckerMove, dice: &Dice, - cum: &[u8; 25], + board: &Board, ) -> anyhow::Result { let to1 = move1.get_to(); let to2 = move2.get_to(); @@ -300,7 +302,7 @@ fn white_checker_moves_to_trictrac_action( } } else { // double sortie - if from1 < from2 || from2 == 0 { + if from1 < from2 { max(dice.values.0, dice.values.1) as usize } else { min(dice.values.0, dice.values.1) as usize @@ -319,21 +321,11 @@ fn white_checker_moves_to_trictrac_action( } let dice_order = diff_move1 == dice.values.0 as usize; - // cum[i] = # white checkers in fields 1..=i (precomputed by the caller). - // checker1 is the ordinal of the last checker at from1. - let checker1 = cum[from1] as usize; - // checker2 is the ordinal on the board after move1 (removed from from1, added to to1). - // Adjust the cumulative in O(1) without cloning the board. - let checker2 = { - let mut c = cum[from2]; - if from1 > 0 && from2 >= from1 { - c -= 1; // one checker was removed from from1, shifting later ordinals down - } - if from1 > 0 && to1 > 0 && from2 >= to1 { - c += 1; // one checker was added at to1, shifting later ordinals up - } - c as usize - }; + let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize; + let mut tmp_board = board.clone(); + // should not raise an error for a valid action + tmp_board.move_checker(&crate::Color::White, *move1)?; + let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; Ok(TrictracAction::Move { dice_order, checker1, @@ -464,48 +456,5 @@ mod tests { }), ttaction.ok() ); - - // Black player - state.active_player_id = 2; - state.dice = Dice { values: (6, 3) }; - state.board.set_positions( - &crate::Color::White, - [ - 2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4, - ], - ); - let ttaction = super::checker_moves_to_trictrac_action( - &CheckerMove::new(21, 0).unwrap(), - &CheckerMove::new(0, 0).unwrap(), - &crate::Color::Black, - &state, - ); - - assert_eq!( - Some(TrictracAction::Move { - dice_order: true, - checker1: 2, - checker2: 0, // blocked by white on last field - }), - ttaction.ok() - ); - - // same with dice order reversed - state.dice = Dice { values: (3, 6) }; - let ttaction = super::checker_moves_to_trictrac_action( - &CheckerMove::new(21, 0).unwrap(), - &CheckerMove::new(0, 0).unwrap(), - &crate::Color::Black, - &state, - ); - - assert_eq!( - Some(TrictracAction::Move { - dice_order: false, - checker1: 2, - checker2: 0, // blocked by white on last field - }), - ttaction.ok() - ); } }