Compare commits
No commits in common. "feature/open_spiel_parallel" and "main" have entirely different histories.
feature/op
...
main
37 changed files with 1105 additions and 7106 deletions
132
Cargo.lock
generated
132
Cargo.lock
generated
|
|
@ -92,12 +92,6 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
|
|
@ -1122,12 +1116,6 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cast_trait"
|
||||
version = "0.1.2"
|
||||
|
|
@ -1212,33 +1200,6 @@ dependencies = [
|
|||
"rand 0.7.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"ciborium-ll",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-io"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-ll"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
|
|
@ -1492,42 +1453,6 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "critical-section"
|
||||
version = "1.2.0"
|
||||
|
|
@ -4536,12 +4461,6 @@ version = "1.70.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
|
|
@ -4678,34 +4597,6 @@ version = "0.3.32"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
"plotters-svg",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.18.0"
|
||||
|
|
@ -6000,19 +5891,6 @@ dependencies = [
|
|||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spiel_bot"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"burn",
|
||||
"criterion",
|
||||
"rand 0.9.2",
|
||||
"rand_distr",
|
||||
"rayon",
|
||||
"trictrac-store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.10.0"
|
||||
|
|
@ -6421,16 +6299,6 @@ dependencies = [
|
|||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = ["client_cli", "bot", "store", "spiel_bot"]
|
||||
members = ["client_cli", "bot", "store"]
|
||||
|
|
|
|||
992
doc/plan_cxxbindings.md
Normal file
992
doc/plan_cxxbindings.md
Normal file
|
|
@ -0,0 +1,992 @@
|
|||
# Plan: C++ OpenSpiel Game via cxx.rs
|
||||
|
||||
> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings.
|
||||
>
|
||||
> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Overview
|
||||
|
||||
The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding:
|
||||
|
||||
- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box<ffi::TricTracEngine>` and delegates all logic to Rust.
|
||||
- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`.
|
||||
|
||||
The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games.
|
||||
|
||||
---
|
||||
|
||||
## 2. Files to Create / Modify
|
||||
|
||||
```
|
||||
trictrac/
|
||||
store/
|
||||
Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type
|
||||
build.rs ← CREATE: cxx-build bridge registration
|
||||
src/
|
||||
lib.rs ← MODIFY: add cxxengine module
|
||||
cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl
|
||||
|
||||
forks/open_spiel/
|
||||
CMakeLists.txt ← MODIFY: add Corrosion FetchContent
|
||||
open_spiel/
|
||||
games/
|
||||
CMakeLists.txt ← MODIFY: add trictrac/ sources + test
|
||||
trictrac/ ← CREATE directory
|
||||
trictrac.h ← CREATE
|
||||
trictrac.cc ← CREATE
|
||||
trictrac_test.cc ← CREATE
|
||||
|
||||
justfile ← MODIFY: add buildtrictrac target
|
||||
trictrac/
|
||||
justfile ← MODIFY: add cxxlib target
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Step 1 — Rust: `store/Cargo.toml`
|
||||
|
||||
Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library.
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "trictrac-store"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "trictrac_store"
|
||||
# cdylib → Python .so (used by maturin / pyengine)
|
||||
# rlib → used by other Rust crates in the workspace
|
||||
# staticlib → used by C++ consumers (cxxengine)
|
||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.21.7"
|
||||
cxx = "1.0"
|
||||
log = "0.4.20"
|
||||
merge = "0.1.0"
|
||||
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
|
||||
rand = "0.9"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
```
|
||||
|
||||
> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`).
|
||||
|
||||
---
|
||||
|
||||
## 4. Step 2 — Rust: `store/build.rs`
|
||||
|
||||
The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header.
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
cxx_build::bridge("src/cxxengine.rs")
|
||||
.std("c++17")
|
||||
.compile("trictrac-cxx");
|
||||
|
||||
// Re-run if the bridge source changes
|
||||
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
||||
}
|
||||
```
|
||||
|
||||
`cxx-build` will:
|
||||
|
||||
- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block.
|
||||
- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header.
|
||||
- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source.
|
||||
- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`).
|
||||
|
||||
---
|
||||
|
||||
## 5. Step 3 — Rust: `store/src/cxxengine.rs`
|
||||
|
||||
This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3.
|
||||
|
||||
### Design decisions vs. `pyengine.rs`
|
||||
|
||||
| pyengine | cxxengine | Reason |
|
||||
| ------------------------- | ---------------------------- | -------------------------------------------- |
|
||||
| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception |
|
||||
| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples |
|
||||
| `Vec<usize>` for actions | `Vec<u64>` | cxx does not support `usize` |
|
||||
| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays |
|
||||
| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy |
|
||||
|
||||
### File content
|
||||
|
||||
```rust
|
||||
//! # C++ bindings for the TricTrac game engine via cxx.rs
|
||||
//!
|
||||
//! Exposes an opaque `TricTracEngine` type and associated functions
|
||||
//! to C++. The C++ side (trictrac.cc) uses `rust::Box<ffi::TricTracEngine>`.
|
||||
//!
|
||||
//! The Rust engine always works from the perspective of White (player 1).
|
||||
//! For Black (player 2), the board is mirrored before computing actions
|
||||
//! and events are mirrored back before applying — exactly as in pyengine.rs.
|
||||
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||
|
||||
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
||||
|
||||
#[cxx::bridge(namespace = "trictrac_engine")]
|
||||
pub mod ffi {
|
||||
// ── Shared types (visible to both Rust and C++) ───────────────────────────
|
||||
|
||||
/// Two dice values passed from C++ to Rust for a dice-roll event.
|
||||
struct DicePair {
|
||||
die1: u8,
|
||||
die2: u8,
|
||||
}
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
struct PlayerScores {
|
||||
score_p1: i32,
|
||||
score_p2: i32,
|
||||
}
|
||||
|
||||
// ── Opaque Rust type exposed to C++ ───────────────────────────────────────
|
||||
|
||||
extern "Rust" {
|
||||
/// Opaque handle to a TricTrac game state.
|
||||
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
||||
type TricTracEngine;
|
||||
|
||||
/// Create a new engine, initialise two players, begin with player 1.
|
||||
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
||||
|
||||
/// Return a deep copy of the engine (needed for State::Clone()).
|
||||
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
||||
|
||||
// ── Queries ───────────────────────────────────────────────────────────
|
||||
|
||||
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
||||
fn needs_roll(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// True when Stage::Ended.
|
||||
fn is_game_ended(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// Active player index: 0 (player 1 / White) or 1 (player 2 / Black).
|
||||
fn current_player_idx(self: &TricTracEngine) -> u64;
|
||||
|
||||
/// Legal action indices for `player_idx`. Returns empty vec if it is
|
||||
/// not that player's turn. Indices are in [0, 513].
|
||||
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec<u64>;
|
||||
|
||||
/// Human-readable action description, e.g. "0:Move { dice_order: true … }".
|
||||
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 36-element state observation vector (i8). Mirrored for player 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
||||
/// Full debug representation of the current state.
|
||||
fn to_debug_string(self: &TricTracEngine) -> String;
|
||||
|
||||
// ── Mutations ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a dice roll result. Returns Err if not in RollWaiting stage.
|
||||
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
||||
|
||||
/// Apply a player action (move, go, roll). Returns Err if invalid.
|
||||
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Opaque type implementation ────────────────────────────────────────────────
|
||||
|
||||
pub struct TricTracEngine {
|
||||
game_state: GameState,
|
||||
}
|
||||
|
||||
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
Box::new(TricTracEngine { game_state })
|
||||
}
|
||||
|
||||
impl TricTracEngine {
|
||||
fn clone_engine(&self) -> Box<TricTracEngine> {
|
||||
Box::new(TricTracEngine {
|
||||
game_state: self.game_state.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn needs_roll(&self) -> bool {
|
||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||
}
|
||||
|
||||
fn is_game_ended(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended
|
||||
}
|
||||
|
||||
/// Returns 0 for player 1 (White) and 1 for player 2 (Black).
|
||||
fn current_player_idx(&self) -> u64 {
|
||||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<u64> {
|
||||
if player_idx == self.current_player_idx() {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
||||
TrictracAction::from_action_index(action_idx as usize)
|
||||
.map(|a| format!("{}:{}", player_idx, a))
|
||||
.unwrap_or_else(|| "unknown action".into())
|
||||
}
|
||||
|
||||
fn get_players_scores(&self) -> ffi::PlayerScores {
|
||||
ffi::PlayerScores {
|
||||
score_p1: self.score_for(1),
|
||||
score_p2: self.score_for(2),
|
||||
}
|
||||
}
|
||||
|
||||
fn score_for(&self, player_id: u64) -> i32 {
|
||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
||||
player.holes as i32 * 12 + player.points as i32
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_debug_string(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
|
||||
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> {
|
||||
let player_id = self.game_state.active_player_id;
|
||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
return Err("Not in RollWaiting stage".into());
|
||||
}
|
||||
let dice = Dice {
|
||||
values: (dice.die1, dice.die2),
|
||||
};
|
||||
self.game_state
|
||||
.consume(&GameEvent::RollResult { player_id, dice });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: u64) -> Result<(), String> {
|
||||
let action_idx = action_idx as usize;
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
|
||||
let event = TrictracAction::from_action_index(action_idx)
|
||||
.and_then(|a| {
|
||||
let game_state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(game_state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
});
|
||||
|
||||
match event {
|
||||
Some(evt) if self.game_state.validate(&evt) => {
|
||||
self.game_state.consume(&evt);
|
||||
Ok(())
|
||||
}
|
||||
Some(_) => Err("Action is invalid".into()),
|
||||
None => Err("Could not build event from action index".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options:
|
||||
>
|
||||
> - Use `anyhow::Error` (add `anyhow` dependency).
|
||||
> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`.
|
||||
>
|
||||
> The recommended approach is `anyhow`:
|
||||
>
|
||||
> ```toml
|
||||
> [dependencies]
|
||||
> anyhow = "1.0"
|
||||
> ```
|
||||
>
|
||||
> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message.
|
||||
|
||||
---
|
||||
|
||||
## 6. Step 4 — Rust: `store/src/lib.rs`
|
||||
|
||||
Add the new module:
|
||||
|
||||
```rust
|
||||
// existing modules …
|
||||
mod pyengine;
|
||||
|
||||
// NEW: C++ bindings via cxx.rs
|
||||
pub mod cxxengine;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Step 5 — C++: `trictrac/trictrac.h`
|
||||
|
||||
Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box<ffi::TricTracEngine>` and delegates everything to it.
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.h
|
||||
#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
#define OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
// Generated by cxx-build from store/src/cxxengine.rs.
|
||||
// The include path is set by CMake (see CMakeLists.txt).
|
||||
#include "trictrac_store/src/cxxengine.rs.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
|
||||
inline constexpr int kNumPlayers = 2;
|
||||
inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes
|
||||
inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust
|
||||
inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust
|
||||
inline constexpr int kDefaultMaxTurns = 1000;
|
||||
|
||||
class TrictracGame;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracState
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracState : public State {
|
||||
public:
|
||||
explicit TrictracState(std::shared_ptr<const Game> game);
|
||||
TrictracState(const TrictracState& other);
|
||||
|
||||
Player CurrentPlayer() const override;
|
||||
std::vector<Action> LegalActions() const override;
|
||||
std::string ActionToString(Player player, Action move_id) const override;
|
||||
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;
|
||||
std::string ToString() const override;
|
||||
bool IsTerminal() const override;
|
||||
std::vector<double> Returns() const override;
|
||||
std::string ObservationString(Player player) const override;
|
||||
void ObservationTensor(Player player, absl::Span<float> values) const override;
|
||||
std::unique_ptr<State> Clone() const override;
|
||||
|
||||
protected:
|
||||
void DoApplyAction(Action move_id) override;
|
||||
|
||||
private:
|
||||
// Decode a chance action index [0,35] to (die1, die2).
|
||||
// Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
static trictrac_engine::DicePair DecodeChanceAction(Action action);
|
||||
|
||||
// The Rust engine handle. Deep-copied via clone_engine() when cloning state.
|
||||
rust::Box<trictrac_engine::TricTracEngine> engine_;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracGame
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracGame : public Game {
|
||||
public:
|
||||
explicit TrictracGame(const GameParameters& params);
|
||||
|
||||
int NumDistinctActions() const override { return kNumDistinctActions; }
|
||||
std::unique_ptr<State> NewInitialState() const override;
|
||||
int MaxChanceOutcomes() const override { return kNumChanceOutcomes; }
|
||||
int NumPlayers() const override { return kNumPlayers; }
|
||||
double MinUtility() const override { return 0.0; }
|
||||
double MaxUtility() const override { return 200.0; }
|
||||
int MaxGameLength() const override { return 3 * max_turns_; }
|
||||
int MaxChanceNodesInHistory() const override { return MaxGameLength(); }
|
||||
std::vector<int> ObservationTensorShape() const override {
|
||||
return {kStateEncodingSize};
|
||||
}
|
||||
|
||||
private:
|
||||
int max_turns_;
|
||||
};
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Step 6 — C++: `trictrac/trictrac.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/abseil-cpp/absl/types/span.h"
|
||||
#include "open_spiel/game_parameters.h"
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_globals.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
// ── Game registration ────────────────────────────────────────────────────────
|
||||
|
||||
const GameType kGameType{
|
||||
/*short_name=*/"trictrac",
|
||||
/*long_name=*/"Trictrac",
|
||||
GameType::Dynamics::kSequential,
|
||||
GameType::ChanceMode::kExplicitStochastic,
|
||||
GameType::Information::kPerfectInformation,
|
||||
GameType::Utility::kGeneralSum,
|
||||
GameType::RewardModel::kRewards,
|
||||
/*min_num_players=*/kNumPlayers,
|
||||
/*max_num_players=*/kNumPlayers,
|
||||
/*provides_information_state_string=*/false,
|
||||
/*provides_information_state_tensor=*/false,
|
||||
/*provides_observation_string=*/true,
|
||||
/*provides_observation_tensor=*/true,
|
||||
/*parameter_specification=*/{
|
||||
{"max_turns", GameParameter(kDefaultMaxTurns)},
|
||||
}};
|
||||
|
||||
static std::shared_ptr<const Game> Factory(const GameParameters& params) {
|
||||
return std::make_shared<const TrictracGame>(params);
|
||||
}
|
||||
|
||||
REGISTER_SPIEL_GAME(kGameType, Factory);
|
||||
|
||||
} // namespace
|
||||
|
||||
// ── TrictracGame ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracGame::TrictracGame(const GameParameters& params)
|
||||
: Game(kGameType, params),
|
||||
max_turns_(ParameterValue<int>("max_turns", kDefaultMaxTurns)) {}
|
||||
|
||||
std::unique_ptr<State> TrictracGame::NewInitialState() const {
|
||||
return std::make_unique<TrictracState>(shared_from_this());
|
||||
}
|
||||
|
||||
// ── TrictracState ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracState::TrictracState(std::shared_ptr<const Game> game)
|
||||
: State(game),
|
||||
engine_(trictrac_engine::new_trictrac_engine()) {}
|
||||
|
||||
// Copy constructor: deep-copy the Rust engine via clone_engine().
|
||||
TrictracState::TrictracState(const TrictracState& other)
|
||||
: State(other),
|
||||
engine_(other.engine_->clone_engine()) {}
|
||||
|
||||
std::unique_ptr<State> TrictracState::Clone() const {
|
||||
return std::make_unique<TrictracState>(*this);
|
||||
}
|
||||
|
||||
// ── Current player ────────────────────────────────────────────────────────────
|
||||
|
||||
Player TrictracState::CurrentPlayer() const {
|
||||
if (engine_->is_game_ended()) return kTerminalPlayerId;
|
||||
if (engine_->needs_roll()) return kChancePlayerId;
|
||||
return static_cast<Player>(engine_->current_player_idx());
|
||||
}
|
||||
|
||||
// ── Legal actions ─────────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<Action> TrictracState::LegalActions() const {
|
||||
if (IsChanceNode()) {
|
||||
// All 36 dice outcomes are equally likely; return indices 0–35.
|
||||
std::vector<Action> actions(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i;
|
||||
return actions;
|
||||
}
|
||||
Player player = CurrentPlayer();
|
||||
rust::Vec<uint64_t> rust_actions =
|
||||
engine_->get_legal_actions(static_cast<uint64_t>(player));
|
||||
std::vector<Action> actions;
|
||||
actions.reserve(rust_actions.size());
|
||||
for (uint64_t a : rust_actions) actions.push_back(static_cast<Action>(a));
|
||||
return actions;
|
||||
}
|
||||
|
||||
// ── Chance outcomes ───────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<std::pair<Action, double>> TrictracState::ChanceOutcomes() const {
|
||||
SPIEL_CHECK_TRUE(IsChanceNode());
|
||||
const double p = 1.0 / kNumChanceOutcomes;
|
||||
std::vector<std::pair<Action, double>> outcomes;
|
||||
outcomes.reserve(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p);
|
||||
return outcomes;
|
||||
}
|
||||
|
||||
// ── Apply action ──────────────────────────────────────────────────────────────
|
||||
|
||||
/*static*/
|
||||
trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) {
|
||||
// Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
return trictrac_engine::DicePair{
|
||||
/*die1=*/static_cast<uint8_t>(action / 6 + 1),
|
||||
/*die2=*/static_cast<uint8_t>(action % 6 + 1),
|
||||
};
|
||||
}
|
||||
|
||||
void TrictracState::DoApplyAction(Action action) {
|
||||
if (IsChanceNode()) {
|
||||
engine_->apply_dice_roll(DecodeChanceAction(action));
|
||||
} else {
|
||||
engine_->apply_action(static_cast<uint64_t>(action));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Terminal & returns ────────────────────────────────────────────────────────
|
||||
|
||||
bool TrictracState::IsTerminal() const {
|
||||
return engine_->is_game_ended();
|
||||
}
|
||||
|
||||
std::vector<double> TrictracState::Returns() const {
|
||||
trictrac_engine::PlayerScores scores = engine_->get_players_scores();
|
||||
return {static_cast<double>(scores.score_p1),
|
||||
static_cast<double>(scores.score_p2)};
|
||||
}
|
||||
|
||||
// ── Observation ───────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ObservationString(Player player) const {
|
||||
return std::string(engine_->get_observation_string(
|
||||
static_cast<uint64_t>(player)));
|
||||
}
|
||||
|
||||
void TrictracState::ObservationTensor(Player player,
|
||||
absl::Span<float> values) const {
|
||||
SPIEL_CHECK_EQ(values.size(), kStateEncodingSize);
|
||||
rust::Vec<int8_t> tensor =
|
||||
engine_->get_tensor(static_cast<uint64_t>(player));
|
||||
SPIEL_CHECK_EQ(tensor.size(), static_cast<size_t>(kStateEncodingSize));
|
||||
for (int i = 0; i < kStateEncodingSize; ++i) {
|
||||
values[i] = static_cast<float>(tensor[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strings ───────────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ToString() const {
|
||||
return std::string(engine_->to_debug_string());
|
||||
}
|
||||
|
||||
std::string TrictracState::ActionToString(Player player, Action action) const {
|
||||
if (IsChanceNode()) {
|
||||
trictrac_engine::DicePair d = DecodeChanceAction(action);
|
||||
return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")";
|
||||
}
|
||||
return std::string(engine_->action_to_string(
|
||||
static_cast<uint64_t>(player), static_cast<uint64_t>(action)));
|
||||
}
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Step 7 — C++: `trictrac/trictrac_test.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac_test.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/tests/basic_tests.h"
|
||||
#include "open_spiel/utils/init.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
void BasicTrictracTests() {
|
||||
testing::LoadGameTest("trictrac");
|
||||
testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
open_spiel::Init(&argc, &argv);
|
||||
open_spiel::trictrac::BasicTrictracTests();
|
||||
std::cout << "trictrac tests passed" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt`
|
||||
|
||||
The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined:
|
||||
|
||||
```cmake
|
||||
# ── Corrosion: CMake integration for Rust ────────────────────────────────────
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
Corrosion
|
||||
GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git
|
||||
GIT_TAG v0.5.1 # pin to a stable release
|
||||
)
|
||||
FetchContent_MakeAvailable(Corrosion)
|
||||
|
||||
# Import the trictrac-store Rust crate.
|
||||
# This creates a CMake target named 'trictrac-store'.
|
||||
corrosion_import_crate(
|
||||
MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml
|
||||
CRATES trictrac-store
|
||||
)
|
||||
|
||||
# Generate the cxx bridge from cxxengine.rs.
|
||||
# corrosion_add_cxxbridge:
|
||||
# - runs cxx-build as part of the Rust build
|
||||
# - creates a CMake target 'trictrac_cxx_bridge' that:
|
||||
# * compiles the generated C++ shim
|
||||
# * exposes INTERFACE include dirs for the generated .rs.h header
|
||||
corrosion_add_cxxbridge(trictrac_cxx_bridge
|
||||
CRATE trictrac-store
|
||||
FILES src/cxxengine.rs
|
||||
)
|
||||
```
|
||||
|
||||
> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing.
|
||||
|
||||
---
|
||||
|
||||
## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt`
|
||||
|
||||
Two changes: add the new source files to `GAME_SOURCES`, and add a test target.
|
||||
|
||||
### 11.1 Add to `GAME_SOURCES`
|
||||
|
||||
Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add:
|
||||
|
||||
```cmake
|
||||
set(GAME_SOURCES
|
||||
# ... existing games ...
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
# ... remaining games ...
|
||||
)
|
||||
```
|
||||
|
||||
### 11.2 Link cxx bridge into OpenSpiel objects
|
||||
|
||||
The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed.
|
||||
|
||||
The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add:
|
||||
|
||||
```cmake
|
||||
target_link_libraries(open_spiel
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge # C++ shim generated by cxx-build
|
||||
trictrac-store # Rust static library
|
||||
)
|
||||
```
|
||||
|
||||
If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game:
|
||||
|
||||
```cmake
|
||||
add_library(trictrac_game OBJECT
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
)
|
||||
target_include_directories(trictrac_game
|
||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
target_link_libraries(trictrac_game
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
open_spiel_core # or whatever the core target is called
|
||||
)
|
||||
```
|
||||
|
||||
Then reference `$<TARGET_OBJECTS:trictrac_game>` in relevant executables.
|
||||
|
||||
### 11.3 Add the test
|
||||
|
||||
```cmake
|
||||
add_executable(trictrac_test
|
||||
trictrac/trictrac_test.cc
|
||||
${OPEN_SPIEL_OBJECTS}
|
||||
$<TARGET_OBJECTS:tests>
|
||||
)
|
||||
target_link_libraries(trictrac_test
|
||||
PRIVATE
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
)
|
||||
add_test(trictrac_test trictrac_test)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Step 10 — Justfile updates
|
||||
|
||||
### `trictrac/justfile` — add `cxxlib` target
|
||||
|
||||
Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists:
|
||||
|
||||
```just
|
||||
cxxlib:
|
||||
cargo build --release -p trictrac-store
|
||||
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
|
||||
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
|
||||
```
|
||||
|
||||
### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac`
|
||||
|
||||
```just
|
||||
buildtrictrac:
|
||||
# Rebuild the Rust static lib first, then CMake
|
||||
cd ../../trictrac && cargo build --release -p trictrac-store
|
||||
mkdir -p build && cd build && \
|
||||
CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \
|
||||
make -j$(nproc) trictrac_test
|
||||
|
||||
testtrictrac: buildtrictrac
|
||||
./build/trictrac_test
|
||||
|
||||
playtrictrac_cpp:
|
||||
./build/examples/example --game=trictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Key Design Decisions
|
||||
|
||||
### 13.1 Opaque type with `clone_engine()`
|
||||
|
||||
OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box<TricTracEngine>` which calls `.clone()` on the inner `GameState` (which derives `Clone`).
|
||||
|
||||
### 13.2 Action encoding: same 514-element space
|
||||
|
||||
The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means:
|
||||
|
||||
- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies.
|
||||
- Action 0 = Roll (used as the bridge between Move and the next chance node).
|
||||
- Actions 2–513 = Move variants (checker ordinal pair + dice order).
|
||||
- A trained C++ model and Python model share the same action space.
|
||||
|
||||
### 13.3 Chance outcome ordering
|
||||
|
||||
The dice outcome ordering is identical to the Python version:
|
||||
|
||||
```
|
||||
action → (die1, die2)
|
||||
0 → (1,1) 6 → (2,1) ... 35 → (6,6)
|
||||
```
|
||||
|
||||
(`die1 = action/6 + 1`, `die2 = action%6 + 1`)
|
||||
|
||||
This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines.
|
||||
|
||||
### 13.4 `GameType::Utility::kGeneralSum` + `kRewards`
|
||||
|
||||
Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal.
|
||||
|
||||
### 13.5 Mirror pattern preserved
|
||||
|
||||
`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest.
|
||||
|
||||
### 13.6 `rust::Box` vs `rust::UniquePtr`
|
||||
|
||||
`rust::Box<T>` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr<T>` is for C++ types held in Rust.
|
||||
|
||||
### 13.7 Separate struct from `pyengine.rs`
|
||||
|
||||
`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids:
|
||||
|
||||
- PyO3 and cxx attributes conflicting on the same type.
|
||||
- Changes to one binding breaking the other.
|
||||
- Feature-flag complexity.
|
||||
|
||||
---
|
||||
|
||||
## 14. Known Challenges
|
||||
|
||||
### 14.1 Corrosion path resolution
|
||||
|
||||
`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable).
|
||||
|
||||
### 14.2 `staticlib` + `cdylib` in one crate
|
||||
|
||||
Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions:
|
||||
|
||||
- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time.
|
||||
- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile.
|
||||
- Alternatively, accept the extra compile time during development.
|
||||
|
||||
### 14.3 Linker symbols from Rust std
|
||||
|
||||
When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`).
|
||||
|
||||
### 14.4 `anyhow` for error types
|
||||
|
||||
cxx.rs requires the `Err` type in `Result<T, E>` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper:
|
||||
|
||||
```rust
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EngineError(String);
|
||||
impl fmt::Display for EngineError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) }
|
||||
}
|
||||
impl std::error::Error for EngineError {}
|
||||
```
|
||||
|
||||
On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`.
|
||||
|
||||
### 14.5 `UndoAction` not implemented
|
||||
|
||||
OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo.
|
||||
|
||||
### 14.6 Generated header path in `#include`
|
||||
|
||||
The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory.
|
||||
|
||||
### 14.7 `rust::String` to `std::string` conversion
|
||||
|
||||
The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast<std::string>(...)`.
|
||||
|
||||
---
|
||||
|
||||
## 15. Complete File Checklist
|
||||
|
||||
```
|
||||
[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib
|
||||
[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...)
|
||||
[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;`
|
||||
[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation
|
||||
[ ] trictrac/justfile — add `cxxlib` target
|
||||
[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge
|
||||
[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file
|
||||
[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 16. Implementation Order
|
||||
|
||||
Implement in this sequence to get early feedback at each step:
|
||||
|
||||
1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced.
|
||||
2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations.
|
||||
3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist.
|
||||
4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them.
|
||||
5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination.
|
||||
6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`.
|
||||
|
||||
---
|
||||
|
||||
# Implementation complete
|
||||
|
||||
All files are in place and trictrac_test passes (168 84 final scores, all assertions pass).
|
||||
|
||||
What was done
|
||||
|
||||
Rust side (from previous session, already complete):
|
||||
|
||||
- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build
|
||||
- trictrac/store/build.rs — drives cxx-build
|
||||
- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl
|
||||
- trictrac/store/src/lib.rs — added pub mod cxxengine;
|
||||
|
||||
C++ side (this session):
|
||||
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test
|
||||
|
||||
Build system:
|
||||
|
||||
- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge
|
||||
- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable
|
||||
|
||||
Justfiles:
|
||||
|
||||
- trictrac/justfile — added cxxlib target
|
||||
- forks/open_spiel/justfile — added buildtrictrac and testtrictrac
|
||||
|
||||
Fixes discovered during build
|
||||
|
||||
| Issue | Fix |
|
||||
| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries |
|
||||
| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) |
|
||||
| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h |
|
||||
| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc |
|
||||
| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) |
|
||||
| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries |
|
||||
| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort |
|
||||
| Duplicate actions for doubles | Added std::unique after sort |
|
||||
| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states |
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
Part B — Batched MCTS leaf evaluation
|
||||
|
||||
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
|
||||
|
||||
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
|
||||
|
||||
pub trait Evaluator: Send + Sync {
|
||||
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||
|
||||
/// Evaluate a batch of observations at once. Default falls back to
|
||||
/// sequential calls; backends override this for efficiency.
|
||||
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
|
||||
|
||||
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
|
||||
|
||||
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||
let b = obs_batch.len();
|
||||
let obs_size = obs_batch[0].len();
|
||||
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
|
||||
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||
let action_size = policies.len() / b;
|
||||
(0..b).map(|i| {
|
||||
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
|
||||
}).collect()
|
||||
}
|
||||
|
||||
Step B3 — Add eval_batch_size to MctsConfig
|
||||
|
||||
pub struct MctsConfig {
|
||||
// ... existing fields ...
|
||||
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
|
||||
pub eval_batch_size: usize,
|
||||
}
|
||||
|
||||
Default: 1 (backwards-compatible).
|
||||
|
||||
Step B4 — Make the simulation iterative (mcts/search.rs)
|
||||
|
||||
The current simulate is recursive. For batching we need to split it into two phases:
|
||||
|
||||
descend (pure tree traversal — no network call):
|
||||
|
||||
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
|
||||
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
|
||||
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
|
||||
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
|
||||
|
||||
ascend (backup — no network call):
|
||||
|
||||
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
|
||||
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
|
||||
|
||||
Step B5 — Add run_mcts_batched to mcts/mod.rs
|
||||
|
||||
The new entry point, called by run_mcts when config.eval_batch_size > 1:
|
||||
|
||||
expand root (1 network call)
|
||||
optionally add Dirichlet noise
|
||||
|
||||
for round in 0..(n*simulations / batch_size):
|
||||
leaves = []
|
||||
for * in 0..batch_size:
|
||||
leaf = descend(root, state, env, rng)
|
||||
leaves.push(leaf)
|
||||
|
||||
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
|
||||
where leaf.kind == NeedsEval]
|
||||
results = evaluator.evaluate_batch(obs_batch)
|
||||
|
||||
for (leaf, result) in zip(leaves, results):
|
||||
expand the leaf node (insert children from result.policy)
|
||||
ascend(root, leaf.path, result.value, leaf.player_idx)
|
||||
// ascend also handles terminal and crossed-chance leaves
|
||||
|
||||
// handle remainder: n_simulations % batch_size
|
||||
|
||||
run_mcts becomes a thin dispatcher:
|
||||
if config.eval_batch_size <= 1 {
|
||||
// existing path (unchanged)
|
||||
} else {
|
||||
run_mcts_batched(...)
|
||||
}
|
||||
|
||||
Step B6 — CLI flag in az_train.rs
|
||||
|
||||
--eval-batch N default: 8 Leaf batch size for MCTS network calls
|
||||
|
||||
---
|
||||
|
||||
Summary of file changes
|
||||
|
||||
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
|
||||
│ File │ Changes │
|
||||
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||
│ spiel_bot/Cargo.toml │ add rayon │
|
||||
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
|
||||
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
|
||||
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
|
||||
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
|
||||
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
Key design constraint
|
||||
|
||||
Parts A and B are independent and composable:
|
||||
|
||||
- A only touches the outer game loop.
|
||||
- B only touches the inner MCTS per game.
|
||||
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.
|
||||
|
|
@ -1,782 +0,0 @@
|
|||
# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac
|
||||
|
||||
## 0. Context and Scope
|
||||
|
||||
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
||||
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
||||
other stage to an inline random-opponent loop.
|
||||
|
||||
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
||||
for **self-play training**. Its goals:
|
||||
|
||||
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
||||
that works with Trictrac's multi-stage turn model and stochastic dice.
|
||||
- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer)
|
||||
as the first algorithm.
|
||||
- Remain **modular**: adding DQN or PPO later requires only a new
|
||||
`impl Algorithm for Dqn` without touching the environment or network layers.
|
||||
- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from
|
||||
`trictrac-store`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Library Landscape
|
||||
|
||||
### 1.1 Neural Network Frameworks
|
||||
|
||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|
||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||
|
||||
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
||||
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
||||
switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by
|
||||
changing one type alias.
|
||||
|
||||
`tch-rs` would be the best choice for raw training throughput (it is the most
|
||||
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
||||
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
||||
|
||||
### 1.2 Other Key Crates
|
||||
|
||||
| Crate | Role |
|
||||
| -------------------- | ----------------------------------------------------------------- |
|
||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||
| `anyhow` | error propagation (already used everywhere) |
|
||||
| `indicatif` | training progress bars |
|
||||
| `tracing` | structured logging per episode/iteration |
|
||||
|
||||
### 1.3 What `burn-rl` Provides (and Does Not)
|
||||
|
||||
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
||||
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
||||
trait. It does **not** provide:
|
||||
|
||||
- MCTS or any tree-search algorithm
|
||||
- Two-player self-play
|
||||
- Legal action masking during training
|
||||
- Chance-node handling
|
||||
|
||||
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||
own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||
|
||||
---
|
||||
|
||||
## 2. Trictrac-Specific Design Constraints
|
||||
|
||||
### 2.1 Multi-Stage Turn Model
|
||||
|
||||
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||
genuine player choice:
|
||||
|
||||
| TurnStage | Node type | Handler |
|
||||
| ---------------- | ------------------------------- | ------------------------------- |
|
||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||
| `Move` | **Player decision** | MCTS / policy network |
|
||||
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||
|
||||
The environment wrapper advances through forced/chance stages automatically so
|
||||
that from the algorithm's perspective every node it sees is a genuine player
|
||||
decision.
|
||||
|
||||
### 2.2 Stochastic Dice in MCTS
|
||||
|
||||
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||
introduce stochasticity. Three approaches exist:
|
||||
|
||||
**A. Outcome sampling (recommended)**
|
||||
During each MCTS simulation, when a chance node is reached, sample one dice
|
||||
outcome at random and continue. After many simulations the expected value
|
||||
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||
and requires no changes to the standard PUCT formula.
|
||||
|
||||
**B. Chance-node averaging (expectimax)**
|
||||
At each chance node, expand all 21 unique dice pairs weighted by their
|
||||
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
||||
prohibitively expensive.
|
||||
|
||||
**C. Condition on dice in the observation (current approach)**
|
||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||
network naturally conditions on the rolled dice when it evaluates a position.
|
||||
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
|
||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||
and value heads learn to play optimally given any dice pair.
|
||||
|
||||
**Use approach A + C together**: the environment samples dice automatically
|
||||
(chance node bypass), and the 217-dim tensor encodes the dice so the network
|
||||
can exploit them.
|
||||
|
||||
### 2.3 Perspective / Mirroring
|
||||
|
||||
All move rules and tensor encoding are defined from White's perspective.
|
||||
`to_tensor()` must always be called after mirroring the state for Black.
|
||||
The environment wrapper handles this transparently: every observation returned
|
||||
to an algorithm is already in the active player's perspective.
|
||||
|
||||
### 2.4 Legal Action Masking
|
||||
|
||||
A crucial difference from the existing `bot/` code: instead of penalizing
|
||||
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
||||
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||
network from wasting capacity on illegal moves and eliminates the need for the
|
||||
penalty-reward hack.
|
||||
|
||||
---
|
||||
|
||||
## 3. Proposed Crate Architecture
|
||||
|
||||
```
|
||||
spiel_bot/
|
||||
├── Cargo.toml
|
||||
└── src/
|
||||
├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo"
|
||||
│
|
||||
├── env/
|
||||
│ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface
|
||||
│ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store
|
||||
│
|
||||
├── mcts/
|
||||
│ ├── mod.rs # MctsConfig, run_mcts() entry point
|
||||
│ ├── node.rs # MctsNode (visit count, W, prior, children)
|
||||
│ └── search.rs # simulate(), backup(), select_action()
|
||||
│
|
||||
├── network/
|
||||
│ ├── mod.rs # PolicyValueNet trait
|
||||
│ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads
|
||||
│
|
||||
├── alphazero/
|
||||
│ ├── mod.rs # AlphaZeroConfig
|
||||
│ ├── selfplay.rs # generate_episode() -> Vec<TrainSample>
|
||||
│ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle)
|
||||
│ └── trainer.rs # training loop: selfplay → sample → loss → update
|
||||
│
|
||||
└── agent/
|
||||
├── mod.rs # Agent trait
|
||||
├── random.rs # RandomAgent (baseline)
|
||||
└── mcts_agent.rs # MctsAgent: uses trained network for inference
|
||||
```
|
||||
|
||||
Future algorithms slot in without touching the above:
|
||||
|
||||
```
|
||||
├── dqn/ # (future) DQN: impl Algorithm + own replay buffer
|
||||
└── ppo/ # (future) PPO: impl Algorithm + rollout buffer
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Core Traits
|
||||
|
||||
### 4.1 `GameEnv` — the minimal OpenSpiel interface
|
||||
|
||||
```rust
|
||||
use rand::Rng;
|
||||
|
||||
/// Who controls the current node.
|
||||
pub enum Player {
|
||||
P1, // player index 0
|
||||
P2, // player index 1
|
||||
Chance, // dice roll
|
||||
Terminal, // game over
|
||||
}
|
||||
|
||||
pub trait GameEnv: Clone + Send + Sync + 'static {
|
||||
type State: Clone + Send + Sync;
|
||||
|
||||
/// Fresh game state.
|
||||
fn new_game(&self) -> Self::State;
|
||||
|
||||
/// Who acts at this node.
|
||||
fn current_player(&self, s: &Self::State) -> Player;
|
||||
|
||||
/// Legal action indices (always in [0, action_space())).
|
||||
/// Empty only at Terminal nodes.
|
||||
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
|
||||
|
||||
/// Apply a player action (must be legal).
|
||||
fn apply(&self, s: &mut Self::State, action: usize);
|
||||
|
||||
/// Advance a Chance node by sampling dice; no-op at non-Chance nodes.
|
||||
fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng);
|
||||
|
||||
/// Observation tensor from `pov`'s perspective (0 or 1).
|
||||
/// Returns 217 f32 values for Trictrac.
|
||||
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
|
||||
|
||||
/// Flat observation size (217 for Trictrac).
|
||||
fn obs_size(&self) -> usize;
|
||||
|
||||
/// Total action-space size (514 for Trictrac).
|
||||
fn action_space(&self) -> usize;
|
||||
|
||||
/// Game outcome per player, or None if not Terminal.
|
||||
/// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw.
|
||||
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 `PolicyValueNet` — neural network interface
|
||||
|
||||
```rust
|
||||
use burn::prelude::*;
|
||||
|
||||
pub trait PolicyValueNet<B: Backend>: Send + Sync {
|
||||
/// Forward pass.
|
||||
/// `obs`: [batch, obs_size] tensor.
|
||||
/// Returns: (policy_logits [batch, action_space], value [batch]).
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
|
||||
|
||||
/// Save weights to `path`.
|
||||
fn save(&self, path: &std::path::Path) -> anyhow::Result<()>;
|
||||
|
||||
/// Load weights from `path`.
|
||||
fn load(path: &std::path::Path) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 `Agent` — player policy interface
|
||||
|
||||
```rust
|
||||
pub trait Agent: Send {
|
||||
/// Select an action index given the current game state observation.
|
||||
/// `legal`: mask of valid action indices.
|
||||
fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. MCTS Implementation
|
||||
|
||||
### 5.1 Node
|
||||
|
||||
```rust
|
||||
pub struct MctsNode {
|
||||
n: u32, // visit count N(s, a)
|
||||
w: f32, // sum of backed-up values W(s, a)
|
||||
p: f32, // prior from policy head P(s, a)
|
||||
children: Vec<(usize, MctsNode)>, // (action_idx, child)
|
||||
is_expanded: bool,
|
||||
}
|
||||
|
||||
impl MctsNode {
|
||||
pub fn q(&self) -> f32 {
|
||||
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||
}
|
||||
|
||||
/// PUCT score used for selection.
|
||||
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 Simulation Loop
|
||||
|
||||
One MCTS simulation (for deterministic decision nodes):
|
||||
|
||||
```
|
||||
1. SELECTION — traverse from root, always pick child with highest PUCT,
|
||||
auto-advancing forced/chance nodes via env.apply_chance().
|
||||
2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get
|
||||
(policy_logits, value). Mask illegal actions, softmax
|
||||
the remaining logits → priors P(s,a) for each child.
|
||||
3. BACKUP — propagate -value up the tree (negate at each level because
|
||||
perspective alternates between P1 and P2).
|
||||
```
|
||||
|
||||
After `n_simulations` iterations, action selection at the root:
|
||||
|
||||
```rust
|
||||
// During training: sample proportional to N^(1/temperature)
|
||||
// During evaluation: argmax N
|
||||
fn select_action(root: &MctsNode, temperature: f32) -> usize { ... }
|
||||
```
|
||||
|
||||
### 5.3 Configuration
|
||||
|
||||
```rust
|
||||
pub struct MctsConfig {
|
||||
pub n_simulations: usize, // e.g. 200
|
||||
pub c_puct: f32, // exploration constant, e.g. 1.5
|
||||
pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3
|
||||
pub dirichlet_eps: f32, // noise weight, e.g. 0.25
|
||||
pub temperature: f32, // action sampling temperature (anneals to 0)
|
||||
}
|
||||
```
|
||||
|
||||
### 5.4 Handling Chance Nodes Inside MCTS
|
||||
|
||||
When simulation reaches a Chance node (dice roll), the environment automatically
|
||||
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||
simulation independently samples dice, the Q-values at player nodes converge to
|
||||
their expected value over many simulations.
|
||||
|
||||
---
|
||||
|
||||
## 6. Network Architecture
|
||||
|
||||
### 6.1 ResNet Policy-Value Network
|
||||
|
||||
A single trunk with residual blocks, then two heads:
|
||||
|
||||
```
|
||||
Input: [batch, 217]
|
||||
↓
|
||||
Linear(217 → 512) + ReLU
|
||||
↓
|
||||
ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU)
|
||||
↓ trunk output [batch, 512]
|
||||
├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site)
|
||||
└─ Value head: Linear(512 → 1) → tanh (output in [-1, 1])
|
||||
```
|
||||
|
||||
Burn implementation sketch:
|
||||
|
||||
```rust
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TrictracNet<B: Backend> {
|
||||
input: Linear<B>,
|
||||
res_blocks: Vec<ResBlock<B>>,
|
||||
policy_head: Linear<B>,
|
||||
value_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TrictracNet<B> {
|
||||
pub fn forward(&self, obs: Tensor<B, 2>)
|
||||
-> (Tensor<B, 2>, Tensor<B, 1>)
|
||||
{
|
||||
let x = activation::relu(self.input.forward(obs));
|
||||
let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x));
|
||||
let policy = self.policy_head.forward(x.clone()); // raw logits
|
||||
let value = activation::tanh(self.value_head.forward(x))
|
||||
.squeeze(1);
|
||||
(policy, value)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
A simpler MLP (no residual blocks) is sufficient for a first version and much
|
||||
faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two
|
||||
heads.
|
||||
|
||||
### 6.2 Loss Function
|
||||
|
||||
```
|
||||
L = MSE(value_pred, z)
|
||||
+ CrossEntropy(policy_logits_masked, π_mcts)
|
||||
- c_l2 * L2_regularization
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `z` = game outcome (±1) from the active player's perspective
|
||||
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||
- Legal action masking is applied before computing CrossEntropy
|
||||
|
||||
---
|
||||
|
||||
## 7. AlphaZero Training Loop
|
||||
|
||||
```
|
||||
INIT
|
||||
network ← random weights
|
||||
replay ← empty ReplayBuffer(capacity = 100_000)
|
||||
|
||||
LOOP forever:
|
||||
── Self-play phase ──────────────────────────────────────────────
|
||||
(parallel with rayon, n_workers games at once)
|
||||
for each game:
|
||||
state ← env.new_game()
|
||||
samples = []
|
||||
while not terminal:
|
||||
advance forced/chance nodes automatically
|
||||
obs ← env.observation(state, current_player)
|
||||
legal ← env.legal_actions(state)
|
||||
π, root_value ← mcts.run(state, network, config)
|
||||
action ← sample from π (with temperature)
|
||||
samples.push((obs, π, current_player))
|
||||
env.apply(state, action)
|
||||
z ← env.returns(state) // final scores
|
||||
for (obs, π, player) in samples:
|
||||
replay.push(TrainSample { obs, policy: π, value: z[player] })
|
||||
|
||||
── Training phase ───────────────────────────────────────────────
|
||||
for each gradient step:
|
||||
batch ← replay.sample(batch_size)
|
||||
(policy_logits, value_pred) ← network.forward(batch.obs)
|
||||
loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy)
|
||||
optimizer.step(loss.backward())
|
||||
|
||||
── Evaluation (every N iterations) ─────────────────────────────
|
||||
win_rate ← evaluate(network_new vs network_prev, n_eval_games)
|
||||
if win_rate > 0.55: save checkpoint
|
||||
```
|
||||
|
||||
### 7.1 Replay Buffer
|
||||
|
||||
```rust
|
||||
pub struct TrainSample {
|
||||
pub obs: Vec<f32>, // 217 values
|
||||
pub policy: Vec<f32>, // 514 values (normalized MCTS visit counts)
|
||||
pub value: f32, // game outcome ∈ {-1, 0, +1}
|
||||
}
|
||||
|
||||
pub struct ReplayBuffer {
|
||||
data: VecDeque<TrainSample>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn push(&mut self, s: TrainSample) {
|
||||
if self.data.len() == self.capacity { self.data.pop_front(); }
|
||||
self.data.push_back(s);
|
||||
}
|
||||
|
||||
pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||
// sample without replacement
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7.2 Parallelism Strategy
|
||||
|
||||
Self-play is embarrassingly parallel (each game is independent):
|
||||
|
||||
```rust
|
||||
let samples: Vec<TrainSample> = (0..n_games)
|
||||
.into_par_iter() // rayon
|
||||
.flat_map(|_| generate_episode(&env, &network, &mcts_config))
|
||||
.collect();
|
||||
```
|
||||
|
||||
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
||||
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
||||
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||
`Autodiff<NdArray<f32>>`.
|
||||
|
||||
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
||||
self-play workers from the training thread, allowing continuous data generation
|
||||
while the GPU trains.
|
||||
|
||||
---
|
||||
|
||||
## 8. `TrictracEnv` Implementation Sketch
|
||||
|
||||
```rust
|
||||
use trictrac_store::{
|
||||
training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE},
|
||||
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TrictracEnv;
|
||||
|
||||
impl GameEnv for TrictracEnv {
|
||||
type State = GameState;
|
||||
|
||||
fn new_game(&self) -> GameState {
|
||||
GameState::new_with_players("P1", "P2")
|
||||
}
|
||||
|
||||
fn current_player(&self, s: &GameState) -> Player {
|
||||
match s.stage {
|
||||
Stage::Ended => Player::Terminal,
|
||||
_ => match s.turn_stage {
|
||||
TurnStage::RollWaiting => Player::Chance,
|
||||
_ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
|
||||
let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() };
|
||||
get_valid_action_indices(&view).unwrap_or_default()
|
||||
}
|
||||
|
||||
fn apply(&self, s: &mut GameState, action_idx: usize) {
|
||||
// advance all forced/chance nodes first, then apply the player action
|
||||
self.advance_forced(s);
|
||||
let needs_mirror = s.active_player_id == 2;
|
||||
let view = if needs_mirror { s.mirror() } else { s.clone() };
|
||||
if let Some(event) = TrictracAction::from_action_index(action_idx)
|
||||
.and_then(|a| a.to_event(&view))
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
{
|
||||
let _ = s.consume(&event);
|
||||
}
|
||||
// advance any forced stages that follow
|
||||
self.advance_forced(s);
|
||||
}
|
||||
|
||||
fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) {
|
||||
// RollDice → RollWaiting
|
||||
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||
// RollWaiting → next stage
|
||||
let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) };
|
||||
let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice });
|
||||
self.advance_forced(s);
|
||||
}
|
||||
|
||||
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
|
||||
if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() }
|
||||
}
|
||||
|
||||
fn obs_size(&self) -> usize { 217 }
|
||||
fn action_space(&self) -> usize { ACTION_SPACE_SIZE }
|
||||
|
||||
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
|
||||
if s.stage != Stage::Ended { return None; }
|
||||
// Convert hole+point scores to ±1 outcome
|
||||
let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||
let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||
Some(match s1.cmp(&s2) {
|
||||
std::cmp::Ordering::Greater => [ 1.0, -1.0],
|
||||
std::cmp::Ordering::Less => [-1.0, 1.0],
|
||||
std::cmp::Ordering::Equal => [ 0.0, 0.0],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
/// Advance through all forced (non-decision, non-chance) stages.
|
||||
fn advance_forced(&self, s: &mut GameState) {
|
||||
use trictrac_store::PointsRules;
|
||||
loop {
|
||||
match s.turn_stage {
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||
// Scoring is deterministic; compute and apply automatically.
|
||||
let color = s.player_color_by_id(&s.active_player_id)
|
||||
.unwrap_or(trictrac_store::Color::White);
|
||||
let drc = s.players.get(&s.active_player_id)
|
||||
.map(|p| p.dice_roll_count).unwrap_or(0);
|
||||
let pr = PointsRules::new(&color, &s.board, s.dice);
|
||||
let pts = pr.get_points(drc);
|
||||
let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 };
|
||||
let _ = s.consume(&GameEvent::Mark {
|
||||
player_id: s.active_player_id, points,
|
||||
});
|
||||
}
|
||||
TurnStage::RollDice => {
|
||||
// RollDice is a forced "initiate roll" action with no real choice.
|
||||
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Cargo.toml Changes
|
||||
|
||||
### 9.1 Add `spiel_bot` to the workspace
|
||||
|
||||
```toml
|
||||
# Cargo.toml (workspace root)
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["client_cli", "bot", "store", "spiel_bot"]
|
||||
```
|
||||
|
||||
### 9.2 `spiel_bot/Cargo.toml`
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "spiel_bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["alphazero"]
|
||||
alphazero = []
|
||||
# dqn = [] # future
|
||||
# ppo = [] # future
|
||||
|
||||
[dependencies]
|
||||
trictrac-store = { path = "../store" }
|
||||
anyhow = "1"
|
||||
rand = "0.9"
|
||||
rayon = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Burn: NdArray for pure-Rust CPU training
|
||||
# Replace NdArray with Wgpu or Tch for GPU.
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
|
||||
# Optional: progress display and structured logging
|
||||
indicatif = "0.17"
|
||||
tracing = "0.1"
|
||||
|
||||
[[bin]]
|
||||
name = "az_train"
|
||||
path = "src/bin/az_train.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "az_eval"
|
||||
path = "src/bin/az_eval.rs"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||
|
||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||
| ---------------- | --------------------------- | -------------------------------------------- |
|
||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||
| Opponent | hardcoded random | self-play |
|
||||
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||
| `burn-rl` dep | yes | no |
|
||||
| C++ dep | no | no |
|
||||
| Python dep | no | no |
|
||||
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||
|
||||
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
||||
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
||||
replace `TrictracEnvironment` with `TrictracEnv`).
|
||||
|
||||
---
|
||||
|
||||
## 11. Implementation Order
|
||||
|
||||
1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game
|
||||
through the trait, verify terminal state and returns).
|
||||
2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) +
|
||||
Burn forward/backward pass test with dummy data.
|
||||
3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests
|
||||
(visit counts sum to `n_simulations`, legal mask respected).
|
||||
4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub
|
||||
(one iteration, check loss decreases).
|
||||
5. **Integration test**: run 100 self-play games with a tiny network (1 res block,
|
||||
64 hidden units), verify the training loop completes without panics.
|
||||
6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s
|
||||
on CPU, consistent with `random_game` throughput).
|
||||
7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule
|
||||
hyperparameter sweep.
|
||||
8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report
|
||||
win rate every checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## 12. Key Open Questions
|
||||
|
||||
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
||||
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||
scored more holes). Better option: normalize the score margin. The final
|
||||
choice affects how the value head is trained.
|
||||
|
||||
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
||||
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
||||
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||
limiting `n_simulations` to 50–100 for early training.
|
||||
|
||||
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
||||
This is a long-horizon decision that AlphaZero handles naturally via MCTS
|
||||
lookahead, but needs careful value normalization (a "Go" restarts scoring
|
||||
within the same game).
|
||||
|
||||
4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated
|
||||
to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic.
|
||||
This is optional but reduces code duplication.
|
||||
|
||||
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||
|
||||
## Implementation results
|
||||
|
||||
All benchmarks compile and run. Here's the complete results table:
|
||||
|
||||
| Group | Benchmark | Time |
|
||||
| ------- | ----------------------- | --------------------- |
|
||||
| env | apply_chance | 3.87 µs |
|
||||
| | legal_actions | 1.91 µs |
|
||||
| | observation (to_tensor) | 341 ns |
|
||||
| | random_game (baseline) | 3.55 ms → 282 games/s |
|
||||
| network | mlp_b1 hidden=64 | 94.9 µs |
|
||||
| | mlp_b32 hidden=64 | 141 µs |
|
||||
| | mlp_b1 hidden=256 | 352 µs |
|
||||
| | mlp_b32 hidden=256 | 479 µs |
|
||||
| mcts | zero_eval n=1 | 6.8 µs |
|
||||
| | zero_eval n=5 | 23.9 µs |
|
||||
| | zero_eval n=20 | 90.9 µs |
|
||||
| | mlp64 n=1 | 203 µs |
|
||||
| | mlp64 n=5 | 622 µs |
|
||||
| | mlp64 n=20 | 2.30 ms |
|
||||
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
|
||||
| | trictrac n=2 | 145 ms → 7 games/s |
|
||||
| train | mlp64 Adam b=16 | 1.93 ms |
|
||||
| | mlp64 Adam b=64 | 2.68 ms |
|
||||
|
||||
Key observations:
|
||||
|
||||
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
|
||||
- observation (217-value tensor): only 341 ns — not a bottleneck
|
||||
- legal_actions: 1.9 µs — well optimised
|
||||
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
|
||||
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
|
||||
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
|
||||
- Training: 2.7 ms/step at batch=64 → 370 steps/s
|
||||
|
||||
### Summary of Step 8
|
||||
|
||||
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
|
||||
|
||||
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
|
||||
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
|
||||
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
|
||||
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
|
||||
- Output: win/draw/loss per side + combined decisive win rate
|
||||
|
||||
Typical usage after training:
|
||||
cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
|
||||
|
||||
### az_train
|
||||
|
||||
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release
|
||||
|
||||
#### ResNet, more games, custom output dir
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
|
||||
--save-every 20 --out checkpoints/
|
||||
|
||||
#### Resume from iteration 50
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
|
||||
|
||||
What the binary does each iteration:
|
||||
|
||||
1. Calls model.valid() to get a zero-overhead inference copy for self-play
|
||||
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
|
||||
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
|
||||
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
|
||||
5. Saves a .mpk checkpoint every --save-every iterations and always on the last
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
# Tensor research
|
||||
|
||||
## Current tensor anatomy
|
||||
|
||||
[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!)
|
||||
[24] active player color: 0 or 1
|
||||
[25] turn_stage: 1–5
|
||||
[26–27] dice values (raw 1–6)
|
||||
[28–31] white: points, holes, can_bredouille, can_big_bredouille
|
||||
[32–35] black: same
|
||||
─────────────────────────────────
|
||||
Total 36 floats
|
||||
|
||||
The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's
|
||||
AlphaZero uses a fully-connected network.
|
||||
|
||||
### Fundamental problems with the current encoding
|
||||
|
||||
1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network
|
||||
must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with
|
||||
all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from.
|
||||
|
||||
2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient
|
||||
flow during training is uneven.
|
||||
|
||||
3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it
|
||||
triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers
|
||||
produces a score. Including this explicitly is the single highest-value addition.
|
||||
|
||||
4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely
|
||||
different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0.
|
||||
|
||||
5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the
|
||||
starting position). It's in the Player struct but not exported.
|
||||
|
||||
## Key Trictrac distinctions from backgammon that shape the encoding
|
||||
|
||||
| Concept | Backgammon | Trictrac |
|
||||
| ------------------------- | ---------------------- | --------------------------------------------------------- |
|
||||
| Hitting a blot | Removes checker to bar | Scores points, checker stays |
|
||||
| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened |
|
||||
| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) |
|
||||
| 3-checker field | Safe with spare | Safe with spare |
|
||||
| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) |
|
||||
| Both colors on a field | Impossible | Perfectly legal |
|
||||
| Rest corner (field 12/13) | Does not exist | Special two-checker rules |
|
||||
|
||||
The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary
|
||||
indicators directly teaches the network the phase transitions the game hinges on.
|
||||
|
||||
## Options
|
||||
|
||||
### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D)
|
||||
|
||||
The minimum viable improvement.
|
||||
|
||||
For each of the 24 fields, encode own and opponent separately with 4 indicators each:
|
||||
|
||||
own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target)
|
||||
own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill)
|
||||
own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare)
|
||||
own_x[i]: max(0, count − 3) (overflow)
|
||||
opp_1[i]: same for opponent
|
||||
…
|
||||
|
||||
Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec().
|
||||
|
||||
Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204
|
||||
Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured
|
||||
overhead is negligible.
|
||||
Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from
|
||||
scratch will be substantially faster and the converged policy quality better.
|
||||
|
||||
### Option B — Option A + Trictrac-specific derived features (flat 1D)
|
||||
|
||||
Recommended starting point.
|
||||
|
||||
Add on top of Option A:
|
||||
|
||||
// Quarter fill status — the single most important derived feature
|
||||
quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields)
|
||||
quarter_filled_opp[q] (q=0..3): same for opponent
|
||||
→ 8 values
|
||||
|
||||
// Exit readiness
|
||||
can_exit_own: 1.0 if all own checkers are in fields 19–24
|
||||
can_exit_opp: same for opponent
|
||||
→ 2 values
|
||||
|
||||
// Rest corner status (field 12/13)
|
||||
own_corner_taken: 1.0 if field 12 has ≥2 own checkers
|
||||
opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers
|
||||
→ 2 values
|
||||
|
||||
// Jan de 3 coups counter (normalized)
|
||||
dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0)
|
||||
→ 1 value
|
||||
|
||||
Size: 204 + 8 + 2 + 2 + 1 = 217
|
||||
Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve
|
||||
the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes
|
||||
expensive inference from the network.
|
||||
|
||||
### Option C — Option B + richer positional features (flat 1D)
|
||||
|
||||
More complete, higher sample efficiency, minor extra cost.
|
||||
|
||||
Add on top of Option B:
|
||||
|
||||
// Per-quarter fill fraction — how close to filling each quarter
|
||||
own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0
|
||||
opp_quarter_fill_fraction[q] (q=0..3): same for opponent
|
||||
→ 8 values
|
||||
|
||||
// Blot counts — number of own/opponent single-checker fields globally
|
||||
// (tells the network at a glance how much battage risk/opportunity exists)
|
||||
own_blot_count: (number of own fields with exactly 1 checker) / 15.0
|
||||
opp_blot_count: same for opponent
|
||||
→ 2 values
|
||||
|
||||
// Bredouille would-double multiplier (already present, but explicitly scaled)
|
||||
// No change needed, already binary
|
||||
|
||||
Size: 217 + 8 + 2 = 227
|
||||
Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network
|
||||
from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts).
|
||||
|
||||
### Option D — 2D spatial tensor {K, 24}
|
||||
|
||||
For CNN-based networks. Best eventual architecture but requires changing the training setup.
|
||||
|
||||
Shape {14, 24} — 14 feature channels over 24 field positions:
|
||||
|
||||
Channel 0: own_count_1 (blot)
|
||||
Channel 1: own_count_2
|
||||
Channel 2: own_count_3
|
||||
Channel 3: own_count_overflow (float)
|
||||
Channel 4: opp_count_1
|
||||
Channel 5: opp_count_2
|
||||
Channel 6: opp_count_3
|
||||
Channel 7: opp_count_overflow
|
||||
Channel 8: own_corner_mask (1.0 at field 12)
|
||||
Channel 9: opp_corner_mask (1.0 at field 13)
|
||||
Channel 10: final_quarter_mask (1.0 at fields 19–24)
|
||||
Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter)
|
||||
Channel 12: quarter_filled_opp (same for opponent)
|
||||
Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers)
|
||||
|
||||
Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform
|
||||
value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global"
|
||||
row by returning shape {K, 25} with position 0 holding global features.
|
||||
|
||||
Size: 14 × 24 + few global channels ≈ 336–384
|
||||
C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated
|
||||
accordingly.
|
||||
Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's
|
||||
alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H,
|
||||
W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension.
|
||||
Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel
|
||||
size 2–3 captures adjacent-field "tout d'une" interactions.
|
||||
|
||||
### On 3D tensors
|
||||
|
||||
Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is
|
||||
the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and
|
||||
field-within-quarter patterns jointly.
|
||||
|
||||
However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The
|
||||
extra architecture work makes this premature unless you're already building a custom network. The information content
|
||||
is the same as Option D.
|
||||
|
||||
### Recommendation
|
||||
|
||||
Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and
|
||||
the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions
|
||||
(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move.
|
||||
|
||||
Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when
|
||||
the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time.
|
||||
|
||||
Costs summary:
|
||||
|
||||
| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain |
|
||||
| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- |
|
||||
| Current | 36 | — | — | — | baseline |
|
||||
| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) |
|
||||
| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) |
|
||||
| C | 227 | to_vec() rewrite | constant update | none | large + moderate |
|
||||
| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial |
|
||||
|
||||
One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2,
|
||||
the new to_vec() must express everything from the active player's perspective (which the mirror already handles for
|
||||
the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state,
|
||||
which they will be if computed inside to_vec().
|
||||
|
||||
## Other algorithms
|
||||
|
||||
The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully.
|
||||
|
||||
### 1. Without MCTS, feature quality matters more
|
||||
|
||||
AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred
|
||||
simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup —
|
||||
the network must learn the full strategic structure directly from gradient updates.
|
||||
|
||||
This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO,
|
||||
not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less
|
||||
sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from
|
||||
scratch.
|
||||
|
||||
### 2. Normalization becomes mandatory, not optional
|
||||
|
||||
AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps
|
||||
Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw
|
||||
values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can
|
||||
be unstable.
|
||||
|
||||
For DQN/PPO, every feature in the tensor should be in [0, 1]:
|
||||
|
||||
dice values: / 6.0
|
||||
checker counts: overflow channel / 12.0
|
||||
points: / 12.0
|
||||
holes: / 12.0
|
||||
dice_roll_count: / 3.0 (clamped)
|
||||
|
||||
Booleans and the TD-Gammon binary indicators are already in [0, 1].
|
||||
|
||||
### 3. The shape question depends on architecture, not algorithm
|
||||
|
||||
| Architecture | Shape | When to use |
|
||||
| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- |
|
||||
| MLP | {217} flat | Any algorithm, simplest baseline |
|
||||
| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) |
|
||||
| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network |
|
||||
| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now |
|
||||
|
||||
The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want
|
||||
convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet
|
||||
template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all
|
||||
— you just reshape the flat tensor in Python before passing it to the network.
|
||||
|
||||
### One thing that genuinely changes: value function perspective
|
||||
|
||||
AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This
|
||||
works well.
|
||||
|
||||
DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit
|
||||
current-player indicator), because a single Q-network estimates action values for both players simultaneously. With
|
||||
the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must
|
||||
learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity.
|
||||
This is a minor point for a symmetric game like Trictrac, but worth keeping in mind.
|
||||
|
||||
Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm.
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
[package]
|
||||
name = "spiel_bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
trictrac-store = { path = "../store" }
|
||||
anyhow = "1"
|
||||
rand = "0.9"
|
||||
rand_distr = "0.5"
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
rayon = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
||||
[[bench]]
|
||||
name = "alphazero"
|
||||
harness = false
|
||||
|
|
@ -1,373 +0,0 @@
|
|||
//! AlphaZero pipeline benchmarks.
|
||||
//!
|
||||
//! Run with:
|
||||
//!
|
||||
//! ```sh
|
||||
//! cargo bench -p spiel_bot
|
||||
//! ```
|
||||
//!
|
||||
//! Use `-- <filter>` to run a specific group, e.g.:
|
||||
//!
|
||||
//! ```sh
|
||||
//! cargo bench -p spiel_bot -- env/
|
||||
//! cargo bench -p spiel_bot -- network/
|
||||
//! cargo bench -p spiel_bot -- mcts/
|
||||
//! cargo bench -p spiel_bot -- episode/
|
||||
//! cargo bench -p spiel_bot -- train/
|
||||
//! ```
|
||||
//!
|
||||
//! Target: ≥ 500 games/s for random play on CPU (consistent with
|
||||
//! `random_game` throughput in `trictrac-store`).
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
tensor::{Tensor, TensorData, backend::Backend},
|
||||
};
|
||||
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
// ── Shared types ───────────────────────────────────────────────────────────
|
||||
|
||||
type InferB = NdArray<f32>;
|
||||
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
|
||||
|
||||
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||
|
||||
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
|
||||
|
||||
/// Uniform evaluator (returns zero logits and zero value).
|
||||
/// Used to isolate MCTS tree-traversal cost from network cost.
|
||||
struct ZeroEval(usize);
|
||||
impl Evaluator for ZeroEval {
|
||||
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||
(vec![0.0f32; self.0], 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
// ── 1. Environment primitives ──────────────────────────────────────────────
|
||||
|
||||
/// Baseline performance of the raw Trictrac environment without MCTS.
|
||||
/// Target: ≥ 500 full games / second.
|
||||
fn bench_env(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
|
||||
let mut group = c.benchmark_group("env");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
// ── apply_chance ──────────────────────────────────────────────────────
|
||||
group.bench_function("apply_chance", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
|
||||
env.new_game()
|
||||
},
|
||||
|mut s| {
|
||||
env.apply_chance(&mut s, &mut seeded());
|
||||
black_box(s)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
// ── legal_actions ─────────────────────────────────────────────────────
|
||||
group.bench_function("legal_actions", |b| {
|
||||
let mut rng = seeded();
|
||||
let mut s = env.new_game();
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
b.iter(|| black_box(env.legal_actions(&s)))
|
||||
});
|
||||
|
||||
// ── observation (to_tensor) ───────────────────────────────────────────
|
||||
group.bench_function("observation", |b| {
|
||||
let mut rng = seeded();
|
||||
let mut s = env.new_game();
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
b.iter(|| black_box(env.observation(&s, 0)))
|
||||
});
|
||||
|
||||
// ── full random game ──────────────────────────────────────────────────
|
||||
group.sample_size(50);
|
||||
group.bench_function("random_game", |b| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| {
|
||||
let mut s = env.new_game();
|
||||
loop {
|
||||
match env.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => env.apply_chance(&mut s, &mut rng),
|
||||
_ => {
|
||||
let actions = env.legal_actions(&s);
|
||||
let idx = rng.random_range(0..actions.len());
|
||||
env.apply(&mut s, actions[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
black_box(s)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 2. Network inference ───────────────────────────────────────────────────
|
||||
|
||||
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
|
||||
fn bench_network(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("network");
|
||||
group.measurement_time(Duration::from_secs(5));
|
||||
|
||||
for &hidden in &[64usize, 256] {
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
|
||||
let obs: Vec<f32> = vec![0.5; 217];
|
||||
|
||||
// Batch size 1 — single-position evaluation as in MCTS.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp_b1", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
// Batch size 32 — training mini-batch.
|
||||
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp_b32", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// ── ResNet (4 residual blocks) ────────────────────────────────────────
|
||||
for &hidden in &[256usize, 512] {
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = ResNet::<InferB>::new(&cfg, &infer_device());
|
||||
let obs: Vec<f32> = vec![0.5; 217];
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b1", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b32", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 3. MCTS ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// MCTS cost at different simulation budgets with two evaluator types:
|
||||
/// - `zero` — isolates tree-traversal overhead (no network).
|
||||
/// - `mlp64` — real MLP, shows end-to-end cost per move.
|
||||
fn bench_mcts(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
|
||||
// Build a decision-node state (after dice roll).
|
||||
let state = {
|
||||
let mut s = env.new_game();
|
||||
let mut rng = seeded();
|
||||
while env.current_player(&s).is_chance() {
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
}
|
||||
s
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group("mcts");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
let zero_eval = ZeroEval(514);
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
|
||||
|
||||
for &n_sim in &[1usize, 5, 20] {
|
||||
let cfg = MctsConfig {
|
||||
n_simulations: n_sim,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
// Zero evaluator: tree traversal only.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("zero_eval", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
// MLP evaluator: full cost per decision.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp64", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 4. Episode generation ─────────────────────────────────────────────────
|
||||
|
||||
/// Full self-play episode latency (one complete game) at different MCTS
|
||||
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
|
||||
fn bench_episode(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
|
||||
|
||||
let mut group = c.benchmark_group("episode");
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
|
||||
for &n_sim in &[1usize, 2] {
|
||||
let mcts_cfg = MctsConfig {
|
||||
n_simulations: n_sim,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("trictrac", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| {
|
||||
black_box(generate_episode(
|
||||
&env,
|
||||
&eval,
|
||||
&mcts_cfg,
|
||||
&|_| 1.0,
|
||||
&mut rng,
|
||||
))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 5. Training step ───────────────────────────────────────────────────────
|
||||
|
||||
/// Gradient-step latency for different batch sizes.
|
||||
fn bench_train(c: &mut Criterion) {
|
||||
use burn::optim::AdamConfig;
|
||||
|
||||
let mut group = c.benchmark_group("train");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
|
||||
let dummy_samples = |n: usize| -> Vec<TrainSample> {
|
||||
(0..n)
|
||||
.map(|i| TrainSample {
|
||||
obs: vec![0.5; 217],
|
||||
policy: {
|
||||
let mut p = vec![0.0f32; 514];
|
||||
p[i % 514] = 1.0;
|
||||
p
|
||||
},
|
||||
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
for &batch_size in &[16usize, 64] {
|
||||
let batch = dummy_samples(batch_size);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp64_adam", batch_size),
|
||||
&batch_size,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
(
|
||||
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
|
||||
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
|
||||
)
|
||||
},
|
||||
|(model, mut opt)| {
|
||||
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── Criterion entry point ──────────────────────────────────────────────────
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_env,
|
||||
bench_network,
|
||||
bench_mcts,
|
||||
bench_episode,
|
||||
bench_train,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
//! AlphaZero: self-play data generation, replay buffer, and training step.
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! | Module | Contents |
|
||||
//! |--------|----------|
|
||||
//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] |
|
||||
//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] |
|
||||
//! | [`trainer`] | [`train_step`] |
|
||||
//!
|
||||
//! # Typical outer loop
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use burn::backend::{Autodiff, NdArray};
|
||||
//! use burn::optim::AdamConfig;
|
||||
//! use spiel_bot::{
|
||||
//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step},
|
||||
//! env::TrictracEnv,
|
||||
//! mcts::MctsConfig,
|
||||
//! network::{MlpConfig, MlpNet},
|
||||
//! };
|
||||
//!
|
||||
//! type Infer = NdArray<f32>;
|
||||
//! type Train = Autodiff<NdArray<f32>>;
|
||||
//!
|
||||
//! let device = Default::default();
|
||||
//! let env = TrictracEnv;
|
||||
//! let config = AlphaZeroConfig::default();
|
||||
//!
|
||||
//! // Build training model and optimizer.
|
||||
//! let mut train_model = MlpNet::<Train>::new(&MlpConfig::default(), &device);
|
||||
//! let mut optimizer = AdamConfig::new().init();
|
||||
//! let mut replay = ReplayBuffer::new(config.replay_capacity);
|
||||
//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
//!
|
||||
//! for _iter in 0..config.n_iterations {
|
||||
//! // Convert to inference backend for self-play.
|
||||
//! let infer_model = MlpNet::<Infer>::new(&MlpConfig::default(), &device)
|
||||
//! .load_record(train_model.clone().into_record());
|
||||
//! let eval = BurnEvaluator::new(infer_model, device.clone());
|
||||
//!
|
||||
//! // Self-play: generate episodes.
|
||||
//! for _ in 0..config.n_games_per_iter {
|
||||
//! let samples = generate_episode(&env, &eval, &config.mcts,
|
||||
//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
|
||||
//! replay.extend(samples);
|
||||
//! }
|
||||
//!
|
||||
//! // Training: gradient steps.
|
||||
//! if replay.len() >= config.batch_size {
|
||||
//! for _ in 0..config.n_train_steps_per_iter {
|
||||
//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng)
|
||||
//! .into_iter().cloned().collect();
|
||||
//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device,
|
||||
//! config.learning_rate);
|
||||
//! train_model = m;
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod replay;
|
||||
pub mod selfplay;
|
||||
pub mod trainer;
|
||||
|
||||
pub use replay::{ReplayBuffer, TrainSample};
|
||||
pub use selfplay::{BurnEvaluator, generate_episode};
|
||||
pub use trainer::{cosine_lr, train_step};
|
||||
|
||||
use crate::mcts::MctsConfig;
|
||||
|
||||
// ── Configuration ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Top-level AlphaZero hyperparameters.
|
||||
///
|
||||
/// The MCTS parameters live in [`MctsConfig`]; this struct holds the
|
||||
/// outer training-loop parameters.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AlphaZeroConfig {
|
||||
/// MCTS parameters for self-play.
|
||||
pub mcts: MctsConfig,
|
||||
/// Number of self-play games per training iteration.
|
||||
pub n_games_per_iter: usize,
|
||||
/// Number of gradient steps per training iteration.
|
||||
pub n_train_steps_per_iter: usize,
|
||||
/// Mini-batch size for each gradient step.
|
||||
pub batch_size: usize,
|
||||
/// Maximum number of samples in the replay buffer.
|
||||
pub replay_capacity: usize,
|
||||
/// Initial (peak) Adam learning rate.
|
||||
pub learning_rate: f64,
|
||||
/// Minimum learning rate for cosine annealing (floor of the schedule).
|
||||
///
|
||||
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
|
||||
/// Compute the current LR with [`cosine_lr`]:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
|
||||
/// ```
|
||||
pub lr_min: f64,
|
||||
/// Number of outer iterations (self-play + train) to run.
|
||||
pub n_iterations: usize,
|
||||
/// Move index after which the action temperature drops to 0 (greedy play).
|
||||
pub temperature_drop_move: usize,
|
||||
}
|
||||
|
||||
impl Default for AlphaZeroConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mcts: MctsConfig {
|
||||
n_simulations: 100,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.1,
|
||||
dirichlet_eps: 0.25,
|
||||
temperature: 1.0,
|
||||
},
|
||||
n_games_per_iter: 10,
|
||||
n_train_steps_per_iter: 20,
|
||||
batch_size: 64,
|
||||
replay_capacity: 50_000,
|
||||
learning_rate: 1e-3,
|
||||
lr_min: 1e-4, // cosine annealing floor
|
||||
n_iterations: 100,
|
||||
temperature_drop_move: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
//! Replay buffer for AlphaZero self-play data.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use rand::Rng;
|
||||
|
||||
// ── Training sample ────────────────────────────────────────────────────────
|
||||
|
||||
/// One training example produced by self-play.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrainSample {
|
||||
/// Observation tensor from the acting player's perspective (`obs_size` floats).
|
||||
pub obs: Vec<f32>,
|
||||
/// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1).
|
||||
pub policy: Vec<f32>,
|
||||
/// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw.
|
||||
pub value: f32,
|
||||
}
|
||||
|
||||
// ── Replay buffer ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Fixed-capacity circular buffer of [`TrainSample`]s.
|
||||
///
|
||||
/// When the buffer is full, the oldest sample is evicted on push.
|
||||
/// Samples are drawn without replacement using a Fisher-Yates partial shuffle.
|
||||
pub struct ReplayBuffer {
|
||||
data: VecDeque<TrainSample>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
/// Create a buffer with the given maximum capacity.
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
data: VecDeque::with_capacity(capacity.min(1024)),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a sample; evicts the oldest if at capacity.
|
||||
pub fn push(&mut self, sample: TrainSample) {
|
||||
if self.data.len() == self.capacity {
|
||||
self.data.pop_front();
|
||||
}
|
||||
self.data.push_back(sample);
|
||||
}
|
||||
|
||||
/// Add all samples from an episode.
|
||||
pub fn extend(&mut self, samples: impl IntoIterator<Item = TrainSample>) {
|
||||
for s in samples {
|
||||
self.push(s);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Sample up to `n` distinct samples, without replacement.
|
||||
///
|
||||
/// If the buffer has fewer than `n` samples, all are returned (shuffled).
|
||||
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||
let len = self.data.len();
|
||||
let n = n.min(len);
|
||||
// Partial Fisher-Yates using index shuffling.
|
||||
let mut indices: Vec<usize> = (0..len).collect();
|
||||
for i in 0..n {
|
||||
let j = rng.random_range(i..len);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
fn dummy(value: f32) -> TrainSample {
|
||||
TrainSample { obs: vec![value], policy: vec![1.0], value }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_and_len() {
|
||||
let mut buf = ReplayBuffer::new(10);
|
||||
assert!(buf.is_empty());
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
assert_eq!(buf.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evicts_oldest_at_capacity() {
|
||||
let mut buf = ReplayBuffer::new(3);
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
buf.push(dummy(3.0));
|
||||
buf.push(dummy(4.0)); // evicts 1.0
|
||||
assert_eq!(buf.len(), 3);
|
||||
// Oldest remaining should be 2.0
|
||||
assert_eq!(buf.data[0].value, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_batch_size() {
|
||||
let mut buf = ReplayBuffer::new(20);
|
||||
for i in 0..10 {
|
||||
buf.push(dummy(i as f32));
|
||||
}
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let batch = buf.sample_batch(5, &mut rng);
|
||||
assert_eq!(batch.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_batch_capped_at_len() {
|
||||
let mut buf = ReplayBuffer::new(20);
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let batch = buf.sample_batch(100, &mut rng);
|
||||
assert_eq!(batch.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_batch_no_duplicates() {
|
||||
let mut buf = ReplayBuffer::new(20);
|
||||
for i in 0..10 {
|
||||
buf.push(dummy(i as f32));
|
||||
}
|
||||
let mut rng = SmallRng::seed_from_u64(1);
|
||||
let batch = buf.sample_batch(10, &mut rng);
|
||||
let mut seen: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||
seen.sort_by(f32::total_cmp);
|
||||
seen.dedup();
|
||||
assert_eq!(seen.len(), 10, "sample contained duplicates");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,238 +0,0 @@
|
|||
//! Self-play episode generation and Burn-backed evaluator.
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::env::GameEnv;
|
||||
use crate::mcts::{self, Evaluator, MctsConfig, MctsNode};
|
||||
use crate::network::PolicyValueNet;
|
||||
|
||||
use super::replay::TrainSample;
|
||||
|
||||
// ── BurnEvaluator ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS.
|
||||
///
|
||||
/// Use the **inference backend** (`NdArray<f32>`, no `Autodiff` wrapper) so
|
||||
/// that self-play generates no gradient tape overhead.
|
||||
pub struct BurnEvaluator<B: Backend, N: PolicyValueNet<B>> {
|
||||
model: N,
|
||||
device: B::Device,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
||||
pub fn new(model: N, device: B::Device) -> Self {
|
||||
Self { model, device, _b: PhantomData }
|
||||
}
|
||||
|
||||
pub fn into_model(self) -> N {
|
||||
self.model
|
||||
}
|
||||
|
||||
pub fn model_ref(&self) -> &N {
|
||||
&self.model
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||
// external synchronisation.
|
||||
unsafe impl<B: Backend, N: PolicyValueNet<B>> Send for BurnEvaluator<B, N> {}
|
||||
unsafe impl<B: Backend, N: PolicyValueNet<B>> Sync for BurnEvaluator<B, N> {}
|
||||
|
||||
impl<B: Backend, N: PolicyValueNet<B>> Evaluator for BurnEvaluator<B, N> {
|
||||
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32) {
|
||||
let obs_size = obs.len();
|
||||
let data = TensorData::new(obs.to_vec(), [1, obs_size]);
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(data, &self.device);
|
||||
|
||||
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||
|
||||
let policy: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||
let value: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||
|
||||
(policy, value[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ── Episode generation ─────────────────────────────────────────────────────
|
||||
|
||||
/// One pending observation waiting for its game-outcome value label.
|
||||
struct PendingSample {
|
||||
obs: Vec<f32>,
|
||||
policy: Vec<f32>,
|
||||
player: usize,
|
||||
}
|
||||
|
||||
/// Play one full game using MCTS guided by `evaluator`.
|
||||
///
|
||||
/// Returns a [`TrainSample`] for every decision step in the game.
|
||||
///
|
||||
/// `temperature_fn(step)` controls exploration: return `1.0` for early
|
||||
/// moves and `0.0` after a fixed number of moves (e.g. move 30).
|
||||
pub fn generate_episode<E: GameEnv>(
|
||||
env: &E,
|
||||
evaluator: &dyn Evaluator,
|
||||
mcts_config: &MctsConfig,
|
||||
temperature_fn: &dyn Fn(usize) -> f32,
|
||||
rng: &mut impl Rng,
|
||||
) -> Vec<TrainSample> {
|
||||
let mut state = env.new_game();
|
||||
let mut pending: Vec<PendingSample> = Vec::new();
|
||||
let mut step = 0usize;
|
||||
|
||||
loop {
|
||||
// Advance through chance nodes automatically.
|
||||
while env.current_player(&state).is_chance() {
|
||||
env.apply_chance(&mut state, rng);
|
||||
}
|
||||
|
||||
if env.current_player(&state).is_terminal() {
|
||||
break;
|
||||
}
|
||||
|
||||
let player_idx = env.current_player(&state).index().unwrap();
|
||||
|
||||
// Run MCTS to get a policy.
|
||||
let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng);
|
||||
let policy = mcts::mcts_policy(&root, env.action_space());
|
||||
|
||||
// Record the observation from the acting player's perspective.
|
||||
let obs = env.observation(&state, player_idx);
|
||||
pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx });
|
||||
|
||||
// Select and apply the action.
|
||||
let temperature = temperature_fn(step);
|
||||
let action = mcts::select_action(&root, temperature, rng);
|
||||
env.apply(&mut state, action);
|
||||
step += 1;
|
||||
}
|
||||
|
||||
// Assign game outcomes.
|
||||
let returns = env.returns(&state).unwrap_or([0.0; 2]);
|
||||
pending
|
||||
.into_iter()
|
||||
.map(|s| TrainSample {
|
||||
obs: s.obs,
|
||||
policy: s.policy,
|
||||
value: returns[s.player],
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use crate::env::Player;
|
||||
use crate::mcts::{Evaluator, MctsConfig};
|
||||
use crate::network::{MlpConfig, MlpNet};
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn rng() -> SmallRng {
|
||||
SmallRng::seed_from_u64(7)
|
||||
}
|
||||
|
||||
// Countdown game (same as in mcts tests).
|
||||
#[derive(Clone, Debug)]
|
||||
struct CState { remaining: u8, to_move: usize }
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CountdownEnv;
|
||||
|
||||
impl GameEnv for CountdownEnv {
|
||||
type State = CState;
|
||||
fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } }
|
||||
fn current_player(&self, s: &CState) -> Player {
|
||||
if s.remaining == 0 { Player::Terminal }
|
||||
else if s.to_move == 0 { Player::P1 } else { Player::P2 }
|
||||
}
|
||||
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||
}
|
||||
fn apply(&self, s: &mut CState, action: usize) {
|
||||
let sub = (action as u8) + 1;
|
||||
if s.remaining <= sub { s.remaining = 0; }
|
||||
else { s.remaining -= sub; s.to_move = 1 - s.to_move; }
|
||||
}
|
||||
fn apply_chance<R: Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||
vec![s.remaining as f32 / 4.0, s.to_move as f32]
|
||||
}
|
||||
fn obs_size(&self) -> usize { 2 }
|
||||
fn action_space(&self) -> usize { 2 }
|
||||
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||
if s.remaining != 0 { return None; }
|
||||
let mut r = [-1.0f32; 2];
|
||||
r[s.to_move] = 1.0;
|
||||
Some(r)
|
||||
}
|
||||
}
|
||||
|
||||
fn tiny_config() -> MctsConfig {
|
||||
MctsConfig { n_simulations: 5, c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 }
|
||||
}
|
||||
|
||||
// ── BurnEvaluator tests ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn burn_evaluator_output_shapes() {
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let eval = BurnEvaluator::new(model, device());
|
||||
let (policy, value) = eval.evaluate(&[0.5f32, 0.5]);
|
||||
assert_eq!(policy.len(), 2, "policy length should equal action_space");
|
||||
assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)");
|
||||
}
|
||||
|
||||
// ── generate_episode tests ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn episode_terminates_and_has_samples() {
|
||||
let env = CountdownEnv;
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let eval = BurnEvaluator::new(model, device());
|
||||
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_sample_values_are_valid() {
|
||||
let env = CountdownEnv;
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let eval = BurnEvaluator::new(model, device());
|
||||
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||
for s in &samples {
|
||||
assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
|
||||
"unexpected value {}", s.value);
|
||||
let sum: f32 = s.policy.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}");
|
||||
assert_eq!(s.obs.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_with_temperature_zero() {
|
||||
let env = CountdownEnv;
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let eval = BurnEvaluator::new(model, device());
|
||||
// temperature=0 means greedy; episode must still terminate
|
||||
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng());
|
||||
assert!(!samples.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,258 +0,0 @@
|
|||
//! One gradient-descent training step for AlphaZero.
|
||||
//!
|
||||
//! The loss combines:
|
||||
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
|
||||
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||
//! actual game outcome.
|
||||
//!
|
||||
//! # Learning-rate scheduling
|
||||
//!
|
||||
//! [`cosine_lr`] implements one-cycle cosine annealing:
|
||||
//!
|
||||
//! ```text
|
||||
//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T))
|
||||
//! ```
|
||||
//!
|
||||
//! Typical usage in the outer loop:
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! for step in 0..total_train_steps {
|
||||
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
|
||||
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
|
||||
//! model = m;
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! # Backend
|
||||
//!
|
||||
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
|
||||
//! Weights are transferred between the two via [`burn::record`].
|
||||
|
||||
use burn::{
|
||||
module::AutodiffModule,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
prelude::ElementConversion,
|
||||
tensor::{
|
||||
activation::log_softmax,
|
||||
backend::AutodiffBackend,
|
||||
Tensor, TensorData,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::network::PolicyValueNet;
|
||||
use super::replay::TrainSample;
|
||||
|
||||
/// Run one gradient step on `model` using `batch`.
|
||||
///
|
||||
/// Returns the updated model and the scalar loss value for logging.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `lr` — learning rate (e.g. `1e-3`).
|
||||
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
|
||||
pub fn train_step<B, N, O>(
|
||||
model: N,
|
||||
optimizer: &mut O,
|
||||
batch: &[TrainSample],
|
||||
device: &B::Device,
|
||||
lr: f64,
|
||||
) -> (N, f32)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
N: PolicyValueNet<B> + AutodiffModule<B>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
assert!(!batch.is_empty(), "train_step called with empty batch");
|
||||
|
||||
let batch_size = batch.len();
|
||||
let obs_size = batch[0].obs.len();
|
||||
let action_size = batch[0].policy.len();
|
||||
|
||||
// ── Build input tensors ────────────────────────────────────────────────
|
||||
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
|
||||
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||
device,
|
||||
);
|
||||
let policy_target = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(policy_flat, [batch_size, action_size]),
|
||||
device,
|
||||
);
|
||||
let value_target = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(value_flat, [batch_size, 1]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── Forward pass ──────────────────────────────────────────────────────
|
||||
let (policy_logits, value_pred) = model.forward(obs_tensor);
|
||||
|
||||
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
|
||||
let log_probs = log_softmax(policy_logits, 1);
|
||||
let policy_loss = (policy_target.clone().neg() * log_probs)
|
||||
.sum_dim(1)
|
||||
.mean();
|
||||
|
||||
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
|
||||
let diff = value_pred - value_target;
|
||||
let value_loss = (diff.clone() * diff).mean();
|
||||
|
||||
// ── Combined loss ─────────────────────────────────────────────────────
|
||||
let loss = policy_loss + value_loss;
|
||||
|
||||
// Extract scalar before backward (consumes the tensor).
|
||||
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||
|
||||
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
let model = optimizer.step(lr, model, grads);
|
||||
|
||||
(model, loss_scalar)
|
||||
}
|
||||
|
||||
// ── Learning-rate schedule ─────────────────────────────────────────────────
|
||||
|
||||
/// Cosine learning-rate schedule (one half-period, no warmup).
|
||||
///
|
||||
/// Returns the learning rate for training step `step` out of `total_steps`:
|
||||
///
|
||||
/// ```text
|
||||
/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total))
|
||||
/// ```
|
||||
///
|
||||
/// - At `t = 0` returns `initial`.
|
||||
/// - At `t = total_steps` (or beyond) returns `lr_min`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
|
||||
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
|
||||
if total_steps == 0 || step >= total_steps {
|
||||
return lr_min;
|
||||
}
|
||||
let progress = step as f64 / total_steps as f64;
|
||||
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
use crate::network::{MlpConfig, MlpNet};
|
||||
use super::super::replay::TrainSample;
|
||||
|
||||
type B = Autodiff<NdArray<f32>>;
|
||||
|
||||
fn device() -> <B as burn::tensor::backend::Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
|
||||
(0..n)
|
||||
.map(|i| TrainSample {
|
||||
obs: vec![0.5f32; obs_size],
|
||||
policy: {
|
||||
let mut p = vec![0.0f32; action_size];
|
||||
p[i % action_size] = 1.0;
|
||||
p
|
||||
},
|
||||
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_returns_finite_loss() {
|
||||
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
|
||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||
assert!(loss > 0.0, "loss should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_decreases_over_steps() {
|
||||
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||
let mut model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
// Same batch every step — loss should decrease.
|
||||
let batch = dummy_batch(16, 4, 4);
|
||||
|
||||
let mut prev_loss = f32::INFINITY;
|
||||
for _ in 0..10 {
|
||||
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
|
||||
model = m;
|
||||
assert!(loss.is_finite());
|
||||
prev_loss = loss;
|
||||
}
|
||||
// After 10 steps on fixed data, loss should be below a reasonable threshold.
|
||||
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_batch_size_one() {
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(1, 2, 2);
|
||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||
assert!(loss.is_finite());
|
||||
}
|
||||
|
||||
// ── cosine_lr ─────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_at_step_zero_is_initial() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
|
||||
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_at_end_is_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
|
||||
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_beyond_end_is_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
|
||||
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_midpoint_is_average() {
|
||||
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
|
||||
let expected = (1e-3 + 1e-5) / 2.0;
|
||||
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_monotone_decreasing() {
|
||||
let mut prev = f64::INFINITY;
|
||||
for step in 0..=100 {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
|
||||
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
|
||||
prev = lr;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_zero_total_steps_returns_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
|
||||
assert!((lr - 1e-5).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,262 +0,0 @@
|
|||
//! Evaluate a trained AlphaZero checkpoint against a random player.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```sh
|
||||
//! # Random weights (sanity check — should be ~50 %)
|
||||
//! cargo run -p spiel_bot --bin az_eval --release
|
||||
//!
|
||||
//! # Trained MLP checkpoint
|
||||
//! cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50
|
||||
//!
|
||||
//! # Trained ResNet checkpoint
|
||||
//! cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100
|
||||
//! ```
|
||||
//!
|
||||
//! # Options
|
||||
//!
|
||||
//! | Flag | Default | Description |
|
||||
//! |------|---------|-------------|
|
||||
//! | `--checkpoint <path>` | (none) | Load weights from `.mpk` file; random weights if omitted |
|
||||
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
|
||||
//! | `--hidden <N>` | 256 (mlp) / 512 (resnet) | Hidden size |
|
||||
//! | `--n-games <N>` | `100` | Games per side (total = 2 × N) |
|
||||
//! | `--n-sim <N>` | `50` | MCTS simulations per move |
|
||||
//! | `--seed <N>` | `42` | RNG seed |
|
||||
//! | `--c-puct <F>` | `1.5` | PUCT exploration constant |
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn::backend::NdArray;
|
||||
use rand::{SeedableRng, rngs::SmallRng, Rng};
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::BurnEvaluator,
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::{Evaluator, MctsConfig, run_mcts, select_action},
|
||||
network::{MlpConfig, MlpNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
type InferB = NdArray<f32>;
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
checkpoint: Option<PathBuf>,
|
||||
arch: String,
|
||||
hidden: Option<usize>,
|
||||
n_games: usize,
|
||||
n_sim: usize,
|
||||
seed: u64,
|
||||
c_puct: f32,
|
||||
}
|
||||
|
||||
impl Default for Args {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
checkpoint: None,
|
||||
arch: "mlp".into(),
|
||||
hidden: None,
|
||||
n_games: 100,
|
||||
n_sim: 50,
|
||||
seed: 42,
|
||||
c_puct: 1.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let raw: Vec<String> = std::env::args().collect();
|
||||
let mut args = Args::default();
|
||||
let mut i = 1;
|
||||
while i < raw.len() {
|
||||
match raw[i].as_str() {
|
||||
"--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); }
|
||||
"--arch" => { i += 1; args.arch = raw[i].clone(); }
|
||||
"--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); }
|
||||
"--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); }
|
||||
"--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); }
|
||||
"--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); }
|
||||
"--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); }
|
||||
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
args
|
||||
}
|
||||
|
||||
// ── Game loop ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Play one complete game.
|
||||
///
|
||||
/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black).
|
||||
/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0).
|
||||
fn play_game(
|
||||
env: &TrictracEnv,
|
||||
mcts_side: usize,
|
||||
evaluator: &dyn Evaluator,
|
||||
mcts_cfg: &MctsConfig,
|
||||
rng: &mut SmallRng,
|
||||
) -> [f32; 2] {
|
||||
let mut state = env.new_game();
|
||||
loop {
|
||||
match env.current_player(&state) {
|
||||
Player::Terminal => {
|
||||
return env.returns(&state).expect("Terminal state must have returns");
|
||||
}
|
||||
Player::Chance => env.apply_chance(&mut state, rng),
|
||||
player => {
|
||||
let side = player.index().unwrap(); // 0 = P1, 1 = P2
|
||||
let action = if side == mcts_side {
|
||||
let root = run_mcts(env, &state, evaluator, mcts_cfg, rng);
|
||||
select_action(&root, 0.0, rng) // greedy (temperature = 0)
|
||||
} else {
|
||||
let actions = env.legal_actions(&state);
|
||||
actions[rng.random_range(0..actions.len())]
|
||||
};
|
||||
env.apply(&mut state, action);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Statistics ────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Default)]
|
||||
struct Stats {
|
||||
wins: u32,
|
||||
draws: u32,
|
||||
losses: u32,
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
fn record(&mut self, mcts_return: f32) {
|
||||
if mcts_return > 0.0 { self.wins += 1; }
|
||||
else if mcts_return < 0.0 { self.losses += 1; }
|
||||
else { self.draws += 1; }
|
||||
}
|
||||
|
||||
fn total(&self) -> u32 { self.wins + self.draws + self.losses }
|
||||
|
||||
fn win_rate_decisive(&self) -> f64 {
|
||||
let d = self.wins + self.losses;
|
||||
if d == 0 { 0.5 } else { self.wins as f64 / d as f64 }
|
||||
}
|
||||
|
||||
fn print(&self) {
|
||||
let n = self.total();
|
||||
let pct = |k: u32| 100.0 * k as f64 / n as f64;
|
||||
println!(
|
||||
" Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)",
|
||||
self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Evaluation ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn run_evaluation(
|
||||
evaluator: &dyn Evaluator,
|
||||
n_games: usize,
|
||||
mcts_cfg: &MctsConfig,
|
||||
seed: u64,
|
||||
) -> (Stats, Stats) {
|
||||
let env = TrictracEnv;
|
||||
let total = n_games * 2;
|
||||
let mut as_p1 = Stats::default();
|
||||
let mut as_p2 = Stats::default();
|
||||
|
||||
for i in 0..total {
|
||||
// Alternate sides: even games → MctsAgent as P1, odd → as P2.
|
||||
let mcts_side = i % 2;
|
||||
let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64));
|
||||
let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng);
|
||||
|
||||
let mcts_return = result[mcts_side];
|
||||
if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); }
|
||||
|
||||
let done = i + 1;
|
||||
if done % 10 == 0 || done == total {
|
||||
eprint!("\r [{done}/{total}] ", );
|
||||
}
|
||||
}
|
||||
eprintln!();
|
||||
(as_p1, as_p2)
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
let device: <InferB as burn::tensor::backend::Backend>::Device = Default::default();
|
||||
|
||||
// ── Load model ────────────────────────────────────────────────────────
|
||||
let evaluator: Box<dyn Evaluator> = match args.arch.as_str() {
|
||||
"resnet" => {
|
||||
let hidden = args.hidden.unwrap_or(512);
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = match &args.checkpoint {
|
||||
Some(path) => ResNet::<InferB>::load(&cfg, path, &device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
|
||||
None => ResNet::new(&cfg, &device),
|
||||
};
|
||||
Box::new(BurnEvaluator::<InferB, ResNet<InferB>>::new(model, device))
|
||||
}
|
||||
"mlp" | _ => {
|
||||
let hidden = args.hidden.unwrap_or(256);
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = match &args.checkpoint {
|
||||
Some(path) => MlpNet::<InferB>::load(&cfg, path, &device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
|
||||
None => MlpNet::new(&cfg, &device),
|
||||
};
|
||||
Box::new(BurnEvaluator::<InferB, MlpNet<InferB>>::new(model, device))
|
||||
}
|
||||
};
|
||||
|
||||
let mcts_cfg = MctsConfig {
|
||||
n_simulations: args.n_sim,
|
||||
c_puct: args.c_puct,
|
||||
dirichlet_alpha: 0.0, // no exploration noise during evaluation
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 0.0, // greedy action selection
|
||||
};
|
||||
|
||||
// ── Header ────────────────────────────────────────────────────────────
|
||||
let ckpt_label = args.checkpoint
|
||||
.as_deref()
|
||||
.and_then(|p| p.file_name())
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("random weights");
|
||||
|
||||
println!();
|
||||
println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent",
|
||||
args.arch, args.n_sim);
|
||||
println!("Games per side: {} | Total: {} | Seed: {}",
|
||||
args.n_games, args.n_games * 2, args.seed);
|
||||
println!();
|
||||
|
||||
// ── Run ───────────────────────────────────────────────────────────────
|
||||
let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed);
|
||||
|
||||
// ── Results ───────────────────────────────────────────────────────────
|
||||
println!("MctsAgent as P1 (White):");
|
||||
as_p1.print();
|
||||
|
||||
println!("MctsAgent as P2 (Black):");
|
||||
as_p2.print();
|
||||
|
||||
let combined_wins = as_p1.wins + as_p2.wins;
|
||||
let combined_decisive = combined_wins + as_p1.losses + as_p2.losses;
|
||||
let combined_wr = if combined_decisive == 0 { 0.5 }
|
||||
else { combined_wins as f64 / combined_decisive as f64 };
|
||||
|
||||
println!();
|
||||
println!("Combined win rate (excluding draws): {:.1}% [{}/{}]",
|
||||
combined_wr * 100.0, combined_wins, combined_decisive);
|
||||
println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%",
|
||||
as_p1.win_rate_decisive() * 100.0,
|
||||
as_p2.win_rate_decisive() * 100.0);
|
||||
}
|
||||
|
|
@ -1,331 +0,0 @@
|
|||
//! AlphaZero self-play training loop.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```sh
|
||||
//! # Start fresh (MLP, default settings)
|
||||
//! cargo run -p spiel_bot --bin az_train --release
|
||||
//!
|
||||
//! # ResNet, 200 iterations, save every 20
|
||||
//! cargo run -p spiel_bot --bin az_train --release -- \
|
||||
//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/
|
||||
//!
|
||||
//! # Resume from a checkpoint
|
||||
//! cargo run -p spiel_bot --bin az_train --release -- \
|
||||
//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100
|
||||
//! ```
|
||||
//!
|
||||
//! # Options
|
||||
//!
|
||||
//! | Flag | Default | Description |
|
||||
//! |------|---------|-------------|
|
||||
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
|
||||
//! | `--hidden N` | 256/512 | Hidden layer width |
|
||||
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
|
||||
//! | `--n-iter N` | `100` | Training iterations |
|
||||
//! | `--n-games N` | `10` | Self-play games per iteration |
|
||||
//! | `--n-train N` | `20` | Gradient steps per iteration |
|
||||
//! | `--n-sim N` | `100` | MCTS simulations per move |
|
||||
//! | `--batch N` | `64` | Mini-batch size |
|
||||
//! | `--replay-cap N` | `50000` | Replay buffer capacity |
|
||||
//! | `--lr F` | `1e-3` | Peak (initial) learning rate |
|
||||
//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) |
|
||||
//! | `--c-puct F` | `1.5` | PUCT exploration constant |
|
||||
//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha |
|
||||
//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight |
|
||||
//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 |
|
||||
//! | `--save-every N` | `10` | Save checkpoint every N iterations |
|
||||
//! | `--seed N` | `42` | RNG seed |
|
||||
//! | `--resume PATH` | (none) | Load weights from checkpoint before training |
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
module::AutodiffModule,
|
||||
optim::AdamConfig,
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::{
|
||||
BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step,
|
||||
},
|
||||
env::TrictracEnv,
|
||||
mcts::MctsConfig,
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
type TrainB = Autodiff<NdArray<f32>>;
|
||||
type InferB = NdArray<f32>;
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
arch: String,
|
||||
hidden: Option<usize>,
|
||||
out_dir: PathBuf,
|
||||
n_iter: usize,
|
||||
n_games: usize,
|
||||
n_train: usize,
|
||||
n_sim: usize,
|
||||
batch_size: usize,
|
||||
replay_cap: usize,
|
||||
lr: f64,
|
||||
lr_min: f64,
|
||||
c_puct: f32,
|
||||
dirichlet_alpha: f32,
|
||||
dirichlet_eps: f32,
|
||||
temp_drop: usize,
|
||||
save_every: usize,
|
||||
seed: u64,
|
||||
resume: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Default for Args {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
arch: "mlp".into(),
|
||||
hidden: None,
|
||||
out_dir: PathBuf::from("checkpoints"),
|
||||
n_iter: 100,
|
||||
n_games: 10,
|
||||
n_train: 20,
|
||||
n_sim: 100,
|
||||
batch_size: 64,
|
||||
replay_cap: 50_000,
|
||||
lr: 1e-3,
|
||||
lr_min: 1e-4,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.1,
|
||||
dirichlet_eps: 0.25,
|
||||
temp_drop: 30,
|
||||
save_every: 10,
|
||||
seed: 42,
|
||||
resume: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let raw: Vec<String> = std::env::args().collect();
|
||||
let mut a = Args::default();
|
||||
let mut i = 1;
|
||||
while i < raw.len() {
|
||||
match raw[i].as_str() {
|
||||
"--arch" => { i += 1; a.arch = raw[i].clone(); }
|
||||
"--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); }
|
||||
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
|
||||
"--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); }
|
||||
"--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); }
|
||||
"--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); }
|
||||
"--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); }
|
||||
"--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); }
|
||||
"--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); }
|
||||
"--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); }
|
||||
"--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); }
|
||||
"--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); }
|
||||
"--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); }
|
||||
"--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); }
|
||||
"--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); }
|
||||
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
|
||||
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
|
||||
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
|
||||
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
// ── Training loop ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Generic training loop, parameterised over the network type.
|
||||
///
|
||||
/// `save_fn` receives the **training-backend** model and the target path;
|
||||
/// it is called in the match arm where the concrete network type is known.
|
||||
fn train_loop<N>(
|
||||
mut model: N,
|
||||
save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>,
|
||||
args: &Args,
|
||||
)
|
||||
where
|
||||
N: PolicyValueNet<TrainB> + AutodiffModule<TrainB> + Clone,
|
||||
<N as AutodiffModule<TrainB>>::InnerModule: PolicyValueNet<InferB> + Send + 'static,
|
||||
{
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
let infer_device: <InferB as Backend>::Device = Default::default();
|
||||
|
||||
// Type is inferred as OptimizerAdaptor<Adam, N, TrainB> at the call site.
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = ReplayBuffer::new(args.replay_cap);
|
||||
let mut rng = SmallRng::seed_from_u64(args.seed);
|
||||
let env = TrictracEnv;
|
||||
|
||||
// Total gradient steps (used for cosine LR denominator).
|
||||
let total_train_steps = (args.n_iter * args.n_train).max(1);
|
||||
let mut global_step = 0usize;
|
||||
|
||||
println!(
|
||||
"\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}",
|
||||
"", args.arch, args.n_iter, args.n_games, args.n_sim, ""
|
||||
);
|
||||
|
||||
for iter in 0..args.n_iter {
|
||||
let t0 = Instant::now();
|
||||
|
||||
// ── Self-play ────────────────────────────────────────────────────
|
||||
// Convert to inference backend (zero autodiff overhead).
|
||||
let infer_model: <N as AutodiffModule<TrainB>>::InnerModule = model.valid();
|
||||
let evaluator: BurnEvaluator<InferB, <N as AutodiffModule<TrainB>>::InnerModule> =
|
||||
BurnEvaluator::new(infer_model, infer_device.clone());
|
||||
|
||||
let mcts_cfg = MctsConfig {
|
||||
n_simulations: args.n_sim,
|
||||
c_puct: args.c_puct,
|
||||
dirichlet_alpha: args.dirichlet_alpha,
|
||||
dirichlet_eps: args.dirichlet_eps,
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let temp_drop = args.temp_drop;
|
||||
let temperature_fn = |step: usize| -> f32 {
|
||||
if step < temp_drop { 1.0 } else { 0.0 }
|
||||
};
|
||||
|
||||
// Prepare per-game seeds and evaluators sequentially so the main RNG
|
||||
// and model cloning stay deterministic regardless of thread scheduling.
|
||||
// Burn modules are Send but not Sync, so each task must own its model.
|
||||
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
|
||||
let game_evals: Vec<_> = (0..args.n_games)
|
||||
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
|
||||
.collect();
|
||||
drop(evaluator);
|
||||
|
||||
let all_samples: Vec<Vec<TrainSample>> = game_seeds
|
||||
.into_par_iter()
|
||||
.zip(game_evals.into_par_iter())
|
||||
.map(|(seed, game_eval)| {
|
||||
let mut game_rng = SmallRng::seed_from_u64(seed);
|
||||
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut new_samples = 0usize;
|
||||
for samples in all_samples {
|
||||
new_samples += samples.len();
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
||||
// ── Training ─────────────────────────────────────────────────────
|
||||
let mut loss_sum = 0.0f32;
|
||||
let mut n_steps = 0usize;
|
||||
|
||||
if replay.len() >= args.batch_size {
|
||||
for _ in 0..args.n_train {
|
||||
let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
|
||||
let batch: Vec<TrainSample> = replay
|
||||
.sample_batch(args.batch_size, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let (m, loss) =
|
||||
train_step(model, &mut optimizer, &batch, &train_device, lr);
|
||||
model = m;
|
||||
loss_sum += loss;
|
||||
n_steps += 1;
|
||||
global_step += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Logging ──────────────────────────────────────────────────────
|
||||
let elapsed = t0.elapsed();
|
||||
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
|
||||
let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
|
||||
|
||||
println!(
|
||||
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s",
|
||||
iter + 1,
|
||||
args.n_iter,
|
||||
replay.len(),
|
||||
new_samples,
|
||||
avg_loss,
|
||||
lr_now,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
// ── Checkpoint ───────────────────────────────────────────────────
|
||||
let is_last = iter + 1 == args.n_iter;
|
||||
if (iter + 1) % args.save_every == 0 || is_last {
|
||||
let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1));
|
||||
match save_fn(&model, &path) {
|
||||
Ok(()) => println!(" -> saved {}", path.display()),
|
||||
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nTraining complete.");
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
|
||||
// Create output directory if it doesn't exist.
|
||||
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
|
||||
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
|
||||
match args.arch.as_str() {
|
||||
"resnet" => {
|
||||
let hidden = args.hidden.unwrap_or(512);
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
|
||||
let model = match &args.resume {
|
||||
Some(path) => {
|
||||
println!("Resuming from {}", path.display());
|
||||
ResNet::<TrainB>::load(&cfg, path, &train_device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
|
||||
}
|
||||
None => ResNet::<TrainB>::new(&cfg, &train_device),
|
||||
};
|
||||
|
||||
train_loop(
|
||||
model,
|
||||
&|m: &ResNet<TrainB>, path: &Path| {
|
||||
// Save via inference model to avoid autodiff record overhead.
|
||||
m.valid().save(path)
|
||||
},
|
||||
&args,
|
||||
);
|
||||
}
|
||||
|
||||
"mlp" | _ => {
|
||||
let hidden = args.hidden.unwrap_or(256);
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
|
||||
let model = match &args.resume {
|
||||
Some(path) => {
|
||||
println!("Resuming from {}", path.display());
|
||||
MlpNet::<TrainB>::load(&cfg, path, &train_device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
|
||||
}
|
||||
None => MlpNet::<TrainB>::new(&cfg, &train_device),
|
||||
};
|
||||
|
||||
train_loop(
|
||||
model,
|
||||
&|m: &MlpNet<TrainB>, path: &Path| m.valid().save(path),
|
||||
&args,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,251 +0,0 @@
|
|||
//! DQN self-play training loop.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```sh
|
||||
//! # Start fresh with default settings
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release
|
||||
//!
|
||||
//! # Custom hyperparameters
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||
//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000
|
||||
//!
|
||||
//! # Resume from a checkpoint
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||
//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100
|
||||
//! ```
|
||||
//!
|
||||
//! # Options
|
||||
//!
|
||||
//! | Flag | Default | Description |
|
||||
//! |------|---------|-------------|
|
||||
//! | `--hidden N` | 256 | Hidden layer width |
|
||||
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
|
||||
//! | `--n-iter N` | 100 | Training iterations |
|
||||
//! | `--n-games N` | 10 | Self-play games per iteration |
|
||||
//! | `--n-train N` | 20 | Gradient steps per iteration |
|
||||
//! | `--batch N` | 64 | Mini-batch size |
|
||||
//! | `--replay-cap N` | 50000 | Replay buffer capacity |
|
||||
//! | `--lr F` | 1e-3 | Adam learning rate |
|
||||
//! | `--epsilon-start F` | 1.0 | Initial exploration rate |
|
||||
//! | `--epsilon-end F` | 0.05 | Final exploration rate |
|
||||
//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor |
|
||||
//! | `--gamma F` | 0.99 | Discount factor |
|
||||
//! | `--target-update N` | 500 | Hard-update target net every N steps |
|
||||
//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) |
|
||||
//! | `--save-every N` | 10 | Save checkpoint every N iterations |
|
||||
//! | `--seed N` | 42 | RNG seed |
|
||||
//! | `--resume PATH` | (none) | Load weights before training |
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
module::AutodiffModule,
|
||||
optim::AdamConfig,
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use spiel_bot::{
|
||||
dqn::{
|
||||
DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step,
|
||||
generate_dqn_episode, hard_update, linear_epsilon,
|
||||
},
|
||||
env::TrictracEnv,
|
||||
network::{QNet, QNetConfig},
|
||||
};
|
||||
|
||||
type TrainB = Autodiff<NdArray<f32>>;
|
||||
type InferB = NdArray<f32>;
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
hidden: usize,
|
||||
out_dir: PathBuf,
|
||||
save_every: usize,
|
||||
seed: u64,
|
||||
resume: Option<PathBuf>,
|
||||
config: DqnConfig,
|
||||
}
|
||||
|
||||
impl Default for Args {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hidden: 256,
|
||||
out_dir: PathBuf::from("checkpoints"),
|
||||
save_every: 10,
|
||||
seed: 42,
|
||||
resume: None,
|
||||
config: DqnConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let raw: Vec<String> = std::env::args().collect();
|
||||
let mut a = Args::default();
|
||||
let mut i = 1;
|
||||
while i < raw.len() {
|
||||
match raw[i].as_str() {
|
||||
"--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); }
|
||||
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
|
||||
"--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); }
|
||||
"--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); }
|
||||
"--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); }
|
||||
"--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); }
|
||||
"--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); }
|
||||
"--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); }
|
||||
"--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); }
|
||||
"--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); }
|
||||
"--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); }
|
||||
"--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); }
|
||||
"--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); }
|
||||
"--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); }
|
||||
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
|
||||
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
|
||||
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
|
||||
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
// ── Training loop ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn train_loop(
|
||||
mut q_net: QNet<TrainB>,
|
||||
cfg: &QNetConfig,
|
||||
save_fn: &dyn Fn(&QNet<TrainB>, &Path) -> anyhow::Result<()>,
|
||||
args: &Args,
|
||||
) {
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
let infer_device: <InferB as Backend>::Device = Default::default();
|
||||
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = DqnReplayBuffer::new(args.config.replay_capacity);
|
||||
let mut rng = SmallRng::seed_from_u64(args.seed);
|
||||
let env = TrictracEnv;
|
||||
|
||||
let mut target_net: QNet<InferB> = hard_update::<TrainB, _>(&q_net);
|
||||
let mut global_step = 0usize;
|
||||
let mut epsilon = args.config.epsilon_start;
|
||||
|
||||
println!(
|
||||
"\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}",
|
||||
"", args.config.n_iterations, args.config.n_games_per_iter,
|
||||
args.config.n_train_steps_per_iter, ""
|
||||
);
|
||||
|
||||
for iter in 0..args.config.n_iterations {
|
||||
let t0 = Instant::now();
|
||||
|
||||
// ── Self-play ────────────────────────────────────────────────────
|
||||
let infer_q: QNet<InferB> = q_net.valid();
|
||||
let mut new_samples = 0usize;
|
||||
|
||||
for _ in 0..args.config.n_games_per_iter {
|
||||
let samples = generate_dqn_episode(
|
||||
&env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale,
|
||||
);
|
||||
new_samples += samples.len();
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
||||
// ── Training ─────────────────────────────────────────────────────
|
||||
let mut loss_sum = 0.0f32;
|
||||
let mut n_steps = 0usize;
|
||||
|
||||
if replay.len() >= args.config.batch_size {
|
||||
for _ in 0..args.config.n_train_steps_per_iter {
|
||||
let batch: Vec<_> = replay
|
||||
.sample_batch(args.config.batch_size, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Target Q-values computed on the inference backend.
|
||||
let target_q = compute_target_q(
|
||||
&target_net, &batch, cfg.action_size, &infer_device,
|
||||
);
|
||||
|
||||
let (q, loss) = dqn_train_step(
|
||||
q_net, &mut optimizer, &batch, &target_q,
|
||||
&train_device, args.config.learning_rate, args.config.gamma,
|
||||
);
|
||||
q_net = q;
|
||||
loss_sum += loss;
|
||||
n_steps += 1;
|
||||
global_step += 1;
|
||||
|
||||
// Hard-update target net every target_update_freq steps.
|
||||
if global_step % args.config.target_update_freq == 0 {
|
||||
target_net = hard_update::<TrainB, _>(&q_net);
|
||||
}
|
||||
|
||||
// Linear epsilon decay.
|
||||
epsilon = linear_epsilon(
|
||||
args.config.epsilon_start,
|
||||
args.config.epsilon_end,
|
||||
global_step,
|
||||
args.config.epsilon_decay_steps,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Logging ──────────────────────────────────────────────────────
|
||||
let elapsed = t0.elapsed();
|
||||
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
|
||||
|
||||
println!(
|
||||
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s",
|
||||
iter + 1,
|
||||
args.config.n_iterations,
|
||||
replay.len(),
|
||||
new_samples,
|
||||
avg_loss,
|
||||
epsilon,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
// ── Checkpoint ───────────────────────────────────────────────────
|
||||
let is_last = iter + 1 == args.config.n_iterations;
|
||||
if (iter + 1) % args.save_every == 0 || is_last {
|
||||
let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1));
|
||||
match save_fn(&q_net, &path) {
|
||||
Ok(()) => println!(" -> saved {}", path.display()),
|
||||
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nDQN training complete.");
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
|
||||
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
|
||||
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden };
|
||||
|
||||
let q_net = match &args.resume {
|
||||
Some(path) => {
|
||||
println!("Resuming from {}", path.display());
|
||||
QNet::<TrainB>::load(&cfg, path, &train_device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
|
||||
}
|
||||
None => QNet::<TrainB>::new(&cfg, &train_device),
|
||||
};
|
||||
|
||||
train_loop(q_net, &cfg, &|m: &QNet<TrainB>, path| m.valid().save(path), &args);
|
||||
}
|
||||
|
|
@ -1,247 +0,0 @@
|
|||
//! DQN self-play episode generation.
|
||||
//!
|
||||
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
|
||||
//! mirroring so that each player always acts from "White's perspective").
|
||||
//! Transitions for both players are stored in the returned sample list.
|
||||
//!
|
||||
//! # Reward
|
||||
//!
|
||||
//! After each full decision (action applied and the state has advanced through
|
||||
//! any intervening chance nodes back to the same player's next turn), the
|
||||
//! reward is:
|
||||
//!
|
||||
//! ```text
|
||||
//! r = (my_total_score_now − my_total_score_then)
|
||||
//! − (opp_total_score_now − opp_total_score_then)
|
||||
//! ```
|
||||
//!
|
||||
//! where `total_score = holes × 12 + points`.
|
||||
//!
|
||||
//! # Transition structure
|
||||
//!
|
||||
//! We use a "pending transition" per player. When a player acts again, we
|
||||
//! *complete* the previous pending transition by filling in `next_obs`,
|
||||
//! `next_legal`, and computing `reward`. Terminal transitions are completed
|
||||
//! when the game ends.
|
||||
|
||||
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::env::{GameEnv, TrictracEnv};
|
||||
use crate::network::QValueNet;
|
||||
use super::DqnSample;
|
||||
|
||||
// ── Internals ─────────────────────────────────────────────────────────────────
|
||||
|
||||
struct PendingTransition {
|
||||
obs: Vec<f32>,
|
||||
action: usize,
|
||||
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
|
||||
score_before: [i32; 2],
|
||||
}
|
||||
|
||||
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
|
||||
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
|
||||
q_net: &Q,
|
||||
obs: &[f32],
|
||||
legal: &[usize],
|
||||
epsilon: f32,
|
||||
rng: &mut impl Rng,
|
||||
device: &B::Device,
|
||||
) -> usize {
|
||||
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
|
||||
if rng.random::<f32>() < epsilon {
|
||||
legal[rng.random_range(0..legal.len())]
|
||||
} else {
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs.to_vec(), [1, obs.len()]),
|
||||
device,
|
||||
);
|
||||
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||
legal
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|&a, &b| {
|
||||
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
|
||||
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
|
||||
let opp_idx = 1 - player_idx;
|
||||
((score_after[player_idx] - score_before[player_idx])
|
||||
- (score_after[opp_idx] - score_before[opp_idx])) as f32
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Play one full game and return all transitions for both players.
|
||||
///
|
||||
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
|
||||
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
|
||||
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
|
||||
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
|
||||
env: &TrictracEnv,
|
||||
q_net: &Q,
|
||||
epsilon: f32,
|
||||
rng: &mut impl Rng,
|
||||
device: &B::Device,
|
||||
reward_scale: f32,
|
||||
) -> Vec<DqnSample> {
|
||||
let obs_size = env.obs_size();
|
||||
let mut state = env.new_game();
|
||||
let mut pending: [Option<PendingTransition>; 2] = [None, None];
|
||||
let mut samples: Vec<DqnSample> = Vec::new();
|
||||
|
||||
loop {
|
||||
// ── Advance past chance nodes ──────────────────────────────────────
|
||||
while env.current_player(&state).is_chance() {
|
||||
env.apply_chance(&mut state, rng);
|
||||
}
|
||||
|
||||
let score_now = TrictracEnv::score_snapshot(&state);
|
||||
|
||||
if env.current_player(&state).is_terminal() {
|
||||
// Complete all pending transitions as terminal.
|
||||
for player_idx in 0..2 {
|
||||
if let Some(prev) = pending[player_idx].take() {
|
||||
let reward =
|
||||
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||
samples.push(DqnSample {
|
||||
obs: prev.obs,
|
||||
action: prev.action,
|
||||
reward,
|
||||
next_obs: vec![0.0; obs_size],
|
||||
next_legal: vec![],
|
||||
done: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let player_idx = env.current_player(&state).index().unwrap();
|
||||
let legal = env.legal_actions(&state);
|
||||
let obs = env.observation(&state, player_idx);
|
||||
|
||||
// ── Complete the previous transition for this player ───────────────
|
||||
if let Some(prev) = pending[player_idx].take() {
|
||||
let reward =
|
||||
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||
samples.push(DqnSample {
|
||||
obs: prev.obs,
|
||||
action: prev.action,
|
||||
reward,
|
||||
next_obs: obs.clone(),
|
||||
next_legal: legal.clone(),
|
||||
done: false,
|
||||
});
|
||||
}
|
||||
|
||||
// ── Pick and apply action ──────────────────────────────────────────
|
||||
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
|
||||
env.apply(&mut state, action);
|
||||
|
||||
// ── Record new pending transition ──────────────────────────────────
|
||||
pending[player_idx] = Some(PendingTransition {
|
||||
obs,
|
||||
action,
|
||||
score_before: score_now,
|
||||
});
|
||||
}
|
||||
|
||||
samples
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use crate::network::{QNet, QNetConfig};
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device { Default::default() }
|
||||
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
|
||||
|
||||
fn tiny_q() -> QNet<B> {
|
||||
QNet::new(&QNetConfig::default(), &device())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_terminates_and_produces_samples() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_obs_size_correct() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
for s in &samples {
|
||||
assert_eq!(s.obs.len(), 217, "obs size mismatch");
|
||||
if s.done {
|
||||
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
|
||||
assert!(s.next_legal.is_empty());
|
||||
} else {
|
||||
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
|
||||
assert!(!s.next_legal.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_actions_within_action_space() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
for s in &samples {
|
||||
assert!(s.action < 514, "action {} out of bounds", s.action);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greedy_episode_also_terminates() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
|
||||
assert!(!samples.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn at_least_one_done_sample() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
let n_done = samples.iter().filter(|s| s.done).count();
|
||||
// Two players, so 1 or 2 terminal transitions.
|
||||
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_reward_correct() {
|
||||
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
|
||||
let before = [2 * 12 + 10, 0];
|
||||
let after = [3 * 12 + 2, 0];
|
||||
let r = compute_reward(0, &before, &after);
|
||||
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_reward_with_opponent_scoring() {
|
||||
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
|
||||
let before = [0, 0];
|
||||
let after = [2, 3];
|
||||
let r = compute_reward(0, &before, &after);
|
||||
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,232 +0,0 @@
|
|||
//! DQN: self-play data generation, replay buffer, and training step.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//!
|
||||
//! Deep Q-Network with:
|
||||
//! - **ε-greedy** exploration (linearly decayed).
|
||||
//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where
|
||||
//! `score = holes × 12 + points`.
|
||||
//! - **Experience replay** with a fixed-capacity circular buffer.
|
||||
//! - **Target network**: hard-copied from the online Q-net every
|
||||
//! `target_update_freq` gradient steps for training stability.
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! | Module | Contents |
|
||||
//! |--------|----------|
|
||||
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
|
||||
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
|
||||
|
||||
pub mod episode;
|
||||
pub mod trainer;
|
||||
|
||||
pub use episode::generate_dqn_episode;
|
||||
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use rand::Rng;
|
||||
|
||||
// ── DqnSample ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// One transition `(s, a, r, s', done)` collected during self-play.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DqnSample {
|
||||
/// Observation from the acting player's perspective (`obs_size` floats).
|
||||
pub obs: Vec<f32>,
|
||||
/// Action index taken.
|
||||
pub action: usize,
|
||||
/// Per-turn reward: `my_score_delta − opponent_score_delta`.
|
||||
pub reward: f32,
|
||||
/// Next observation from the same player's perspective.
|
||||
/// All-zeros when `done = true` (ignored by the TD target).
|
||||
pub next_obs: Vec<f32>,
|
||||
/// Legal actions at `next_obs`. Empty when `done = true`.
|
||||
pub next_legal: Vec<usize>,
|
||||
/// `true` when `next_obs` is a terminal state.
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
|
||||
///
|
||||
/// When full, the oldest sample is evicted on push.
|
||||
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
|
||||
pub struct DqnReplayBuffer {
|
||||
data: VecDeque<DqnSample>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl DqnReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
|
||||
}
|
||||
|
||||
pub fn push(&mut self, sample: DqnSample) {
|
||||
if self.data.len() == self.capacity {
|
||||
self.data.pop_front();
|
||||
}
|
||||
self.data.push_back(sample);
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
|
||||
for s in samples { self.push(s); }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.data.len() }
|
||||
pub fn is_empty(&self) -> bool { self.data.is_empty() }
|
||||
|
||||
/// Sample up to `n` distinct samples without replacement.
|
||||
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
|
||||
let len = self.data.len();
|
||||
let n = n.min(len);
|
||||
let mut indices: Vec<usize> = (0..len).collect();
|
||||
for i in 0..n {
|
||||
let j = rng.random_range(i..len);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── DqnConfig ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Top-level DQN hyperparameters for the training loop.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DqnConfig {
|
||||
/// Initial exploration rate (1.0 = fully random).
|
||||
pub epsilon_start: f32,
|
||||
/// Final exploration rate after decay.
|
||||
pub epsilon_end: f32,
|
||||
/// Number of gradient steps over which ε decays linearly from start to end.
|
||||
///
|
||||
/// Should be calibrated to the total number of gradient steps
|
||||
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
|
||||
/// means exploration never reaches `epsilon_end` during the run.
|
||||
pub epsilon_decay_steps: usize,
|
||||
/// Discount factor γ for the TD target. Typical: 0.99.
|
||||
pub gamma: f32,
|
||||
/// Hard-copy Q → target every this many gradient steps.
|
||||
///
|
||||
/// Should be much smaller than the total number of gradient steps
|
||||
/// (`n_iterations × n_train_steps_per_iter`).
|
||||
pub target_update_freq: usize,
|
||||
/// Adam learning rate.
|
||||
pub learning_rate: f64,
|
||||
/// Mini-batch size for each gradient step.
|
||||
pub batch_size: usize,
|
||||
/// Maximum number of samples in the replay buffer.
|
||||
pub replay_capacity: usize,
|
||||
/// Number of outer iterations (self-play + train).
|
||||
pub n_iterations: usize,
|
||||
/// Self-play games per iteration.
|
||||
pub n_games_per_iter: usize,
|
||||
/// Gradient steps per iteration.
|
||||
pub n_train_steps_per_iter: usize,
|
||||
/// Reward normalisation divisor.
|
||||
///
|
||||
/// Per-turn rewards (score delta) are divided by this constant before being
|
||||
/// stored. Without normalisation, rewards can reach ±24 (jan with
|
||||
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
|
||||
/// causing MSE loss to grow unboundedly.
|
||||
///
|
||||
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
|
||||
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
|
||||
pub reward_scale: f32,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
// Total gradient steps with these defaults = 500 × 20 = 10_000,
|
||||
// so epsilon decays fully and the target is updated 100 times.
|
||||
Self {
|
||||
epsilon_start: 1.0,
|
||||
epsilon_end: 0.05,
|
||||
epsilon_decay_steps: 10_000,
|
||||
gamma: 0.99,
|
||||
target_update_freq: 100,
|
||||
learning_rate: 1e-3,
|
||||
batch_size: 64,
|
||||
replay_capacity: 50_000,
|
||||
n_iterations: 500,
|
||||
n_games_per_iter: 10,
|
||||
n_train_steps_per_iter: 20,
|
||||
reward_scale: 12.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
|
||||
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
|
||||
if decay_steps == 0 || step >= decay_steps {
|
||||
return end;
|
||||
}
|
||||
start + (end - start) * (step as f32 / decay_steps as f32)
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
fn dummy(reward: f32) -> DqnSample {
|
||||
DqnSample {
|
||||
obs: vec![0.0],
|
||||
action: 0,
|
||||
reward,
|
||||
next_obs: vec![0.0],
|
||||
next_legal: vec![0],
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_and_len() {
|
||||
let mut buf = DqnReplayBuffer::new(10);
|
||||
assert!(buf.is_empty());
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
assert_eq!(buf.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evicts_oldest_at_capacity() {
|
||||
let mut buf = DqnReplayBuffer::new(3);
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
buf.push(dummy(3.0));
|
||||
buf.push(dummy(4.0));
|
||||
assert_eq!(buf.len(), 3);
|
||||
assert_eq!(buf.data[0].reward, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_batch_size() {
|
||||
let mut buf = DqnReplayBuffer::new(20);
|
||||
for i in 0..10 { buf.push(dummy(i as f32)); }
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_start() {
|
||||
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_end() {
|
||||
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_monotone() {
|
||||
let mut prev = f32::INFINITY;
|
||||
for step in 0..=100 {
|
||||
let e = linear_epsilon(1.0, 0.05, step, 100);
|
||||
assert!(e <= prev + 1e-6);
|
||||
prev = e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,278 +0,0 @@
|
|||
//! DQN gradient step and target-network management.
|
||||
//!
|
||||
//! # TD target
|
||||
//!
|
||||
//! ```text
|
||||
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
|
||||
//! y_i = r_i if done
|
||||
//! ```
|
||||
//!
|
||||
//! # Loss
|
||||
//!
|
||||
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
|
||||
//! and `y_i` (computed from the frozen target net).
|
||||
//!
|
||||
//! # Target network
|
||||
//!
|
||||
//! [`hard_update`] copies the online Q-net weights into the target net by
|
||||
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
|
||||
|
||||
use burn::{
|
||||
module::AutodiffModule,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
prelude::ElementConversion,
|
||||
tensor::{
|
||||
Int, Tensor, TensorData,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::network::QValueNet;
|
||||
use super::DqnSample;
|
||||
|
||||
// ── Target Q computation ─────────────────────────────────────────────────────
|
||||
|
||||
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
|
||||
///
|
||||
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
|
||||
/// (their bootstrap term is dropped by the TD target anyway).
|
||||
///
|
||||
/// The target network runs on the **inference backend** (`InferB`) with no
|
||||
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
|
||||
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
|
||||
target_net: &Q,
|
||||
batch: &[DqnSample],
|
||||
action_size: usize,
|
||||
device: &B::Device,
|
||||
) -> Vec<f32> {
|
||||
let batch_size = batch.len();
|
||||
|
||||
// Collect indices of non-done samples (done samples have no next state).
|
||||
let non_done: Vec<usize> = batch
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, s)| !s.done)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if non_done.is_empty() {
|
||||
return vec![0.0; batch_size];
|
||||
}
|
||||
|
||||
let obs_size = batch[0].next_obs.len();
|
||||
let nd = non_done.len();
|
||||
|
||||
// Stack next observations for non-done samples → [nd, obs_size].
|
||||
let obs_flat: Vec<f32> = non_done
|
||||
.iter()
|
||||
.flat_map(|&i| batch[i].next_obs.iter().copied())
|
||||
.collect();
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [nd, obs_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// Forward target net → [nd, action_size], then to Vec<f32>.
|
||||
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||
|
||||
// For each non-done sample, pick max Q over legal next actions.
|
||||
let mut result = vec![0.0f32; batch_size];
|
||||
for (k, &i) in non_done.iter().enumerate() {
|
||||
let legal = &batch[i].next_legal;
|
||||
let offset = k * action_size;
|
||||
let max_q = legal
|
||||
.iter()
|
||||
.map(|&a| q_flat[offset + a])
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
// If legal is empty (shouldn't happen for non-done, but be safe):
|
||||
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// ── Training step ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run one gradient step on `q_net` using `batch`.
|
||||
///
|
||||
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
|
||||
/// frozen target network and passed in here so that this function only
|
||||
/// needs the **autodiff backend**.
|
||||
///
|
||||
/// Returns the updated network and the scalar MSE loss.
|
||||
pub fn dqn_train_step<B, Q, O>(
|
||||
q_net: Q,
|
||||
optimizer: &mut O,
|
||||
batch: &[DqnSample],
|
||||
target_max_q: &[f32],
|
||||
device: &B::Device,
|
||||
lr: f64,
|
||||
gamma: f32,
|
||||
) -> (Q, f32)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
Q: QValueNet<B> + AutodiffModule<B>,
|
||||
O: Optimizer<Q, B>,
|
||||
{
|
||||
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
|
||||
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
|
||||
|
||||
let batch_size = batch.len();
|
||||
let obs_size = batch[0].obs.len();
|
||||
|
||||
// ── Build observation tensor [B, obs_size] ────────────────────────────
|
||||
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
|
||||
let q_all = q_net.forward(obs_tensor);
|
||||
|
||||
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
|
||||
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
|
||||
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
|
||||
TensorData::new(actions, [batch_size]),
|
||||
device,
|
||||
)
|
||||
.reshape([batch_size, 1]); // [B] → [B, 1]
|
||||
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
|
||||
|
||||
// ── TD targets: r + γ · max_next_q · (1 − done) ──────────────────────
|
||||
let targets: Vec<f32> = batch
|
||||
.iter()
|
||||
.zip(target_max_q.iter())
|
||||
.map(|(s, &max_q)| {
|
||||
if s.done { s.reward } else { s.reward + gamma * max_q }
|
||||
})
|
||||
.collect();
|
||||
let target_tensor = Tensor::<B, 1>::from_data(
|
||||
TensorData::new(targets, [batch_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── MSE loss ──────────────────────────────────────────────────────────
|
||||
let diff = q_pred - target_tensor.detach();
|
||||
let loss = (diff.clone() * diff).mean();
|
||||
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||
|
||||
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &q_net);
|
||||
let q_net = optimizer.step(lr, q_net, grads);
|
||||
|
||||
(q_net, loss_scalar)
|
||||
}
|
||||
|
||||
// ── Target network update ─────────────────────────────────────────────────────
|
||||
|
||||
/// Hard-copy the online Q-net weights to a new target network.
|
||||
///
|
||||
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
|
||||
/// inference-backend module with identical weights.
|
||||
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
|
||||
q_net.valid()
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
optim::AdamConfig,
|
||||
};
|
||||
use crate::network::{QNet, QNetConfig};
|
||||
|
||||
type InferB = NdArray<f32>;
|
||||
type TrainB = Autodiff<NdArray<f32>>;
|
||||
|
||||
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||
|
||||
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
|
||||
(0..n)
|
||||
.map(|i| DqnSample {
|
||||
obs: vec![0.5f32; obs_size],
|
||||
action: i % action_size,
|
||||
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
next_obs: vec![0.5f32; obs_size],
|
||||
next_legal: vec![0, 1],
|
||||
done: i == n - 1,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_target_q_length() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
assert_eq!(tq.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_target_q_done_is_zero() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
// Single done sample.
|
||||
let batch = vec![DqnSample {
|
||||
obs: vec![0.0; 4],
|
||||
action: 0,
|
||||
reward: 5.0,
|
||||
next_obs: vec![0.0; 4],
|
||||
next_legal: vec![],
|
||||
done: true,
|
||||
}];
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
assert_eq!(tq.len(), 1);
|
||||
assert_eq!(tq[0], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_returns_finite_loss() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
|
||||
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_loss_decreases() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(16, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
|
||||
let mut prev_loss = f32::INFINITY;
|
||||
for _ in 0..10 {
|
||||
let (q, loss) = dqn_train_step(
|
||||
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
|
||||
);
|
||||
q_net = q;
|
||||
assert!(loss.is_finite());
|
||||
prev_loss = loss;
|
||||
}
|
||||
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hard_update_copies_weights() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = hard_update::<TrainB, _>(&q_net);
|
||||
|
||||
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
|
||||
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
|
||||
// After hard_update the target produces finite outputs.
|
||||
assert!(q_out.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
}
|
||||
121
spiel_bot/src/env/mod.rs
vendored
121
spiel_bot/src/env/mod.rs
vendored
|
|
@ -1,121 +0,0 @@
|
|||
//! Game environment abstraction — the minimal "Rust OpenSpiel".
|
||||
//!
|
||||
//! A `GameEnv` describes the rules of a two-player, zero-sum game that may
|
||||
//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN,
|
||||
//! and PPO interact with a game exclusively through this trait.
|
||||
//!
|
||||
//! # Node taxonomy
|
||||
//!
|
||||
//! Every game position belongs to one of four categories, returned by
|
||||
//! [`GameEnv::current_player`]:
|
||||
//!
|
||||
//! | [`Player`] | Meaning |
|
||||
//! |-----------|---------|
|
||||
//! | `P1` | Player 1 (index 0) must choose an action |
|
||||
//! | `P2` | Player 2 (index 1) must choose an action |
|
||||
//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) |
|
||||
//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful |
|
||||
//!
|
||||
//! # Perspective convention
|
||||
//!
|
||||
//! [`GameEnv::observation`] always returns the board from *the requested
|
||||
//! player's* point of view. Callers pass `pov = 0` for Player 1 and
|
||||
//! `pov = 1` for Player 2. The implementation is responsible for any
|
||||
//! mirroring required (e.g. Trictrac always reasons from White's side).
|
||||
|
||||
pub mod trictrac;
|
||||
pub use trictrac::TrictracEnv;
|
||||
|
||||
/// Who controls the current game node.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Player {
|
||||
/// Player 1 (index 0) is to move.
|
||||
P1,
|
||||
/// Player 2 (index 1) is to move.
|
||||
P2,
|
||||
/// A stochastic event (dice roll, etc.) must be resolved.
|
||||
Chance,
|
||||
/// The game is over.
|
||||
Terminal,
|
||||
}
|
||||
|
||||
impl Player {
|
||||
/// Returns the player index (0 or 1) if this is a decision node,
|
||||
/// or `None` for `Chance` / `Terminal`.
|
||||
pub fn index(self) -> Option<usize> {
|
||||
match self {
|
||||
Player::P1 => Some(0),
|
||||
Player::P2 => Some(1),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_decision(self) -> bool {
|
||||
matches!(self, Player::P1 | Player::P2)
|
||||
}
|
||||
|
||||
pub fn is_chance(self) -> bool {
|
||||
self == Player::Chance
|
||||
}
|
||||
|
||||
pub fn is_terminal(self) -> bool {
|
||||
self == Player::Terminal
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that completely describes a two-player zero-sum game.
|
||||
///
|
||||
/// Implementors must be cheaply cloneable (the type is used as a stateless
|
||||
/// factory; the mutable game state lives in `Self::State`).
|
||||
pub trait GameEnv: Clone + Send + Sync + 'static {
|
||||
/// The mutable game state. Must be `Clone` so MCTS can copy
|
||||
/// game trees without touching the environment.
|
||||
type State: Clone + Send + Sync;
|
||||
|
||||
// ── State creation ────────────────────────────────────────────────────
|
||||
|
||||
/// Create a fresh game state at the initial position.
|
||||
fn new_game(&self) -> Self::State;
|
||||
|
||||
// ── Node queries ──────────────────────────────────────────────────────
|
||||
|
||||
/// Classify the current node.
|
||||
fn current_player(&self, s: &Self::State) -> Player;
|
||||
|
||||
/// Legal action indices at a decision node (`current_player` is `P1`/`P2`).
|
||||
///
|
||||
/// The returned indices are in `[0, action_space())`.
|
||||
/// The result is unspecified (may panic or return empty) when called at a
|
||||
/// `Chance` or `Terminal` node.
|
||||
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
|
||||
|
||||
// ── State mutation ────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a player action. `action` must be a value returned by
|
||||
/// [`legal_actions`] for the current state.
|
||||
fn apply(&self, s: &mut Self::State, action: usize);
|
||||
|
||||
/// Sample and apply a stochastic outcome. Must only be called when
|
||||
/// `current_player(s) == Player::Chance`.
|
||||
fn apply_chance<R: rand::Rng>(&self, s: &mut Self::State, rng: &mut R);
|
||||
|
||||
// ── Observation ───────────────────────────────────────────────────────
|
||||
|
||||
/// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2).
|
||||
/// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`.
|
||||
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
|
||||
|
||||
/// Number of floats returned by [`observation`].
|
||||
fn obs_size(&self) -> usize;
|
||||
|
||||
/// Total number of distinct action indices (the policy head output size).
|
||||
fn action_space(&self) -> usize;
|
||||
|
||||
// ── Terminal values ───────────────────────────────────────────────────
|
||||
|
||||
/// Game outcome for each player, or `None` if the game is not over.
|
||||
///
|
||||
/// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw.
|
||||
/// Index 0 = Player 1, index 1 = Player 2.
|
||||
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
|
||||
}
|
||||
547
spiel_bot/src/env/trictrac.rs
vendored
547
spiel_bot/src/env/trictrac.rs
vendored
|
|
@ -1,547 +0,0 @@
|
|||
//! [`GameEnv`] implementation for Trictrac.
|
||||
//!
|
||||
//! # Game flow (schools_enabled = false)
|
||||
//!
|
||||
//! With scoring schools disabled (the standard training configuration),
|
||||
//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine
|
||||
//! applies them automatically inside `RollResult` and `Move`. The only
|
||||
//! four stages that actually occur are:
|
||||
//!
|
||||
//! | `TurnStage` | [`Player`] kind | Handled by |
|
||||
//! |-------------|-----------------|------------|
|
||||
//! | `RollDice` | `Chance` | [`apply_chance`] |
|
||||
//! | `RollWaiting` | `Chance` | [`apply_chance`] |
|
||||
//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] |
|
||||
//! | `Move` | `P1`/`P2` | [`apply`] |
|
||||
//!
|
||||
//! # Perspective
|
||||
//!
|
||||
//! The Trictrac engine always reasons from White's perspective. Player 1 is
|
||||
//! White; Player 2 is Black. When Player 2 is active, the board is mirrored
|
||||
//! before computing legal actions / the observation tensor, and the resulting
|
||||
//! event is mirrored back before being applied to the real state. This
|
||||
//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`.
|
||||
|
||||
use trictrac_store::{
|
||||
training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE},
|
||||
Dice, GameEvent, GameState, Stage, TurnStage,
|
||||
};
|
||||
|
||||
use super::{GameEnv, Player};
|
||||
|
||||
/// Stateless factory that produces Trictrac [`GameState`] environments.
|
||||
///
|
||||
/// Schools (`schools_enabled`) are always disabled — scoring is automatic.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TrictracEnv;
|
||||
|
||||
impl GameEnv for TrictracEnv {
|
||||
type State = GameState;
|
||||
|
||||
// ── State creation ────────────────────────────────────────────────────
|
||||
|
||||
fn new_game(&self) -> GameState {
|
||||
GameState::new_with_players("P1", "P2")
|
||||
}
|
||||
|
||||
// ── Node queries ──────────────────────────────────────────────────────
|
||||
|
||||
fn current_player(&self, s: &GameState) -> Player {
|
||||
if s.stage == Stage::Ended {
|
||||
return Player::Terminal;
|
||||
}
|
||||
match s.turn_stage {
|
||||
TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance,
|
||||
_ => {
|
||||
if s.active_player_id == 1 {
|
||||
Player::P1
|
||||
} else {
|
||||
Player::P2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the legal action indices for the active player.
|
||||
///
|
||||
/// The board is automatically mirrored for Player 2 so that the engine
|
||||
/// always reasons from White's perspective. The returned indices are
|
||||
/// identical in meaning for both players (checker ordinals are
|
||||
/// perspective-relative).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug builds if called at a `Chance` or `Terminal` node.
|
||||
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
|
||||
debug_assert!(
|
||||
self.current_player(s).is_decision(),
|
||||
"legal_actions called at a non-decision node (turn_stage={:?})",
|
||||
s.turn_stage
|
||||
);
|
||||
let indices = if s.active_player_id == 2 {
|
||||
get_valid_action_indices(&s.mirror())
|
||||
} else {
|
||||
get_valid_action_indices(s)
|
||||
};
|
||||
indices.unwrap_or_default()
|
||||
}
|
||||
|
||||
// ── State mutation ────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a player action index to the game state.
|
||||
///
|
||||
/// For Player 2, the action is decoded against the mirrored board and
|
||||
/// the resulting event is un-mirrored before being applied.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug builds if `action` cannot be decoded or does not
|
||||
/// produce a valid event for the current state.
|
||||
fn apply(&self, s: &mut GameState, action: usize) {
|
||||
let needs_mirror = s.active_player_id == 2;
|
||||
|
||||
let event = if needs_mirror {
|
||||
let view = s.mirror();
|
||||
TrictracAction::from_action_index(action)
|
||||
.and_then(|a| a.to_event(&view))
|
||||
.map(|e| e.get_mirror(false))
|
||||
} else {
|
||||
TrictracAction::from_action_index(action).and_then(|a| a.to_event(s))
|
||||
};
|
||||
|
||||
match event {
|
||||
Some(e) => {
|
||||
s.consume(&e).expect("apply: consume failed for valid action");
|
||||
}
|
||||
None => {
|
||||
panic!("apply: action index {action} produced no event in state {s}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample dice and advance through a chance node.
|
||||
///
|
||||
/// Handles both `RollDice` (triggers the roll mechanism, then samples
|
||||
/// dice) and `RollWaiting` (only samples dice) in a single call so that
|
||||
/// callers never need to distinguish the two.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug builds if called at a non-Chance node.
|
||||
fn apply_chance<R: rand::Rng>(&self, s: &mut GameState, rng: &mut R) {
|
||||
debug_assert!(
|
||||
self.current_player(s).is_chance(),
|
||||
"apply_chance called at a non-Chance node (turn_stage={:?})",
|
||||
s.turn_stage
|
||||
);
|
||||
|
||||
// Step 1: RollDice → RollWaiting (player initiates the roll).
|
||||
if s.turn_stage == TurnStage::RollDice {
|
||||
s.consume(&GameEvent::Roll {
|
||||
player_id: s.active_player_id,
|
||||
})
|
||||
.expect("apply_chance: Roll event failed");
|
||||
}
|
||||
|
||||
// Step 2: RollWaiting → Move / HoldOrGoChoice / Ended.
|
||||
// With schools_enabled=false, point marking is automatic inside consume().
|
||||
let dice = Dice {
|
||||
values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)),
|
||||
};
|
||||
s.consume(&GameEvent::RollResult {
|
||||
player_id: s.active_player_id,
|
||||
dice,
|
||||
})
|
||||
.expect("apply_chance: RollResult event failed");
|
||||
}
|
||||
|
||||
// ── Observation ───────────────────────────────────────────────────────
|
||||
|
||||
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
|
||||
if pov == 0 {
|
||||
s.to_tensor()
|
||||
} else {
|
||||
s.mirror().to_tensor()
|
||||
}
|
||||
}
|
||||
|
||||
fn obs_size(&self) -> usize {
|
||||
217
|
||||
}
|
||||
|
||||
fn action_space(&self) -> usize {
|
||||
ACTION_SPACE_SIZE
|
||||
}
|
||||
|
||||
// ── Terminal values ───────────────────────────────────────────────────
|
||||
|
||||
/// Returns `Some([r1, r2])` when the game is over, `None` otherwise.
|
||||
///
|
||||
/// The winner (higher cumulative score) receives `+1.0`; the loser
|
||||
/// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score
|
||||
/// is `holes × 12 + points`.
|
||||
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
|
||||
if s.stage != Stage::Ended {
|
||||
return None;
|
||||
}
|
||||
let score = |id: u64| -> i32 {
|
||||
s.players
|
||||
.get(&id)
|
||||
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||
.unwrap_or(0)
|
||||
};
|
||||
let s1 = score(1);
|
||||
let s2 = score(2);
|
||||
Some(match s1.cmp(&s2) {
|
||||
std::cmp::Ordering::Greater => [1.0, -1.0],
|
||||
std::cmp::Ordering::Less => [-1.0, 1.0],
|
||||
std::cmp::Ordering::Equal => [0.0, 0.0],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── DQN helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
impl TrictracEnv {
|
||||
/// Score snapshot for DQN reward computation.
|
||||
///
|
||||
/// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`.
|
||||
/// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2).
|
||||
pub fn score_snapshot(s: &GameState) -> [i32; 2] {
|
||||
[s.total_score(1), s.total_score(2)]
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{rngs::SmallRng, Rng, SeedableRng};
|
||||
|
||||
fn env() -> TrictracEnv {
|
||||
TrictracEnv
|
||||
}
|
||||
|
||||
fn seeded_rng(seed: u64) -> SmallRng {
|
||||
SmallRng::seed_from_u64(seed)
|
||||
}
|
||||
|
||||
// ── Initial state ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn new_game_is_chance_node() {
|
||||
let e = env();
|
||||
let s = e.new_game();
|
||||
// A fresh game starts at RollDice — a Chance node.
|
||||
assert_eq!(e.current_player(&s), Player::Chance);
|
||||
assert!(e.returns(&s).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_game_is_not_terminal() {
|
||||
let e = env();
|
||||
let s = e.new_game();
|
||||
assert_ne!(e.current_player(&s), Player::Terminal);
|
||||
assert!(e.returns(&s).is_none());
|
||||
}
|
||||
|
||||
// ── Chance nodes ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_chance_reaches_decision_node() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(1);
|
||||
|
||||
// A single chance step must yield a decision node (or end the game,
|
||||
// which only happens after 12 holes — impossible on the first roll).
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
let p = e.current_player(&s);
|
||||
assert!(
|
||||
p.is_decision(),
|
||||
"expected decision node after first roll, got {p:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_chance_from_rollwaiting() {
|
||||
// Check that apply_chance works when called mid-way (at RollWaiting).
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
assert_eq!(s.turn_stage, TurnStage::RollDice);
|
||||
|
||||
// Manually advance to RollWaiting.
|
||||
s.consume(&GameEvent::Roll { player_id: s.active_player_id })
|
||||
.unwrap();
|
||||
assert_eq!(s.turn_stage, TurnStage::RollWaiting);
|
||||
|
||||
let mut rng = seeded_rng(2);
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
|
||||
let p = e.current_player(&s);
|
||||
assert!(p.is_decision() || p.is_terminal());
|
||||
}
|
||||
|
||||
// ── Legal actions ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn legal_actions_nonempty_after_roll() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(3);
|
||||
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
assert!(e.current_player(&s).is_decision());
|
||||
|
||||
let actions = e.legal_actions(&s);
|
||||
assert!(
|
||||
!actions.is_empty(),
|
||||
"legal_actions must be non-empty at a decision node"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legal_actions_within_action_space() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(4);
|
||||
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
for &a in e.legal_actions(&s).iter() {
|
||||
assert!(
|
||||
a < e.action_space(),
|
||||
"action {a} out of bounds (action_space={})",
|
||||
e.action_space()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Observations ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn observation_has_correct_size() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(5);
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
|
||||
assert_eq!(e.observation(&s, 0).len(), e.obs_size());
|
||||
assert_eq!(e.observation(&s, 1).len(), e.obs_size());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn observation_values_in_unit_interval() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(6);
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
|
||||
for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] {
|
||||
for (i, &v) in obs.iter().enumerate() {
|
||||
assert!(
|
||||
v >= 0.0 && v <= 1.0,
|
||||
"pov={pov}: obs[{i}] = {v} is outside [0,1]"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn p1_and_p2_observations_differ() {
|
||||
// The board is mirrored for P2, so the two observations should differ
|
||||
// whenever there are checkers in non-symmetric positions (always true
|
||||
// in a real game after a few moves).
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(7);
|
||||
|
||||
// Advance far enough that the board is non-trivial.
|
||||
for _ in 0..6 {
|
||||
while e.current_player(&s).is_chance() {
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
}
|
||||
if e.current_player(&s).is_terminal() {
|
||||
break;
|
||||
}
|
||||
let actions = e.legal_actions(&s);
|
||||
e.apply(&mut s, actions[0]);
|
||||
}
|
||||
|
||||
if !e.current_player(&s).is_terminal() {
|
||||
let obs0 = e.observation(&s, 0);
|
||||
let obs1 = e.observation(&s, 1);
|
||||
assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Applying actions ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_changes_state() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(8);
|
||||
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
assert!(e.current_player(&s).is_decision());
|
||||
|
||||
let before = s.clone();
|
||||
let action = e.legal_actions(&s)[0];
|
||||
e.apply(&mut s, action);
|
||||
|
||||
assert_ne!(
|
||||
before.turn_stage, s.turn_stage,
|
||||
"state must change after apply"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_all_legal_actions_do_not_panic() {
|
||||
// Verify that every action returned by legal_actions can be applied
|
||||
// without panicking (on several independent copies of the same state).
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(9);
|
||||
|
||||
e.apply_chance(&mut s, &mut rng);
|
||||
assert!(e.current_player(&s).is_decision());
|
||||
|
||||
for action in e.legal_actions(&s) {
|
||||
let mut copy = s.clone();
|
||||
e.apply(&mut copy, action); // must not panic
|
||||
}
|
||||
}
|
||||
|
||||
// ── Full game ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a complete game with random actions through the `GameEnv` trait
|
||||
/// and verify that:
|
||||
/// - The game terminates.
|
||||
/// - `returns()` is `Some` at the end.
|
||||
/// - The outcome is valid: scores sum to 0 (zero-sum) or each player's
|
||||
/// score is ±1 / 0.
|
||||
/// - No step panics.
|
||||
#[test]
|
||||
fn full_random_game_terminates() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(42);
|
||||
let max_steps = 50_000;
|
||||
|
||||
for step in 0..max_steps {
|
||||
match e.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||
Player::P1 | Player::P2 => {
|
||||
let actions = e.legal_actions(&s);
|
||||
assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node");
|
||||
let idx = rng.random_range(0..actions.len());
|
||||
e.apply(&mut s, actions[idx]);
|
||||
}
|
||||
}
|
||||
assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps");
|
||||
}
|
||||
|
||||
let result = e.returns(&s);
|
||||
assert!(result.is_some(), "returns() must be Some at Terminal");
|
||||
|
||||
let [r1, r2] = result.unwrap();
|
||||
let sum = r1 + r2;
|
||||
assert!(
|
||||
(sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5,
|
||||
"game must be zero-sum: r1={r1}, r2={r2}, sum={sum}"
|
||||
);
|
||||
assert!(
|
||||
r1.abs() <= 1.0 && r2.abs() <= 1.0,
|
||||
"returns must be in [-1,1]: r1={r1}, r2={r2}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Run multiple games with different seeds to stress-test for panics.
|
||||
#[test]
|
||||
fn multiple_games_no_panic() {
|
||||
let e = env();
|
||||
let max_steps = 20_000;
|
||||
|
||||
for seed in 0..10u64 {
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(seed);
|
||||
|
||||
for _ in 0..max_steps {
|
||||
match e.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||
Player::P1 | Player::P2 => {
|
||||
let actions = e.legal_actions(&s);
|
||||
let idx = rng.random_range(0..actions.len());
|
||||
e.apply(&mut s, actions[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Returns ───────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn returns_none_mid_game() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(11);
|
||||
|
||||
// Advance a few steps but do not finish the game.
|
||||
for _ in 0..4 {
|
||||
match e.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||
Player::P1 | Player::P2 => {
|
||||
let actions = e.legal_actions(&s);
|
||||
e.apply(&mut s, actions[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !e.current_player(&s).is_terminal() {
|
||||
assert!(
|
||||
e.returns(&s).is_none(),
|
||||
"returns() must be None before the game ends"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Player 2 actions ──────────────────────────────────────────────────
|
||||
|
||||
/// Verify that Player 2 (Black) can take actions without panicking,
|
||||
/// and that the state advances correctly.
|
||||
#[test]
|
||||
fn player2_can_act() {
|
||||
let e = env();
|
||||
let mut s = e.new_game();
|
||||
let mut rng = seeded_rng(12);
|
||||
|
||||
// Keep stepping until Player 2 gets a turn.
|
||||
let max_steps = 5_000;
|
||||
let mut p2_acted = false;
|
||||
|
||||
for _ in 0..max_steps {
|
||||
match e.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||
Player::P2 => {
|
||||
let actions = e.legal_actions(&s);
|
||||
assert!(!actions.is_empty());
|
||||
e.apply(&mut s, actions[0]);
|
||||
p2_acted = true;
|
||||
break;
|
||||
}
|
||||
Player::P1 => {
|
||||
let actions = e.legal_actions(&s);
|
||||
e.apply(&mut s, actions[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
pub mod alphazero;
|
||||
pub mod dqn;
|
||||
pub mod env;
|
||||
pub mod mcts;
|
||||
pub mod network;
|
||||
|
|
@ -1,412 +0,0 @@
|
|||
//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//!
|
||||
//! The implementation follows AlphaZero's MCTS:
|
||||
//!
|
||||
//! 1. **Expand root** — run the network once to get priors and a value
|
||||
//! estimate; optionally add Dirichlet noise for training-time exploration.
|
||||
//! 2. **Simulate** `n_simulations` times:
|
||||
//! - *Selection* — traverse the tree with PUCT until an unvisited leaf.
|
||||
//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes;
|
||||
//! chance nodes are **not** stored in the tree (outcome sampling).
|
||||
//! - *Expansion* — evaluate the network at the leaf; populate children.
|
||||
//! - *Backup* — propagate the value upward; negate at each player boundary.
|
||||
//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]).
|
||||
//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]).
|
||||
//!
|
||||
//! # Perspective convention
|
||||
//!
|
||||
//! Every [`MctsNode::w`] is stored **from the perspective of the player who
|
||||
//! acts at that node**. The backup negates the child value whenever the
|
||||
//! acting player differs between parent and child.
|
||||
//!
|
||||
//! # Stochastic games
|
||||
//!
|
||||
//! When [`GameEnv::current_player`] returns [`Player::Chance`], the
|
||||
//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and
|
||||
//! continues. Chance nodes are skipped transparently; Q-values converge to
|
||||
//! their expectation over many simulations (outcome sampling).
|
||||
|
||||
pub mod node;
|
||||
mod search;
|
||||
|
||||
pub use node::MctsNode;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
use crate::env::GameEnv;
|
||||
|
||||
// ── Evaluator trait ────────────────────────────────────────────────────────
|
||||
|
||||
/// Evaluates a game position for use in MCTS.
|
||||
///
|
||||
/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet)
|
||||
/// but the `mcts` module itself does **not** depend on Burn.
|
||||
pub trait Evaluator: Send + Sync {
|
||||
/// Evaluate `obs` (flat observation vector of length `obs_size`).
|
||||
///
|
||||
/// Returns:
|
||||
/// - `policy_logits`: one raw logit per action (`action_space` entries).
|
||||
/// Illegal action entries are masked inside the search — no need to
|
||||
/// zero them here.
|
||||
/// - `value`: scalar in `(-1, 1)` from **the current player's** perspective.
|
||||
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||
}
|
||||
|
||||
// ── Configuration ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Hyperparameters for [`run_mcts`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MctsConfig {
|
||||
/// Number of MCTS simulations per move. Typical: 50–800.
|
||||
pub n_simulations: usize,
|
||||
/// PUCT exploration constant `c_puct`. Typical: 1.0–2.0.
|
||||
pub c_puct: f32,
|
||||
/// Dirichlet noise concentration α. Set to `0.0` to disable.
|
||||
/// Typical: `0.3` for Chess, `0.1` for large action spaces.
|
||||
pub dirichlet_alpha: f32,
|
||||
/// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`.
|
||||
pub dirichlet_eps: f32,
|
||||
/// Action sampling temperature. `> 0` = proportional sample, `0` = argmax.
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl Default for MctsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
n_simulations: 200,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.3,
|
||||
dirichlet_eps: 0.25,
|
||||
temperature: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Public interface ───────────────────────────────────────────────────────
|
||||
|
||||
/// Run MCTS from `state` and return the populated root [`MctsNode`].
|
||||
///
|
||||
/// `state` must be a player-decision node (`P1` or `P2`).
|
||||
/// Use [`mcts_policy`] and [`select_action`] on the returned root.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `env.current_player(state)` is not `P1` or `P2`.
|
||||
pub fn run_mcts<E: GameEnv>(
|
||||
env: &E,
|
||||
state: &E::State,
|
||||
evaluator: &dyn Evaluator,
|
||||
config: &MctsConfig,
|
||||
rng: &mut impl Rng,
|
||||
) -> MctsNode {
|
||||
let player_idx = env
|
||||
.current_player(state)
|
||||
.index()
|
||||
.expect("run_mcts called at a non-decision node");
|
||||
|
||||
// ── Expand root (network called once here, not inside the loop) ────────
|
||||
let mut root = MctsNode::new(1.0);
|
||||
search::expand::<E>(&mut root, state, env, evaluator, player_idx);
|
||||
|
||||
// ── Optional Dirichlet noise for training exploration ──────────────────
|
||||
if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 {
|
||||
search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng);
|
||||
}
|
||||
|
||||
// ── Simulations ────────────────────────────────────────────────────────
|
||||
for _ in 0..config.n_simulations {
|
||||
search::simulate::<E>(
|
||||
&mut root,
|
||||
state.clone(),
|
||||
env,
|
||||
evaluator,
|
||||
config,
|
||||
rng,
|
||||
player_idx,
|
||||
);
|
||||
}
|
||||
|
||||
root
|
||||
}
|
||||
|
||||
/// Compute the MCTS policy: normalized visit counts at the root.
|
||||
///
|
||||
/// Returns a vector of length `action_space` where `policy[a]` is the
|
||||
/// fraction of simulations that visited action `a`.
|
||||
pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec<f32> {
|
||||
let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum();
|
||||
let mut policy = vec![0.0f32; action_space];
|
||||
if total > 0.0 {
|
||||
for (a, child) in &root.children {
|
||||
policy[*a] = child.n as f32 / total;
|
||||
}
|
||||
} else if !root.children.is_empty() {
|
||||
// n_simulations = 0: uniform over legal actions.
|
||||
let uniform = 1.0 / root.children.len() as f32;
|
||||
for (a, _) in &root.children {
|
||||
policy[*a] = uniform;
|
||||
}
|
||||
}
|
||||
policy
|
||||
}
|
||||
|
||||
/// Select an action index from the root after MCTS.
|
||||
///
|
||||
/// * `temperature = 0` — greedy argmax of visit counts.
|
||||
/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the root has no children.
|
||||
pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize {
|
||||
assert!(!root.children.is_empty(), "select_action called on a root with no children");
|
||||
if temperature <= 0.0 {
|
||||
root.children
|
||||
.iter()
|
||||
.max_by_key(|(_, c)| c.n)
|
||||
.map(|(a, _)| *a)
|
||||
.unwrap()
|
||||
} else {
|
||||
let weights: Vec<f32> = root
|
||||
.children
|
||||
.iter()
|
||||
.map(|(_, c)| (c.n as f32).powf(1.0 / temperature))
|
||||
.collect();
|
||||
let total: f32 = weights.iter().sum();
|
||||
let mut r: f32 = rng.random::<f32>() * total;
|
||||
for (i, (a, _)) in root.children.iter().enumerate() {
|
||||
r -= weights[i];
|
||||
if r <= 0.0 {
|
||||
return *a;
|
||||
}
|
||||
}
|
||||
root.children.last().map(|(a, _)| *a).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
use crate::env::Player;
|
||||
|
||||
// ── Minimal deterministic test game ───────────────────────────────────
|
||||
//
|
||||
// "Countdown" — two players alternate subtracting 1 or 2 from a counter.
|
||||
// The player who brings the counter to 0 wins.
|
||||
// No chance nodes, two legal actions (0 = -1, 1 = -2).
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CState {
|
||||
remaining: u8,
|
||||
to_move: usize, // at terminal: last mover (winner)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CountdownEnv;
|
||||
|
||||
impl crate::env::GameEnv for CountdownEnv {
|
||||
type State = CState;
|
||||
|
||||
fn new_game(&self) -> CState {
|
||||
CState { remaining: 6, to_move: 0 }
|
||||
}
|
||||
|
||||
fn current_player(&self, s: &CState) -> Player {
|
||||
if s.remaining == 0 {
|
||||
Player::Terminal
|
||||
} else if s.to_move == 0 {
|
||||
Player::P1
|
||||
} else {
|
||||
Player::P2
|
||||
}
|
||||
}
|
||||
|
||||
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||
}
|
||||
|
||||
fn apply(&self, s: &mut CState, action: usize) {
|
||||
let sub = (action as u8) + 1;
|
||||
if s.remaining <= sub {
|
||||
s.remaining = 0;
|
||||
// to_move stays as winner
|
||||
} else {
|
||||
s.remaining -= sub;
|
||||
s.to_move = 1 - s.to_move;
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||
|
||||
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||
vec![s.remaining as f32 / 6.0, s.to_move as f32]
|
||||
}
|
||||
|
||||
fn obs_size(&self) -> usize { 2 }
|
||||
fn action_space(&self) -> usize { 2 }
|
||||
|
||||
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||
if s.remaining != 0 { return None; }
|
||||
let mut r = [-1.0f32; 2];
|
||||
r[s.to_move] = 1.0;
|
||||
Some(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Uniform evaluator: all logits = 0, value = 0.
|
||||
// `action_space` must match the environment's `action_space()`.
|
||||
struct ZeroEval(usize);
|
||||
impl Evaluator for ZeroEval {
|
||||
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||
(vec![0.0f32; self.0], 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn rng() -> SmallRng {
|
||||
SmallRng::seed_from_u64(42)
|
||||
}
|
||||
|
||||
fn config_n(n: usize) -> MctsConfig {
|
||||
MctsConfig {
|
||||
n_simulations: n,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0, // off for reproducibility
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Visit count tests ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn visit_counts_sum_to_n_simulations() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng());
|
||||
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||
assert_eq!(total, 50, "visit counts must sum to n_simulations");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_root_children_are_legal() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let legal = env.legal_actions(&state);
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng());
|
||||
for (a, _) in &root.children {
|
||||
assert!(legal.contains(a), "child action {a} is not legal");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Policy tests ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn policy_sums_to_one() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng());
|
||||
let policy = mcts_policy(&root, env.action_space());
|
||||
let sum: f32 = policy.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_zero_for_illegal_actions() {
|
||||
let env = CountdownEnv;
|
||||
// remaining = 1 → only action 0 is legal
|
||||
let state = CState { remaining: 1, to_move: 0 };
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng());
|
||||
let policy = mcts_policy(&root, env.action_space());
|
||||
assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass");
|
||||
}
|
||||
|
||||
// ── Action selection tests ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn greedy_selects_most_visited() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng());
|
||||
let greedy = select_action(&root, 0.0, &mut rng());
|
||||
let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap();
|
||||
assert_eq!(greedy, most_visited);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temperature_sampling_stays_legal() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let legal = env.legal_actions(&state);
|
||||
let mut r = rng();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r);
|
||||
for _ in 0..20 {
|
||||
let a = select_action(&root, 1.0, &mut r);
|
||||
assert!(legal.contains(&a), "sampled action {a} is not legal");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Zero-simulation edge case ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn zero_simulations_uniform_policy() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng());
|
||||
let policy = mcts_policy(&root, env.action_space());
|
||||
// With 0 simulations, fallback is uniform over the 2 legal actions.
|
||||
let sum: f32 = policy.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
// ── Root value ────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn root_q_in_valid_range() {
|
||||
let env = CountdownEnv;
|
||||
let state = env.new_game();
|
||||
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng());
|
||||
let q = root.q();
|
||||
assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]");
|
||||
}
|
||||
|
||||
// ── Integration: run on a real Trictrac game ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn no_panic_on_trictrac_state() {
|
||||
use crate::env::TrictracEnv;
|
||||
|
||||
let env = TrictracEnv;
|
||||
let mut state = env.new_game();
|
||||
let mut r = rng();
|
||||
|
||||
// Advance past the initial chance node to reach a decision node.
|
||||
while env.current_player(&state).is_chance() {
|
||||
env.apply_chance(&mut state, &mut r);
|
||||
}
|
||||
|
||||
if env.current_player(&state).is_terminal() {
|
||||
return; // unlikely but safe
|
||||
}
|
||||
|
||||
let config = MctsConfig {
|
||||
n_simulations: 5, // tiny for speed
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
..MctsConfig::default()
|
||||
};
|
||||
|
||||
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
||||
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
||||
// Every simulation crosses a chance node at depth 1 (dice roll after
|
||||
// the player's move). Since the fix now updates child.n in that case,
|
||||
// children visit counts must sum to exactly n_simulations.
|
||||
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||
assert_eq!(total, config.n_simulations as u32);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
//! MCTS tree node.
|
||||
//!
|
||||
//! [`MctsNode`] holds the visit statistics for one player-decision position in
|
||||
//! the search tree. A node is *expanded* the first time the policy-value
|
||||
//! network is evaluated there; before that it is a leaf.
|
||||
|
||||
/// One node in the MCTS tree, representing a player-decision position.
|
||||
///
|
||||
/// `w` stores the sum of values backed up into this node, always from the
|
||||
/// perspective of **the player who acts here**. `q()` therefore also returns
|
||||
/// a value in `(-1, 1)` from that same perspective.
|
||||
#[derive(Debug)]
|
||||
pub struct MctsNode {
|
||||
/// Visit count `N(s, a)`.
|
||||
pub n: u32,
|
||||
/// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective.
|
||||
pub w: f32,
|
||||
/// Prior probability `P(s, a)` assigned by the policy head (after masked softmax).
|
||||
pub p: f32,
|
||||
/// Children: `(action_index, child_node)`, populated on first expansion.
|
||||
pub children: Vec<(usize, MctsNode)>,
|
||||
/// `true` after the network has been evaluated and children have been set up.
|
||||
pub expanded: bool,
|
||||
}
|
||||
|
||||
impl MctsNode {
|
||||
/// Create a fresh, unexpanded leaf with the given prior probability.
|
||||
pub fn new(prior: f32) -> Self {
|
||||
Self {
|
||||
n: 0,
|
||||
w: 0.0,
|
||||
p: prior,
|
||||
children: Vec::new(),
|
||||
expanded: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// `Q(s, a) = W / N`, or `0.0` if this node has never been visited.
|
||||
#[inline]
|
||||
pub fn q(&self) -> f32 {
|
||||
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||
}
|
||||
|
||||
/// PUCT selection score:
|
||||
///
|
||||
/// ```text
|
||||
/// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a))
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn q_zero_when_unvisited() {
|
||||
let node = MctsNode::new(0.5);
|
||||
assert_eq!(node.q(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn q_reflects_w_over_n() {
|
||||
let mut node = MctsNode::new(0.5);
|
||||
node.n = 4;
|
||||
node.w = 2.0;
|
||||
assert!((node.q() - 0.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn puct_exploration_dominates_unvisited() {
|
||||
// Unvisited child should outscore a visited child with negative Q.
|
||||
let mut visited = MctsNode::new(0.5);
|
||||
visited.n = 10;
|
||||
visited.w = -5.0; // Q = -0.5
|
||||
|
||||
let unvisited = MctsNode::new(0.5);
|
||||
|
||||
let parent_n = 10;
|
||||
let c = 1.5;
|
||||
assert!(
|
||||
unvisited.puct(parent_n, c) > visited.puct(parent_n, c),
|
||||
"unvisited child should have higher PUCT than a negatively-valued visited child"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
//! Simulation, expansion, backup, and noise helpers.
|
||||
//!
|
||||
//! These are internal to the `mcts` module; the public entry points are
|
||||
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
|
||||
|
||||
use rand::Rng;
|
||||
use rand_distr::{Gamma, Distribution};
|
||||
|
||||
use crate::env::GameEnv;
|
||||
use super::{Evaluator, MctsConfig};
|
||||
use super::node::MctsNode;
|
||||
|
||||
// ── Masked softmax ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Numerically stable softmax over `legal` actions only.
|
||||
///
|
||||
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
|
||||
/// Returns a probability vector of length `action_space`.
|
||||
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
|
||||
let mut probs = vec![0.0f32; action_space];
|
||||
if legal.is_empty() {
|
||||
return probs;
|
||||
}
|
||||
let max_logit = legal
|
||||
.iter()
|
||||
.map(|&a| logits[a])
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for &a in legal {
|
||||
let e = (logits[a] - max_logit).exp();
|
||||
probs[a] = e;
|
||||
sum += e;
|
||||
}
|
||||
if sum > 0.0 {
|
||||
for &a in legal {
|
||||
probs[a] /= sum;
|
||||
}
|
||||
} else {
|
||||
let uniform = 1.0 / legal.len() as f32;
|
||||
for &a in legal {
|
||||
probs[a] = uniform;
|
||||
}
|
||||
}
|
||||
probs
|
||||
}
|
||||
|
||||
// ── Dirichlet noise ────────────────────────────────────────────────────────
|
||||
|
||||
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
|
||||
///
|
||||
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
|
||||
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
|
||||
pub(super) fn add_dirichlet_noise(
|
||||
node: &mut MctsNode,
|
||||
alpha: f32,
|
||||
eps: f32,
|
||||
rng: &mut impl Rng,
|
||||
) {
|
||||
let n = node.children.len();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
|
||||
return;
|
||||
};
|
||||
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
|
||||
let sum: f32 = samples.iter().sum();
|
||||
if sum <= 0.0 {
|
||||
return;
|
||||
}
|
||||
for (i, (_, child)) in node.children.iter_mut().enumerate() {
|
||||
let noise = samples[i] / sum;
|
||||
child.p = (1.0 - eps) * child.p + eps * noise;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Expansion ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Evaluate the network at `state` and populate `node` with children.
|
||||
///
|
||||
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
|
||||
/// Returns the network value estimate from `player_idx`'s perspective.
|
||||
pub(super) fn expand<E: GameEnv>(
|
||||
node: &mut MctsNode,
|
||||
state: &E::State,
|
||||
env: &E,
|
||||
evaluator: &dyn Evaluator,
|
||||
player_idx: usize,
|
||||
) -> f32 {
|
||||
let obs = env.observation(state, player_idx);
|
||||
let legal = env.legal_actions(state);
|
||||
let (logits, value) = evaluator.evaluate(&obs);
|
||||
let priors = masked_softmax(&logits, &legal, env.action_space());
|
||||
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
|
||||
node.expanded = true;
|
||||
node.n = 1;
|
||||
node.w = value;
|
||||
value
|
||||
}
|
||||
|
||||
// ── Simulation ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// One MCTS simulation from an **already-expanded** decision node.
|
||||
///
|
||||
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
|
||||
/// and backs up the result.
|
||||
///
|
||||
/// * `player_idx` — the player (0 or 1) who acts at `state`.
|
||||
/// * Returns the backed-up value **from `player_idx`'s perspective**.
|
||||
pub(super) fn simulate<E: GameEnv>(
|
||||
node: &mut MctsNode,
|
||||
state: E::State,
|
||||
env: &E,
|
||||
evaluator: &dyn Evaluator,
|
||||
config: &MctsConfig,
|
||||
rng: &mut impl Rng,
|
||||
player_idx: usize,
|
||||
) -> f32 {
|
||||
debug_assert!(node.expanded, "simulate called on unexpanded node");
|
||||
|
||||
// ── Selection: child with highest PUCT ────────────────────────────────
|
||||
let parent_n = node.n;
|
||||
let best = node
|
||||
.children
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, (_, a)), (_, (_, b))| {
|
||||
a.puct(parent_n, config.c_puct)
|
||||
.partial_cmp(&b.puct(parent_n, config.c_puct))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.expect("expanded node must have at least one child");
|
||||
|
||||
let (action, child) = &mut node.children[best];
|
||||
let action = *action;
|
||||
|
||||
// ── Apply action + advance through any chance nodes ───────────────────
|
||||
let mut next_state = state;
|
||||
env.apply(&mut next_state, action);
|
||||
|
||||
// Track whether we crossed a chance node (dice roll) on the way down.
|
||||
// If we did, the child's cached legal actions are for a *different* dice
|
||||
// outcome and must not be reused — evaluate with the network directly.
|
||||
let mut crossed_chance = false;
|
||||
while env.current_player(&next_state).is_chance() {
|
||||
env.apply_chance(&mut next_state, rng);
|
||||
crossed_chance = true;
|
||||
}
|
||||
|
||||
let next_cp = env.current_player(&next_state);
|
||||
|
||||
// ── Evaluate leaf or terminal ──────────────────────────────────────────
|
||||
// All values are converted to `player_idx`'s perspective before backup.
|
||||
let child_value = if next_cp.is_terminal() {
|
||||
let returns = env
|
||||
.returns(&next_state)
|
||||
.expect("terminal node must have returns");
|
||||
returns[player_idx]
|
||||
} else {
|
||||
let child_player = next_cp.index().unwrap();
|
||||
let v = if crossed_chance {
|
||||
// Outcome sampling: after dice, evaluate the resulting position
|
||||
// directly with the network. Do NOT build the tree across chance
|
||||
// boundaries — the dice change which actions are legal, so any
|
||||
// previously cached children would be for a different outcome.
|
||||
let obs = env.observation(&next_state, child_player);
|
||||
let (_, value) = evaluator.evaluate(&obs);
|
||||
// Record the visit so that PUCT and mcts_policy use real counts.
|
||||
// Without this, child.n stays 0 for every simulation in games where
|
||||
// every player action is immediately followed by a chance node (e.g.
|
||||
// Trictrac), causing mcts_policy to always return a uniform policy.
|
||||
child.n += 1;
|
||||
child.w += value;
|
||||
value
|
||||
} else if child.expanded {
|
||||
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||
} else {
|
||||
expand::<E>(child, &next_state, env, evaluator, child_player)
|
||||
};
|
||||
// Negate when the child belongs to the opponent.
|
||||
if child_player == player_idx { v } else { -v }
|
||||
};
|
||||
|
||||
// ── Backup ────────────────────────────────────────────────────────────
|
||||
node.n += 1;
|
||||
node.w += child_value;
|
||||
|
||||
child_value
|
||||
}
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
//! Two-hidden-layer MLP policy-value network.
|
||||
//!
|
||||
//! ```text
|
||||
//! Input [B, obs_size]
|
||||
//! → Linear(obs → hidden) → ReLU
|
||||
//! → Linear(hidden → hidden) → ReLU
|
||||
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
|
||||
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
|
||||
//! ```
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig},
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::{
|
||||
activation::{relu, tanh},
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use super::PolicyValueNet;
|
||||
|
||||
// ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for [`MlpNet`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MlpConfig {
|
||||
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||
pub obs_size: usize,
|
||||
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||
pub action_size: usize,
|
||||
/// Width of both hidden layers.
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Default for MlpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
obs_size: 217,
|
||||
action_size: 514,
|
||||
hidden_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Network ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Simple two-hidden-layer MLP with shared trunk and two heads.
|
||||
///
|
||||
/// Prefer this over [`ResNet`](super::ResNet) when training time is a
|
||||
/// priority, or as a fast baseline.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MlpNet<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
policy_head: Linear<B>,
|
||||
value_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MlpNet<B> {
|
||||
/// Construct a fresh network with random weights.
|
||||
pub fn new(config: &MlpConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||
policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||
value_head: LinearConfig::new(config.hidden_size, 1).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||
///
|
||||
/// The file is written exactly at `path`; callers should append `.mpk` if
|
||||
/// they want the conventional extension.
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
CompactRecorder::new()
|
||||
.record(self.clone().into_record(), path.to_path_buf())
|
||||
.map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}"))
|
||||
}
|
||||
|
||||
/// Load weights from `path` into a fresh model built from `config`.
|
||||
pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||
let record = CompactRecorder::new()
|
||||
.load(path.to_path_buf(), device)
|
||||
.map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?;
|
||||
Ok(Self::new(config, device).load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PolicyValueNet<B> for MlpNet<B> {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||
let x = relu(self.fc1.forward(obs));
|
||||
let x = relu(self.fc2.forward(x));
|
||||
let policy = self.policy_head.forward(x.clone());
|
||||
let value = tanh(self.value_head.forward(x));
|
||||
(policy, value)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn default_net() -> MlpNet<B> {
|
||||
MlpNet::new(&MlpConfig::default(), &device())
|
||||
}
|
||||
|
||||
fn zeros_obs(batch: usize) -> Tensor<B, 2> {
|
||||
Tensor::zeros([batch, 217], &device())
|
||||
}
|
||||
|
||||
// ── Shape tests ───────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn forward_output_shapes() {
|
||||
let net = default_net();
|
||||
let obs = zeros_obs(4);
|
||||
let (policy, value) = net.forward(obs);
|
||||
|
||||
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
|
||||
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_single_sample() {
|
||||
let net = default_net();
|
||||
let (policy, value) = net.forward(zeros_obs(1));
|
||||
assert_eq!(policy.dims(), [1, 514]);
|
||||
assert_eq!(value.dims(), [1, 1]);
|
||||
}
|
||||
|
||||
// ── Value bounds ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn value_in_tanh_range() {
|
||||
let net = default_net();
|
||||
// Use a non-zero input so the output is not trivially at 0.
|
||||
let obs = Tensor::<B, 2>::ones([8, 217], &device());
|
||||
let (_, value) = net.forward(obs);
|
||||
let data: Vec<f32> = value.into_data().to_vec().unwrap();
|
||||
for v in &data {
|
||||
assert!(
|
||||
*v > -1.0 && *v < 1.0,
|
||||
"value {v} is outside open interval (-1, 1)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Policy logits ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn policy_logits_not_all_equal() {
|
||||
// With random weights the 514 logits should not all be identical.
|
||||
let net = default_net();
|
||||
let (policy, _) = net.forward(zeros_obs(1));
|
||||
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
|
||||
let first = data[0];
|
||||
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
|
||||
assert!(!all_same, "all policy logits are identical — network may be degenerate");
|
||||
}
|
||||
|
||||
// ── Config propagation ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn custom_config_shapes() {
|
||||
let config = MlpConfig {
|
||||
obs_size: 10,
|
||||
action_size: 20,
|
||||
hidden_size: 32,
|
||||
};
|
||||
let net = MlpNet::<B>::new(&config, &device());
|
||||
let obs = Tensor::zeros([3, 10], &device());
|
||||
let (policy, value) = net.forward(obs);
|
||||
assert_eq!(policy.dims(), [3, 20]);
|
||||
assert_eq!(value.dims(), [3, 1]);
|
||||
}
|
||||
|
||||
// ── Save / Load ───────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn save_load_preserves_weights() {
|
||||
let config = MlpConfig::default();
|
||||
let net = default_net();
|
||||
|
||||
// Forward pass before saving.
|
||||
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||
let (policy_before, value_before) = net.forward(obs.clone());
|
||||
|
||||
// Save to a temp file.
|
||||
let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk");
|
||||
net.save(&path).expect("save failed");
|
||||
|
||||
// Load into a fresh model.
|
||||
let loaded = MlpNet::<B>::load(&config, &path, &device()).expect("load failed");
|
||||
let (policy_after, value_after) = loaded.forward(obs);
|
||||
|
||||
// Outputs must be bitwise identical.
|
||||
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
|
||||
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
|
||||
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
|
||||
}
|
||||
|
||||
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
|
||||
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
|
||||
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
|
||||
}
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
//! Neural network abstractions for policy-value learning.
|
||||
//!
|
||||
//! # Trait
|
||||
//!
|
||||
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
|
||||
//! implement. It takes an observation tensor and returns raw policy logits
|
||||
//! plus a tanh-squashed scalar value estimate.
|
||||
//!
|
||||
//! # Architectures
|
||||
//!
|
||||
//! | Module | Description | Default hidden |
|
||||
//! |--------|-------------|----------------|
|
||||
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
|
||||
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
|
||||
//!
|
||||
//! # Backend convention
|
||||
//!
|
||||
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
|
||||
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
|
||||
//! through the forward pass.
|
||||
//!
|
||||
//! Both modes use the exact same struct; only the type-level backend changes:
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use burn::backend::{Autodiff, NdArray};
|
||||
//! type InferBackend = NdArray<f32>;
|
||||
//! type TrainBackend = Autodiff<NdArray<f32>>;
|
||||
//!
|
||||
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
|
||||
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
|
||||
//! ```
|
||||
//!
|
||||
//! # Output shapes
|
||||
//!
|
||||
//! Given a batch of `B` observations of size `obs_size`:
|
||||
//!
|
||||
//! | Output | Shape | Range |
|
||||
//! |--------|-------|-------|
|
||||
//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) |
|
||||
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
|
||||
//!
|
||||
//! Callers are responsible for masking illegal actions in `policy_logits`
|
||||
//! before passing to softmax.
|
||||
|
||||
pub mod mlp;
|
||||
pub mod qnet;
|
||||
pub mod resnet;
|
||||
|
||||
pub use mlp::{MlpConfig, MlpNet};
|
||||
pub use qnet::{QNet, QNetConfig};
|
||||
pub use resnet::{ResNet, ResNetConfig};
|
||||
|
||||
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||
|
||||
/// A neural network that produces a policy and a value from an observation.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||||
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||||
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||||
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||||
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||||
}
|
||||
|
||||
/// A neural network that outputs one Q-value per action.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - output: `[batch, action_size]` — raw Q-values (no activation)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
|
||||
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
|
||||
}
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
//! Single-headed Q-value network for DQN.
|
||||
//!
|
||||
//! ```text
|
||||
//! Input [B, obs_size]
|
||||
//! → Linear(obs → hidden) → ReLU
|
||||
//! → Linear(hidden → hidden) → ReLU
|
||||
//! → Linear(hidden → action_size) ← raw Q-values, no activation
|
||||
//! ```
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig},
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::{activation::relu, backend::Backend, Tensor},
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use super::QValueNet;
|
||||
|
||||
// ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for [`QNet`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QNetConfig {
|
||||
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||
pub obs_size: usize,
|
||||
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||
pub action_size: usize,
|
||||
/// Width of both hidden layers.
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Default for QNetConfig {
|
||||
fn default() -> Self {
|
||||
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Network ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Two-hidden-layer MLP that outputs one Q-value per action.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct QNet<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
q_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> QNet<B> {
|
||||
/// Construct a fresh network with random weights.
|
||||
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
CompactRecorder::new()
|
||||
.record(self.clone().into_record(), path.to_path_buf())
|
||||
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
|
||||
}
|
||||
|
||||
/// Load weights from `path` into a fresh model built from `config`.
|
||||
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||
let record = CompactRecorder::new()
|
||||
.load(path.to_path_buf(), device)
|
||||
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
|
||||
Ok(Self::new(config, device).load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> QValueNet<B> for QNet<B> {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = relu(self.fc1.forward(obs));
|
||||
let x = relu(self.fc2.forward(x));
|
||||
self.q_head.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device { Default::default() }
|
||||
|
||||
fn default_net() -> QNet<B> {
|
||||
QNet::new(&QNetConfig::default(), &device())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_output_shape() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::zeros([4, 217], &device());
|
||||
let q = net.forward(obs);
|
||||
assert_eq!(q.dims(), [4, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_single_sample() {
|
||||
let net = default_net();
|
||||
let q = net.forward(Tensor::zeros([1, 217], &device()));
|
||||
assert_eq!(q.dims(), [1, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn q_values_not_all_equal() {
|
||||
let net = default_net();
|
||||
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
|
||||
.into_data().to_vec().unwrap();
|
||||
let first = q[0];
|
||||
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_config_shapes() {
|
||||
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
|
||||
let net = QNet::<B>::new(&cfg, &device());
|
||||
let q = net.forward(Tensor::zeros([3, 10], &device()));
|
||||
assert_eq!(q.dims(), [3, 20]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_load_preserves_weights() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
|
||||
|
||||
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
|
||||
net.save(&path).expect("save failed");
|
||||
|
||||
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
|
||||
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
|
||||
|
||||
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
|
||||
}
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
//! Residual-block policy-value network.
|
||||
//!
|
||||
//! ```text
|
||||
//! Input [B, obs_size]
|
||||
//! → Linear(obs → hidden) → ReLU (input projection)
|
||||
//! → ResBlock × 4 (residual trunk)
|
||||
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
|
||||
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
|
||||
//!
|
||||
//! ResBlock:
|
||||
//! x → Linear → ReLU → Linear → (+x) → ReLU
|
||||
//! ```
|
||||
//!
|
||||
//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better
|
||||
//! suited for long training runs where board-pattern recognition matters.
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig},
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::{
|
||||
activation::{relu, tanh},
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use super::PolicyValueNet;
|
||||
|
||||
// ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for [`ResNet`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResNetConfig {
|
||||
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||
pub obs_size: usize,
|
||||
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||
pub action_size: usize,
|
||||
/// Width of all hidden layers (input projection + residual blocks).
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Default for ResNetConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
obs_size: 217,
|
||||
action_size: 514,
|
||||
hidden_size: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Residual block ────────────────────────────────────────────────────────────
|
||||
|
||||
/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`.
|
||||
///
|
||||
/// Both linear layers preserve the hidden dimension so the skip connection
|
||||
/// can be added without projection.
|
||||
#[derive(Module, Debug)]
|
||||
struct ResBlock<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResBlock<B> {
|
||||
fn new(hidden: usize, device: &B::Device) -> Self {
|
||||
Self {
|
||||
fc1: LinearConfig::new(hidden, hidden).init(device),
|
||||
fc2: LinearConfig::new(hidden, hidden).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let residual = x.clone();
|
||||
let out = relu(self.fc1.forward(x));
|
||||
relu(self.fc2.forward(out) + residual)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Network ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Four-residual-block policy-value network.
|
||||
///
|
||||
/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and
|
||||
/// when representing complex positional patterns is important.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResNet<B: Backend> {
|
||||
input: Linear<B>,
|
||||
block0: ResBlock<B>,
|
||||
block1: ResBlock<B>,
|
||||
block2: ResBlock<B>,
|
||||
block3: ResBlock<B>,
|
||||
policy_head: Linear<B>,
|
||||
value_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResNet<B> {
|
||||
/// Construct a fresh network with random weights.
|
||||
pub fn new(config: &ResNetConfig, device: &B::Device) -> Self {
|
||||
let h = config.hidden_size;
|
||||
Self {
|
||||
input: LinearConfig::new(config.obs_size, h).init(device),
|
||||
block0: ResBlock::new(h, device),
|
||||
block1: ResBlock::new(h, device),
|
||||
block2: ResBlock::new(h, device),
|
||||
block3: ResBlock::new(h, device),
|
||||
policy_head: LinearConfig::new(h, config.action_size).init(device),
|
||||
value_head: LinearConfig::new(h, 1).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
CompactRecorder::new()
|
||||
.record(self.clone().into_record(), path.to_path_buf())
|
||||
.map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}"))
|
||||
}
|
||||
|
||||
/// Load weights from `path` into a fresh model built from `config`.
|
||||
pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||
let record = CompactRecorder::new()
|
||||
.load(path.to_path_buf(), device)
|
||||
.map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?;
|
||||
Ok(Self::new(config, device).load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PolicyValueNet<B> for ResNet<B> {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||
let x = relu(self.input.forward(obs));
|
||||
let x = self.block0.forward(x);
|
||||
let x = self.block1.forward(x);
|
||||
let x = self.block2.forward(x);
|
||||
let x = self.block3.forward(x);
|
||||
let policy = self.policy_head.forward(x.clone());
|
||||
let value = tanh(self.value_head.forward(x));
|
||||
(policy, value)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn small_config() -> ResNetConfig {
|
||||
// Use a small hidden size so tests are fast.
|
||||
ResNetConfig {
|
||||
obs_size: 217,
|
||||
action_size: 514,
|
||||
hidden_size: 64,
|
||||
}
|
||||
}
|
||||
|
||||
fn net() -> ResNet<B> {
|
||||
ResNet::new(&small_config(), &device())
|
||||
}
|
||||
|
||||
// ── Shape tests ───────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn forward_output_shapes() {
|
||||
let obs = Tensor::zeros([4, 217], &device());
|
||||
let (policy, value) = net().forward(obs);
|
||||
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
|
||||
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_single_sample() {
|
||||
let (policy, value) = net().forward(Tensor::zeros([1, 217], &device()));
|
||||
assert_eq!(policy.dims(), [1, 514]);
|
||||
assert_eq!(value.dims(), [1, 1]);
|
||||
}
|
||||
|
||||
// ── Value bounds ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn value_in_tanh_range() {
|
||||
let obs = Tensor::<B, 2>::ones([8, 217], &device());
|
||||
let (_, value) = net().forward(obs);
|
||||
let data: Vec<f32> = value.into_data().to_vec().unwrap();
|
||||
for v in &data {
|
||||
assert!(
|
||||
*v > -1.0 && *v < 1.0,
|
||||
"value {v} is outside open interval (-1, 1)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Residual connections ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn policy_logits_not_all_equal() {
|
||||
let (policy, _) = net().forward(Tensor::zeros([1, 217], &device()));
|
||||
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
|
||||
let first = data[0];
|
||||
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
|
||||
assert!(!all_same, "all policy logits are identical");
|
||||
}
|
||||
|
||||
// ── Save / Load ───────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn save_load_preserves_weights() {
|
||||
let config = small_config();
|
||||
let model = net();
|
||||
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||
|
||||
let (policy_before, value_before) = model.forward(obs.clone());
|
||||
|
||||
let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk");
|
||||
model.save(&path).expect("save failed");
|
||||
|
||||
let loaded = ResNet::<B>::load(&config, &path, &device()).expect("load failed");
|
||||
let (policy_after, value_after) = loaded.forward(obs);
|
||||
|
||||
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
|
||||
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
|
||||
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
|
||||
}
|
||||
|
||||
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
|
||||
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
|
||||
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
|
||||
}
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
// ── Integration: both architectures satisfy PolicyValueNet ────────────
|
||||
|
||||
#[test]
|
||||
fn resnet_satisfies_trait() {
|
||||
fn requires_net<B: Backend, N: PolicyValueNet<B>>(net: &N, obs: Tensor<B, 2>) {
|
||||
let (p, v) = net.forward(obs);
|
||||
assert_eq!(p.dims()[1], 514);
|
||||
assert_eq!(v.dims()[1], 1);
|
||||
}
|
||||
requires_net(&net(), Tensor::zeros([2, 217], &device()));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,391 +0,0 @@
|
|||
//! End-to-end integration tests for the AlphaZero training pipeline.
|
||||
//!
|
||||
//! Each test exercises the full chain:
|
||||
//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`]
|
||||
//!
|
||||
//! Two environments are used:
|
||||
//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves.
|
||||
//! Used when we need many iterations without worrying about runtime.
|
||||
//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that
|
||||
//! the full pipeline compiles and runs correctly with 217-dim observations
|
||||
//! and 514-dim action spaces.
|
||||
//!
|
||||
//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep
|
||||
//! runtime minimal; correctness, not training quality, is what matters here.
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
module::AutodiffModule,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step},
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::MctsConfig,
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet},
|
||||
};
|
||||
|
||||
// ── Backend aliases ────────────────────────────────────────────────────────
|
||||
|
||||
type Train = Autodiff<NdArray<f32>>;
|
||||
type Infer = NdArray<f32>;
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn train_device() -> <Train as burn::tensor::backend::Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn infer_device() -> <Infer as burn::tensor::backend::Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Tiny 64-unit MLP, compatible with an obs/action space of any size.
|
||||
fn tiny_mlp(obs: usize, actions: usize) -> MlpNet<Train> {
|
||||
let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 };
|
||||
MlpNet::new(&cfg, &train_device())
|
||||
}
|
||||
|
||||
fn tiny_mcts(n: usize) -> MctsConfig {
|
||||
MctsConfig {
|
||||
n_simulations: n,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn seeded() -> SmallRng {
|
||||
SmallRng::seed_from_u64(0)
|
||||
}
|
||||
|
||||
// ── Countdown environment (fast, local, no external deps) ─────────────────
|
||||
//
|
||||
// Two players alternate subtracting 1 or 2 from a counter that starts at N.
|
||||
// The player who brings the counter to 0 wins.
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CState {
|
||||
remaining: u8,
|
||||
to_move: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CountdownEnv(u8); // starting value
|
||||
|
||||
impl GameEnv for CountdownEnv {
|
||||
type State = CState;
|
||||
|
||||
fn new_game(&self) -> CState {
|
||||
CState { remaining: self.0, to_move: 0 }
|
||||
}
|
||||
|
||||
fn current_player(&self, s: &CState) -> Player {
|
||||
if s.remaining == 0 { Player::Terminal }
|
||||
else if s.to_move == 0 { Player::P1 }
|
||||
else { Player::P2 }
|
||||
}
|
||||
|
||||
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||
}
|
||||
|
||||
fn apply(&self, s: &mut CState, action: usize) {
|
||||
let sub = (action as u8) + 1;
|
||||
if s.remaining <= sub {
|
||||
s.remaining = 0;
|
||||
} else {
|
||||
s.remaining -= sub;
|
||||
s.to_move = 1 - s.to_move;
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||
|
||||
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||
vec![s.remaining as f32 / self.0 as f32, s.to_move as f32]
|
||||
}
|
||||
|
||||
fn obs_size(&self) -> usize { 2 }
|
||||
fn action_space(&self) -> usize { 2 }
|
||||
|
||||
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||
if s.remaining != 0 { return None; }
|
||||
let mut r = [-1.0f32; 2];
|
||||
r[s.to_move] = 1.0;
|
||||
Some(r)
|
||||
}
|
||||
}
|
||||
|
||||
// ── 1. Full loop on CountdownEnv ──────────────────────────────────────────
|
||||
|
||||
/// The canonical AlphaZero loop: self-play → replay → train, iterated.
|
||||
/// Uses CountdownEnv so each game terminates in < 10 moves.
|
||||
#[test]
|
||||
fn countdown_full_loop_no_panic() {
|
||||
let env = CountdownEnv(8);
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(3);
|
||||
|
||||
let mut model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = ReplayBuffer::new(1_000);
|
||||
|
||||
for _iter in 0..5 {
|
||||
// Self-play: 3 games per iteration.
|
||||
for _ in 0..3 {
|
||||
let infer = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
|
||||
assert!(!samples.is_empty());
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
||||
// Training: 4 gradient steps per iteration.
|
||||
if replay.len() >= 4 {
|
||||
for _ in 0..4 {
|
||||
let batch: Vec<TrainSample> = replay
|
||||
.sample_batch(4, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
|
||||
model = m;
|
||||
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(replay.len() > 0);
|
||||
}
|
||||
|
||||
// ── 2. Replay buffer invariants ───────────────────────────────────────────
|
||||
|
||||
/// After several Countdown games, replay capacity is respected and batch
|
||||
/// shapes are consistent.
|
||||
#[test]
|
||||
fn replay_buffer_capacity_and_shapes() {
|
||||
let env = CountdownEnv(6);
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(2);
|
||||
let model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
|
||||
let capacity = 50;
|
||||
let mut replay = ReplayBuffer::new(capacity);
|
||||
|
||||
for _ in 0..20 {
|
||||
let infer = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
||||
assert!(replay.len() <= capacity, "buffer exceeded capacity");
|
||||
assert!(replay.len() > 0);
|
||||
|
||||
let batch = replay.sample_batch(8, &mut rng);
|
||||
assert_eq!(batch.len(), 8.min(replay.len()));
|
||||
for s in &batch {
|
||||
assert_eq!(s.obs.len(), env.obs_size());
|
||||
assert_eq!(s.policy.len(), env.action_space());
|
||||
let policy_sum: f32 = s.policy.iter().sum();
|
||||
assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}");
|
||||
assert!(s.value.abs() <= 1.0, "value {} out of range", s.value);
|
||||
}
|
||||
}
|
||||
|
||||
// ── 3. TrictracEnv: sample shapes ─────────────────────────────────────────
|
||||
|
||||
/// Verify that one TrictracEnv episode produces samples with the correct
|
||||
/// tensor dimensions: obs = 217, policy = 514.
|
||||
#[test]
|
||||
fn trictrac_sample_shapes() {
|
||||
let env = TrictracEnv;
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(2);
|
||||
let model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
|
||||
let infer = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
|
||||
|
||||
assert!(!samples.is_empty(), "Trictrac episode produced no samples");
|
||||
|
||||
for (i, s) in samples.iter().enumerate() {
|
||||
assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len());
|
||||
assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len());
|
||||
let policy_sum: f32 = s.policy.iter().sum();
|
||||
assert!(
|
||||
(policy_sum - 1.0).abs() < 1e-4,
|
||||
"sample {i}: policy sums to {policy_sum}"
|
||||
);
|
||||
assert!(
|
||||
s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
|
||||
"sample {i}: unexpected value {}",
|
||||
s.value
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── 4. TrictracEnv: training step after real self-play ────────────────────
|
||||
|
||||
/// Collect one Trictrac episode, then verify that a gradient step runs
|
||||
/// without panic and produces a finite loss.
|
||||
#[test]
|
||||
fn trictrac_train_step_finite_loss() {
|
||||
let env = TrictracEnv;
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(2);
|
||||
let model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = ReplayBuffer::new(10_000);
|
||||
|
||||
// Generate one episode.
|
||||
let infer = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
|
||||
assert!(!samples.is_empty());
|
||||
let n_samples = samples.len();
|
||||
replay.extend(samples);
|
||||
|
||||
// Train on a batch from this episode.
|
||||
let batch_size = 8.min(n_samples);
|
||||
let batch: Vec<TrainSample> = replay
|
||||
.sample_batch(batch_size, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
|
||||
assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}");
|
||||
assert!(loss > 0.0, "loss should be positive");
|
||||
}
|
||||
|
||||
// ── 5. Backend transfer: train → infer → same outputs ─────────────────────
|
||||
|
||||
/// Weights transferred from the training backend to the inference backend
|
||||
/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes.
|
||||
#[test]
|
||||
fn valid_model_matches_train_model_outputs() {
|
||||
use burn::tensor::{Tensor, TensorData};
|
||||
|
||||
let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||
let train_model = MlpNet::<Train>::new(&cfg, &train_device());
|
||||
let infer_model: MlpNet<Infer> = train_model.valid();
|
||||
|
||||
// Build the same input on both backends.
|
||||
let obs_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
|
||||
|
||||
let obs_train = Tensor::<Train, 2>::from_data(
|
||||
TensorData::new(obs_data.clone(), [1, 4]),
|
||||
&train_device(),
|
||||
);
|
||||
let obs_infer = Tensor::<Infer, 2>::from_data(
|
||||
TensorData::new(obs_data, [1, 4]),
|
||||
&infer_device(),
|
||||
);
|
||||
|
||||
let (p_train, v_train) = train_model.forward(obs_train);
|
||||
let (p_infer, v_infer) = infer_model.forward(obs_infer);
|
||||
|
||||
let p_train: Vec<f32> = p_train.into_data().to_vec().unwrap();
|
||||
let p_infer: Vec<f32> = p_infer.into_data().to_vec().unwrap();
|
||||
let v_train: Vec<f32> = v_train.into_data().to_vec().unwrap();
|
||||
let v_infer: Vec<f32> = v_infer.into_data().to_vec().unwrap();
|
||||
|
||||
for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() {
|
||||
assert!(
|
||||
(a - b).abs() < 1e-5,
|
||||
"policy[{i}] differs after valid(): train={a}, infer={b}"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
(v_train[0] - v_infer[0]).abs() < 1e-5,
|
||||
"value differs after valid(): train={}, infer={}",
|
||||
v_train[0], v_infer[0]
|
||||
);
|
||||
}
|
||||
|
||||
// ── 6. Loss converges on a fixed batch ────────────────────────────────────
|
||||
|
||||
/// With repeated gradient steps on the same Countdown batch, the loss must
|
||||
/// decrease monotonically (or at least end lower than it started).
|
||||
#[test]
|
||||
fn loss_decreases_on_fixed_batch() {
|
||||
let env = CountdownEnv(6);
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(3);
|
||||
let model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
|
||||
// Collect a fixed batch from one episode.
|
||||
let infer = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples: Vec<TrainSample> = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng);
|
||||
assert!(!samples.is_empty());
|
||||
|
||||
let batch: Vec<TrainSample> = {
|
||||
let mut replay = ReplayBuffer::new(1000);
|
||||
replay.extend(samples);
|
||||
replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect()
|
||||
};
|
||||
|
||||
// Overfit on the same fixed batch for 20 steps.
|
||||
let mut model = tiny_mlp(env.obs_size(), env.action_space());
|
||||
let mut first_loss = f32::NAN;
|
||||
let mut last_loss = f32::NAN;
|
||||
|
||||
for step in 0..20 {
|
||||
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2);
|
||||
model = m;
|
||||
assert!(loss.is_finite(), "loss is not finite at step {step}");
|
||||
if step == 0 { first_loss = loss; }
|
||||
last_loss = loss;
|
||||
}
|
||||
|
||||
assert!(
|
||||
last_loss < first_loss,
|
||||
"loss did not decrease after 20 steps: first={first_loss}, last={last_loss}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── 7. Trictrac: multi-iteration loop ─────────────────────────────────────
|
||||
|
||||
/// Two full self-play + train iterations on TrictracEnv.
|
||||
/// Verifies the entire pipeline runs without panic end-to-end.
|
||||
#[test]
|
||||
fn trictrac_two_iteration_loop() {
|
||||
let env = TrictracEnv;
|
||||
let mut rng = seeded();
|
||||
let mcts = tiny_mcts(2);
|
||||
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
let mut model = MlpNet::<Train>::new(&cfg, &train_device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = ReplayBuffer::new(20_000);
|
||||
|
||||
for iter in 0..2 {
|
||||
// Self-play: 1 game per iteration.
|
||||
let infer: MlpNet<Infer> = model.valid();
|
||||
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
|
||||
let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
|
||||
assert!(!samples.is_empty(), "iter {iter}: episode was empty");
|
||||
replay.extend(samples);
|
||||
|
||||
// Training: 3 gradient steps.
|
||||
let batch_size = 16.min(replay.len());
|
||||
for _ in 0..3 {
|
||||
let batch: Vec<TrainSample> = replay
|
||||
.sample_batch(batch_size, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
|
||||
model = m;
|
||||
assert!(loss.is_finite(), "iter {iter}: loss={loss}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -25,9 +25,5 @@ rand = "0.9"
|
|||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[[bin]]
|
||||
name = "random_game"
|
||||
path = "src/bin/random_game.rs"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,262 +0,0 @@
|
|||
//! Run one or many games of trictrac between two random players.
|
||||
//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`.
|
||||
//! In multi-game mode, runs silently and reports throughput at the end.
|
||||
//!
|
||||
//! Usage:
|
||||
//! cargo run --bin random_game -- [--seed <u64>] [--games <usize>] [--max-steps <usize>] [--verbose]
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::env;
|
||||
use std::time::Instant;
|
||||
|
||||
use trictrac_store::{
|
||||
training_common::sample_valid_action,
|
||||
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
|
||||
};
|
||||
|
||||
// ── CLI args ──────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
seed: Option<u64>,
|
||||
games: usize,
|
||||
max_steps: usize,
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut seed = None;
|
||||
let mut games = 1;
|
||||
let mut max_steps = 10_000;
|
||||
let mut verbose = false;
|
||||
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--seed" => {
|
||||
i += 1;
|
||||
seed = args.get(i).and_then(|s| s.parse().ok());
|
||||
}
|
||||
"--games" => {
|
||||
i += 1;
|
||||
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
|
||||
games = v;
|
||||
}
|
||||
}
|
||||
"--max-steps" => {
|
||||
i += 1;
|
||||
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
|
||||
max_steps = v;
|
||||
}
|
||||
}
|
||||
"--verbose" => verbose = true,
|
||||
_ => {}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Args {
|
||||
seed,
|
||||
games,
|
||||
max_steps,
|
||||
verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn player_label(id: u64) -> &'static str {
|
||||
if id == 1 { "White" } else { "Black" }
|
||||
}
|
||||
|
||||
/// Apply a `Roll` + `RollResult` in one logical step, returning the dice.
|
||||
/// This collapses the two-step dice phase into a single "chance node" action,
|
||||
/// matching how the OpenSpiel layer exposes it.
|
||||
fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result<Dice, String> {
|
||||
// RollDice → RollWaiting
|
||||
state
|
||||
.consume(&GameEvent::Roll { player_id: state.active_player_id })
|
||||
.map_err(|e| format!("Roll event failed: {e}"))?;
|
||||
|
||||
// RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole)
|
||||
let dice = roller.roll();
|
||||
state
|
||||
.consume(&GameEvent::RollResult { player_id: state.active_player_id, dice })
|
||||
.map_err(|e| format!("RollResult event failed: {e}"))?;
|
||||
|
||||
Ok(dice)
|
||||
}
|
||||
|
||||
/// Sample a random action and apply it to `state`, handling the Black-mirror
|
||||
/// transform exactly as `cxxengine.rs::apply_action` does:
|
||||
///
|
||||
/// 1. For Black, build a mirrored view of the state so that `sample_valid_action`
|
||||
/// and `to_event` always reason from White's perspective.
|
||||
/// 2. Mirror the resulting event back to the original coordinate frame before
|
||||
/// calling `state.consume`.
|
||||
///
|
||||
/// Returns the chosen action (in the view's coordinate frame) for display.
|
||||
fn apply_player_action(state: &mut GameState) -> Result<(), String> {
|
||||
let needs_mirror = state.active_player_id == 2;
|
||||
|
||||
// Build a White-perspective view: borrowed for White, owned mirror for Black.
|
||||
let view: Cow<GameState> = if needs_mirror {
|
||||
Cow::Owned(state.mirror())
|
||||
} else {
|
||||
Cow::Borrowed(state)
|
||||
};
|
||||
|
||||
let action = sample_valid_action(&view)
|
||||
.ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?;
|
||||
|
||||
let event = action
|
||||
.to_event(&view)
|
||||
.ok_or_else(|| format!("could not convert {action:?} to event"))?;
|
||||
|
||||
// Translate the event from the view's frame back to the game's frame.
|
||||
let event = if needs_mirror { event.get_mirror(false) } else { event };
|
||||
|
||||
state
|
||||
.consume(&event)
|
||||
.map_err(|e| format!("consume({action:?}): {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Single game ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run one full game, optionally printing play-by-play.
|
||||
/// Returns `(steps, truncated)`.
|
||||
fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) {
|
||||
let mut state = GameState::new_with_players("White", "Black");
|
||||
let mut step = 0usize;
|
||||
|
||||
if !quiet {
|
||||
println!("{state}");
|
||||
}
|
||||
|
||||
while state.stage != Stage::Ended {
|
||||
step += 1;
|
||||
if step > max_steps {
|
||||
return (step - 1, true);
|
||||
}
|
||||
|
||||
match state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
let player = state.active_player_id;
|
||||
match apply_dice_roll(&mut state, roller) {
|
||||
Ok(dice) => {
|
||||
if !quiet {
|
||||
println!(
|
||||
"[step {step:4}] {} rolls: {} & {}",
|
||||
player_label(player),
|
||||
dice.values.0,
|
||||
dice.values.1
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error during dice roll: {e}");
|
||||
eprintln!("State:\n{state}");
|
||||
return (step, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
stage => {
|
||||
let player = state.active_player_id;
|
||||
match apply_player_action(&mut state) {
|
||||
Ok(()) => {
|
||||
if !quiet {
|
||||
println!(
|
||||
"[step {step:4}] {} ({stage:?})",
|
||||
player_label(player)
|
||||
);
|
||||
if verbose {
|
||||
println!("{state}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: {e}");
|
||||
eprintln!("State:\n{state}");
|
||||
return (step, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !quiet {
|
||||
println!("\n=== Game over after {step} steps ===\n");
|
||||
println!("{state}");
|
||||
|
||||
let white = state.players.get(&1);
|
||||
let black = state.players.get(&2);
|
||||
|
||||
match (white, black) {
|
||||
(Some(w), Some(b)) => {
|
||||
println!("White — holes: {:2}, points: {:2}", w.holes, w.points);
|
||||
println!("Black — holes: {:2}, points: {:2}", b.holes, b.points);
|
||||
println!();
|
||||
|
||||
let white_score = w.holes as i32 * 12 + w.points as i32;
|
||||
let black_score = b.holes as i32 * 12 + b.points as i32;
|
||||
|
||||
if white_score > black_score {
|
||||
println!("Winner: White (+{})", white_score - black_score);
|
||||
} else if black_score > white_score {
|
||||
println!("Winner: Black (+{})", black_score - white_score);
|
||||
} else {
|
||||
println!("Draw");
|
||||
}
|
||||
}
|
||||
_ => eprintln!("Could not read final player scores."),
|
||||
}
|
||||
}
|
||||
|
||||
(step, false)
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
let mut roller = DiceRoller::new(args.seed);
|
||||
|
||||
if args.games == 1 {
|
||||
println!("=== Trictrac — random game ===");
|
||||
if let Some(s) = args.seed {
|
||||
println!("seed: {s}");
|
||||
}
|
||||
println!();
|
||||
run_game(&mut roller, args.max_steps, false, args.verbose);
|
||||
} else {
|
||||
println!("=== Trictrac — {} games ===", args.games);
|
||||
if let Some(s) = args.seed {
|
||||
println!("seed: {s}");
|
||||
}
|
||||
println!();
|
||||
|
||||
let mut total_steps = 0u64;
|
||||
let mut truncated = 0usize;
|
||||
|
||||
let t0 = Instant::now();
|
||||
for _ in 0..args.games {
|
||||
let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose);
|
||||
total_steps += steps as u64;
|
||||
if trunc {
|
||||
truncated += 1;
|
||||
}
|
||||
}
|
||||
let elapsed = t0.elapsed();
|
||||
|
||||
let secs = elapsed.as_secs_f64();
|
||||
println!("Games : {}", args.games);
|
||||
println!("Truncated : {truncated}");
|
||||
println!("Total steps: {total_steps}");
|
||||
println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64);
|
||||
println!("Elapsed : {:.3} s", secs);
|
||||
println!("Throughput : {:.1} games/s", args.games as f64 / secs);
|
||||
println!(" {:.0} steps/s", total_steps as f64 / secs);
|
||||
}
|
||||
}
|
||||
|
|
@ -598,40 +598,12 @@ impl Board {
|
|||
core::array::from_fn(|i| i + min)
|
||||
}
|
||||
|
||||
/// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i.
|
||||
/// result[0] = 0.
|
||||
pub fn white_checker_cumulative(&self) -> [u8; 25] {
|
||||
let mut cum = [0u8; 25];
|
||||
let mut total = 0u8;
|
||||
for (i, &count) in self.positions.iter().enumerate() {
|
||||
if count > 0 {
|
||||
total += count as u8;
|
||||
}
|
||||
cum[i + 1] = total;
|
||||
}
|
||||
cum
|
||||
}
|
||||
|
||||
pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> {
|
||||
self.remove_checker(color, cmove.from)?;
|
||||
self.add_checker(color, cmove.to)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reverse a previously applied `move_checker`. No validation: assumes the move was valid.
|
||||
pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) {
|
||||
let unit = match color {
|
||||
Color::White => 1,
|
||||
Color::Black => -1,
|
||||
};
|
||||
if cmove.from != 0 {
|
||||
self.positions[cmove.from - 1] += unit;
|
||||
}
|
||||
if cmove.to != 0 {
|
||||
self.positions[cmove.to - 1] -= unit;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> {
|
||||
if field == 0 {
|
||||
return Ok(());
|
||||
|
|
|
|||
|
|
@ -83,8 +83,8 @@ pub mod ffi {
|
|||
/// Both players' scores.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<f32>;
|
||||
/// 36-element state vector (i8). Mirrored for player_idx == 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
|
@ -153,7 +153,8 @@ impl TricTracEngine {
|
|||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
get_valid_action_indices(&mirror)
|
||||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
|
@ -179,11 +180,11 @@ impl TricTracEngine {
|
|||
.unwrap_or(-1)
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_tensor()
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_tensor()
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -242,9 +243,8 @@ impl TricTracEngine {
|
|||
self.game_state
|
||||
),
|
||||
None => anyhow::bail!(
|
||||
"apply_action: could not build event from action index {} in state {}",
|
||||
action_idx,
|
||||
self.game_state
|
||||
"apply_action: could not build event from action index {}",
|
||||
action_idx
|
||||
),
|
||||
}
|
||||
}))
|
||||
|
|
|
|||
|
|
@ -156,6 +156,13 @@ impl GameState {
|
|||
if let Some(p1) = self.players.get(&1) {
|
||||
mirrored_players.insert(2, p1.mirror());
|
||||
}
|
||||
let mirrored_history = self
|
||||
.history
|
||||
.clone()
|
||||
.iter()
|
||||
.map(|evt| evt.get_mirror(false))
|
||||
.collect();
|
||||
|
||||
let (move1, move2) = self.dice_moves;
|
||||
GameState {
|
||||
stage: self.stage,
|
||||
|
|
@ -164,7 +171,7 @@ impl GameState {
|
|||
active_player_id: mirrored_active_player,
|
||||
// active_player_id: self.active_player_id,
|
||||
players: mirrored_players,
|
||||
history: Vec::new(),
|
||||
history: mirrored_history,
|
||||
dice: self.dice,
|
||||
dice_points: self.dice_points,
|
||||
dice_moves: (move1.mirror(), move2.mirror()),
|
||||
|
|
@ -200,110 +207,6 @@ impl GameState {
|
|||
self.to_vec().iter().map(|&x| x as f32).collect()
|
||||
}
|
||||
|
||||
/// Get state as a tensor for neural network training (Option B, TD-Gammon style).
|
||||
/// Returns 217 f32 values, all normalized to [0, 1].
|
||||
///
|
||||
/// Must be called from the active player's perspective: callers should mirror
|
||||
/// the GameState for Black before calling so that "own" always means White.
|
||||
///
|
||||
/// Layout:
|
||||
/// [0..95] own (White) checkers: 4 values per field × 24 fields
|
||||
/// [96..191] opp (Black) checkers: 4 values per field × 24 fields
|
||||
/// [192..193] dice values / 6
|
||||
/// [194] active player color (0=White, 1=Black)
|
||||
/// [195] turn_stage / 5
|
||||
/// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille
|
||||
/// [200..203] Black player: same
|
||||
/// [204..207] own quarter filled (quarters 1-4)
|
||||
/// [208..211] opp quarter filled (quarters 1-4)
|
||||
/// [212] own checkers all in exit zone (fields 19-24)
|
||||
/// [213] opp checkers all in exit zone (fields 1-6)
|
||||
/// [214] own coin de repos taken (field 12 has ≥2 own checkers)
|
||||
/// [215] opp coin de repos taken (field 13 has ≥2 opp checkers)
|
||||
/// [216] own dice_roll_count / 3, clamped to 1
|
||||
pub fn to_tensor(&self) -> Vec<f32> {
|
||||
let mut t = Vec::with_capacity(217);
|
||||
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
|
||||
|
||||
// [0..95] own (White) checkers, TD-Gammon encoding.
|
||||
// Each field contributes 4 values:
|
||||
// (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1]
|
||||
// The overflow term is divided by 12 because the maximum excess is
|
||||
// 15 (all checkers) − 3 = 12.
|
||||
for &c in &pos {
|
||||
let own = c.max(0) as u8;
|
||||
t.push((own == 1) as u8 as f32);
|
||||
t.push((own == 2) as u8 as f32);
|
||||
t.push((own == 3) as u8 as f32);
|
||||
t.push(own.saturating_sub(3) as f32 / 12.0);
|
||||
}
|
||||
|
||||
// [96..191] opp (Black) checkers, same encoding.
|
||||
for &c in &pos {
|
||||
let opp = (-c).max(0) as u8;
|
||||
t.push((opp == 1) as u8 as f32);
|
||||
t.push((opp == 2) as u8 as f32);
|
||||
t.push((opp == 3) as u8 as f32);
|
||||
t.push(opp.saturating_sub(3) as f32 / 12.0);
|
||||
}
|
||||
|
||||
// [192..193] dice
|
||||
t.push(self.dice.values.0 as f32 / 6.0);
|
||||
t.push(self.dice.values.1 as f32 / 6.0);
|
||||
|
||||
// [194] active player color
|
||||
t.push(
|
||||
self.who_plays()
|
||||
.map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 })
|
||||
.unwrap_or(0.0),
|
||||
);
|
||||
|
||||
// [195] turn stage
|
||||
t.push(u8::from(self.turn_stage) as f32 / 5.0);
|
||||
|
||||
// [196..199] White player stats
|
||||
let wp = self.get_white_player();
|
||||
t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [200..203] Black player stats
|
||||
let bp = self.get_black_player();
|
||||
t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [204..207] own (White) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [208..211] opp (Black) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [212] can_exit_own: no own checker in fields 1-18
|
||||
t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32);
|
||||
|
||||
// [213] can_exit_opp: no opp checker in fields 7-24
|
||||
t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32);
|
||||
|
||||
// [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers)
|
||||
t.push((pos[11] >= 2) as u8 as f32);
|
||||
|
||||
// [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers)
|
||||
t.push((pos[12] <= -2) as u8 as f32);
|
||||
|
||||
// [216] own dice_roll_count / 3, clamped to 1
|
||||
t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0));
|
||||
|
||||
debug_assert_eq!(t.len(), 217, "to_tensor length mismatch");
|
||||
t
|
||||
}
|
||||
|
||||
/// Get state as a vector (to be used for bot training input) :
|
||||
/// length = 36
|
||||
/// i8 for board positions with negative values for blacks
|
||||
|
|
@ -1011,16 +914,6 @@ impl GameState {
|
|||
self.mark_points(player_id, points)
|
||||
}
|
||||
|
||||
/// Total accumulated score for a player: `holes × 12 + points`.
|
||||
///
|
||||
/// Returns `0` if `player_id` is not found (e.g. before `init_player`).
|
||||
pub fn total_score(&self, player_id: PlayerId) -> i32 {
|
||||
self.players
|
||||
.get(&player_id)
|
||||
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
// Update player points and holes
|
||||
let mut new_hole = false;
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ impl MoveRules {
|
|||
// Si possible, les deux dés doivent être joués
|
||||
if moves.0.get_from() == 0 || moves.1.get_from() == 0 {
|
||||
let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]);
|
||||
possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok());
|
||||
possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok());
|
||||
// possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok());
|
||||
if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() {
|
||||
if *moves == (EMPTY_MOVE, EMPTY_MOVE) {
|
||||
|
|
@ -238,7 +238,7 @@ impl MoveRules {
|
|||
|
||||
// check exit rules
|
||||
// if !ignored_rules.contains(&TricTracRule::Exit) {
|
||||
self.check_exit_rules(moves, None)?;
|
||||
self.check_exit_rules(moves)?;
|
||||
// }
|
||||
|
||||
// --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ----
|
||||
|
|
@ -321,11 +321,7 @@ impl MoveRules {
|
|||
.is_empty()
|
||||
}
|
||||
|
||||
fn check_exit_rules(
|
||||
&self,
|
||||
moves: &(CheckerMove, CheckerMove),
|
||||
exit_seqs: Option<&[(CheckerMove, CheckerMove)]>,
|
||||
) -> Result<(), MoveError> {
|
||||
fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> {
|
||||
if !moves.0.is_exit() && !moves.1.is_exit() {
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -335,22 +331,16 @@ impl MoveRules {
|
|||
}
|
||||
|
||||
// toutes les sorties directes sont autorisées, ainsi que les nombres défaillants
|
||||
let owned;
|
||||
let seqs = match exit_seqs {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
owned = self
|
||||
.get_possible_moves_sequences(false, vec![TricTracRule::Exit]);
|
||||
&owned
|
||||
}
|
||||
};
|
||||
if seqs.contains(moves) {
|
||||
let ignored_rules = vec![TricTracRule::Exit];
|
||||
let possible_moves_sequences_without_excedent =
|
||||
self.get_possible_moves_sequences(false, ignored_rules);
|
||||
if possible_moves_sequences_without_excedent.contains(moves) {
|
||||
return Ok(());
|
||||
}
|
||||
// À ce stade au moins un des déplacements concerne un nombre en excédant
|
||||
// - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on
|
||||
// refuse cette séquence
|
||||
if !seqs.is_empty() {
|
||||
if !possible_moves_sequences_without_excedent.is_empty() {
|
||||
return Err(MoveError::ExitByEffectPossible);
|
||||
}
|
||||
|
||||
|
|
@ -371,24 +361,17 @@ impl MoveRules {
|
|||
let _ = board_to_check.move_checker(&Color::White, moves.0);
|
||||
let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check);
|
||||
|
||||
// dice normal order
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true);
|
||||
let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
|
||||
|
||||
// dice reversed order
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false);
|
||||
let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
|
||||
|
||||
if is_not_farthest1 && is_not_farthest2 {
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves);
|
||||
if (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2)
|
||||
{
|
||||
return Err(MoveError::ExitNotFarthest);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) {
|
||||
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) {
|
||||
let move1to = if moves.0.get_to() == 0 {
|
||||
25
|
||||
} else {
|
||||
|
|
@ -403,16 +386,20 @@ impl MoveRules {
|
|||
};
|
||||
let dist2 = move2to - moves.1.get_from();
|
||||
|
||||
let (dice1, dice2) = if dice_order {
|
||||
self.dice.values
|
||||
} else {
|
||||
(self.dice.values.1, self.dice.values.0)
|
||||
};
|
||||
let dist_min = cmp::min(dist1, dist2);
|
||||
let dist_max = cmp::max(dist1, dist2);
|
||||
|
||||
(
|
||||
dist1 != 0 && dist1 < dice1 as usize,
|
||||
dist2 != 0 && dist2 < dice2 as usize,
|
||||
)
|
||||
let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize;
|
||||
let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize;
|
||||
|
||||
let min_excedant = dist_min != 0 && dist_min < dice_min;
|
||||
let max_excedant = dist_max != 0 && dist_max < dice_max;
|
||||
|
||||
if dist_min == dist1 {
|
||||
(min_excedant, max_excedant)
|
||||
} else {
|
||||
(max_excedant, min_excedant)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_board_exit_farthest(board: &Board) -> Field {
|
||||
|
|
@ -451,18 +438,12 @@ impl MoveRules {
|
|||
} else {
|
||||
(dice2, dice1)
|
||||
};
|
||||
let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) {
|
||||
Some(self.get_quarter_filling_moves_sequences())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut moves_seqs = self.get_possible_moves_sequences_by_dices(
|
||||
dice_max,
|
||||
dice_min,
|
||||
with_excedents,
|
||||
false,
|
||||
&ignored_rules,
|
||||
filling_seqs.as_deref(),
|
||||
ignored_rules.clone(),
|
||||
);
|
||||
// if we got valid sequences with the highest die, we don't accept sequences using only the
|
||||
// lowest die
|
||||
|
|
@ -472,8 +453,7 @@ impl MoveRules {
|
|||
dice_max,
|
||||
with_excedents,
|
||||
ignore_empty,
|
||||
&ignored_rules,
|
||||
filling_seqs.as_deref(),
|
||||
ignored_rules,
|
||||
);
|
||||
moves_seqs.append(&mut moves_seqs_order2);
|
||||
let empty_removed = moves_seqs
|
||||
|
|
@ -544,16 +524,14 @@ impl MoveRules {
|
|||
let mut moves_seqs = Vec::new();
|
||||
let color = &Color::White;
|
||||
let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter];
|
||||
let mut board = self.board.clone();
|
||||
for moves in self.get_possible_moves_sequences(true, ignored_rules) {
|
||||
let mut board = self.board.clone();
|
||||
board.move_checker(color, moves.0).unwrap();
|
||||
board.move_checker(color, moves.1).unwrap();
|
||||
// println!("get_quarter_filling_moves_sequences board : {:?}", board);
|
||||
if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) {
|
||||
moves_seqs.push(moves);
|
||||
}
|
||||
board.unmove_checker(color, moves.1);
|
||||
board.unmove_checker(color, moves.0);
|
||||
}
|
||||
moves_seqs
|
||||
}
|
||||
|
|
@ -564,27 +542,18 @@ impl MoveRules {
|
|||
dice2: u8,
|
||||
with_excedents: bool,
|
||||
ignore_empty: bool,
|
||||
ignored_rules: &[TricTracRule],
|
||||
filling_seqs: Option<&[(CheckerMove, CheckerMove)]>,
|
||||
ignored_rules: Vec<TricTracRule>,
|
||||
) -> Vec<(CheckerMove, CheckerMove)> {
|
||||
let mut moves_seqs = Vec::new();
|
||||
let color = &Color::White;
|
||||
let forbid_exits = self.has_checkers_outside_last_quarter();
|
||||
// Precompute non-excedant sequences once so check_exit_rules need not repeat
|
||||
// the full move generation for every exit-move candidate.
|
||||
// Only needed when Exit is not already ignored and exits are actually reachable.
|
||||
let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits {
|
||||
Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit]))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut board = self.board.clone();
|
||||
// println!("==== First");
|
||||
for first_move in
|
||||
self.board
|
||||
.get_possible_moves(*color, dice1, with_excedents, false, forbid_exits)
|
||||
{
|
||||
if board.move_checker(color, first_move).is_err() {
|
||||
let mut board2 = self.board.clone();
|
||||
if board2.move_checker(color, first_move).is_err() {
|
||||
println!("err move");
|
||||
continue;
|
||||
}
|
||||
|
|
@ -594,7 +563,7 @@ impl MoveRules {
|
|||
let mut has_second_dice_move = false;
|
||||
// println!(" ==== Second");
|
||||
for second_move in
|
||||
board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
|
||||
board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
|
||||
{
|
||||
if self
|
||||
.check_corner_rules(&(first_move, second_move))
|
||||
|
|
@ -618,10 +587,24 @@ impl MoveRules {
|
|||
&& self.can_take_corner_by_effect())
|
||||
&& (ignored_rules.contains(&TricTracRule::Exit)
|
||||
|| self
|
||||
.check_exit_rules(&(first_move, second_move), exit_seqs.as_deref())
|
||||
.check_exit_rules(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd (exit rule): {:?} - {:?}, {:?}",
|
||||
// e, first_move, second_move
|
||||
// )
|
||||
// })
|
||||
.is_ok())
|
||||
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|
||||
|| self
|
||||
.check_must_fill_quarter_rule(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd: {:?} - {:?}, {:?} for {:?}",
|
||||
// e, first_move, second_move, self.board
|
||||
// )
|
||||
// })
|
||||
.is_ok())
|
||||
&& filling_seqs
|
||||
.map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move)))
|
||||
{
|
||||
if second_move.get_to() == 0
|
||||
&& first_move.get_to() == 0
|
||||
|
|
@ -644,14 +627,16 @@ impl MoveRules {
|
|||
&& !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE))
|
||||
&& self.can_take_corner_by_effect())
|
||||
&& (ignored_rules.contains(&TricTracRule::Exit)
|
||||
|| self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok())
|
||||
&& filling_seqs
|
||||
.map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE)))
|
||||
|| self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok())
|
||||
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|
||||
|| self
|
||||
.check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE))
|
||||
.is_ok())
|
||||
{
|
||||
// empty move
|
||||
moves_seqs.push((first_move, EMPTY_MOVE));
|
||||
}
|
||||
board.unmove_checker(color, first_move);
|
||||
//if board2.get_color_fields(*color).is_empty() {
|
||||
}
|
||||
moves_seqs
|
||||
}
|
||||
|
|
@ -1510,7 +1495,6 @@ mod tests {
|
|||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
);
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1518,8 +1502,7 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
|
||||
|
|
@ -1534,7 +1517,6 @@ mod tests {
|
|||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
)];
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
moves,
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1542,8 +1524,7 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
let moves = vec![(
|
||||
|
|
@ -1557,8 +1538,7 @@ mod tests {
|
|||
state.dice.values.0,
|
||||
true,
|
||||
false,
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
|
||||
|
|
@ -1574,7 +1554,6 @@ mod tests {
|
|||
CheckerMove::new(19, 21).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
|
|
@ -1582,8 +1561,7 @@ mod tests {
|
|||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
&[],
|
||||
filling_seqs.as_deref(),
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
@ -1602,26 +1580,13 @@ mod tests {
|
|||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
|
||||
let moves = (
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
|
||||
state.dice.values = (6, 4);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(20, 24).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves, None).is_ok());
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -113,11 +113,11 @@ impl TricTrac {
|
|||
[self.get_score(1), self.get_score(2)]
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_tensor()
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_tensor()
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
use std::cmp::{max, min};
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
|
||||
use crate::board::Board;
|
||||
use crate::{CheckerMove, Dice, GameEvent, GameState};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -220,14 +221,10 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
// Ajoute aussi les mouvements possibles
|
||||
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
// rules.board is already White-perspective (mirrored if Black): compute cum once.
|
||||
let cum = rules.board.white_checker_cumulative();
|
||||
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(white_checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
&cum,
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
|
@ -238,14 +235,10 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
// Empty move
|
||||
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
||||
}
|
||||
// rules.board is already White-perspective (mirrored if Black): compute cum once.
|
||||
let cum = rules.board.white_checker_cumulative();
|
||||
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(white_checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
&cum,
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
|
@ -258,27 +251,36 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
|||
Ok(valid_actions)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
color: &crate::Color,
|
||||
state: &GameState,
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
// Moves are always in White's coordinate system. For Black, mirror the board first.
|
||||
let cum = if color == &crate::Color::Black {
|
||||
state.board.mirror().white_checker_cumulative()
|
||||
let dice = &state.dice;
|
||||
let board = &state.board;
|
||||
|
||||
if color == &crate::Color::Black {
|
||||
// Moves are already 'white', so we don't mirror them
|
||||
white_checker_moves_to_trictrac_action(
|
||||
move1,
|
||||
move2,
|
||||
// &move1.clone().mirror(),
|
||||
// &move2.clone().mirror(),
|
||||
dice,
|
||||
&board.clone().mirror(),
|
||||
)
|
||||
// .map(|a| a.mirror())
|
||||
} else {
|
||||
state.board.white_checker_cumulative()
|
||||
};
|
||||
white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum)
|
||||
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
|
||||
}
|
||||
}
|
||||
|
||||
fn white_checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
cum: &[u8; 25],
|
||||
board: &Board,
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
|
|
@ -300,7 +302,7 @@ fn white_checker_moves_to_trictrac_action(
|
|||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 || from2 == 0 {
|
||||
if from1 < from2 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
|
|
@ -319,21 +321,11 @@ fn white_checker_moves_to_trictrac_action(
|
|||
}
|
||||
let dice_order = diff_move1 == dice.values.0 as usize;
|
||||
|
||||
// cum[i] = # white checkers in fields 1..=i (precomputed by the caller).
|
||||
// checker1 is the ordinal of the last checker at from1.
|
||||
let checker1 = cum[from1] as usize;
|
||||
// checker2 is the ordinal on the board after move1 (removed from from1, added to to1).
|
||||
// Adjust the cumulative in O(1) without cloning the board.
|
||||
let checker2 = {
|
||||
let mut c = cum[from2];
|
||||
if from1 > 0 && from2 >= from1 {
|
||||
c -= 1; // one checker was removed from from1, shifting later ordinals down
|
||||
}
|
||||
if from1 > 0 && to1 > 0 && from2 >= to1 {
|
||||
c += 1; // one checker was added at to1, shifting later ordinals up
|
||||
}
|
||||
c as usize
|
||||
};
|
||||
let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize;
|
||||
let mut tmp_board = board.clone();
|
||||
// should not raise an error for a valid action
|
||||
tmp_board.move_checker(&crate::Color::White, *move1)?;
|
||||
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
|
||||
Ok(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
|
|
@ -464,48 +456,5 @@ mod tests {
|
|||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
|
||||
// Black player
|
||||
state.active_player_id = 2;
|
||||
state.dice = Dice { values: (6, 3) };
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4,
|
||||
],
|
||||
);
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
&CheckerMove::new(21, 0).unwrap(),
|
||||
&CheckerMove::new(0, 0).unwrap(),
|
||||
&crate::Color::Black,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 2,
|
||||
checker2: 0, // blocked by white on last field
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
|
||||
// same with dice order reversed
|
||||
state.dice = Dice { values: (3, 6) };
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
&CheckerMove::new(21, 0).unwrap(),
|
||||
&CheckerMove::new(0, 0).unwrap(),
|
||||
&crate::Color::Black,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: false,
|
||||
checker1: 2,
|
||||
checker2: 0, // blocked by white on last field
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue