From 1c4c81441792b7bf6b603a6694314aac3e786e2b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Mar 2026 21:23:31 +0100 Subject: [PATCH 01/13] fix: training_common::white_checker_moves_to_trictrac_action --- store/src/training_common.rs | 55 ++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/store/src/training_common.rs b/store/src/training_common.rs index 6a5b537..69765fc 100644 --- a/store/src/training_common.rs +++ b/store/src/training_common.rs @@ -224,7 +224,10 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result anyhow::Result Date: Sat, 7 Mar 2026 17:52:04 +0100 Subject: [PATCH 02/13] doc:rust open_spiel research --- doc/spiel_bot_research.md | 782 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 782 insertions(+) create mode 100644 doc/spiel_bot_research.md diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md new file mode 100644 index 0000000..a8863af --- /dev/null +++ b/doc/spiel_bot_research.md @@ -0,0 +1,782 @@ +# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac + +## 0. Context and Scope + +The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library +(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` +encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every +other stage to an inline random-opponent loop. + +`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency +for **self-play training**. Its goals: + +- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") + that works with Trictrac's multi-stage turn model and stochastic dice. +- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) + as the first algorithm. +- Remain **modular**: adding DQN or PPO later requires only a new + `impl Algorithm for Dqn` without touching the environment or network layers. +- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from + `trictrac-store`. + +--- + +## 1. Library Landscape + +### 1.1 Neural Network Frameworks + +| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | +| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | +| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | +| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | +| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | +| ndarray alone | no | no | yes | mature | array ops only; no autograd | + +**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ +runtime needed, the `ndarray` backend is sufficient for CPU training and can +switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by +changing one type alias. + +`tch-rs` would be the best choice for raw training throughput (it is the most +battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the +pure-Rust constraint. If training speed becomes the bottleneck after prototyping, +switching `spiel_bot` to `tch-rs` is a one-line backend swap. + +### 1.2 Other Key Crates + +| Crate | Role | +| -------------------- | ----------------------------------------------------------------- | +| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | +| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | +| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | +| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | +| `anyhow` | error propagation (already used everywhere) | +| `indicatif` | training progress bars | +| `tracing` | structured logging per episode/iteration | + +### 1.3 What `burn-rl` Provides (and Does Not) + +The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) +provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` +trait. It does **not** provide: + +- MCTS or any tree-search algorithm +- Two-player self-play +- Legal action masking during training +- Chance-node handling + +For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its +own (simpler, more targeted) traits and implement MCTS from scratch. + +--- + +## 2. Trictrac-Specific Design Constraints + +### 2.1 Multi-Stage Turn Model + +A Trictrac turn passes through up to six `TurnStage` values. Only two involve +genuine player choice: + +| TurnStage | Node type | Handler | +| ---------------- | ------------------------------- | ------------------------------- | +| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | +| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | +| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | +| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | +| `Move` | **Player decision** | MCTS / policy network | +| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | + +The environment wrapper advances through forced/chance stages automatically so +that from the algorithm's perspective every node it sees is a genuine player +decision. + +### 2.2 Stochastic Dice in MCTS + +AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice +introduce stochasticity. Three approaches exist: + +**A. Outcome sampling (recommended)** +During each MCTS simulation, when a chance node is reached, sample one dice +outcome at random and continue. After many simulations the expected value +converges. This is the approach used by OpenSpiel's MCTS for stochastic games +and requires no changes to the standard PUCT formula. + +**B. Chance-node averaging (expectimax)** +At each chance node, expand all 21 unique dice pairs weighted by their +probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is +exact but multiplies the branching factor by ~21 at every dice roll, making it +prohibitively expensive. + +**C. Condition on dice in the observation (current approach)** +Dice values are already encoded at indices [192–193] of `to_tensor()`. The +network naturally conditions on the rolled dice when it evaluates a position. +MCTS only runs on player-decision nodes _after_ the dice have been sampled; +chance nodes are bypassed by the environment wrapper (approach A). The policy +and value heads learn to play optimally given any dice pair. + +**Use approach A + C together**: the environment samples dice automatically +(chance node bypass), and the 217-dim tensor encodes the dice so the network +can exploit them. + +### 2.3 Perspective / Mirroring + +All move rules and tensor encoding are defined from White's perspective. +`to_tensor()` must always be called after mirroring the state for Black. +The environment wrapper handles this transparently: every observation returned +to an algorithm is already in the active player's perspective. + +### 2.4 Legal Action Masking + +A crucial difference from the existing `bot/` code: instead of penalizing +invalid actions with `ERROR_REWARD`, the policy head logits are **masked** +before softmax — illegal action logits are set to `-inf`. This prevents the +network from wasting capacity on illegal moves and eliminates the need for the +penalty-reward hack. + +--- + +## 3. Proposed Crate Architecture + +``` +spiel_bot/ +├── Cargo.toml +└── src/ + ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" + │ + ├── env/ + │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface + │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store + │ + ├── mcts/ + │ ├── mod.rs # MctsConfig, run_mcts() entry point + │ ├── node.rs # MctsNode (visit count, W, prior, children) + │ └── search.rs # simulate(), backup(), select_action() + │ + ├── network/ + │ ├── mod.rs # PolicyValueNet trait + │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads + │ + ├── alphazero/ + │ ├── mod.rs # AlphaZeroConfig + │ ├── selfplay.rs # generate_episode() -> Vec + │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) + │ └── trainer.rs # training loop: selfplay → sample → loss → update + │ + └── agent/ + ├── mod.rs # Agent trait + ├── random.rs # RandomAgent (baseline) + └── mcts_agent.rs # MctsAgent: uses trained network for inference +``` + +Future algorithms slot in without touching the above: + +``` + ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer + └── ppo/ # (future) PPO: impl Algorithm + rollout buffer +``` + +--- + +## 4. Core Traits + +### 4.1 `GameEnv` — the minimal OpenSpiel interface + +```rust +use rand::Rng; + +/// Who controls the current node. +pub enum Player { + P1, // player index 0 + P2, // player index 1 + Chance, // dice roll + Terminal, // game over +} + +pub trait GameEnv: Clone + Send + Sync + 'static { + type State: Clone + Send + Sync; + + /// Fresh game state. + fn new_game(&self) -> Self::State; + + /// Who acts at this node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices (always in [0, action_space())). + /// Empty only at Terminal nodes. + fn legal_actions(&self, s: &Self::State) -> Vec; + + /// Apply a player action (must be legal). + fn apply(&self, s: &mut Self::State, action: usize); + + /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. + fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); + + /// Observation tensor from `pov`'s perspective (0 or 1). + /// Returns 217 f32 values for Trictrac. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Flat observation size (217 for Trictrac). + fn obs_size(&self) -> usize; + + /// Total action-space size (514 for Trictrac). + fn action_space(&self) -> usize; + + /// Game outcome per player, or None if not Terminal. + /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} +``` + +### 4.2 `PolicyValueNet` — neural network interface + +```rust +use burn::prelude::*; + +pub trait PolicyValueNet: Send + Sync { + /// Forward pass. + /// `obs`: [batch, obs_size] tensor. + /// Returns: (policy_logits [batch, action_space], value [batch]). + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); + + /// Save weights to `path`. + fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; + + /// Load weights from `path`. + fn load(path: &std::path::Path) -> anyhow::Result + where + Self: Sized; +} +``` + +### 4.3 `Agent` — player policy interface + +```rust +pub trait Agent: Send { + /// Select an action index given the current game state observation. + /// `legal`: mask of valid action indices. + fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; +} +``` + +--- + +## 5. MCTS Implementation + +### 5.1 Node + +```rust +pub struct MctsNode { + n: u32, // visit count N(s, a) + w: f32, // sum of backed-up values W(s, a) + p: f32, // prior from policy head P(s, a) + children: Vec<(usize, MctsNode)>, // (action_idx, child) + is_expanded: bool, +} + +impl MctsNode { + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT score used for selection. + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} +``` + +### 5.2 Simulation Loop + +One MCTS simulation (for deterministic decision nodes): + +``` +1. SELECTION — traverse from root, always pick child with highest PUCT, + auto-advancing forced/chance nodes via env.apply_chance(). +2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get + (policy_logits, value). Mask illegal actions, softmax + the remaining logits → priors P(s,a) for each child. +3. BACKUP — propagate -value up the tree (negate at each level because + perspective alternates between P1 and P2). +``` + +After `n_simulations` iterations, action selection at the root: + +```rust +// During training: sample proportional to N^(1/temperature) +// During evaluation: argmax N +fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } +``` + +### 5.3 Configuration + +```rust +pub struct MctsConfig { + pub n_simulations: usize, // e.g. 200 + pub c_puct: f32, // exploration constant, e.g. 1.5 + pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 + pub dirichlet_eps: f32, // noise weight, e.g. 0.25 + pub temperature: f32, // action sampling temperature (anneals to 0) +} +``` + +### 5.4 Handling Chance Nodes Inside MCTS + +When simulation reaches a Chance node (dice roll), the environment automatically +samples dice and advances to the next decision node. The MCTS tree does **not** +branch on dice outcomes — it treats the sampled outcome as the state. This +corresponds to "outcome sampling" (approach A from §2.2). Because each +simulation independently samples dice, the Q-values at player nodes converge to +their expected value over many simulations. + +--- + +## 6. Network Architecture + +### 6.1 ResNet Policy-Value Network + +A single trunk with residual blocks, then two heads: + +``` +Input: [batch, 217] + ↓ +Linear(217 → 512) + ReLU + ↓ +ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) + ↓ trunk output [batch, 512] + ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) + └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) +``` + +Burn implementation sketch: + +```rust +#[derive(Module, Debug)] +pub struct TrictracNet { + input: Linear, + res_blocks: Vec>, + policy_head: Linear, + value_head: Linear, +} + +impl TrictracNet { + pub fn forward(&self, obs: Tensor) + -> (Tensor, Tensor) + { + let x = activation::relu(self.input.forward(obs)); + let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); + let policy = self.policy_head.forward(x.clone()); // raw logits + let value = activation::tanh(self.value_head.forward(x)) + .squeeze(1); + (policy, value) + } +} +``` + +A simpler MLP (no residual blocks) is sufficient for a first version and much +faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two +heads. + +### 6.2 Loss Function + +``` +L = MSE(value_pred, z) + + CrossEntropy(policy_logits_masked, π_mcts) + - c_l2 * L2_regularization +``` + +Where: + +- `z` = game outcome (±1) from the active player's perspective +- `π_mcts` = normalized MCTS visit counts at the root (the policy target) +- Legal action masking is applied before computing CrossEntropy + +--- + +## 7. AlphaZero Training Loop + +``` +INIT + network ← random weights + replay ← empty ReplayBuffer(capacity = 100_000) + +LOOP forever: + ── Self-play phase ────────────────────────────────────────────── + (parallel with rayon, n_workers games at once) + for each game: + state ← env.new_game() + samples = [] + while not terminal: + advance forced/chance nodes automatically + obs ← env.observation(state, current_player) + legal ← env.legal_actions(state) + π, root_value ← mcts.run(state, network, config) + action ← sample from π (with temperature) + samples.push((obs, π, current_player)) + env.apply(state, action) + z ← env.returns(state) // final scores + for (obs, π, player) in samples: + replay.push(TrainSample { obs, policy: π, value: z[player] }) + + ── Training phase ─────────────────────────────────────────────── + for each gradient step: + batch ← replay.sample(batch_size) + (policy_logits, value_pred) ← network.forward(batch.obs) + loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) + optimizer.step(loss.backward()) + + ── Evaluation (every N iterations) ───────────────────────────── + win_rate ← evaluate(network_new vs network_prev, n_eval_games) + if win_rate > 0.55: save checkpoint +``` + +### 7.1 Replay Buffer + +```rust +pub struct TrainSample { + pub obs: Vec, // 217 values + pub policy: Vec, // 514 values (normalized MCTS visit counts) + pub value: f32, // game outcome ∈ {-1, 0, +1} +} + +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + pub fn push(&mut self, s: TrainSample) { + if self.data.len() == self.capacity { self.data.pop_front(); } + self.data.push_back(s); + } + + pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + // sample without replacement + } +} +``` + +### 7.2 Parallelism Strategy + +Self-play is embarrassingly parallel (each game is independent): + +```rust +let samples: Vec = (0..n_games) + .into_par_iter() // rayon + .flat_map(|_| generate_episode(&env, &network, &mcts_config)) + .collect(); +``` + +Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. +Self-play uses inference-only (no gradient tape), so a `NdArray` backend +(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with +`Autodiff>`. + +For larger scale, a producer-consumer architecture (crossbeam-channel) separates +self-play workers from the training thread, allowing continuous data generation +while the GPU trains. + +--- + +## 8. `TrictracEnv` Implementation Sketch + +```rust +use trictrac_store::{ + training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +#[derive(Clone)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + fn current_player(&self, s: &GameState) -> Player { + match s.stage { + Stage::Ended => Player::Terminal, + _ => match s.turn_stage { + TurnStage::RollWaiting => Player::Chance, + _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, + }, + } + } + + fn legal_actions(&self, s: &GameState) -> Vec { + let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; + get_valid_action_indices(&view).unwrap_or_default() + } + + fn apply(&self, s: &mut GameState, action_idx: usize) { + // advance all forced/chance nodes first, then apply the player action + self.advance_forced(s); + let needs_mirror = s.active_player_id == 2; + let view = if needs_mirror { s.mirror() } else { s.clone() }; + if let Some(event) = TrictracAction::from_action_index(action_idx) + .and_then(|a| a.to_event(&view)) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + { + let _ = s.consume(&event); + } + // advance any forced stages that follow + self.advance_forced(s); + } + + fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { + // RollDice → RollWaiting + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + // RollWaiting → next stage + let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; + let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); + self.advance_forced(s); + } + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } + } + + fn obs_size(&self) -> usize { 217 } + fn action_space(&self) -> usize { ACTION_SPACE_SIZE } + + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { return None; } + // Convert hole+point scores to ±1 outcome + let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [ 1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [ 0.0, 0.0], + }) + } +} + +impl TrictracEnv { + /// Advance through all forced (non-decision, non-chance) stages. + fn advance_forced(&self, s: &mut GameState) { + use trictrac_store::PointsRules; + loop { + match s.turn_stage { + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // Scoring is deterministic; compute and apply automatically. + let color = s.player_color_by_id(&s.active_player_id) + .unwrap_or(trictrac_store::Color::White); + let drc = s.players.get(&s.active_player_id) + .map(|p| p.dice_roll_count).unwrap_or(0); + let pr = PointsRules::new(&color, &s.board, s.dice); + let pts = pr.get_points(drc); + let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; + let _ = s.consume(&GameEvent::Mark { + player_id: s.active_player_id, points, + }); + } + TurnStage::RollDice => { + // RollDice is a forced "initiate roll" action with no real choice. + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + } + _ => break, + } + } + } +} +``` + +--- + +## 9. Cargo.toml Changes + +### 9.1 Add `spiel_bot` to the workspace + +```toml +# Cargo.toml (workspace root) +[workspace] +resolver = "2" +members = ["client_cli", "bot", "store", "spiel_bot"] +``` + +### 9.2 `spiel_bot/Cargo.toml` + +```toml +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[features] +default = ["alphazero"] +alphazero = [] +# dqn = [] # future +# ppo = [] # future + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rayon = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Burn: NdArray for pure-Rust CPU training +# Replace NdArray with Wgpu or Tch for GPU. +burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +# Optional: progress display and structured logging +indicatif = "0.17" +tracing = "0.1" + +[[bin]] +name = "az_train" +path = "src/bin/az_train.rs" + +[[bin]] +name = "az_eval" +path = "src/bin/az_eval.rs" +``` + +--- + +## 10. Comparison: `bot` crate vs `spiel_bot` + +| Aspect | `bot` (existing) | `spiel_bot` (proposed) | +| ---------------- | --------------------------- | -------------------------------------------- | +| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | +| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | +| Opponent | hardcoded random | self-play | +| Invalid actions | penalise with reward | legal action mask (no penalty) | +| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | +| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | +| Burn dep | yes (0.20) | yes (0.20), same backend | +| `burn-rl` dep | yes | no | +| C++ dep | no | no | +| Python dep | no | no | +| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | + +The two crates are **complementary**: `bot` is a working DQN/PPO baseline; +`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The +`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just +replace `TrictracEnvironment` with `TrictracEnv`). + +--- + +## 11. Implementation Order + +1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game + through the trait, verify terminal state and returns). +2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + + Burn forward/backward pass test with dummy data. +3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests + (visit counts sum to `n_simulations`, legal mask respected). +4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub + (one iteration, check loss decreases). +5. **Integration test**: run 100 self-play games with a tiny network (1 res block, + 64 hidden units), verify the training loop completes without panics. +6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s + on CPU, consistent with `random_game` throughput). +7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule + hyperparameter sweep. +8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report + win rate every checkpoint. + +--- + +## 12. Key Open Questions + +1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. + AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever + scored more holes). Better option: normalize the score margin. The final + choice affects how the value head is trained. + +2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). + MCTS with 200 simulations per step means ~120k network evaluations per game. + At batch inference this is feasible on CPU; on GPU it becomes fast. Consider + limiting `n_simulations` to 50–100 for early training. + +3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). + This is a long-horizon decision that AlphaZero handles naturally via MCTS + lookahead, but needs careful value normalization (a "Go" restarts scoring + within the same game). + +4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated + to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. + This is optional but reduces code duplication. + +5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, + 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. + A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. + +## Implementation results + +All benchmarks compile and run. Here's the complete results table: + +| Group | Benchmark | Time | +| ------- | ----------------------- | --------------------- | +| env | apply_chance | 3.87 µs | +| | legal_actions | 1.91 µs | +| | observation (to_tensor) | 341 ns | +| | random_game (baseline) | 3.55 ms → 282 games/s | +| network | mlp_b1 hidden=64 | 94.9 µs | +| | mlp_b32 hidden=64 | 141 µs | +| | mlp_b1 hidden=256 | 352 µs | +| | mlp_b32 hidden=256 | 479 µs | +| mcts | zero_eval n=1 | 6.8 µs | +| | zero_eval n=5 | 23.9 µs | +| | zero_eval n=20 | 90.9 µs | +| | mlp64 n=1 | 203 µs | +| | mlp64 n=5 | 622 µs | +| | mlp64 n=20 | 2.30 ms | +| episode | trictrac n=1 | 51.8 ms → 19 games/s | +| | trictrac n=2 | 145 ms → 7 games/s | +| train | mlp64 Adam b=16 | 1.93 ms | +| | mlp64 Adam b=64 | 2.68 ms | + +Key observations: + +- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) +- observation (217-value tensor): only 341 ns — not a bottleneck +- legal_actions: 1.9 µs — well optimised +- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs +- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal +- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it +- Training: 2.7 ms/step at batch=64 → 370 steps/s + +### Summary of Step 8 + +spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: + +- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct +- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) +- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side +- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) +- Output: win/draw/loss per side + combined decisive win rate + +Typical usage after training: +cargo run -p spiel_bot --bin az_eval --release -- \ + --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 + +### az_train + +#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) + +cargo run -p spiel_bot --bin az_train --release + +#### ResNet, more games, custom output dir + +cargo run -p spiel_bot --bin az_train --release -- \ + --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ + --save-every 20 --out checkpoints/ + +#### Resume from iteration 50 + +cargo run -p spiel_bot --bin az_train --release -- \ + --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 + +What the binary does each iteration: + +1. Calls model.valid() to get a zero-overhead inference copy for self-play +2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) +3. Pushes samples into a ReplayBuffer (capacity --replay-cap) +4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min +5. Saves a .mpk checkpoint every --save-every iterations and always on the last From a6644e3c9dd7ff09227814e3a58322d0942c57bd Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:10:49 +0100 Subject: [PATCH 03/13] fix: to_tensor() normalization --- store/src/game.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/store/src/game.rs b/store/src/game.rs index f553bdb..2fde45c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -225,22 +225,26 @@ impl GameState { let mut t = Vec::with_capacity(217); let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - // [0..95] own (White) checkers, TD-Gammon encoding + // [0..95] own (White) checkers, TD-Gammon encoding. + // Each field contributes 4 values: + // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] + // The overflow term is divided by 12 because the maximum excess is + // 15 (all checkers) − 3 = 12. for &c in &pos { let own = c.max(0) as u8; t.push((own == 1) as u8 as f32); t.push((own == 2) as u8 as f32); t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32); + t.push(own.saturating_sub(3) as f32 / 12.0); } - // [96..191] opp (Black) checkers, TD-Gammon encoding + // [96..191] opp (Black) checkers, same encoding. for &c in &pos { let opp = (-c).max(0) as u8; t.push((opp == 1) as u8 as f32); t.push((opp == 2) as u8 as f32); t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32); + t.push(opp.saturating_sub(3) as f32 / 12.0); } // [192..193] dice From df05a430225ca63485d22dcd3c1794885df512dc Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:12:59 +0100 Subject: [PATCH 04/13] feat(spiel_bot): init crate & implements `GameEnv` trait + `TrictracEnv` --- Cargo.lock | 9 + Cargo.toml | 2 +- spiel_bot/Cargo.toml | 9 + spiel_bot/src/env/mod.rs | 121 ++++++++ spiel_bot/src/env/trictrac.rs | 535 ++++++++++++++++++++++++++++++++++ spiel_bot/src/lib.rs | 1 + 6 files changed, 676 insertions(+), 1 deletion(-) create mode 100644 spiel_bot/Cargo.toml create mode 100644 spiel_bot/src/env/mod.rs create mode 100644 spiel_bot/src/env/trictrac.rs create mode 100644 spiel_bot/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index a43261e..d1f5a20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5891,6 +5891,15 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spiel_bot" +version = "0.1.0" +dependencies = [ + "anyhow", + "rand 0.9.2", + "trictrac-store", +] + [[package]] name = "spin" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index b9e6d45..4c2eb15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store"] +members = ["client_cli", "bot", "store", "spiel_bot"] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml new file mode 100644 index 0000000..2459f51 --- /dev/null +++ b/spiel_bot/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs new file mode 100644 index 0000000..42b4ae0 --- /dev/null +++ b/spiel_bot/src/env/mod.rs @@ -0,0 +1,121 @@ +//! Game environment abstraction — the minimal "Rust OpenSpiel". +//! +//! A `GameEnv` describes the rules of a two-player, zero-sum game that may +//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, +//! and PPO interact with a game exclusively through this trait. +//! +//! # Node taxonomy +//! +//! Every game position belongs to one of four categories, returned by +//! [`GameEnv::current_player`]: +//! +//! | [`Player`] | Meaning | +//! |-----------|---------| +//! | `P1` | Player 1 (index 0) must choose an action | +//! | `P2` | Player 2 (index 1) must choose an action | +//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | +//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | +//! +//! # Perspective convention +//! +//! [`GameEnv::observation`] always returns the board from *the requested +//! player's* point of view. Callers pass `pov = 0` for Player 1 and +//! `pov = 1` for Player 2. The implementation is responsible for any +//! mirroring required (e.g. Trictrac always reasons from White's side). + +pub mod trictrac; +pub use trictrac::TrictracEnv; + +/// Who controls the current game node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Player { + /// Player 1 (index 0) is to move. + P1, + /// Player 2 (index 1) is to move. + P2, + /// A stochastic event (dice roll, etc.) must be resolved. + Chance, + /// The game is over. + Terminal, +} + +impl Player { + /// Returns the player index (0 or 1) if this is a decision node, + /// or `None` for `Chance` / `Terminal`. + pub fn index(self) -> Option { + match self { + Player::P1 => Some(0), + Player::P2 => Some(1), + _ => None, + } + } + + pub fn is_decision(self) -> bool { + matches!(self, Player::P1 | Player::P2) + } + + pub fn is_chance(self) -> bool { + self == Player::Chance + } + + pub fn is_terminal(self) -> bool { + self == Player::Terminal + } +} + +/// Trait that completely describes a two-player zero-sum game. +/// +/// Implementors must be cheaply cloneable (the type is used as a stateless +/// factory; the mutable game state lives in `Self::State`). +pub trait GameEnv: Clone + Send + Sync + 'static { + /// The mutable game state. Must be `Clone` so MCTS can copy + /// game trees without touching the environment. + type State: Clone + Send + Sync; + + // ── State creation ──────────────────────────────────────────────────── + + /// Create a fresh game state at the initial position. + fn new_game(&self) -> Self::State; + + // ── Node queries ────────────────────────────────────────────────────── + + /// Classify the current node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). + /// + /// The returned indices are in `[0, action_space())`. + /// The result is unspecified (may panic or return empty) when called at a + /// `Chance` or `Terminal` node. + fn legal_actions(&self, s: &Self::State) -> Vec; + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action. `action` must be a value returned by + /// [`legal_actions`] for the current state. + fn apply(&self, s: &mut Self::State, action: usize); + + /// Sample and apply a stochastic outcome. Must only be called when + /// `current_player(s) == Player::Chance`. + fn apply_chance(&self, s: &mut Self::State, rng: &mut R); + + // ── Observation ─────────────────────────────────────────────────────── + + /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). + /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Number of floats returned by [`observation`]. + fn obs_size(&self) -> usize; + + /// Total number of distinct action indices (the policy head output size). + fn action_space(&self) -> usize; + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Game outcome for each player, or `None` if the game is not over. + /// + /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. + /// Index 0 = Player 1, index 1 = Player 2. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs new file mode 100644 index 0000000..99ba058 --- /dev/null +++ b/spiel_bot/src/env/trictrac.rs @@ -0,0 +1,535 @@ +//! [`GameEnv`] implementation for Trictrac. +//! +//! # Game flow (schools_enabled = false) +//! +//! With scoring schools disabled (the standard training configuration), +//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine +//! applies them automatically inside `RollResult` and `Move`. The only +//! four stages that actually occur are: +//! +//! | `TurnStage` | [`Player`] kind | Handled by | +//! |-------------|-----------------|------------| +//! | `RollDice` | `Chance` | [`apply_chance`] | +//! | `RollWaiting` | `Chance` | [`apply_chance`] | +//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | +//! | `Move` | `P1`/`P2` | [`apply`] | +//! +//! # Perspective +//! +//! The Trictrac engine always reasons from White's perspective. Player 1 is +//! White; Player 2 is Black. When Player 2 is active, the board is mirrored +//! before computing legal actions / the observation tensor, and the resulting +//! event is mirrored back before being applied to the real state. This +//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. + +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, + Dice, GameEvent, GameState, Stage, TurnStage, +}; + +use super::{GameEnv, Player}; + +/// Stateless factory that produces Trictrac [`GameState`] environments. +/// +/// Schools (`schools_enabled`) are always disabled — scoring is automatic. +#[derive(Clone, Debug, Default)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + // ── State creation ──────────────────────────────────────────────────── + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + // ── Node queries ────────────────────────────────────────────────────── + + fn current_player(&self, s: &GameState) -> Player { + if s.stage == Stage::Ended { + return Player::Terminal; + } + match s.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, + _ => { + if s.active_player_id == 1 { + Player::P1 + } else { + Player::P2 + } + } + } + } + + /// Returns the legal action indices for the active player. + /// + /// The board is automatically mirrored for Player 2 so that the engine + /// always reasons from White's perspective. The returned indices are + /// identical in meaning for both players (checker ordinals are + /// perspective-relative). + /// + /// # Panics + /// + /// Panics in debug builds if called at a `Chance` or `Terminal` node. + fn legal_actions(&self, s: &GameState) -> Vec { + debug_assert!( + self.current_player(s).is_decision(), + "legal_actions called at a non-decision node (turn_stage={:?})", + s.turn_stage + ); + let indices = if s.active_player_id == 2 { + get_valid_action_indices(&s.mirror()) + } else { + get_valid_action_indices(s) + }; + indices.unwrap_or_default() + } + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action index to the game state. + /// + /// For Player 2, the action is decoded against the mirrored board and + /// the resulting event is un-mirrored before being applied. + /// + /// # Panics + /// + /// Panics in debug builds if `action` cannot be decoded or does not + /// produce a valid event for the current state. + fn apply(&self, s: &mut GameState, action: usize) { + let needs_mirror = s.active_player_id == 2; + + let event = if needs_mirror { + let view = s.mirror(); + TrictracAction::from_action_index(action) + .and_then(|a| a.to_event(&view)) + .map(|e| e.get_mirror(false)) + } else { + TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) + }; + + match event { + Some(e) => { + s.consume(&e).expect("apply: consume failed for valid action"); + } + None => { + panic!("apply: action index {action} produced no event in state {s}"); + } + } + } + + /// Sample dice and advance through a chance node. + /// + /// Handles both `RollDice` (triggers the roll mechanism, then samples + /// dice) and `RollWaiting` (only samples dice) in a single call so that + /// callers never need to distinguish the two. + /// + /// # Panics + /// + /// Panics in debug builds if called at a non-Chance node. + fn apply_chance(&self, s: &mut GameState, rng: &mut R) { + debug_assert!( + self.current_player(s).is_chance(), + "apply_chance called at a non-Chance node (turn_stage={:?})", + s.turn_stage + ); + + // Step 1: RollDice → RollWaiting (player initiates the roll). + if s.turn_stage == TurnStage::RollDice { + s.consume(&GameEvent::Roll { + player_id: s.active_player_id, + }) + .expect("apply_chance: Roll event failed"); + } + + // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. + // With schools_enabled=false, point marking is automatic inside consume(). + let dice = Dice { + values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), + }; + s.consume(&GameEvent::RollResult { + player_id: s.active_player_id, + dice, + }) + .expect("apply_chance: RollResult event failed"); + } + + // ── Observation ─────────────────────────────────────────────────────── + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { + s.to_tensor() + } else { + s.mirror().to_tensor() + } + } + + fn obs_size(&self) -> usize { + 217 + } + + fn action_space(&self) -> usize { + ACTION_SPACE_SIZE + } + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. + /// + /// The winner (higher cumulative score) receives `+1.0`; the loser + /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score + /// is `holes × 12 + points`. + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { + return None; + } + let score = |id: u64| -> i32 { + s.players + .get(&id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + }; + let s1 = score(1); + let s2 = score(2); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [0.0, 0.0], + }) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{rngs::SmallRng, Rng, SeedableRng}; + + fn env() -> TrictracEnv { + TrictracEnv + } + + fn seeded_rng(seed: u64) -> SmallRng { + SmallRng::seed_from_u64(seed) + } + + // ── Initial state ───────────────────────────────────────────────────── + + #[test] + fn new_game_is_chance_node() { + let e = env(); + let s = e.new_game(); + // A fresh game starts at RollDice — a Chance node. + assert_eq!(e.current_player(&s), Player::Chance); + assert!(e.returns(&s).is_none()); + } + + #[test] + fn new_game_is_not_terminal() { + let e = env(); + let s = e.new_game(); + assert_ne!(e.current_player(&s), Player::Terminal); + assert!(e.returns(&s).is_none()); + } + + // ── Chance nodes ────────────────────────────────────────────────────── + + #[test] + fn apply_chance_reaches_decision_node() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(1); + + // A single chance step must yield a decision node (or end the game, + // which only happens after 12 holes — impossible on the first roll). + e.apply_chance(&mut s, &mut rng); + let p = e.current_player(&s); + assert!( + p.is_decision(), + "expected decision node after first roll, got {p:?}" + ); + } + + #[test] + fn apply_chance_from_rollwaiting() { + // Check that apply_chance works when called mid-way (at RollWaiting). + let e = env(); + let mut s = e.new_game(); + assert_eq!(s.turn_stage, TurnStage::RollDice); + + // Manually advance to RollWaiting. + s.consume(&GameEvent::Roll { player_id: s.active_player_id }) + .unwrap(); + assert_eq!(s.turn_stage, TurnStage::RollWaiting); + + let mut rng = seeded_rng(2); + e.apply_chance(&mut s, &mut rng); + + let p = e.current_player(&s); + assert!(p.is_decision() || p.is_terminal()); + } + + // ── Legal actions ───────────────────────────────────────────────────── + + #[test] + fn legal_actions_nonempty_after_roll() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(3); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let actions = e.legal_actions(&s); + assert!( + !actions.is_empty(), + "legal_actions must be non-empty at a decision node" + ); + } + + #[test] + fn legal_actions_within_action_space() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(4); + + e.apply_chance(&mut s, &mut rng); + for &a in e.legal_actions(&s).iter() { + assert!( + a < e.action_space(), + "action {a} out of bounds (action_space={})", + e.action_space() + ); + } + } + + // ── Observations ────────────────────────────────────────────────────── + + #[test] + fn observation_has_correct_size() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(5); + e.apply_chance(&mut s, &mut rng); + + assert_eq!(e.observation(&s, 0).len(), e.obs_size()); + assert_eq!(e.observation(&s, 1).len(), e.obs_size()); + } + + #[test] + fn observation_values_in_unit_interval() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(6); + e.apply_chance(&mut s, &mut rng); + + for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { + for (i, &v) in obs.iter().enumerate() { + assert!( + v >= 0.0 && v <= 1.0, + "pov={pov}: obs[{i}] = {v} is outside [0,1]" + ); + } + } + } + + #[test] + fn p1_and_p2_observations_differ() { + // The board is mirrored for P2, so the two observations should differ + // whenever there are checkers in non-symmetric positions (always true + // in a real game after a few moves). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(7); + + // Advance far enough that the board is non-trivial. + for _ in 0..6 { + while e.current_player(&s).is_chance() { + e.apply_chance(&mut s, &mut rng); + } + if e.current_player(&s).is_terminal() { + break; + } + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + + if !e.current_player(&s).is_terminal() { + let obs0 = e.observation(&s, 0); + let obs1 = e.observation(&s, 1); + assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); + } + } + + // ── Applying actions ────────────────────────────────────────────────── + + #[test] + fn apply_changes_state() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(8); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let before = s.clone(); + let action = e.legal_actions(&s)[0]; + e.apply(&mut s, action); + + assert_ne!( + before.turn_stage, s.turn_stage, + "state must change after apply" + ); + } + + #[test] + fn apply_all_legal_actions_do_not_panic() { + // Verify that every action returned by legal_actions can be applied + // without panicking (on several independent copies of the same state). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(9); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + for action in e.legal_actions(&s) { + let mut copy = s.clone(); + e.apply(&mut copy, action); // must not panic + } + } + + // ── Full game ───────────────────────────────────────────────────────── + + /// Run a complete game with random actions through the `GameEnv` trait + /// and verify that: + /// - The game terminates. + /// - `returns()` is `Some` at the end. + /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's + /// score is ±1 / 0. + /// - No step panics. + #[test] + fn full_random_game_terminates() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(42); + let max_steps = 50_000; + + for step in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); + } + + let result = e.returns(&s); + assert!(result.is_some(), "returns() must be Some at Terminal"); + + let [r1, r2] = result.unwrap(); + let sum = r1 + r2; + assert!( + (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, + "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" + ); + assert!( + r1.abs() <= 1.0 && r2.abs() <= 1.0, + "returns must be in [-1,1]: r1={r1}, r2={r2}" + ); + } + + /// Run multiple games with different seeds to stress-test for panics. + #[test] + fn multiple_games_no_panic() { + let e = env(); + let max_steps = 20_000; + + for seed in 0..10u64 { + let mut s = e.new_game(); + let mut rng = seeded_rng(seed); + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + } + } + } + + // ── Returns ─────────────────────────────────────────────────────────── + + #[test] + fn returns_none_mid_game() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(11); + + // Advance a few steps but do not finish the game. + for _ in 0..4 { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + if !e.current_player(&s).is_terminal() { + assert!( + e.returns(&s).is_none(), + "returns() must be None before the game ends" + ); + } + } + + // ── Player 2 actions ────────────────────────────────────────────────── + + /// Verify that Player 2 (Black) can take actions without panicking, + /// and that the state advances correctly. + #[test] + fn player2_can_act() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(12); + + // Keep stepping until Player 2 gets a turn. + let max_steps = 5_000; + let mut p2_acted = false; + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty()); + e.apply(&mut s, actions[0]); + p2_acted = true; + break; + } + Player::P1 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs new file mode 100644 index 0000000..3d7924f --- /dev/null +++ b/spiel_bot/src/lib.rs @@ -0,0 +1 @@ +pub mod env; From d5cd4c2402d294ad5a3ab0fa23d778193856ef60 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:30:27 +0100 Subject: [PATCH 05/13] feat(spiel_bot): network with mlp and resnet --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/network/mlp.rs | 223 ++++++++++++++++++++++++++++ spiel_bot/src/network/mod.rs | 64 ++++++++ spiel_bot/src/network/resnet.rs | 253 ++++++++++++++++++++++++++++++++ 6 files changed, 543 insertions(+) create mode 100644 spiel_bot/src/network/mlp.rs create mode 100644 spiel_bot/src/network/mod.rs create mode 100644 spiel_bot/src/network/resnet.rs diff --git a/Cargo.lock b/Cargo.lock index d1f5a20..2e81285 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5896,6 +5896,7 @@ name = "spiel_bot" version = "0.1.0" dependencies = [ "anyhow", + "burn", "rand 0.9.2", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 2459f51..fba2aab 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 3d7924f..6e71016 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1 +1,2 @@ pub mod env; +pub mod network; diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs new file mode 100644 index 0000000..eb6184e --- /dev/null +++ b/spiel_bot/src/network/mlp.rs @@ -0,0 +1,223 @@ +//! Two-hidden-layer MLP policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`MlpNet`]. +#[derive(Debug, Clone)] +pub struct MlpConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for MlpConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 256, + } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Simple two-hidden-layer MLP with shared trunk and two heads. +/// +/// Prefer this over [`ResNet`](super::ResNet) when training time is a +/// priority, or as a fast baseline. +#[derive(Module, Debug)] +pub struct MlpNet { + fc1: Linear, + fc2: Linear, + policy_head: Linear, + value_head: Linear, +} + +impl MlpNet { + /// Construct a fresh network with random weights. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + value_head: LinearConfig::new(config.hidden_size, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + /// + /// The file is written exactly at `path`; callers should append `.mpk` if + /// they want the conventional extension. + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for MlpNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn default_net() -> MlpNet { + MlpNet::new(&MlpConfig::default(), &device()) + } + + fn zeros_obs(batch: usize) -> Tensor { + Tensor::zeros([batch, 217], &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let net = default_net(); + let obs = zeros_obs(4); + let (policy, value) = net.forward(obs); + + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let (policy, value) = net.forward(zeros_obs(1)); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let net = default_net(); + // Use a non-zero input so the output is not trivially at 0. + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net.forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Policy logits ───────────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + // With random weights the 514 logits should not all be identical. + let net = default_net(); + let (policy, _) = net.forward(zeros_obs(1)); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical — network may be degenerate"); + } + + // ── Config propagation ──────────────────────────────────────────────── + + #[test] + fn custom_config_shapes() { + let config = MlpConfig { + obs_size: 10, + action_size: 20, + hidden_size: 32, + }; + let net = MlpNet::::new(&config, &device()); + let obs = Tensor::zeros([3, 10], &device()); + let (policy, value) = net.forward(obs); + assert_eq!(policy.dims(), [3, 20]); + assert_eq!(value.dims(), [3, 1]); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = MlpConfig::default(); + let net = default_net(); + + // Forward pass before saving. + let obs = Tensor::::ones([2, 217], &device()); + let (policy_before, value_before) = net.forward(obs.clone()); + + // Save to a temp file. + let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); + net.save(&path).expect("save failed"); + + // Load into a fresh model. + let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + // Outputs must be bitwise identical. + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs new file mode 100644 index 0000000..df710e9 --- /dev/null +++ b/spiel_bot/src/network/mod.rs @@ -0,0 +1,64 @@ +//! Neural network abstractions for policy-value learning. +//! +//! # Trait +//! +//! [`PolicyValueNet`] is the single trait that all network architectures +//! implement. It takes an observation tensor and returns raw policy logits +//! plus a tanh-squashed scalar value estimate. +//! +//! # Architectures +//! +//! | Module | Description | Default hidden | +//! |--------|-------------|----------------| +//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | +//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | +//! +//! # Backend convention +//! +//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). +//! * **Training** — use `Autodiff>` so Burn can differentiate +//! through the forward pass. +//! +//! Both modes use the exact same struct; only the type-level backend changes: +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! type InferBackend = NdArray; +//! type TrainBackend = Autodiff>; +//! +//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! ``` +//! +//! # Output shapes +//! +//! Given a batch of `B` observations of size `obs_size`: +//! +//! | Output | Shape | Range | +//! |--------|-------|-------| +//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | +//! | `value` | `[B, 1]` | (-1, 1) via tanh | +//! +//! Callers are responsible for masking illegal actions in `policy_logits` +//! before passing to softmax. + +pub mod mlp; +pub mod resnet; + +pub use mlp::{MlpConfig, MlpNet}; +pub use resnet::{ResNet, ResNetConfig}; + +use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; + +/// A neural network that produces a policy and a value from an observation. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) +/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses +/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. +/// Use an `Arc>` wrapper if cross-thread sharing is needed. +pub trait PolicyValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); +} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs new file mode 100644 index 0000000..d20d5ad --- /dev/null +++ b/spiel_bot/src/network/resnet.rs @@ -0,0 +1,253 @@ +//! Residual-block policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU (input projection) +//! → ResBlock × 4 (residual trunk) +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! +//! ResBlock: +//! x → Linear → ReLU → Linear → (+x) → ReLU +//! ``` +//! +//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better +//! suited for long training runs where board-pattern recognition matters. + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`ResNet`]. +#[derive(Debug, Clone)] +pub struct ResNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of all hidden layers (input projection + residual blocks). + pub hidden_size: usize, +} + +impl Default for ResNetConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 512, + } + } +} + +// ── Residual block ──────────────────────────────────────────────────────────── + +/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. +/// +/// Both linear layers preserve the hidden dimension so the skip connection +/// can be added without projection. +#[derive(Module, Debug)] +struct ResBlock { + fc1: Linear, + fc2: Linear, +} + +impl ResBlock { + fn new(hidden: usize, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(hidden, hidden).init(device), + fc2: LinearConfig::new(hidden, hidden).init(device), + } + } + + fn forward(&self, x: Tensor) -> Tensor { + let residual = x.clone(); + let out = relu(self.fc1.forward(x)); + relu(self.fc2.forward(out) + residual) + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Four-residual-block policy-value network. +/// +/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and +/// when representing complex positional patterns is important. +#[derive(Module, Debug)] +pub struct ResNet { + input: Linear, + block0: ResBlock, + block1: ResBlock, + block2: ResBlock, + block3: ResBlock, + policy_head: Linear, + value_head: Linear, +} + +impl ResNet { + /// Construct a fresh network with random weights. + pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { + let h = config.hidden_size; + Self { + input: LinearConfig::new(config.obs_size, h).init(device), + block0: ResBlock::new(h, device), + block1: ResBlock::new(h, device), + block2: ResBlock::new(h, device), + block3: ResBlock::new(h, device), + policy_head: LinearConfig::new(h, config.action_size).init(device), + value_head: LinearConfig::new(h, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for ResNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.input.forward(obs)); + let x = self.block0.forward(x); + let x = self.block1.forward(x); + let x = self.block2.forward(x); + let x = self.block3.forward(x); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn small_config() -> ResNetConfig { + // Use a small hidden size so tests are fast. + ResNetConfig { + obs_size: 217, + action_size: 514, + hidden_size: 64, + } + } + + fn net() -> ResNet { + ResNet::new(&small_config(), &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let obs = Tensor::zeros([4, 217], &device()); + let (policy, value) = net().forward(obs); + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net().forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Residual connections ────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical"); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = small_config(); + let model = net(); + let obs = Tensor::::ones([2, 217], &device()); + + let (policy_before, value_before) = model.forward(obs.clone()); + + let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); + model.save(&path).expect("save failed"); + + let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } + + // ── Integration: both architectures satisfy PolicyValueNet ──────────── + + #[test] + fn resnet_satisfies_trait() { + fn requires_net>(net: &N, obs: Tensor) { + let (p, v) = net.forward(obs); + assert_eq!(p.dims()[1], 514); + assert_eq!(v.dims()[1], 1); + } + requires_net(&net(), Tensor::zeros([2, 217], &device())); + } +} From 58ae8ad3b3533edb77e5a7a75d8c3400a3c381b0 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:45:02 +0100 Subject: [PATCH 06/13] feat(spiel_bot): Monte-Carlo tree search --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/mcts/mod.rs | 408 +++++++++++++++++++++++++++++++++++ spiel_bot/src/mcts/node.rs | 91 ++++++++ spiel_bot/src/mcts/search.rs | 170 +++++++++++++++ 6 files changed, 672 insertions(+) create mode 100644 spiel_bot/src/mcts/mod.rs create mode 100644 spiel_bot/src/mcts/node.rs create mode 100644 spiel_bot/src/mcts/search.rs diff --git a/Cargo.lock b/Cargo.lock index 2e81285..0baa02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5898,6 +5898,7 @@ dependencies = [ "anyhow", "burn", "rand 0.9.2", + "rand_distr", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index fba2aab..323c953 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 6e71016..5beb37c 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,2 +1,3 @@ pub mod env; +pub mod mcts; pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs new file mode 100644 index 0000000..e92bd09 --- /dev/null +++ b/spiel_bot/src/mcts/mod.rs @@ -0,0 +1,408 @@ +//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. +//! +//! # Algorithm +//! +//! The implementation follows AlphaZero's MCTS: +//! +//! 1. **Expand root** — run the network once to get priors and a value +//! estimate; optionally add Dirichlet noise for training-time exploration. +//! 2. **Simulate** `n_simulations` times: +//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. +//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; +//! chance nodes are **not** stored in the tree (outcome sampling). +//! - *Expansion* — evaluate the network at the leaf; populate children. +//! - *Backup* — propagate the value upward; negate at each player boundary. +//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). +//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). +//! +//! # Perspective convention +//! +//! Every [`MctsNode::w`] is stored **from the perspective of the player who +//! acts at that node**. The backup negates the child value whenever the +//! acting player differs between parent and child. +//! +//! # Stochastic games +//! +//! When [`GameEnv::current_player`] returns [`Player::Chance`], the +//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and +//! continues. Chance nodes are skipped transparently; Q-values converge to +//! their expectation over many simulations (outcome sampling). + +pub mod node; +mod search; + +pub use node::MctsNode; + +use rand::Rng; + +use crate::env::GameEnv; + +// ── Evaluator trait ──────────────────────────────────────────────────────── + +/// Evaluates a game position for use in MCTS. +/// +/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) +/// but the `mcts` module itself does **not** depend on Burn. +pub trait Evaluator: Send + Sync { + /// Evaluate `obs` (flat observation vector of length `obs_size`). + /// + /// Returns: + /// - `policy_logits`: one raw logit per action (`action_space` entries). + /// Illegal action entries are masked inside the search — no need to + /// zero them here. + /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. + fn evaluate(&self, obs: &[f32]) -> (Vec, f32); +} + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Hyperparameters for [`run_mcts`]. +#[derive(Debug, Clone)] +pub struct MctsConfig { + /// Number of MCTS simulations per move. Typical: 50–800. + pub n_simulations: usize, + /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. + pub c_puct: f32, + /// Dirichlet noise concentration α. Set to `0.0` to disable. + /// Typical: `0.3` for Chess, `0.1` for large action spaces. + pub dirichlet_alpha: f32, + /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. + pub dirichlet_eps: f32, + /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. + pub temperature: f32, +} + +impl Default for MctsConfig { + fn default() -> Self { + Self { + n_simulations: 200, + c_puct: 1.5, + dirichlet_alpha: 0.3, + dirichlet_eps: 0.25, + temperature: 1.0, + } + } +} + +// ── Public interface ─────────────────────────────────────────────────────── + +/// Run MCTS from `state` and return the populated root [`MctsNode`]. +/// +/// `state` must be a player-decision node (`P1` or `P2`). +/// Use [`mcts_policy`] and [`select_action`] on the returned root. +/// +/// # Panics +/// +/// Panics if `env.current_player(state)` is not `P1` or `P2`. +pub fn run_mcts( + env: &E, + state: &E::State, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, +) -> MctsNode { + let player_idx = env + .current_player(state) + .index() + .expect("run_mcts called at a non-decision node"); + + // ── Expand root (network called once here, not inside the loop) ──────── + let mut root = MctsNode::new(1.0); + search::expand::(&mut root, state, env, evaluator, player_idx); + + // ── Optional Dirichlet noise for training exploration ────────────────── + if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { + search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); + } + + // ── Simulations ──────────────────────────────────────────────────────── + for _ in 0..config.n_simulations { + search::simulate::( + &mut root, + state.clone(), + env, + evaluator, + config, + rng, + player_idx, + ); + } + + root +} + +/// Compute the MCTS policy: normalized visit counts at the root. +/// +/// Returns a vector of length `action_space` where `policy[a]` is the +/// fraction of simulations that visited action `a`. +pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { + let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); + let mut policy = vec![0.0f32; action_space]; + if total > 0.0 { + for (a, child) in &root.children { + policy[*a] = child.n as f32 / total; + } + } else if !root.children.is_empty() { + // n_simulations = 0: uniform over legal actions. + let uniform = 1.0 / root.children.len() as f32; + for (a, _) in &root.children { + policy[*a] = uniform; + } + } + policy +} + +/// Select an action index from the root after MCTS. +/// +/// * `temperature = 0` — greedy argmax of visit counts. +/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. +/// +/// # Panics +/// +/// Panics if the root has no children. +pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { + assert!(!root.children.is_empty(), "select_action called on a root with no children"); + if temperature <= 0.0 { + root.children + .iter() + .max_by_key(|(_, c)| c.n) + .map(|(a, _)| *a) + .unwrap() + } else { + let weights: Vec = root + .children + .iter() + .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) + .collect(); + let total: f32 = weights.iter().sum(); + let mut r: f32 = rng.random::() * total; + for (i, (a, _)) in root.children.iter().enumerate() { + r -= weights[i]; + if r <= 0.0 { + return *a; + } + } + root.children.last().map(|(a, _)| *a).unwrap() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + use crate::env::Player; + + // ── Minimal deterministic test game ─────────────────────────────────── + // + // "Countdown" — two players alternate subtracting 1 or 2 from a counter. + // The player who brings the counter to 0 wins. + // No chance nodes, two legal actions (0 = -1, 1 = -2). + + #[derive(Clone, Debug)] + struct CState { + remaining: u8, + to_move: usize, // at terminal: last mover (winner) + } + + #[derive(Clone)] + struct CountdownEnv; + + impl crate::env::GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: 6, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { + Player::Terminal + } else if s.to_move == 0 { + Player::P1 + } else { + Player::P2 + } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + // to_move stays as winner + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 6.0, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + // Uniform evaluator: all logits = 0, value = 0. + // `action_space` must match the environment's `action_space()`. + struct ZeroEval(usize); + impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(42) + } + + fn config_n(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, // off for reproducibility + dirichlet_eps: 0.0, + temperature: 1.0, + } + } + + // ── Visit count tests ───────────────────────────────────────────────── + + #[test] + fn visit_counts_sum_to_n_simulations() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 50, "visit counts must sum to n_simulations"); + } + + #[test] + fn all_root_children_are_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); + for (a, _) in &root.children { + assert!(legal.contains(a), "child action {a} is not legal"); + } + } + + // ── Policy tests ───────────────────────────────────────────────────── + + #[test] + fn policy_sums_to_one() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); + } + + #[test] + fn policy_zero_for_illegal_actions() { + let env = CountdownEnv; + // remaining = 1 → only action 0 is legal + let state = CState { remaining: 1, to_move: 0 }; + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); + } + + // ── Action selection tests ──────────────────────────────────────────── + + #[test] + fn greedy_selects_most_visited() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); + let greedy = select_action(&root, 0.0, &mut rng()); + let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); + assert_eq!(greedy, most_visited); + } + + #[test] + fn temperature_sampling_stays_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let mut r = rng(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); + for _ in 0..20 { + let a = select_action(&root, 1.0, &mut r); + assert!(legal.contains(&a), "sampled action {a} is not legal"); + } + } + + // ── Zero-simulation edge case ───────────────────────────────────────── + + #[test] + fn zero_simulations_uniform_policy() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + // With 0 simulations, fallback is uniform over the 2 legal actions. + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + // ── Root value ──────────────────────────────────────────────────────── + + #[test] + fn root_q_in_valid_range() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); + let q = root.q(); + assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); + } + + // ── Integration: run on a real Trictrac game ────────────────────────── + + #[test] + fn no_panic_on_trictrac_state() { + use crate::env::TrictracEnv; + + let env = TrictracEnv; + let mut state = env.new_game(); + let mut r = rng(); + + // Advance past the initial chance node to reach a decision node. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, &mut r); + } + + if env.current_player(&state).is_terminal() { + return; // unlikely but safe + } + + let config = MctsConfig { + n_simulations: 5, // tiny for speed + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + ..MctsConfig::default() + }; + + let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); + assert!(root.n > 0); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 5); + } +} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs new file mode 100644 index 0000000..aff7735 --- /dev/null +++ b/spiel_bot/src/mcts/node.rs @@ -0,0 +1,91 @@ +//! MCTS tree node. +//! +//! [`MctsNode`] holds the visit statistics for one player-decision position in +//! the search tree. A node is *expanded* the first time the policy-value +//! network is evaluated there; before that it is a leaf. + +/// One node in the MCTS tree, representing a player-decision position. +/// +/// `w` stores the sum of values backed up into this node, always from the +/// perspective of **the player who acts here**. `q()` therefore also returns +/// a value in `(-1, 1)` from that same perspective. +#[derive(Debug)] +pub struct MctsNode { + /// Visit count `N(s, a)`. + pub n: u32, + /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. + pub w: f32, + /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). + pub p: f32, + /// Children: `(action_index, child_node)`, populated on first expansion. + pub children: Vec<(usize, MctsNode)>, + /// `true` after the network has been evaluated and children have been set up. + pub expanded: bool, +} + +impl MctsNode { + /// Create a fresh, unexpanded leaf with the given prior probability. + pub fn new(prior: f32) -> Self { + Self { + n: 0, + w: 0.0, + p: prior, + children: Vec::new(), + expanded: false, + } + } + + /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. + #[inline] + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT selection score: + /// + /// ```text + /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) + /// ``` + #[inline] + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_zero_when_unvisited() { + let node = MctsNode::new(0.5); + assert_eq!(node.q(), 0.0); + } + + #[test] + fn q_reflects_w_over_n() { + let mut node = MctsNode::new(0.5); + node.n = 4; + node.w = 2.0; + assert!((node.q() - 0.5).abs() < 1e-6); + } + + #[test] + fn puct_exploration_dominates_unvisited() { + // Unvisited child should outscore a visited child with negative Q. + let mut visited = MctsNode::new(0.5); + visited.n = 10; + visited.w = -5.0; // Q = -0.5 + + let unvisited = MctsNode::new(0.5); + + let parent_n = 10; + let c = 1.5; + assert!( + unvisited.puct(parent_n, c) > visited.puct(parent_n, c), + "unvisited child should have higher PUCT than a negatively-valued visited child" + ); + } +} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs new file mode 100644 index 0000000..c4960c7 --- /dev/null +++ b/spiel_bot/src/mcts/search.rs @@ -0,0 +1,170 @@ +//! Simulation, expansion, backup, and noise helpers. +//! +//! These are internal to the `mcts` module; the public entry points are +//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. + +use rand::Rng; +use rand_distr::{Gamma, Distribution}; + +use crate::env::GameEnv; +use super::{Evaluator, MctsConfig}; +use super::node::MctsNode; + +// ── Masked softmax ───────────────────────────────────────────────────────── + +/// Numerically stable softmax over `legal` actions only. +/// +/// Illegal logits are treated as `-∞` and receive probability `0.0`. +/// Returns a probability vector of length `action_space`. +pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { + let mut probs = vec![0.0f32; action_space]; + if legal.is_empty() { + return probs; + } + let max_logit = legal + .iter() + .map(|&a| logits[a]) + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for &a in legal { + let e = (logits[a] - max_logit).exp(); + probs[a] = e; + sum += e; + } + if sum > 0.0 { + for &a in legal { + probs[a] /= sum; + } + } else { + let uniform = 1.0 / legal.len() as f32; + for &a in legal { + probs[a] = uniform; + } + } + probs +} + +// ── Dirichlet noise ──────────────────────────────────────────────────────── + +/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. +/// +/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. +/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. +pub(super) fn add_dirichlet_noise( + node: &mut MctsNode, + alpha: f32, + eps: f32, + rng: &mut impl Rng, +) { + let n = node.children.len(); + if n == 0 { + return; + } + let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { + return; + }; + let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); + let sum: f32 = samples.iter().sum(); + if sum <= 0.0 { + return; + } + for (i, (_, child)) in node.children.iter_mut().enumerate() { + let noise = samples[i] / sum; + child.p = (1.0 - eps) * child.p + eps * noise; + } +} + +// ── Expansion ────────────────────────────────────────────────────────────── + +/// Evaluate the network at `state` and populate `node` with children. +/// +/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. +/// Returns the network value estimate from `player_idx`'s perspective. +pub(super) fn expand( + node: &mut MctsNode, + state: &E::State, + env: &E, + evaluator: &dyn Evaluator, + player_idx: usize, +) -> f32 { + let obs = env.observation(state, player_idx); + let legal = env.legal_actions(state); + let (logits, value) = evaluator.evaluate(&obs); + let priors = masked_softmax(&logits, &legal, env.action_space()); + node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); + node.expanded = true; + node.n = 1; + node.w = value; + value +} + +// ── Simulation ───────────────────────────────────────────────────────────── + +/// One MCTS simulation from an **already-expanded** decision node. +/// +/// Traverses the tree with PUCT selection, expands the first unvisited leaf, +/// and backs up the result. +/// +/// * `player_idx` — the player (0 or 1) who acts at `state`. +/// * Returns the backed-up value **from `player_idx`'s perspective**. +pub(super) fn simulate( + node: &mut MctsNode, + state: E::State, + env: &E, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, + player_idx: usize, +) -> f32 { + debug_assert!(node.expanded, "simulate called on unexpanded node"); + + // ── Selection: child with highest PUCT ──────────────────────────────── + let parent_n = node.n; + let best = node + .children + .iter() + .enumerate() + .max_by(|(_, (_, a)), (_, (_, b))| { + a.puct(parent_n, config.c_puct) + .partial_cmp(&b.puct(parent_n, config.c_puct)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) + .expect("expanded node must have at least one child"); + + let (action, child) = &mut node.children[best]; + let action = *action; + + // ── Apply action + advance through any chance nodes ─────────────────── + let mut next_state = state; + env.apply(&mut next_state, action); + while env.current_player(&next_state).is_chance() { + env.apply_chance(&mut next_state, rng); + } + + let next_cp = env.current_player(&next_state); + + // ── Evaluate leaf or terminal ────────────────────────────────────────── + // All values are converted to `player_idx`'s perspective before backup. + let child_value = if next_cp.is_terminal() { + let returns = env + .returns(&next_state) + .expect("terminal node must have returns"); + returns[player_idx] + } else { + let child_player = next_cp.index().unwrap(); + let v = if child.expanded { + simulate(child, next_state, env, evaluator, config, rng, child_player) + } else { + expand::(child, &next_state, env, evaluator, child_player) + }; + // Negate when the child belongs to the opponent. + if child_player == player_idx { v } else { -v } + }; + + // ── Backup ──────────────────────────────────────────────────────────── + node.n += 1; + node.w += child_value; + + child_value +} From b0ae4db2d9661405360e60f2115d0429e39c0af1 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 21:00:27 +0100 Subject: [PATCH 07/13] feat(spiel_bot): AlphaZero --- spiel_bot/src/alphazero/mod.rs | 117 ++++++++++++++ spiel_bot/src/alphazero/replay.rs | 144 +++++++++++++++++ spiel_bot/src/alphazero/selfplay.rs | 234 ++++++++++++++++++++++++++++ spiel_bot/src/alphazero/trainer.rs | 172 ++++++++++++++++++++ spiel_bot/src/lib.rs | 1 + 5 files changed, 668 insertions(+) create mode 100644 spiel_bot/src/alphazero/mod.rs create mode 100644 spiel_bot/src/alphazero/replay.rs create mode 100644 spiel_bot/src/alphazero/selfplay.rs create mode 100644 spiel_bot/src/alphazero/trainer.rs diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs new file mode 100644 index 0000000..bb86724 --- /dev/null +++ b/spiel_bot/src/alphazero/mod.rs @@ -0,0 +1,117 @@ +//! AlphaZero: self-play data generation, replay buffer, and training step. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | +//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | +//! | [`trainer`] | [`train_step`] | +//! +//! # Typical outer loop +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! use burn::optim::AdamConfig; +//! use spiel_bot::{ +//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, +//! env::TrictracEnv, +//! mcts::MctsConfig, +//! network::{MlpConfig, MlpNet}, +//! }; +//! +//! type Infer = NdArray; +//! type Train = Autodiff>; +//! +//! let device = Default::default(); +//! let env = TrictracEnv; +//! let config = AlphaZeroConfig::default(); +//! +//! // Build training model and optimizer. +//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); +//! let mut optimizer = AdamConfig::new().init(); +//! let mut replay = ReplayBuffer::new(config.replay_capacity); +//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); +//! +//! for _iter in 0..config.n_iterations { +//! // Convert to inference backend for self-play. +//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) +//! .load_record(train_model.clone().into_record()); +//! let eval = BurnEvaluator::new(infer_model, device.clone()); +//! +//! // Self-play: generate episodes. +//! for _ in 0..config.n_games_per_iter { +//! let samples = generate_episode(&env, &eval, &config.mcts, +//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); +//! replay.extend(samples); +//! } +//! +//! // Training: gradient steps. +//! if replay.len() >= config.batch_size { +//! for _ in 0..config.n_train_steps_per_iter { +//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) +//! .into_iter().cloned().collect(); +//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, +//! config.learning_rate); +//! train_model = m; +//! } +//! } +//! } +//! ``` + +pub mod replay; +pub mod selfplay; +pub mod trainer; + +pub use replay::{ReplayBuffer, TrainSample}; +pub use selfplay::{BurnEvaluator, generate_episode}; +pub use trainer::train_step; + +use crate::mcts::MctsConfig; + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Top-level AlphaZero hyperparameters. +/// +/// The MCTS parameters live in [`MctsConfig`]; this struct holds the +/// outer training-loop parameters. +#[derive(Debug, Clone)] +pub struct AlphaZeroConfig { + /// MCTS parameters for self-play. + pub mcts: MctsConfig, + /// Number of self-play games per training iteration. + pub n_games_per_iter: usize, + /// Number of gradient steps per training iteration. + pub n_train_steps_per_iter: usize, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Number of outer iterations (self-play + train) to run. + pub n_iterations: usize, + /// Move index after which the action temperature drops to 0 (greedy play). + pub temperature_drop_move: usize, +} + +impl Default for AlphaZeroConfig { + fn default() -> Self { + Self { + mcts: MctsConfig { + n_simulations: 100, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temperature: 1.0, + }, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + batch_size: 64, + replay_capacity: 50_000, + learning_rate: 1e-3, + n_iterations: 100, + temperature_drop_move: 30, + } + } +} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs new file mode 100644 index 0000000..5e64cc4 --- /dev/null +++ b/spiel_bot/src/alphazero/replay.rs @@ -0,0 +1,144 @@ +//! Replay buffer for AlphaZero self-play data. + +use std::collections::VecDeque; +use rand::Rng; + +// ── Training sample ──────────────────────────────────────────────────────── + +/// One training example produced by self-play. +#[derive(Clone, Debug)] +pub struct TrainSample { + /// Observation tensor from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). + pub policy: Vec, + /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. + pub value: f32, +} + +// ── Replay buffer ────────────────────────────────────────────────────────── + +/// Fixed-capacity circular buffer of [`TrainSample`]s. +/// +/// When the buffer is full, the oldest sample is evicted on push. +/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + /// Create a buffer with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity.min(1024)), + capacity, + } + } + + /// Add a sample; evicts the oldest if at capacity. + pub fn push(&mut self, sample: TrainSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + /// Add all samples from an episode. + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { + self.push(s); + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Sample up to `n` distinct samples, without replacement. + /// + /// If the buffer has fewer than `n` samples, all are returned (shuffled). + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + let len = self.data.len(); + let n = n.min(len); + // Partial Fisher-Yates using index shuffling. + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(value: f32) -> TrainSample { + TrainSample { obs: vec![value], policy: vec![1.0], value } + } + + #[test] + fn push_and_len() { + let mut buf = ReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = ReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); // evicts 1.0 + assert_eq!(buf.len(), 3); + // Oldest remaining should be 2.0 + assert_eq!(buf.data[0].value, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(5, &mut rng); + assert_eq!(batch.len(), 5); + } + + #[test] + fn sample_batch_capped_at_len() { + let mut buf = ReplayBuffer::new(20); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(100, &mut rng); + assert_eq!(batch.len(), 2); + } + + #[test] + fn sample_batch_no_duplicates() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(1); + let batch = buf.sample_batch(10, &mut rng); + let mut seen: Vec = batch.iter().map(|s| s.value).collect(); + seen.sort_by(f32::total_cmp); + seen.dedup(); + assert_eq!(seen.len(), 10, "sample contained duplicates"); + } +} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs new file mode 100644 index 0000000..6f10f8d --- /dev/null +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -0,0 +1,234 @@ +//! Self-play episode generation and Burn-backed evaluator. + +use std::marker::PhantomData; + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::GameEnv; +use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; +use crate::network::PolicyValueNet; + +use super::replay::TrainSample; + +// ── BurnEvaluator ────────────────────────────────────────────────────────── + +/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. +/// +/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so +/// that self-play generates no gradient tape overhead. +pub struct BurnEvaluator> { + model: N, + device: B::Device, + _b: PhantomData, +} + +impl> BurnEvaluator { + pub fn new(model: N, device: B::Device) -> Self { + Self { model, device, _b: PhantomData } + } + + pub fn into_model(self) -> N { + self.model + } +} + +// Safety: NdArray modules are Send; we never share across threads without +// external synchronisation. +unsafe impl> Send for BurnEvaluator {} +unsafe impl> Sync for BurnEvaluator {} + +impl> Evaluator for BurnEvaluator { + fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { + let obs_size = obs.len(); + let data = TensorData::new(obs.to_vec(), [1, obs_size]); + let obs_tensor = Tensor::::from_data(data, &self.device); + + let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); + + let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); + let value: Vec = value_tensor.into_data().to_vec().unwrap(); + + (policy, value[0]) + } +} + +// ── Episode generation ───────────────────────────────────────────────────── + +/// One pending observation waiting for its game-outcome value label. +struct PendingSample { + obs: Vec, + policy: Vec, + player: usize, +} + +/// Play one full game using MCTS guided by `evaluator`. +/// +/// Returns a [`TrainSample`] for every decision step in the game. +/// +/// `temperature_fn(step)` controls exploration: return `1.0` for early +/// moves and `0.0` after a fixed number of moves (e.g. move 30). +pub fn generate_episode( + env: &E, + evaluator: &dyn Evaluator, + mcts_config: &MctsConfig, + temperature_fn: &dyn Fn(usize) -> f32, + rng: &mut impl Rng, +) -> Vec { + let mut state = env.new_game(); + let mut pending: Vec = Vec::new(); + let mut step = 0usize; + + loop { + // Advance through chance nodes automatically. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + if env.current_player(&state).is_terminal() { + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + + // Run MCTS to get a policy. + let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); + let policy = mcts::mcts_policy(&root, env.action_space()); + + // Record the observation from the acting player's perspective. + let obs = env.observation(&state, player_idx); + pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); + + // Select and apply the action. + let temperature = temperature_fn(step); + let action = mcts::select_action(&root, temperature, rng); + env.apply(&mut state, action); + step += 1; + } + + // Assign game outcomes. + let returns = env.returns(&state).unwrap_or([0.0; 2]); + pending + .into_iter() + .map(|s| TrainSample { + obs: s.obs, + policy: s.policy, + value: returns[s.player], + }) + .collect() +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::env::Player; + use crate::mcts::{Evaluator, MctsConfig}; + use crate::network::{MlpConfig, MlpNet}; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(7) + } + + // Countdown game (same as in mcts tests). + #[derive(Clone, Debug)] + struct CState { remaining: u8, to_move: usize } + + #[derive(Clone)] + struct CountdownEnv; + + impl GameEnv for CountdownEnv { + type State = CState; + fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } else { Player::P2 } + } + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { s.remaining = 0; } + else { s.remaining -= sub; s.to_move = 1 - s.to_move; } + } + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 4.0, s.to_move as f32] + } + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + fn tiny_config() -> MctsConfig { + MctsConfig { n_simulations: 5, c_puct: 1.5, + dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } + } + + // ── BurnEvaluator tests ─────────────────────────────────────────────── + + #[test] + fn burn_evaluator_output_shapes() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); + assert_eq!(policy.len(), 2, "policy length should equal action_space"); + assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); + } + + // ── generate_episode tests ──────────────────────────────────────────── + + #[test] + fn episode_terminates_and_has_samples() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_sample_values_are_valid() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + for s in &samples { + assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "unexpected value {}", s.value); + let sum: f32 = s.policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); + assert_eq!(s.obs.len(), 2); + } + } + + #[test] + fn episode_with_temperature_zero() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + // temperature=0 means greedy; episode must still terminate + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); + assert!(!samples.is_empty()); + } +} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs new file mode 100644 index 0000000..d2482d1 --- /dev/null +++ b/spiel_bot/src/alphazero/trainer.rs @@ -0,0 +1,172 @@ +//! One gradient-descent training step for AlphaZero. +//! +//! The loss combines: +//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. +//! - **Value loss** — mean-squared error between the predicted value and the +//! actual game outcome. +//! +//! # Backend +//! +//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). +//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. +//! Weights are transferred between the two via [`burn::record`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + activation::log_softmax, + backend::AutodiffBackend, + Tensor, TensorData, + }, +}; + +use crate::network::PolicyValueNet; +use super::replay::TrainSample; + +/// Run one gradient step on `model` using `batch`. +/// +/// Returns the updated model and the scalar loss value for logging. +/// +/// # Parameters +/// +/// - `lr` — learning rate (e.g. `1e-3`). +/// - `batch` — slice of [`TrainSample`]s; must be non-empty. +pub fn train_step( + model: N, + optimizer: &mut O, + batch: &[TrainSample], + device: &B::Device, + lr: f64, +) -> (N, f32) +where + B: AutodiffBackend, + N: PolicyValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "train_step called with empty batch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + let action_size = batch[0].policy.len(); + + // ── Build input tensors ──────────────────────────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); + let value_flat: Vec = batch.iter().map(|s| s.value).collect(); + + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + let policy_target = Tensor::::from_data( + TensorData::new(policy_flat, [batch_size, action_size]), + device, + ); + let value_target = Tensor::::from_data( + TensorData::new(value_flat, [batch_size, 1]), + device, + ); + + // ── Forward pass ────────────────────────────────────────────────────── + let (policy_logits, value_pred) = model.forward(obs_tensor); + + // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── + let log_probs = log_softmax(policy_logits, 1); + let policy_loss = (policy_target.clone().neg() * log_probs) + .sum_dim(1) + .mean(); + + // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── + let diff = value_pred - value_target; + let value_loss = (diff.clone() * diff).mean(); + + // ── Combined loss ───────────────────────────────────────────────────── + let loss = policy_loss + value_loss; + + // Extract scalar before backward (consumes the tensor). + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + let model = optimizer.step(lr, model, grads); + + (model, loss_scalar) +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + + use crate::network::{MlpConfig, MlpNet}; + use super::super::replay::TrainSample; + + type B = Autodiff>; + + fn device() -> ::Device { + Default::default() + } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5f32; obs_size], + policy: { + let mut p = vec![0.0f32; action_size]; + p[i % action_size] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + } + + #[test] + fn train_step_returns_finite_loss() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); + } + + #[test] + fn loss_decreases_over_steps() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + // Same batch every step — loss should decrease. + let batch = dummy_batch(16, 4, 4); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); + model = m; + assert!(loss.is_finite()); + prev_loss = loss; + } + // After 10 steps on fixed data, loss should be below a reasonable threshold. + assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn train_step_batch_size_one() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(1, 2, 2); + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite()); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 5beb37c..23895b9 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,3 +1,4 @@ +pub mod alphazero; pub mod env; pub mod mcts; pub mod network; From 519dfe67ad952db34222762ae212a34ba267f50e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:18:59 +0100 Subject: [PATCH 08/13] fix(spiel_bot): mcts fix --- spiel_bot/src/mcts/mod.rs | 8 ++++++-- spiel_bot/src/mcts/search.rs | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index e92bd09..a0a690d 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -401,8 +401,12 @@ mod tests { }; let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - assert!(root.n > 0); + // root.n = 1 (expansion) + n_simulations (one backup per simulation). + assert_eq!(root.n, 1 + config.n_simulations as u32); + // Children visit counts may sum to less than n_simulations when some + // simulations cross a chance node at depth 1 (turn ends after one move) + // and evaluate with the network directly without updating child.n. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 5); + assert!(total <= config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index c4960c7..55db701 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -138,8 +138,14 @@ pub(super) fn simulate( // ── Apply action + advance through any chance nodes ─────────────────── let mut next_state = state; env.apply(&mut next_state, action); + + // Track whether we crossed a chance node (dice roll) on the way down. + // If we did, the child's cached legal actions are for a *different* dice + // outcome and must not be reused — evaluate with the network directly. + let mut crossed_chance = false; while env.current_player(&next_state).is_chance() { env.apply_chance(&mut next_state, rng); + crossed_chance = true; } let next_cp = env.current_player(&next_state); @@ -153,7 +159,15 @@ pub(super) fn simulate( returns[player_idx] } else { let child_player = next_cp.index().unwrap(); - let v = if child.expanded { + let v = if crossed_chance { + // Outcome sampling: after dice, evaluate the resulting position + // directly with the network. Do NOT build the tree across chance + // boundaries — the dice change which actions are legal, so any + // previously cached children would be for a different outcome. + let obs = env.observation(&next_state, child_player); + let (_, value) = evaluator.evaluate(&obs); + value + } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) } else { expand::(child, &next_state, env, evaluator, child_player) From aea1e3faafc42b1953d709f03331db8d50dfc458 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:19:15 +0100 Subject: [PATCH 09/13] tests(spiel_bot): integration tests --- spiel_bot/tests/integration.rs | 391 +++++++++++++++++++++++++++++++++ 1 file changed, 391 insertions(+) create mode 100644 spiel_bot/tests/integration.rs diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs new file mode 100644 index 0000000..d73fda0 --- /dev/null +++ b/spiel_bot/tests/integration.rs @@ -0,0 +1,391 @@ +//! End-to-end integration tests for the AlphaZero training pipeline. +//! +//! Each test exercises the full chain: +//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] +//! +//! Two environments are used: +//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. +//! Used when we need many iterations without worrying about runtime. +//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that +//! the full pipeline compiles and runs correctly with 217-dim observations +//! and 514-dim action spaces. +//! +//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep +//! runtime minimal; correctness, not training quality, is what matters here. + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Backend aliases ──────────────────────────────────────────────────────── + +type Train = Autodiff>; +type Infer = NdArray; + +// ── Helpers ──────────────────────────────────────────────────────────────── + +fn train_device() -> ::Device { + Default::default() +} + +fn infer_device() -> ::Device { + Default::default() +} + +/// Tiny 64-unit MLP, compatible with an obs/action space of any size. +fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { + let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; + MlpNet::new(&cfg, &train_device()) +} + +fn tiny_mcts(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + } +} + +fn seeded() -> SmallRng { + SmallRng::seed_from_u64(0) +} + +// ── Countdown environment (fast, local, no external deps) ───────────────── +// +// Two players alternate subtracting 1 or 2 from a counter that starts at N. +// The player who brings the counter to 0 wins. + +#[derive(Clone, Debug)] +struct CState { + remaining: u8, + to_move: usize, +} + +#[derive(Clone)] +struct CountdownEnv(u8); // starting value + +impl GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: self.0, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } + else { Player::P2 } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } +} + +// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── + +/// The canonical AlphaZero loop: self-play → replay → train, iterated. +/// Uses CountdownEnv so each game terminates in < 10 moves. +#[test] +fn countdown_full_loop_no_panic() { + let env = CountdownEnv(8); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(1_000); + + for _iter in 0..5 { + // Self-play: 3 games per iteration. + for _ in 0..3 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + replay.extend(samples); + } + + // Training: 4 gradient steps per iteration. + if replay.len() >= 4 { + for _ in 0..4 { + let batch: Vec = replay + .sample_batch(4, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + } + } + + assert!(replay.len() > 0); +} + +// ── 2. Replay buffer invariants ─────────────────────────────────────────── + +/// After several Countdown games, replay capacity is respected and batch +/// shapes are consistent. +#[test] +fn replay_buffer_capacity_and_shapes() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let capacity = 50; + let mut replay = ReplayBuffer::new(capacity); + + for _ in 0..20 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + replay.extend(samples); + } + + assert!(replay.len() <= capacity, "buffer exceeded capacity"); + assert!(replay.len() > 0); + + let batch = replay.sample_batch(8, &mut rng); + assert_eq!(batch.len(), 8.min(replay.len())); + for s in &batch { + assert_eq!(s.obs.len(), env.obs_size()); + assert_eq!(s.policy.len(), env.action_space()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); + assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); + } +} + +// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── + +/// Verify that one TrictracEnv episode produces samples with the correct +/// tensor dimensions: obs = 217, policy = 514. +#[test] +fn trictrac_sample_shapes() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + + assert!(!samples.is_empty(), "Trictrac episode produced no samples"); + + for (i, s) in samples.iter().enumerate() { + assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); + assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!( + (policy_sum - 1.0).abs() < 1e-4, + "sample {i}: policy sums to {policy_sum}" + ); + assert!( + s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "sample {i}: unexpected value {}", + s.value + ); + } +} + +// ── 4. TrictracEnv: training step after real self-play ──────────────────── + +/// Collect one Trictrac episode, then verify that a gradient step runs +/// without panic and produces a finite loss. +#[test] +fn trictrac_train_step_finite_loss() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(10_000); + + // Generate one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + let n_samples = samples.len(); + replay.extend(samples); + + // Train on a batch from this episode. + let batch_size = 8.min(n_samples); + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); +} + +// ── 5. Backend transfer: train → infer → same outputs ───────────────────── + +/// Weights transferred from the training backend to the inference backend +/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. +#[test] +fn valid_model_matches_train_model_outputs() { + use burn::tensor::{Tensor, TensorData}; + + let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let train_model = MlpNet::::new(&cfg, &train_device()); + let infer_model: MlpNet = train_model.valid(); + + // Build the same input on both backends. + let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; + + let obs_train = Tensor::::from_data( + TensorData::new(obs_data.clone(), [1, 4]), + &train_device(), + ); + let obs_infer = Tensor::::from_data( + TensorData::new(obs_data, [1, 4]), + &infer_device(), + ); + + let (p_train, v_train) = train_model.forward(obs_train); + let (p_infer, v_infer) = infer_model.forward(obs_infer); + + let p_train: Vec = p_train.into_data().to_vec().unwrap(); + let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); + let v_train: Vec = v_train.into_data().to_vec().unwrap(); + let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); + + for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { + assert!( + (a - b).abs() < 1e-5, + "policy[{i}] differs after valid(): train={a}, infer={b}" + ); + } + assert!( + (v_train[0] - v_infer[0]).abs() < 1e-5, + "value differs after valid(): train={}, infer={}", + v_train[0], v_infer[0] + ); +} + +// ── 6. Loss converges on a fixed batch ──────────────────────────────────── + +/// With repeated gradient steps on the same Countdown batch, the loss must +/// decrease monotonically (or at least end lower than it started). +#[test] +fn loss_decreases_on_fixed_batch() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + + // Collect a fixed batch from one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); + assert!(!samples.is_empty()); + + let batch: Vec = { + let mut replay = ReplayBuffer::new(1000); + replay.extend(samples); + replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() + }; + + // Overfit on the same fixed batch for 20 steps. + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut first_loss = f32::NAN; + let mut last_loss = f32::NAN; + + for step in 0..20 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); + model = m; + assert!(loss.is_finite(), "loss is not finite at step {step}"); + if step == 0 { first_loss = loss; } + last_loss = loss; + } + + assert!( + last_loss < first_loss, + "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" + ); +} + +// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── + +/// Two full self-play + train iterations on TrictracEnv. +/// Verifies the entire pipeline runs without panic end-to-end. +#[test] +fn trictrac_two_iteration_loop() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mut model = MlpNet::::new(&cfg, &train_device()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(20_000); + + for iter in 0..2 { + // Self-play: 1 game per iteration. + let infer: MlpNet = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); + assert!(!samples.is_empty(), "iter {iter}: episode was empty"); + replay.extend(samples); + + // Training: 3 gradient steps. + let batch_size = 16.min(replay.len()); + for _ in 0..3 { + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "iter {iter}: loss={loss}"); + } + } +} From 9c82692ddb687057f15c0ef84d4dc8a5338f8697 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:49:55 +0100 Subject: [PATCH 10/13] feat(spiel_bot): benchmarks --- Cargo.lock | 120 ++++++++++++ spiel_bot/Cargo.toml | 7 + spiel_bot/benches/alphazero.rs | 341 +++++++++++++++++++++++++++++++++ 3 files changed, 468 insertions(+) create mode 100644 spiel_bot/benches/alphazero.rs diff --git a/Cargo.lock b/Cargo.lock index 0baa02a..34bfe80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -1116,6 +1122,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cast_trait" version = "0.1.2" @@ -1200,6 +1212,33 @@ dependencies = [ "rand 0.7.3", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1453,6 +1492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -4461,6 +4536,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4597,6 +4678,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5897,6 +6006,7 @@ version = "0.1.0" dependencies = [ "anyhow", "burn", + "criterion", "rand 0.9.2", "rand_distr", "trictrac-store", @@ -6310,6 +6420,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 323c953..3848dce 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,3 +9,10 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "alphazero" +harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs new file mode 100644 index 0000000..2950b09 --- /dev/null +++ b/spiel_bot/benches/alphazero.rs @@ -0,0 +1,341 @@ +//! AlphaZero pipeline benchmarks. +//! +//! Run with: +//! +//! ```sh +//! cargo bench -p spiel_bot +//! ``` +//! +//! Use `-- ` to run a specific group, e.g.: +//! +//! ```sh +//! cargo bench -p spiel_bot -- env/ +//! cargo bench -p spiel_bot -- network/ +//! cargo bench -p spiel_bot -- mcts/ +//! cargo bench -p spiel_bot -- episode/ +//! cargo bench -p spiel_bot -- train/ +//! ``` +//! +//! Target: ≥ 500 games/s for random play on CPU (consistent with +//! `random_game` throughput in `trictrac-store`). + +use std::time::Duration; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData, backend::Backend}, +}; +use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts}, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Shared types ─────────────────────────────────────────────────────────── + +type InferB = NdArray; +type TrainB = burn::backend::Autodiff>; + +fn infer_device() -> ::Device { Default::default() } +fn train_device() -> ::Device { Default::default() } + +fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } + +/// Uniform evaluator (returns zero logits and zero value). +/// Used to isolate MCTS tree-traversal cost from network cost. +struct ZeroEval(usize); +impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } +} + +// ── 1. Environment primitives ────────────────────────────────────────────── + +/// Baseline performance of the raw Trictrac environment without MCTS. +/// Target: ≥ 500 full games / second. +fn bench_env(c: &mut Criterion) { + let env = TrictracEnv; + + let mut group = c.benchmark_group("env"); + group.measurement_time(Duration::from_secs(10)); + + // ── apply_chance ────────────────────────────────────────────────────── + group.bench_function("apply_chance", |b| { + b.iter_batched( + || { + // A fresh game is always at RollDice (Chance) — ready for apply_chance. + env.new_game() + }, + |mut s| { + env.apply_chance(&mut s, &mut seeded()); + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + // ── legal_actions ───────────────────────────────────────────────────── + group.bench_function("legal_actions", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.legal_actions(&s))) + }); + + // ── observation (to_tensor) ─────────────────────────────────────────── + group.bench_function("observation", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.observation(&s, 0))) + }); + + // ── full random game ────────────────────────────────────────────────── + group.sample_size(50); + group.bench_function("random_game", |b| { + b.iter_batched( + seeded, + |mut rng| { + let mut s = env.new_game(); + loop { + match env.current_player(&s) { + Player::Terminal => break, + Player::Chance => env.apply_chance(&mut s, &mut rng), + _ => { + let actions = env.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + env.apply(&mut s, actions[idx]); + } + } + } + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +// ── 2. Network inference ─────────────────────────────────────────────────── + +/// Forward-pass latency for MLP variants (hidden = 64 / 256). +fn bench_network(c: &mut Criterion) { + let mut group = c.benchmark_group("network"); + group.measurement_time(Duration::from_secs(5)); + + for &hidden in &[64usize, 256] { + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = MlpNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + // Batch size 1 — single-position evaluation as in MCTS. + group.bench_with_input( + BenchmarkId::new("mlp_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + // Batch size 32 — training mini-batch. + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("mlp_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + group.finish(); +} + +// ── 3. MCTS ─────────────────────────────────────────────────────────────── + +/// MCTS cost at different simulation budgets with two evaluator types: +/// - `zero` — isolates tree-traversal overhead (no network). +/// - `mlp64` — real MLP, shows end-to-end cost per move. +fn bench_mcts(c: &mut Criterion) { + let env = TrictracEnv; + + // Build a decision-node state (after dice roll). + let state = { + let mut s = env.new_game(); + let mut rng = seeded(); + while env.current_player(&s).is_chance() { + env.apply_chance(&mut s, &mut rng); + } + s + }; + + let mut group = c.benchmark_group("mcts"); + group.measurement_time(Duration::from_secs(10)); + + let zero_eval = ZeroEval(514); + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); + let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); + + for &n_sim in &[1usize, 5, 20] { + let cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + // Zero evaluator: tree traversal only. + group.bench_with_input( + BenchmarkId::new("zero_eval", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + + // MLP evaluator: full cost per decision. + group.bench_with_input( + BenchmarkId::new("mlp64", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 4. Episode generation ───────────────────────────────────────────────── + +/// Full self-play episode latency (one complete game) at different MCTS +/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. +fn bench_episode(c: &mut Criterion) { + let env = TrictracEnv; + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let model = MlpNet::::new(&mlp_cfg, &infer_device()); + let eval = BurnEvaluator::::new(model, infer_device()); + + let mut group = c.benchmark_group("episode"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(60)); + + for &n_sim in &[1usize, 2] { + let mcts_cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + group.bench_with_input( + BenchmarkId::new("trictrac", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| { + black_box(generate_episode( + &env, + &eval, + &mcts_cfg, + &|_| 1.0, + &mut rng, + )) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 5. Training step ─────────────────────────────────────────────────────── + +/// Gradient-step latency for different batch sizes. +fn bench_train(c: &mut Criterion) { + use burn::optim::AdamConfig; + + let mut group = c.benchmark_group("train"); + group.measurement_time(Duration::from_secs(10)); + + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + + let dummy_samples = |n: usize| -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5; 217], + policy: { + let mut p = vec![0.0f32; 514]; + p[i % 514] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + }; + + for &batch_size in &[16usize, 64] { + let batch = dummy_samples(batch_size); + + group.bench_with_input( + BenchmarkId::new("mlp64_adam", batch_size), + &batch_size, + |b, _| { + b.iter_batched( + || { + ( + MlpNet::::new(&mlp_cfg, &train_device()), + AdamConfig::new().init::>(), + ) + }, + |(model, mut opt)| { + black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── Criterion entry point ────────────────────────────────────────────────── + +criterion_group!( + benches, + bench_env, + bench_network, + bench_mcts, + bench_episode, + bench_train, +); +criterion_main!(benches); From 822290d7224fa8cc6ce0aa324adcfe388dd705b2 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 23:05:53 +0100 Subject: [PATCH 11/13] feat(spiel_bot): upgrade network --- spiel_bot/benches/alphazero.rs | 34 +++++++++++- spiel_bot/src/alphazero/mod.rs | 14 ++++- spiel_bot/src/alphazero/trainer.rs | 86 ++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs index 2950b09..00d5b02 100644 --- a/spiel_bot/benches/alphazero.rs +++ b/spiel_bot/benches/alphazero.rs @@ -32,7 +32,7 @@ use spiel_bot::{ alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, env::{GameEnv, Player, TrictracEnv}, mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet}, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, }; // ── Shared types ─────────────────────────────────────────────────────────── @@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) { ); } + // ── ResNet (4 residual blocks) ──────────────────────────────────────── + for &hidden in &[256usize, 512] { + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = ResNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + group.bench_with_input( + BenchmarkId::new("resnet_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("resnet_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + group.finish(); } diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs index bb86724..d92224e 100644 --- a/spiel_bot/src/alphazero/mod.rs +++ b/spiel_bot/src/alphazero/mod.rs @@ -65,7 +65,7 @@ pub mod trainer; pub use replay::{ReplayBuffer, TrainSample}; pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::train_step; +pub use trainer::{cosine_lr, train_step}; use crate::mcts::MctsConfig; @@ -87,8 +87,17 @@ pub struct AlphaZeroConfig { pub batch_size: usize, /// Maximum number of samples in the replay buffer. pub replay_capacity: usize, - /// Adam learning rate. + /// Initial (peak) Adam learning rate. pub learning_rate: f64, + /// Minimum learning rate for cosine annealing (floor of the schedule). + /// + /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). + /// Compute the current LR with [`cosine_lr`]: + /// + /// ```rust,ignore + /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); + /// ``` + pub lr_min: f64, /// Number of outer iterations (self-play + train) to run. pub n_iterations: usize, /// Move index after which the action temperature drops to 0 (greedy play). @@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig { batch_size: 64, replay_capacity: 50_000, learning_rate: 1e-3, + lr_min: 1e-4, // cosine annealing floor n_iterations: 100, temperature_drop_move: 30, } diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs index d2482d1..9075519 100644 --- a/spiel_bot/src/alphazero/trainer.rs +++ b/spiel_bot/src/alphazero/trainer.rs @@ -5,6 +5,24 @@ //! - **Value loss** — mean-squared error between the predicted value and the //! actual game outcome. //! +//! # Learning-rate scheduling +//! +//! [`cosine_lr`] implements one-cycle cosine annealing: +//! +//! ```text +//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) +//! ``` +//! +//! Typical usage in the outer loop: +//! +//! ```rust,ignore +//! for step in 0..total_train_steps { +//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); +//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); +//! model = m; +//! } +//! ``` +//! //! # Backend //! //! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). @@ -96,6 +114,30 @@ where (model, loss_scalar) } +// ── Learning-rate schedule ───────────────────────────────────────────────── + +/// Cosine learning-rate schedule (one half-period, no warmup). +/// +/// Returns the learning rate for training step `step` out of `total_steps`: +/// +/// ```text +/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) +/// ``` +/// +/// - At `t = 0` returns `initial`. +/// - At `t = total_steps` (or beyond) returns `lr_min`. +/// +/// # Panics +/// +/// Does not panic. When `total_steps == 0`, returns `lr_min`. +pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { + if total_steps == 0 || step >= total_steps { + return lr_min; + } + let progress = step as f64 / total_steps as f64; + lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) +} + // ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -169,4 +211,48 @@ mod tests { let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); assert!(loss.is_finite()); } + + // ── cosine_lr ───────────────────────────────────────────────────────── + + #[test] + fn cosine_lr_at_step_zero_is_initial() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); + assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); + } + + #[test] + fn cosine_lr_at_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); + } + + #[test] + fn cosine_lr_beyond_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); + } + + #[test] + fn cosine_lr_midpoint_is_average() { + // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. + let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); + let expected = (1e-3 + 1e-5) / 2.0; + assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); + } + + #[test] + fn cosine_lr_monotone_decreasing() { + let mut prev = f64::INFINITY; + for step in 0..=100 { + let lr = super::cosine_lr(1e-3, 1e-5, step, 100); + assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); + prev = lr; + } + } + + #[test] + fn cosine_lr_zero_total_steps_returns_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); + assert!((lr - 1e-5).abs() < 1e-10); + } } From 3221b5256a370734d7cc5d861a981b0111ed49fb Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 23:06:21 +0100 Subject: [PATCH 12/13] feat(spiel_bot): alphazero eval binary --- spiel_bot/src/bin/az_eval.rs | 262 +++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 spiel_bot/src/bin/az_eval.rs diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs new file mode 100644 index 0000000..3c82519 --- /dev/null +++ b/spiel_bot/src/bin/az_eval.rs @@ -0,0 +1,262 @@ +//! Evaluate a trained AlphaZero checkpoint against a random player. +//! +//! # Usage +//! +//! ```sh +//! # Random weights (sanity check — should be ~50 %) +//! cargo run -p spiel_bot --bin az_eval --release +//! +//! # Trained MLP checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 +//! +//! # Trained ResNet checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | +//! | `--n-games ` | `100` | Games per side (total = 2 × N) | +//! | `--n-sim ` | `50` | MCTS simulations per move | +//! | `--seed ` | `42` | RNG seed | +//! | `--c-puct ` | `1.5` | PUCT exploration constant | + +use std::path::PathBuf; + +use burn::backend::NdArray; +use rand::{SeedableRng, rngs::SmallRng, Rng}; + +use spiel_bot::{ + alphazero::BurnEvaluator, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts, select_action}, + network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, +}; + +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + checkpoint: Option, + arch: String, + hidden: Option, + n_games: usize, + n_sim: usize, + seed: u64, + c_puct: f32, +} + +impl Default for Args { + fn default() -> Self { + Self { + checkpoint: None, + arch: "mlp".into(), + hidden: None, + n_games: 100, + n_sim: 50, + seed: 42, + c_puct: 1.5, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut args = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } + "--arch" => { i += 1; args.arch = raw[i].clone(); } + "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } + "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } + "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } + "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } + "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + args +} + +// ── Game loop ───────────────────────────────────────────────────────────────── + +/// Play one complete game. +/// +/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). +/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). +fn play_game( + env: &TrictracEnv, + mcts_side: usize, + evaluator: &dyn Evaluator, + mcts_cfg: &MctsConfig, + rng: &mut SmallRng, +) -> [f32; 2] { + let mut state = env.new_game(); + loop { + match env.current_player(&state) { + Player::Terminal => { + return env.returns(&state).expect("Terminal state must have returns"); + } + Player::Chance => env.apply_chance(&mut state, rng), + player => { + let side = player.index().unwrap(); // 0 = P1, 1 = P2 + let action = if side == mcts_side { + let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); + select_action(&root, 0.0, rng) // greedy (temperature = 0) + } else { + let actions = env.legal_actions(&state); + actions[rng.random_range(0..actions.len())] + }; + env.apply(&mut state, action); + } + } + } +} + +// ── Statistics ──────────────────────────────────────────────────────────────── + +#[derive(Default)] +struct Stats { + wins: u32, + draws: u32, + losses: u32, +} + +impl Stats { + fn record(&mut self, mcts_return: f32) { + if mcts_return > 0.0 { self.wins += 1; } + else if mcts_return < 0.0 { self.losses += 1; } + else { self.draws += 1; } + } + + fn total(&self) -> u32 { self.wins + self.draws + self.losses } + + fn win_rate_decisive(&self) -> f64 { + let d = self.wins + self.losses; + if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } + } + + fn print(&self) { + let n = self.total(); + let pct = |k: u32| 100.0 * k as f64 / n as f64; + println!( + " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", + self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), + ); + } +} + +// ── Evaluation ──────────────────────────────────────────────────────────────── + +fn run_evaluation( + evaluator: &dyn Evaluator, + n_games: usize, + mcts_cfg: &MctsConfig, + seed: u64, +) -> (Stats, Stats) { + let env = TrictracEnv; + let total = n_games * 2; + let mut as_p1 = Stats::default(); + let mut as_p2 = Stats::default(); + + for i in 0..total { + // Alternate sides: even games → MctsAgent as P1, odd → as P2. + let mcts_side = i % 2; + let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); + let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); + + let mcts_return = result[mcts_side]; + if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } + + let done = i + 1; + if done % 10 == 0 || done == total { + eprint!("\r [{done}/{total}] ", ); + } + } + eprintln!(); + (as_p1, as_p2) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let device: ::Device = Default::default(); + + // ── Load model ──────────────────────────────────────────────────────── + let evaluator: Box = match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => ResNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => ResNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => MlpNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => MlpNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + }; + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: 0.0, // no exploration noise during evaluation + dirichlet_eps: 0.0, + temperature: 0.0, // greedy action selection + }; + + // ── Header ──────────────────────────────────────────────────────────── + let ckpt_label = args.checkpoint + .as_deref() + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .unwrap_or("random weights"); + + println!(); + println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", + args.arch, args.n_sim); + println!("Games per side: {} | Total: {} | Seed: {}", + args.n_games, args.n_games * 2, args.seed); + println!(); + + // ── Run ─────────────────────────────────────────────────────────────── + let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); + + // ── Results ─────────────────────────────────────────────────────────── + println!("MctsAgent as P1 (White):"); + as_p1.print(); + + println!("MctsAgent as P2 (Black):"); + as_p2.print(); + + let combined_wins = as_p1.wins + as_p2.wins; + let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; + let combined_wr = if combined_decisive == 0 { 0.5 } + else { combined_wins as f64 / combined_decisive as f64 }; + + println!(); + println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", + combined_wr * 100.0, combined_wins, combined_decisive); + println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", + as_p1.win_rate_decisive() * 100.0, + as_p2.win_rate_decisive() * 100.0); +} From 150efe302fc1294df640340a0a162e8e366db59b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Mar 2026 12:28:39 +0100 Subject: [PATCH 13/13] feat(spiel_bot): az_train training command --- spiel_bot/src/bin/az_train.rs | 314 ++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 spiel_bot/src/bin/az_train.rs diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs new file mode 100644 index 0000000..ab385c2 --- /dev/null +++ b/spiel_bot/src/bin/az_train.rs @@ -0,0 +1,314 @@ +//! AlphaZero self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh (MLP, default settings) +//! cargo run -p spiel_bot --bin az_train --release +//! +//! # ResNet, 200 iterations, save every 20 +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/ +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden N` | 256/512 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | `100` | Training iterations | +//! | `--n-games N` | `10` | Self-play games per iteration | +//! | `--n-train N` | `20` | Gradient steps per iteration | +//! | `--n-sim N` | `100` | MCTS simulations per move | +//! | `--batch N` | `64` | Mini-batch size | +//! | `--replay-cap N` | `50000` | Replay buffer capacity | +//! | `--lr F` | `1e-3` | Peak (initial) learning rate | +//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) | +//! | `--c-puct F` | `1.5` | PUCT exploration constant | +//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha | +//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight | +//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 | +//! | `--save-every N` | `10` | Save checkpoint every N iterations | +//! | `--seed N` | `42` | RNG seed | +//! | `--resume PATH` | (none) | Load weights from checkpoint before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{ + BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step, + }, + env::TrictracEnv, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + arch: String, + hidden: Option, + out_dir: PathBuf, + n_iter: usize, + n_games: usize, + n_train: usize, + n_sim: usize, + batch_size: usize, + replay_cap: usize, + lr: f64, + lr_min: f64, + c_puct: f32, + dirichlet_alpha: f32, + dirichlet_eps: f32, + temp_drop: usize, + save_every: usize, + seed: u64, + resume: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + arch: "mlp".into(), + hidden: None, + out_dir: PathBuf::from("checkpoints"), + n_iter: 100, + n_games: 10, + n_train: 20, + n_sim: 100, + batch_size: 64, + replay_cap: 50_000, + lr: 1e-3, + lr_min: 1e-4, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temp_drop: 30, + save_every: 10, + seed: 42, + resume: None, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--arch" => { i += 1; a.arch = raw[i].clone(); } + "--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); } + "--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); } + "--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); } + "--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); } + "--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); } + "--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); } + "--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); } + "--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +/// Generic training loop, parameterised over the network type. +/// +/// `save_fn` receives the **training-backend** model and the target path; +/// it is called in the match arm where the concrete network type is known. +fn train_loop( + mut model: N, + save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>, + args: &Args, +) +where + N: PolicyValueNet + AutodiffModule + Clone, + >::InnerModule: PolicyValueNet + Send + 'static, +{ + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + // Type is inferred as OptimizerAdaptor at the call site. + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(args.replay_cap); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + // Total gradient steps (used for cosine LR denominator). + let total_train_steps = (args.n_iter * args.n_train).max(1); + let mut global_step = 0usize; + + println!( + "\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}", + "", args.arch, args.n_iter, args.n_games, args.n_sim, "" + ); + + for iter in 0..args.n_iter { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + // Convert to inference backend (zero autodiff overhead). + let infer_model: >::InnerModule = model.valid(); + let evaluator: BurnEvaluator>::InnerModule> = + BurnEvaluator::new(infer_model, infer_device.clone()); + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: args.dirichlet_alpha, + dirichlet_eps: args.dirichlet_eps, + temperature: 1.0, + }; + + let temp_drop = args.temp_drop; + let temperature_fn = |step: usize| -> f32 { + if step < temp_drop { 1.0 } else { 0.0 } + }; + + let mut new_samples = 0usize; + for _ in 0..args.n_games { + let samples = + generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.batch_size { + for _ in 0..args.n_train { + let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + let batch: Vec = replay + .sample_batch(args.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = + train_step(model, &mut optimizer, &batch, &train_device, lr); + model = m; + loss_sum += loss; + n_steps += 1; + global_step += 1; + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s", + iter + 1, + args.n_iter, + replay.len(), + new_samples, + avg_loss, + lr_now, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.n_iter; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1)); + match save_fn(&model, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nTraining complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + // Create output directory if it doesn't exist. + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + + match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + ResNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => ResNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &ResNet, path: &Path| { + // Save via inference model to avoid autodiff record overhead. + m.valid().save(path) + }, + &args, + ); + } + + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + MlpNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => MlpNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &MlpNet, path: &Path| m.valid().save(path), + &args, + ); + } + } +}