Compare commits

...

12 commits

22 changed files with 4930 additions and 5 deletions

131
Cargo.lock generated
View file

@ -92,6 +92,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.21"
@ -1116,6 +1122,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cast_trait"
version = "0.1.2"
@ -1200,6 +1212,33 @@ dependencies = [
"rand 0.7.3",
]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
@ -1453,6 +1492,42 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "critical-section"
version = "1.2.0"
@ -4461,6 +4536,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@ -4597,6 +4678,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.18.0"
@ -5891,6 +6000,18 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "spiel_bot"
version = "0.1.0"
dependencies = [
"anyhow",
"burn",
"criterion",
"rand 0.9.2",
"rand_distr",
"trictrac-store",
]
[[package]]
name = "spin"
version = "0.10.0"
@ -6299,6 +6420,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"

View file

@ -1,4 +1,4 @@
[workspace]
resolver = "2"
members = ["client_cli", "bot", "store"]
members = ["client_cli", "bot", "store", "spiel_bot"]

782
doc/spiel_bot_research.md Normal file
View file

@ -0,0 +1,782 @@
# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac
## 0. Context and Scope
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
other stage to an inline random-opponent loop.
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
for **self-play training**. Its goals:
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
that works with Trictrac's multi-stage turn model and stochastic dice.
- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer)
as the first algorithm.
- Remain **modular**: adding DQN or PPO later requires only a new
`impl Algorithm for Dqn` without touching the environment or network layers.
- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from
`trictrac-store`.
---
## 1. Library Landscape
### 1.1 Neural Network Frameworks
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
runtime needed, the `ndarray` backend is sufficient for CPU training and can
switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by
changing one type alias.
`tch-rs` would be the best choice for raw training throughput (it is the most
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
### 1.2 Other Key Crates
| Crate | Role |
| -------------------- | ----------------------------------------------------------------- |
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
| `anyhow` | error propagation (already used everywhere) |
| `indicatif` | training progress bars |
| `tracing` | structured logging per episode/iteration |
### 1.3 What `burn-rl` Provides (and Does Not)
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
trait. It does **not** provide:
- MCTS or any tree-search algorithm
- Two-player self-play
- Legal action masking during training
- Chance-node handling
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
own (simpler, more targeted) traits and implement MCTS from scratch.
---
## 2. Trictrac-Specific Design Constraints
### 2.1 Multi-Stage Turn Model
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
genuine player choice:
| TurnStage | Node type | Handler |
| ---------------- | ------------------------------- | ------------------------------- |
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
| `Move` | **Player decision** | MCTS / policy network |
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
The environment wrapper advances through forced/chance stages automatically so
that from the algorithm's perspective every node it sees is a genuine player
decision.
### 2.2 Stochastic Dice in MCTS
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
introduce stochasticity. Three approaches exist:
**A. Outcome sampling (recommended)**
During each MCTS simulation, when a chance node is reached, sample one dice
outcome at random and continue. After many simulations the expected value
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
and requires no changes to the standard PUCT formula.
**B. Chance-node averaging (expectimax)**
At each chance node, expand all 21 unique dice pairs weighted by their
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
exact but multiplies the branching factor by ~21 at every dice roll, making it
prohibitively expensive.
**C. Condition on dice in the observation (current approach)**
Dice values are already encoded at indices [192193] of `to_tensor()`. The
network naturally conditions on the rolled dice when it evaluates a position.
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
chance nodes are bypassed by the environment wrapper (approach A). The policy
and value heads learn to play optimally given any dice pair.
**Use approach A + C together**: the environment samples dice automatically
(chance node bypass), and the 217-dim tensor encodes the dice so the network
can exploit them.
### 2.3 Perspective / Mirroring
All move rules and tensor encoding are defined from White's perspective.
`to_tensor()` must always be called after mirroring the state for Black.
The environment wrapper handles this transparently: every observation returned
to an algorithm is already in the active player's perspective.
### 2.4 Legal Action Masking
A crucial difference from the existing `bot/` code: instead of penalizing
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
before softmax — illegal action logits are set to `-inf`. This prevents the
network from wasting capacity on illegal moves and eliminates the need for the
penalty-reward hack.
---
## 3. Proposed Crate Architecture
```
spiel_bot/
├── Cargo.toml
└── src/
├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo"
├── env/
│ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface
│ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store
├── mcts/
│ ├── mod.rs # MctsConfig, run_mcts() entry point
│ ├── node.rs # MctsNode (visit count, W, prior, children)
│ └── search.rs # simulate(), backup(), select_action()
├── network/
│ ├── mod.rs # PolicyValueNet trait
│ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads
├── alphazero/
│ ├── mod.rs # AlphaZeroConfig
│ ├── selfplay.rs # generate_episode() -> Vec<TrainSample>
│ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle)
│ └── trainer.rs # training loop: selfplay → sample → loss → update
└── agent/
├── mod.rs # Agent trait
├── random.rs # RandomAgent (baseline)
└── mcts_agent.rs # MctsAgent: uses trained network for inference
```
Future algorithms slot in without touching the above:
```
├── dqn/ # (future) DQN: impl Algorithm + own replay buffer
└── ppo/ # (future) PPO: impl Algorithm + rollout buffer
```
---
## 4. Core Traits
### 4.1 `GameEnv` — the minimal OpenSpiel interface
```rust
use rand::Rng;
/// Who controls the current node.
pub enum Player {
P1, // player index 0
P2, // player index 1
Chance, // dice roll
Terminal, // game over
}
pub trait GameEnv: Clone + Send + Sync + 'static {
type State: Clone + Send + Sync;
/// Fresh game state.
fn new_game(&self) -> Self::State;
/// Who acts at this node.
fn current_player(&self, s: &Self::State) -> Player;
/// Legal action indices (always in [0, action_space())).
/// Empty only at Terminal nodes.
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
/// Apply a player action (must be legal).
fn apply(&self, s: &mut Self::State, action: usize);
/// Advance a Chance node by sampling dice; no-op at non-Chance nodes.
fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng);
/// Observation tensor from `pov`'s perspective (0 or 1).
/// Returns 217 f32 values for Trictrac.
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
/// Flat observation size (217 for Trictrac).
fn obs_size(&self) -> usize;
/// Total action-space size (514 for Trictrac).
fn action_space(&self) -> usize;
/// Game outcome per player, or None if not Terminal.
/// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw.
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
}
```
### 4.2 `PolicyValueNet` — neural network interface
```rust
use burn::prelude::*;
pub trait PolicyValueNet<B: Backend>: Send + Sync {
/// Forward pass.
/// `obs`: [batch, obs_size] tensor.
/// Returns: (policy_logits [batch, action_space], value [batch]).
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
/// Save weights to `path`.
fn save(&self, path: &std::path::Path) -> anyhow::Result<()>;
/// Load weights from `path`.
fn load(path: &std::path::Path) -> anyhow::Result<Self>
where
Self: Sized;
}
```
### 4.3 `Agent` — player policy interface
```rust
pub trait Agent: Send {
/// Select an action index given the current game state observation.
/// `legal`: mask of valid action indices.
fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize;
}
```
---
## 5. MCTS Implementation
### 5.1 Node
```rust
pub struct MctsNode {
n: u32, // visit count N(s, a)
w: f32, // sum of backed-up values W(s, a)
p: f32, // prior from policy head P(s, a)
children: Vec<(usize, MctsNode)>, // (action_idx, child)
is_expanded: bool,
}
impl MctsNode {
pub fn q(&self) -> f32 {
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
}
/// PUCT score used for selection.
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
}
}
```
### 5.2 Simulation Loop
One MCTS simulation (for deterministic decision nodes):
```
1. SELECTION — traverse from root, always pick child with highest PUCT,
auto-advancing forced/chance nodes via env.apply_chance().
2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get
(policy_logits, value). Mask illegal actions, softmax
the remaining logits → priors P(s,a) for each child.
3. BACKUP — propagate -value up the tree (negate at each level because
perspective alternates between P1 and P2).
```
After `n_simulations` iterations, action selection at the root:
```rust
// During training: sample proportional to N^(1/temperature)
// During evaluation: argmax N
fn select_action(root: &MctsNode, temperature: f32) -> usize { ... }
```
### 5.3 Configuration
```rust
pub struct MctsConfig {
pub n_simulations: usize, // e.g. 200
pub c_puct: f32, // exploration constant, e.g. 1.5
pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3
pub dirichlet_eps: f32, // noise weight, e.g. 0.25
pub temperature: f32, // action sampling temperature (anneals to 0)
}
```
### 5.4 Handling Chance Nodes Inside MCTS
When simulation reaches a Chance node (dice roll), the environment automatically
samples dice and advances to the next decision node. The MCTS tree does **not**
branch on dice outcomes — it treats the sampled outcome as the state. This
corresponds to "outcome sampling" (approach A from §2.2). Because each
simulation independently samples dice, the Q-values at player nodes converge to
their expected value over many simulations.
---
## 6. Network Architecture
### 6.1 ResNet Policy-Value Network
A single trunk with residual blocks, then two heads:
```
Input: [batch, 217]
Linear(217 → 512) + ReLU
ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU)
↓ trunk output [batch, 512]
├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site)
└─ Value head: Linear(512 → 1) → tanh (output in [-1, 1])
```
Burn implementation sketch:
```rust
#[derive(Module, Debug)]
pub struct TrictracNet<B: Backend> {
input: Linear<B>,
res_blocks: Vec<ResBlock<B>>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> TrictracNet<B> {
pub fn forward(&self, obs: Tensor<B, 2>)
-> (Tensor<B, 2>, Tensor<B, 1>)
{
let x = activation::relu(self.input.forward(obs));
let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x));
let policy = self.policy_head.forward(x.clone()); // raw logits
let value = activation::tanh(self.value_head.forward(x))
.squeeze(1);
(policy, value)
}
}
```
A simpler MLP (no residual blocks) is sufficient for a first version and much
faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two
heads.
### 6.2 Loss Function
```
L = MSE(value_pred, z)
+ CrossEntropy(policy_logits_masked, π_mcts)
- c_l2 * L2_regularization
```
Where:
- `z` = game outcome (±1) from the active player's perspective
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
- Legal action masking is applied before computing CrossEntropy
---
## 7. AlphaZero Training Loop
```
INIT
network ← random weights
replay ← empty ReplayBuffer(capacity = 100_000)
LOOP forever:
── Self-play phase ──────────────────────────────────────────────
(parallel with rayon, n_workers games at once)
for each game:
state ← env.new_game()
samples = []
while not terminal:
advance forced/chance nodes automatically
obs ← env.observation(state, current_player)
legal ← env.legal_actions(state)
π, root_value ← mcts.run(state, network, config)
action ← sample from π (with temperature)
samples.push((obs, π, current_player))
env.apply(state, action)
z ← env.returns(state) // final scores
for (obs, π, player) in samples:
replay.push(TrainSample { obs, policy: π, value: z[player] })
── Training phase ───────────────────────────────────────────────
for each gradient step:
batch ← replay.sample(batch_size)
(policy_logits, value_pred) ← network.forward(batch.obs)
loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy)
optimizer.step(loss.backward())
── Evaluation (every N iterations) ─────────────────────────────
win_rate ← evaluate(network_new vs network_prev, n_eval_games)
if win_rate > 0.55: save checkpoint
```
### 7.1 Replay Buffer
```rust
pub struct TrainSample {
pub obs: Vec<f32>, // 217 values
pub policy: Vec<f32>, // 514 values (normalized MCTS visit counts)
pub value: f32, // game outcome ∈ {-1, 0, +1}
}
pub struct ReplayBuffer {
data: VecDeque<TrainSample>,
capacity: usize,
}
impl ReplayBuffer {
pub fn push(&mut self, s: TrainSample) {
if self.data.len() == self.capacity { self.data.pop_front(); }
self.data.push_back(s);
}
pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
// sample without replacement
}
}
```
### 7.2 Parallelism Strategy
Self-play is embarrassingly parallel (each game is independent):
```rust
let samples: Vec<TrainSample> = (0..n_games)
.into_par_iter() // rayon
.flat_map(|_| generate_episode(&env, &network, &mcts_config))
.collect();
```
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
`Autodiff<NdArray<f32>>`.
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
self-play workers from the training thread, allowing continuous data generation
while the GPU trains.
---
## 8. `TrictracEnv` Implementation Sketch
```rust
use trictrac_store::{
training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE},
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
};
#[derive(Clone)]
pub struct TrictracEnv;
impl GameEnv for TrictracEnv {
type State = GameState;
fn new_game(&self) -> GameState {
GameState::new_with_players("P1", "P2")
}
fn current_player(&self, s: &GameState) -> Player {
match s.stage {
Stage::Ended => Player::Terminal,
_ => match s.turn_stage {
TurnStage::RollWaiting => Player::Chance,
_ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 },
},
}
}
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() };
get_valid_action_indices(&view).unwrap_or_default()
}
fn apply(&self, s: &mut GameState, action_idx: usize) {
// advance all forced/chance nodes first, then apply the player action
self.advance_forced(s);
let needs_mirror = s.active_player_id == 2;
let view = if needs_mirror { s.mirror() } else { s.clone() };
if let Some(event) = TrictracAction::from_action_index(action_idx)
.and_then(|a| a.to_event(&view))
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
{
let _ = s.consume(&event);
}
// advance any forced stages that follow
self.advance_forced(s);
}
fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) {
// RollDice → RollWaiting
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
// RollWaiting → next stage
let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) };
let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice });
self.advance_forced(s);
}
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() }
}
fn obs_size(&self) -> usize { 217 }
fn action_space(&self) -> usize { ACTION_SPACE_SIZE }
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
if s.stage != Stage::Ended { return None; }
// Convert hole+point scores to ±1 outcome
let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
Some(match s1.cmp(&s2) {
std::cmp::Ordering::Greater => [ 1.0, -1.0],
std::cmp::Ordering::Less => [-1.0, 1.0],
std::cmp::Ordering::Equal => [ 0.0, 0.0],
})
}
}
impl TrictracEnv {
/// Advance through all forced (non-decision, non-chance) stages.
fn advance_forced(&self, s: &mut GameState) {
use trictrac_store::PointsRules;
loop {
match s.turn_stage {
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
// Scoring is deterministic; compute and apply automatically.
let color = s.player_color_by_id(&s.active_player_id)
.unwrap_or(trictrac_store::Color::White);
let drc = s.players.get(&s.active_player_id)
.map(|p| p.dice_roll_count).unwrap_or(0);
let pr = PointsRules::new(&color, &s.board, s.dice);
let pts = pr.get_points(drc);
let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 };
let _ = s.consume(&GameEvent::Mark {
player_id: s.active_player_id, points,
});
}
TurnStage::RollDice => {
// RollDice is a forced "initiate roll" action with no real choice.
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
}
_ => break,
}
}
}
}
```
---
## 9. Cargo.toml Changes
### 9.1 Add `spiel_bot` to the workspace
```toml
# Cargo.toml (workspace root)
[workspace]
resolver = "2"
members = ["client_cli", "bot", "store", "spiel_bot"]
```
### 9.2 `spiel_bot/Cargo.toml`
```toml
[package]
name = "spiel_bot"
version = "0.1.0"
edition = "2021"
[features]
default = ["alphazero"]
alphazero = []
# dqn = [] # future
# ppo = [] # future
[dependencies]
trictrac-store = { path = "../store" }
anyhow = "1"
rand = "0.9"
rayon = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Burn: NdArray for pure-Rust CPU training
# Replace NdArray with Wgpu or Tch for GPU.
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
# Optional: progress display and structured logging
indicatif = "0.17"
tracing = "0.1"
[[bin]]
name = "az_train"
path = "src/bin/az_train.rs"
[[bin]]
name = "az_eval"
path = "src/bin/az_eval.rs"
```
---
## 10. Comparison: `bot` crate vs `spiel_bot`
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
| ---------------- | --------------------------- | -------------------------------------------- |
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
| Opponent | hardcoded random | self-play |
| Invalid actions | penalise with reward | legal action mask (no penalty) |
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
| Burn dep | yes (0.20) | yes (0.20), same backend |
| `burn-rl` dep | yes | no |
| C++ dep | no | no |
| Python dep | no | no |
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
replace `TrictracEnvironment` with `TrictracEnv`).
---
## 11. Implementation Order
1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game
through the trait, verify terminal state and returns).
2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) +
Burn forward/backward pass test with dummy data.
3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests
(visit counts sum to `n_simulations`, legal mask respected).
4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub
(one iteration, check loss decreases).
5. **Integration test**: run 100 self-play games with a tiny network (1 res block,
64 hidden units), verify the training loop completes without panics.
6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s
on CPU, consistent with `random_game` throughput).
7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule
hyperparameter sweep.
8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report
win rate every checkpoint.
---
## 12. Key Open Questions
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
scored more holes). Better option: normalize the score margin. The final
choice affects how the value head is trained.
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
MCTS with 200 simulations per step means ~120k network evaluations per game.
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
limiting `n_simulations` to 50100 for early training.
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
This is a long-horizon decision that AlphaZero handles naturally via MCTS
lookahead, but needs careful value normalization (a "Go" restarts scoring
within the same game).
4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated
to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic.
This is optional but reduces code duplication.
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
## Implementation results
All benchmarks compile and run. Here's the complete results table:
| Group | Benchmark | Time |
| ------- | ----------------------- | --------------------- |
| env | apply_chance | 3.87 µs |
| | legal_actions | 1.91 µs |
| | observation (to_tensor) | 341 ns |
| | random_game (baseline) | 3.55 ms → 282 games/s |
| network | mlp_b1 hidden=64 | 94.9 µs |
| | mlp_b32 hidden=64 | 141 µs |
| | mlp_b1 hidden=256 | 352 µs |
| | mlp_b32 hidden=256 | 479 µs |
| mcts | zero_eval n=1 | 6.8 µs |
| | zero_eval n=5 | 23.9 µs |
| | zero_eval n=20 | 90.9 µs |
| | mlp64 n=1 | 203 µs |
| | mlp64 n=5 | 622 µs |
| | mlp64 n=20 | 2.30 ms |
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
| | trictrac n=2 | 145 ms → 7 games/s |
| train | mlp64 Adam b=16 | 1.93 ms |
| | mlp64 Adam b=64 | 2.68 ms |
Key observations:
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
- observation (217-value tensor): only 341 ns — not a bottleneck
- legal_actions: 1.9 µs — well optimised
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
- Training: 2.7 ms/step at batch=64 → 370 steps/s
### Summary of Step 8
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
- Output: win/draw/loss per side + combined decisive win rate
Typical usage after training:
cargo run -p spiel_bot --bin az_eval --release -- \
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
### az_train
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
cargo run -p spiel_bot --bin az_train --release
#### ResNet, more games, custom output dir
cargo run -p spiel_bot --bin az_train --release -- \
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
--save-every 20 --out checkpoints/
#### Resume from iteration 50
cargo run -p spiel_bot --bin az_train --release -- \
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
What the binary does each iteration:
1. Calls model.valid() to get a zero-overhead inference copy for self-play
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
5. Saves a .mpk checkpoint every --save-every iterations and always on the last

18
spiel_bot/Cargo.toml Normal file
View file

@ -0,0 +1,18 @@
[package]
name = "spiel_bot"
version = "0.1.0"
edition = "2021"
[dependencies]
trictrac-store = { path = "../store" }
anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "alphazero"
harness = false

View file

@ -0,0 +1,373 @@
//! AlphaZero pipeline benchmarks.
//!
//! Run with:
//!
//! ```sh
//! cargo bench -p spiel_bot
//! ```
//!
//! Use `-- <filter>` to run a specific group, e.g.:
//!
//! ```sh
//! cargo bench -p spiel_bot -- env/
//! cargo bench -p spiel_bot -- network/
//! cargo bench -p spiel_bot -- mcts/
//! cargo bench -p spiel_bot -- episode/
//! cargo bench -p spiel_bot -- train/
//! ```
//!
//! Target: ≥ 500 games/s for random play on CPU (consistent with
//! `random_game` throughput in `trictrac-store`).
use std::time::Duration;
use burn::{
backend::NdArray,
tensor::{Tensor, TensorData, backend::Backend},
};
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
// ── Shared types ───────────────────────────────────────────────────────────
type InferB = NdArray<f32>;
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
/// Uniform evaluator (returns zero logits and zero value).
/// Used to isolate MCTS tree-traversal cost from network cost.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
// ── 1. Environment primitives ──────────────────────────────────────────────
/// Baseline performance of the raw Trictrac environment without MCTS.
/// Target: ≥ 500 full games / second.
fn bench_env(c: &mut Criterion) {
let env = TrictracEnv;
let mut group = c.benchmark_group("env");
group.measurement_time(Duration::from_secs(10));
// ── apply_chance ──────────────────────────────────────────────────────
group.bench_function("apply_chance", |b| {
b.iter_batched(
|| {
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
env.new_game()
},
|mut s| {
env.apply_chance(&mut s, &mut seeded());
black_box(s)
},
BatchSize::SmallInput,
)
});
// ── legal_actions ─────────────────────────────────────────────────────
group.bench_function("legal_actions", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.legal_actions(&s)))
});
// ── observation (to_tensor) ───────────────────────────────────────────
group.bench_function("observation", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.observation(&s, 0)))
});
// ── full random game ──────────────────────────────────────────────────
group.sample_size(50);
group.bench_function("random_game", |b| {
b.iter_batched(
seeded,
|mut rng| {
let mut s = env.new_game();
loop {
match env.current_player(&s) {
Player::Terminal => break,
Player::Chance => env.apply_chance(&mut s, &mut rng),
_ => {
let actions = env.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
env.apply(&mut s, actions[idx]);
}
}
}
black_box(s)
},
BatchSize::SmallInput,
)
});
group.finish();
}
// ── 2. Network inference ───────────────────────────────────────────────────
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
fn bench_network(c: &mut Criterion) {
let mut group = c.benchmark_group("network");
group.measurement_time(Duration::from_secs(5));
for &hidden in &[64usize, 256] {
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
// Batch size 1 — single-position evaluation as in MCTS.
group.bench_with_input(
BenchmarkId::new("mlp_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
// Batch size 32 — training mini-batch.
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("mlp_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
// ── ResNet (4 residual blocks) ────────────────────────────────────────
for &hidden in &[256usize, 512] {
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = ResNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
group.bench_with_input(
BenchmarkId::new("resnet_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("resnet_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}
// ── 3. MCTS ───────────────────────────────────────────────────────────────
/// MCTS cost at different simulation budgets with two evaluator types:
/// - `zero` — isolates tree-traversal overhead (no network).
/// - `mlp64` — real MLP, shows end-to-end cost per move.
fn bench_mcts(c: &mut Criterion) {
let env = TrictracEnv;
// Build a decision-node state (after dice roll).
let state = {
let mut s = env.new_game();
let mut rng = seeded();
while env.current_player(&s).is_chance() {
env.apply_chance(&mut s, &mut rng);
}
s
};
let mut group = c.benchmark_group("mcts");
group.measurement_time(Duration::from_secs(10));
let zero_eval = ZeroEval(514);
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
for &n_sim in &[1usize, 5, 20] {
let cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
// Zero evaluator: tree traversal only.
group.bench_with_input(
BenchmarkId::new("zero_eval", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
// MLP evaluator: full cost per decision.
group.bench_with_input(
BenchmarkId::new("mlp64", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 4. Episode generation ─────────────────────────────────────────────────
/// Full self-play episode latency (one complete game) at different MCTS
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
fn bench_episode(c: &mut Criterion) {
let env = TrictracEnv;
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
let mut group = c.benchmark_group("episode");
group.sample_size(10);
group.measurement_time(Duration::from_secs(60));
for &n_sim in &[1usize, 2] {
let mcts_cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
group.bench_with_input(
BenchmarkId::new("trictrac", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| {
black_box(generate_episode(
&env,
&eval,
&mcts_cfg,
&|_| 1.0,
&mut rng,
))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 5. Training step ───────────────────────────────────────────────────────
/// Gradient-step latency for different batch sizes.
fn bench_train(c: &mut Criterion) {
use burn::optim::AdamConfig;
let mut group = c.benchmark_group("train");
group.measurement_time(Duration::from_secs(10));
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let dummy_samples = |n: usize| -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5; 217],
policy: {
let mut p = vec![0.0f32; 514];
p[i % 514] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
};
for &batch_size in &[16usize, 64] {
let batch = dummy_samples(batch_size);
group.bench_with_input(
BenchmarkId::new("mlp64_adam", batch_size),
&batch_size,
|b, _| {
b.iter_batched(
|| {
(
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
)
},
|(model, mut opt)| {
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── Criterion entry point ──────────────────────────────────────────────────
criterion_group!(
benches,
bench_env,
bench_network,
bench_mcts,
bench_episode,
bench_train,
);
criterion_main!(benches);

View file

@ -0,0 +1,127 @@
//! AlphaZero: self-play data generation, replay buffer, and training step.
//!
//! # Modules
//!
//! | Module | Contents |
//! |--------|----------|
//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] |
//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] |
//! | [`trainer`] | [`train_step`] |
//!
//! # Typical outer loop
//!
//! ```rust,ignore
//! use burn::backend::{Autodiff, NdArray};
//! use burn::optim::AdamConfig;
//! use spiel_bot::{
//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step},
//! env::TrictracEnv,
//! mcts::MctsConfig,
//! network::{MlpConfig, MlpNet},
//! };
//!
//! type Infer = NdArray<f32>;
//! type Train = Autodiff<NdArray<f32>>;
//!
//! let device = Default::default();
//! let env = TrictracEnv;
//! let config = AlphaZeroConfig::default();
//!
//! // Build training model and optimizer.
//! let mut train_model = MlpNet::<Train>::new(&MlpConfig::default(), &device);
//! let mut optimizer = AdamConfig::new().init();
//! let mut replay = ReplayBuffer::new(config.replay_capacity);
//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
//!
//! for _iter in 0..config.n_iterations {
//! // Convert to inference backend for self-play.
//! let infer_model = MlpNet::<Infer>::new(&MlpConfig::default(), &device)
//! .load_record(train_model.clone().into_record());
//! let eval = BurnEvaluator::new(infer_model, device.clone());
//!
//! // Self-play: generate episodes.
//! for _ in 0..config.n_games_per_iter {
//! let samples = generate_episode(&env, &eval, &config.mcts,
//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
//! replay.extend(samples);
//! }
//!
//! // Training: gradient steps.
//! if replay.len() >= config.batch_size {
//! for _ in 0..config.n_train_steps_per_iter {
//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng)
//! .into_iter().cloned().collect();
//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device,
//! config.learning_rate);
//! train_model = m;
//! }
//! }
//! }
//! ```
pub mod replay;
pub mod selfplay;
pub mod trainer;
pub use replay::{ReplayBuffer, TrainSample};
pub use selfplay::{BurnEvaluator, generate_episode};
pub use trainer::{cosine_lr, train_step};
use crate::mcts::MctsConfig;
// ── Configuration ─────────────────────────────────────────────────────────
/// Top-level AlphaZero hyperparameters.
///
/// The MCTS parameters live in [`MctsConfig`]; this struct holds the
/// outer training-loop parameters.
#[derive(Debug, Clone)]
pub struct AlphaZeroConfig {
/// MCTS parameters for self-play.
pub mcts: MctsConfig,
/// Number of self-play games per training iteration.
pub n_games_per_iter: usize,
/// Number of gradient steps per training iteration.
pub n_train_steps_per_iter: usize,
/// Mini-batch size for each gradient step.
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Initial (peak) Adam learning rate.
pub learning_rate: f64,
/// Minimum learning rate for cosine annealing (floor of the schedule).
///
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
/// Compute the current LR with [`cosine_lr`]:
///
/// ```rust,ignore
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
/// ```
pub lr_min: f64,
/// Number of outer iterations (self-play + train) to run.
pub n_iterations: usize,
/// Move index after which the action temperature drops to 0 (greedy play).
pub temperature_drop_move: usize,
}
impl Default for AlphaZeroConfig {
fn default() -> Self {
Self {
mcts: MctsConfig {
n_simulations: 100,
c_puct: 1.5,
dirichlet_alpha: 0.1,
dirichlet_eps: 0.25,
temperature: 1.0,
},
n_games_per_iter: 10,
n_train_steps_per_iter: 20,
batch_size: 64,
replay_capacity: 50_000,
learning_rate: 1e-3,
lr_min: 1e-4, // cosine annealing floor
n_iterations: 100,
temperature_drop_move: 30,
}
}
}

View file

@ -0,0 +1,144 @@
//! Replay buffer for AlphaZero self-play data.
use std::collections::VecDeque;
use rand::Rng;
// ── Training sample ────────────────────────────────────────────────────────
/// One training example produced by self-play.
#[derive(Clone, Debug)]
pub struct TrainSample {
/// Observation tensor from the acting player's perspective (`obs_size` floats).
pub obs: Vec<f32>,
/// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1).
pub policy: Vec<f32>,
/// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw.
pub value: f32,
}
// ── Replay buffer ──────────────────────────────────────────────────────────
/// Fixed-capacity circular buffer of [`TrainSample`]s.
///
/// When the buffer is full, the oldest sample is evicted on push.
/// Samples are drawn without replacement using a Fisher-Yates partial shuffle.
pub struct ReplayBuffer {
data: VecDeque<TrainSample>,
capacity: usize,
}
impl ReplayBuffer {
/// Create a buffer with the given maximum capacity.
pub fn new(capacity: usize) -> Self {
Self {
data: VecDeque::with_capacity(capacity.min(1024)),
capacity,
}
}
/// Add a sample; evicts the oldest if at capacity.
pub fn push(&mut self, sample: TrainSample) {
if self.data.len() == self.capacity {
self.data.pop_front();
}
self.data.push_back(sample);
}
/// Add all samples from an episode.
pub fn extend(&mut self, samples: impl IntoIterator<Item = TrainSample>) {
for s in samples {
self.push(s);
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Sample up to `n` distinct samples, without replacement.
///
/// If the buffer has fewer than `n` samples, all are returned (shuffled).
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
let len = self.data.len();
let n = n.min(len);
// Partial Fisher-Yates using index shuffling.
let mut indices: Vec<usize> = (0..len).collect();
for i in 0..n {
let j = rng.random_range(i..len);
indices.swap(i, j);
}
indices[..n].iter().map(|&i| &self.data[i]).collect()
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
fn dummy(value: f32) -> TrainSample {
TrainSample { obs: vec![value], policy: vec![1.0], value }
}
#[test]
fn push_and_len() {
let mut buf = ReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(dummy(1.0));
buf.push(dummy(2.0));
assert_eq!(buf.len(), 2);
}
#[test]
fn evicts_oldest_at_capacity() {
let mut buf = ReplayBuffer::new(3);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
buf.push(dummy(3.0));
buf.push(dummy(4.0)); // evicts 1.0
assert_eq!(buf.len(), 3);
// Oldest remaining should be 2.0
assert_eq!(buf.data[0].value, 2.0);
}
#[test]
fn sample_batch_size() {
let mut buf = ReplayBuffer::new(20);
for i in 0..10 {
buf.push(dummy(i as f32));
}
let mut rng = SmallRng::seed_from_u64(0);
let batch = buf.sample_batch(5, &mut rng);
assert_eq!(batch.len(), 5);
}
#[test]
fn sample_batch_capped_at_len() {
let mut buf = ReplayBuffer::new(20);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
let mut rng = SmallRng::seed_from_u64(0);
let batch = buf.sample_batch(100, &mut rng);
assert_eq!(batch.len(), 2);
}
#[test]
fn sample_batch_no_duplicates() {
let mut buf = ReplayBuffer::new(20);
for i in 0..10 {
buf.push(dummy(i as f32));
}
let mut rng = SmallRng::seed_from_u64(1);
let batch = buf.sample_batch(10, &mut rng);
let mut seen: Vec<f32> = batch.iter().map(|s| s.value).collect();
seen.sort_by(f32::total_cmp);
seen.dedup();
assert_eq!(seen.len(), 10, "sample contained duplicates");
}
}

View file

@ -0,0 +1,234 @@
//! Self-play episode generation and Burn-backed evaluator.
use std::marker::PhantomData;
use burn::tensor::{backend::Backend, Tensor, TensorData};
use rand::Rng;
use crate::env::GameEnv;
use crate::mcts::{self, Evaluator, MctsConfig, MctsNode};
use crate::network::PolicyValueNet;
use super::replay::TrainSample;
// ── BurnEvaluator ──────────────────────────────────────────────────────────
/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS.
///
/// Use the **inference backend** (`NdArray<f32>`, no `Autodiff` wrapper) so
/// that self-play generates no gradient tape overhead.
pub struct BurnEvaluator<B: Backend, N: PolicyValueNet<B>> {
model: N,
device: B::Device,
_b: PhantomData<B>,
}
impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
pub fn new(model: N, device: B::Device) -> Self {
Self { model, device, _b: PhantomData }
}
pub fn into_model(self) -> N {
self.model
}
}
// Safety: NdArray<f32> modules are Send; we never share across threads without
// external synchronisation.
unsafe impl<B: Backend, N: PolicyValueNet<B>> Send for BurnEvaluator<B, N> {}
unsafe impl<B: Backend, N: PolicyValueNet<B>> Sync for BurnEvaluator<B, N> {}
impl<B: Backend, N: PolicyValueNet<B>> Evaluator for BurnEvaluator<B, N> {
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32) {
let obs_size = obs.len();
let data = TensorData::new(obs.to_vec(), [1, obs_size]);
let obs_tensor = Tensor::<B, 2>::from_data(data, &self.device);
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
let policy: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
let value: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
(policy, value[0])
}
}
// ── Episode generation ─────────────────────────────────────────────────────
/// One pending observation waiting for its game-outcome value label.
struct PendingSample {
obs: Vec<f32>,
policy: Vec<f32>,
player: usize,
}
/// Play one full game using MCTS guided by `evaluator`.
///
/// Returns a [`TrainSample`] for every decision step in the game.
///
/// `temperature_fn(step)` controls exploration: return `1.0` for early
/// moves and `0.0` after a fixed number of moves (e.g. move 30).
pub fn generate_episode<E: GameEnv>(
env: &E,
evaluator: &dyn Evaluator,
mcts_config: &MctsConfig,
temperature_fn: &dyn Fn(usize) -> f32,
rng: &mut impl Rng,
) -> Vec<TrainSample> {
let mut state = env.new_game();
let mut pending: Vec<PendingSample> = Vec::new();
let mut step = 0usize;
loop {
// Advance through chance nodes automatically.
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, rng);
}
if env.current_player(&state).is_terminal() {
break;
}
let player_idx = env.current_player(&state).index().unwrap();
// Run MCTS to get a policy.
let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng);
let policy = mcts::mcts_policy(&root, env.action_space());
// Record the observation from the acting player's perspective.
let obs = env.observation(&state, player_idx);
pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx });
// Select and apply the action.
let temperature = temperature_fn(step);
let action = mcts::select_action(&root, temperature, rng);
env.apply(&mut state, action);
step += 1;
}
// Assign game outcomes.
let returns = env.returns(&state).unwrap_or([0.0; 2]);
pending
.into_iter()
.map(|s| TrainSample {
obs: s.obs,
policy: s.policy,
value: returns[s.player],
})
.collect()
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng};
use crate::env::Player;
use crate::mcts::{Evaluator, MctsConfig};
use crate::network::{MlpConfig, MlpNet};
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn rng() -> SmallRng {
SmallRng::seed_from_u64(7)
}
// Countdown game (same as in mcts tests).
#[derive(Clone, Debug)]
struct CState { remaining: u8, to_move: usize }
#[derive(Clone)]
struct CountdownEnv;
impl GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } }
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 { Player::Terminal }
else if s.to_move == 0 { Player::P1 } else { Player::P2 }
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub { s.remaining = 0; }
else { s.remaining -= sub; s.to_move = 1 - s.to_move; }
}
fn apply_chance<R: Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / 4.0, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
fn tiny_config() -> MctsConfig {
MctsConfig { n_simulations: 5, c_puct: 1.5,
dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 }
}
// ── BurnEvaluator tests ───────────────────────────────────────────────
#[test]
fn burn_evaluator_output_shapes() {
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let (policy, value) = eval.evaluate(&[0.5f32, 0.5]);
assert_eq!(policy.len(), 2, "policy length should equal action_space");
assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)");
}
// ── generate_episode tests ────────────────────────────────────────────
#[test]
fn episode_terminates_and_has_samples() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
assert!(!samples.is_empty(), "episode must produce at least one sample");
}
#[test]
fn episode_sample_values_are_valid() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
for s in &samples {
assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
"unexpected value {}", s.value);
let sum: f32 = s.policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}");
assert_eq!(s.obs.len(), 2);
}
}
#[test]
fn episode_with_temperature_zero() {
let env = CountdownEnv;
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let eval = BurnEvaluator::new(model, device());
// temperature=0 means greedy; episode must still terminate
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng());
assert!(!samples.is_empty());
}
}

View file

@ -0,0 +1,258 @@
//! One gradient-descent training step for AlphaZero.
//!
//! The loss combines:
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
//! - **Value loss** — mean-squared error between the predicted value and the
//! actual game outcome.
//!
//! # Learning-rate scheduling
//!
//! [`cosine_lr`] implements one-cycle cosine annealing:
//!
//! ```text
//! lr(t) = lr_min + 0.5 · (lr_max lr_min) · (1 + cos(π · t / T))
//! ```
//!
//! Typical usage in the outer loop:
//!
//! ```rust,ignore
//! for step in 0..total_train_steps {
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
//! model = m;
//! }
//! ```
//!
//! # Backend
//!
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
//! Weights are transferred between the two via [`burn::record`].
use burn::{
module::AutodiffModule,
optim::{GradientsParams, Optimizer},
prelude::ElementConversion,
tensor::{
activation::log_softmax,
backend::AutodiffBackend,
Tensor, TensorData,
},
};
use crate::network::PolicyValueNet;
use super::replay::TrainSample;
/// Run one gradient step on `model` using `batch`.
///
/// Returns the updated model and the scalar loss value for logging.
///
/// # Parameters
///
/// - `lr` — learning rate (e.g. `1e-3`).
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
pub fn train_step<B, N, O>(
model: N,
optimizer: &mut O,
batch: &[TrainSample],
device: &B::Device,
lr: f64,
) -> (N, f32)
where
B: AutodiffBackend,
N: PolicyValueNet<B> + AutodiffModule<B>,
O: Optimizer<N, B>,
{
assert!(!batch.is_empty(), "train_step called with empty batch");
let batch_size = batch.len();
let obs_size = batch[0].obs.len();
let action_size = batch[0].policy.len();
// ── Build input tensors ────────────────────────────────────────────────
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [batch_size, obs_size]),
device,
);
let policy_target = Tensor::<B, 2>::from_data(
TensorData::new(policy_flat, [batch_size, action_size]),
device,
);
let value_target = Tensor::<B, 2>::from_data(
TensorData::new(value_flat, [batch_size, 1]),
device,
);
// ── Forward pass ──────────────────────────────────────────────────────
let (policy_logits, value_pred) = model.forward(obs_tensor);
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
let log_probs = log_softmax(policy_logits, 1);
let policy_loss = (policy_target.clone().neg() * log_probs)
.sum_dim(1)
.mean();
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
let diff = value_pred - value_target;
let value_loss = (diff.clone() * diff).mean();
// ── Combined loss ─────────────────────────────────────────────────────
let loss = policy_loss + value_loss;
// Extract scalar before backward (consumes the tensor).
let loss_scalar: f32 = loss.clone().into_scalar().elem();
// ── Backward + optimizer step ─────────────────────────────────────────
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
let model = optimizer.step(lr, model, grads);
(model, loss_scalar)
}
// ── Learning-rate schedule ─────────────────────────────────────────────────
/// Cosine learning-rate schedule (one half-period, no warmup).
///
/// Returns the learning rate for training step `step` out of `total_steps`:
///
/// ```text
/// lr(t) = lr_min + 0.5 · (initial lr_min) · (1 + cos(π · t / total))
/// ```
///
/// - At `t = 0` returns `initial`.
/// - At `t = total_steps` (or beyond) returns `lr_min`.
///
/// # Panics
///
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
if total_steps == 0 || step >= total_steps {
return lr_min;
}
let progress = step as f64 / total_steps as f64;
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::AdamConfig,
};
use crate::network::{MlpConfig, MlpNet};
use super::super::replay::TrainSample;
type B = Autodiff<NdArray<f32>>;
fn device() -> <B as burn::tensor::backend::Backend>::Device {
Default::default()
}
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5f32; obs_size],
policy: {
let mut p = vec![0.0f32; action_size];
p[i % action_size] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
}
#[test]
fn train_step_returns_finite_loss() {
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
let model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(8, 4, 4);
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite(), "loss must be finite, got {loss}");
assert!(loss > 0.0, "loss should be positive");
}
#[test]
fn loss_decreases_over_steps() {
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let mut model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
// Same batch every step — loss should decrease.
let batch = dummy_batch(16, 4, 4);
let mut prev_loss = f32::INFINITY;
for _ in 0..10 {
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
model = m;
assert!(loss.is_finite());
prev_loss = loss;
}
// After 10 steps on fixed data, loss should be below a reasonable threshold.
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
}
#[test]
fn train_step_batch_size_one() {
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
let model = MlpNet::<B>::new(&config, &device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(1, 2, 2);
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite());
}
// ── cosine_lr ─────────────────────────────────────────────────────────
#[test]
fn cosine_lr_at_step_zero_is_initial() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
}
#[test]
fn cosine_lr_at_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
}
#[test]
fn cosine_lr_beyond_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
}
#[test]
fn cosine_lr_midpoint_is_average() {
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
let expected = (1e-3 + 1e-5) / 2.0;
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
}
#[test]
fn cosine_lr_monotone_decreasing() {
let mut prev = f64::INFINITY;
for step in 0..=100 {
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
prev = lr;
}
}
#[test]
fn cosine_lr_zero_total_steps_returns_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
assert!((lr - 1e-5).abs() < 1e-10);
}
}

View file

@ -0,0 +1,262 @@
//! Evaluate a trained AlphaZero checkpoint against a random player.
//!
//! # Usage
//!
//! ```sh
//! # Random weights (sanity check — should be ~50 %)
//! cargo run -p spiel_bot --bin az_eval --release
//!
//! # Trained MLP checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50
//!
//! # Trained ResNet checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--checkpoint <path>` | (none) | Load weights from `.mpk` file; random weights if omitted |
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
//! | `--hidden <N>` | 256 (mlp) / 512 (resnet) | Hidden size |
//! | `--n-games <N>` | `100` | Games per side (total = 2 × N) |
//! | `--n-sim <N>` | `50` | MCTS simulations per move |
//! | `--seed <N>` | `42` | RNG seed |
//! | `--c-puct <F>` | `1.5` | PUCT exploration constant |
use std::path::PathBuf;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng, Rng};
use spiel_bot::{
alphazero::BurnEvaluator,
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts, select_action},
network::{MlpConfig, MlpNet, ResNet, ResNetConfig},
};
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
checkpoint: Option<PathBuf>,
arch: String,
hidden: Option<usize>,
n_games: usize,
n_sim: usize,
seed: u64,
c_puct: f32,
}
impl Default for Args {
fn default() -> Self {
Self {
checkpoint: None,
arch: "mlp".into(),
hidden: None,
n_games: 100,
n_sim: 50,
seed: 42,
c_puct: 1.5,
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut args = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); }
"--arch" => { i += 1; args.arch = raw[i].clone(); }
"--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); }
"--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); }
"--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); }
"--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); }
"--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
args
}
// ── Game loop ─────────────────────────────────────────────────────────────────
/// Play one complete game.
///
/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black).
/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0).
fn play_game(
env: &TrictracEnv,
mcts_side: usize,
evaluator: &dyn Evaluator,
mcts_cfg: &MctsConfig,
rng: &mut SmallRng,
) -> [f32; 2] {
let mut state = env.new_game();
loop {
match env.current_player(&state) {
Player::Terminal => {
return env.returns(&state).expect("Terminal state must have returns");
}
Player::Chance => env.apply_chance(&mut state, rng),
player => {
let side = player.index().unwrap(); // 0 = P1, 1 = P2
let action = if side == mcts_side {
let root = run_mcts(env, &state, evaluator, mcts_cfg, rng);
select_action(&root, 0.0, rng) // greedy (temperature = 0)
} else {
let actions = env.legal_actions(&state);
actions[rng.random_range(0..actions.len())]
};
env.apply(&mut state, action);
}
}
}
}
// ── Statistics ────────────────────────────────────────────────────────────────
#[derive(Default)]
struct Stats {
wins: u32,
draws: u32,
losses: u32,
}
impl Stats {
fn record(&mut self, mcts_return: f32) {
if mcts_return > 0.0 { self.wins += 1; }
else if mcts_return < 0.0 { self.losses += 1; }
else { self.draws += 1; }
}
fn total(&self) -> u32 { self.wins + self.draws + self.losses }
fn win_rate_decisive(&self) -> f64 {
let d = self.wins + self.losses;
if d == 0 { 0.5 } else { self.wins as f64 / d as f64 }
}
fn print(&self) {
let n = self.total();
let pct = |k: u32| 100.0 * k as f64 / n as f64;
println!(
" Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)",
self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses),
);
}
}
// ── Evaluation ────────────────────────────────────────────────────────────────
fn run_evaluation(
evaluator: &dyn Evaluator,
n_games: usize,
mcts_cfg: &MctsConfig,
seed: u64,
) -> (Stats, Stats) {
let env = TrictracEnv;
let total = n_games * 2;
let mut as_p1 = Stats::default();
let mut as_p2 = Stats::default();
for i in 0..total {
// Alternate sides: even games → MctsAgent as P1, odd → as P2.
let mcts_side = i % 2;
let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64));
let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng);
let mcts_return = result[mcts_side];
if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); }
let done = i + 1;
if done % 10 == 0 || done == total {
eprint!("\r [{done}/{total}] ", );
}
}
eprintln!();
(as_p1, as_p2)
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
let device: <InferB as burn::tensor::backend::Backend>::Device = Default::default();
// ── Load model ────────────────────────────────────────────────────────
let evaluator: Box<dyn Evaluator> = match args.arch.as_str() {
"resnet" => {
let hidden = args.hidden.unwrap_or(512);
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => ResNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => ResNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, ResNet<InferB>>::new(model, device))
}
"mlp" | _ => {
let hidden = args.hidden.unwrap_or(256);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => MlpNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => MlpNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, MlpNet<InferB>>::new(model, device))
}
};
let mcts_cfg = MctsConfig {
n_simulations: args.n_sim,
c_puct: args.c_puct,
dirichlet_alpha: 0.0, // no exploration noise during evaluation
dirichlet_eps: 0.0,
temperature: 0.0, // greedy action selection
};
// ── Header ────────────────────────────────────────────────────────────
let ckpt_label = args.checkpoint
.as_deref()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("random weights");
println!();
println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent",
args.arch, args.n_sim);
println!("Games per side: {} | Total: {} | Seed: {}",
args.n_games, args.n_games * 2, args.seed);
println!();
// ── Run ───────────────────────────────────────────────────────────────
let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed);
// ── Results ───────────────────────────────────────────────────────────
println!("MctsAgent as P1 (White):");
as_p1.print();
println!("MctsAgent as P2 (Black):");
as_p2.print();
let combined_wins = as_p1.wins + as_p2.wins;
let combined_decisive = combined_wins + as_p1.losses + as_p2.losses;
let combined_wr = if combined_decisive == 0 { 0.5 }
else { combined_wins as f64 / combined_decisive as f64 };
println!();
println!("Combined win rate (excluding draws): {:.1}% [{}/{}]",
combined_wr * 100.0, combined_wins, combined_decisive);
println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%",
as_p1.win_rate_decisive() * 100.0,
as_p2.win_rate_decisive() * 100.0);
}

View file

@ -0,0 +1,314 @@
//! AlphaZero self-play training loop.
//!
//! # Usage
//!
//! ```sh
//! # Start fresh (MLP, default settings)
//! cargo run -p spiel_bot --bin az_train --release
//!
//! # ResNet, 200 iterations, save every 20
//! cargo run -p spiel_bot --bin az_train --release -- \
//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/
//!
//! # Resume from a checkpoint
//! cargo run -p spiel_bot --bin az_train --release -- \
//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
//! | `--hidden N` | 256/512 | Hidden layer width |
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
//! | `--n-iter N` | `100` | Training iterations |
//! | `--n-games N` | `10` | Self-play games per iteration |
//! | `--n-train N` | `20` | Gradient steps per iteration |
//! | `--n-sim N` | `100` | MCTS simulations per move |
//! | `--batch N` | `64` | Mini-batch size |
//! | `--replay-cap N` | `50000` | Replay buffer capacity |
//! | `--lr F` | `1e-3` | Peak (initial) learning rate |
//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) |
//! | `--c-puct F` | `1.5` | PUCT exploration constant |
//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha |
//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight |
//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 |
//! | `--save-every N` | `10` | Save checkpoint every N iterations |
//! | `--seed N` | `42` | RNG seed |
//! | `--resume PATH` | (none) | Load weights from checkpoint before training |
use std::path::{Path, PathBuf};
use std::time::Instant;
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{
BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step,
},
env::TrictracEnv,
mcts::MctsConfig,
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
type TrainB = Autodiff<NdArray<f32>>;
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
arch: String,
hidden: Option<usize>,
out_dir: PathBuf,
n_iter: usize,
n_games: usize,
n_train: usize,
n_sim: usize,
batch_size: usize,
replay_cap: usize,
lr: f64,
lr_min: f64,
c_puct: f32,
dirichlet_alpha: f32,
dirichlet_eps: f32,
temp_drop: usize,
save_every: usize,
seed: u64,
resume: Option<PathBuf>,
}
impl Default for Args {
fn default() -> Self {
Self {
arch: "mlp".into(),
hidden: None,
out_dir: PathBuf::from("checkpoints"),
n_iter: 100,
n_games: 10,
n_train: 20,
n_sim: 100,
batch_size: 64,
replay_cap: 50_000,
lr: 1e-3,
lr_min: 1e-4,
c_puct: 1.5,
dirichlet_alpha: 0.1,
dirichlet_eps: 0.25,
temp_drop: 30,
save_every: 10,
seed: 42,
resume: None,
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut a = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--arch" => { i += 1; a.arch = raw[i].clone(); }
"--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); }
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
"--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); }
"--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); }
"--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); }
"--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); }
"--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); }
"--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); }
"--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); }
"--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); }
"--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); }
"--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); }
"--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); }
"--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); }
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
a
}
// ── Training loop ─────────────────────────────────────────────────────────────
/// Generic training loop, parameterised over the network type.
///
/// `save_fn` receives the **training-backend** model and the target path;
/// it is called in the match arm where the concrete network type is known.
fn train_loop<N>(
mut model: N,
save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>,
args: &Args,
)
where
N: PolicyValueNet<TrainB> + AutodiffModule<TrainB> + Clone,
<N as AutodiffModule<TrainB>>::InnerModule: PolicyValueNet<InferB> + Send + 'static,
{
let train_device: <TrainB as Backend>::Device = Default::default();
let infer_device: <InferB as Backend>::Device = Default::default();
// Type is inferred as OptimizerAdaptor<Adam, N, TrainB> at the call site.
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(args.replay_cap);
let mut rng = SmallRng::seed_from_u64(args.seed);
let env = TrictracEnv;
// Total gradient steps (used for cosine LR denominator).
let total_train_steps = (args.n_iter * args.n_train).max(1);
let mut global_step = 0usize;
println!(
"\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}",
"", args.arch, args.n_iter, args.n_games, args.n_sim, ""
);
for iter in 0..args.n_iter {
let t0 = Instant::now();
// ── Self-play ────────────────────────────────────────────────────
// Convert to inference backend (zero autodiff overhead).
let infer_model: <N as AutodiffModule<TrainB>>::InnerModule = model.valid();
let evaluator: BurnEvaluator<InferB, <N as AutodiffModule<TrainB>>::InnerModule> =
BurnEvaluator::new(infer_model, infer_device.clone());
let mcts_cfg = MctsConfig {
n_simulations: args.n_sim,
c_puct: args.c_puct,
dirichlet_alpha: args.dirichlet_alpha,
dirichlet_eps: args.dirichlet_eps,
temperature: 1.0,
};
let temp_drop = args.temp_drop;
let temperature_fn = |step: usize| -> f32 {
if step < temp_drop { 1.0 } else { 0.0 }
};
let mut new_samples = 0usize;
for _ in 0..args.n_games {
let samples =
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
new_samples += samples.len();
replay.extend(samples);
}
// ── Training ─────────────────────────────────────────────────────
let mut loss_sum = 0.0f32;
let mut n_steps = 0usize;
if replay.len() >= args.batch_size {
for _ in 0..args.n_train {
let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
let batch: Vec<TrainSample> = replay
.sample_batch(args.batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) =
train_step(model, &mut optimizer, &batch, &train_device, lr);
model = m;
loss_sum += loss;
n_steps += 1;
global_step += 1;
}
}
// ── Logging ──────────────────────────────────────────────────────
let elapsed = t0.elapsed();
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps);
println!(
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s",
iter + 1,
args.n_iter,
replay.len(),
new_samples,
avg_loss,
lr_now,
elapsed.as_secs_f32(),
);
// ── Checkpoint ───────────────────────────────────────────────────
let is_last = iter + 1 == args.n_iter;
if (iter + 1) % args.save_every == 0 || is_last {
let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1));
match save_fn(&model, &path) {
Ok(()) => println!(" -> saved {}", path.display()),
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
}
}
}
println!("\nTraining complete.");
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
// Create output directory if it doesn't exist.
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
std::process::exit(1);
}
let train_device: <TrainB as Backend>::Device = Default::default();
match args.arch.as_str() {
"resnet" => {
let hidden = args.hidden.unwrap_or(512);
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
ResNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => ResNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(
model,
&|m: &ResNet<TrainB>, path: &Path| {
// Save via inference model to avoid autodiff record overhead.
m.valid().save(path)
},
&args,
);
}
"mlp" | _ => {
let hidden = args.hidden.unwrap_or(256);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
MlpNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => MlpNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(
model,
&|m: &MlpNet<TrainB>, path: &Path| m.valid().save(path),
&args,
);
}
}
}

121
spiel_bot/src/env/mod.rs vendored Normal file
View file

@ -0,0 +1,121 @@
//! Game environment abstraction — the minimal "Rust OpenSpiel".
//!
//! A `GameEnv` describes the rules of a two-player, zero-sum game that may
//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN,
//! and PPO interact with a game exclusively through this trait.
//!
//! # Node taxonomy
//!
//! Every game position belongs to one of four categories, returned by
//! [`GameEnv::current_player`]:
//!
//! | [`Player`] | Meaning |
//! |-----------|---------|
//! | `P1` | Player 1 (index 0) must choose an action |
//! | `P2` | Player 2 (index 1) must choose an action |
//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) |
//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful |
//!
//! # Perspective convention
//!
//! [`GameEnv::observation`] always returns the board from *the requested
//! player's* point of view. Callers pass `pov = 0` for Player 1 and
//! `pov = 1` for Player 2. The implementation is responsible for any
//! mirroring required (e.g. Trictrac always reasons from White's side).
pub mod trictrac;
pub use trictrac::TrictracEnv;
/// Who controls the current game node.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Player {
/// Player 1 (index 0) is to move.
P1,
/// Player 2 (index 1) is to move.
P2,
/// A stochastic event (dice roll, etc.) must be resolved.
Chance,
/// The game is over.
Terminal,
}
impl Player {
/// Returns the player index (0 or 1) if this is a decision node,
/// or `None` for `Chance` / `Terminal`.
pub fn index(self) -> Option<usize> {
match self {
Player::P1 => Some(0),
Player::P2 => Some(1),
_ => None,
}
}
pub fn is_decision(self) -> bool {
matches!(self, Player::P1 | Player::P2)
}
pub fn is_chance(self) -> bool {
self == Player::Chance
}
pub fn is_terminal(self) -> bool {
self == Player::Terminal
}
}
/// Trait that completely describes a two-player zero-sum game.
///
/// Implementors must be cheaply cloneable (the type is used as a stateless
/// factory; the mutable game state lives in `Self::State`).
pub trait GameEnv: Clone + Send + Sync + 'static {
/// The mutable game state. Must be `Clone` so MCTS can copy
/// game trees without touching the environment.
type State: Clone + Send + Sync;
// ── State creation ────────────────────────────────────────────────────
/// Create a fresh game state at the initial position.
fn new_game(&self) -> Self::State;
// ── Node queries ──────────────────────────────────────────────────────
/// Classify the current node.
fn current_player(&self, s: &Self::State) -> Player;
/// Legal action indices at a decision node (`current_player` is `P1`/`P2`).
///
/// The returned indices are in `[0, action_space())`.
/// The result is unspecified (may panic or return empty) when called at a
/// `Chance` or `Terminal` node.
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
// ── State mutation ────────────────────────────────────────────────────
/// Apply a player action. `action` must be a value returned by
/// [`legal_actions`] for the current state.
fn apply(&self, s: &mut Self::State, action: usize);
/// Sample and apply a stochastic outcome. Must only be called when
/// `current_player(s) == Player::Chance`.
fn apply_chance<R: rand::Rng>(&self, s: &mut Self::State, rng: &mut R);
// ── Observation ───────────────────────────────────────────────────────
/// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2).
/// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`.
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
/// Number of floats returned by [`observation`].
fn obs_size(&self) -> usize;
/// Total number of distinct action indices (the policy head output size).
fn action_space(&self) -> usize;
// ── Terminal values ───────────────────────────────────────────────────
/// Game outcome for each player, or `None` if the game is not over.
///
/// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw.
/// Index 0 = Player 1, index 1 = Player 2.
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
}

535
spiel_bot/src/env/trictrac.rs vendored Normal file
View file

@ -0,0 +1,535 @@
//! [`GameEnv`] implementation for Trictrac.
//!
//! # Game flow (schools_enabled = false)
//!
//! With scoring schools disabled (the standard training configuration),
//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine
//! applies them automatically inside `RollResult` and `Move`. The only
//! four stages that actually occur are:
//!
//! | `TurnStage` | [`Player`] kind | Handled by |
//! |-------------|-----------------|------------|
//! | `RollDice` | `Chance` | [`apply_chance`] |
//! | `RollWaiting` | `Chance` | [`apply_chance`] |
//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] |
//! | `Move` | `P1`/`P2` | [`apply`] |
//!
//! # Perspective
//!
//! The Trictrac engine always reasons from White's perspective. Player 1 is
//! White; Player 2 is Black. When Player 2 is active, the board is mirrored
//! before computing legal actions / the observation tensor, and the resulting
//! event is mirrored back before being applied to the real state. This
//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`.
use trictrac_store::{
training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE},
Dice, GameEvent, GameState, Stage, TurnStage,
};
use super::{GameEnv, Player};
/// Stateless factory that produces Trictrac [`GameState`] environments.
///
/// Schools (`schools_enabled`) are always disabled — scoring is automatic.
#[derive(Clone, Debug, Default)]
pub struct TrictracEnv;
impl GameEnv for TrictracEnv {
type State = GameState;
// ── State creation ────────────────────────────────────────────────────
fn new_game(&self) -> GameState {
GameState::new_with_players("P1", "P2")
}
// ── Node queries ──────────────────────────────────────────────────────
fn current_player(&self, s: &GameState) -> Player {
if s.stage == Stage::Ended {
return Player::Terminal;
}
match s.turn_stage {
TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance,
_ => {
if s.active_player_id == 1 {
Player::P1
} else {
Player::P2
}
}
}
}
/// Returns the legal action indices for the active player.
///
/// The board is automatically mirrored for Player 2 so that the engine
/// always reasons from White's perspective. The returned indices are
/// identical in meaning for both players (checker ordinals are
/// perspective-relative).
///
/// # Panics
///
/// Panics in debug builds if called at a `Chance` or `Terminal` node.
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
debug_assert!(
self.current_player(s).is_decision(),
"legal_actions called at a non-decision node (turn_stage={:?})",
s.turn_stage
);
let indices = if s.active_player_id == 2 {
get_valid_action_indices(&s.mirror())
} else {
get_valid_action_indices(s)
};
indices.unwrap_or_default()
}
// ── State mutation ────────────────────────────────────────────────────
/// Apply a player action index to the game state.
///
/// For Player 2, the action is decoded against the mirrored board and
/// the resulting event is un-mirrored before being applied.
///
/// # Panics
///
/// Panics in debug builds if `action` cannot be decoded or does not
/// produce a valid event for the current state.
fn apply(&self, s: &mut GameState, action: usize) {
let needs_mirror = s.active_player_id == 2;
let event = if needs_mirror {
let view = s.mirror();
TrictracAction::from_action_index(action)
.and_then(|a| a.to_event(&view))
.map(|e| e.get_mirror(false))
} else {
TrictracAction::from_action_index(action).and_then(|a| a.to_event(s))
};
match event {
Some(e) => {
s.consume(&e).expect("apply: consume failed for valid action");
}
None => {
panic!("apply: action index {action} produced no event in state {s}");
}
}
}
/// Sample dice and advance through a chance node.
///
/// Handles both `RollDice` (triggers the roll mechanism, then samples
/// dice) and `RollWaiting` (only samples dice) in a single call so that
/// callers never need to distinguish the two.
///
/// # Panics
///
/// Panics in debug builds if called at a non-Chance node.
fn apply_chance<R: rand::Rng>(&self, s: &mut GameState, rng: &mut R) {
debug_assert!(
self.current_player(s).is_chance(),
"apply_chance called at a non-Chance node (turn_stage={:?})",
s.turn_stage
);
// Step 1: RollDice → RollWaiting (player initiates the roll).
if s.turn_stage == TurnStage::RollDice {
s.consume(&GameEvent::Roll {
player_id: s.active_player_id,
})
.expect("apply_chance: Roll event failed");
}
// Step 2: RollWaiting → Move / HoldOrGoChoice / Ended.
// With schools_enabled=false, point marking is automatic inside consume().
let dice = Dice {
values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)),
};
s.consume(&GameEvent::RollResult {
player_id: s.active_player_id,
dice,
})
.expect("apply_chance: RollResult event failed");
}
// ── Observation ───────────────────────────────────────────────────────
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
if pov == 0 {
s.to_tensor()
} else {
s.mirror().to_tensor()
}
}
fn obs_size(&self) -> usize {
217
}
fn action_space(&self) -> usize {
ACTION_SPACE_SIZE
}
// ── Terminal values ───────────────────────────────────────────────────
/// Returns `Some([r1, r2])` when the game is over, `None` otherwise.
///
/// The winner (higher cumulative score) receives `+1.0`; the loser
/// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score
/// is `holes × 12 + points`.
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
if s.stage != Stage::Ended {
return None;
}
let score = |id: u64| -> i32 {
s.players
.get(&id)
.map(|p| p.holes as i32 * 12 + p.points as i32)
.unwrap_or(0)
};
let s1 = score(1);
let s2 = score(2);
Some(match s1.cmp(&s2) {
std::cmp::Ordering::Greater => [1.0, -1.0],
std::cmp::Ordering::Less => [-1.0, 1.0],
std::cmp::Ordering::Equal => [0.0, 0.0],
})
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{rngs::SmallRng, Rng, SeedableRng};
fn env() -> TrictracEnv {
TrictracEnv
}
fn seeded_rng(seed: u64) -> SmallRng {
SmallRng::seed_from_u64(seed)
}
// ── Initial state ─────────────────────────────────────────────────────
#[test]
fn new_game_is_chance_node() {
let e = env();
let s = e.new_game();
// A fresh game starts at RollDice — a Chance node.
assert_eq!(e.current_player(&s), Player::Chance);
assert!(e.returns(&s).is_none());
}
#[test]
fn new_game_is_not_terminal() {
let e = env();
let s = e.new_game();
assert_ne!(e.current_player(&s), Player::Terminal);
assert!(e.returns(&s).is_none());
}
// ── Chance nodes ──────────────────────────────────────────────────────
#[test]
fn apply_chance_reaches_decision_node() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(1);
// A single chance step must yield a decision node (or end the game,
// which only happens after 12 holes — impossible on the first roll).
e.apply_chance(&mut s, &mut rng);
let p = e.current_player(&s);
assert!(
p.is_decision(),
"expected decision node after first roll, got {p:?}"
);
}
#[test]
fn apply_chance_from_rollwaiting() {
// Check that apply_chance works when called mid-way (at RollWaiting).
let e = env();
let mut s = e.new_game();
assert_eq!(s.turn_stage, TurnStage::RollDice);
// Manually advance to RollWaiting.
s.consume(&GameEvent::Roll { player_id: s.active_player_id })
.unwrap();
assert_eq!(s.turn_stage, TurnStage::RollWaiting);
let mut rng = seeded_rng(2);
e.apply_chance(&mut s, &mut rng);
let p = e.current_player(&s);
assert!(p.is_decision() || p.is_terminal());
}
// ── Legal actions ─────────────────────────────────────────────────────
#[test]
fn legal_actions_nonempty_after_roll() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(3);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
let actions = e.legal_actions(&s);
assert!(
!actions.is_empty(),
"legal_actions must be non-empty at a decision node"
);
}
#[test]
fn legal_actions_within_action_space() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(4);
e.apply_chance(&mut s, &mut rng);
for &a in e.legal_actions(&s).iter() {
assert!(
a < e.action_space(),
"action {a} out of bounds (action_space={})",
e.action_space()
);
}
}
// ── Observations ──────────────────────────────────────────────────────
#[test]
fn observation_has_correct_size() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(5);
e.apply_chance(&mut s, &mut rng);
assert_eq!(e.observation(&s, 0).len(), e.obs_size());
assert_eq!(e.observation(&s, 1).len(), e.obs_size());
}
#[test]
fn observation_values_in_unit_interval() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(6);
e.apply_chance(&mut s, &mut rng);
for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] {
for (i, &v) in obs.iter().enumerate() {
assert!(
v >= 0.0 && v <= 1.0,
"pov={pov}: obs[{i}] = {v} is outside [0,1]"
);
}
}
}
#[test]
fn p1_and_p2_observations_differ() {
// The board is mirrored for P2, so the two observations should differ
// whenever there are checkers in non-symmetric positions (always true
// in a real game after a few moves).
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(7);
// Advance far enough that the board is non-trivial.
for _ in 0..6 {
while e.current_player(&s).is_chance() {
e.apply_chance(&mut s, &mut rng);
}
if e.current_player(&s).is_terminal() {
break;
}
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
if !e.current_player(&s).is_terminal() {
let obs0 = e.observation(&s, 0);
let obs1 = e.observation(&s, 1);
assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board");
}
}
// ── Applying actions ──────────────────────────────────────────────────
#[test]
fn apply_changes_state() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(8);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
let before = s.clone();
let action = e.legal_actions(&s)[0];
e.apply(&mut s, action);
assert_ne!(
before.turn_stage, s.turn_stage,
"state must change after apply"
);
}
#[test]
fn apply_all_legal_actions_do_not_panic() {
// Verify that every action returned by legal_actions can be applied
// without panicking (on several independent copies of the same state).
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(9);
e.apply_chance(&mut s, &mut rng);
assert!(e.current_player(&s).is_decision());
for action in e.legal_actions(&s) {
let mut copy = s.clone();
e.apply(&mut copy, action); // must not panic
}
}
// ── Full game ─────────────────────────────────────────────────────────
/// Run a complete game with random actions through the `GameEnv` trait
/// and verify that:
/// - The game terminates.
/// - `returns()` is `Some` at the end.
/// - The outcome is valid: scores sum to 0 (zero-sum) or each player's
/// score is ±1 / 0.
/// - No step panics.
#[test]
fn full_random_game_terminates() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(42);
let max_steps = 50_000;
for step in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node");
let idx = rng.random_range(0..actions.len());
e.apply(&mut s, actions[idx]);
}
}
assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps");
}
let result = e.returns(&s);
assert!(result.is_some(), "returns() must be Some at Terminal");
let [r1, r2] = result.unwrap();
let sum = r1 + r2;
assert!(
(sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5,
"game must be zero-sum: r1={r1}, r2={r2}, sum={sum}"
);
assert!(
r1.abs() <= 1.0 && r2.abs() <= 1.0,
"returns must be in [-1,1]: r1={r1}, r2={r2}"
);
}
/// Run multiple games with different seeds to stress-test for panics.
#[test]
fn multiple_games_no_panic() {
let e = env();
let max_steps = 20_000;
for seed in 0..10u64 {
let mut s = e.new_game();
let mut rng = seeded_rng(seed);
for _ in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
e.apply(&mut s, actions[idx]);
}
}
}
}
}
// ── Returns ───────────────────────────────────────────────────────────
#[test]
fn returns_none_mid_game() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(11);
// Advance a few steps but do not finish the game.
for _ in 0..4 {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P1 | Player::P2 => {
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
}
}
if !e.current_player(&s).is_terminal() {
assert!(
e.returns(&s).is_none(),
"returns() must be None before the game ends"
);
}
}
// ── Player 2 actions ──────────────────────────────────────────────────
/// Verify that Player 2 (Black) can take actions without panicking,
/// and that the state advances correctly.
#[test]
fn player2_can_act() {
let e = env();
let mut s = e.new_game();
let mut rng = seeded_rng(12);
// Keep stepping until Player 2 gets a turn.
let max_steps = 5_000;
let mut p2_acted = false;
for _ in 0..max_steps {
match e.current_player(&s) {
Player::Terminal => break,
Player::Chance => e.apply_chance(&mut s, &mut rng),
Player::P2 => {
let actions = e.legal_actions(&s);
assert!(!actions.is_empty());
e.apply(&mut s, actions[0]);
p2_acted = true;
break;
}
Player::P1 => {
let actions = e.legal_actions(&s);
e.apply(&mut s, actions[0]);
}
}
}
assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps");
}
}

4
spiel_bot/src/lib.rs Normal file
View file

@ -0,0 +1,4 @@
pub mod alphazero;
pub mod env;
pub mod mcts;
pub mod network;

412
spiel_bot/src/mcts/mod.rs Normal file
View file

@ -0,0 +1,412 @@
//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance.
//!
//! # Algorithm
//!
//! The implementation follows AlphaZero's MCTS:
//!
//! 1. **Expand root** — run the network once to get priors and a value
//! estimate; optionally add Dirichlet noise for training-time exploration.
//! 2. **Simulate** `n_simulations` times:
//! - *Selection* — traverse the tree with PUCT until an unvisited leaf.
//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes;
//! chance nodes are **not** stored in the tree (outcome sampling).
//! - *Expansion* — evaluate the network at the leaf; populate children.
//! - *Backup* — propagate the value upward; negate at each player boundary.
//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]).
//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]).
//!
//! # Perspective convention
//!
//! Every [`MctsNode::w`] is stored **from the perspective of the player who
//! acts at that node**. The backup negates the child value whenever the
//! acting player differs between parent and child.
//!
//! # Stochastic games
//!
//! When [`GameEnv::current_player`] returns [`Player::Chance`], the
//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and
//! continues. Chance nodes are skipped transparently; Q-values converge to
//! their expectation over many simulations (outcome sampling).
pub mod node;
mod search;
pub use node::MctsNode;
use rand::Rng;
use crate::env::GameEnv;
// ── Evaluator trait ────────────────────────────────────────────────────────
/// Evaluates a game position for use in MCTS.
///
/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet)
/// but the `mcts` module itself does **not** depend on Burn.
pub trait Evaluator: Send + Sync {
/// Evaluate `obs` (flat observation vector of length `obs_size`).
///
/// Returns:
/// - `policy_logits`: one raw logit per action (`action_space` entries).
/// Illegal action entries are masked inside the search — no need to
/// zero them here.
/// - `value`: scalar in `(-1, 1)` from **the current player's** perspective.
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
}
// ── Configuration ─────────────────────────────────────────────────────────
/// Hyperparameters for [`run_mcts`].
#[derive(Debug, Clone)]
pub struct MctsConfig {
/// Number of MCTS simulations per move. Typical: 50800.
pub n_simulations: usize,
/// PUCT exploration constant `c_puct`. Typical: 1.02.0.
pub c_puct: f32,
/// Dirichlet noise concentration α. Set to `0.0` to disable.
/// Typical: `0.3` for Chess, `0.1` for large action spaces.
pub dirichlet_alpha: f32,
/// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`.
pub dirichlet_eps: f32,
/// Action sampling temperature. `> 0` = proportional sample, `0` = argmax.
pub temperature: f32,
}
impl Default for MctsConfig {
fn default() -> Self {
Self {
n_simulations: 200,
c_puct: 1.5,
dirichlet_alpha: 0.3,
dirichlet_eps: 0.25,
temperature: 1.0,
}
}
}
// ── Public interface ───────────────────────────────────────────────────────
/// Run MCTS from `state` and return the populated root [`MctsNode`].
///
/// `state` must be a player-decision node (`P1` or `P2`).
/// Use [`mcts_policy`] and [`select_action`] on the returned root.
///
/// # Panics
///
/// Panics if `env.current_player(state)` is not `P1` or `P2`.
pub fn run_mcts<E: GameEnv>(
env: &E,
state: &E::State,
evaluator: &dyn Evaluator,
config: &MctsConfig,
rng: &mut impl Rng,
) -> MctsNode {
let player_idx = env
.current_player(state)
.index()
.expect("run_mcts called at a non-decision node");
// ── Expand root (network called once here, not inside the loop) ────────
let mut root = MctsNode::new(1.0);
search::expand::<E>(&mut root, state, env, evaluator, player_idx);
// ── Optional Dirichlet noise for training exploration ──────────────────
if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 {
search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng);
}
// ── Simulations ────────────────────────────────────────────────────────
for _ in 0..config.n_simulations {
search::simulate::<E>(
&mut root,
state.clone(),
env,
evaluator,
config,
rng,
player_idx,
);
}
root
}
/// Compute the MCTS policy: normalized visit counts at the root.
///
/// Returns a vector of length `action_space` where `policy[a]` is the
/// fraction of simulations that visited action `a`.
pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec<f32> {
let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum();
let mut policy = vec![0.0f32; action_space];
if total > 0.0 {
for (a, child) in &root.children {
policy[*a] = child.n as f32 / total;
}
} else if !root.children.is_empty() {
// n_simulations = 0: uniform over legal actions.
let uniform = 1.0 / root.children.len() as f32;
for (a, _) in &root.children {
policy[*a] = uniform;
}
}
policy
}
/// Select an action index from the root after MCTS.
///
/// * `temperature = 0` — greedy argmax of visit counts.
/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`.
///
/// # Panics
///
/// Panics if the root has no children.
pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize {
assert!(!root.children.is_empty(), "select_action called on a root with no children");
if temperature <= 0.0 {
root.children
.iter()
.max_by_key(|(_, c)| c.n)
.map(|(a, _)| *a)
.unwrap()
} else {
let weights: Vec<f32> = root
.children
.iter()
.map(|(_, c)| (c.n as f32).powf(1.0 / temperature))
.collect();
let total: f32 = weights.iter().sum();
let mut r: f32 = rng.random::<f32>() * total;
for (i, (a, _)) in root.children.iter().enumerate() {
r -= weights[i];
if r <= 0.0 {
return *a;
}
}
root.children.last().map(|(a, _)| *a).unwrap()
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
use crate::env::Player;
// ── Minimal deterministic test game ───────────────────────────────────
//
// "Countdown" — two players alternate subtracting 1 or 2 from a counter.
// The player who brings the counter to 0 wins.
// No chance nodes, two legal actions (0 = -1, 1 = -2).
#[derive(Clone, Debug)]
struct CState {
remaining: u8,
to_move: usize, // at terminal: last mover (winner)
}
#[derive(Clone)]
struct CountdownEnv;
impl crate::env::GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState {
CState { remaining: 6, to_move: 0 }
}
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 {
Player::Terminal
} else if s.to_move == 0 {
Player::P1
} else {
Player::P2
}
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub {
s.remaining = 0;
// to_move stays as winner
} else {
s.remaining -= sub;
s.to_move = 1 - s.to_move;
}
}
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / 6.0, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
// Uniform evaluator: all logits = 0, value = 0.
// `action_space` must match the environment's `action_space()`.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
fn rng() -> SmallRng {
SmallRng::seed_from_u64(42)
}
fn config_n(n: usize) -> MctsConfig {
MctsConfig {
n_simulations: n,
c_puct: 1.5,
dirichlet_alpha: 0.0, // off for reproducibility
dirichlet_eps: 0.0,
temperature: 1.0,
}
}
// ── Visit count tests ─────────────────────────────────────────────────
#[test]
fn visit_counts_sum_to_n_simulations() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng());
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert_eq!(total, 50, "visit counts must sum to n_simulations");
}
#[test]
fn all_root_children_are_legal() {
let env = CountdownEnv;
let state = env.new_game();
let legal = env.legal_actions(&state);
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng());
for (a, _) in &root.children {
assert!(legal.contains(a), "child action {a} is not legal");
}
}
// ── Policy tests ─────────────────────────────────────────────────────
#[test]
fn policy_sums_to_one() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng());
let policy = mcts_policy(&root, env.action_space());
let sum: f32 = policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0");
}
#[test]
fn policy_zero_for_illegal_actions() {
let env = CountdownEnv;
// remaining = 1 → only action 0 is legal
let state = CState { remaining: 1, to_move: 0 };
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng());
let policy = mcts_policy(&root, env.action_space());
assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass");
}
// ── Action selection tests ────────────────────────────────────────────
#[test]
fn greedy_selects_most_visited() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng());
let greedy = select_action(&root, 0.0, &mut rng());
let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap();
assert_eq!(greedy, most_visited);
}
#[test]
fn temperature_sampling_stays_legal() {
let env = CountdownEnv;
let state = env.new_game();
let legal = env.legal_actions(&state);
let mut r = rng();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r);
for _ in 0..20 {
let a = select_action(&root, 1.0, &mut r);
assert!(legal.contains(&a), "sampled action {a} is not legal");
}
}
// ── Zero-simulation edge case ─────────────────────────────────────────
#[test]
fn zero_simulations_uniform_policy() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng());
let policy = mcts_policy(&root, env.action_space());
// With 0 simulations, fallback is uniform over the 2 legal actions.
let sum: f32 = policy.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
// ── Root value ────────────────────────────────────────────────────────
#[test]
fn root_q_in_valid_range() {
let env = CountdownEnv;
let state = env.new_game();
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng());
let q = root.q();
assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]");
}
// ── Integration: run on a real Trictrac game ──────────────────────────
#[test]
fn no_panic_on_trictrac_state() {
use crate::env::TrictracEnv;
let env = TrictracEnv;
let mut state = env.new_game();
let mut r = rng();
// Advance past the initial chance node to reach a decision node.
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, &mut r);
}
if env.current_player(&state).is_terminal() {
return; // unlikely but safe
}
let config = MctsConfig {
n_simulations: 5, // tiny for speed
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
..MctsConfig::default()
};
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
assert_eq!(root.n, 1 + config.n_simulations as u32);
// Children visit counts may sum to less than n_simulations when some
// simulations cross a chance node at depth 1 (turn ends after one move)
// and evaluate with the network directly without updating child.n.
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert!(total <= config.n_simulations as u32);
}
}

View file

@ -0,0 +1,91 @@
//! MCTS tree node.
//!
//! [`MctsNode`] holds the visit statistics for one player-decision position in
//! the search tree. A node is *expanded* the first time the policy-value
//! network is evaluated there; before that it is a leaf.
/// One node in the MCTS tree, representing a player-decision position.
///
/// `w` stores the sum of values backed up into this node, always from the
/// perspective of **the player who acts here**. `q()` therefore also returns
/// a value in `(-1, 1)` from that same perspective.
#[derive(Debug)]
pub struct MctsNode {
/// Visit count `N(s, a)`.
pub n: u32,
/// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective.
pub w: f32,
/// Prior probability `P(s, a)` assigned by the policy head (after masked softmax).
pub p: f32,
/// Children: `(action_index, child_node)`, populated on first expansion.
pub children: Vec<(usize, MctsNode)>,
/// `true` after the network has been evaluated and children have been set up.
pub expanded: bool,
}
impl MctsNode {
/// Create a fresh, unexpanded leaf with the given prior probability.
pub fn new(prior: f32) -> Self {
Self {
n: 0,
w: 0.0,
p: prior,
children: Vec::new(),
expanded: false,
}
}
/// `Q(s, a) = W / N`, or `0.0` if this node has never been visited.
#[inline]
pub fn q(&self) -> f32 {
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
}
/// PUCT selection score:
///
/// ```text
/// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a))
/// ```
#[inline]
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
}
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q_zero_when_unvisited() {
let node = MctsNode::new(0.5);
assert_eq!(node.q(), 0.0);
}
#[test]
fn q_reflects_w_over_n() {
let mut node = MctsNode::new(0.5);
node.n = 4;
node.w = 2.0;
assert!((node.q() - 0.5).abs() < 1e-6);
}
#[test]
fn puct_exploration_dominates_unvisited() {
// Unvisited child should outscore a visited child with negative Q.
let mut visited = MctsNode::new(0.5);
visited.n = 10;
visited.w = -5.0; // Q = -0.5
let unvisited = MctsNode::new(0.5);
let parent_n = 10;
let c = 1.5;
assert!(
unvisited.puct(parent_n, c) > visited.puct(parent_n, c),
"unvisited child should have higher PUCT than a negatively-valued visited child"
);
}
}

View file

@ -0,0 +1,184 @@
//! Simulation, expansion, backup, and noise helpers.
//!
//! These are internal to the `mcts` module; the public entry points are
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
use rand::Rng;
use rand_distr::{Gamma, Distribution};
use crate::env::GameEnv;
use super::{Evaluator, MctsConfig};
use super::node::MctsNode;
// ── Masked softmax ─────────────────────────────────────────────────────────
/// Numerically stable softmax over `legal` actions only.
///
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
/// Returns a probability vector of length `action_space`.
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
let mut probs = vec![0.0f32; action_space];
if legal.is_empty() {
return probs;
}
let max_logit = legal
.iter()
.map(|&a| logits[a])
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for &a in legal {
let e = (logits[a] - max_logit).exp();
probs[a] = e;
sum += e;
}
if sum > 0.0 {
for &a in legal {
probs[a] /= sum;
}
} else {
let uniform = 1.0 / legal.len() as f32;
for &a in legal {
probs[a] = uniform;
}
}
probs
}
// ── Dirichlet noise ────────────────────────────────────────────────────────
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
///
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
pub(super) fn add_dirichlet_noise(
node: &mut MctsNode,
alpha: f32,
eps: f32,
rng: &mut impl Rng,
) {
let n = node.children.len();
if n == 0 {
return;
}
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
return;
};
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
let sum: f32 = samples.iter().sum();
if sum <= 0.0 {
return;
}
for (i, (_, child)) in node.children.iter_mut().enumerate() {
let noise = samples[i] / sum;
child.p = (1.0 - eps) * child.p + eps * noise;
}
}
// ── Expansion ──────────────────────────────────────────────────────────────
/// Evaluate the network at `state` and populate `node` with children.
///
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
/// Returns the network value estimate from `player_idx`'s perspective.
pub(super) fn expand<E: GameEnv>(
node: &mut MctsNode,
state: &E::State,
env: &E,
evaluator: &dyn Evaluator,
player_idx: usize,
) -> f32 {
let obs = env.observation(state, player_idx);
let legal = env.legal_actions(state);
let (logits, value) = evaluator.evaluate(&obs);
let priors = masked_softmax(&logits, &legal, env.action_space());
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
node.expanded = true;
node.n = 1;
node.w = value;
value
}
// ── Simulation ─────────────────────────────────────────────────────────────
/// One MCTS simulation from an **already-expanded** decision node.
///
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
/// and backs up the result.
///
/// * `player_idx` — the player (0 or 1) who acts at `state`.
/// * Returns the backed-up value **from `player_idx`'s perspective**.
pub(super) fn simulate<E: GameEnv>(
node: &mut MctsNode,
state: E::State,
env: &E,
evaluator: &dyn Evaluator,
config: &MctsConfig,
rng: &mut impl Rng,
player_idx: usize,
) -> f32 {
debug_assert!(node.expanded, "simulate called on unexpanded node");
// ── Selection: child with highest PUCT ────────────────────────────────
let parent_n = node.n;
let best = node
.children
.iter()
.enumerate()
.max_by(|(_, (_, a)), (_, (_, b))| {
a.puct(parent_n, config.c_puct)
.partial_cmp(&b.puct(parent_n, config.c_puct))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.expect("expanded node must have at least one child");
let (action, child) = &mut node.children[best];
let action = *action;
// ── Apply action + advance through any chance nodes ───────────────────
let mut next_state = state;
env.apply(&mut next_state, action);
// Track whether we crossed a chance node (dice roll) on the way down.
// If we did, the child's cached legal actions are for a *different* dice
// outcome and must not be reused — evaluate with the network directly.
let mut crossed_chance = false;
while env.current_player(&next_state).is_chance() {
env.apply_chance(&mut next_state, rng);
crossed_chance = true;
}
let next_cp = env.current_player(&next_state);
// ── Evaluate leaf or terminal ──────────────────────────────────────────
// All values are converted to `player_idx`'s perspective before backup.
let child_value = if next_cp.is_terminal() {
let returns = env
.returns(&next_state)
.expect("terminal node must have returns");
returns[player_idx]
} else {
let child_player = next_cp.index().unwrap();
let v = if crossed_chance {
// Outcome sampling: after dice, evaluate the resulting position
// directly with the network. Do NOT build the tree across chance
// boundaries — the dice change which actions are legal, so any
// previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs);
value
} else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player)
} else {
expand::<E>(child, &next_state, env, evaluator, child_player)
};
// Negate when the child belongs to the opponent.
if child_player == player_idx { v } else { -v }
};
// ── Backup ────────────────────────────────────────────────────────────
node.n += 1;
node.w += child_value;
child_value
}

View file

@ -0,0 +1,223 @@
//! Two-hidden-layer MLP policy-value network.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU
//! → Linear(hidden → hidden) → ReLU
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
//! ```
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{
activation::{relu, tanh},
backend::Backend,
Tensor,
},
};
use std::path::Path;
use super::PolicyValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`MlpNet`].
#[derive(Debug, Clone)]
pub struct MlpConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of both hidden layers.
pub hidden_size: usize,
}
impl Default for MlpConfig {
fn default() -> Self {
Self {
obs_size: 217,
action_size: 514,
hidden_size: 256,
}
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Simple two-hidden-layer MLP with shared trunk and two heads.
///
/// Prefer this over [`ResNet`](super::ResNet) when training time is a
/// priority, or as a fast baseline.
#[derive(Module, Debug)]
pub struct MlpNet<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> MlpNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &MlpConfig, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
value_head: LinearConfig::new(config.hidden_size, 1).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
///
/// The file is written exactly at `path`; callers should append `.mpk` if
/// they want the conventional extension.
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> PolicyValueNet<B> for MlpNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let x = relu(self.fc1.forward(obs));
let x = relu(self.fc2.forward(x));
let policy = self.policy_head.forward(x.clone());
let value = tanh(self.value_head.forward(x));
(policy, value)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn default_net() -> MlpNet<B> {
MlpNet::new(&MlpConfig::default(), &device())
}
fn zeros_obs(batch: usize) -> Tensor<B, 2> {
Tensor::zeros([batch, 217], &device())
}
// ── Shape tests ───────────────────────────────────────────────────────
#[test]
fn forward_output_shapes() {
let net = default_net();
let obs = zeros_obs(4);
let (policy, value) = net.forward(obs);
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
}
#[test]
fn forward_single_sample() {
let net = default_net();
let (policy, value) = net.forward(zeros_obs(1));
assert_eq!(policy.dims(), [1, 514]);
assert_eq!(value.dims(), [1, 1]);
}
// ── Value bounds ──────────────────────────────────────────────────────
#[test]
fn value_in_tanh_range() {
let net = default_net();
// Use a non-zero input so the output is not trivially at 0.
let obs = Tensor::<B, 2>::ones([8, 217], &device());
let (_, value) = net.forward(obs);
let data: Vec<f32> = value.into_data().to_vec().unwrap();
for v in &data {
assert!(
*v > -1.0 && *v < 1.0,
"value {v} is outside open interval (-1, 1)"
);
}
}
// ── Policy logits ─────────────────────────────────────────────────────
#[test]
fn policy_logits_not_all_equal() {
// With random weights the 514 logits should not all be identical.
let net = default_net();
let (policy, _) = net.forward(zeros_obs(1));
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
let first = data[0];
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
assert!(!all_same, "all policy logits are identical — network may be degenerate");
}
// ── Config propagation ────────────────────────────────────────────────
#[test]
fn custom_config_shapes() {
let config = MlpConfig {
obs_size: 10,
action_size: 20,
hidden_size: 32,
};
let net = MlpNet::<B>::new(&config, &device());
let obs = Tensor::zeros([3, 10], &device());
let (policy, value) = net.forward(obs);
assert_eq!(policy.dims(), [3, 20]);
assert_eq!(value.dims(), [3, 1]);
}
// ── Save / Load ───────────────────────────────────────────────────────
#[test]
fn save_load_preserves_weights() {
let config = MlpConfig::default();
let net = default_net();
// Forward pass before saving.
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let (policy_before, value_before) = net.forward(obs.clone());
// Save to a temp file.
let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk");
net.save(&path).expect("save failed");
// Load into a fresh model.
let loaded = MlpNet::<B>::load(&config, &path, &device()).expect("load failed");
let (policy_after, value_after) = loaded.forward(obs);
// Outputs must be bitwise identical.
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
}
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
}
let _ = std::fs::remove_file(path);
}
}

View file

@ -0,0 +1,64 @@
//! Neural network abstractions for policy-value learning.
//!
//! # Trait
//!
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
//! implement. It takes an observation tensor and returns raw policy logits
//! plus a tanh-squashed scalar value estimate.
//!
//! # Architectures
//!
//! | Module | Description | Default hidden |
//! |--------|-------------|----------------|
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
//!
//! # Backend convention
//!
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
//! through the forward pass.
//!
//! Both modes use the exact same struct; only the type-level backend changes:
//!
//! ```rust,ignore
//! use burn::backend::{Autodiff, NdArray};
//! type InferBackend = NdArray<f32>;
//! type TrainBackend = Autodiff<NdArray<f32>>;
//!
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
//! ```
//!
//! # Output shapes
//!
//! Given a batch of `B` observations of size `obs_size`:
//!
//! | Output | Shape | Range |
//! |--------|-------|-------|
//! | `policy_logits` | `[B, action_size]` | (unnormalised) |
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
//!
//! Callers are responsible for masking illegal actions in `policy_logits`
//! before passing to softmax.
pub mod mlp;
pub mod resnet;
pub use mlp::{MlpConfig, MlpNet};
pub use resnet::{ResNet, ResNetConfig};
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
/// A neural network that produces a policy and a value from an observation.
///
/// # Shapes
/// - `obs`: `[batch, obs_size]`
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
}

View file

@ -0,0 +1,253 @@
//! Residual-block policy-value network.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU (input projection)
//! → ResBlock × 4 (residual trunk)
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
//!
//! ResBlock:
//! x → Linear → ReLU → Linear → (+x) → ReLU
//! ```
//!
//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better
//! suited for long training runs where board-pattern recognition matters.
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{
activation::{relu, tanh},
backend::Backend,
Tensor,
},
};
use std::path::Path;
use super::PolicyValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`ResNet`].
#[derive(Debug, Clone)]
pub struct ResNetConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of all hidden layers (input projection + residual blocks).
pub hidden_size: usize,
}
impl Default for ResNetConfig {
fn default() -> Self {
Self {
obs_size: 217,
action_size: 514,
hidden_size: 512,
}
}
}
// ── Residual block ────────────────────────────────────────────────────────────
/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`.
///
/// Both linear layers preserve the hidden dimension so the skip connection
/// can be added without projection.
#[derive(Module, Debug)]
struct ResBlock<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> ResBlock<B> {
fn new(hidden: usize, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(hidden, hidden).init(device),
fc2: LinearConfig::new(hidden, hidden).init(device),
}
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let residual = x.clone();
let out = relu(self.fc1.forward(x));
relu(self.fc2.forward(out) + residual)
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Four-residual-block policy-value network.
///
/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and
/// when representing complex positional patterns is important.
#[derive(Module, Debug)]
pub struct ResNet<B: Backend> {
input: Linear<B>,
block0: ResBlock<B>,
block1: ResBlock<B>,
block2: ResBlock<B>,
block3: ResBlock<B>,
policy_head: Linear<B>,
value_head: Linear<B>,
}
impl<B: Backend> ResNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &ResNetConfig, device: &B::Device) -> Self {
let h = config.hidden_size;
Self {
input: LinearConfig::new(config.obs_size, h).init(device),
block0: ResBlock::new(h, device),
block1: ResBlock::new(h, device),
block2: ResBlock::new(h, device),
block3: ResBlock::new(h, device),
policy_head: LinearConfig::new(h, config.action_size).init(device),
value_head: LinearConfig::new(h, 1).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> PolicyValueNet<B> for ResNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let x = relu(self.input.forward(obs));
let x = self.block0.forward(x);
let x = self.block1.forward(x);
let x = self.block2.forward(x);
let x = self.block3.forward(x);
let policy = self.policy_head.forward(x.clone());
let value = tanh(self.value_head.forward(x));
(policy, value)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device {
Default::default()
}
fn small_config() -> ResNetConfig {
// Use a small hidden size so tests are fast.
ResNetConfig {
obs_size: 217,
action_size: 514,
hidden_size: 64,
}
}
fn net() -> ResNet<B> {
ResNet::new(&small_config(), &device())
}
// ── Shape tests ───────────────────────────────────────────────────────
#[test]
fn forward_output_shapes() {
let obs = Tensor::zeros([4, 217], &device());
let (policy, value) = net().forward(obs);
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
}
#[test]
fn forward_single_sample() {
let (policy, value) = net().forward(Tensor::zeros([1, 217], &device()));
assert_eq!(policy.dims(), [1, 514]);
assert_eq!(value.dims(), [1, 1]);
}
// ── Value bounds ──────────────────────────────────────────────────────
#[test]
fn value_in_tanh_range() {
let obs = Tensor::<B, 2>::ones([8, 217], &device());
let (_, value) = net().forward(obs);
let data: Vec<f32> = value.into_data().to_vec().unwrap();
for v in &data {
assert!(
*v > -1.0 && *v < 1.0,
"value {v} is outside open interval (-1, 1)"
);
}
}
// ── Residual connections ──────────────────────────────────────────────
#[test]
fn policy_logits_not_all_equal() {
let (policy, _) = net().forward(Tensor::zeros([1, 217], &device()));
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
let first = data[0];
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
assert!(!all_same, "all policy logits are identical");
}
// ── Save / Load ───────────────────────────────────────────────────────
#[test]
fn save_load_preserves_weights() {
let config = small_config();
let model = net();
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let (policy_before, value_before) = model.forward(obs.clone());
let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk");
model.save(&path).expect("save failed");
let loaded = ResNet::<B>::load(&config, &path, &device()).expect("load failed");
let (policy_after, value_after) = loaded.forward(obs);
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
}
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
}
let _ = std::fs::remove_file(path);
}
// ── Integration: both architectures satisfy PolicyValueNet ────────────
#[test]
fn resnet_satisfies_trait() {
fn requires_net<B: Backend, N: PolicyValueNet<B>>(net: &N, obs: Tensor<B, 2>) {
let (p, v) = net.forward(obs);
assert_eq!(p.dims()[1], 514);
assert_eq!(v.dims()[1], 1);
}
requires_net(&net(), Tensor::zeros([2, 217], &device()));
}
}

View file

@ -0,0 +1,391 @@
//! End-to-end integration tests for the AlphaZero training pipeline.
//!
//! Each test exercises the full chain:
//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`]
//!
//! Two environments are used:
//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves.
//! Used when we need many iterations without worrying about runtime.
//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that
//! the full pipeline compiles and runs correctly with 217-dim observations
//! and 514-dim action spaces.
//!
//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep
//! runtime minimal; correctness, not training quality, is what matters here.
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::MctsConfig,
network::{MlpConfig, MlpNet, PolicyValueNet},
};
// ── Backend aliases ────────────────────────────────────────────────────────
type Train = Autodiff<NdArray<f32>>;
type Infer = NdArray<f32>;
// ── Helpers ────────────────────────────────────────────────────────────────
fn train_device() -> <Train as burn::tensor::backend::Backend>::Device {
Default::default()
}
fn infer_device() -> <Infer as burn::tensor::backend::Backend>::Device {
Default::default()
}
/// Tiny 64-unit MLP, compatible with an obs/action space of any size.
fn tiny_mlp(obs: usize, actions: usize) -> MlpNet<Train> {
let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 };
MlpNet::new(&cfg, &train_device())
}
fn tiny_mcts(n: usize) -> MctsConfig {
MctsConfig {
n_simulations: n,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
}
}
fn seeded() -> SmallRng {
SmallRng::seed_from_u64(0)
}
// ── Countdown environment (fast, local, no external deps) ─────────────────
//
// Two players alternate subtracting 1 or 2 from a counter that starts at N.
// The player who brings the counter to 0 wins.
#[derive(Clone, Debug)]
struct CState {
remaining: u8,
to_move: usize,
}
#[derive(Clone)]
struct CountdownEnv(u8); // starting value
impl GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState {
CState { remaining: self.0, to_move: 0 }
}
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 { Player::Terminal }
else if s.to_move == 0 { Player::P1 }
else { Player::P2 }
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub {
s.remaining = 0;
} else {
s.remaining -= sub;
s.to_move = 1 - s.to_move;
}
}
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / self.0 as f32, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
// ── 1. Full loop on CountdownEnv ──────────────────────────────────────────
/// The canonical AlphaZero loop: self-play → replay → train, iterated.
/// Uses CountdownEnv so each game terminates in < 10 moves.
#[test]
fn countdown_full_loop_no_panic() {
let env = CountdownEnv(8);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(1_000);
for _iter in 0..5 {
// Self-play: 3 games per iteration.
for _ in 0..3 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
replay.extend(samples);
}
// Training: 4 gradient steps per iteration.
if replay.len() >= 4 {
for _ in 0..4 {
let batch: Vec<TrainSample> = replay
.sample_batch(4, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
}
}
assert!(replay.len() > 0);
}
// ── 2. Replay buffer invariants ───────────────────────────────────────────
/// After several Countdown games, replay capacity is respected and batch
/// shapes are consistent.
#[test]
fn replay_buffer_capacity_and_shapes() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let capacity = 50;
let mut replay = ReplayBuffer::new(capacity);
for _ in 0..20 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
replay.extend(samples);
}
assert!(replay.len() <= capacity, "buffer exceeded capacity");
assert!(replay.len() > 0);
let batch = replay.sample_batch(8, &mut rng);
assert_eq!(batch.len(), 8.min(replay.len()));
for s in &batch {
assert_eq!(s.obs.len(), env.obs_size());
assert_eq!(s.policy.len(), env.action_space());
let policy_sum: f32 = s.policy.iter().sum();
assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}");
assert!(s.value.abs() <= 1.0, "value {} out of range", s.value);
}
}
// ── 3. TrictracEnv: sample shapes ─────────────────────────────────────────
/// Verify that one TrictracEnv episode produces samples with the correct
/// tensor dimensions: obs = 217, policy = 514.
#[test]
fn trictrac_sample_shapes() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty(), "Trictrac episode produced no samples");
for (i, s) in samples.iter().enumerate() {
assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len());
assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len());
let policy_sum: f32 = s.policy.iter().sum();
assert!(
(policy_sum - 1.0).abs() < 1e-4,
"sample {i}: policy sums to {policy_sum}"
);
assert!(
s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
"sample {i}: unexpected value {}",
s.value
);
}
}
// ── 4. TrictracEnv: training step after real self-play ────────────────────
/// Collect one Trictrac episode, then verify that a gradient step runs
/// without panic and produces a finite loss.
#[test]
fn trictrac_train_step_finite_loss() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(10_000);
// Generate one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
let n_samples = samples.len();
replay.extend(samples);
// Train on a batch from this episode.
let batch_size = 8.min(n_samples);
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}");
assert!(loss > 0.0, "loss should be positive");
}
// ── 5. Backend transfer: train → infer → same outputs ─────────────────────
/// Weights transferred from the training backend to the inference backend
/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes.
#[test]
fn valid_model_matches_train_model_outputs() {
use burn::tensor::{Tensor, TensorData};
let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let train_model = MlpNet::<Train>::new(&cfg, &train_device());
let infer_model: MlpNet<Infer> = train_model.valid();
// Build the same input on both backends.
let obs_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
let obs_train = Tensor::<Train, 2>::from_data(
TensorData::new(obs_data.clone(), [1, 4]),
&train_device(),
);
let obs_infer = Tensor::<Infer, 2>::from_data(
TensorData::new(obs_data, [1, 4]),
&infer_device(),
);
let (p_train, v_train) = train_model.forward(obs_train);
let (p_infer, v_infer) = infer_model.forward(obs_infer);
let p_train: Vec<f32> = p_train.into_data().to_vec().unwrap();
let p_infer: Vec<f32> = p_infer.into_data().to_vec().unwrap();
let v_train: Vec<f32> = v_train.into_data().to_vec().unwrap();
let v_infer: Vec<f32> = v_infer.into_data().to_vec().unwrap();
for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"policy[{i}] differs after valid(): train={a}, infer={b}"
);
}
assert!(
(v_train[0] - v_infer[0]).abs() < 1e-5,
"value differs after valid(): train={}, infer={}",
v_train[0], v_infer[0]
);
}
// ── 6. Loss converges on a fixed batch ────────────────────────────────────
/// With repeated gradient steps on the same Countdown batch, the loss must
/// decrease monotonically (or at least end lower than it started).
#[test]
fn loss_decreases_on_fixed_batch() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
// Collect a fixed batch from one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples: Vec<TrainSample> = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng);
assert!(!samples.is_empty());
let batch: Vec<TrainSample> = {
let mut replay = ReplayBuffer::new(1000);
replay.extend(samples);
replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect()
};
// Overfit on the same fixed batch for 20 steps.
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut first_loss = f32::NAN;
let mut last_loss = f32::NAN;
for step in 0..20 {
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2);
model = m;
assert!(loss.is_finite(), "loss is not finite at step {step}");
if step == 0 { first_loss = loss; }
last_loss = loss;
}
assert!(
last_loss < first_loss,
"loss did not decrease after 20 steps: first={first_loss}, last={last_loss}"
);
}
// ── 7. Trictrac: multi-iteration loop ─────────────────────────────────────
/// Two full self-play + train iterations on TrictracEnv.
/// Verifies the entire pipeline runs without panic end-to-end.
#[test]
fn trictrac_two_iteration_loop() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mut model = MlpNet::<Train>::new(&cfg, &train_device());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(20_000);
for iter in 0..2 {
// Self-play: 1 game per iteration.
let infer: MlpNet<Infer> = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
assert!(!samples.is_empty(), "iter {iter}: episode was empty");
replay.extend(samples);
// Training: 3 gradient steps.
let batch_size = 16.min(replay.len());
for _ in 0..3 {
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "iter {iter}: loss={loss}");
}
}
}

View file

@ -225,22 +225,26 @@ impl GameState {
let mut t = Vec::with_capacity(217);
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
// [0..95] own (White) checkers, TD-Gammon encoding
// [0..95] own (White) checkers, TD-Gammon encoding.
// Each field contributes 4 values:
// (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1]
// The overflow term is divided by 12 because the maximum excess is
// 15 (all checkers) 3 = 12.
for &c in &pos {
let own = c.max(0) as u8;
t.push((own == 1) as u8 as f32);
t.push((own == 2) as u8 as f32);
t.push((own == 3) as u8 as f32);
t.push(own.saturating_sub(3) as f32);
t.push(own.saturating_sub(3) as f32 / 12.0);
}
// [96..191] opp (Black) checkers, TD-Gammon encoding
// [96..191] opp (Black) checkers, same encoding.
for &c in &pos {
let opp = (-c).max(0) as u8;
t.push((opp == 1) as u8 as f32);
t.push((opp == 2) as u8 as f32);
t.push((opp == 3) as u8 as f32);
t.push(opp.saturating_sub(3) as f32);
t.push(opp.saturating_sub(3) as f32 / 12.0);
}
// [192..193] dice