Compare commits

...
Sign in to create a new pull request.

31 commits

Author SHA1 Message Date
1554286f25 doc: research parallel 2026-03-12 10:13:07 +01:00
31bb568c2a feat(spiel_bot): az_train parallel games with rayon 2026-03-12 10:13:07 +01:00
e80dade303 fix: --n-sim training parameter 2026-03-11 22:17:03 +01:00
e7d13c9a02 feat(spiel_bot): dqn 2026-03-10 22:12:52 +01:00
7c0f230e3d doc: tensor research 2026-03-10 08:19:24 +01:00
150efe302f feat(spiel_bot): az_train training command 2026-03-09 09:17:17 +01:00
3221b5256a feat(spiel_bot): alphazero eval binary 2026-03-09 09:17:17 +01:00
822290d722 feat(spiel_bot): upgrade network 2026-03-09 09:17:17 +01:00
9c82692ddb feat(spiel_bot): benchmarks 2026-03-09 09:17:17 +01:00
aea1e3faaf tests(spiel_bot): integration tests 2026-03-09 09:17:17 +01:00
519dfe67ad fix(spiel_bot): mcts fix 2026-03-09 09:17:17 +01:00
b0ae4db2d9 feat(spiel_bot): AlphaZero 2026-03-09 09:17:17 +01:00
58ae8ad3b3 feat(spiel_bot): Monte-Carlo tree search 2026-03-09 09:17:17 +01:00
d5cd4c2402 feat(spiel_bot): network with mlp and resnet 2026-03-09 09:17:17 +01:00
df05a43022 feat(spiel_bot): init crate & implements GameEnv trait + TrictracEnv 2026-03-09 09:17:17 +01:00
a6644e3c9d fix: to_tensor() normalization 2026-03-09 09:17:17 +01:00
85ccca4741 doc:rust open_spiel research 2026-03-09 09:17:17 +01:00
1c4c814417 fix: training_common::white_checker_moves_to_trictrac_action 2026-03-09 09:16:30 +01:00
db5c1ea4f4 debug 2026-03-07 13:53:13 +01:00
aa7f5fe42a feat: add get_tensor on GameState more explicit for training than the minimal get_vec() 2026-03-07 12:56:03 +01:00
145ab7dcda Merge branch 'feature/performance' into develop 2026-03-06 22:19:26 +01:00
f26808d798 clean research 2026-03-06 22:19:08 +01:00
43eb5bf18d refact(perf): use board::white_checker_cumulative to convert move to trictracAction 2026-03-06 22:19:08 +01:00
dfc485a47a refact(perf): precompute non excedant get_possible_moves_sequences 2026-03-06 22:19:08 +01:00
a239c02937 refact(perf): less board clones with new function unmove_checker() 2026-03-06 22:19:07 +01:00
6beaa56202 refact(perf): remove moves history from mirror() 2026-03-06 22:19:07 +01:00
45b9db61e3 refact(perf): remove Recursive get_possible_moves_sequences in check_must_fill_quarter_rule 2026-03-06 22:19:07 +01:00
44a5ba87b0 perf research 2026-03-06 18:11:03 +01:00
bd4c75228b fix: exit with farthest rule (2) 2026-03-06 18:09:15 +01:00
8732512736 feat: command to play random games with open_spiel logic 2026-03-06 17:33:28 +01:00
eba93f0f13 Merge tag 'v0.2.0' into develop
v0.2.0
2026-03-06 15:13:28 +01:00
37 changed files with 7106 additions and 1105 deletions

132
Cargo.lock generated
View file

@ -92,6 +92,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.21" version = "0.6.21"
@ -1116,6 +1122,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cast_trait" name = "cast_trait"
version = "0.1.2" version = "0.1.2"
@ -1200,6 +1212,33 @@ dependencies = [
"rand 0.7.3", "rand 0.7.3",
] ]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]] [[package]]
name = "cipher" name = "cipher"
version = "0.4.4" version = "0.4.4"
@ -1453,6 +1492,42 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]] [[package]]
name = "critical-section" name = "critical-section"
version = "1.2.0" version = "1.2.0"
@ -4461,6 +4536,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]] [[package]]
name = "opaque-debug" name = "opaque-debug"
version = "0.3.1" version = "0.3.1"
@ -4597,6 +4678,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "png" name = "png"
version = "0.18.0" version = "0.18.0"
@ -5891,6 +6000,19 @@ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.60.2",
] ]
[[package]]
name = "spiel_bot"
version = "0.1.0"
dependencies = [
"anyhow",
"burn",
"criterion",
"rand 0.9.2",
"rand_distr",
"rayon",
"trictrac-store",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.10.0" version = "0.10.0"
@ -6299,6 +6421,16 @@ dependencies = [
"zerovec", "zerovec",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.10.0" version = "1.10.0"

View file

@ -1,4 +1,4 @@
[workspace] [workspace]
resolver = "2" resolver = "2"
members = ["client_cli", "bot", "store"] members = ["client_cli", "bot", "store", "spiel_bot"]

View file

@ -1,992 +0,0 @@
# Plan: C++ OpenSpiel Game via cxx.rs
> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings.
>
> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`.
---
## 1. Overview
The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding:
- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`.
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`.
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box<ffi::TricTracEngine>` and delegates all logic to Rust.
- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`.
The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games.
---
## 2. Files to Create / Modify
```
trictrac/
store/
Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type
build.rs ← CREATE: cxx-build bridge registration
src/
lib.rs ← MODIFY: add cxxengine module
cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl
forks/open_spiel/
CMakeLists.txt ← MODIFY: add Corrosion FetchContent
open_spiel/
games/
CMakeLists.txt ← MODIFY: add trictrac/ sources + test
trictrac/ ← CREATE directory
trictrac.h ← CREATE
trictrac.cc ← CREATE
trictrac_test.cc ← CREATE
justfile ← MODIFY: add buildtrictrac target
trictrac/
justfile ← MODIFY: add cxxlib target
```
---
## 3. Step 1 — Rust: `store/Cargo.toml`
Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library.
```toml
[package]
name = "trictrac-store"
version = "0.1.0"
edition = "2021"
[lib]
name = "trictrac_store"
# cdylib → Python .so (used by maturin / pyengine)
# rlib → used by other Rust crates in the workspace
# staticlib → used by C++ consumers (cxxengine)
crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
base64 = "0.21.7"
cxx = "1.0"
log = "0.4.20"
merge = "0.1.0"
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
rand = "0.9"
serde = { version = "1.0", features = ["derive"] }
transpose = "0.2.2"
[build-dependencies]
cxx-build = "1.0"
```
> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`).
---
## 4. Step 2 — Rust: `store/build.rs`
The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header.
```rust
fn main() {
cxx_build::bridge("src/cxxengine.rs")
.std("c++17")
.compile("trictrac-cxx");
// Re-run if the bridge source changes
println!("cargo:rerun-if-changed=src/cxxengine.rs");
}
```
`cxx-build` will:
- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block.
- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header.
- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source.
- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`).
---
## 5. Step 3 — Rust: `store/src/cxxengine.rs`
This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3.
### Design decisions vs. `pyengine.rs`
| pyengine | cxxengine | Reason |
| ------------------------- | ---------------------------- | -------------------------------------------- |
| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception |
| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples |
| `Vec<usize>` for actions | `Vec<u64>` | cxx does not support `usize` |
| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays |
| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy |
### File content
```rust
//! # C++ bindings for the TricTrac game engine via cxx.rs
//!
//! Exposes an opaque `TricTracEngine` type and associated functions
//! to C++. The C++ side (trictrac.cc) uses `rust::Box<ffi::TricTracEngine>`.
//!
//! The Rust engine always works from the perspective of White (player 1).
//! For Black (player 2), the board is mirrored before computing actions
//! and events are mirrored back before applying — exactly as in pyengine.rs.
use crate::dice::Dice;
use crate::game::{GameEvent, GameState, Stage, TurnStage};
use crate::training_common::{get_valid_action_indices, TrictracAction};
// ── cxx bridge declaration ────────────────────────────────────────────────────
#[cxx::bridge(namespace = "trictrac_engine")]
pub mod ffi {
// ── Shared types (visible to both Rust and C++) ───────────────────────────
/// Two dice values passed from C++ to Rust for a dice-roll event.
struct DicePair {
die1: u8,
die2: u8,
}
/// Both players' scores: holes * 12 + points.
struct PlayerScores {
score_p1: i32,
score_p2: i32,
}
// ── Opaque Rust type exposed to C++ ───────────────────────────────────────
extern "Rust" {
/// Opaque handle to a TricTrac game state.
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
type TricTracEngine;
/// Create a new engine, initialise two players, begin with player 1.
fn new_trictrac_engine() -> Box<TricTracEngine>;
/// Return a deep copy of the engine (needed for State::Clone()).
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
// ── Queries ───────────────────────────────────────────────────────────
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
fn needs_roll(self: &TricTracEngine) -> bool;
/// True when Stage::Ended.
fn is_game_ended(self: &TricTracEngine) -> bool;
/// Active player index: 0 (player 1 / White) or 1 (player 2 / Black).
fn current_player_idx(self: &TricTracEngine) -> u64;
/// Legal action indices for `player_idx`. Returns empty vec if it is
/// not that player's turn. Indices are in [0, 513].
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec<u64>;
/// Human-readable action description, e.g. "0:Move { dice_order: true … }".
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
/// Both players' scores: holes * 12 + points.
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
/// 36-element state observation vector (i8). Mirrored for player 1.
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
/// Human-readable state description for `player_idx`.
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
/// Full debug representation of the current state.
fn to_debug_string(self: &TricTracEngine) -> String;
// ── Mutations ─────────────────────────────────────────────────────────
/// Apply a dice roll result. Returns Err if not in RollWaiting stage.
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
/// Apply a player action (move, go, roll). Returns Err if invalid.
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
}
}
// ── Opaque type implementation ────────────────────────────────────────────────
pub struct TricTracEngine {
game_state: GameState,
}
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
let mut game_state = GameState::new(false); // schools_enabled = false
game_state.init_player("player1");
game_state.init_player("player2");
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
Box::new(TricTracEngine { game_state })
}
impl TricTracEngine {
fn clone_engine(&self) -> Box<TricTracEngine> {
Box::new(TricTracEngine {
game_state: self.game_state.clone(),
})
}
fn needs_roll(&self) -> bool {
self.game_state.turn_stage == TurnStage::RollWaiting
}
fn is_game_ended(&self) -> bool {
self.game_state.stage == Stage::Ended
}
/// Returns 0 for player 1 (White) and 1 for player 2 (Black).
fn current_player_idx(&self) -> u64 {
self.game_state.active_player_id - 1
}
fn get_legal_actions(&self, player_idx: u64) -> Vec<u64> {
if player_idx == self.current_player_idx() {
if player_idx == 0 {
get_valid_action_indices(&self.game_state)
.into_iter()
.map(|i| i as u64)
.collect()
} else {
let mirror = self.game_state.mirror();
get_valid_action_indices(&mirror)
.into_iter()
.map(|i| i as u64)
.collect()
}
} else {
vec![]
}
}
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
TrictracAction::from_action_index(action_idx as usize)
.map(|a| format!("{}:{}", player_idx, a))
.unwrap_or_else(|| "unknown action".into())
}
fn get_players_scores(&self) -> ffi::PlayerScores {
ffi::PlayerScores {
score_p1: self.score_for(1),
score_p2: self.score_for(2),
}
}
fn score_for(&self, player_id: u64) -> i32 {
if let Some(player) = self.game_state.players.get(&player_id) {
player.holes as i32 * 12 + player.points as i32
} else {
-1
}
}
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
if player_idx == 0 {
self.game_state.to_vec()
} else {
self.game_state.mirror().to_vec()
}
}
fn get_observation_string(&self, player_idx: u64) -> String {
if player_idx == 0 {
format!("{}", self.game_state)
} else {
format!("{}", self.game_state.mirror())
}
}
fn to_debug_string(&self) -> String {
format!("{}", self.game_state)
}
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> {
let player_id = self.game_state.active_player_id;
if self.game_state.turn_stage != TurnStage::RollWaiting {
return Err("Not in RollWaiting stage".into());
}
let dice = Dice {
values: (dice.die1, dice.die2),
};
self.game_state
.consume(&GameEvent::RollResult { player_id, dice });
Ok(())
}
fn apply_action(&mut self, action_idx: u64) -> Result<(), String> {
let action_idx = action_idx as usize;
let needs_mirror = self.game_state.active_player_id == 2;
let event = TrictracAction::from_action_index(action_idx)
.and_then(|a| {
let game_state = if needs_mirror {
&self.game_state.mirror()
} else {
&self.game_state
};
a.to_event(game_state)
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
});
match event {
Some(evt) if self.game_state.validate(&evt) => {
self.game_state.consume(&evt);
Ok(())
}
Some(_) => Err("Action is invalid".into()),
None => Err("Could not build event from action index".into()),
}
}
}
```
> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options:
>
> - Use `anyhow::Error` (add `anyhow` dependency).
> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`.
>
> The recommended approach is `anyhow`:
>
> ```toml
> [dependencies]
> anyhow = "1.0"
> ```
>
> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message.
---
## 6. Step 4 — Rust: `store/src/lib.rs`
Add the new module:
```rust
// existing modules …
mod pyengine;
// NEW: C++ bindings via cxx.rs
pub mod cxxengine;
```
---
## 7. Step 5 — C++: `trictrac/trictrac.h`
Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box<ffi::TricTracEngine>` and delegates everything to it.
```cpp
// open_spiel/games/trictrac/trictrac.h
#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_
#define OPEN_SPIEL_GAMES_TRICTRAC_H_
#include <memory>
#include <string>
#include <vector>
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"
// Generated by cxx-build from store/src/cxxengine.rs.
// The include path is set by CMake (see CMakeLists.txt).
#include "trictrac_store/src/cxxengine.rs.h"
namespace open_spiel {
namespace trictrac {
inline constexpr int kNumPlayers = 2;
inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes
inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust
inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust
inline constexpr int kDefaultMaxTurns = 1000;
class TrictracGame;
// ---------------------------------------------------------------------------
// TrictracState
// ---------------------------------------------------------------------------
class TrictracState : public State {
public:
explicit TrictracState(std::shared_ptr<const Game> game);
TrictracState(const TrictracState& other);
Player CurrentPlayer() const override;
std::vector<Action> LegalActions() const override;
std::string ActionToString(Player player, Action move_id) const override;
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;
std::string ToString() const override;
bool IsTerminal() const override;
std::vector<double> Returns() const override;
std::string ObservationString(Player player) const override;
void ObservationTensor(Player player, absl::Span<float> values) const override;
std::unique_ptr<State> Clone() const override;
protected:
void DoApplyAction(Action move_id) override;
private:
// Decode a chance action index [0,35] to (die1, die2).
// Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action]
static trictrac_engine::DicePair DecodeChanceAction(Action action);
// The Rust engine handle. Deep-copied via clone_engine() when cloning state.
rust::Box<trictrac_engine::TricTracEngine> engine_;
};
// ---------------------------------------------------------------------------
// TrictracGame
// ---------------------------------------------------------------------------
class TrictracGame : public Game {
public:
explicit TrictracGame(const GameParameters& params);
int NumDistinctActions() const override { return kNumDistinctActions; }
std::unique_ptr<State> NewInitialState() const override;
int MaxChanceOutcomes() const override { return kNumChanceOutcomes; }
int NumPlayers() const override { return kNumPlayers; }
double MinUtility() const override { return 0.0; }
double MaxUtility() const override { return 200.0; }
int MaxGameLength() const override { return 3 * max_turns_; }
int MaxChanceNodesInHistory() const override { return MaxGameLength(); }
std::vector<int> ObservationTensorShape() const override {
return {kStateEncodingSize};
}
private:
int max_turns_;
};
} // namespace trictrac
} // namespace open_spiel
#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_
```
---
## 8. Step 6 — C++: `trictrac/trictrac.cc`
```cpp
// open_spiel/games/trictrac/trictrac.cc
#include "open_spiel/games/trictrac/trictrac.h"
#include <memory>
#include <string>
#include <vector>
#include "open_spiel/abseil-cpp/absl/types/span.h"
#include "open_spiel/game_parameters.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_globals.h"
#include "open_spiel/spiel_utils.h"
namespace open_spiel {
namespace trictrac {
namespace {
// ── Game registration ────────────────────────────────────────────────────────
const GameType kGameType{
/*short_name=*/"trictrac",
/*long_name=*/"Trictrac",
GameType::Dynamics::kSequential,
GameType::ChanceMode::kExplicitStochastic,
GameType::Information::kPerfectInformation,
GameType::Utility::kGeneralSum,
GameType::RewardModel::kRewards,
/*min_num_players=*/kNumPlayers,
/*max_num_players=*/kNumPlayers,
/*provides_information_state_string=*/false,
/*provides_information_state_tensor=*/false,
/*provides_observation_string=*/true,
/*provides_observation_tensor=*/true,
/*parameter_specification=*/{
{"max_turns", GameParameter(kDefaultMaxTurns)},
}};
static std::shared_ptr<const Game> Factory(const GameParameters& params) {
return std::make_shared<const TrictracGame>(params);
}
REGISTER_SPIEL_GAME(kGameType, Factory);
} // namespace
// ── TrictracGame ─────────────────────────────────────────────────────────────
TrictracGame::TrictracGame(const GameParameters& params)
: Game(kGameType, params),
max_turns_(ParameterValue<int>("max_turns", kDefaultMaxTurns)) {}
std::unique_ptr<State> TrictracGame::NewInitialState() const {
return std::make_unique<TrictracState>(shared_from_this());
}
// ── TrictracState ─────────────────────────────────────────────────────────────
TrictracState::TrictracState(std::shared_ptr<const Game> game)
: State(game),
engine_(trictrac_engine::new_trictrac_engine()) {}
// Copy constructor: deep-copy the Rust engine via clone_engine().
TrictracState::TrictracState(const TrictracState& other)
: State(other),
engine_(other.engine_->clone_engine()) {}
std::unique_ptr<State> TrictracState::Clone() const {
return std::make_unique<TrictracState>(*this);
}
// ── Current player ────────────────────────────────────────────────────────────
Player TrictracState::CurrentPlayer() const {
if (engine_->is_game_ended()) return kTerminalPlayerId;
if (engine_->needs_roll()) return kChancePlayerId;
return static_cast<Player>(engine_->current_player_idx());
}
// ── Legal actions ─────────────────────────────────────────────────────────────
std::vector<Action> TrictracState::LegalActions() const {
if (IsChanceNode()) {
// All 36 dice outcomes are equally likely; return indices 035.
std::vector<Action> actions(kNumChanceOutcomes);
for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i;
return actions;
}
Player player = CurrentPlayer();
rust::Vec<uint64_t> rust_actions =
engine_->get_legal_actions(static_cast<uint64_t>(player));
std::vector<Action> actions;
actions.reserve(rust_actions.size());
for (uint64_t a : rust_actions) actions.push_back(static_cast<Action>(a));
return actions;
}
// ── Chance outcomes ───────────────────────────────────────────────────────────
std::vector<std::pair<Action, double>> TrictracState::ChanceOutcomes() const {
SPIEL_CHECK_TRUE(IsChanceNode());
const double p = 1.0 / kNumChanceOutcomes;
std::vector<std::pair<Action, double>> outcomes;
outcomes.reserve(kNumChanceOutcomes);
for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p);
return outcomes;
}
// ── Apply action ──────────────────────────────────────────────────────────────
/*static*/
trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) {
// Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action]
return trictrac_engine::DicePair{
/*die1=*/static_cast<uint8_t>(action / 6 + 1),
/*die2=*/static_cast<uint8_t>(action % 6 + 1),
};
}
void TrictracState::DoApplyAction(Action action) {
if (IsChanceNode()) {
engine_->apply_dice_roll(DecodeChanceAction(action));
} else {
engine_->apply_action(static_cast<uint64_t>(action));
}
}
// ── Terminal & returns ────────────────────────────────────────────────────────
bool TrictracState::IsTerminal() const {
return engine_->is_game_ended();
}
std::vector<double> TrictracState::Returns() const {
trictrac_engine::PlayerScores scores = engine_->get_players_scores();
return {static_cast<double>(scores.score_p1),
static_cast<double>(scores.score_p2)};
}
// ── Observation ───────────────────────────────────────────────────────────────
std::string TrictracState::ObservationString(Player player) const {
return std::string(engine_->get_observation_string(
static_cast<uint64_t>(player)));
}
void TrictracState::ObservationTensor(Player player,
absl::Span<float> values) const {
SPIEL_CHECK_EQ(values.size(), kStateEncodingSize);
rust::Vec<int8_t> tensor =
engine_->get_tensor(static_cast<uint64_t>(player));
SPIEL_CHECK_EQ(tensor.size(), static_cast<size_t>(kStateEncodingSize));
for (int i = 0; i < kStateEncodingSize; ++i) {
values[i] = static_cast<float>(tensor[i]);
}
}
// ── Strings ───────────────────────────────────────────────────────────────────
std::string TrictracState::ToString() const {
return std::string(engine_->to_debug_string());
}
std::string TrictracState::ActionToString(Player player, Action action) const {
if (IsChanceNode()) {
trictrac_engine::DicePair d = DecodeChanceAction(action);
return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")";
}
return std::string(engine_->action_to_string(
static_cast<uint64_t>(player), static_cast<uint64_t>(action)));
}
} // namespace trictrac
} // namespace open_spiel
```
---
## 9. Step 7 — C++: `trictrac/trictrac_test.cc`
```cpp
// open_spiel/games/trictrac/trictrac_test.cc
#include "open_spiel/games/trictrac/trictrac.h"
#include <iostream>
#include <memory>
#include "open_spiel/spiel.h"
#include "open_spiel/tests/basic_tests.h"
#include "open_spiel/utils/init.h"
namespace open_spiel {
namespace trictrac {
namespace {
void BasicTrictracTests() {
testing::LoadGameTest("trictrac");
testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5);
}
} // namespace
} // namespace trictrac
} // namespace open_spiel
int main(int argc, char** argv) {
open_spiel::Init(&argc, &argv);
open_spiel::trictrac::BasicTrictracTests();
std::cout << "trictrac tests passed" << std::endl;
return 0;
}
```
---
## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt`
The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined:
```cmake
# ── Corrosion: CMake integration for Rust ────────────────────────────────────
include(FetchContent)
FetchContent_Declare(
Corrosion
GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git
GIT_TAG v0.5.1 # pin to a stable release
)
FetchContent_MakeAvailable(Corrosion)
# Import the trictrac-store Rust crate.
# This creates a CMake target named 'trictrac-store'.
corrosion_import_crate(
MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml
CRATES trictrac-store
)
# Generate the cxx bridge from cxxengine.rs.
# corrosion_add_cxxbridge:
# - runs cxx-build as part of the Rust build
# - creates a CMake target 'trictrac_cxx_bridge' that:
# * compiles the generated C++ shim
# * exposes INTERFACE include dirs for the generated .rs.h header
corrosion_add_cxxbridge(trictrac_cxx_bridge
CRATE trictrac-store
FILES src/cxxengine.rs
)
```
> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing.
---
## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt`
Two changes: add the new source files to `GAME_SOURCES`, and add a test target.
### 11.1 Add to `GAME_SOURCES`
Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add:
```cmake
set(GAME_SOURCES
# ... existing games ...
trictrac/trictrac.cc
trictrac/trictrac.h
# ... remaining games ...
)
```
### 11.2 Link cxx bridge into OpenSpiel objects
The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed.
The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add:
```cmake
target_link_libraries(open_spiel
PUBLIC
trictrac_cxx_bridge # C++ shim generated by cxx-build
trictrac-store # Rust static library
)
```
If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game:
```cmake
add_library(trictrac_game OBJECT
trictrac/trictrac.cc
trictrac/trictrac.h
)
target_include_directories(trictrac_game
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..
)
target_link_libraries(trictrac_game
PUBLIC
trictrac_cxx_bridge
trictrac-store
open_spiel_core # or whatever the core target is called
)
```
Then reference `$<TARGET_OBJECTS:trictrac_game>` in relevant executables.
### 11.3 Add the test
```cmake
add_executable(trictrac_test
trictrac/trictrac_test.cc
${OPEN_SPIEL_OBJECTS}
$<TARGET_OBJECTS:tests>
)
target_link_libraries(trictrac_test
PRIVATE
trictrac_cxx_bridge
trictrac-store
)
add_test(trictrac_test trictrac_test)
```
---
## 12. Step 10 — Justfile updates
### `trictrac/justfile` — add `cxxlib` target
Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists:
```just
cxxlib:
cargo build --release -p trictrac-store
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
```
### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac`
```just
buildtrictrac:
# Rebuild the Rust static lib first, then CMake
cd ../../trictrac && cargo build --release -p trictrac-store
mkdir -p build && cd build && \
CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \
make -j$(nproc) trictrac_test
testtrictrac: buildtrictrac
./build/trictrac_test
playtrictrac_cpp:
./build/examples/example --game=trictrac
```
---
## 13. Key Design Decisions
### 13.1 Opaque type with `clone_engine()`
OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box<TricTracEngine>` which calls `.clone()` on the inner `GameState` (which derives `Clone`).
### 13.2 Action encoding: same 514-element space
The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means:
- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies.
- Action 0 = Roll (used as the bridge between Move and the next chance node).
- Actions 2513 = Move variants (checker ordinal pair + dice order).
- A trained C++ model and Python model share the same action space.
### 13.3 Chance outcome ordering
The dice outcome ordering is identical to the Python version:
```
action → (die1, die2)
0 → (1,1) 6 → (2,1) ... 35 → (6,6)
```
(`die1 = action/6 + 1`, `die2 = action%6 + 1`)
This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines.
### 13.4 `GameType::Utility::kGeneralSum` + `kRewards`
Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal.
### 13.5 Mirror pattern preserved
`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest.
### 13.6 `rust::Box` vs `rust::UniquePtr`
`rust::Box<T>` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr<T>` is for C++ types held in Rust.
### 13.7 Separate struct from `pyengine.rs`
`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids:
- PyO3 and cxx attributes conflicting on the same type.
- Changes to one binding breaking the other.
- Feature-flag complexity.
---
## 14. Known Challenges
### 14.1 Corrosion path resolution
`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable).
### 14.2 `staticlib` + `cdylib` in one crate
Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions:
- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time.
- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile.
- Alternatively, accept the extra compile time during development.
### 14.3 Linker symbols from Rust std
When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`).
### 14.4 `anyhow` for error types
cxx.rs requires the `Err` type in `Result<T, E>` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper:
```rust
use std::fmt;
#[derive(Debug)]
struct EngineError(String);
impl fmt::Display for EngineError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) }
}
impl std::error::Error for EngineError {}
```
On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`.
### 14.5 `UndoAction` not implemented
OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo.
### 14.6 Generated header path in `#include`
The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory.
### 14.7 `rust::String` to `std::string` conversion
The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast<std::string>(...)`.
---
## 15. Complete File Checklist
```
[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib
[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...)
[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;`
[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation
[ ] trictrac/justfile — add `cxxlib` target
[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge
[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file
[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac
```
---
## 16. Implementation Order
Implement in this sequence to get early feedback at each step:
1. **Rust bridge first** (`Cargo.toml``build.rs``cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced.
2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations.
3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist.
4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them.
5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination.
6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`.
---
# Implementation complete
All files are in place and trictrac_test passes (168 84 final scores, all assertions pass).
What was done
Rust side (from previous session, already complete):
- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build
- trictrac/store/build.rs — drives cxx-build
- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl
- trictrac/store/src/lib.rs — added pub mod cxxengine;
C++ side (this session):
- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header
- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation
- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test
Build system:
- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge
- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable
Justfiles:
- trictrac/justfile — added cxxlib target
- forks/open_spiel/justfile — added buildtrictrac and testtrictrac
Fixes discovered during build
| Issue | Fix |
| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries |
| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) |
| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h |
| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc |
| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) |
| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries |
| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort |
| Duplicate actions for doubles | Added std::unique after sort |
| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states |

121
doc/spiel_bot_parallel.md Normal file
View file

@ -0,0 +1,121 @@
Part B — Batched MCTS leaf evaluation
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
pub trait Evaluator: Send + Sync {
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
/// Evaluate a batch of observations at once. Default falls back to
/// sequential calls; backends override this for efficiency.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
}
}
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
let b = obs_batch.len();
let obs_size = obs_batch[0].len();
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
let action_size = policies.len() / b;
(0..b).map(|i| {
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
}).collect()
}
Step B3 — Add eval_batch_size to MctsConfig
pub struct MctsConfig {
// ... existing fields ...
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
pub eval_batch_size: usize,
}
Default: 1 (backwards-compatible).
Step B4 — Make the simulation iterative (mcts/search.rs)
The current simulate is recursive. For batching we need to split it into two phases:
descend (pure tree traversal — no network call):
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
ascend (backup — no network call):
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
Step B5 — Add run_mcts_batched to mcts/mod.rs
The new entry point, called by run_mcts when config.eval_batch_size > 1:
expand root (1 network call)
optionally add Dirichlet noise
for round in 0..(n*simulations / batch_size):
leaves = []
for * in 0..batch_size:
leaf = descend(root, state, env, rng)
leaves.push(leaf)
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
where leaf.kind == NeedsEval]
results = evaluator.evaluate_batch(obs_batch)
for (leaf, result) in zip(leaves, results):
expand the leaf node (insert children from result.policy)
ascend(root, leaf.path, result.value, leaf.player_idx)
// ascend also handles terminal and crossed-chance leaves
// handle remainder: n_simulations % batch_size
run_mcts becomes a thin dispatcher:
if config.eval_batch_size <= 1 {
// existing path (unchanged)
} else {
run_mcts_batched(...)
}
Step B6 — CLI flag in az_train.rs
--eval-batch N default: 8 Leaf batch size for MCTS network calls
---
Summary of file changes
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
│ File │ Changes │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ spiel_bot/Cargo.toml │ add rayon │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
Key design constraint
Parts A and B are independent and composable:
- A only touches the outer game loop.
- B only touches the inner MCTS per game.
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.

782
doc/spiel_bot_research.md Normal file
View file

@ -0,0 +1,782 @@
# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac
## 0. Context and Scope
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
other stage to an inline random-opponent loop.
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
for **self-play training**. Its goals:
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
that works with Trictrac's multi-stage turn model and stochastic dice.
- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer)
as the first algorithm.
- Remain **modular**: adding DQN or PPO later requires only a new
`impl Algorithm for Dqn` without touching the environment or network layers.
- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from
`trictrac-store`.
---
## 1. Library Landscape
### 1.1 Neural Network Frameworks
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
runtime needed, the `ndarray` backend is sufficient for CPU training and can
switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by
changing one type alias.
`tch-rs` would be the best choice for raw training throughput (it is the most
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
### 1.2 Other Key Crates
| Crate | Role |
| -------------------- | ----------------------------------------------------------------- |
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
| `anyhow` | error propagation (already used everywhere) |
| `indicatif` | training progress bars |
| `tracing` | structured logging per episode/iteration |
### 1.3 What `burn-rl` Provides (and Does Not)
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
trait. It does **not** provide:
- MCTS or any tree-search algorithm
- Two-player self-play
- Legal action masking during training
- Chance-node handling
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
own (simpler, more targeted) traits and implement MCTS from scratch.
---
## 2. Trictrac-Specific Design Constraints
### 2.1 Multi-Stage Turn Model
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
genuine player choice:
| TurnStage | Node type | Handler |
| ---------------- | ------------------------------- | ------------------------------- |
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
| `Move` | **Player decision** | MCTS / policy network |
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
The environment wrapper advances through forced/chance stages automatically so
that from the algorithm's perspective every node it sees is a genuine player
decision.
### 2.2 Stochastic Dice in MCTS
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
introduce stochasticity. Three approaches exist:
**A. Outcome sampling (recommended)**
During each MCTS simulation, when a chance node is reached, sample one dice
outcome at random and continue. After many simulations the expected value
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
and requires no changes to the standard PUCT formula.
**B. Chance-node averaging (expectimax)**
At each chance node, expand all 21 unique dice pairs weighted by their
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
exact but multiplies the branching factor by ~21 at every dice roll, making it
prohibitively expensive.
**C. Condition on dice in the observation (current approach)**
Dice values are already encoded at indices [192193] of `to_tensor()`. The
network naturally conditions on the rolled dice when it evaluates a position.
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
chance nodes are bypassed by the environment wrapper (approach A). The policy
and value heads learn to play optimally given any dice pair.
**Use approach A + C together**: the environment samples dice automatically
(chance node bypass), and the 217-dim tensor encodes the dice so the network
can exploit them.
### 2.3 Perspective / Mirroring
All move rules and tensor encoding are defined from White's perspective.
`to_tensor()` must always be called after mirroring the state for Black.
The environment wrapper handles this transparently: every observation returned
to an algorithm is already in the active player's perspective.
### 2.4 Legal Action Masking
A crucial difference from the existing `bot/` code: instead of penalizing
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
before softmax — illegal action logits are set to `-inf`. This prevents the
network from wasting capacity on illegal moves and eliminates the need for the
penalty-reward hack.
---
## 3. Proposed Crate Architecture
```
spiel_bot/
├── Cargo.toml
└── src/
├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo"
├── env/
│ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface
│ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store
├── mcts/
│ ├── mod.rs # MctsConfig, run_mcts() entry point
│ ├── node.rs # MctsNode (visit count, W, prior, children)
│ └── search.rs # simulate(), backup(), select_action()
├── network/
│ ├── mod.rs # PolicyValueNet trait
│ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads
├── alphazero/
│ ├── mod.rs # AlphaZeroConfig
│ ├── selfplay.rs # generate_episode() -> Vec<TrainSample>
│ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle)
│ └── trainer.rs # training loop: selfplay → sample → loss → update
└── agent/
├── mod.rs # Agent trait
├── random.rs # RandomAgent (baseline)
└── mcts_agent.rs # MctsAgent: uses trained network for inference
```
Future algorithms slot in without touching the above:
```
├── dqn/ # (future) DQN: impl Algorithm + own replay buffer
└── ppo/ # (future) PPO: impl Algorithm + rollout buffer
```
---
## 4. Core Traits
### 4.1 `GameEnv` — the minimal OpenSpiel interface
```rust
use rand::Rng;
/// Who controls the current node.
pub enum Player {
P1, // player index 0
P2, // player index 1
Chance, // dice roll
Terminal, // game over
}
pub trait GameEnv: Clone + Send + Sync + 'static {
type State: Clone + Send + Sync;
/// Fresh game state.
fn new_game(&self) -> Self::State;
/// Who acts at this node.
fn current_player(&self, s: &Self::State) -> Player;
/// Legal action indices (always in [0, action_space())).
/// Empty only at Terminal nodes.
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
/// Apply a player action (must be legal).
fn apply(&self, s: &mut Self::State, action: usize);
/// Advance a Chance node by sampling dice; no-op at non-Chance nodes.
fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng);
/// Observation tensor from `pov`'s perspective (0 or 1).
/// Returns 217 f32 values for Trictrac.
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
/// Flat observation size (217 for Trictrac).
fn obs_size(&self) -> usize;
/// Total action-space size (514 for Trictrac).
fn action_space(&self) -> usize;
/// Game outcome per player, or None if not Terminal.
/// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw.
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
}
```
### 4.2 `PolicyValueNet` — neural network interface
```rust
use burn::prelude::*;
pub trait PolicyValueNet<B: Backend>: Send + Sync {
/// Forward pass.
/// `obs`: [batch, obs_size] tensor.
/// Returns: (policy_logits [batch, action_space], value [batch]).
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
/// Save weights to `path`.
fn save(&self, path: &std::path::Path) -> anyhow::Result<()>;
/// Load weights from `path`.
fn load(path: &std::path::Path) -> anyhow::Result<Self>
where
Self: Sized;
}
```
### 4.3 `Agent` — player policy interface
```rust
pub trait Agent: Send {
/// Select an action index given the current game state observation.
/// `legal`: mask of valid action indices.
fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize;
}
```
---
## 5. MCTS Implementation
### 5.1 Node
```rust
pub struct MctsNode {
n: u32, // visit count N(s, a)
w: f32, // sum of backed-up values W(s, a)
p: f32, // prior from policy head P(s, a)
children: Vec<(usize, MctsNode)>, // (action_idx, child)
is_expanded: bool,
}
impl MctsNode {
pub fn q(&self) -> f32 {
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
}
/// PUCT score used for selection.
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
}
}
```
### 5.2 Simulation Loop
One MCTS simulation (for deterministic decision nodes):
```
1. SELECTION — traverse from root, always pick child with highest PUCT,
auto-advancing forced/chance nodes via env.apply_chance().
2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get
(policy_logits, value). Mask illegal actions, softmax
the remaining logits → priors P(s,a) for each child.
3. BACKUP — propagate -value up the tree (negate at each level because
perspective alternates between P1 and P2).
```
After `n_simulations` iterations, action selection at the root:
```rust
// During training: sample proportional to N^(1/temperature)
// During evaluation: argmax N
fn select_action(root: &MctsNode, temperature: f32) -> usize { ... }
```
### 5.3 Configuration
```rust
pub struct MctsConfig {
pub n_simulations: usize, // e.g. 200
pub c_puct: f32, // exploration constant, e.g. 1.5
pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3
pub dirichlet_eps: f32, // noise weight, e.g. 0.25
pub temperature: f32, // action sampling temperature (anneals to 0)
}
```
### 5.4 Handling Chance Nodes Inside MCTS
When simulation reaches a Chance node (dice roll), the environment automatically
samples dice and advances to the next decision node. The MCTS tree does **not**
branch on dice outcomes — it treats the sampled outcome as the state. This
corresponds to "outcome sampling" (approach A from §2.2). Because each
simulation independently samples dice, the Q-values at player nodes converge to
their expected value over many simulations.
---
## 6. Network Architecture
### 6.1 ResNet Policy-Value Network
A single trunk with residual blocks, then two heads:
```
Input: [batch, 217]
Linear(217 → 512) + ReLU
ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU)
↓ trunk output [batch, 512]
├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site)
└─ Value head: Linear(512 → 1) → tanh (output in [-1, 1])
```
Burn implementation sketch:
```rust
#[derive(Module, Debug)]
pub struct TrictracNet<B: Backend> {
input: Linear<B>,
res_blocks: Vec<ResBlock<B>>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> TrictracNet<B> {
pub fn forward(&self, obs: Tensor<B, 2>)
-> (Tensor<B, 2>, Tensor<B, 1>)
{
let x = activation::relu(self.input.forward(obs));
let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x));
let policy = self.policy_head.forward(x.clone()); // raw logits
let value = activation::tanh(self.value_head.forward(x))
.squeeze(1);
(policy, value)
}
}
```
A simpler MLP (no residual blocks) is sufficient for a first version and much
faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two
heads.
### 6.2 Loss Function
```
L = MSE(value_pred, z)
+ CrossEntropy(policy_logits_masked, π_mcts)
- c_l2 * L2_regularization
```
Where:
- `z` = game outcome (±1) from the active player's perspective
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
- Legal action masking is applied before computing CrossEntropy
---
## 7. AlphaZero Training Loop
```
INIT
network ← random weights
replay ← empty ReplayBuffer(capacity = 100_000)
LOOP forever:
── Self-play phase ──────────────────────────────────────────────
(parallel with rayon, n_workers games at once)
for each game:
state ← env.new_game()
samples = []
while not terminal:
advance forced/chance nodes automatically
obs ← env.observation(state, current_player)
legal ← env.legal_actions(state)
π, root_value ← mcts.run(state, network, config)
action ← sample from π (with temperature)
samples.push((obs, π, current_player))
env.apply(state, action)
z ← env.returns(state) // final scores
for (obs, π, player) in samples:
replay.push(TrainSample { obs, policy: π, value: z[player] })
── Training phase ───────────────────────────────────────────────
for each gradient step:
batch ← replay.sample(batch_size)
(policy_logits, value_pred) ← network.forward(batch.obs)
loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy)
optimizer.step(loss.backward())
── Evaluation (every N iterations) ─────────────────────────────
win_rate ← evaluate(network_new vs network_prev, n_eval_games)
if win_rate > 0.55: save checkpoint
```
### 7.1 Replay Buffer
```rust
pub struct TrainSample {
pub obs: Vec<f32>, // 217 values
pub policy: Vec<f32>, // 514 values (normalized MCTS visit counts)
pub value: f32, // game outcome ∈ {-1, 0, +1}
}
pub struct ReplayBuffer {
data: VecDeque<TrainSample>,
capacity: usize,
}
impl ReplayBuffer {
pub fn push(&mut self, s: TrainSample) {
if self.data.len() == self.capacity { self.data.pop_front(); }
self.data.push_back(s);
}
pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
// sample without replacement
}
}
```
### 7.2 Parallelism Strategy
Self-play is embarrassingly parallel (each game is independent):
```rust
let samples: Vec<TrainSample> = (0..n_games)
.into_par_iter() // rayon
.flat_map(|_| generate_episode(&env, &network, &mcts_config))
.collect();
```
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
`Autodiff<NdArray<f32>>`.
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
self-play workers from the training thread, allowing continuous data generation
while the GPU trains.
---
## 8. `TrictracEnv` Implementation Sketch
```rust
use trictrac_store::{
training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE},
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
};
#[derive(Clone)]
pub struct TrictracEnv;
impl GameEnv for TrictracEnv {
type State = GameState;
fn new_game(&self) -> GameState {
GameState::new_with_players("P1", "P2")
}
fn current_player(&self, s: &GameState) -> Player {
match s.stage {
Stage::Ended => Player::Terminal,
_ => match s.turn_stage {
TurnStage::RollWaiting => Player::Chance,
_ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 },
},
}
}
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() };
get_valid_action_indices(&view).unwrap_or_default()
}
fn apply(&self, s: &mut GameState, action_idx: usize) {
// advance all forced/chance nodes first, then apply the player action
self.advance_forced(s);
let needs_mirror = s.active_player_id == 2;
let view = if needs_mirror { s.mirror() } else { s.clone() };
if let Some(event) = TrictracAction::from_action_index(action_idx)
.and_then(|a| a.to_event(&view))
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
{
let _ = s.consume(&event);
}
// advance any forced stages that follow
self.advance_forced(s);
}
fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) {
// RollDice → RollWaiting
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
// RollWaiting → next stage
let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) };
let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice });
self.advance_forced(s);
}
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() }
}
fn obs_size(&self) -> usize { 217 }
fn action_space(&self) -> usize { ACTION_SPACE_SIZE }
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
if s.stage != Stage::Ended { return None; }
// Convert hole+point scores to ±1 outcome
let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
Some(match s1.cmp(&s2) {
std::cmp::Ordering::Greater => [ 1.0, -1.0],
std::cmp::Ordering::Less => [-1.0, 1.0],
std::cmp::Ordering::Equal => [ 0.0, 0.0],
})
}
}
impl TrictracEnv {
/// Advance through all forced (non-decision, non-chance) stages.
fn advance_forced(&self, s: &mut GameState) {
use trictrac_store::PointsRules;
loop {
match s.turn_stage {
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
// Scoring is deterministic; compute and apply automatically.
let color = s.player_color_by_id(&s.active_player_id)
.unwrap_or(trictrac_store::Color::White);
let drc = s.players.get(&s.active_player_id)
.map(|p| p.dice_roll_count).unwrap_or(0);
let pr = PointsRules::new(&color, &s.board, s.dice);
let pts = pr.get_points(drc);
let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 };
let _ = s.consume(&GameEvent::Mark {
player_id: s.active_player_id, points,
});
}
TurnStage::RollDice => {
// RollDice is a forced "initiate roll" action with no real choice.
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
}
_ => break,
}
}
}
}
```
---
## 9. Cargo.toml Changes
### 9.1 Add `spiel_bot` to the workspace
```toml
# Cargo.toml (workspace root)
[workspace]
resolver = "2"
members = ["client_cli", "bot", "store", "spiel_bot"]
```
### 9.2 `spiel_bot/Cargo.toml`
```toml
[package]
name = "spiel_bot"
version = "0.1.0"
edition = "2021"
[features]
default = ["alphazero"]
alphazero = []
# dqn = [] # future
# ppo = [] # future
[dependencies]
trictrac-store = { path = "../store" }
anyhow = "1"
rand = "0.9"
rayon = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Burn: NdArray for pure-Rust CPU training
# Replace NdArray with Wgpu or Tch for GPU.
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
# Optional: progress display and structured logging
indicatif = "0.17"
tracing = "0.1"
[[bin]]
name = "az_train"
path = "src/bin/az_train.rs"
[[bin]]
name = "az_eval"
path = "src/bin/az_eval.rs"
```
---
## 10. Comparison: `bot` crate vs `spiel_bot`
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
| ---------------- | --------------------------- | -------------------------------------------- |
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
| Opponent | hardcoded random | self-play |
| Invalid actions | penalise with reward | legal action mask (no penalty) |
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
| Burn dep | yes (0.20) | yes (0.20), same backend |
| `burn-rl` dep | yes | no |
| C++ dep | no | no |
| Python dep | no | no |
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
replace `TrictracEnvironment` with `TrictracEnv`).
---
## 11. Implementation Order
1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game
through the trait, verify terminal state and returns).
2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) +
Burn forward/backward pass test with dummy data.
3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests
(visit counts sum to `n_simulations`, legal mask respected).
4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub
(one iteration, check loss decreases).
5. **Integration test**: run 100 self-play games with a tiny network (1 res block,
64 hidden units), verify the training loop completes without panics.
6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s
on CPU, consistent with `random_game` throughput).
7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule
hyperparameter sweep.
8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report
win rate every checkpoint.
---
## 12. Key Open Questions
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
scored more holes). Better option: normalize the score margin. The final
choice affects how the value head is trained.
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
MCTS with 200 simulations per step means ~120k network evaluations per game.
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
limiting `n_simulations` to 50100 for early training.
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
This is a long-horizon decision that AlphaZero handles naturally via MCTS
lookahead, but needs careful value normalization (a "Go" restarts scoring
within the same game).
4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated
to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic.
This is optional but reduces code duplication.
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
## Implementation results
All benchmarks compile and run. Here's the complete results table:
| Group | Benchmark | Time |
| ------- | ----------------------- | --------------------- |
| env | apply_chance | 3.87 µs |
| | legal_actions | 1.91 µs |
| | observation (to_tensor) | 341 ns |
| | random_game (baseline) | 3.55 ms → 282 games/s |
| network | mlp_b1 hidden=64 | 94.9 µs |
| | mlp_b32 hidden=64 | 141 µs |
| | mlp_b1 hidden=256 | 352 µs |
| | mlp_b32 hidden=256 | 479 µs |
| mcts | zero_eval n=1 | 6.8 µs |
| | zero_eval n=5 | 23.9 µs |
| | zero_eval n=20 | 90.9 µs |
| | mlp64 n=1 | 203 µs |
| | mlp64 n=5 | 622 µs |
| | mlp64 n=20 | 2.30 ms |
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
| | trictrac n=2 | 145 ms → 7 games/s |
| train | mlp64 Adam b=16 | 1.93 ms |
| | mlp64 Adam b=64 | 2.68 ms |
Key observations:
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
- observation (217-value tensor): only 341 ns — not a bottleneck
- legal_actions: 1.9 µs — well optimised
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
- Training: 2.7 ms/step at batch=64 → 370 steps/s
### Summary of Step 8
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
- Output: win/draw/loss per side + combined decisive win rate
Typical usage after training:
cargo run -p spiel_bot --bin az_eval --release -- \
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
### az_train
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
cargo run -p spiel_bot --bin az_train --release
#### ResNet, more games, custom output dir
cargo run -p spiel_bot --bin az_train --release -- \
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
--save-every 20 --out checkpoints/
#### Resume from iteration 50
cargo run -p spiel_bot --bin az_train --release -- \
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
What the binary does each iteration:
1. Calls model.valid() to get a zero-overhead inference copy for self-play
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
5. Saves a .mpk checkpoint every --save-every iterations and always on the last

253
doc/tensor_research.md Normal file
View file

@ -0,0 +1,253 @@
# Tensor research
## Current tensor anatomy
[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!)
[24] active player color: 0 or 1
[25] turn_stage: 15
[2627] dice values (raw 16)
[2831] white: points, holes, can_bredouille, can_big_bredouille
[3235] black: same
─────────────────────────────────
Total 36 floats
The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's
AlphaZero uses a fully-connected network.
### Fundamental problems with the current encoding
1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network
must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with
all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from.
2. No normalization. Dice (16), counts (15 to +15), booleans (0/1), points (012) coexist without scaling. Gradient
flow during training is uneven.
3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it
triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers
produces a score. Including this explicitly is the single highest-value addition.
4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 1924) governs an entirely
different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0.
5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the
starting position). It's in the Player struct but not exported.
## Key Trictrac distinctions from backgammon that shape the encoding
| Concept | Backgammon | Trictrac |
| ------------------------- | ---------------------- | --------------------------------------------------------- |
| Hitting a blot | Removes checker to bar | Scores points, checker stays |
| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened |
| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) |
| 3-checker field | Safe with spare | Safe with spare |
| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) |
| Both colors on a field | Impossible | Perfectly legal |
| Rest corner (field 12/13) | Does not exist | Special two-checker rules |
The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary
indicators directly teaches the network the phase transitions the game hinges on.
## Options
### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D)
The minimum viable improvement.
For each of the 24 fields, encode own and opponent separately with 4 indicators each:
own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target)
own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill)
own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare)
own_x[i]: max(0, count 3) (overflow)
opp_1[i]: same for opponent
Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec().
Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204
Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured
overhead is negligible.
Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from
scratch will be substantially faster and the converged policy quality better.
### Option B — Option A + Trictrac-specific derived features (flat 1D)
Recommended starting point.
Add on top of Option A:
// Quarter fill status — the single most important derived feature
quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields)
quarter_filled_opp[q] (q=0..3): same for opponent
→ 8 values
// Exit readiness
can_exit_own: 1.0 if all own checkers are in fields 1924
can_exit_opp: same for opponent
→ 2 values
// Rest corner status (field 12/13)
own_corner_taken: 1.0 if field 12 has ≥2 own checkers
opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers
→ 2 values
// Jan de 3 coups counter (normalized)
dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0)
→ 1 value
Size: 204 + 8 + 2 + 2 + 1 = 217
Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve
the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes
expensive inference from the network.
### Option C — Option B + richer positional features (flat 1D)
More complete, higher sample efficiency, minor extra cost.
Add on top of Option B:
// Per-quarter fill fraction — how close to filling each quarter
own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0
opp_quarter_fill_fraction[q] (q=0..3): same for opponent
→ 8 values
// Blot counts — number of own/opponent single-checker fields globally
// (tells the network at a glance how much battage risk/opportunity exists)
own_blot_count: (number of own fields with exactly 1 checker) / 15.0
opp_blot_count: same for opponent
→ 2 values
// Bredouille would-double multiplier (already present, but explicitly scaled)
// No change needed, already binary
Size: 217 + 8 + 2 = 227
Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network
from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts).
### Option D — 2D spatial tensor {K, 24}
For CNN-based networks. Best eventual architecture but requires changing the training setup.
Shape {14, 24} — 14 feature channels over 24 field positions:
Channel 0: own_count_1 (blot)
Channel 1: own_count_2
Channel 2: own_count_3
Channel 3: own_count_overflow (float)
Channel 4: opp_count_1
Channel 5: opp_count_2
Channel 6: opp_count_3
Channel 7: opp_count_overflow
Channel 8: own_corner_mask (1.0 at field 12)
Channel 9: opp_corner_mask (1.0 at field 13)
Channel 10: final_quarter_mask (1.0 at fields 1924)
Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter)
Channel 12: quarter_filled_opp (same for opponent)
Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers)
Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform
value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global"
row by returning shape {K, 25} with position 0 holding global features.
Size: 14 × 24 + few global channels ≈ 336384
C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated
accordingly.
Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's
alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H,
W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension.
Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel
size 23 captures adjacent-field "tout d'une" interactions.
### On 3D tensors
Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is
the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and
field-within-quarter patterns jointly.
However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The
extra architecture work makes this premature unless you're already building a custom network. The information content
is the same as Option D.
### Recommendation
Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and
the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions
(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move.
Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when
the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time.
Costs summary:
| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain |
| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- |
| Current | 36 | — | — | — | baseline |
| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) |
| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) |
| C | 227 | to_vec() rewrite | constant update | none | large + moderate |
| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial |
One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2,
the new to_vec() must express everything from the active player's perspective (which the mirror already handles for
the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state,
which they will be if computed inside to_vec().
## Other algorithms
The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully.
### 1. Without MCTS, feature quality matters more
AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred
simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup —
the network must learn the full strategic structure directly from gradient updates.
This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO,
not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less
sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from
scratch.
### 2. Normalization becomes mandatory, not optional
AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps
Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw
values (dice 16, counts 015, booleans 0/1, points 012) in the same vector, gradient flow is uneven and training can
be unstable.
For DQN/PPO, every feature in the tensor should be in [0, 1]:
dice values: / 6.0
checker counts: overflow channel / 12.0
points: / 12.0
holes: / 12.0
dice_roll_count: / 3.0 (clamped)
Booleans and the TD-Gammon binary indicators are already in [0, 1].
### 3. The shape question depends on architecture, not algorithm
| Architecture | Shape | When to use |
| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- |
| MLP | {217} flat | Any algorithm, simplest baseline |
| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) |
| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network |
| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now |
The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want
convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet
template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all
— you just reshape the flat tensor in Python before passing it to the network.
### One thing that genuinely changes: value function perspective
AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This
works well.
DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit
current-player indicator), because a single Q-network estimates action values for both players simultaneously. With
the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must
learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity.
This is a minor point for a symmetric game like Trictrac, but worth keeping in mind.
Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm.

19
spiel_bot/Cargo.toml Normal file
View file

@ -0,0 +1,19 @@
[package]
name = "spiel_bot"
version = "0.1.0"
edition = "2021"
[dependencies]
trictrac-store = { path = "../store" }
anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
rayon = "1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "alphazero"
harness = false

View file

@ -0,0 +1,373 @@
//! AlphaZero pipeline benchmarks.
//!
//! Run with:
//!
//! ```sh
//! cargo bench -p spiel_bot
//! ```
//!
//! Use `-- <filter>` to run a specific group, e.g.:
//!
//! ```sh
//! cargo bench -p spiel_bot -- env/
//! cargo bench -p spiel_bot -- network/
//! cargo bench -p spiel_bot -- mcts/
//! cargo bench -p spiel_bot -- episode/
//! cargo bench -p spiel_bot -- train/
//! ```
//!
//! Target: ≥ 500 games/s for random play on CPU (consistent with
//! `random_game` throughput in `trictrac-store`).
use std::time::Duration;
use burn::{
backend::NdArray,
tensor::{Tensor, TensorData, backend::Backend},
};
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
// ── Shared types ───────────────────────────────────────────────────────────
type InferB = NdArray<f32>;
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
/// Uniform evaluator (returns zero logits and zero value).
/// Used to isolate MCTS tree-traversal cost from network cost.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
// ── 1. Environment primitives ──────────────────────────────────────────────
/// Baseline performance of the raw Trictrac environment without MCTS.
/// Target: ≥ 500 full games / second.
fn bench_env(c: &mut Criterion) {
let env = TrictracEnv;
let mut group = c.benchmark_group("env");
group.measurement_time(Duration::from_secs(10));
// ── apply_chance ──────────────────────────────────────────────────────
group.bench_function("apply_chance", |b| {
b.iter_batched(
|| {
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
env.new_game()
},
|mut s| {
env.apply_chance(&mut s, &mut seeded());
black_box(s)
},
BatchSize::SmallInput,
)
});
// ── legal_actions ─────────────────────────────────────────────────────
group.bench_function("legal_actions", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.legal_actions(&s)))
});
// ── observation (to_tensor) ───────────────────────────────────────────
group.bench_function("observation", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.observation(&s, 0)))
});
// ── full random game ──────────────────────────────────────────────────
group.sample_size(50);
group.bench_function("random_game", |b| {
b.iter_batched(
seeded,
|mut rng| {
let mut s = env.new_game();
loop {
match env.current_player(&s) {
Player::Terminal => break,
Player::Chance => env.apply_chance(&mut s, &mut rng),
_ => {
let actions = env.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
env.apply(&mut s, actions[idx]);
}
}
}
black_box(s)
},
BatchSize::SmallInput,
)
});
group.finish();
}
// ── 2. Network inference ───────────────────────────────────────────────────
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
fn bench_network(c: &mut Criterion) {
let mut group = c.benchmark_group("network");
group.measurement_time(Duration::from_secs(5));
for &hidden in &[64usize, 256] {
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
// Batch size 1 — single-position evaluation as in MCTS.
group.bench_with_input(
BenchmarkId::new("mlp_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
// Batch size 32 — training mini-batch.
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("mlp_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
// ── ResNet (4 residual blocks) ────────────────────────────────────────
for &hidden in &[256usize, 512] {
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = ResNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
group.bench_with_input(
BenchmarkId::new("resnet_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("resnet_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}
// ── 3. MCTS ───────────────────────────────────────────────────────────────
/// MCTS cost at different simulation budgets with two evaluator types:
/// - `zero` — isolates tree-traversal overhead (no network).
/// - `mlp64` — real MLP, shows end-to-end cost per move.
fn bench_mcts(c: &mut Criterion) {
let env = TrictracEnv;
// Build a decision-node state (after dice roll).
let state = {
let mut s = env.new_game();
let mut rng = seeded();
while env.current_player(&s).is_chance() {
env.apply_chance(&mut s, &mut rng);
}
s
};
let mut group = c.benchmark_group("mcts");
group.measurement_time(Duration::from_secs(10));
let zero_eval = ZeroEval(514);
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
for &n_sim in &[1usize, 5, 20] {
let cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
// Zero evaluator: tree traversal only.
group.bench_with_input(
BenchmarkId::new("zero_eval", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
// MLP evaluator: full cost per decision.
group.bench_with_input(
BenchmarkId::new("mlp64", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 4. Episode generation ─────────────────────────────────────────────────
/// Full self-play episode latency (one complete game) at different MCTS
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
fn bench_episode(c: &mut Criterion) {
let env = TrictracEnv;
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
let mut group = c.benchmark_group("episode");
group.sample_size(10);
group.measurement_time(Duration::from_secs(60));
for &n_sim in &[1usize, 2] {
let mcts_cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
group.bench_with_input(
BenchmarkId::new("trictrac", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| {
black_box(generate_episode(
&env,
&eval,
&mcts_cfg,
&|_| 1.0,
&mut rng,
))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 5. Training step ───────────────────────────────────────────────────────
/// Gradient-step latency for different batch sizes.
fn bench_train(c: &mut Criterion) {
use burn::optim::AdamConfig;
let mut group = c.benchmark_group("train");
group.measurement_time(Duration::from_secs(10));
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let dummy_samples = |n: usize| -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5; 217],
policy: {
let mut p = vec![0.0f32; 514];
p[i % 514] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
};
for &batch_size in &[16usize, 64] {
let batch = dummy_samples(batch_size);
group.bench_with_input(
BenchmarkId::new("mlp64_adam", batch_size),
&batch_size,
|b, _| {
b.iter_batched(
|| {
(
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
)
},
|(model, mut opt)| {
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── Criterion entry point ──────────────────────────────────────────────────
criterion_group!(
benches,
bench_env,
bench_network,
bench_mcts,
bench_episode,
bench_train,
);
criterion_main!(benches);

View file

@ -0,0 +1,127 @@
//! AlphaZero: self-play data generation, replay buffer, and training step.
//!
//! # Modules
//!
//! | Module | Contents |
//! |--------|----------|
//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] |
//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] |
//! | [`trainer`] | [`train_step`] |
//!
//! # Typical outer loop
//!
//! ```rust,ignore
//! use burn::backend::{Autodiff, NdArray};
//! use burn::optim::AdamConfig;
//! use spiel_bot::{
//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step},
//! env::TrictracEnv,
//! mcts::MctsConfig,
//! network::{MlpConfig, MlpNet},
//! };
//!
//! type Infer = NdArray<f32>;
//! type Train = Autodiff<NdArray<f32>>;
//!
//! let device = Default::default();
//! let env = TrictracEnv;
//! let config = AlphaZeroConfig::default();
//!
//! // Build training model and optimizer.
//! let mut train_model = MlpNet::<Train>::new(&MlpConfig::default(), &device);
//! let mut optimizer = AdamConfig::new().init();
//! let mut replay = ReplayBuffer::new(config.replay_capacity);
//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
//!
//! for _iter in 0..config.n_iterations {
//! // Convert to inference backend for self-play.
//! let infer_model = MlpNet::<Infer>::new(&MlpConfig::default(), &device)
//! .load_record(train_model.clone().into_record());
//! let eval = BurnEvaluator::new(infer_model, device.clone());
//!
//! // Self-play: generate episodes.
//! for _ in 0..config.n_games_per_iter {
//! let samples = generate_episode(&env, &eval, &config.mcts,
//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
//! replay.extend(samples);
//! }
//!
//! // Training: gradient steps.
//! if replay.len() >= config.batch_size {
//! for _ in 0..config.n_train_steps_per_iter {
//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng)
//! .into_iter().cloned().collect();
//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device,
//! config.learning_rate);
//! train_model = m;
//! }
//! }
//! }
//! ```
pub mod replay;
pub mod selfplay;
pub mod trainer;
pub use replay::{ReplayBuffer, TrainSample};
pub use selfplay::{BurnEvaluator, generate_episode};
pub use trainer::{cosine_lr, train_step};
use crate::mcts::MctsConfig;
// ── Configuration ─────────────────────────────────────────────────────────
/// Top-level AlphaZero hyperparameters.
///
/// The MCTS parameters live in [`MctsConfig`]; this struct holds the
/// outer training-loop parameters.
#[derive(Debug, Clone)]
pub struct AlphaZeroConfig {
/// MCTS parameters for self-play.
pub mcts: MctsConfig,
/// Number of self-play games per training iteration.
pub n_games_per_iter: usize,
/// Number of gradient steps per training iteration.
pub n_train_steps_per_iter: usize,
/// Mini-batch size for each gradient step.
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Initial (peak) Adam learning rate.
pub learning_rate: f64,
/// Minimum learning rate for cosine annealing (floor of the schedule).
///
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
/// Compute the current LR with [`cosine_lr`]:
///
/// ```rust,ignore
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
/// ```
pub lr_min: f64,
/// Number of outer iterations (self-play + train) to run.
pub n_iterations: usize,
/// Move index after which the action temperature drops to 0 (greedy play).
pub temperature_drop_move: usize,
}
impl Default for AlphaZeroConfig {
fn default() -> Self {
Self {
mcts: MctsConfig {
n_simulations: 100,
c_puct: 1.5,
dirichlet_alpha: 0.1,
dirichlet_eps: 0.25,
temperature: 1.0,
},
n_games_per_iter: 10,
n_train_steps_per_iter: 20,
batch_size: 64,
replay_capacity: 50_000,
learning_rate: 1e-3,
lr_min: 1e-4, // cosine annealing floor
n_iterations: 100,
temperature_drop_move: 30,
}
}
}

View file

@ -0,0 +1,144 @@
//! Replay buffer for AlphaZero self-play data.
use std::collections::VecDeque;
use rand::Rng;
// ── Training sample ────────────────────────────────────────────────────────
/// One training example produced by self-play.
#[derive(Clone, Debug)]
pub struct TrainSample {
/// Observation tensor from the acting player's perspective (`obs_size` floats).
pub obs: Vec<f32>,
/// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1).
pub policy: Vec<f32>,
/// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw.
pub value: f32,
}
// ── Replay buffer ──────────────────────────────────────────────────────────
/// Fixed-capacity circular buffer of [`TrainSample`]s.
///
/// When the buffer is full, the oldest sample is evicted on push.
/// Samples are drawn without replacement using a Fisher-Yates partial shuffle.
pub struct ReplayBuffer {
data: VecDeque<TrainSample>,
capacity: usize,
}
impl ReplayBuffer {
/// Create a buffer with the given maximum capacity.
pub fn new(capacity: usize) -> Self {
Self {
data: VecDeque::with_capacity(capacity.min(1024)),
capacity,
}
}
/// Add a sample; evicts the oldest if at capacity.
pub fn push(&mut self, sample: TrainSample) {
if self.data.len() == self.capacity {
self.data.pop_front();
}
self.data.push_back(sample);
}
/// Add all samples from an episode.
pub fn extend(&mut self, samples: impl IntoIterator<Item = TrainSample>) {
for s in samples {
self.push(s);
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Sample up to `n` distinct samples, without replacement.
///
/// If the buffer has fewer than `n` samples, all are returned (shuffled).
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
let len = self.data.len();
let n = n.min(len);
// Partial Fisher-Yates using index shuffling.
let mut indices: Vec<usize> = (0..len).collect();
for i in 0..n {
let j = rng.random_range(i..len);
indices.swap(i, j);
}
indices[..n].iter().map(|&i| &self.data[i]).collect()
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
fn dummy(value: f32) -> TrainSample {
TrainSample { obs: vec![value], policy: vec![1.0], value }
}
#[test]
fn push_and_len() {
let mut buf = ReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(dummy(1.0));
buf.push(dummy(2.0));
assert_eq!(buf.len(), 2);
}
#[test]
fn evicts_oldest_at_capacity() {
let mut buf = ReplayBuffer::new(3);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
buf.push(dummy(3.0));
buf.push(dummy(4.0)); // evicts 1.0
assert_eq!(buf.len(), 3);
// Oldest remaining should be 2.0
assert_eq!(buf.data[0].value, 2.0);
}
#[test]
fn sample_batch_size() {
let mut buf = ReplayBuffer::new(20);
for i in 0..10 {
buf.push(dummy(i as f32));
}
let mut rng = SmallRng::seed_from_u64(0);
let batch = buf.sample_batch(5, &mut rng);
assert_eq!(batch.len(), 5);
}
#[test]
fn sample_batch_capped_at_len() {
let mut buf = ReplayBuffer::new(20);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
let mut rng = SmallRng::seed_from_u64(0);
let batch = buf.sample_batch(100, &mut rng);
assert_eq!(batch.len(), 2);
}
#[test]
fn sample_batch_no_duplicates() {
let mut buf = ReplayBuffer::new(20);
for i in 0..10 {
buf.push(dummy(i as f32));
}
let mut rng = SmallRng::seed_from_u64(1);
let batch = buf.sample_batch(10, &mut rng);
let mut seen: Vec<f32> = batch.iter().map(|s| s.value).collect();
seen.sort_by(f32::total_cmp);
seen.dedup();
assert_eq!(seen.len(), 10, "sample contained duplicates");
}
}

View file

@ -0,0 +1,238 @@
//! Self-play episode generation and Burn-backed evaluator.
use std::marker::PhantomData;
use burn::tensor::{backend::Backend, Tensor, TensorData};
use rand::Rng;
use crate::env::GameEnv;
use crate::mcts::{self, Evaluator, MctsConfig, MctsNode};
use crate::network::PolicyValueNet;
use super::replay::TrainSample;
// ── BurnEvaluator ──────────────────────────────────────────────────────────
/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS.
///
/// Use the **inference backend** (`NdArray<f32>`, no `Autodiff` wrapper) so
/// that self-play generates no gradient tape overhead.
pub struct BurnEvaluator<B: Backend, N: PolicyValueNet<B>> {
model: N,
device: B::Device,
_b: PhantomData<B>,
}
impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
pub fn new(model: N, device: B::Device) -> Self {
Self { model, device, _b: PhantomData }
}
pub fn into_model(self) -> N {
self.model
}
pub fn model_ref(&self) -> &N {
&self.model
}
}
// Safety: NdArray<f32> modules are Send; we never share across threads without
// external synchronisation.
unsafe impl<B: Backend, N: PolicyValueNet<B>> Send for BurnEvaluator<B, N> {}
unsafe impl<B: Backend, N: PolicyValueNet<B>> Sync for BurnEvaluator<B, N> {}
impl<B: Backend, N: PolicyValueNet<B>> Evaluator for BurnEvaluator<B, N> {
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32) {
let obs_size = obs.len();
let data = TensorData::new(obs.to_vec(), [1, obs_size]);
let obs_tensor = Tensor::<B, 2>::from_data(data, &self.device);
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
let policy: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
let value: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
(policy, value[0])
}
}
// ── Episode generation ─────────────────────────────────────────────────────
/// One pending observation waiting for its game-outcome value label.
struct PendingSample {
obs: Vec<f32>,
policy: Vec<f32>,
player: usize,
}
/// Play one full game using MCTS guided by `evaluator`.
///
/// Returns a [`TrainSample`] for every decision step in the game.
///
/// `temperature_fn(step)` controls exploration: return `1.0` for early
/// moves and `0.0` after a fixed number of moves (e.g. move 30).
pub fn generate_episode<E: GameEnv>(
env: &E,
evaluator: &dyn Evaluator,
mcts_config: &MctsConfig,
temperature_fn: &dyn Fn(usize) -> f32,
rng: &mut impl Rng,
) -> Vec<TrainSample> {
let mut state = env.new_game();
let mut pending: Vec<PendingSample> = Vec::new();
let mut step = 0usize;
loop {
// Advance through chance nodes automatically.
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, rng);
}
if env.current_player(&state).is_terminal() {
break;
}
let player_idx = env.current_player(&state).index().unwrap();
// Run MCTS to get a policy.
let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng);
let policy = mcts::mcts_policy(&root, env.action_space());
// Record the observation from the acting player's perspective.
let obs = env.observation(&state, player_idx);
pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx });
// Select and apply the action.
let temperature = temperature_fn(step);
let action = mcts::select_action(&root, temperature, rng);
env.apply(&mut state, action);
step += 1;
}
// Assign game outcomes.
let returns = env.returns(&state).unwrap_or([0.0; 2]);
pending
.into_iter()
.map(|s| TrainSample {
obs: s.obs,
policy: s.policy,
value: returns[s.player],
})
.collect()
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng};
use crate::env::Player;
use crate::mcts::{Evaluator, MctsConfig};
use crate::network::{MlpConfig, MlpNet};
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn rng() -> SmallRng {
SmallRng::seed_from_u64(7)
}
// Countdown game (same as in mcts tests).
#[derive(Clone, Debug)]
struct CState { remaining: u8, to_move: usize }
#[derive(Clone)]
struct CountdownEnv;
impl GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } }
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 { Player::Terminal }
else if s.to_move == 0 { Player::P1 } else { Player::P2 }
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub { s.remaining = 0; }
else { s.remaining -= sub; s.to_move = 1 - s.to_move; }
}
fn apply_chance<R: Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / 4.0, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
fn tiny_config() -> MctsConfig {
MctsConfig { n_simulations: 5, c_puct: 1.5,
dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 }
}
// ── BurnEvaluator tests ───────────────────────────────────────────────
#[test]
fn burn_evaluator_output_shapes() {
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let (policy, value) = eval.evaluate(&[0.5f32, 0.5]);
assert_eq!(policy.len(), 2, "policy length should equal action_space");
assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)");
}
// ── generate_episode tests ────────────────────────────────────────────
#[test]
fn episode_terminates_and_has_samples() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
assert!(!samples.is_empty(), "episode must produce at least one sample");
}
#[test]
fn episode_sample_values_are_valid() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
for s in &samples {
assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
"unexpected value {}", s.value);
let sum: f32 = s.policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}");
assert_eq!(s.obs.len(), 2);
}
}
#[test]
fn episode_with_temperature_zero() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
// temperature=0 means greedy; episode must still terminate
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng());
assert!(!samples.is_empty());
}
}

View file

@ -0,0 +1,258 @@
//! One gradient-descent training step for AlphaZero.
//!
//! The loss combines:
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
//! - **Value loss** — mean-squared error between the predicted value and the
//! actual game outcome.
//!
//! # Learning-rate scheduling
//!
//! [`cosine_lr`] implements one-cycle cosine annealing:
//!
//! ```text
//! lr(t) = lr_min + 0.5 · (lr_max lr_min) · (1 + cos(π · t / T))
//! ```
//!
//! Typical usage in the outer loop:
//!
//! ```rust,ignore
//! for step in 0..total_train_steps {
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
//! model = m;
//! }
//! ```
//!
//! # Backend
//!
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
//! Weights are transferred between the two via [`burn::record`].
use burn::{
module::AutodiffModule,
optim::{GradientsParams, Optimizer},
prelude::ElementConversion,
tensor::{
activation::log_softmax,
backend::AutodiffBackend,
Tensor, TensorData,
},
};
use crate::network::PolicyValueNet;
use super::replay::TrainSample;
/// Run one gradient step on `model` using `batch`.
///
/// Returns the updated model and the scalar loss value for logging.
///
/// # Parameters
///
/// - `lr` — learning rate (e.g. `1e-3`).
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
pub fn train_step<B, N, O>(
model: N,
optimizer: &mut O,
batch: &[TrainSample],
device: &B::Device,
lr: f64,
) -> (N, f32)
where
B: AutodiffBackend,
N: PolicyValueNet<B> + AutodiffModule<B>,
O: Optimizer<N, B>,
{
assert!(!batch.is_empty(), "train_step called with empty batch");
let batch_size = batch.len();
let obs_size = batch[0].obs.len();
let action_size = batch[0].policy.len();
// ── Build input tensors ────────────────────────────────────────────────
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [batch_size, obs_size]),
device,
);
let policy_target = Tensor::<B, 2>::from_data(
TensorData::new(policy_flat, [batch_size, action_size]),
device,
);
let value_target = Tensor::<B, 2>::from_data(
TensorData::new(value_flat, [batch_size, 1]),
device,
);
// ── Forward pass ──────────────────────────────────────────────────────
let (policy_logits, value_pred) = model.forward(obs_tensor);
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
let log_probs = log_softmax(policy_logits, 1);
let policy_loss = (policy_target.clone().neg() * log_probs)
.sum_dim(1)
.mean();
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
let diff = value_pred - value_target;
let value_loss = (diff.clone() * diff).mean();
// ── Combined loss ─────────────────────────────────────────────────────
let loss = policy_loss + value_loss;
// Extract scalar before backward (consumes the tensor).
let loss_scalar: f32 = loss.clone().into_scalar().elem();
// ── Backward + optimizer step ─────────────────────────────────────────
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
let model = optimizer.step(lr, model, grads);
(model, loss_scalar)
}
// ── Learning-rate schedule ─────────────────────────────────────────────────
/// Cosine learning-rate schedule (one half-period, no warmup).
///
/// Returns the learning rate for training step `step` out of `total_steps`:
///
/// ```text
/// lr(t) = lr_min + 0.5 · (initial lr_min) · (1 + cos(π · t / total))
/// ```
///
/// - At `t = 0` returns `initial`.
/// - At `t = total_steps` (or beyond) returns `lr_min`.
///
/// # Panics
///
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
if total_steps == 0 || step >= total_steps {
return lr_min;
}
let progress = step as f64 / total_steps as f64;
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::AdamConfig,
};
use crate::network::{MlpConfig, MlpNet};
use super::super::replay::TrainSample;
type B = Autodiff<NdArray<f32>>;
fn device() -> <B as burn::tensor::backend::Backend>::Device {
Default::default()
}
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5f32; obs_size],
policy: {
let mut p = vec![0.0f32; action_size];
p[i % action_size] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
}
#[test]
fn train_step_returns_finite_loss() {
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
let model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(8, 4, 4);
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite(), "loss must be finite, got {loss}");
assert!(loss > 0.0, "loss should be positive");
}
#[test]
fn loss_decreases_over_steps() {
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let mut model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
// Same batch every step — loss should decrease.
let batch = dummy_batch(16, 4, 4);
let mut prev_loss = f32::INFINITY;
for _ in 0..10 {
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
model = m;
assert!(loss.is_finite());
prev_loss = loss;
}
// After 10 steps on fixed data, loss should be below a reasonable threshold.
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
}
#[test]
fn train_step_batch_size_one() {
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(1, 2, 2);
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite());
}
// ── cosine_lr ─────────────────────────────────────────────────────────
#[test]
fn cosine_lr_at_step_zero_is_initial() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
}
#[test]
fn cosine_lr_at_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
}
#[test]
fn cosine_lr_beyond_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
}
#[test]
fn cosine_lr_midpoint_is_average() {
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
let expected = (1e-3 + 1e-5) / 2.0;
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
}
#[test]
fn cosine_lr_monotone_decreasing() {
let mut prev = f64::INFINITY;
for step in 0..=100 {
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
prev = lr;
}
}
#[test]
fn cosine_lr_zero_total_steps_returns_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
assert!((lr - 1e-5).abs() < 1e-10);
}
}

View file

@ -0,0 +1,262 @@
//! Evaluate a trained AlphaZero checkpoint against a random player.
//!
//! # Usage
//!
//! ```sh
//! # Random weights (sanity check — should be ~50 %)
//! cargo run -p spiel_bot --bin az_eval --release
//!
//! # Trained MLP checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50
//!
//! # Trained ResNet checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--checkpoint <path>` | (none) | Load weights from `.mpk` file; random weights if omitted |
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
//! | `--hidden <N>` | 256 (mlp) / 512 (resnet) | Hidden size |
//! | `--n-games <N>` | `100` | Games per side (total = 2 × N) |
//! | `--n-sim <N>` | `50` | MCTS simulations per move |
//! | `--seed <N>` | `42` | RNG seed |
//! | `--c-puct <F>` | `1.5` | PUCT exploration constant |
use std::path::PathBuf;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng, Rng};
use spiel_bot::{
alphazero::BurnEvaluator,
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts, select_action},
network::{MlpConfig, MlpNet, ResNet, ResNetConfig},
};
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
checkpoint: Option<PathBuf>,
arch: String,
hidden: Option<usize>,
n_games: usize,
n_sim: usize,
seed: u64,
c_puct: f32,
}
impl Default for Args {
fn default() -> Self {
Self {
checkpoint: None,
arch: "mlp".into(),
hidden: None,
n_games: 100,
n_sim: 50,
seed: 42,
c_puct: 1.5,
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut args = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); }
"--arch" => { i += 1; args.arch = raw[i].clone(); }
"--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); }
"--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); }
"--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); }
"--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); }
"--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
args
}
// ── Game loop ─────────────────────────────────────────────────────────────────
/// Play one complete game.
///
/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black).
/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0).
fn play_game(
env: &TrictracEnv,
mcts_side: usize,
evaluator: &dyn Evaluator,
mcts_cfg: &MctsConfig,
rng: &mut SmallRng,
) -> [f32; 2] {
let mut state = env.new_game();
loop {
match env.current_player(&state) {
Player::Terminal => {
return env.returns(&state).expect("Terminal state must have returns");
}
Player::Chance => env.apply_chance(&mut state, rng),
player => {
let side = player.index().unwrap(); // 0 = P1, 1 = P2
let action = if side == mcts_side {
let root = run_mcts(env, &state, evaluator, mcts_cfg, rng);
select_action(&root, 0.0, rng) // greedy (temperature = 0)
} else {
let actions = env.legal_actions(&state);
actions[rng.random_range(0..actions.len())]
};
env.apply(&mut state, action);
}
}
}
}
// ── Statistics ────────────────────────────────────────────────────────────────
#[derive(Default)]
struct Stats {
wins: u32,
draws: u32,
losses: u32,
}
impl Stats {
fn record(&mut self, mcts_return: f32) {
if mcts_return > 0.0 { self.wins += 1; }
else if mcts_return < 0.0 { self.losses += 1; }
else { self.draws += 1; }
}
fn total(&self) -> u32 { self.wins + self.draws + self.losses }
fn win_rate_decisive(&self) -> f64 {
let d = self.wins + self.losses;
if d == 0 { 0.5 } else { self.wins as f64 / d as f64 }
}
fn print(&self) {
let n = self.total();
let pct = |k: u32| 100.0 * k as f64 / n as f64;
println!(
" Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)",
self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses),
);
}
}
// ── Evaluation ────────────────────────────────────────────────────────────────
fn run_evaluation(
evaluator: &dyn Evaluator,
n_games: usize,
mcts_cfg: &MctsConfig,
seed: u64,
) -> (Stats, Stats) {
let env = TrictracEnv;
let total = n_games * 2;
let mut as_p1 = Stats::default();
let mut as_p2 = Stats::default();
for i in 0..total {
// Alternate sides: even games → MctsAgent as P1, odd → as P2.
let mcts_side = i % 2;
let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64));
let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng);
let mcts_return = result[mcts_side];
if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); }
let done = i + 1;
if done % 10 == 0 || done == total {
eprint!("\r [{done}/{total}] ", );
}
}
eprintln!();
(as_p1, as_p2)
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
let device: <InferB as burn::tensor::backend::Backend>::Device = Default::default();
// ── Load model ────────────────────────────────────────────────────────
let evaluator: Box<dyn Evaluator> = match args.arch.as_str() {
"resnet" => {
let hidden = args.hidden.unwrap_or(512);
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => ResNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => ResNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, ResNet<InferB>>::new(model, device))
}
"mlp" | _ => {
let hidden = args.hidden.unwrap_or(256);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => MlpNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => MlpNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, MlpNet<InferB>>::new(model, device))
}
};
let mcts_cfg = MctsConfig {
n_simulations: args.n_sim,
c_puct: args.c_puct,
dirichlet_alpha: 0.0, // no exploration noise during evaluation
dirichlet_eps: 0.0,
temperature: 0.0, // greedy action selection
};
// ── Header ────────────────────────────────────────────────────────────
let ckpt_label = args.checkpoint
.as_deref()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("random weights");
println!();
println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent",
args.arch, args.n_sim);
println!("Games per side: {} | Total: {} | Seed: {}",
args.n_games, args.n_games * 2, args.seed);
println!();
// ── Run ───────────────────────────────────────────────────────────────
let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed);
// ── Results ───────────────────────────────────────────────────────────
println!("MctsAgent as P1 (White):");
as_p1.print();
println!("MctsAgent as P2 (Black):");
as_p2.print();
let combined_wins = as_p1.wins + as_p2.wins;
let combined_decisive = combined_wins + as_p1.losses + as_p2.losses;
let combined_wr = if combined_decisive == 0 { 0.5 }
else { combined_wins as f64 / combined_decisive as f64 };
println!();
println!("Combined win rate (excluding draws): {:.1}% [{}/{}]",
combined_wr * 100.0, combined_wins, combined_decisive);
println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%",
as_p1.win_rate_decisive() * 100.0,
as_p2.win_rate_decisive() * 100.0);
}

View file

@ -0,0 +1,331 @@
//! AlphaZero self-play training loop.
//!
//! # Usage
//!
//! ```sh
//! # Start fresh (MLP, default settings)
//! cargo run -p spiel_bot --bin az_train --release
//!
//! # ResNet, 200 iterations, save every 20
//! cargo run -p spiel_bot --bin az_train --release -- \
//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/
//!
//! # Resume from a checkpoint
//! cargo run -p spiel_bot --bin az_train --release -- \
//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
//! | `--hidden N` | 256/512 | Hidden layer width |
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
//! | `--n-iter N` | `100` | Training iterations |
//! | `--n-games N` | `10` | Self-play games per iteration |
//! | `--n-train N` | `20` | Gradient steps per iteration |
//! | `--n-sim N` | `100` | MCTS simulations per move |
//! | `--batch N` | `64` | Mini-batch size |
//! | `--replay-cap N` | `50000` | Replay buffer capacity |
//! | `--lr F` | `1e-3` | Peak (initial) learning rate |
//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) |
//! | `--c-puct F` | `1.5` | PUCT exploration constant |
//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha |
//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight |
//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 |
//! | `--save-every N` | `10` | Save checkpoint every N iterations |
//! | `--seed N` | `42` | RNG seed |
//! | `--resume PATH` | (none) | Load weights from checkpoint before training |
use std::path::{Path, PathBuf};
use std::time::Instant;
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use rayon::prelude::*;
use spiel_bot::{
alphazero::{
BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step,
},
env::TrictracEnv,
mcts::MctsConfig,
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
type TrainB = Autodiff<NdArray<f32>>;
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
arch: String,
hidden: Option<usize>,
out_dir: PathBuf,
n_iter: usize,
n_games: usize,
n_train: usize,
n_sim: usize,
batch_size: usize,
replay_cap: usize,
lr: f64,
lr_min: f64,
c_puct: f32,
dirichlet_alpha: f32,
dirichlet_eps: f32,
temp_drop: usize,
save_every: usize,
seed: u64,
resume: Option<PathBuf>,
}
impl Default for Args {
fn default() -> Self {
Self {
arch: "mlp".into(),
hidden: None,
out_dir: PathBuf::from("checkpoints"),
n_iter: 100,
n_games: 10,
n_train: 20,
n_sim: 100,
batch_size: 64,
replay_cap: 50_000,
lr: 1e-3,
lr_min: 1e-4,
c_puct: 1.5,
dirichlet_alpha: 0.1,
dirichlet_eps: 0.25,
temp_drop: 30,
save_every: 10,
seed: 42,
resume: None,
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut a = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--arch" => { i += 1; a.arch = raw[i].clone(); }
"--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); }
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
"--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); }
"--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); }
"--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); }
"--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); }
"--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); }
"--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); }
"--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); }
"--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); }
"--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); }
"--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); }
"--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); }
"--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); }
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
a
}
// ── Training loop ─────────────────────────────────────────────────────────────
/// Generic training loop, parameterised over the network type.
///
/// `save_fn` receives the **training-backend** model and the target path;
/// it is called in the match arm where the concrete network type is known.
fn train_loop<N>(
mut model: N,
save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>,
args: &Args,
)
where
N: PolicyValueNet<TrainB> + AutodiffModule<TrainB> + Clone,
<N as AutodiffModule<TrainB>>::InnerModule: PolicyValueNet<InferB> + Send + 'static,
{
let train_device: <TrainB as Backend>::Device = Default::default();
let infer_device: <InferB as Backend>::Device = Default::default();
// Type is inferred as OptimizerAdaptor<Adam, N, TrainB> at the call site.
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(args.replay_cap);
let mut rng = SmallRng::seed_from_u64(args.seed);
let env = TrictracEnv;
// Total gradient steps (used for cosine LR denominator).
let total_train_steps = (args.n_iter * args.n_train).max(1);
let mut global_step = 0usize;
println!(
"\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}",
"", args.arch, args.n_iter, args.n_games, args.n_sim, ""
);
for iter in 0..args.n_iter {
let t0 = Instant::now();
// ── Self-play ────────────────────────────────────────────────────
// Convert to inference backend (zero autodiff overhead).
let infer_model: <N as AutodiffModule<TrainB>>::InnerModule = model.valid();
let evaluator: BurnEvaluator<InferB, <N as AutodiffModule<TrainB>>::InnerModule> =
BurnEvaluator::new(infer_model, infer_device.clone());
let mcts_cfg = MctsConfig {
n_simulations: args.n_sim,
c_puct: args.c_puct,
dirichlet_alpha: args.dirichlet_alpha,
dirichlet_eps: args.dirichlet_eps,
temperature: 1.0,
};
let temp_drop = args.temp_drop;
let temperature_fn = |step: usize| -> f32 {
if step < temp_drop { 1.0 } else { 0.0 }
};
// Prepare per-game seeds and evaluators sequentially so the main RNG
// and model cloning stay deterministic regardless of thread scheduling.
// Burn modules are Send but not Sync, so each task must own its model.
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
let game_evals: Vec<_> = (0..args.n_games)
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
.collect();
drop(evaluator);
let all_samples: Vec<Vec<TrainSample>> = game_seeds
.into_par_iter()
.zip(game_evals.into_par_iter())
.map(|(seed, game_eval)| {
let mut game_rng = SmallRng::seed_from_u64(seed);
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
})
.collect();
let mut new_samples = 0usize;
for samples in all_samples {
new_samples += samples.len();
replay.extend(samples);
}
// ── Training ─────────────────────────────────────────────────────
let mut loss_sum = 0.0f32;
let mut n_steps = 0usize;
if replay.len() >= args.batch_size {
for _ in 0..args.n_train {
let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
let batch: Vec<TrainSample> = replay
.sample_batch(args.batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) =
train_step(model, &mut optimizer, &batch, &train_device, lr);
model = m;
loss_sum += loss;
n_steps += 1;
global_step += 1;
}
}
// ── Logging ──────────────────────────────────────────────────────
let elapsed = t0.elapsed();
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
println!(
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s",
iter + 1,
args.n_iter,
replay.len(),
new_samples,
avg_loss,
lr_now,
elapsed.as_secs_f32(),
);
// ── Checkpoint ───────────────────────────────────────────────────
let is_last = iter + 1 == args.n_iter;
if (iter + 1) % args.save_every == 0 || is_last {
let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1));
match save_fn(&model, &path) {
Ok(()) => println!(" -> saved {}", path.display()),
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
}
}
}
println!("\nTraining complete.");
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
// Create output directory if it doesn't exist.
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
std::process::exit(1);
}
let train_device: <TrainB as Backend>::Device = Default::default();
match args.arch.as_str() {
"resnet" => {
let hidden = args.hidden.unwrap_or(512);
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
ResNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => ResNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(
model,
&|m: &ResNet<TrainB>, path: &Path| {
// Save via inference model to avoid autodiff record overhead.
m.valid().save(path)
},
&args,
);
}
"mlp" | _ => {
let hidden = args.hidden.unwrap_or(256);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
MlpNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => MlpNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(
model,
&|m: &MlpNet<TrainB>, path: &Path| m.valid().save(path),
&args,
);
}
}
}

View file

@ -0,0 +1,251 @@
//! DQN self-play training loop.
//!
//! # Usage
//!
//! ```sh
//! # Start fresh with default settings
//! cargo run -p spiel_bot --bin dqn_train --release
//!
//! # Custom hyperparameters
//! cargo run -p spiel_bot --bin dqn_train --release -- \
//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000
//!
//! # Resume from a checkpoint
//! cargo run -p spiel_bot --bin dqn_train --release -- \
//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--hidden N` | 256 | Hidden layer width |
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
//! | `--n-iter N` | 100 | Training iterations |
//! | `--n-games N` | 10 | Self-play games per iteration |
//! | `--n-train N` | 20 | Gradient steps per iteration |
//! | `--batch N` | 64 | Mini-batch size |
//! | `--replay-cap N` | 50000 | Replay buffer capacity |
//! | `--lr F` | 1e-3 | Adam learning rate |
//! | `--epsilon-start F` | 1.0 | Initial exploration rate |
//! | `--epsilon-end F` | 0.05 | Final exploration rate |
//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor |
//! | `--gamma F` | 0.99 | Discount factor |
//! | `--target-update N` | 500 | Hard-update target net every N steps |
//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) |
//! | `--save-every N` | 10 | Save checkpoint every N iterations |
//! | `--seed N` | 42 | RNG seed |
//! | `--resume PATH` | (none) | Load weights before training |
use std::path::{Path, PathBuf};
use std::time::Instant;
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
dqn::{
DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step,
generate_dqn_episode, hard_update, linear_epsilon,
},
env::TrictracEnv,
network::{QNet, QNetConfig},
};
type TrainB = Autodiff<NdArray<f32>>;
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
hidden: usize,
out_dir: PathBuf,
save_every: usize,
seed: u64,
resume: Option<PathBuf>,
config: DqnConfig,
}
impl Default for Args {
fn default() -> Self {
Self {
hidden: 256,
out_dir: PathBuf::from("checkpoints"),
save_every: 10,
seed: 42,
resume: None,
config: DqnConfig::default(),
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut a = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); }
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
"--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); }
"--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); }
"--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); }
"--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); }
"--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); }
"--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); }
"--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); }
"--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); }
"--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); }
"--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); }
"--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); }
"--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); }
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
a
}
// ── Training loop ─────────────────────────────────────────────────────────────
fn train_loop(
mut q_net: QNet<TrainB>,
cfg: &QNetConfig,
save_fn: &dyn Fn(&QNet<TrainB>, &Path) -> anyhow::Result<()>,
args: &Args,
) {
let train_device: <TrainB as Backend>::Device = Default::default();
let infer_device: <InferB as Backend>::Device = Default::default();
let mut optimizer = AdamConfig::new().init();
let mut replay = DqnReplayBuffer::new(args.config.replay_capacity);
let mut rng = SmallRng::seed_from_u64(args.seed);
let env = TrictracEnv;
let mut target_net: QNet<InferB> = hard_update::<TrainB, _>(&q_net);
let mut global_step = 0usize;
let mut epsilon = args.config.epsilon_start;
println!(
"\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}",
"", args.config.n_iterations, args.config.n_games_per_iter,
args.config.n_train_steps_per_iter, ""
);
for iter in 0..args.config.n_iterations {
let t0 = Instant::now();
// ── Self-play ────────────────────────────────────────────────────
let infer_q: QNet<InferB> = q_net.valid();
let mut new_samples = 0usize;
for _ in 0..args.config.n_games_per_iter {
let samples = generate_dqn_episode(
&env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale,
);
new_samples += samples.len();
replay.extend(samples);
}
// ── Training ─────────────────────────────────────────────────────
let mut loss_sum = 0.0f32;
let mut n_steps = 0usize;
if replay.len() >= args.config.batch_size {
for _ in 0..args.config.n_train_steps_per_iter {
let batch: Vec<_> = replay
.sample_batch(args.config.batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
// Target Q-values computed on the inference backend.
let target_q = compute_target_q(
&target_net, &batch, cfg.action_size, &infer_device,
);
let (q, loss) = dqn_train_step(
q_net, &mut optimizer, &batch, &target_q,
&train_device, args.config.learning_rate, args.config.gamma,
);
q_net = q;
loss_sum += loss;
n_steps += 1;
global_step += 1;
// Hard-update target net every target_update_freq steps.
if global_step % args.config.target_update_freq == 0 {
target_net = hard_update::<TrainB, _>(&q_net);
}
// Linear epsilon decay.
epsilon = linear_epsilon(
args.config.epsilon_start,
args.config.epsilon_end,
global_step,
args.config.epsilon_decay_steps,
);
}
}
// ── Logging ──────────────────────────────────────────────────────
let elapsed = t0.elapsed();
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
println!(
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s",
iter + 1,
args.config.n_iterations,
replay.len(),
new_samples,
avg_loss,
epsilon,
elapsed.as_secs_f32(),
);
// ── Checkpoint ───────────────────────────────────────────────────
let is_last = iter + 1 == args.config.n_iterations;
if (iter + 1) % args.save_every == 0 || is_last {
let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1));
match save_fn(&q_net, &path) {
Ok(()) => println!(" -> saved {}", path.display()),
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
}
}
}
println!("\nDQN training complete.");
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
std::process::exit(1);
}
let train_device: <TrainB as Backend>::Device = Default::default();
let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden };
let q_net = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
QNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => QNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(q_net, &cfg, &|m: &QNet<TrainB>, path| m.valid().save(path), &args);
}

View file

@ -0,0 +1,247 @@
//! DQN self-play episode generation.
//!
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
//! mirroring so that each player always acts from "White's perspective").
//! Transitions for both players are stored in the returned sample list.
//!
//! # Reward
//!
//! After each full decision (action applied and the state has advanced through
//! any intervening chance nodes back to the same player's next turn), the
//! reward is:
//!
//! ```text
//! r = (my_total_score_now my_total_score_then)
//! (opp_total_score_now opp_total_score_then)
//! ```
//!
//! where `total_score = holes × 12 + points`.
//!
//! # Transition structure
//!
//! We use a "pending transition" per player. When a player acts again, we
//! *complete* the previous pending transition by filling in `next_obs`,
//! `next_legal`, and computing `reward`. Terminal transitions are completed
//! when the game ends.
use burn::tensor::{backend::Backend, Tensor, TensorData};
use rand::Rng;
use crate::env::{GameEnv, TrictracEnv};
use crate::network::QValueNet;
use super::DqnSample;
// ── Internals ─────────────────────────────────────────────────────────────────
struct PendingTransition {
obs: Vec<f32>,
action: usize,
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
score_before: [i32; 2],
}
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
q_net: &Q,
obs: &[f32],
legal: &[usize],
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
) -> usize {
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
if rng.random::<f32>() < epsilon {
legal[rng.random_range(0..legal.len())]
} else {
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs.to_vec(), [1, obs.len()]),
device,
);
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
legal
.iter()
.copied()
.max_by(|&a, &b| {
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap()
}
}
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
let opp_idx = 1 - player_idx;
((score_after[player_idx] - score_before[player_idx])
- (score_after[opp_idx] - score_before[opp_idx])) as f32
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Play one full game and return all transitions for both players.
///
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
env: &TrictracEnv,
q_net: &Q,
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
reward_scale: f32,
) -> Vec<DqnSample> {
let obs_size = env.obs_size();
let mut state = env.new_game();
let mut pending: [Option<PendingTransition>; 2] = [None, None];
let mut samples: Vec<DqnSample> = Vec::new();
loop {
// ── Advance past chance nodes ──────────────────────────────────────
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, rng);
}
let score_now = TrictracEnv::score_snapshot(&state);
if env.current_player(&state).is_terminal() {
// Complete all pending transitions as terminal.
for player_idx in 0..2 {
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: vec![0.0; obs_size],
next_legal: vec![],
done: true,
});
}
}
break;
}
let player_idx = env.current_player(&state).index().unwrap();
let legal = env.legal_actions(&state);
let obs = env.observation(&state, player_idx);
// ── Complete the previous transition for this player ───────────────
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: obs.clone(),
next_legal: legal.clone(),
done: false,
});
}
// ── Pick and apply action ──────────────────────────────────────────
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
env.apply(&mut state, action);
// ── Record new pending transition ──────────────────────────────────
pending[player_idx] = Some(PendingTransition {
obs,
action,
score_before: score_now,
});
}
samples
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng};
use crate::network::{QNet, QNetConfig};
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device { Default::default() }
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
fn tiny_q() -> QNet<B> {
QNet::new(&QNetConfig::default(), &device())
}
#[test]
fn episode_terminates_and_produces_samples() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty(), "episode must produce at least one sample");
}
#[test]
fn episode_obs_size_correct() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert_eq!(s.obs.len(), 217, "obs size mismatch");
if s.done {
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
assert!(s.next_legal.is_empty());
} else {
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
assert!(!s.next_legal.is_empty());
}
}
}
#[test]
fn episode_actions_within_action_space() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert!(s.action < 514, "action {} out of bounds", s.action);
}
}
#[test]
fn greedy_episode_also_terminates() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty());
}
#[test]
fn at_least_one_done_sample() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
let n_done = samples.iter().filter(|s| s.done).count();
// Two players, so 1 or 2 terminal transitions.
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
}
#[test]
fn compute_reward_correct() {
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
let before = [2 * 12 + 10, 0];
let after = [3 * 12 + 2, 0];
let r = compute_reward(0, &before, &after);
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
}
#[test]
fn compute_reward_with_opponent_scoring() {
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
let before = [0, 0];
let after = [2, 3];
let r = compute_reward(0, &before, &after);
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
}
}

232
spiel_bot/src/dqn/mod.rs Normal file
View file

@ -0,0 +1,232 @@
//! DQN: self-play data generation, replay buffer, and training step.
//!
//! # Algorithm
//!
//! Deep Q-Network with:
//! - **ε-greedy** exploration (linearly decayed).
//! - **Dense per-turn rewards**: `my_score_delta opponent_score_delta` where
//! `score = holes × 12 + points`.
//! - **Experience replay** with a fixed-capacity circular buffer.
//! - **Target network**: hard-copied from the online Q-net every
//! `target_update_freq` gradient steps for training stability.
//!
//! # Modules
//!
//! | Module | Contents |
//! |--------|----------|
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
pub mod episode;
pub mod trainer;
pub use episode::generate_dqn_episode;
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
use std::collections::VecDeque;
use rand::Rng;
// ── DqnSample ─────────────────────────────────────────────────────────────────
/// One transition `(s, a, r, s', done)` collected during self-play.
#[derive(Clone, Debug)]
pub struct DqnSample {
/// Observation from the acting player's perspective (`obs_size` floats).
pub obs: Vec<f32>,
/// Action index taken.
pub action: usize,
/// Per-turn reward: `my_score_delta opponent_score_delta`.
pub reward: f32,
/// Next observation from the same player's perspective.
/// All-zeros when `done = true` (ignored by the TD target).
pub next_obs: Vec<f32>,
/// Legal actions at `next_obs`. Empty when `done = true`.
pub next_legal: Vec<usize>,
/// `true` when `next_obs` is a terminal state.
pub done: bool,
}
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
///
/// When full, the oldest sample is evicted on push.
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
pub struct DqnReplayBuffer {
data: VecDeque<DqnSample>,
capacity: usize,
}
impl DqnReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
}
pub fn push(&mut self, sample: DqnSample) {
if self.data.len() == self.capacity {
self.data.pop_front();
}
self.data.push_back(sample);
}
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
for s in samples { self.push(s); }
}
pub fn len(&self) -> usize { self.data.len() }
pub fn is_empty(&self) -> bool { self.data.is_empty() }
/// Sample up to `n` distinct samples without replacement.
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
let len = self.data.len();
let n = n.min(len);
let mut indices: Vec<usize> = (0..len).collect();
for i in 0..n {
let j = rng.random_range(i..len);
indices.swap(i, j);
}
indices[..n].iter().map(|&i| &self.data[i]).collect()
}
}
// ── DqnConfig ─────────────────────────────────────────────────────────────────
/// Top-level DQN hyperparameters for the training loop.
#[derive(Debug, Clone)]
pub struct DqnConfig {
/// Initial exploration rate (1.0 = fully random).
pub epsilon_start: f32,
/// Final exploration rate after decay.
pub epsilon_end: f32,
/// Number of gradient steps over which ε decays linearly from start to end.
///
/// Should be calibrated to the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
/// means exploration never reaches `epsilon_end` during the run.
pub epsilon_decay_steps: usize,
/// Discount factor γ for the TD target. Typical: 0.99.
pub gamma: f32,
/// Hard-copy Q → target every this many gradient steps.
///
/// Should be much smaller than the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`).
pub target_update_freq: usize,
/// Adam learning rate.
pub learning_rate: f64,
/// Mini-batch size for each gradient step.
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Number of outer iterations (self-play + train).
pub n_iterations: usize,
/// Self-play games per iteration.
pub n_games_per_iter: usize,
/// Gradient steps per iteration.
pub n_train_steps_per_iter: usize,
/// Reward normalisation divisor.
///
/// Per-turn rewards (score delta) are divided by this constant before being
/// stored. Without normalisation, rewards can reach ±24 (jan with
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
/// causing MSE loss to grow unboundedly.
///
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
pub reward_scale: f32,
}
impl Default for DqnConfig {
fn default() -> Self {
// Total gradient steps with these defaults = 500 × 20 = 10_000,
// so epsilon decays fully and the target is updated 100 times.
Self {
epsilon_start: 1.0,
epsilon_end: 0.05,
epsilon_decay_steps: 10_000,
gamma: 0.99,
target_update_freq: 100,
learning_rate: 1e-3,
batch_size: 64,
replay_capacity: 50_000,
n_iterations: 500,
n_games_per_iter: 10,
n_train_steps_per_iter: 20,
reward_scale: 12.0,
}
}
}
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
if decay_steps == 0 || step >= decay_steps {
return end;
}
start + (end - start) * (step as f32 / decay_steps as f32)
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
fn dummy(reward: f32) -> DqnSample {
DqnSample {
obs: vec![0.0],
action: 0,
reward,
next_obs: vec![0.0],
next_legal: vec![0],
done: false,
}
}
#[test]
fn push_and_len() {
let mut buf = DqnReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(dummy(1.0));
buf.push(dummy(2.0));
assert_eq!(buf.len(), 2);
}
#[test]
fn evicts_oldest_at_capacity() {
let mut buf = DqnReplayBuffer::new(3);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
buf.push(dummy(3.0));
buf.push(dummy(4.0));
assert_eq!(buf.len(), 3);
assert_eq!(buf.data[0].reward, 2.0);
}
#[test]
fn sample_batch_size() {
let mut buf = DqnReplayBuffer::new(20);
for i in 0..10 { buf.push(dummy(i as f32)); }
let mut rng = SmallRng::seed_from_u64(0);
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
}
#[test]
fn linear_epsilon_start() {
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
}
#[test]
fn linear_epsilon_end() {
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
}
#[test]
fn linear_epsilon_monotone() {
let mut prev = f32::INFINITY;
for step in 0..=100 {
let e = linear_epsilon(1.0, 0.05, step, 100);
assert!(e <= prev + 1e-6);
prev = e;
}
}
}

View file

@ -0,0 +1,278 @@
//! DQN gradient step and target-network management.
//!
//! # TD target
//!
//! ```text
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
//! y_i = r_i if done
//! ```
//!
//! # Loss
//!
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
//! and `y_i` (computed from the frozen target net).
//!
//! # Target network
//!
//! [`hard_update`] copies the online Q-net weights into the target net by
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
use burn::{
module::AutodiffModule,
optim::{GradientsParams, Optimizer},
prelude::ElementConversion,
tensor::{
Int, Tensor, TensorData,
backend::{AutodiffBackend, Backend},
},
};
use crate::network::QValueNet;
use super::DqnSample;
// ── Target Q computation ─────────────────────────────────────────────────────
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
///
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
/// (their bootstrap term is dropped by the TD target anyway).
///
/// The target network runs on the **inference backend** (`InferB`) with no
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
target_net: &Q,
batch: &[DqnSample],
action_size: usize,
device: &B::Device,
) -> Vec<f32> {
let batch_size = batch.len();
// Collect indices of non-done samples (done samples have no next state).
let non_done: Vec<usize> = batch
.iter()
.enumerate()
.filter(|(_, s)| !s.done)
.map(|(i, _)| i)
.collect();
if non_done.is_empty() {
return vec![0.0; batch_size];
}
let obs_size = batch[0].next_obs.len();
let nd = non_done.len();
// Stack next observations for non-done samples → [nd, obs_size].
let obs_flat: Vec<f32> = non_done
.iter()
.flat_map(|&i| batch[i].next_obs.iter().copied())
.collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [nd, obs_size]),
device,
);
// Forward target net → [nd, action_size], then to Vec<f32>.
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
// For each non-done sample, pick max Q over legal next actions.
let mut result = vec![0.0f32; batch_size];
for (k, &i) in non_done.iter().enumerate() {
let legal = &batch[i].next_legal;
let offset = k * action_size;
let max_q = legal
.iter()
.map(|&a| q_flat[offset + a])
.fold(f32::NEG_INFINITY, f32::max);
// If legal is empty (shouldn't happen for non-done, but be safe):
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
}
result
}
// ── Training step ─────────────────────────────────────────────────────────────
/// Run one gradient step on `q_net` using `batch`.
///
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
/// frozen target network and passed in here so that this function only
/// needs the **autodiff backend**.
///
/// Returns the updated network and the scalar MSE loss.
pub fn dqn_train_step<B, Q, O>(
q_net: Q,
optimizer: &mut O,
batch: &[DqnSample],
target_max_q: &[f32],
device: &B::Device,
lr: f64,
gamma: f32,
) -> (Q, f32)
where
B: AutodiffBackend,
Q: QValueNet<B> + AutodiffModule<B>,
O: Optimizer<Q, B>,
{
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
let batch_size = batch.len();
let obs_size = batch[0].obs.len();
// ── Build observation tensor [B, obs_size] ────────────────────────────
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [batch_size, obs_size]),
device,
);
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
let q_all = q_net.forward(obs_tensor);
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
TensorData::new(actions, [batch_size]),
device,
)
.reshape([batch_size, 1]); // [B] → [B, 1]
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
// ── TD targets: r + γ · max_next_q · (1 done) ──────────────────────
let targets: Vec<f32> = batch
.iter()
.zip(target_max_q.iter())
.map(|(s, &max_q)| {
if s.done { s.reward } else { s.reward + gamma * max_q }
})
.collect();
let target_tensor = Tensor::<B, 1>::from_data(
TensorData::new(targets, [batch_size]),
device,
);
// ── MSE loss ──────────────────────────────────────────────────────────
let diff = q_pred - target_tensor.detach();
let loss = (diff.clone() * diff).mean();
let loss_scalar: f32 = loss.clone().into_scalar().elem();
// ── Backward + optimizer step ─────────────────────────────────────────
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &q_net);
let q_net = optimizer.step(lr, q_net, grads);
(q_net, loss_scalar)
}
// ── Target network update ─────────────────────────────────────────────────────
/// Hard-copy the online Q-net weights to a new target network.
///
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
/// inference-backend module with identical weights.
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
q_net.valid()
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::AdamConfig,
};
use crate::network::{QNet, QNetConfig};
type InferB = NdArray<f32>;
type TrainB = Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
(0..n)
.map(|i| DqnSample {
obs: vec![0.5f32; obs_size],
action: i % action_size,
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
next_obs: vec![0.5f32; obs_size],
next_legal: vec![0, 1],
done: i == n - 1,
})
.collect()
}
#[test]
fn compute_target_q_length() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 8);
}
#[test]
fn compute_target_q_done_is_zero() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
// Single done sample.
let batch = vec![DqnSample {
obs: vec![0.0; 4],
action: 0,
reward: 5.0,
next_obs: vec![0.0; 4],
next_legal: vec![],
done: true,
}];
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 1);
assert_eq!(tq[0], 0.0);
}
#[test]
fn train_step_returns_finite_loss() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
#[test]
fn train_step_loss_decreases() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(16, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let mut prev_loss = f32::INFINITY;
for _ in 0..10 {
let (q, loss) = dqn_train_step(
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
);
q_net = q;
assert!(loss.is_finite());
prev_loss = loss;
}
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
}
#[test]
fn hard_update_copies_weights() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = hard_update::<TrainB, _>(&q_net);
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
// After hard_update the target produces finite outputs.
assert!(q_out.iter().all(|v| v.is_finite()));
}
}

121
spiel_bot/src/env/mod.rs vendored Normal file
View file

@ -0,0 +1,121 @@
//! Game environment abstraction — the minimal "Rust OpenSpiel".
//!
//! A `GameEnv` describes the rules of a two-player, zero-sum game that may
//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN,
//! and PPO interact with a game exclusively through this trait.
//!
//! # Node taxonomy
//!
//! Every game position belongs to one of four categories, returned by
//! [`GameEnv::current_player`]:
//!
//! | [`Player`] | Meaning |
//! |-----------|---------|
//! | `P1` | Player 1 (index 0) must choose an action |
//! | `P2` | Player 2 (index 1) must choose an action |
//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) |
//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful |
//!
//! # Perspective convention
//!
//! [`GameEnv::observation`] always returns the board from *the requested
//! player's* point of view. Callers pass `pov = 0` for Player 1 and
//! `pov = 1` for Player 2. The implementation is responsible for any
//! mirroring required (e.g. Trictrac always reasons from White's side).
pub mod trictrac;
pub use trictrac::TrictracEnv;
/// Who controls the current game node.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Player {
/// Player 1 (index 0) is to move.
P1,
/// Player 2 (index 1) is to move.
P2,
/// A stochastic event (dice roll, etc.) must be resolved.
Chance,
/// The game is over.
Terminal,
}
impl Player {
/// Returns the player index (0 or 1) if this is a decision node,
/// or `None` for `Chance` / `Terminal`.
pub fn index(self) -> Option<usize> {
match self {
Player::P1 => Some(0),
Player::P2 => Some(1),
_ => None,
}
}
pub fn is_decision(self) -> bool {
matches!(self, Player::P1 | Player::P2)
}
pub fn is_chance(self) -> bool {
self == Player::Chance
}
pub fn is_terminal(self) -> bool {
self == Player::Terminal
}
}
/// Trait that completely describes a two-player zero-sum game.
///
/// Implementors must be cheaply cloneable (the type is used as a stateless
/// factory; the mutable game state lives in `Self::State`).
pub trait GameEnv: Clone + Send + Sync + 'static {
/// The mutable game state. Must be `Clone` so MCTS can copy
/// game trees without touching the environment.
type State: Clone + Send + Sync;
// ── State creation ────────────────────────────────────────────────────
/// Create a fresh game state at the initial position.
fn new_game(&self) -> Self::State;
// ── Node queries ──────────────────────────────────────────────────────
/// Classify the current node.
fn current_player(&self, s: &Self::State) -> Player;
/// Legal action indices at a decision node (`current_player` is `P1`/`P2`).
///
/// The returned indices are in `[0, action_space())`.
/// The result is unspecified (may panic or return empty) when called at a
/// `Chance` or `Terminal` node.
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
// ── State mutation ────────────────────────────────────────────────────
/// Apply a player action. `action` must be a value returned by
/// [`legal_actions`] for the current state.
fn apply(&self, s: &mut Self::State, action: usize);
/// Sample and apply a stochastic outcome. Must only be called when
/// `current_player(s) == Player::Chance`.
fn apply_chance<R: rand::Rng>(&self, s: &mut Self::State, rng: &mut R);
// ── Observation ───────────────────────────────────────────────────────
/// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2).
/// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`.
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
/// Number of floats returned by [`observation`].
fn obs_size(&self) -> usize;
/// Total number of distinct action indices (the policy head output size).
fn action_space(&self) -> usize;
// ── Terminal values ───────────────────────────────────────────────────
/// Game outcome for each player, or `None` if the game is not over.
///
/// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw.
/// Index 0 = Player 1, index 1 = Player 2.
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
}

547
spiel_bot/src/env/trictrac.rs vendored Normal file
View file

@ -0,0 +1,547 @@
//! [`GameEnv`] implementation for Trictrac.
//!
//! # Game flow (schools_enabled = false)
//!
//! With scoring schools disabled (the standard training configuration),
//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine
//! applies them automatically inside `RollResult` and `Move`. The only
//! four stages that actually occur are:
//!
//! | `TurnStage` | [`Player`] kind | Handled by |
//! |-------------|-----------------|------------|
//! | `RollDice` | `Chance` | [`apply_chance`] |
//! | `RollWaiting` | `Chance` | [`apply_chance`] |
//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] |
//! | `Move` | `P1`/`P2` | [`apply`] |
//!
//! # Perspective
//!
//! The Trictrac engine always reasons from White's perspective. Player 1 is
//! White; Player 2 is Black. When Player 2 is active, the board is mirrored
//! before computing legal actions / the observation tensor, and the resulting
//! event is mirrored back before being applied to the real state. This
//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`.
use trictrac_store::{
training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE},
Dice, GameEvent, GameState, Stage, TurnStage,
};
use super::{GameEnv, Player};
/// Stateless factory that produces Trictrac [`GameState`] environments.
///
/// Schools (`schools_enabled`) are always disabled — scoring is automatic.
#[derive(Clone, Debug, Default)]
pub struct TrictracEnv;
impl GameEnv for TrictracEnv {
type State = GameState;
// ── State creation ────────────────────────────────────────────────────
fn new_game(&self) -> GameState {
GameState::new_with_players("P1", "P2")
}
// ── Node queries ──────────────────────────────────────────────────────
fn current_player(&self, s: &GameState) -> Player {
if s.stage == Stage::Ended {
return Player::Terminal;
}
match s.turn_stage {
TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance,
_ => {
if s.active_player_id == 1 {
Player::P1
} else {
Player::P2
}
}
}
}
/// Returns the legal action indices for the active player.
///
/// The board is automatically mirrored for Player 2 so that the engine
/// always reasons from White's perspective. The returned indices are
/// identical in meaning for both players (checker ordinals are
/// perspective-relative).
///
/// # Panics
///
/// Panics in debug builds if called at a `Chance` or `Terminal` node.
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
debug_assert!(
self.current_player(s).is_decision(),
"legal_actions called at a non-decision node (turn_stage={:?})",
s.turn_stage
);
let indices = if s.active_player_id == 2 {
get_valid_action_indices(&s.mirror())
} else {
get_valid_action_indices(s)
};
indices.unwrap_or_default()
}
// ── State mutation ────────────────────────────────────────────────────
/// Apply a player action index to the game state.
///
/// For Player 2, the action is decoded against the mirrored board and
/// the resulting event is un-mirrored before being applied.
///
/// # Panics
///
/// Panics in debug builds if `action` cannot be decoded or does not
/// produce a valid event for the current state.
fn apply(&self, s: &mut GameState, action: usize) {
let needs_mirror = s.active_player_id == 2;
let event = if needs_mirror {
let view = s.mirror();
TrictracAction::from_action_index(action)
.and_then(|a| a.to_event(&view))
.map(|e| e.get_mirror(false))
} else {
TrictracAction::from_action_index(action).and_then(|a| a.to_event(s))
};
match event {
Some(e) => {
s.consume(&e).expect("apply: consume failed for valid action");
}
None => {
panic!("apply: action index {action} produced no event in state {s}");
}
}
}
/// Sample dice and advance through a chance node.
///
/// Handles both `RollDice` (triggers the roll mechanism, then samples
/// dice) and `RollWaiting` (only samples dice) in a single call so that
/// callers never need to distinguish the two.
///
/// # Panics
///
/// Panics in debug builds if called at a non-Chance node.
fn apply_chance<R: rand::Rng>(&self, s: &mut GameState, rng: &mut R) {
debug_assert!(
self.current_player(s).is_chance(),
"apply_chance called at a non-Chance node (turn_stage={:?})",
s.turn_stage
);
// Step 1: RollDice → RollWaiting (player initiates the roll).
if s.turn_stage == TurnStage::RollDice {
s.consume(&GameEvent::Roll {
player_id: s.active_player_id,
})
.expect("apply_chance: Roll event failed");
}
// Step 2: RollWaiting → Move / HoldOrGoChoice / Ended.
// With schools_enabled=false, point marking is automatic inside consume().
let dice = Dice {
values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)),
};
s.consume(&GameEvent::RollResult {
player_id: s.active_player_id,
dice,
})
.expect("apply_chance: RollResult event failed");
}
// ── Observation ───────────────────────────────────────────────────────
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
if pov == 0 {
s.to_tensor()
} else {
s.mirror().to_tensor()
}
}
fn obs_size(&self) -> usize {
217
}
fn action_space(&self) -> usize {
ACTION_SPACE_SIZE
}
// ── Terminal values ───────────────────────────────────────────────────
/// Returns `Some([r1, r2])` when the game is over, `None` otherwise.
///
/// The winner (higher cumulative score) receives `+1.0`; the loser
/// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score
/// is `holes × 12 + points`.
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
if s.stage != Stage::Ended {
return None;
}
let score = |id: u64| -> i32 {
s.players
.get(&id)
.map(|p| p.holes as i32 * 12 + p.points as i32)
.unwrap_or(0)
};
let s1 = score(1);
let s2 = score(2);
Some(match s1.cmp(&s2) {
std::cmp::Ordering::Greater => [1.0, -1.0],
std::cmp::Ordering::Less => [-1.0, 1.0],
std::cmp::Ordering::Equal => [0.0, 0.0],
})
}
}
// ── DQN helpers ───────────────────────────────────────────────────────────────
impl TrictracEnv {
/// Score snapshot for DQN reward computation.
///
/// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`.
/// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2).
pub fn score_snapshot(s: &GameState) -> [i32; 2] {
[s.total_score(1), s.total_score(2)]
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{rngs::SmallRng, Rng, SeedableRng};
fn env() -> TrictracEnv {
TrictracEnv
}
fn seeded_rng(seed: u64) -> SmallRng {
SmallRng::seed_from_u64(seed)
}
// ── Initial state ─────────────────────────────────────────────────────
#[test]
fn new_game_is_chance_node() {
let e = env();
let s = e.new_game();
// A fresh game starts at RollDice — a Chance node.
assert_eq!(e.current_player(&s), Player::Chance);
assert!(e.returns(&s).is_none());
}
#[test]
fn new_game_is_not_terminal() {
let e = env();
let s = e.new_game();
assert_ne!(e.current_player(&s), Player::Terminal);
assert!(e.returns(&s).is_none());
}
// ── Chance nodes ──────────────────────────────────────────────────────
#[test]
fn apply_chance_reaches_decision_node() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(1);
// A single chance step must yield a decision node (or end the game,
// which only happens after 12 holes — impossible on the first roll).
e.apply_chance(&mut s, &mut rng);
let p = e.current_player(&s);
assert!(
p.is_decision(),
"expected decision node after first roll, got {p:?}"
);
}
#[test]
fn apply_chance_from_rollwaiting() {
// Check that apply_chance works when called mid-way (at RollWaiting).
let e = env();
let mut s = e.new_game();
assert_eq!(s.turn_stage, TurnStage::RollDice);
// Manually advance to RollWaiting.
s.consume(&GameEvent::Roll { player_id: s.active_player_id })
.unwrap();
assert_eq!(s.turn_stage, TurnStage::RollWaiting);
let mut rng = seeded_rng(2);
e.apply_chance(&mut s, &mut rng);
let p = e.current_player(&s);
assert!(p.is_decision() || p.is_terminal());
}
// ── Legal actions ─────────────────────────────────────────────────────
#[test]
fn legal_actions_nonempty_after_roll() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(3);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
let actions = e.legal_actions(&s);
assert!(
!actions.is_empty(),
"legal_actions must be non-empty at a decision node"
);
}
#[test]
fn legal_actions_within_action_space() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(4);
e.apply_chance(&mut s, &mut rng);
for &a in e.legal_actions(&s).iter() {
assert!(
a < e.action_space(),
"action {a} out of bounds (action_space={})",
e.action_space()
);
}
}
// ── Observations ──────────────────────────────────────────────────────
#[test]
fn observation_has_correct_size() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(5);
e.apply_chance(&mut s, &mut rng);
assert_eq!(e.observation(&s, 0).len(), e.obs_size());
assert_eq!(e.observation(&s, 1).len(), e.obs_size());
}
#[test]
fn observation_values_in_unit_interval() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(6);
e.apply_chance(&mut s, &mut rng);
for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] {
for (i, &v) in obs.iter().enumerate() {
assert!(
v >= 0.0 && v <= 1.0,
"pov={pov}: obs[{i}] = {v} is outside [0,1]"
);
}
}
}
#[test]
fn p1_and_p2_observations_differ() {
// The board is mirrored for P2, so the two observations should differ
// whenever there are checkers in non-symmetric positions (always true
// in a real game after a few moves).
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(7);
// Advance far enough that the board is non-trivial.
for _ in 0..6 {
while e.current_player(&s).is_chance() {
e.apply_chance(&mut s, &mut rng);
}
if e.current_player(&s).is_terminal() {
break;
}
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
if !e.current_player(&s).is_terminal() {
let obs0 = e.observation(&s, 0);
let obs1 = e.observation(&s, 1);
assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board");
}
}
// ── Applying actions ──────────────────────────────────────────────────
#[test]
fn apply_changes_state() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(8);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
let before = s.clone();
let action = e.legal_actions(&s)[0];
e.apply(&mut s, action);
assert_ne!(
before.turn_stage, s.turn_stage,
"state must change after apply"
);
}
#[test]
fn apply_all_legal_actions_do_not_panic() {
// Verify that every action returned by legal_actions can be applied
// without panicking (on several independent copies of the same state).
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(9);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
for action in e.legal_actions(&s) {
let mut copy = s.clone();
e.apply(&mut copy, action); // must not panic
}
}
// ── Full game ─────────────────────────────────────────────────────────
/// Run a complete game with random actions through the `GameEnv` trait
/// and verify that:
/// - The game terminates.
/// - `returns()` is `Some` at the end.
/// - The outcome is valid: scores sum to 0 (zero-sum) or each player's
/// score is ±1 / 0.
/// - No step panics.
#[test]
fn full_random_game_terminates() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(42);
let max_steps = 50_000;
for step in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node");
let idx = rng.random_range(0..actions.len());
e.apply(&mut s, actions[idx]);
}
}
assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps");
}
let result = e.returns(&s);
assert!(result.is_some(), "returns() must be Some at Terminal");
let [r1, r2] = result.unwrap();
let sum = r1 + r2;
assert!(
(sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5,
"game must be zero-sum: r1={r1}, r2={r2}, sum={sum}"
);
assert!(
r1.abs() <= 1.0 && r2.abs() <= 1.0,
"returns must be in [-1,1]: r1={r1}, r2={r2}"
);
}
/// Run multiple games with different seeds to stress-test for panics.
#[test]
fn multiple_games_no_panic() {
let e = env();
let max_steps = 20_000;
for seed in 0..10u64 {
let mut s = e.new_game();
let mut rng = seeded_rng(seed);
for _ in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
e.apply(&mut s, actions[idx]);
}
}
}
}
}
// ── Returns ───────────────────────────────────────────────────────────
#[test]
fn returns_none_mid_game() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(11);
// Advance a few steps but do not finish the game.
for _ in 0..4 {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
}
}
if !e.current_player(&s).is_terminal() {
assert!(
e.returns(&s).is_none(),
"returns() must be None before the game ends"
);
}
}
// ── Player 2 actions ──────────────────────────────────────────────────
/// Verify that Player 2 (Black) can take actions without panicking,
/// and that the state advances correctly.
#[test]
fn player2_can_act() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(12);
// Keep stepping until Player 2 gets a turn.
let max_steps = 5_000;
let mut p2_acted = false;
for _ in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P2 => {
let actions = e.legal_actions(&s);
assert!(!actions.is_empty());
e.apply(&mut s, actions[0]);
p2_acted = true;
break;
}
Player::P1 => {
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
}
}
assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps");
}
}

5
spiel_bot/src/lib.rs Normal file
View file

@ -0,0 +1,5 @@
pub mod alphazero;
pub mod dqn;
pub mod env;
pub mod mcts;
pub mod network;

412
spiel_bot/src/mcts/mod.rs Normal file
View file

@ -0,0 +1,412 @@
//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance.
//!
//! # Algorithm
//!
//! The implementation follows AlphaZero's MCTS:
//!
//! 1. **Expand root** — run the network once to get priors and a value
//! estimate; optionally add Dirichlet noise for training-time exploration.
//! 2. **Simulate** `n_simulations` times:
//! - *Selection* — traverse the tree with PUCT until an unvisited leaf.
//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes;
//! chance nodes are **not** stored in the tree (outcome sampling).
//! - *Expansion* — evaluate the network at the leaf; populate children.
//! - *Backup* — propagate the value upward; negate at each player boundary.
//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]).
//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]).
//!
//! # Perspective convention
//!
//! Every [`MctsNode::w`] is stored **from the perspective of the player who
//! acts at that node**. The backup negates the child value whenever the
//! acting player differs between parent and child.
//!
//! # Stochastic games
//!
//! When [`GameEnv::current_player`] returns [`Player::Chance`], the
//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and
//! continues. Chance nodes are skipped transparently; Q-values converge to
//! their expectation over many simulations (outcome sampling).
pub mod node;
mod search;
pub use node::MctsNode;
use rand::Rng;
use crate::env::GameEnv;
// ── Evaluator trait ────────────────────────────────────────────────────────
/// Evaluates a game position for use in MCTS.
///
/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet)
/// but the `mcts` module itself does **not** depend on Burn.
pub trait Evaluator: Send + Sync {
/// Evaluate `obs` (flat observation vector of length `obs_size`).
///
/// Returns:
/// - `policy_logits`: one raw logit per action (`action_space` entries).
/// Illegal action entries are masked inside the search — no need to
/// zero them here.
/// - `value`: scalar in `(-1, 1)` from **the current player's** perspective.
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
}
// ── Configuration ─────────────────────────────────────────────────────────
/// Hyperparameters for [`run_mcts`].
#[derive(Debug, Clone)]
pub struct MctsConfig {
/// Number of MCTS simulations per move. Typical: 50800.
pub n_simulations: usize,
/// PUCT exploration constant `c_puct`. Typical: 1.02.0.
pub c_puct: f32,
/// Dirichlet noise concentration α. Set to `0.0` to disable.
/// Typical: `0.3` for Chess, `0.1` for large action spaces.
pub dirichlet_alpha: f32,
/// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`.
pub dirichlet_eps: f32,
/// Action sampling temperature. `> 0` = proportional sample, `0` = argmax.
pub temperature: f32,
}
impl Default for MctsConfig {
fn default() -> Self {
Self {
n_simulations: 200,
c_puct: 1.5,
dirichlet_alpha: 0.3,
dirichlet_eps: 0.25,
temperature: 1.0,
}
}
}
// ── Public interface ───────────────────────────────────────────────────────
/// Run MCTS from `state` and return the populated root [`MctsNode`].
///
/// `state` must be a player-decision node (`P1` or `P2`).
/// Use [`mcts_policy`] and [`select_action`] on the returned root.
///
/// # Panics
///
/// Panics if `env.current_player(state)` is not `P1` or `P2`.
pub fn run_mcts<E: GameEnv>(
env: &E,
state: &E::State,
evaluator: &dyn Evaluator,
config: &MctsConfig,
rng: &mut impl Rng,
) -> MctsNode {
let player_idx = env
.current_player(state)
.index()
.expect("run_mcts called at a non-decision node");
// ── Expand root (network called once here, not inside the loop) ────────
let mut root = MctsNode::new(1.0);
search::expand::<E>(&mut root, state, env, evaluator, player_idx);
// ── Optional Dirichlet noise for training exploration ──────────────────
if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 {
search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng);
}
// ── Simulations ────────────────────────────────────────────────────────
for _ in 0..config.n_simulations {
search::simulate::<E>(
&mut root,
state.clone(),
env,
evaluator,
config,
rng,
player_idx,
);
}
root
}
/// Compute the MCTS policy: normalized visit counts at the root.
///
/// Returns a vector of length `action_space` where `policy[a]` is the
/// fraction of simulations that visited action `a`.
pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec<f32> {
let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum();
let mut policy = vec![0.0f32; action_space];
if total > 0.0 {
for (a, child) in &root.children {
policy[*a] = child.n as f32 / total;
}
} else if !root.children.is_empty() {
// n_simulations = 0: uniform over legal actions.
let uniform = 1.0 / root.children.len() as f32;
for (a, _) in &root.children {
policy[*a] = uniform;
}
}
policy
}
/// Select an action index from the root after MCTS.
///
/// * `temperature = 0` — greedy argmax of visit counts.
/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`.
///
/// # Panics
///
/// Panics if the root has no children.
pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize {
assert!(!root.children.is_empty(), "select_action called on a root with no children");
if temperature <= 0.0 {
root.children
.iter()
.max_by_key(|(_, c)| c.n)
.map(|(a, _)| *a)
.unwrap()
} else {
let weights: Vec<f32> = root
.children
.iter()
.map(|(_, c)| (c.n as f32).powf(1.0 / temperature))
.collect();
let total: f32 = weights.iter().sum();
let mut r: f32 = rng.random::<f32>() * total;
for (i, (a, _)) in root.children.iter().enumerate() {
r -= weights[i];
if r <= 0.0 {
return *a;
}
}
root.children.last().map(|(a, _)| *a).unwrap()
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
use crate::env::Player;
// ── Minimal deterministic test game ───────────────────────────────────
//
// "Countdown" — two players alternate subtracting 1 or 2 from a counter.
// The player who brings the counter to 0 wins.
// No chance nodes, two legal actions (0 = -1, 1 = -2).
#[derive(Clone, Debug)]
struct CState {
remaining: u8,
to_move: usize, // at terminal: last mover (winner)
}
#[derive(Clone)]
struct CountdownEnv;
impl crate::env::GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState {
CState { remaining: 6, to_move: 0 }
}
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 {
Player::Terminal
} else if s.to_move == 0 {
Player::P1
} else {
Player::P2
}
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub {
s.remaining = 0;
// to_move stays as winner
} else {
s.remaining -= sub;
s.to_move = 1 - s.to_move;
}
}
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / 6.0, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
// Uniform evaluator: all logits = 0, value = 0.
// `action_space` must match the environment's `action_space()`.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
fn rng() -> SmallRng {
SmallRng::seed_from_u64(42)
}
fn config_n(n: usize) -> MctsConfig {
MctsConfig {
n_simulations: n,
c_puct: 1.5,
dirichlet_alpha: 0.0, // off for reproducibility
dirichlet_eps: 0.0,
temperature: 1.0,
}
}
// ── Visit count tests ─────────────────────────────────────────────────
#[test]
fn visit_counts_sum_to_n_simulations() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng());
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert_eq!(total, 50, "visit counts must sum to n_simulations");
}
#[test]
fn all_root_children_are_legal() {
let env = CountdownEnv;
let state = env.new_game();
let legal = env.legal_actions(&state);
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng());
for (a, _) in &root.children {
assert!(legal.contains(a), "child action {a} is not legal");
}
}
// ── Policy tests ─────────────────────────────────────────────────────
#[test]
fn policy_sums_to_one() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng());
let policy = mcts_policy(&root, env.action_space());
let sum: f32 = policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0");
}
#[test]
fn policy_zero_for_illegal_actions() {
let env = CountdownEnv;
// remaining = 1 → only action 0 is legal
let state = CState { remaining: 1, to_move: 0 };
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng());
let policy = mcts_policy(&root, env.action_space());
assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass");
}
// ── Action selection tests ────────────────────────────────────────────
#[test]
fn greedy_selects_most_visited() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng());
let greedy = select_action(&root, 0.0, &mut rng());
let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap();
assert_eq!(greedy, most_visited);
}
#[test]
fn temperature_sampling_stays_legal() {
let env = CountdownEnv;
let state = env.new_game();
let legal = env.legal_actions(&state);
let mut r = rng();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r);
for _ in 0..20 {
let a = select_action(&root, 1.0, &mut r);
assert!(legal.contains(&a), "sampled action {a} is not legal");
}
}
// ── Zero-simulation edge case ─────────────────────────────────────────
#[test]
fn zero_simulations_uniform_policy() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng());
let policy = mcts_policy(&root, env.action_space());
// With 0 simulations, fallback is uniform over the 2 legal actions.
let sum: f32 = policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
// ── Root value ────────────────────────────────────────────────────────
#[test]
fn root_q_in_valid_range() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng());
let q = root.q();
assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]");
}
// ── Integration: run on a real Trictrac game ──────────────────────────
#[test]
fn no_panic_on_trictrac_state() {
use crate::env::TrictracEnv;
let env = TrictracEnv;
let mut state = env.new_game();
let mut r = rng();
// Advance past the initial chance node to reach a decision node.
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, &mut r);
}
if env.current_player(&state).is_terminal() {
return; // unlikely but safe
}
let config = MctsConfig {
n_simulations: 5, // tiny for speed
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
..MctsConfig::default()
};
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
assert_eq!(root.n, 1 + config.n_simulations as u32);
// Every simulation crosses a chance node at depth 1 (dice roll after
// the player's move). Since the fix now updates child.n in that case,
// children visit counts must sum to exactly n_simulations.
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert_eq!(total, config.n_simulations as u32);
}
}

View file

@ -0,0 +1,91 @@
//! MCTS tree node.
//!
//! [`MctsNode`] holds the visit statistics for one player-decision position in
//! the search tree. A node is *expanded* the first time the policy-value
//! network is evaluated there; before that it is a leaf.
/// One node in the MCTS tree, representing a player-decision position.
///
/// `w` stores the sum of values backed up into this node, always from the
/// perspective of **the player who acts here**. `q()` therefore also returns
/// a value in `(-1, 1)` from that same perspective.
#[derive(Debug)]
pub struct MctsNode {
/// Visit count `N(s, a)`.
pub n: u32,
/// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective.
pub w: f32,
/// Prior probability `P(s, a)` assigned by the policy head (after masked softmax).
pub p: f32,
/// Children: `(action_index, child_node)`, populated on first expansion.
pub children: Vec<(usize, MctsNode)>,
/// `true` after the network has been evaluated and children have been set up.
pub expanded: bool,
}
impl MctsNode {
/// Create a fresh, unexpanded leaf with the given prior probability.
pub fn new(prior: f32) -> Self {
Self {
n: 0,
w: 0.0,
p: prior,
children: Vec::new(),
expanded: false,
}
}
/// `Q(s, a) = W / N`, or `0.0` if this node has never been visited.
#[inline]
pub fn q(&self) -> f32 {
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
}
/// PUCT selection score:
///
/// ```text
/// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a))
/// ```
#[inline]
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q_zero_when_unvisited() {
let node = MctsNode::new(0.5);
assert_eq!(node.q(), 0.0);
}
#[test]
fn q_reflects_w_over_n() {
let mut node = MctsNode::new(0.5);
node.n = 4;
node.w = 2.0;
assert!((node.q() - 0.5).abs() < 1e-6);
}
#[test]
fn puct_exploration_dominates_unvisited() {
// Unvisited child should outscore a visited child with negative Q.
let mut visited = MctsNode::new(0.5);
visited.n = 10;
visited.w = -5.0; // Q = -0.5
let unvisited = MctsNode::new(0.5);
let parent_n = 10;
let c = 1.5;
assert!(
unvisited.puct(parent_n, c) > visited.puct(parent_n, c),
"unvisited child should have higher PUCT than a negatively-valued visited child"
);
}
}

View file

@ -0,0 +1,190 @@
//! Simulation, expansion, backup, and noise helpers.
//!
//! These are internal to the `mcts` module; the public entry points are
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
use rand::Rng;
use rand_distr::{Gamma, Distribution};
use crate::env::GameEnv;
use super::{Evaluator, MctsConfig};
use super::node::MctsNode;
// ── Masked softmax ─────────────────────────────────────────────────────────
/// Numerically stable softmax over `legal` actions only.
///
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
/// Returns a probability vector of length `action_space`.
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
let mut probs = vec![0.0f32; action_space];
if legal.is_empty() {
return probs;
}
let max_logit = legal
.iter()
.map(|&a| logits[a])
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for &a in legal {
let e = (logits[a] - max_logit).exp();
probs[a] = e;
sum += e;
}
if sum > 0.0 {
for &a in legal {
probs[a] /= sum;
}
} else {
let uniform = 1.0 / legal.len() as f32;
for &a in legal {
probs[a] = uniform;
}
}
probs
}
// ── Dirichlet noise ────────────────────────────────────────────────────────
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
///
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
pub(super) fn add_dirichlet_noise(
node: &mut MctsNode,
alpha: f32,
eps: f32,
rng: &mut impl Rng,
) {
let n = node.children.len();
if n == 0 {
return;
}
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
return;
};
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
let sum: f32 = samples.iter().sum();
if sum <= 0.0 {
return;
}
for (i, (_, child)) in node.children.iter_mut().enumerate() {
let noise = samples[i] / sum;
child.p = (1.0 - eps) * child.p + eps * noise;
}
}
// ── Expansion ──────────────────────────────────────────────────────────────
/// Evaluate the network at `state` and populate `node` with children.
///
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
/// Returns the network value estimate from `player_idx`'s perspective.
pub(super) fn expand<E: GameEnv>(
node: &mut MctsNode,
state: &E::State,
env: &E,
evaluator: &dyn Evaluator,
player_idx: usize,
) -> f32 {
let obs = env.observation(state, player_idx);
let legal = env.legal_actions(state);
let (logits, value) = evaluator.evaluate(&obs);
let priors = masked_softmax(&logits, &legal, env.action_space());
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
node.expanded = true;
node.n = 1;
node.w = value;
value
}
// ── Simulation ─────────────────────────────────────────────────────────────
/// One MCTS simulation from an **already-expanded** decision node.
///
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
/// and backs up the result.
///
/// * `player_idx` — the player (0 or 1) who acts at `state`.
/// * Returns the backed-up value **from `player_idx`'s perspective**.
pub(super) fn simulate<E: GameEnv>(
node: &mut MctsNode,
state: E::State,
env: &E,
evaluator: &dyn Evaluator,
config: &MctsConfig,
rng: &mut impl Rng,
player_idx: usize,
) -> f32 {
debug_assert!(node.expanded, "simulate called on unexpanded node");
// ── Selection: child with highest PUCT ────────────────────────────────
let parent_n = node.n;
let best = node
.children
.iter()
.enumerate()
.max_by(|(_, (_, a)), (_, (_, b))| {
a.puct(parent_n, config.c_puct)
.partial_cmp(&b.puct(parent_n, config.c_puct))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.expect("expanded node must have at least one child");
let (action, child) = &mut node.children[best];
let action = *action;
// ── Apply action + advance through any chance nodes ───────────────────
let mut next_state = state;
env.apply(&mut next_state, action);
// Track whether we crossed a chance node (dice roll) on the way down.
// If we did, the child's cached legal actions are for a *different* dice
// outcome and must not be reused — evaluate with the network directly.
let mut crossed_chance = false;
while env.current_player(&next_state).is_chance() {
env.apply_chance(&mut next_state, rng);
crossed_chance = true;
}
let next_cp = env.current_player(&next_state);
// ── Evaluate leaf or terminal ──────────────────────────────────────────
// All values are converted to `player_idx`'s perspective before backup.
let child_value = if next_cp.is_terminal() {
let returns = env
.returns(&next_state)
.expect("terminal node must have returns");
returns[player_idx]
} else {
let child_player = next_cp.index().unwrap();
let v = if crossed_chance {
// Outcome sampling: after dice, evaluate the resulting position
// directly with the network. Do NOT build the tree across chance
// boundaries — the dice change which actions are legal, so any
// previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs);
// Record the visit so that PUCT and mcts_policy use real counts.
// Without this, child.n stays 0 for every simulation in games where
// every player action is immediately followed by a chance node (e.g.
// Trictrac), causing mcts_policy to always return a uniform policy.
child.n += 1;
child.w += value;
value
} else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player)
} else {
expand::<E>(child, &next_state, env, evaluator, child_player)
};
// Negate when the child belongs to the opponent.
if child_player == player_idx { v } else { -v }
};
// ── Backup ────────────────────────────────────────────────────────────
node.n += 1;
node.w += child_value;
child_value
}

View file

@ -0,0 +1,223 @@
//! Two-hidden-layer MLP policy-value network.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU
//! → Linear(hidden → hidden) → ReLU
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
//! ```
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{
activation::{relu, tanh},
backend::Backend,
Tensor,
},
};
use std::path::Path;
use super::PolicyValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`MlpNet`].
#[derive(Debug, Clone)]
pub struct MlpConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of both hidden layers.
pub hidden_size: usize,
}
impl Default for MlpConfig {
fn default() -> Self {
Self {
obs_size: 217,
action_size: 514,
hidden_size: 256,
}
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Simple two-hidden-layer MLP with shared trunk and two heads.
///
/// Prefer this over [`ResNet`](super::ResNet) when training time is a
/// priority, or as a fast baseline.
#[derive(Module, Debug)]
pub struct MlpNet<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> MlpNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &MlpConfig, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
value_head: LinearConfig::new(config.hidden_size, 1).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
///
/// The file is written exactly at `path`; callers should append `.mpk` if
/// they want the conventional extension.
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> PolicyValueNet<B> for MlpNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let x = relu(self.fc1.forward(obs));
let x = relu(self.fc2.forward(x));
let policy = self.policy_head.forward(x.clone());
let value = tanh(self.value_head.forward(x));
(policy, value)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn default_net() -> MlpNet<B> {
MlpNet::new(&MlpConfig::default(), &device())
}
fn zeros_obs(batch: usize) -> Tensor<B, 2> {
Tensor::zeros([batch, 217], &device())
}
// ── Shape tests ───────────────────────────────────────────────────────
#[test]
fn forward_output_shapes() {
let net = default_net();
let obs = zeros_obs(4);
let (policy, value) = net.forward(obs);
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
}
#[test]
fn forward_single_sample() {
let net = default_net();
let (policy, value) = net.forward(zeros_obs(1));
assert_eq!(policy.dims(), [1, 514]);
assert_eq!(value.dims(), [1, 1]);
}
// ── Value bounds ──────────────────────────────────────────────────────
#[test]
fn value_in_tanh_range() {
let net = default_net();
// Use a non-zero input so the output is not trivially at 0.
let obs = Tensor::<B, 2>::ones([8, 217], &device());
let (_, value) = net.forward(obs);
let data: Vec<f32> = value.into_data().to_vec().unwrap();
for v in &data {
assert!(
*v > -1.0 && *v < 1.0,
"value {v} is outside open interval (-1, 1)"
);
}
}
// ── Policy logits ─────────────────────────────────────────────────────
#[test]
fn policy_logits_not_all_equal() {
// With random weights the 514 logits should not all be identical.
let net = default_net();
let (policy, _) = net.forward(zeros_obs(1));
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
let first = data[0];
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
assert!(!all_same, "all policy logits are identical — network may be degenerate");
}
// ── Config propagation ────────────────────────────────────────────────
#[test]
fn custom_config_shapes() {
let config = MlpConfig {
obs_size: 10,
action_size: 20,
hidden_size: 32,
};
let net = MlpNet::<B>::new(&config, &device());
let obs = Tensor::zeros([3, 10], &device());
let (policy, value) = net.forward(obs);
assert_eq!(policy.dims(), [3, 20]);
assert_eq!(value.dims(), [3, 1]);
}
// ── Save / Load ───────────────────────────────────────────────────────
#[test]
fn save_load_preserves_weights() {
let config = MlpConfig::default();
let net = default_net();
// Forward pass before saving.
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let (policy_before, value_before) = net.forward(obs.clone());
// Save to a temp file.
let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk");
net.save(&path).expect("save failed");
// Load into a fresh model.
let loaded = MlpNet::<B>::load(&config, &path, &device()).expect("load failed");
let (policy_after, value_after) = loaded.forward(obs);
// Outputs must be bitwise identical.
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
}
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
}
let _ = std::fs::remove_file(path);
}
}

View file

@ -0,0 +1,78 @@
//! Neural network abstractions for policy-value learning.
//!
//! # Trait
//!
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
//! implement. It takes an observation tensor and returns raw policy logits
//! plus a tanh-squashed scalar value estimate.
//!
//! # Architectures
//!
//! | Module | Description | Default hidden |
//! |--------|-------------|----------------|
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
//!
//! # Backend convention
//!
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
//! through the forward pass.
//!
//! Both modes use the exact same struct; only the type-level backend changes:
//!
//! ```rust,ignore
//! use burn::backend::{Autodiff, NdArray};
//! type InferBackend = NdArray<f32>;
//! type TrainBackend = Autodiff<NdArray<f32>>;
//!
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
//! ```
//!
//! # Output shapes
//!
//! Given a batch of `B` observations of size `obs_size`:
//!
//! | Output | Shape | Range |
//! |--------|-------|-------|
//! | `policy_logits` | `[B, action_size]` | (unnormalised) |
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
//!
//! Callers are responsible for masking illegal actions in `policy_logits`
//! before passing to softmax.
pub mod mlp;
pub mod qnet;
pub mod resnet;
pub use mlp::{MlpConfig, MlpNet};
pub use qnet::{QNet, QNetConfig};
pub use resnet::{ResNet, ResNetConfig};
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
/// A neural network that produces a policy and a value from an observation.
///
/// # Shapes
/// - `obs`: `[batch, obs_size]`
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
///
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
}
/// A neural network that outputs one Q-value per action.
///
/// # Shapes
/// - `obs`: `[batch, obs_size]`
/// - output: `[batch, action_size]` — raw Q-values (no activation)
///
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
}

View file

@ -0,0 +1,147 @@
//! Single-headed Q-value network for DQN.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU
//! → Linear(hidden → hidden) → ReLU
//! → Linear(hidden → action_size) ← raw Q-values, no activation
//! ```
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{activation::relu, backend::Backend, Tensor},
};
use std::path::Path;
use super::QValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`QNet`].
#[derive(Debug, Clone)]
pub struct QNetConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of both hidden layers.
pub hidden_size: usize,
}
impl Default for QNetConfig {
fn default() -> Self {
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Two-hidden-layer MLP that outputs one Q-value per action.
#[derive(Module, Debug)]
pub struct QNet<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
q_head: Linear<B>,
}
impl<B: Backend> QNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> QValueNet<B> for QNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
let x = relu(self.fc1.forward(obs));
let x = relu(self.fc2.forward(x));
self.q_head.forward(x)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device { Default::default() }
fn default_net() -> QNet<B> {
QNet::new(&QNetConfig::default(), &device())
}
#[test]
fn forward_output_shape() {
let net = default_net();
let obs = Tensor::zeros([4, 217], &device());
let q = net.forward(obs);
assert_eq!(q.dims(), [4, 514]);
}
#[test]
fn forward_single_sample() {
let net = default_net();
let q = net.forward(Tensor::zeros([1, 217], &device()));
assert_eq!(q.dims(), [1, 514]);
}
#[test]
fn q_values_not_all_equal() {
let net = default_net();
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
.into_data().to_vec().unwrap();
let first = q[0];
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
}
#[test]
fn custom_config_shapes() {
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
let net = QNet::<B>::new(&cfg, &device());
let q = net.forward(Tensor::zeros([3, 10], &device()));
assert_eq!(q.dims(), [3, 20]);
}
#[test]
fn save_load_preserves_weights() {
let net = default_net();
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
net.save(&path).expect("save failed");
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
}
let _ = std::fs::remove_file(path);
}
}

View file

@ -0,0 +1,253 @@
//! Residual-block policy-value network.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU (input projection)
//! → ResBlock × 4 (residual trunk)
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
//!
//! ResBlock:
//! x → Linear → ReLU → Linear → (+x) → ReLU
//! ```
//!
//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better
//! suited for long training runs where board-pattern recognition matters.
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{
activation::{relu, tanh},
backend::Backend,
Tensor,
},
};
use std::path::Path;
use super::PolicyValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`ResNet`].
#[derive(Debug, Clone)]
pub struct ResNetConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of all hidden layers (input projection + residual blocks).
pub hidden_size: usize,
}
impl Default for ResNetConfig {
fn default() -> Self {
Self {
obs_size: 217,
action_size: 514,
hidden_size: 512,
}
}
}
// ── Residual block ────────────────────────────────────────────────────────────
/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`.
///
/// Both linear layers preserve the hidden dimension so the skip connection
/// can be added without projection.
#[derive(Module, Debug)]
struct ResBlock<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> ResBlock<B> {
fn new(hidden: usize, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(hidden, hidden).init(device),
fc2: LinearConfig::new(hidden, hidden).init(device),
}
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let residual = x.clone();
let out = relu(self.fc1.forward(x));
relu(self.fc2.forward(out) + residual)
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Four-residual-block policy-value network.
///
/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and
/// when representing complex positional patterns is important.
#[derive(Module, Debug)]
pub struct ResNet<B: Backend> {
input: Linear<B>,
block0: ResBlock<B>,
block1: ResBlock<B>,
block2: ResBlock<B>,
block3: ResBlock<B>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> ResNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &ResNetConfig, device: &B::Device) -> Self {
let h = config.hidden_size;
Self {
input: LinearConfig::new(config.obs_size, h).init(device),
block0: ResBlock::new(h, device),
block1: ResBlock::new(h, device),
block2: ResBlock::new(h, device),
block3: ResBlock::new(h, device),
policy_head: LinearConfig::new(h, config.action_size).init(device),
value_head: LinearConfig::new(h, 1).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> PolicyValueNet<B> for ResNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let x = relu(self.input.forward(obs));
let x = self.block0.forward(x);
let x = self.block1.forward(x);
let x = self.block2.forward(x);
let x = self.block3.forward(x);
let policy = self.policy_head.forward(x.clone());
let value = tanh(self.value_head.forward(x));
(policy, value)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn small_config() -> ResNetConfig {
// Use a small hidden size so tests are fast.
ResNetConfig {
obs_size: 217,
action_size: 514,
hidden_size: 64,
}
}
fn net() -> ResNet<B> {
ResNet::new(&small_config(), &device())
}
// ── Shape tests ───────────────────────────────────────────────────────
#[test]
fn forward_output_shapes() {
let obs = Tensor::zeros([4, 217], &device());
let (policy, value) = net().forward(obs);
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
}
#[test]
fn forward_single_sample() {
let (policy, value) = net().forward(Tensor::zeros([1, 217], &device()));
assert_eq!(policy.dims(), [1, 514]);
assert_eq!(value.dims(), [1, 1]);
}
// ── Value bounds ──────────────────────────────────────────────────────
#[test]
fn value_in_tanh_range() {
let obs = Tensor::<B, 2>::ones([8, 217], &device());
let (_, value) = net().forward(obs);
let data: Vec<f32> = value.into_data().to_vec().unwrap();
for v in &data {
assert!(
*v > -1.0 && *v < 1.0,
"value {v} is outside open interval (-1, 1)"
);
}
}
// ── Residual connections ──────────────────────────────────────────────
#[test]
fn policy_logits_not_all_equal() {
let (policy, _) = net().forward(Tensor::zeros([1, 217], &device()));
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
let first = data[0];
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
assert!(!all_same, "all policy logits are identical");
}
// ── Save / Load ───────────────────────────────────────────────────────
#[test]
fn save_load_preserves_weights() {
let config = small_config();
let model = net();
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let (policy_before, value_before) = model.forward(obs.clone());
let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk");
model.save(&path).expect("save failed");
let loaded = ResNet::<B>::load(&config, &path, &device()).expect("load failed");
let (policy_after, value_after) = loaded.forward(obs);
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
}
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
}
let _ = std::fs::remove_file(path);
}
// ── Integration: both architectures satisfy PolicyValueNet ────────────
#[test]
fn resnet_satisfies_trait() {
fn requires_net<B: Backend, N: PolicyValueNet<B>>(net: &N, obs: Tensor<B, 2>) {
let (p, v) = net.forward(obs);
assert_eq!(p.dims()[1], 514);
assert_eq!(v.dims()[1], 1);
}
requires_net(&net(), Tensor::zeros([2, 217], &device()));
}
}

View file

@ -0,0 +1,391 @@
//! End-to-end integration tests for the AlphaZero training pipeline.
//!
//! Each test exercises the full chain:
//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`]
//!
//! Two environments are used:
//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves.
//! Used when we need many iterations without worrying about runtime.
//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that
//! the full pipeline compiles and runs correctly with 217-dim observations
//! and 514-dim action spaces.
//!
//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep
//! runtime minimal; correctness, not training quality, is what matters here.
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::MctsConfig,
network::{MlpConfig, MlpNet, PolicyValueNet},
};
// ── Backend aliases ────────────────────────────────────────────────────────
type Train = Autodiff<NdArray<f32>>;
type Infer = NdArray<f32>;
// ── Helpers ────────────────────────────────────────────────────────────────
fn train_device() -> <Train as burn::tensor::backend::Backend>::Device {
Default::default()
}
fn infer_device() -> <Infer as burn::tensor::backend::Backend>::Device {
Default::default()
}
/// Tiny 64-unit MLP, compatible with an obs/action space of any size.
fn tiny_mlp(obs: usize, actions: usize) -> MlpNet<Train> {
let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 };
MlpNet::new(&cfg, &train_device())
}
fn tiny_mcts(n: usize) -> MctsConfig {
MctsConfig {
n_simulations: n,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
}
}
fn seeded() -> SmallRng {
SmallRng::seed_from_u64(0)
}
// ── Countdown environment (fast, local, no external deps) ─────────────────
//
// Two players alternate subtracting 1 or 2 from a counter that starts at N.
// The player who brings the counter to 0 wins.
#[derive(Clone, Debug)]
struct CState {
remaining: u8,
to_move: usize,
}
#[derive(Clone)]
struct CountdownEnv(u8); // starting value
impl GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState {
CState { remaining: self.0, to_move: 0 }
}
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 { Player::Terminal }
else if s.to_move == 0 { Player::P1 }
else { Player::P2 }
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub {
s.remaining = 0;
} else {
s.remaining -= sub;
s.to_move = 1 - s.to_move;
}
}
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / self.0 as f32, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
// ── 1. Full loop on CountdownEnv ──────────────────────────────────────────
/// The canonical AlphaZero loop: self-play → replay → train, iterated.
/// Uses CountdownEnv so each game terminates in < 10 moves.
#[test]
fn countdown_full_loop_no_panic() {
let env = CountdownEnv(8);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(1_000);
for _iter in 0..5 {
// Self-play: 3 games per iteration.
for _ in 0..3 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
replay.extend(samples);
}
// Training: 4 gradient steps per iteration.
if replay.len() >= 4 {
for _ in 0..4 {
let batch: Vec<TrainSample> = replay
.sample_batch(4, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
}
}
assert!(replay.len() > 0);
}
// ── 2. Replay buffer invariants ───────────────────────────────────────────
/// After several Countdown games, replay capacity is respected and batch
/// shapes are consistent.
#[test]
fn replay_buffer_capacity_and_shapes() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let capacity = 50;
let mut replay = ReplayBuffer::new(capacity);
for _ in 0..20 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
replay.extend(samples);
}
assert!(replay.len() <= capacity, "buffer exceeded capacity");
assert!(replay.len() > 0);
let batch = replay.sample_batch(8, &mut rng);
assert_eq!(batch.len(), 8.min(replay.len()));
for s in &batch {
assert_eq!(s.obs.len(), env.obs_size());
assert_eq!(s.policy.len(), env.action_space());
let policy_sum: f32 = s.policy.iter().sum();
assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}");
assert!(s.value.abs() <= 1.0, "value {} out of range", s.value);
}
}
// ── 3. TrictracEnv: sample shapes ─────────────────────────────────────────
/// Verify that one TrictracEnv episode produces samples with the correct
/// tensor dimensions: obs = 217, policy = 514.
#[test]
fn trictrac_sample_shapes() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty(), "Trictrac episode produced no samples");
for (i, s) in samples.iter().enumerate() {
assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len());
assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len());
let policy_sum: f32 = s.policy.iter().sum();
assert!(
(policy_sum - 1.0).abs() < 1e-4,
"sample {i}: policy sums to {policy_sum}"
);
assert!(
s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
"sample {i}: unexpected value {}",
s.value
);
}
}
// ── 4. TrictracEnv: training step after real self-play ────────────────────
/// Collect one Trictrac episode, then verify that a gradient step runs
/// without panic and produces a finite loss.
#[test]
fn trictrac_train_step_finite_loss() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(10_000);
// Generate one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
let n_samples = samples.len();
replay.extend(samples);
// Train on a batch from this episode.
let batch_size = 8.min(n_samples);
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}");
assert!(loss > 0.0, "loss should be positive");
}
// ── 5. Backend transfer: train → infer → same outputs ─────────────────────
/// Weights transferred from the training backend to the inference backend
/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes.
#[test]
fn valid_model_matches_train_model_outputs() {
use burn::tensor::{Tensor, TensorData};
let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let train_model = MlpNet::<Train>::new(&cfg, &train_device());
let infer_model: MlpNet<Infer> = train_model.valid();
// Build the same input on both backends.
let obs_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
let obs_train = Tensor::<Train, 2>::from_data(
TensorData::new(obs_data.clone(), [1, 4]),
&train_device(),
);
let obs_infer = Tensor::<Infer, 2>::from_data(
TensorData::new(obs_data, [1, 4]),
&infer_device(),
);
let (p_train, v_train) = train_model.forward(obs_train);
let (p_infer, v_infer) = infer_model.forward(obs_infer);
let p_train: Vec<f32> = p_train.into_data().to_vec().unwrap();
let p_infer: Vec<f32> = p_infer.into_data().to_vec().unwrap();
let v_train: Vec<f32> = v_train.into_data().to_vec().unwrap();
let v_infer: Vec<f32> = v_infer.into_data().to_vec().unwrap();
for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"policy[{i}] differs after valid(): train={a}, infer={b}"
);
}
assert!(
(v_train[0] - v_infer[0]).abs() < 1e-5,
"value differs after valid(): train={}, infer={}",
v_train[0], v_infer[0]
);
}
// ── 6. Loss converges on a fixed batch ────────────────────────────────────
/// With repeated gradient steps on the same Countdown batch, the loss must
/// decrease monotonically (or at least end lower than it started).
#[test]
fn loss_decreases_on_fixed_batch() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
// Collect a fixed batch from one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples: Vec<TrainSample> = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng);
assert!(!samples.is_empty());
let batch: Vec<TrainSample> = {
let mut replay = ReplayBuffer::new(1000);
replay.extend(samples);
replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect()
};
// Overfit on the same fixed batch for 20 steps.
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut first_loss = f32::NAN;
let mut last_loss = f32::NAN;
for step in 0..20 {
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2);
model = m;
assert!(loss.is_finite(), "loss is not finite at step {step}");
if step == 0 { first_loss = loss; }
last_loss = loss;
}
assert!(
last_loss < first_loss,
"loss did not decrease after 20 steps: first={first_loss}, last={last_loss}"
);
}
// ── 7. Trictrac: multi-iteration loop ─────────────────────────────────────
/// Two full self-play + train iterations on TrictracEnv.
/// Verifies the entire pipeline runs without panic end-to-end.
#[test]
fn trictrac_two_iteration_loop() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mut model = MlpNet::<Train>::new(&cfg, &train_device());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(20_000);
for iter in 0..2 {
// Self-play: 1 game per iteration.
let infer: MlpNet<Infer> = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
assert!(!samples.is_empty(), "iter {iter}: episode was empty");
replay.extend(samples);
// Training: 3 gradient steps.
let batch_size = 16.min(replay.len());
for _ in 0..3 {
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "iter {iter}: loss={loss}");
}
}
}

View file

@ -25,5 +25,9 @@ rand = "0.9"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
transpose = "0.2.2" transpose = "0.2.2"
[[bin]]
name = "random_game"
path = "src/bin/random_game.rs"
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"

View file

@ -0,0 +1,262 @@
//! Run one or many games of trictrac between two random players.
//! In single-game mode, prints play-by-play like OpenSpiel's `example.cc`.
//! In multi-game mode, runs silently and reports throughput at the end.
//!
//! Usage:
//! cargo run --bin random_game -- [--seed <u64>] [--games <usize>] [--max-steps <usize>] [--verbose]
use std::borrow::Cow;
use std::env;
use std::time::Instant;
use trictrac_store::{
training_common::sample_valid_action,
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
};
// ── CLI args ──────────────────────────────────────────────────────────────────
struct Args {
seed: Option<u64>,
games: usize,
max_steps: usize,
verbose: bool,
}
fn parse_args() -> Args {
let args: Vec<String> = env::args().collect();
let mut seed = None;
let mut games = 1;
let mut max_steps = 10_000;
let mut verbose = false;
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--seed" => {
i += 1;
seed = args.get(i).and_then(|s| s.parse().ok());
}
"--games" => {
i += 1;
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
games = v;
}
}
"--max-steps" => {
i += 1;
if let Some(v) = args.get(i).and_then(|s| s.parse().ok()) {
max_steps = v;
}
}
"--verbose" => verbose = true,
_ => {}
}
i += 1;
}
Args {
seed,
games,
max_steps,
verbose,
}
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn player_label(id: u64) -> &'static str {
if id == 1 { "White" } else { "Black" }
}
/// Apply a `Roll` + `RollResult` in one logical step, returning the dice.
/// This collapses the two-step dice phase into a single "chance node" action,
/// matching how the OpenSpiel layer exposes it.
fn apply_dice_roll(state: &mut GameState, roller: &mut DiceRoller) -> Result<Dice, String> {
// RollDice → RollWaiting
state
.consume(&GameEvent::Roll { player_id: state.active_player_id })
.map_err(|e| format!("Roll event failed: {e}"))?;
// RollWaiting → Move / HoldOrGoChoice (or Stage::Ended if 13th hole)
let dice = roller.roll();
state
.consume(&GameEvent::RollResult { player_id: state.active_player_id, dice })
.map_err(|e| format!("RollResult event failed: {e}"))?;
Ok(dice)
}
/// Sample a random action and apply it to `state`, handling the Black-mirror
/// transform exactly as `cxxengine.rs::apply_action` does:
///
/// 1. For Black, build a mirrored view of the state so that `sample_valid_action`
/// and `to_event` always reason from White's perspective.
/// 2. Mirror the resulting event back to the original coordinate frame before
/// calling `state.consume`.
///
/// Returns the chosen action (in the view's coordinate frame) for display.
fn apply_player_action(state: &mut GameState) -> Result<(), String> {
let needs_mirror = state.active_player_id == 2;
// Build a White-perspective view: borrowed for White, owned mirror for Black.
let view: Cow<GameState> = if needs_mirror {
Cow::Owned(state.mirror())
} else {
Cow::Borrowed(state)
};
let action = sample_valid_action(&view)
.ok_or_else(|| format!("no valid action in stage {:?}", state.turn_stage))?;
let event = action
.to_event(&view)
.ok_or_else(|| format!("could not convert {action:?} to event"))?;
// Translate the event from the view's frame back to the game's frame.
let event = if needs_mirror { event.get_mirror(false) } else { event };
state
.consume(&event)
.map_err(|e| format!("consume({action:?}): {e}"))?;
Ok(())
}
// ── Single game ────────────────────────────────────────────────────────────────
/// Run one full game, optionally printing play-by-play.
/// Returns `(steps, truncated)`.
fn run_game(roller: &mut DiceRoller, max_steps: usize, quiet: bool, verbose: bool) -> (usize, bool) {
let mut state = GameState::new_with_players("White", "Black");
let mut step = 0usize;
if !quiet {
println!("{state}");
}
while state.stage != Stage::Ended {
step += 1;
if step > max_steps {
return (step - 1, true);
}
match state.turn_stage {
TurnStage::RollDice => {
let player = state.active_player_id;
match apply_dice_roll(&mut state, roller) {
Ok(dice) => {
if !quiet {
println!(
"[step {step:4}] {} rolls: {} & {}",
player_label(player),
dice.values.0,
dice.values.1
);
}
}
Err(e) => {
eprintln!("Error during dice roll: {e}");
eprintln!("State:\n{state}");
return (step, true);
}
}
}
stage => {
let player = state.active_player_id;
match apply_player_action(&mut state) {
Ok(()) => {
if !quiet {
println!(
"[step {step:4}] {} ({stage:?})",
player_label(player)
);
if verbose {
println!("{state}");
}
}
}
Err(e) => {
eprintln!("Error: {e}");
eprintln!("State:\n{state}");
return (step, true);
}
}
}
}
}
if !quiet {
println!("\n=== Game over after {step} steps ===\n");
println!("{state}");
let white = state.players.get(&1);
let black = state.players.get(&2);
match (white, black) {
(Some(w), Some(b)) => {
println!("White — holes: {:2}, points: {:2}", w.holes, w.points);
println!("Black — holes: {:2}, points: {:2}", b.holes, b.points);
println!();
let white_score = w.holes as i32 * 12 + w.points as i32;
let black_score = b.holes as i32 * 12 + b.points as i32;
if white_score > black_score {
println!("Winner: White (+{})", white_score - black_score);
} else if black_score > white_score {
println!("Winner: Black (+{})", black_score - white_score);
} else {
println!("Draw");
}
}
_ => eprintln!("Could not read final player scores."),
}
}
(step, false)
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
let mut roller = DiceRoller::new(args.seed);
if args.games == 1 {
println!("=== Trictrac — random game ===");
if let Some(s) = args.seed {
println!("seed: {s}");
}
println!();
run_game(&mut roller, args.max_steps, false, args.verbose);
} else {
println!("=== Trictrac — {} games ===", args.games);
if let Some(s) = args.seed {
println!("seed: {s}");
}
println!();
let mut total_steps = 0u64;
let mut truncated = 0usize;
let t0 = Instant::now();
for _ in 0..args.games {
let (steps, trunc) = run_game(&mut roller, args.max_steps, !args.verbose, args.verbose);
total_steps += steps as u64;
if trunc {
truncated += 1;
}
}
let elapsed = t0.elapsed();
let secs = elapsed.as_secs_f64();
println!("Games : {}", args.games);
println!("Truncated : {truncated}");
println!("Total steps: {total_steps}");
println!("Avg steps : {:.1}", total_steps as f64 / args.games as f64);
println!("Elapsed : {:.3} s", secs);
println!("Throughput : {:.1} games/s", args.games as f64 / secs);
println!(" {:.0} steps/s", total_steps as f64 / secs);
}
}

View file

@ -598,12 +598,40 @@ impl Board {
core::array::from_fn(|i| i + min) core::array::from_fn(|i| i + min)
} }
/// Returns cumulative white-checker counts: result[i] = # white checkers in fields 1..=i.
/// result[0] = 0.
pub fn white_checker_cumulative(&self) -> [u8; 25] {
let mut cum = [0u8; 25];
let mut total = 0u8;
for (i, &count) in self.positions.iter().enumerate() {
if count > 0 {
total += count as u8;
}
cum[i + 1] = total;
}
cum
}
pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> { pub fn move_checker(&mut self, color: &Color, cmove: CheckerMove) -> Result<(), Error> {
self.remove_checker(color, cmove.from)?; self.remove_checker(color, cmove.from)?;
self.add_checker(color, cmove.to)?; self.add_checker(color, cmove.to)?;
Ok(()) Ok(())
} }
/// Reverse a previously applied `move_checker`. No validation: assumes the move was valid.
pub fn unmove_checker(&mut self, color: &Color, cmove: CheckerMove) {
let unit = match color {
Color::White => 1,
Color::Black => -1,
};
if cmove.from != 0 {
self.positions[cmove.from - 1] += unit;
}
if cmove.to != 0 {
self.positions[cmove.to - 1] -= unit;
}
}
pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> { pub fn remove_checker(&mut self, color: &Color, field: Field) -> Result<(), Error> {
if field == 0 { if field == 0 {
return Ok(()); return Ok(());

View file

@ -83,8 +83,8 @@ pub mod ffi {
/// Both players' scores. /// Both players' scores.
fn get_players_scores(self: &TricTracEngine) -> PlayerScores; fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
/// 36-element state vector (i8). Mirrored for player_idx == 1. /// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1.
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>; fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<f32>;
/// Human-readable state description for `player_idx`. /// Human-readable state description for `player_idx`.
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String; fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
@ -153,8 +153,7 @@ impl TricTracEngine {
.map(|v| v.into_iter().map(|i| i as u64).collect()) .map(|v| v.into_iter().map(|i| i as u64).collect())
} else { } else {
let mirror = self.game_state.mirror(); let mirror = self.game_state.mirror();
get_valid_action_indices(&mirror) get_valid_action_indices(&mirror).map(|v| v.into_iter().map(|i| i as u64).collect())
.map(|v| v.into_iter().map(|i| i as u64).collect())
} }
})) }))
} }
@ -180,11 +179,11 @@ impl TricTracEngine {
.unwrap_or(-1) .unwrap_or(-1)
} }
fn get_tensor(&self, player_idx: u64) -> Vec<i8> { fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
if player_idx == 0 { if player_idx == 0 {
self.game_state.to_vec() self.game_state.to_tensor()
} else { } else {
self.game_state.mirror().to_vec() self.game_state.mirror().to_tensor()
} }
} }
@ -243,8 +242,9 @@ impl TricTracEngine {
self.game_state self.game_state
), ),
None => anyhow::bail!( None => anyhow::bail!(
"apply_action: could not build event from action index {}", "apply_action: could not build event from action index {} in state {}",
action_idx action_idx,
self.game_state
), ),
} }
})) }))

View file

@ -156,13 +156,6 @@ impl GameState {
if let Some(p1) = self.players.get(&1) { if let Some(p1) = self.players.get(&1) {
mirrored_players.insert(2, p1.mirror()); mirrored_players.insert(2, p1.mirror());
} }
let mirrored_history = self
.history
.clone()
.iter()
.map(|evt| evt.get_mirror(false))
.collect();
let (move1, move2) = self.dice_moves; let (move1, move2) = self.dice_moves;
GameState { GameState {
stage: self.stage, stage: self.stage,
@ -171,7 +164,7 @@ impl GameState {
active_player_id: mirrored_active_player, active_player_id: mirrored_active_player,
// active_player_id: self.active_player_id, // active_player_id: self.active_player_id,
players: mirrored_players, players: mirrored_players,
history: mirrored_history, history: Vec::new(),
dice: self.dice, dice: self.dice,
dice_points: self.dice_points, dice_points: self.dice_points,
dice_moves: (move1.mirror(), move2.mirror()), dice_moves: (move1.mirror(), move2.mirror()),
@ -207,6 +200,110 @@ impl GameState {
self.to_vec().iter().map(|&x| x as f32).collect() self.to_vec().iter().map(|&x| x as f32).collect()
} }
/// Get state as a tensor for neural network training (Option B, TD-Gammon style).
/// Returns 217 f32 values, all normalized to [0, 1].
///
/// Must be called from the active player's perspective: callers should mirror
/// the GameState for Black before calling so that "own" always means White.
///
/// Layout:
/// [0..95] own (White) checkers: 4 values per field × 24 fields
/// [96..191] opp (Black) checkers: 4 values per field × 24 fields
/// [192..193] dice values / 6
/// [194] active player color (0=White, 1=Black)
/// [195] turn_stage / 5
/// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille
/// [200..203] Black player: same
/// [204..207] own quarter filled (quarters 1-4)
/// [208..211] opp quarter filled (quarters 1-4)
/// [212] own checkers all in exit zone (fields 19-24)
/// [213] opp checkers all in exit zone (fields 1-6)
/// [214] own coin de repos taken (field 12 has ≥2 own checkers)
/// [215] opp coin de repos taken (field 13 has ≥2 opp checkers)
/// [216] own dice_roll_count / 3, clamped to 1
pub fn to_tensor(&self) -> Vec<f32> {
let mut t = Vec::with_capacity(217);
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
// [0..95] own (White) checkers, TD-Gammon encoding.
// Each field contributes 4 values:
// (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1]
// The overflow term is divided by 12 because the maximum excess is
// 15 (all checkers) 3 = 12.
for &c in &pos {
let own = c.max(0) as u8;
t.push((own == 1) as u8 as f32);
t.push((own == 2) as u8 as f32);
t.push((own == 3) as u8 as f32);
t.push(own.saturating_sub(3) as f32 / 12.0);
}
// [96..191] opp (Black) checkers, same encoding.
for &c in &pos {
let opp = (-c).max(0) as u8;
t.push((opp == 1) as u8 as f32);
t.push((opp == 2) as u8 as f32);
t.push((opp == 3) as u8 as f32);
t.push(opp.saturating_sub(3) as f32 / 12.0);
}
// [192..193] dice
t.push(self.dice.values.0 as f32 / 6.0);
t.push(self.dice.values.1 as f32 / 6.0);
// [194] active player color
t.push(
self.who_plays()
.map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 })
.unwrap_or(0.0),
);
// [195] turn stage
t.push(u8::from(self.turn_stage) as f32 / 5.0);
// [196..199] White player stats
let wp = self.get_white_player();
t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0));
t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0));
t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
// [200..203] Black player stats
let bp = self.get_black_player();
t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0));
t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0));
t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
// [204..207] own (White) quarter fill status
for &start in &[1usize, 7, 13, 19] {
t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32);
}
// [208..211] opp (Black) quarter fill status
for &start in &[1usize, 7, 13, 19] {
t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32);
}
// [212] can_exit_own: no own checker in fields 1-18
t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32);
// [213] can_exit_opp: no opp checker in fields 7-24
t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32);
// [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers)
t.push((pos[11] >= 2) as u8 as f32);
// [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers)
t.push((pos[12] <= -2) as u8 as f32);
// [216] own dice_roll_count / 3, clamped to 1
t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0));
debug_assert_eq!(t.len(), 217, "to_tensor length mismatch");
t
}
/// Get state as a vector (to be used for bot training input) : /// Get state as a vector (to be used for bot training input) :
/// length = 36 /// length = 36
/// i8 for board positions with negative values for blacks /// i8 for board positions with negative values for blacks
@ -914,6 +1011,16 @@ impl GameState {
self.mark_points(player_id, points) self.mark_points(player_id, points)
} }
/// Total accumulated score for a player: `holes × 12 + points`.
///
/// Returns `0` if `player_id` is not found (e.g. before `init_player`).
pub fn total_score(&self, player_id: PlayerId) -> i32 {
self.players
.get(&player_id)
.map(|p| p.holes as i32 * 12 + p.points as i32)
.unwrap_or(0)
}
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
// Update player points and holes // Update player points and holes
let mut new_hole = false; let mut new_hole = false;

View file

@ -220,7 +220,7 @@ impl MoveRules {
// Si possible, les deux dés doivent être joués // Si possible, les deux dés doivent être joués
if moves.0.get_from() == 0 || moves.1.get_from() == 0 { if moves.0.get_from() == 0 || moves.1.get_from() == 0 {
let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]); let mut possible_moves_sequences = self.get_possible_moves_sequences(true, vec![]);
possible_moves_sequences.retain(|moves| self.check_exit_rules(moves).is_ok()); possible_moves_sequences.retain(|moves| self.check_exit_rules(moves, None).is_ok());
// possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok()); // possible_moves_sequences.retain(|moves| self.check_corner_rules(moves).is_ok());
if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() { if !possible_moves_sequences.contains(moves) && !possible_moves_sequences.is_empty() {
if *moves == (EMPTY_MOVE, EMPTY_MOVE) { if *moves == (EMPTY_MOVE, EMPTY_MOVE) {
@ -238,7 +238,7 @@ impl MoveRules {
// check exit rules // check exit rules
// if !ignored_rules.contains(&TricTracRule::Exit) { // if !ignored_rules.contains(&TricTracRule::Exit) {
self.check_exit_rules(moves)?; self.check_exit_rules(moves, None)?;
// } // }
// --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ---- // --- interdit de jouer dans un cadran que l'adversaire peut encore remplir ----
@ -321,7 +321,11 @@ impl MoveRules {
.is_empty() .is_empty()
} }
fn check_exit_rules(&self, moves: &(CheckerMove, CheckerMove)) -> Result<(), MoveError> { fn check_exit_rules(
&self,
moves: &(CheckerMove, CheckerMove),
exit_seqs: Option<&[(CheckerMove, CheckerMove)]>,
) -> Result<(), MoveError> {
if !moves.0.is_exit() && !moves.1.is_exit() { if !moves.0.is_exit() && !moves.1.is_exit() {
return Ok(()); return Ok(());
} }
@ -331,16 +335,22 @@ impl MoveRules {
} }
// toutes les sorties directes sont autorisées, ainsi que les nombres défaillants // toutes les sorties directes sont autorisées, ainsi que les nombres défaillants
let ignored_rules = vec![TricTracRule::Exit]; let owned;
let possible_moves_sequences_without_excedent = let seqs = match exit_seqs {
self.get_possible_moves_sequences(false, ignored_rules); Some(s) => s,
if possible_moves_sequences_without_excedent.contains(moves) { None => {
owned = self
.get_possible_moves_sequences(false, vec![TricTracRule::Exit]);
&owned
}
};
if seqs.contains(moves) {
return Ok(()); return Ok(());
} }
// À ce stade au moins un des déplacements concerne un nombre en excédant // À ce stade au moins un des déplacements concerne un nombre en excédant
// - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on // - si d'autres séquences de mouvements sans nombre en excédant sont possibles, on
// refuse cette séquence // refuse cette séquence
if !possible_moves_sequences_without_excedent.is_empty() { if !seqs.is_empty() {
return Err(MoveError::ExitByEffectPossible); return Err(MoveError::ExitByEffectPossible);
} }
@ -361,17 +371,24 @@ impl MoveRules {
let _ = board_to_check.move_checker(&Color::White, moves.0); let _ = board_to_check.move_checker(&Color::White, moves.0);
let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check); let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check);
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves); // dice normal order
if (is_move1_exedant && moves.0.get_from() != farthest_on_move1) let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, true);
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2) let is_not_farthest1 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
{ || (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
// dice reversed order
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves, false);
let is_not_farthest2 = (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2);
if is_not_farthest1 && is_not_farthest2 {
return Err(MoveError::ExitNotFarthest); return Err(MoveError::ExitNotFarthest);
} }
Ok(()) Ok(())
} }
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) { fn move_excedants(&self, moves: &(CheckerMove, CheckerMove), dice_order: bool) -> (bool, bool) {
let move1to = if moves.0.get_to() == 0 { let move1to = if moves.0.get_to() == 0 {
25 25
} else { } else {
@ -386,20 +403,16 @@ impl MoveRules {
}; };
let dist2 = move2to - moves.1.get_from(); let dist2 = move2to - moves.1.get_from();
let dist_min = cmp::min(dist1, dist2); let (dice1, dice2) = if dice_order {
let dist_max = cmp::max(dist1, dist2); self.dice.values
let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize;
let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize;
let min_excedant = dist_min != 0 && dist_min < dice_min;
let max_excedant = dist_max != 0 && dist_max < dice_max;
if dist_min == dist1 {
(min_excedant, max_excedant)
} else { } else {
(max_excedant, min_excedant) (self.dice.values.1, self.dice.values.0)
} };
(
dist1 != 0 && dist1 < dice1 as usize,
dist2 != 0 && dist2 < dice2 as usize,
)
} }
fn get_board_exit_farthest(board: &Board) -> Field { fn get_board_exit_farthest(board: &Board) -> Field {
@ -438,12 +451,18 @@ impl MoveRules {
} else { } else {
(dice2, dice1) (dice2, dice1)
}; };
let filling_seqs = if !ignored_rules.contains(&TricTracRule::MustFillQuarter) {
Some(self.get_quarter_filling_moves_sequences())
} else {
None
};
let mut moves_seqs = self.get_possible_moves_sequences_by_dices( let mut moves_seqs = self.get_possible_moves_sequences_by_dices(
dice_max, dice_max,
dice_min, dice_min,
with_excedents, with_excedents,
false, false,
ignored_rules.clone(), &ignored_rules,
filling_seqs.as_deref(),
); );
// if we got valid sequences with the highest die, we don't accept sequences using only the // if we got valid sequences with the highest die, we don't accept sequences using only the
// lowest die // lowest die
@ -453,7 +472,8 @@ impl MoveRules {
dice_max, dice_max,
with_excedents, with_excedents,
ignore_empty, ignore_empty,
ignored_rules, &ignored_rules,
filling_seqs.as_deref(),
); );
moves_seqs.append(&mut moves_seqs_order2); moves_seqs.append(&mut moves_seqs_order2);
let empty_removed = moves_seqs let empty_removed = moves_seqs
@ -524,14 +544,16 @@ impl MoveRules {
let mut moves_seqs = Vec::new(); let mut moves_seqs = Vec::new();
let color = &Color::White; let color = &Color::White;
let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter]; let ignored_rules = vec![TricTracRule::Exit, TricTracRule::MustFillQuarter];
for moves in self.get_possible_moves_sequences(true, ignored_rules) {
let mut board = self.board.clone(); let mut board = self.board.clone();
for moves in self.get_possible_moves_sequences(true, ignored_rules) {
board.move_checker(color, moves.0).unwrap(); board.move_checker(color, moves.0).unwrap();
board.move_checker(color, moves.1).unwrap(); board.move_checker(color, moves.1).unwrap();
// println!("get_quarter_filling_moves_sequences board : {:?}", board); // println!("get_quarter_filling_moves_sequences board : {:?}", board);
if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) { if board.any_quarter_filled(*color) && !moves_seqs.contains(&moves) {
moves_seqs.push(moves); moves_seqs.push(moves);
} }
board.unmove_checker(color, moves.1);
board.unmove_checker(color, moves.0);
} }
moves_seqs moves_seqs
} }
@ -542,18 +564,27 @@ impl MoveRules {
dice2: u8, dice2: u8,
with_excedents: bool, with_excedents: bool,
ignore_empty: bool, ignore_empty: bool,
ignored_rules: Vec<TricTracRule>, ignored_rules: &[TricTracRule],
filling_seqs: Option<&[(CheckerMove, CheckerMove)]>,
) -> Vec<(CheckerMove, CheckerMove)> { ) -> Vec<(CheckerMove, CheckerMove)> {
let mut moves_seqs = Vec::new(); let mut moves_seqs = Vec::new();
let color = &Color::White; let color = &Color::White;
let forbid_exits = self.has_checkers_outside_last_quarter(); let forbid_exits = self.has_checkers_outside_last_quarter();
// Precompute non-excedant sequences once so check_exit_rules need not repeat
// the full move generation for every exit-move candidate.
// Only needed when Exit is not already ignored and exits are actually reachable.
let exit_seqs = if !ignored_rules.contains(&TricTracRule::Exit) && !forbid_exits {
Some(self.get_possible_moves_sequences(false, vec![TricTracRule::Exit]))
} else {
None
};
let mut board = self.board.clone();
// println!("==== First"); // println!("==== First");
for first_move in for first_move in
self.board self.board
.get_possible_moves(*color, dice1, with_excedents, false, forbid_exits) .get_possible_moves(*color, dice1, with_excedents, false, forbid_exits)
{ {
let mut board2 = self.board.clone(); if board.move_checker(color, first_move).is_err() {
if board2.move_checker(color, first_move).is_err() {
println!("err move"); println!("err move");
continue; continue;
} }
@ -563,7 +594,7 @@ impl MoveRules {
let mut has_second_dice_move = false; let mut has_second_dice_move = false;
// println!(" ==== Second"); // println!(" ==== Second");
for second_move in for second_move in
board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits) board.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
{ {
if self if self
.check_corner_rules(&(first_move, second_move)) .check_corner_rules(&(first_move, second_move))
@ -587,24 +618,10 @@ impl MoveRules {
&& self.can_take_corner_by_effect()) && self.can_take_corner_by_effect())
&& (ignored_rules.contains(&TricTracRule::Exit) && (ignored_rules.contains(&TricTracRule::Exit)
|| self || self
.check_exit_rules(&(first_move, second_move)) .check_exit_rules(&(first_move, second_move), exit_seqs.as_deref())
// .inspect_err(|e| {
// println!(
// " 2nd (exit rule): {:?} - {:?}, {:?}",
// e, first_move, second_move
// )
// })
.is_ok())
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|| self
.check_must_fill_quarter_rule(&(first_move, second_move))
// .inspect_err(|e| {
// println!(
// " 2nd: {:?} - {:?}, {:?} for {:?}",
// e, first_move, second_move, self.board
// )
// })
.is_ok()) .is_ok())
&& filling_seqs
.map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, second_move)))
{ {
if second_move.get_to() == 0 if second_move.get_to() == 0
&& first_move.get_to() == 0 && first_move.get_to() == 0
@ -627,16 +644,14 @@ impl MoveRules {
&& !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE)) && !(self.is_move_by_puissance(&(first_move, EMPTY_MOVE))
&& self.can_take_corner_by_effect()) && self.can_take_corner_by_effect())
&& (ignored_rules.contains(&TricTracRule::Exit) && (ignored_rules.contains(&TricTracRule::Exit)
|| self.check_exit_rules(&(first_move, EMPTY_MOVE)).is_ok()) || self.check_exit_rules(&(first_move, EMPTY_MOVE), exit_seqs.as_deref()).is_ok())
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter) && filling_seqs
|| self .map_or(true, |seqs| seqs.is_empty() || seqs.contains(&(first_move, EMPTY_MOVE)))
.check_must_fill_quarter_rule(&(first_move, EMPTY_MOVE))
.is_ok())
{ {
// empty move // empty move
moves_seqs.push((first_move, EMPTY_MOVE)); moves_seqs.push((first_move, EMPTY_MOVE));
} }
//if board2.get_color_fields(*color).is_empty() { board.unmove_checker(color, first_move);
} }
moves_seqs moves_seqs
} }
@ -1495,6 +1510,7 @@ mod tests {
CheckerMove::new(23, 0).unwrap(), CheckerMove::new(23, 0).unwrap(),
CheckerMove::new(24, 0).unwrap(), CheckerMove::new(24, 0).unwrap(),
); );
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
assert_eq!( assert_eq!(
vec![moves], vec![moves],
state.get_possible_moves_sequences_by_dices( state.get_possible_moves_sequences_by_dices(
@ -1502,7 +1518,8 @@ mod tests {
state.dice.values.1, state.dice.values.1,
true, true,
false, false,
vec![] &[],
filling_seqs.as_deref(),
) )
); );
@ -1517,6 +1534,7 @@ mod tests {
CheckerMove::new(19, 23).unwrap(), CheckerMove::new(19, 23).unwrap(),
CheckerMove::new(22, 0).unwrap(), CheckerMove::new(22, 0).unwrap(),
)]; )];
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
assert_eq!( assert_eq!(
moves, moves,
state.get_possible_moves_sequences_by_dices( state.get_possible_moves_sequences_by_dices(
@ -1524,7 +1542,8 @@ mod tests {
state.dice.values.1, state.dice.values.1,
true, true,
false, false,
vec![] &[],
filling_seqs.as_deref(),
) )
); );
let moves = vec![( let moves = vec![(
@ -1538,7 +1557,8 @@ mod tests {
state.dice.values.0, state.dice.values.0,
true, true,
false, false,
vec![] &[],
filling_seqs.as_deref(),
) )
); );
@ -1554,6 +1574,7 @@ mod tests {
CheckerMove::new(19, 21).unwrap(), CheckerMove::new(19, 21).unwrap(),
CheckerMove::new(23, 0).unwrap(), CheckerMove::new(23, 0).unwrap(),
); );
let filling_seqs = Some(state.get_quarter_filling_moves_sequences());
assert_eq!( assert_eq!(
vec![moves], vec![moves],
state.get_possible_moves_sequences_by_dices( state.get_possible_moves_sequences_by_dices(
@ -1561,7 +1582,8 @@ mod tests {
state.dice.values.1, state.dice.values.1,
true, true,
false, false,
vec![] &[],
filling_seqs.as_deref(),
) )
); );
} }
@ -1580,13 +1602,26 @@ mod tests {
CheckerMove::new(19, 23).unwrap(), CheckerMove::new(19, 23).unwrap(),
CheckerMove::new(22, 0).unwrap(), CheckerMove::new(22, 0).unwrap(),
); );
assert!(state.check_exit_rules(&moves).is_ok()); assert!(state.check_exit_rules(&moves, None).is_ok());
let moves = ( let moves = (
CheckerMove::new(19, 24).unwrap(), CheckerMove::new(19, 24).unwrap(),
CheckerMove::new(22, 0).unwrap(), CheckerMove::new(22, 0).unwrap(),
); );
assert!(state.check_exit_rules(&moves).is_ok()); assert!(state.check_exit_rules(&moves, None).is_ok());
state.dice.values = (6, 4);
state.board.set_positions(
&crate::Color::White,
[
-4, -1, -2, -1, 0, 0, 0, -1, 0, 0, 0, 0, -5, -1, 0, 0, 0, 0, 2, 3, 2, 2, 5, 1,
],
);
let moves = (
CheckerMove::new(20, 24).unwrap(),
CheckerMove::new(23, 0).unwrap(),
);
assert!(state.check_exit_rules(&moves, None).is_ok());
} }
#[test] #[test]

View file

@ -113,11 +113,11 @@ impl TricTrac {
[self.get_score(1), self.get_score(2)] [self.get_score(1), self.get_score(2)]
} }
fn get_tensor(&self, player_idx: u64) -> Vec<i8> { fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
if player_idx == 0 { if player_idx == 0 {
self.game_state.to_vec() self.game_state.to_tensor()
} else { } else {
self.game_state.mirror().to_vec() self.game_state.mirror().to_tensor()
} }
} }

View file

@ -3,7 +3,6 @@
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use crate::board::Board;
use crate::{CheckerMove, Dice, GameEvent, GameState}; use crate::{CheckerMove, Dice, GameEvent, GameState};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -221,10 +220,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
// Ajoute aussi les mouvements possibles // Ajoute aussi les mouvements possibles
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice); let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// rules.board is already White-perspective (mirrored if Black): compute cum once.
let cum = rules.board.white_checker_cumulative();
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(white_checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1,
&move2,
&game_state.dice,
&cum,
)?); )?);
} }
} }
@ -235,10 +238,14 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
// Empty move // Empty move
possible_moves.push((CheckerMove::default(), CheckerMove::default())); possible_moves.push((CheckerMove::default(), CheckerMove::default()));
} }
// rules.board is already White-perspective (mirrored if Black): compute cum once.
let cum = rules.board.white_checker_cumulative();
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(white_checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1,
&move2,
&game_state.dice,
&cum,
)?); )?);
} }
} }
@ -251,36 +258,27 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
Ok(valid_actions) Ok(valid_actions)
} }
#[cfg(test)]
fn checker_moves_to_trictrac_action( fn checker_moves_to_trictrac_action(
move1: &CheckerMove, move1: &CheckerMove,
move2: &CheckerMove, move2: &CheckerMove,
color: &crate::Color, color: &crate::Color,
state: &GameState, state: &GameState,
) -> anyhow::Result<TrictracAction> { ) -> anyhow::Result<TrictracAction> {
let dice = &state.dice; // Moves are always in White's coordinate system. For Black, mirror the board first.
let board = &state.board; let cum = if color == &crate::Color::Black {
state.board.mirror().white_checker_cumulative()
if color == &crate::Color::Black {
// Moves are already 'white', so we don't mirror them
white_checker_moves_to_trictrac_action(
move1,
move2,
// &move1.clone().mirror(),
// &move2.clone().mirror(),
dice,
&board.clone().mirror(),
)
// .map(|a| a.mirror())
} else { } else {
white_checker_moves_to_trictrac_action(move1, move2, dice, board) state.board.white_checker_cumulative()
} };
white_checker_moves_to_trictrac_action(move1, move2, &state.dice, &cum)
} }
fn white_checker_moves_to_trictrac_action( fn white_checker_moves_to_trictrac_action(
move1: &CheckerMove, move1: &CheckerMove,
move2: &CheckerMove, move2: &CheckerMove,
dice: &Dice, dice: &Dice,
board: &Board, cum: &[u8; 25],
) -> anyhow::Result<TrictracAction> { ) -> anyhow::Result<TrictracAction> {
let to1 = move1.get_to(); let to1 = move1.get_to();
let to2 = move2.get_to(); let to2 = move2.get_to();
@ -302,7 +300,7 @@ fn white_checker_moves_to_trictrac_action(
} }
} else { } else {
// double sortie // double sortie
if from1 < from2 { if from1 < from2 || from2 == 0 {
max(dice.values.0, dice.values.1) as usize max(dice.values.0, dice.values.1) as usize
} else { } else {
min(dice.values.0, dice.values.1) as usize min(dice.values.0, dice.values.1) as usize
@ -321,11 +319,21 @@ fn white_checker_moves_to_trictrac_action(
} }
let dice_order = diff_move1 == dice.values.0 as usize; let dice_order = diff_move1 == dice.values.0 as usize;
let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize; // cum[i] = # white checkers in fields 1..=i (precomputed by the caller).
let mut tmp_board = board.clone(); // checker1 is the ordinal of the last checker at from1.
// should not raise an error for a valid action let checker1 = cum[from1] as usize;
tmp_board.move_checker(&crate::Color::White, *move1)?; // checker2 is the ordinal on the board after move1 (removed from from1, added to to1).
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; // Adjust the cumulative in O(1) without cloning the board.
let checker2 = {
let mut c = cum[from2];
if from1 > 0 && from2 >= from1 {
c -= 1; // one checker was removed from from1, shifting later ordinals down
}
if from1 > 0 && to1 > 0 && from2 >= to1 {
c += 1; // one checker was added at to1, shifting later ordinals up
}
c as usize
};
Ok(TrictracAction::Move { Ok(TrictracAction::Move {
dice_order, dice_order,
checker1, checker1,
@ -456,5 +464,48 @@ mod tests {
}), }),
ttaction.ok() ttaction.ok()
); );
// Black player
state.active_player_id = 2;
state.dice = Dice { values: (6, 3) };
state.board.set_positions(
&crate::Color::White,
[
2, -11, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 6, 4,
],
);
let ttaction = super::checker_moves_to_trictrac_action(
&CheckerMove::new(21, 0).unwrap(),
&CheckerMove::new(0, 0).unwrap(),
&crate::Color::Black,
&state,
);
assert_eq!(
Some(TrictracAction::Move {
dice_order: true,
checker1: 2,
checker2: 0, // blocked by white on last field
}),
ttaction.ok()
);
// same with dice order reversed
state.dice = Dice { values: (3, 6) };
let ttaction = super::checker_moves_to_trictrac_action(
&CheckerMove::new(21, 0).unwrap(),
&CheckerMove::new(0, 0).unwrap(),
&crate::Color::Black,
&state,
);
assert_eq!(
Some(TrictracAction::Move {
dice_order: false,
checker1: 2,
checker2: 0, // blocked by white on last field
}),
ttaction.ok()
);
} }
} }