From 85ccca47412617d2bd4e0880d132189b6c32886e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 17:52:04 +0100 Subject: [PATCH 01/16] doc:rust open_spiel research --- doc/spiel_bot_research.md | 782 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 782 insertions(+) create mode 100644 doc/spiel_bot_research.md diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md new file mode 100644 index 0000000..a8863af --- /dev/null +++ b/doc/spiel_bot_research.md @@ -0,0 +1,782 @@ +# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac + +## 0. Context and Scope + +The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library +(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` +encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every +other stage to an inline random-opponent loop. + +`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency +for **self-play training**. Its goals: + +- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") + that works with Trictrac's multi-stage turn model and stochastic dice. +- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) + as the first algorithm. +- Remain **modular**: adding DQN or PPO later requires only a new + `impl Algorithm for Dqn` without touching the environment or network layers. +- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from + `trictrac-store`. + +--- + +## 1. Library Landscape + +### 1.1 Neural Network Frameworks + +| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | +| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | +| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | +| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | +| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | +| ndarray alone | no | no | yes | mature | array ops only; no autograd | + +**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ +runtime needed, the `ndarray` backend is sufficient for CPU training and can +switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by +changing one type alias. + +`tch-rs` would be the best choice for raw training throughput (it is the most +battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the +pure-Rust constraint. If training speed becomes the bottleneck after prototyping, +switching `spiel_bot` to `tch-rs` is a one-line backend swap. + +### 1.2 Other Key Crates + +| Crate | Role | +| -------------------- | ----------------------------------------------------------------- | +| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | +| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | +| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | +| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | +| `anyhow` | error propagation (already used everywhere) | +| `indicatif` | training progress bars | +| `tracing` | structured logging per episode/iteration | + +### 1.3 What `burn-rl` Provides (and Does Not) + +The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) +provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` +trait. It does **not** provide: + +- MCTS or any tree-search algorithm +- Two-player self-play +- Legal action masking during training +- Chance-node handling + +For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its +own (simpler, more targeted) traits and implement MCTS from scratch. + +--- + +## 2. Trictrac-Specific Design Constraints + +### 2.1 Multi-Stage Turn Model + +A Trictrac turn passes through up to six `TurnStage` values. Only two involve +genuine player choice: + +| TurnStage | Node type | Handler | +| ---------------- | ------------------------------- | ------------------------------- | +| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | +| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | +| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | +| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | +| `Move` | **Player decision** | MCTS / policy network | +| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | + +The environment wrapper advances through forced/chance stages automatically so +that from the algorithm's perspective every node it sees is a genuine player +decision. + +### 2.2 Stochastic Dice in MCTS + +AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice +introduce stochasticity. Three approaches exist: + +**A. Outcome sampling (recommended)** +During each MCTS simulation, when a chance node is reached, sample one dice +outcome at random and continue. After many simulations the expected value +converges. This is the approach used by OpenSpiel's MCTS for stochastic games +and requires no changes to the standard PUCT formula. + +**B. Chance-node averaging (expectimax)** +At each chance node, expand all 21 unique dice pairs weighted by their +probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is +exact but multiplies the branching factor by ~21 at every dice roll, making it +prohibitively expensive. + +**C. Condition on dice in the observation (current approach)** +Dice values are already encoded at indices [192–193] of `to_tensor()`. The +network naturally conditions on the rolled dice when it evaluates a position. +MCTS only runs on player-decision nodes _after_ the dice have been sampled; +chance nodes are bypassed by the environment wrapper (approach A). The policy +and value heads learn to play optimally given any dice pair. + +**Use approach A + C together**: the environment samples dice automatically +(chance node bypass), and the 217-dim tensor encodes the dice so the network +can exploit them. + +### 2.3 Perspective / Mirroring + +All move rules and tensor encoding are defined from White's perspective. +`to_tensor()` must always be called after mirroring the state for Black. +The environment wrapper handles this transparently: every observation returned +to an algorithm is already in the active player's perspective. + +### 2.4 Legal Action Masking + +A crucial difference from the existing `bot/` code: instead of penalizing +invalid actions with `ERROR_REWARD`, the policy head logits are **masked** +before softmax — illegal action logits are set to `-inf`. This prevents the +network from wasting capacity on illegal moves and eliminates the need for the +penalty-reward hack. + +--- + +## 3. Proposed Crate Architecture + +``` +spiel_bot/ +├── Cargo.toml +└── src/ + ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" + │ + ├── env/ + │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface + │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store + │ + ├── mcts/ + │ ├── mod.rs # MctsConfig, run_mcts() entry point + │ ├── node.rs # MctsNode (visit count, W, prior, children) + │ └── search.rs # simulate(), backup(), select_action() + │ + ├── network/ + │ ├── mod.rs # PolicyValueNet trait + │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads + │ + ├── alphazero/ + │ ├── mod.rs # AlphaZeroConfig + │ ├── selfplay.rs # generate_episode() -> Vec + │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) + │ └── trainer.rs # training loop: selfplay → sample → loss → update + │ + └── agent/ + ├── mod.rs # Agent trait + ├── random.rs # RandomAgent (baseline) + └── mcts_agent.rs # MctsAgent: uses trained network for inference +``` + +Future algorithms slot in without touching the above: + +``` + ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer + └── ppo/ # (future) PPO: impl Algorithm + rollout buffer +``` + +--- + +## 4. Core Traits + +### 4.1 `GameEnv` — the minimal OpenSpiel interface + +```rust +use rand::Rng; + +/// Who controls the current node. +pub enum Player { + P1, // player index 0 + P2, // player index 1 + Chance, // dice roll + Terminal, // game over +} + +pub trait GameEnv: Clone + Send + Sync + 'static { + type State: Clone + Send + Sync; + + /// Fresh game state. + fn new_game(&self) -> Self::State; + + /// Who acts at this node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices (always in [0, action_space())). + /// Empty only at Terminal nodes. + fn legal_actions(&self, s: &Self::State) -> Vec; + + /// Apply a player action (must be legal). + fn apply(&self, s: &mut Self::State, action: usize); + + /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. + fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); + + /// Observation tensor from `pov`'s perspective (0 or 1). + /// Returns 217 f32 values for Trictrac. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Flat observation size (217 for Trictrac). + fn obs_size(&self) -> usize; + + /// Total action-space size (514 for Trictrac). + fn action_space(&self) -> usize; + + /// Game outcome per player, or None if not Terminal. + /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} +``` + +### 4.2 `PolicyValueNet` — neural network interface + +```rust +use burn::prelude::*; + +pub trait PolicyValueNet: Send + Sync { + /// Forward pass. + /// `obs`: [batch, obs_size] tensor. + /// Returns: (policy_logits [batch, action_space], value [batch]). + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); + + /// Save weights to `path`. + fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; + + /// Load weights from `path`. + fn load(path: &std::path::Path) -> anyhow::Result + where + Self: Sized; +} +``` + +### 4.3 `Agent` — player policy interface + +```rust +pub trait Agent: Send { + /// Select an action index given the current game state observation. + /// `legal`: mask of valid action indices. + fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; +} +``` + +--- + +## 5. MCTS Implementation + +### 5.1 Node + +```rust +pub struct MctsNode { + n: u32, // visit count N(s, a) + w: f32, // sum of backed-up values W(s, a) + p: f32, // prior from policy head P(s, a) + children: Vec<(usize, MctsNode)>, // (action_idx, child) + is_expanded: bool, +} + +impl MctsNode { + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT score used for selection. + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} +``` + +### 5.2 Simulation Loop + +One MCTS simulation (for deterministic decision nodes): + +``` +1. SELECTION — traverse from root, always pick child with highest PUCT, + auto-advancing forced/chance nodes via env.apply_chance(). +2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get + (policy_logits, value). Mask illegal actions, softmax + the remaining logits → priors P(s,a) for each child. +3. BACKUP — propagate -value up the tree (negate at each level because + perspective alternates between P1 and P2). +``` + +After `n_simulations` iterations, action selection at the root: + +```rust +// During training: sample proportional to N^(1/temperature) +// During evaluation: argmax N +fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } +``` + +### 5.3 Configuration + +```rust +pub struct MctsConfig { + pub n_simulations: usize, // e.g. 200 + pub c_puct: f32, // exploration constant, e.g. 1.5 + pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 + pub dirichlet_eps: f32, // noise weight, e.g. 0.25 + pub temperature: f32, // action sampling temperature (anneals to 0) +} +``` + +### 5.4 Handling Chance Nodes Inside MCTS + +When simulation reaches a Chance node (dice roll), the environment automatically +samples dice and advances to the next decision node. The MCTS tree does **not** +branch on dice outcomes — it treats the sampled outcome as the state. This +corresponds to "outcome sampling" (approach A from §2.2). Because each +simulation independently samples dice, the Q-values at player nodes converge to +their expected value over many simulations. + +--- + +## 6. Network Architecture + +### 6.1 ResNet Policy-Value Network + +A single trunk with residual blocks, then two heads: + +``` +Input: [batch, 217] + ↓ +Linear(217 → 512) + ReLU + ↓ +ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) + ↓ trunk output [batch, 512] + ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) + └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) +``` + +Burn implementation sketch: + +```rust +#[derive(Module, Debug)] +pub struct TrictracNet { + input: Linear, + res_blocks: Vec>, + policy_head: Linear, + value_head: Linear, +} + +impl TrictracNet { + pub fn forward(&self, obs: Tensor) + -> (Tensor, Tensor) + { + let x = activation::relu(self.input.forward(obs)); + let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); + let policy = self.policy_head.forward(x.clone()); // raw logits + let value = activation::tanh(self.value_head.forward(x)) + .squeeze(1); + (policy, value) + } +} +``` + +A simpler MLP (no residual blocks) is sufficient for a first version and much +faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two +heads. + +### 6.2 Loss Function + +``` +L = MSE(value_pred, z) + + CrossEntropy(policy_logits_masked, π_mcts) + - c_l2 * L2_regularization +``` + +Where: + +- `z` = game outcome (±1) from the active player's perspective +- `π_mcts` = normalized MCTS visit counts at the root (the policy target) +- Legal action masking is applied before computing CrossEntropy + +--- + +## 7. AlphaZero Training Loop + +``` +INIT + network ← random weights + replay ← empty ReplayBuffer(capacity = 100_000) + +LOOP forever: + ── Self-play phase ────────────────────────────────────────────── + (parallel with rayon, n_workers games at once) + for each game: + state ← env.new_game() + samples = [] + while not terminal: + advance forced/chance nodes automatically + obs ← env.observation(state, current_player) + legal ← env.legal_actions(state) + π, root_value ← mcts.run(state, network, config) + action ← sample from π (with temperature) + samples.push((obs, π, current_player)) + env.apply(state, action) + z ← env.returns(state) // final scores + for (obs, π, player) in samples: + replay.push(TrainSample { obs, policy: π, value: z[player] }) + + ── Training phase ─────────────────────────────────────────────── + for each gradient step: + batch ← replay.sample(batch_size) + (policy_logits, value_pred) ← network.forward(batch.obs) + loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) + optimizer.step(loss.backward()) + + ── Evaluation (every N iterations) ───────────────────────────── + win_rate ← evaluate(network_new vs network_prev, n_eval_games) + if win_rate > 0.55: save checkpoint +``` + +### 7.1 Replay Buffer + +```rust +pub struct TrainSample { + pub obs: Vec, // 217 values + pub policy: Vec, // 514 values (normalized MCTS visit counts) + pub value: f32, // game outcome ∈ {-1, 0, +1} +} + +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + pub fn push(&mut self, s: TrainSample) { + if self.data.len() == self.capacity { self.data.pop_front(); } + self.data.push_back(s); + } + + pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + // sample without replacement + } +} +``` + +### 7.2 Parallelism Strategy + +Self-play is embarrassingly parallel (each game is independent): + +```rust +let samples: Vec = (0..n_games) + .into_par_iter() // rayon + .flat_map(|_| generate_episode(&env, &network, &mcts_config)) + .collect(); +``` + +Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. +Self-play uses inference-only (no gradient tape), so a `NdArray` backend +(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with +`Autodiff>`. + +For larger scale, a producer-consumer architecture (crossbeam-channel) separates +self-play workers from the training thread, allowing continuous data generation +while the GPU trains. + +--- + +## 8. `TrictracEnv` Implementation Sketch + +```rust +use trictrac_store::{ + training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +#[derive(Clone)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + fn current_player(&self, s: &GameState) -> Player { + match s.stage { + Stage::Ended => Player::Terminal, + _ => match s.turn_stage { + TurnStage::RollWaiting => Player::Chance, + _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, + }, + } + } + + fn legal_actions(&self, s: &GameState) -> Vec { + let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; + get_valid_action_indices(&view).unwrap_or_default() + } + + fn apply(&self, s: &mut GameState, action_idx: usize) { + // advance all forced/chance nodes first, then apply the player action + self.advance_forced(s); + let needs_mirror = s.active_player_id == 2; + let view = if needs_mirror { s.mirror() } else { s.clone() }; + if let Some(event) = TrictracAction::from_action_index(action_idx) + .and_then(|a| a.to_event(&view)) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + { + let _ = s.consume(&event); + } + // advance any forced stages that follow + self.advance_forced(s); + } + + fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { + // RollDice → RollWaiting + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + // RollWaiting → next stage + let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; + let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); + self.advance_forced(s); + } + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } + } + + fn obs_size(&self) -> usize { 217 } + fn action_space(&self) -> usize { ACTION_SPACE_SIZE } + + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { return None; } + // Convert hole+point scores to ±1 outcome + let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [ 1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [ 0.0, 0.0], + }) + } +} + +impl TrictracEnv { + /// Advance through all forced (non-decision, non-chance) stages. + fn advance_forced(&self, s: &mut GameState) { + use trictrac_store::PointsRules; + loop { + match s.turn_stage { + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // Scoring is deterministic; compute and apply automatically. + let color = s.player_color_by_id(&s.active_player_id) + .unwrap_or(trictrac_store::Color::White); + let drc = s.players.get(&s.active_player_id) + .map(|p| p.dice_roll_count).unwrap_or(0); + let pr = PointsRules::new(&color, &s.board, s.dice); + let pts = pr.get_points(drc); + let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; + let _ = s.consume(&GameEvent::Mark { + player_id: s.active_player_id, points, + }); + } + TurnStage::RollDice => { + // RollDice is a forced "initiate roll" action with no real choice. + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + } + _ => break, + } + } + } +} +``` + +--- + +## 9. Cargo.toml Changes + +### 9.1 Add `spiel_bot` to the workspace + +```toml +# Cargo.toml (workspace root) +[workspace] +resolver = "2" +members = ["client_cli", "bot", "store", "spiel_bot"] +``` + +### 9.2 `spiel_bot/Cargo.toml` + +```toml +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[features] +default = ["alphazero"] +alphazero = [] +# dqn = [] # future +# ppo = [] # future + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rayon = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Burn: NdArray for pure-Rust CPU training +# Replace NdArray with Wgpu or Tch for GPU. +burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +# Optional: progress display and structured logging +indicatif = "0.17" +tracing = "0.1" + +[[bin]] +name = "az_train" +path = "src/bin/az_train.rs" + +[[bin]] +name = "az_eval" +path = "src/bin/az_eval.rs" +``` + +--- + +## 10. Comparison: `bot` crate vs `spiel_bot` + +| Aspect | `bot` (existing) | `spiel_bot` (proposed) | +| ---------------- | --------------------------- | -------------------------------------------- | +| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | +| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | +| Opponent | hardcoded random | self-play | +| Invalid actions | penalise with reward | legal action mask (no penalty) | +| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | +| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | +| Burn dep | yes (0.20) | yes (0.20), same backend | +| `burn-rl` dep | yes | no | +| C++ dep | no | no | +| Python dep | no | no | +| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | + +The two crates are **complementary**: `bot` is a working DQN/PPO baseline; +`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The +`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just +replace `TrictracEnvironment` with `TrictracEnv`). + +--- + +## 11. Implementation Order + +1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game + through the trait, verify terminal state and returns). +2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + + Burn forward/backward pass test with dummy data. +3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests + (visit counts sum to `n_simulations`, legal mask respected). +4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub + (one iteration, check loss decreases). +5. **Integration test**: run 100 self-play games with a tiny network (1 res block, + 64 hidden units), verify the training loop completes without panics. +6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s + on CPU, consistent with `random_game` throughput). +7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule + hyperparameter sweep. +8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report + win rate every checkpoint. + +--- + +## 12. Key Open Questions + +1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. + AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever + scored more holes). Better option: normalize the score margin. The final + choice affects how the value head is trained. + +2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). + MCTS with 200 simulations per step means ~120k network evaluations per game. + At batch inference this is feasible on CPU; on GPU it becomes fast. Consider + limiting `n_simulations` to 50–100 for early training. + +3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). + This is a long-horizon decision that AlphaZero handles naturally via MCTS + lookahead, but needs careful value normalization (a "Go" restarts scoring + within the same game). + +4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated + to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. + This is optional but reduces code duplication. + +5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, + 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. + A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. + +## Implementation results + +All benchmarks compile and run. Here's the complete results table: + +| Group | Benchmark | Time | +| ------- | ----------------------- | --------------------- | +| env | apply_chance | 3.87 µs | +| | legal_actions | 1.91 µs | +| | observation (to_tensor) | 341 ns | +| | random_game (baseline) | 3.55 ms → 282 games/s | +| network | mlp_b1 hidden=64 | 94.9 µs | +| | mlp_b32 hidden=64 | 141 µs | +| | mlp_b1 hidden=256 | 352 µs | +| | mlp_b32 hidden=256 | 479 µs | +| mcts | zero_eval n=1 | 6.8 µs | +| | zero_eval n=5 | 23.9 µs | +| | zero_eval n=20 | 90.9 µs | +| | mlp64 n=1 | 203 µs | +| | mlp64 n=5 | 622 µs | +| | mlp64 n=20 | 2.30 ms | +| episode | trictrac n=1 | 51.8 ms → 19 games/s | +| | trictrac n=2 | 145 ms → 7 games/s | +| train | mlp64 Adam b=16 | 1.93 ms | +| | mlp64 Adam b=64 | 2.68 ms | + +Key observations: + +- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) +- observation (217-value tensor): only 341 ns — not a bottleneck +- legal_actions: 1.9 µs — well optimised +- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs +- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal +- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it +- Training: 2.7 ms/step at batch=64 → 370 steps/s + +### Summary of Step 8 + +spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: + +- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct +- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) +- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side +- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) +- Output: win/draw/loss per side + combined decisive win rate + +Typical usage after training: +cargo run -p spiel_bot --bin az_eval --release -- \ + --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 + +### az_train + +#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) + +cargo run -p spiel_bot --bin az_train --release + +#### ResNet, more games, custom output dir + +cargo run -p spiel_bot --bin az_train --release -- \ + --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ + --save-every 20 --out checkpoints/ + +#### Resume from iteration 50 + +cargo run -p spiel_bot --bin az_train --release -- \ + --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 + +What the binary does each iteration: + +1. Calls model.valid() to get a zero-overhead inference copy for self-play +2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) +3. Pushes samples into a ReplayBuffer (capacity --replay-cap) +4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min +5. Saves a .mpk checkpoint every --save-every iterations and always on the last From a6644e3c9dd7ff09227814e3a58322d0942c57bd Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:10:49 +0100 Subject: [PATCH 02/16] fix: to_tensor() normalization --- store/src/game.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/store/src/game.rs b/store/src/game.rs index f553bdb..2fde45c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -225,22 +225,26 @@ impl GameState { let mut t = Vec::with_capacity(217); let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - // [0..95] own (White) checkers, TD-Gammon encoding + // [0..95] own (White) checkers, TD-Gammon encoding. + // Each field contributes 4 values: + // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] + // The overflow term is divided by 12 because the maximum excess is + // 15 (all checkers) − 3 = 12. for &c in &pos { let own = c.max(0) as u8; t.push((own == 1) as u8 as f32); t.push((own == 2) as u8 as f32); t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32); + t.push(own.saturating_sub(3) as f32 / 12.0); } - // [96..191] opp (Black) checkers, TD-Gammon encoding + // [96..191] opp (Black) checkers, same encoding. for &c in &pos { let opp = (-c).max(0) as u8; t.push((opp == 1) as u8 as f32); t.push((opp == 2) as u8 as f32); t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32); + t.push(opp.saturating_sub(3) as f32 / 12.0); } // [192..193] dice From df05a430225ca63485d22dcd3c1794885df512dc Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:12:59 +0100 Subject: [PATCH 03/16] feat(spiel_bot): init crate & implements `GameEnv` trait + `TrictracEnv` --- Cargo.lock | 9 + Cargo.toml | 2 +- spiel_bot/Cargo.toml | 9 + spiel_bot/src/env/mod.rs | 121 ++++++++ spiel_bot/src/env/trictrac.rs | 535 ++++++++++++++++++++++++++++++++++ spiel_bot/src/lib.rs | 1 + 6 files changed, 676 insertions(+), 1 deletion(-) create mode 100644 spiel_bot/Cargo.toml create mode 100644 spiel_bot/src/env/mod.rs create mode 100644 spiel_bot/src/env/trictrac.rs create mode 100644 spiel_bot/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index a43261e..d1f5a20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5891,6 +5891,15 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spiel_bot" +version = "0.1.0" +dependencies = [ + "anyhow", + "rand 0.9.2", + "trictrac-store", +] + [[package]] name = "spin" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index b9e6d45..4c2eb15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store"] +members = ["client_cli", "bot", "store", "spiel_bot"] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml new file mode 100644 index 0000000..2459f51 --- /dev/null +++ b/spiel_bot/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs new file mode 100644 index 0000000..42b4ae0 --- /dev/null +++ b/spiel_bot/src/env/mod.rs @@ -0,0 +1,121 @@ +//! Game environment abstraction — the minimal "Rust OpenSpiel". +//! +//! A `GameEnv` describes the rules of a two-player, zero-sum game that may +//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, +//! and PPO interact with a game exclusively through this trait. +//! +//! # Node taxonomy +//! +//! Every game position belongs to one of four categories, returned by +//! [`GameEnv::current_player`]: +//! +//! | [`Player`] | Meaning | +//! |-----------|---------| +//! | `P1` | Player 1 (index 0) must choose an action | +//! | `P2` | Player 2 (index 1) must choose an action | +//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | +//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | +//! +//! # Perspective convention +//! +//! [`GameEnv::observation`] always returns the board from *the requested +//! player's* point of view. Callers pass `pov = 0` for Player 1 and +//! `pov = 1` for Player 2. The implementation is responsible for any +//! mirroring required (e.g. Trictrac always reasons from White's side). + +pub mod trictrac; +pub use trictrac::TrictracEnv; + +/// Who controls the current game node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Player { + /// Player 1 (index 0) is to move. + P1, + /// Player 2 (index 1) is to move. + P2, + /// A stochastic event (dice roll, etc.) must be resolved. + Chance, + /// The game is over. + Terminal, +} + +impl Player { + /// Returns the player index (0 or 1) if this is a decision node, + /// or `None` for `Chance` / `Terminal`. + pub fn index(self) -> Option { + match self { + Player::P1 => Some(0), + Player::P2 => Some(1), + _ => None, + } + } + + pub fn is_decision(self) -> bool { + matches!(self, Player::P1 | Player::P2) + } + + pub fn is_chance(self) -> bool { + self == Player::Chance + } + + pub fn is_terminal(self) -> bool { + self == Player::Terminal + } +} + +/// Trait that completely describes a two-player zero-sum game. +/// +/// Implementors must be cheaply cloneable (the type is used as a stateless +/// factory; the mutable game state lives in `Self::State`). +pub trait GameEnv: Clone + Send + Sync + 'static { + /// The mutable game state. Must be `Clone` so MCTS can copy + /// game trees without touching the environment. + type State: Clone + Send + Sync; + + // ── State creation ──────────────────────────────────────────────────── + + /// Create a fresh game state at the initial position. + fn new_game(&self) -> Self::State; + + // ── Node queries ────────────────────────────────────────────────────── + + /// Classify the current node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). + /// + /// The returned indices are in `[0, action_space())`. + /// The result is unspecified (may panic or return empty) when called at a + /// `Chance` or `Terminal` node. + fn legal_actions(&self, s: &Self::State) -> Vec; + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action. `action` must be a value returned by + /// [`legal_actions`] for the current state. + fn apply(&self, s: &mut Self::State, action: usize); + + /// Sample and apply a stochastic outcome. Must only be called when + /// `current_player(s) == Player::Chance`. + fn apply_chance(&self, s: &mut Self::State, rng: &mut R); + + // ── Observation ─────────────────────────────────────────────────────── + + /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). + /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Number of floats returned by [`observation`]. + fn obs_size(&self) -> usize; + + /// Total number of distinct action indices (the policy head output size). + fn action_space(&self) -> usize; + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Game outcome for each player, or `None` if the game is not over. + /// + /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. + /// Index 0 = Player 1, index 1 = Player 2. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs new file mode 100644 index 0000000..99ba058 --- /dev/null +++ b/spiel_bot/src/env/trictrac.rs @@ -0,0 +1,535 @@ +//! [`GameEnv`] implementation for Trictrac. +//! +//! # Game flow (schools_enabled = false) +//! +//! With scoring schools disabled (the standard training configuration), +//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine +//! applies them automatically inside `RollResult` and `Move`. The only +//! four stages that actually occur are: +//! +//! | `TurnStage` | [`Player`] kind | Handled by | +//! |-------------|-----------------|------------| +//! | `RollDice` | `Chance` | [`apply_chance`] | +//! | `RollWaiting` | `Chance` | [`apply_chance`] | +//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | +//! | `Move` | `P1`/`P2` | [`apply`] | +//! +//! # Perspective +//! +//! The Trictrac engine always reasons from White's perspective. Player 1 is +//! White; Player 2 is Black. When Player 2 is active, the board is mirrored +//! before computing legal actions / the observation tensor, and the resulting +//! event is mirrored back before being applied to the real state. This +//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. + +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, + Dice, GameEvent, GameState, Stage, TurnStage, +}; + +use super::{GameEnv, Player}; + +/// Stateless factory that produces Trictrac [`GameState`] environments. +/// +/// Schools (`schools_enabled`) are always disabled — scoring is automatic. +#[derive(Clone, Debug, Default)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + // ── State creation ──────────────────────────────────────────────────── + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + // ── Node queries ────────────────────────────────────────────────────── + + fn current_player(&self, s: &GameState) -> Player { + if s.stage == Stage::Ended { + return Player::Terminal; + } + match s.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, + _ => { + if s.active_player_id == 1 { + Player::P1 + } else { + Player::P2 + } + } + } + } + + /// Returns the legal action indices for the active player. + /// + /// The board is automatically mirrored for Player 2 so that the engine + /// always reasons from White's perspective. The returned indices are + /// identical in meaning for both players (checker ordinals are + /// perspective-relative). + /// + /// # Panics + /// + /// Panics in debug builds if called at a `Chance` or `Terminal` node. + fn legal_actions(&self, s: &GameState) -> Vec { + debug_assert!( + self.current_player(s).is_decision(), + "legal_actions called at a non-decision node (turn_stage={:?})", + s.turn_stage + ); + let indices = if s.active_player_id == 2 { + get_valid_action_indices(&s.mirror()) + } else { + get_valid_action_indices(s) + }; + indices.unwrap_or_default() + } + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action index to the game state. + /// + /// For Player 2, the action is decoded against the mirrored board and + /// the resulting event is un-mirrored before being applied. + /// + /// # Panics + /// + /// Panics in debug builds if `action` cannot be decoded or does not + /// produce a valid event for the current state. + fn apply(&self, s: &mut GameState, action: usize) { + let needs_mirror = s.active_player_id == 2; + + let event = if needs_mirror { + let view = s.mirror(); + TrictracAction::from_action_index(action) + .and_then(|a| a.to_event(&view)) + .map(|e| e.get_mirror(false)) + } else { + TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) + }; + + match event { + Some(e) => { + s.consume(&e).expect("apply: consume failed for valid action"); + } + None => { + panic!("apply: action index {action} produced no event in state {s}"); + } + } + } + + /// Sample dice and advance through a chance node. + /// + /// Handles both `RollDice` (triggers the roll mechanism, then samples + /// dice) and `RollWaiting` (only samples dice) in a single call so that + /// callers never need to distinguish the two. + /// + /// # Panics + /// + /// Panics in debug builds if called at a non-Chance node. + fn apply_chance(&self, s: &mut GameState, rng: &mut R) { + debug_assert!( + self.current_player(s).is_chance(), + "apply_chance called at a non-Chance node (turn_stage={:?})", + s.turn_stage + ); + + // Step 1: RollDice → RollWaiting (player initiates the roll). + if s.turn_stage == TurnStage::RollDice { + s.consume(&GameEvent::Roll { + player_id: s.active_player_id, + }) + .expect("apply_chance: Roll event failed"); + } + + // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. + // With schools_enabled=false, point marking is automatic inside consume(). + let dice = Dice { + values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), + }; + s.consume(&GameEvent::RollResult { + player_id: s.active_player_id, + dice, + }) + .expect("apply_chance: RollResult event failed"); + } + + // ── Observation ─────────────────────────────────────────────────────── + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { + s.to_tensor() + } else { + s.mirror().to_tensor() + } + } + + fn obs_size(&self) -> usize { + 217 + } + + fn action_space(&self) -> usize { + ACTION_SPACE_SIZE + } + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. + /// + /// The winner (higher cumulative score) receives `+1.0`; the loser + /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score + /// is `holes × 12 + points`. + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { + return None; + } + let score = |id: u64| -> i32 { + s.players + .get(&id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + }; + let s1 = score(1); + let s2 = score(2); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [0.0, 0.0], + }) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{rngs::SmallRng, Rng, SeedableRng}; + + fn env() -> TrictracEnv { + TrictracEnv + } + + fn seeded_rng(seed: u64) -> SmallRng { + SmallRng::seed_from_u64(seed) + } + + // ── Initial state ───────────────────────────────────────────────────── + + #[test] + fn new_game_is_chance_node() { + let e = env(); + let s = e.new_game(); + // A fresh game starts at RollDice — a Chance node. + assert_eq!(e.current_player(&s), Player::Chance); + assert!(e.returns(&s).is_none()); + } + + #[test] + fn new_game_is_not_terminal() { + let e = env(); + let s = e.new_game(); + assert_ne!(e.current_player(&s), Player::Terminal); + assert!(e.returns(&s).is_none()); + } + + // ── Chance nodes ────────────────────────────────────────────────────── + + #[test] + fn apply_chance_reaches_decision_node() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(1); + + // A single chance step must yield a decision node (or end the game, + // which only happens after 12 holes — impossible on the first roll). + e.apply_chance(&mut s, &mut rng); + let p = e.current_player(&s); + assert!( + p.is_decision(), + "expected decision node after first roll, got {p:?}" + ); + } + + #[test] + fn apply_chance_from_rollwaiting() { + // Check that apply_chance works when called mid-way (at RollWaiting). + let e = env(); + let mut s = e.new_game(); + assert_eq!(s.turn_stage, TurnStage::RollDice); + + // Manually advance to RollWaiting. + s.consume(&GameEvent::Roll { player_id: s.active_player_id }) + .unwrap(); + assert_eq!(s.turn_stage, TurnStage::RollWaiting); + + let mut rng = seeded_rng(2); + e.apply_chance(&mut s, &mut rng); + + let p = e.current_player(&s); + assert!(p.is_decision() || p.is_terminal()); + } + + // ── Legal actions ───────────────────────────────────────────────────── + + #[test] + fn legal_actions_nonempty_after_roll() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(3); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let actions = e.legal_actions(&s); + assert!( + !actions.is_empty(), + "legal_actions must be non-empty at a decision node" + ); + } + + #[test] + fn legal_actions_within_action_space() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(4); + + e.apply_chance(&mut s, &mut rng); + for &a in e.legal_actions(&s).iter() { + assert!( + a < e.action_space(), + "action {a} out of bounds (action_space={})", + e.action_space() + ); + } + } + + // ── Observations ────────────────────────────────────────────────────── + + #[test] + fn observation_has_correct_size() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(5); + e.apply_chance(&mut s, &mut rng); + + assert_eq!(e.observation(&s, 0).len(), e.obs_size()); + assert_eq!(e.observation(&s, 1).len(), e.obs_size()); + } + + #[test] + fn observation_values_in_unit_interval() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(6); + e.apply_chance(&mut s, &mut rng); + + for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { + for (i, &v) in obs.iter().enumerate() { + assert!( + v >= 0.0 && v <= 1.0, + "pov={pov}: obs[{i}] = {v} is outside [0,1]" + ); + } + } + } + + #[test] + fn p1_and_p2_observations_differ() { + // The board is mirrored for P2, so the two observations should differ + // whenever there are checkers in non-symmetric positions (always true + // in a real game after a few moves). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(7); + + // Advance far enough that the board is non-trivial. + for _ in 0..6 { + while e.current_player(&s).is_chance() { + e.apply_chance(&mut s, &mut rng); + } + if e.current_player(&s).is_terminal() { + break; + } + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + + if !e.current_player(&s).is_terminal() { + let obs0 = e.observation(&s, 0); + let obs1 = e.observation(&s, 1); + assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); + } + } + + // ── Applying actions ────────────────────────────────────────────────── + + #[test] + fn apply_changes_state() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(8); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let before = s.clone(); + let action = e.legal_actions(&s)[0]; + e.apply(&mut s, action); + + assert_ne!( + before.turn_stage, s.turn_stage, + "state must change after apply" + ); + } + + #[test] + fn apply_all_legal_actions_do_not_panic() { + // Verify that every action returned by legal_actions can be applied + // without panicking (on several independent copies of the same state). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(9); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + for action in e.legal_actions(&s) { + let mut copy = s.clone(); + e.apply(&mut copy, action); // must not panic + } + } + + // ── Full game ───────────────────────────────────────────────────────── + + /// Run a complete game with random actions through the `GameEnv` trait + /// and verify that: + /// - The game terminates. + /// - `returns()` is `Some` at the end. + /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's + /// score is ±1 / 0. + /// - No step panics. + #[test] + fn full_random_game_terminates() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(42); + let max_steps = 50_000; + + for step in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); + } + + let result = e.returns(&s); + assert!(result.is_some(), "returns() must be Some at Terminal"); + + let [r1, r2] = result.unwrap(); + let sum = r1 + r2; + assert!( + (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, + "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" + ); + assert!( + r1.abs() <= 1.0 && r2.abs() <= 1.0, + "returns must be in [-1,1]: r1={r1}, r2={r2}" + ); + } + + /// Run multiple games with different seeds to stress-test for panics. + #[test] + fn multiple_games_no_panic() { + let e = env(); + let max_steps = 20_000; + + for seed in 0..10u64 { + let mut s = e.new_game(); + let mut rng = seeded_rng(seed); + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + } + } + } + + // ── Returns ─────────────────────────────────────────────────────────── + + #[test] + fn returns_none_mid_game() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(11); + + // Advance a few steps but do not finish the game. + for _ in 0..4 { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + if !e.current_player(&s).is_terminal() { + assert!( + e.returns(&s).is_none(), + "returns() must be None before the game ends" + ); + } + } + + // ── Player 2 actions ────────────────────────────────────────────────── + + /// Verify that Player 2 (Black) can take actions without panicking, + /// and that the state advances correctly. + #[test] + fn player2_can_act() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(12); + + // Keep stepping until Player 2 gets a turn. + let max_steps = 5_000; + let mut p2_acted = false; + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty()); + e.apply(&mut s, actions[0]); + p2_acted = true; + break; + } + Player::P1 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs new file mode 100644 index 0000000..3d7924f --- /dev/null +++ b/spiel_bot/src/lib.rs @@ -0,0 +1 @@ +pub mod env; From d5cd4c2402d294ad5a3ab0fa23d778193856ef60 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:30:27 +0100 Subject: [PATCH 04/16] feat(spiel_bot): network with mlp and resnet --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/network/mlp.rs | 223 ++++++++++++++++++++++++++++ spiel_bot/src/network/mod.rs | 64 ++++++++ spiel_bot/src/network/resnet.rs | 253 ++++++++++++++++++++++++++++++++ 6 files changed, 543 insertions(+) create mode 100644 spiel_bot/src/network/mlp.rs create mode 100644 spiel_bot/src/network/mod.rs create mode 100644 spiel_bot/src/network/resnet.rs diff --git a/Cargo.lock b/Cargo.lock index d1f5a20..2e81285 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5896,6 +5896,7 @@ name = "spiel_bot" version = "0.1.0" dependencies = [ "anyhow", + "burn", "rand 0.9.2", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 2459f51..fba2aab 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 3d7924f..6e71016 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1 +1,2 @@ pub mod env; +pub mod network; diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs new file mode 100644 index 0000000..eb6184e --- /dev/null +++ b/spiel_bot/src/network/mlp.rs @@ -0,0 +1,223 @@ +//! Two-hidden-layer MLP policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`MlpNet`]. +#[derive(Debug, Clone)] +pub struct MlpConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for MlpConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 256, + } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Simple two-hidden-layer MLP with shared trunk and two heads. +/// +/// Prefer this over [`ResNet`](super::ResNet) when training time is a +/// priority, or as a fast baseline. +#[derive(Module, Debug)] +pub struct MlpNet { + fc1: Linear, + fc2: Linear, + policy_head: Linear, + value_head: Linear, +} + +impl MlpNet { + /// Construct a fresh network with random weights. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + value_head: LinearConfig::new(config.hidden_size, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + /// + /// The file is written exactly at `path`; callers should append `.mpk` if + /// they want the conventional extension. + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for MlpNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn default_net() -> MlpNet { + MlpNet::new(&MlpConfig::default(), &device()) + } + + fn zeros_obs(batch: usize) -> Tensor { + Tensor::zeros([batch, 217], &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let net = default_net(); + let obs = zeros_obs(4); + let (policy, value) = net.forward(obs); + + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let (policy, value) = net.forward(zeros_obs(1)); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let net = default_net(); + // Use a non-zero input so the output is not trivially at 0. + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net.forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Policy logits ───────────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + // With random weights the 514 logits should not all be identical. + let net = default_net(); + let (policy, _) = net.forward(zeros_obs(1)); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical — network may be degenerate"); + } + + // ── Config propagation ──────────────────────────────────────────────── + + #[test] + fn custom_config_shapes() { + let config = MlpConfig { + obs_size: 10, + action_size: 20, + hidden_size: 32, + }; + let net = MlpNet::::new(&config, &device()); + let obs = Tensor::zeros([3, 10], &device()); + let (policy, value) = net.forward(obs); + assert_eq!(policy.dims(), [3, 20]); + assert_eq!(value.dims(), [3, 1]); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = MlpConfig::default(); + let net = default_net(); + + // Forward pass before saving. + let obs = Tensor::::ones([2, 217], &device()); + let (policy_before, value_before) = net.forward(obs.clone()); + + // Save to a temp file. + let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); + net.save(&path).expect("save failed"); + + // Load into a fresh model. + let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + // Outputs must be bitwise identical. + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs new file mode 100644 index 0000000..df710e9 --- /dev/null +++ b/spiel_bot/src/network/mod.rs @@ -0,0 +1,64 @@ +//! Neural network abstractions for policy-value learning. +//! +//! # Trait +//! +//! [`PolicyValueNet`] is the single trait that all network architectures +//! implement. It takes an observation tensor and returns raw policy logits +//! plus a tanh-squashed scalar value estimate. +//! +//! # Architectures +//! +//! | Module | Description | Default hidden | +//! |--------|-------------|----------------| +//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | +//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | +//! +//! # Backend convention +//! +//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). +//! * **Training** — use `Autodiff>` so Burn can differentiate +//! through the forward pass. +//! +//! Both modes use the exact same struct; only the type-level backend changes: +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! type InferBackend = NdArray; +//! type TrainBackend = Autodiff>; +//! +//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! ``` +//! +//! # Output shapes +//! +//! Given a batch of `B` observations of size `obs_size`: +//! +//! | Output | Shape | Range | +//! |--------|-------|-------| +//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | +//! | `value` | `[B, 1]` | (-1, 1) via tanh | +//! +//! Callers are responsible for masking illegal actions in `policy_logits` +//! before passing to softmax. + +pub mod mlp; +pub mod resnet; + +pub use mlp::{MlpConfig, MlpNet}; +pub use resnet::{ResNet, ResNetConfig}; + +use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; + +/// A neural network that produces a policy and a value from an observation. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) +/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses +/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. +/// Use an `Arc>` wrapper if cross-thread sharing is needed. +pub trait PolicyValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); +} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs new file mode 100644 index 0000000..d20d5ad --- /dev/null +++ b/spiel_bot/src/network/resnet.rs @@ -0,0 +1,253 @@ +//! Residual-block policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU (input projection) +//! → ResBlock × 4 (residual trunk) +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! +//! ResBlock: +//! x → Linear → ReLU → Linear → (+x) → ReLU +//! ``` +//! +//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better +//! suited for long training runs where board-pattern recognition matters. + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`ResNet`]. +#[derive(Debug, Clone)] +pub struct ResNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of all hidden layers (input projection + residual blocks). + pub hidden_size: usize, +} + +impl Default for ResNetConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 512, + } + } +} + +// ── Residual block ──────────────────────────────────────────────────────────── + +/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. +/// +/// Both linear layers preserve the hidden dimension so the skip connection +/// can be added without projection. +#[derive(Module, Debug)] +struct ResBlock { + fc1: Linear, + fc2: Linear, +} + +impl ResBlock { + fn new(hidden: usize, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(hidden, hidden).init(device), + fc2: LinearConfig::new(hidden, hidden).init(device), + } + } + + fn forward(&self, x: Tensor) -> Tensor { + let residual = x.clone(); + let out = relu(self.fc1.forward(x)); + relu(self.fc2.forward(out) + residual) + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Four-residual-block policy-value network. +/// +/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and +/// when representing complex positional patterns is important. +#[derive(Module, Debug)] +pub struct ResNet { + input: Linear, + block0: ResBlock, + block1: ResBlock, + block2: ResBlock, + block3: ResBlock, + policy_head: Linear, + value_head: Linear, +} + +impl ResNet { + /// Construct a fresh network with random weights. + pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { + let h = config.hidden_size; + Self { + input: LinearConfig::new(config.obs_size, h).init(device), + block0: ResBlock::new(h, device), + block1: ResBlock::new(h, device), + block2: ResBlock::new(h, device), + block3: ResBlock::new(h, device), + policy_head: LinearConfig::new(h, config.action_size).init(device), + value_head: LinearConfig::new(h, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for ResNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.input.forward(obs)); + let x = self.block0.forward(x); + let x = self.block1.forward(x); + let x = self.block2.forward(x); + let x = self.block3.forward(x); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn small_config() -> ResNetConfig { + // Use a small hidden size so tests are fast. + ResNetConfig { + obs_size: 217, + action_size: 514, + hidden_size: 64, + } + } + + fn net() -> ResNet { + ResNet::new(&small_config(), &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let obs = Tensor::zeros([4, 217], &device()); + let (policy, value) = net().forward(obs); + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net().forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Residual connections ────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical"); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = small_config(); + let model = net(); + let obs = Tensor::::ones([2, 217], &device()); + + let (policy_before, value_before) = model.forward(obs.clone()); + + let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); + model.save(&path).expect("save failed"); + + let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } + + // ── Integration: both architectures satisfy PolicyValueNet ──────────── + + #[test] + fn resnet_satisfies_trait() { + fn requires_net>(net: &N, obs: Tensor) { + let (p, v) = net.forward(obs); + assert_eq!(p.dims()[1], 514); + assert_eq!(v.dims()[1], 1); + } + requires_net(&net(), Tensor::zeros([2, 217], &device())); + } +} From 58ae8ad3b3533edb77e5a7a75d8c3400a3c381b0 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 20:45:02 +0100 Subject: [PATCH 05/16] feat(spiel_bot): Monte-Carlo tree search --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/mcts/mod.rs | 408 +++++++++++++++++++++++++++++++++++ spiel_bot/src/mcts/node.rs | 91 ++++++++ spiel_bot/src/mcts/search.rs | 170 +++++++++++++++ 6 files changed, 672 insertions(+) create mode 100644 spiel_bot/src/mcts/mod.rs create mode 100644 spiel_bot/src/mcts/node.rs create mode 100644 spiel_bot/src/mcts/search.rs diff --git a/Cargo.lock b/Cargo.lock index 2e81285..0baa02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5898,6 +5898,7 @@ dependencies = [ "anyhow", "burn", "rand 0.9.2", + "rand_distr", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index fba2aab..323c953 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 6e71016..5beb37c 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,2 +1,3 @@ pub mod env; +pub mod mcts; pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs new file mode 100644 index 0000000..e92bd09 --- /dev/null +++ b/spiel_bot/src/mcts/mod.rs @@ -0,0 +1,408 @@ +//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. +//! +//! # Algorithm +//! +//! The implementation follows AlphaZero's MCTS: +//! +//! 1. **Expand root** — run the network once to get priors and a value +//! estimate; optionally add Dirichlet noise for training-time exploration. +//! 2. **Simulate** `n_simulations` times: +//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. +//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; +//! chance nodes are **not** stored in the tree (outcome sampling). +//! - *Expansion* — evaluate the network at the leaf; populate children. +//! - *Backup* — propagate the value upward; negate at each player boundary. +//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). +//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). +//! +//! # Perspective convention +//! +//! Every [`MctsNode::w`] is stored **from the perspective of the player who +//! acts at that node**. The backup negates the child value whenever the +//! acting player differs between parent and child. +//! +//! # Stochastic games +//! +//! When [`GameEnv::current_player`] returns [`Player::Chance`], the +//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and +//! continues. Chance nodes are skipped transparently; Q-values converge to +//! their expectation over many simulations (outcome sampling). + +pub mod node; +mod search; + +pub use node::MctsNode; + +use rand::Rng; + +use crate::env::GameEnv; + +// ── Evaluator trait ──────────────────────────────────────────────────────── + +/// Evaluates a game position for use in MCTS. +/// +/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) +/// but the `mcts` module itself does **not** depend on Burn. +pub trait Evaluator: Send + Sync { + /// Evaluate `obs` (flat observation vector of length `obs_size`). + /// + /// Returns: + /// - `policy_logits`: one raw logit per action (`action_space` entries). + /// Illegal action entries are masked inside the search — no need to + /// zero them here. + /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. + fn evaluate(&self, obs: &[f32]) -> (Vec, f32); +} + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Hyperparameters for [`run_mcts`]. +#[derive(Debug, Clone)] +pub struct MctsConfig { + /// Number of MCTS simulations per move. Typical: 50–800. + pub n_simulations: usize, + /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. + pub c_puct: f32, + /// Dirichlet noise concentration α. Set to `0.0` to disable. + /// Typical: `0.3` for Chess, `0.1` for large action spaces. + pub dirichlet_alpha: f32, + /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. + pub dirichlet_eps: f32, + /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. + pub temperature: f32, +} + +impl Default for MctsConfig { + fn default() -> Self { + Self { + n_simulations: 200, + c_puct: 1.5, + dirichlet_alpha: 0.3, + dirichlet_eps: 0.25, + temperature: 1.0, + } + } +} + +// ── Public interface ─────────────────────────────────────────────────────── + +/// Run MCTS from `state` and return the populated root [`MctsNode`]. +/// +/// `state` must be a player-decision node (`P1` or `P2`). +/// Use [`mcts_policy`] and [`select_action`] on the returned root. +/// +/// # Panics +/// +/// Panics if `env.current_player(state)` is not `P1` or `P2`. +pub fn run_mcts( + env: &E, + state: &E::State, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, +) -> MctsNode { + let player_idx = env + .current_player(state) + .index() + .expect("run_mcts called at a non-decision node"); + + // ── Expand root (network called once here, not inside the loop) ──────── + let mut root = MctsNode::new(1.0); + search::expand::(&mut root, state, env, evaluator, player_idx); + + // ── Optional Dirichlet noise for training exploration ────────────────── + if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { + search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); + } + + // ── Simulations ──────────────────────────────────────────────────────── + for _ in 0..config.n_simulations { + search::simulate::( + &mut root, + state.clone(), + env, + evaluator, + config, + rng, + player_idx, + ); + } + + root +} + +/// Compute the MCTS policy: normalized visit counts at the root. +/// +/// Returns a vector of length `action_space` where `policy[a]` is the +/// fraction of simulations that visited action `a`. +pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { + let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); + let mut policy = vec![0.0f32; action_space]; + if total > 0.0 { + for (a, child) in &root.children { + policy[*a] = child.n as f32 / total; + } + } else if !root.children.is_empty() { + // n_simulations = 0: uniform over legal actions. + let uniform = 1.0 / root.children.len() as f32; + for (a, _) in &root.children { + policy[*a] = uniform; + } + } + policy +} + +/// Select an action index from the root after MCTS. +/// +/// * `temperature = 0` — greedy argmax of visit counts. +/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. +/// +/// # Panics +/// +/// Panics if the root has no children. +pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { + assert!(!root.children.is_empty(), "select_action called on a root with no children"); + if temperature <= 0.0 { + root.children + .iter() + .max_by_key(|(_, c)| c.n) + .map(|(a, _)| *a) + .unwrap() + } else { + let weights: Vec = root + .children + .iter() + .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) + .collect(); + let total: f32 = weights.iter().sum(); + let mut r: f32 = rng.random::() * total; + for (i, (a, _)) in root.children.iter().enumerate() { + r -= weights[i]; + if r <= 0.0 { + return *a; + } + } + root.children.last().map(|(a, _)| *a).unwrap() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + use crate::env::Player; + + // ── Minimal deterministic test game ─────────────────────────────────── + // + // "Countdown" — two players alternate subtracting 1 or 2 from a counter. + // The player who brings the counter to 0 wins. + // No chance nodes, two legal actions (0 = -1, 1 = -2). + + #[derive(Clone, Debug)] + struct CState { + remaining: u8, + to_move: usize, // at terminal: last mover (winner) + } + + #[derive(Clone)] + struct CountdownEnv; + + impl crate::env::GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: 6, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { + Player::Terminal + } else if s.to_move == 0 { + Player::P1 + } else { + Player::P2 + } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + // to_move stays as winner + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 6.0, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + // Uniform evaluator: all logits = 0, value = 0. + // `action_space` must match the environment's `action_space()`. + struct ZeroEval(usize); + impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(42) + } + + fn config_n(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, // off for reproducibility + dirichlet_eps: 0.0, + temperature: 1.0, + } + } + + // ── Visit count tests ───────────────────────────────────────────────── + + #[test] + fn visit_counts_sum_to_n_simulations() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 50, "visit counts must sum to n_simulations"); + } + + #[test] + fn all_root_children_are_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); + for (a, _) in &root.children { + assert!(legal.contains(a), "child action {a} is not legal"); + } + } + + // ── Policy tests ───────────────────────────────────────────────────── + + #[test] + fn policy_sums_to_one() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); + } + + #[test] + fn policy_zero_for_illegal_actions() { + let env = CountdownEnv; + // remaining = 1 → only action 0 is legal + let state = CState { remaining: 1, to_move: 0 }; + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); + } + + // ── Action selection tests ──────────────────────────────────────────── + + #[test] + fn greedy_selects_most_visited() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); + let greedy = select_action(&root, 0.0, &mut rng()); + let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); + assert_eq!(greedy, most_visited); + } + + #[test] + fn temperature_sampling_stays_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let mut r = rng(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); + for _ in 0..20 { + let a = select_action(&root, 1.0, &mut r); + assert!(legal.contains(&a), "sampled action {a} is not legal"); + } + } + + // ── Zero-simulation edge case ───────────────────────────────────────── + + #[test] + fn zero_simulations_uniform_policy() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + // With 0 simulations, fallback is uniform over the 2 legal actions. + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + // ── Root value ──────────────────────────────────────────────────────── + + #[test] + fn root_q_in_valid_range() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); + let q = root.q(); + assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); + } + + // ── Integration: run on a real Trictrac game ────────────────────────── + + #[test] + fn no_panic_on_trictrac_state() { + use crate::env::TrictracEnv; + + let env = TrictracEnv; + let mut state = env.new_game(); + let mut r = rng(); + + // Advance past the initial chance node to reach a decision node. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, &mut r); + } + + if env.current_player(&state).is_terminal() { + return; // unlikely but safe + } + + let config = MctsConfig { + n_simulations: 5, // tiny for speed + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + ..MctsConfig::default() + }; + + let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); + assert!(root.n > 0); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 5); + } +} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs new file mode 100644 index 0000000..aff7735 --- /dev/null +++ b/spiel_bot/src/mcts/node.rs @@ -0,0 +1,91 @@ +//! MCTS tree node. +//! +//! [`MctsNode`] holds the visit statistics for one player-decision position in +//! the search tree. A node is *expanded* the first time the policy-value +//! network is evaluated there; before that it is a leaf. + +/// One node in the MCTS tree, representing a player-decision position. +/// +/// `w` stores the sum of values backed up into this node, always from the +/// perspective of **the player who acts here**. `q()` therefore also returns +/// a value in `(-1, 1)` from that same perspective. +#[derive(Debug)] +pub struct MctsNode { + /// Visit count `N(s, a)`. + pub n: u32, + /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. + pub w: f32, + /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). + pub p: f32, + /// Children: `(action_index, child_node)`, populated on first expansion. + pub children: Vec<(usize, MctsNode)>, + /// `true` after the network has been evaluated and children have been set up. + pub expanded: bool, +} + +impl MctsNode { + /// Create a fresh, unexpanded leaf with the given prior probability. + pub fn new(prior: f32) -> Self { + Self { + n: 0, + w: 0.0, + p: prior, + children: Vec::new(), + expanded: false, + } + } + + /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. + #[inline] + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT selection score: + /// + /// ```text + /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) + /// ``` + #[inline] + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_zero_when_unvisited() { + let node = MctsNode::new(0.5); + assert_eq!(node.q(), 0.0); + } + + #[test] + fn q_reflects_w_over_n() { + let mut node = MctsNode::new(0.5); + node.n = 4; + node.w = 2.0; + assert!((node.q() - 0.5).abs() < 1e-6); + } + + #[test] + fn puct_exploration_dominates_unvisited() { + // Unvisited child should outscore a visited child with negative Q. + let mut visited = MctsNode::new(0.5); + visited.n = 10; + visited.w = -5.0; // Q = -0.5 + + let unvisited = MctsNode::new(0.5); + + let parent_n = 10; + let c = 1.5; + assert!( + unvisited.puct(parent_n, c) > visited.puct(parent_n, c), + "unvisited child should have higher PUCT than a negatively-valued visited child" + ); + } +} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs new file mode 100644 index 0000000..c4960c7 --- /dev/null +++ b/spiel_bot/src/mcts/search.rs @@ -0,0 +1,170 @@ +//! Simulation, expansion, backup, and noise helpers. +//! +//! These are internal to the `mcts` module; the public entry points are +//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. + +use rand::Rng; +use rand_distr::{Gamma, Distribution}; + +use crate::env::GameEnv; +use super::{Evaluator, MctsConfig}; +use super::node::MctsNode; + +// ── Masked softmax ───────────────────────────────────────────────────────── + +/// Numerically stable softmax over `legal` actions only. +/// +/// Illegal logits are treated as `-∞` and receive probability `0.0`. +/// Returns a probability vector of length `action_space`. +pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { + let mut probs = vec![0.0f32; action_space]; + if legal.is_empty() { + return probs; + } + let max_logit = legal + .iter() + .map(|&a| logits[a]) + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for &a in legal { + let e = (logits[a] - max_logit).exp(); + probs[a] = e; + sum += e; + } + if sum > 0.0 { + for &a in legal { + probs[a] /= sum; + } + } else { + let uniform = 1.0 / legal.len() as f32; + for &a in legal { + probs[a] = uniform; + } + } + probs +} + +// ── Dirichlet noise ──────────────────────────────────────────────────────── + +/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. +/// +/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. +/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. +pub(super) fn add_dirichlet_noise( + node: &mut MctsNode, + alpha: f32, + eps: f32, + rng: &mut impl Rng, +) { + let n = node.children.len(); + if n == 0 { + return; + } + let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { + return; + }; + let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); + let sum: f32 = samples.iter().sum(); + if sum <= 0.0 { + return; + } + for (i, (_, child)) in node.children.iter_mut().enumerate() { + let noise = samples[i] / sum; + child.p = (1.0 - eps) * child.p + eps * noise; + } +} + +// ── Expansion ────────────────────────────────────────────────────────────── + +/// Evaluate the network at `state` and populate `node` with children. +/// +/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. +/// Returns the network value estimate from `player_idx`'s perspective. +pub(super) fn expand( + node: &mut MctsNode, + state: &E::State, + env: &E, + evaluator: &dyn Evaluator, + player_idx: usize, +) -> f32 { + let obs = env.observation(state, player_idx); + let legal = env.legal_actions(state); + let (logits, value) = evaluator.evaluate(&obs); + let priors = masked_softmax(&logits, &legal, env.action_space()); + node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); + node.expanded = true; + node.n = 1; + node.w = value; + value +} + +// ── Simulation ───────────────────────────────────────────────────────────── + +/// One MCTS simulation from an **already-expanded** decision node. +/// +/// Traverses the tree with PUCT selection, expands the first unvisited leaf, +/// and backs up the result. +/// +/// * `player_idx` — the player (0 or 1) who acts at `state`. +/// * Returns the backed-up value **from `player_idx`'s perspective**. +pub(super) fn simulate( + node: &mut MctsNode, + state: E::State, + env: &E, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, + player_idx: usize, +) -> f32 { + debug_assert!(node.expanded, "simulate called on unexpanded node"); + + // ── Selection: child with highest PUCT ──────────────────────────────── + let parent_n = node.n; + let best = node + .children + .iter() + .enumerate() + .max_by(|(_, (_, a)), (_, (_, b))| { + a.puct(parent_n, config.c_puct) + .partial_cmp(&b.puct(parent_n, config.c_puct)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) + .expect("expanded node must have at least one child"); + + let (action, child) = &mut node.children[best]; + let action = *action; + + // ── Apply action + advance through any chance nodes ─────────────────── + let mut next_state = state; + env.apply(&mut next_state, action); + while env.current_player(&next_state).is_chance() { + env.apply_chance(&mut next_state, rng); + } + + let next_cp = env.current_player(&next_state); + + // ── Evaluate leaf or terminal ────────────────────────────────────────── + // All values are converted to `player_idx`'s perspective before backup. + let child_value = if next_cp.is_terminal() { + let returns = env + .returns(&next_state) + .expect("terminal node must have returns"); + returns[player_idx] + } else { + let child_player = next_cp.index().unwrap(); + let v = if child.expanded { + simulate(child, next_state, env, evaluator, config, rng, child_player) + } else { + expand::(child, &next_state, env, evaluator, child_player) + }; + // Negate when the child belongs to the opponent. + if child_player == player_idx { v } else { -v } + }; + + // ── Backup ──────────────────────────────────────────────────────────── + node.n += 1; + node.w += child_value; + + child_value +} From b0ae4db2d9661405360e60f2115d0429e39c0af1 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 21:00:27 +0100 Subject: [PATCH 06/16] feat(spiel_bot): AlphaZero --- spiel_bot/src/alphazero/mod.rs | 117 ++++++++++++++ spiel_bot/src/alphazero/replay.rs | 144 +++++++++++++++++ spiel_bot/src/alphazero/selfplay.rs | 234 ++++++++++++++++++++++++++++ spiel_bot/src/alphazero/trainer.rs | 172 ++++++++++++++++++++ spiel_bot/src/lib.rs | 1 + 5 files changed, 668 insertions(+) create mode 100644 spiel_bot/src/alphazero/mod.rs create mode 100644 spiel_bot/src/alphazero/replay.rs create mode 100644 spiel_bot/src/alphazero/selfplay.rs create mode 100644 spiel_bot/src/alphazero/trainer.rs diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs new file mode 100644 index 0000000..bb86724 --- /dev/null +++ b/spiel_bot/src/alphazero/mod.rs @@ -0,0 +1,117 @@ +//! AlphaZero: self-play data generation, replay buffer, and training step. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | +//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | +//! | [`trainer`] | [`train_step`] | +//! +//! # Typical outer loop +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! use burn::optim::AdamConfig; +//! use spiel_bot::{ +//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, +//! env::TrictracEnv, +//! mcts::MctsConfig, +//! network::{MlpConfig, MlpNet}, +//! }; +//! +//! type Infer = NdArray; +//! type Train = Autodiff>; +//! +//! let device = Default::default(); +//! let env = TrictracEnv; +//! let config = AlphaZeroConfig::default(); +//! +//! // Build training model and optimizer. +//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); +//! let mut optimizer = AdamConfig::new().init(); +//! let mut replay = ReplayBuffer::new(config.replay_capacity); +//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); +//! +//! for _iter in 0..config.n_iterations { +//! // Convert to inference backend for self-play. +//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) +//! .load_record(train_model.clone().into_record()); +//! let eval = BurnEvaluator::new(infer_model, device.clone()); +//! +//! // Self-play: generate episodes. +//! for _ in 0..config.n_games_per_iter { +//! let samples = generate_episode(&env, &eval, &config.mcts, +//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); +//! replay.extend(samples); +//! } +//! +//! // Training: gradient steps. +//! if replay.len() >= config.batch_size { +//! for _ in 0..config.n_train_steps_per_iter { +//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) +//! .into_iter().cloned().collect(); +//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, +//! config.learning_rate); +//! train_model = m; +//! } +//! } +//! } +//! ``` + +pub mod replay; +pub mod selfplay; +pub mod trainer; + +pub use replay::{ReplayBuffer, TrainSample}; +pub use selfplay::{BurnEvaluator, generate_episode}; +pub use trainer::train_step; + +use crate::mcts::MctsConfig; + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Top-level AlphaZero hyperparameters. +/// +/// The MCTS parameters live in [`MctsConfig`]; this struct holds the +/// outer training-loop parameters. +#[derive(Debug, Clone)] +pub struct AlphaZeroConfig { + /// MCTS parameters for self-play. + pub mcts: MctsConfig, + /// Number of self-play games per training iteration. + pub n_games_per_iter: usize, + /// Number of gradient steps per training iteration. + pub n_train_steps_per_iter: usize, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Number of outer iterations (self-play + train) to run. + pub n_iterations: usize, + /// Move index after which the action temperature drops to 0 (greedy play). + pub temperature_drop_move: usize, +} + +impl Default for AlphaZeroConfig { + fn default() -> Self { + Self { + mcts: MctsConfig { + n_simulations: 100, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temperature: 1.0, + }, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + batch_size: 64, + replay_capacity: 50_000, + learning_rate: 1e-3, + n_iterations: 100, + temperature_drop_move: 30, + } + } +} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs new file mode 100644 index 0000000..5e64cc4 --- /dev/null +++ b/spiel_bot/src/alphazero/replay.rs @@ -0,0 +1,144 @@ +//! Replay buffer for AlphaZero self-play data. + +use std::collections::VecDeque; +use rand::Rng; + +// ── Training sample ──────────────────────────────────────────────────────── + +/// One training example produced by self-play. +#[derive(Clone, Debug)] +pub struct TrainSample { + /// Observation tensor from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). + pub policy: Vec, + /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. + pub value: f32, +} + +// ── Replay buffer ────────────────────────────────────────────────────────── + +/// Fixed-capacity circular buffer of [`TrainSample`]s. +/// +/// When the buffer is full, the oldest sample is evicted on push. +/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + /// Create a buffer with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity.min(1024)), + capacity, + } + } + + /// Add a sample; evicts the oldest if at capacity. + pub fn push(&mut self, sample: TrainSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + /// Add all samples from an episode. + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { + self.push(s); + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Sample up to `n` distinct samples, without replacement. + /// + /// If the buffer has fewer than `n` samples, all are returned (shuffled). + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + let len = self.data.len(); + let n = n.min(len); + // Partial Fisher-Yates using index shuffling. + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(value: f32) -> TrainSample { + TrainSample { obs: vec![value], policy: vec![1.0], value } + } + + #[test] + fn push_and_len() { + let mut buf = ReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = ReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); // evicts 1.0 + assert_eq!(buf.len(), 3); + // Oldest remaining should be 2.0 + assert_eq!(buf.data[0].value, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(5, &mut rng); + assert_eq!(batch.len(), 5); + } + + #[test] + fn sample_batch_capped_at_len() { + let mut buf = ReplayBuffer::new(20); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(100, &mut rng); + assert_eq!(batch.len(), 2); + } + + #[test] + fn sample_batch_no_duplicates() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(1); + let batch = buf.sample_batch(10, &mut rng); + let mut seen: Vec = batch.iter().map(|s| s.value).collect(); + seen.sort_by(f32::total_cmp); + seen.dedup(); + assert_eq!(seen.len(), 10, "sample contained duplicates"); + } +} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs new file mode 100644 index 0000000..6f10f8d --- /dev/null +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -0,0 +1,234 @@ +//! Self-play episode generation and Burn-backed evaluator. + +use std::marker::PhantomData; + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::GameEnv; +use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; +use crate::network::PolicyValueNet; + +use super::replay::TrainSample; + +// ── BurnEvaluator ────────────────────────────────────────────────────────── + +/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. +/// +/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so +/// that self-play generates no gradient tape overhead. +pub struct BurnEvaluator> { + model: N, + device: B::Device, + _b: PhantomData, +} + +impl> BurnEvaluator { + pub fn new(model: N, device: B::Device) -> Self { + Self { model, device, _b: PhantomData } + } + + pub fn into_model(self) -> N { + self.model + } +} + +// Safety: NdArray modules are Send; we never share across threads without +// external synchronisation. +unsafe impl> Send for BurnEvaluator {} +unsafe impl> Sync for BurnEvaluator {} + +impl> Evaluator for BurnEvaluator { + fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { + let obs_size = obs.len(); + let data = TensorData::new(obs.to_vec(), [1, obs_size]); + let obs_tensor = Tensor::::from_data(data, &self.device); + + let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); + + let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); + let value: Vec = value_tensor.into_data().to_vec().unwrap(); + + (policy, value[0]) + } +} + +// ── Episode generation ───────────────────────────────────────────────────── + +/// One pending observation waiting for its game-outcome value label. +struct PendingSample { + obs: Vec, + policy: Vec, + player: usize, +} + +/// Play one full game using MCTS guided by `evaluator`. +/// +/// Returns a [`TrainSample`] for every decision step in the game. +/// +/// `temperature_fn(step)` controls exploration: return `1.0` for early +/// moves and `0.0` after a fixed number of moves (e.g. move 30). +pub fn generate_episode( + env: &E, + evaluator: &dyn Evaluator, + mcts_config: &MctsConfig, + temperature_fn: &dyn Fn(usize) -> f32, + rng: &mut impl Rng, +) -> Vec { + let mut state = env.new_game(); + let mut pending: Vec = Vec::new(); + let mut step = 0usize; + + loop { + // Advance through chance nodes automatically. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + if env.current_player(&state).is_terminal() { + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + + // Run MCTS to get a policy. + let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); + let policy = mcts::mcts_policy(&root, env.action_space()); + + // Record the observation from the acting player's perspective. + let obs = env.observation(&state, player_idx); + pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); + + // Select and apply the action. + let temperature = temperature_fn(step); + let action = mcts::select_action(&root, temperature, rng); + env.apply(&mut state, action); + step += 1; + } + + // Assign game outcomes. + let returns = env.returns(&state).unwrap_or([0.0; 2]); + pending + .into_iter() + .map(|s| TrainSample { + obs: s.obs, + policy: s.policy, + value: returns[s.player], + }) + .collect() +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::env::Player; + use crate::mcts::{Evaluator, MctsConfig}; + use crate::network::{MlpConfig, MlpNet}; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(7) + } + + // Countdown game (same as in mcts tests). + #[derive(Clone, Debug)] + struct CState { remaining: u8, to_move: usize } + + #[derive(Clone)] + struct CountdownEnv; + + impl GameEnv for CountdownEnv { + type State = CState; + fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } else { Player::P2 } + } + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { s.remaining = 0; } + else { s.remaining -= sub; s.to_move = 1 - s.to_move; } + } + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 4.0, s.to_move as f32] + } + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + fn tiny_config() -> MctsConfig { + MctsConfig { n_simulations: 5, c_puct: 1.5, + dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } + } + + // ── BurnEvaluator tests ─────────────────────────────────────────────── + + #[test] + fn burn_evaluator_output_shapes() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); + assert_eq!(policy.len(), 2, "policy length should equal action_space"); + assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); + } + + // ── generate_episode tests ──────────────────────────────────────────── + + #[test] + fn episode_terminates_and_has_samples() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_sample_values_are_valid() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + for s in &samples { + assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "unexpected value {}", s.value); + let sum: f32 = s.policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); + assert_eq!(s.obs.len(), 2); + } + } + + #[test] + fn episode_with_temperature_zero() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + // temperature=0 means greedy; episode must still terminate + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); + assert!(!samples.is_empty()); + } +} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs new file mode 100644 index 0000000..d2482d1 --- /dev/null +++ b/spiel_bot/src/alphazero/trainer.rs @@ -0,0 +1,172 @@ +//! One gradient-descent training step for AlphaZero. +//! +//! The loss combines: +//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. +//! - **Value loss** — mean-squared error between the predicted value and the +//! actual game outcome. +//! +//! # Backend +//! +//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). +//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. +//! Weights are transferred between the two via [`burn::record`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + activation::log_softmax, + backend::AutodiffBackend, + Tensor, TensorData, + }, +}; + +use crate::network::PolicyValueNet; +use super::replay::TrainSample; + +/// Run one gradient step on `model` using `batch`. +/// +/// Returns the updated model and the scalar loss value for logging. +/// +/// # Parameters +/// +/// - `lr` — learning rate (e.g. `1e-3`). +/// - `batch` — slice of [`TrainSample`]s; must be non-empty. +pub fn train_step( + model: N, + optimizer: &mut O, + batch: &[TrainSample], + device: &B::Device, + lr: f64, +) -> (N, f32) +where + B: AutodiffBackend, + N: PolicyValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "train_step called with empty batch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + let action_size = batch[0].policy.len(); + + // ── Build input tensors ──────────────────────────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); + let value_flat: Vec = batch.iter().map(|s| s.value).collect(); + + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + let policy_target = Tensor::::from_data( + TensorData::new(policy_flat, [batch_size, action_size]), + device, + ); + let value_target = Tensor::::from_data( + TensorData::new(value_flat, [batch_size, 1]), + device, + ); + + // ── Forward pass ────────────────────────────────────────────────────── + let (policy_logits, value_pred) = model.forward(obs_tensor); + + // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── + let log_probs = log_softmax(policy_logits, 1); + let policy_loss = (policy_target.clone().neg() * log_probs) + .sum_dim(1) + .mean(); + + // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── + let diff = value_pred - value_target; + let value_loss = (diff.clone() * diff).mean(); + + // ── Combined loss ───────────────────────────────────────────────────── + let loss = policy_loss + value_loss; + + // Extract scalar before backward (consumes the tensor). + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + let model = optimizer.step(lr, model, grads); + + (model, loss_scalar) +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + + use crate::network::{MlpConfig, MlpNet}; + use super::super::replay::TrainSample; + + type B = Autodiff>; + + fn device() -> ::Device { + Default::default() + } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5f32; obs_size], + policy: { + let mut p = vec![0.0f32; action_size]; + p[i % action_size] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + } + + #[test] + fn train_step_returns_finite_loss() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); + } + + #[test] + fn loss_decreases_over_steps() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + // Same batch every step — loss should decrease. + let batch = dummy_batch(16, 4, 4); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); + model = m; + assert!(loss.is_finite()); + prev_loss = loss; + } + // After 10 steps on fixed data, loss should be below a reasonable threshold. + assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn train_step_batch_size_one() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(1, 2, 2); + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite()); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 5beb37c..23895b9 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,3 +1,4 @@ +pub mod alphazero; pub mod env; pub mod mcts; pub mod network; From 519dfe67ad952db34222762ae212a34ba267f50e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:18:59 +0100 Subject: [PATCH 07/16] fix(spiel_bot): mcts fix --- spiel_bot/src/mcts/mod.rs | 8 ++++++-- spiel_bot/src/mcts/search.rs | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index e92bd09..a0a690d 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -401,8 +401,12 @@ mod tests { }; let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - assert!(root.n > 0); + // root.n = 1 (expansion) + n_simulations (one backup per simulation). + assert_eq!(root.n, 1 + config.n_simulations as u32); + // Children visit counts may sum to less than n_simulations when some + // simulations cross a chance node at depth 1 (turn ends after one move) + // and evaluate with the network directly without updating child.n. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 5); + assert!(total <= config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index c4960c7..55db701 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -138,8 +138,14 @@ pub(super) fn simulate( // ── Apply action + advance through any chance nodes ─────────────────── let mut next_state = state; env.apply(&mut next_state, action); + + // Track whether we crossed a chance node (dice roll) on the way down. + // If we did, the child's cached legal actions are for a *different* dice + // outcome and must not be reused — evaluate with the network directly. + let mut crossed_chance = false; while env.current_player(&next_state).is_chance() { env.apply_chance(&mut next_state, rng); + crossed_chance = true; } let next_cp = env.current_player(&next_state); @@ -153,7 +159,15 @@ pub(super) fn simulate( returns[player_idx] } else { let child_player = next_cp.index().unwrap(); - let v = if child.expanded { + let v = if crossed_chance { + // Outcome sampling: after dice, evaluate the resulting position + // directly with the network. Do NOT build the tree across chance + // boundaries — the dice change which actions are legal, so any + // previously cached children would be for a different outcome. + let obs = env.observation(&next_state, child_player); + let (_, value) = evaluator.evaluate(&obs); + value + } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) } else { expand::(child, &next_state, env, evaluator, child_player) From aea1e3faafc42b1953d709f03331db8d50dfc458 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:19:15 +0100 Subject: [PATCH 08/16] tests(spiel_bot): integration tests --- spiel_bot/tests/integration.rs | 391 +++++++++++++++++++++++++++++++++ 1 file changed, 391 insertions(+) create mode 100644 spiel_bot/tests/integration.rs diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs new file mode 100644 index 0000000..d73fda0 --- /dev/null +++ b/spiel_bot/tests/integration.rs @@ -0,0 +1,391 @@ +//! End-to-end integration tests for the AlphaZero training pipeline. +//! +//! Each test exercises the full chain: +//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] +//! +//! Two environments are used: +//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. +//! Used when we need many iterations without worrying about runtime. +//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that +//! the full pipeline compiles and runs correctly with 217-dim observations +//! and 514-dim action spaces. +//! +//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep +//! runtime minimal; correctness, not training quality, is what matters here. + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Backend aliases ──────────────────────────────────────────────────────── + +type Train = Autodiff>; +type Infer = NdArray; + +// ── Helpers ──────────────────────────────────────────────────────────────── + +fn train_device() -> ::Device { + Default::default() +} + +fn infer_device() -> ::Device { + Default::default() +} + +/// Tiny 64-unit MLP, compatible with an obs/action space of any size. +fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { + let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; + MlpNet::new(&cfg, &train_device()) +} + +fn tiny_mcts(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + } +} + +fn seeded() -> SmallRng { + SmallRng::seed_from_u64(0) +} + +// ── Countdown environment (fast, local, no external deps) ───────────────── +// +// Two players alternate subtracting 1 or 2 from a counter that starts at N. +// The player who brings the counter to 0 wins. + +#[derive(Clone, Debug)] +struct CState { + remaining: u8, + to_move: usize, +} + +#[derive(Clone)] +struct CountdownEnv(u8); // starting value + +impl GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: self.0, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } + else { Player::P2 } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } +} + +// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── + +/// The canonical AlphaZero loop: self-play → replay → train, iterated. +/// Uses CountdownEnv so each game terminates in < 10 moves. +#[test] +fn countdown_full_loop_no_panic() { + let env = CountdownEnv(8); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(1_000); + + for _iter in 0..5 { + // Self-play: 3 games per iteration. + for _ in 0..3 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + replay.extend(samples); + } + + // Training: 4 gradient steps per iteration. + if replay.len() >= 4 { + for _ in 0..4 { + let batch: Vec = replay + .sample_batch(4, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + } + } + + assert!(replay.len() > 0); +} + +// ── 2. Replay buffer invariants ─────────────────────────────────────────── + +/// After several Countdown games, replay capacity is respected and batch +/// shapes are consistent. +#[test] +fn replay_buffer_capacity_and_shapes() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let capacity = 50; + let mut replay = ReplayBuffer::new(capacity); + + for _ in 0..20 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + replay.extend(samples); + } + + assert!(replay.len() <= capacity, "buffer exceeded capacity"); + assert!(replay.len() > 0); + + let batch = replay.sample_batch(8, &mut rng); + assert_eq!(batch.len(), 8.min(replay.len())); + for s in &batch { + assert_eq!(s.obs.len(), env.obs_size()); + assert_eq!(s.policy.len(), env.action_space()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); + assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); + } +} + +// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── + +/// Verify that one TrictracEnv episode produces samples with the correct +/// tensor dimensions: obs = 217, policy = 514. +#[test] +fn trictrac_sample_shapes() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + + assert!(!samples.is_empty(), "Trictrac episode produced no samples"); + + for (i, s) in samples.iter().enumerate() { + assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); + assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!( + (policy_sum - 1.0).abs() < 1e-4, + "sample {i}: policy sums to {policy_sum}" + ); + assert!( + s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "sample {i}: unexpected value {}", + s.value + ); + } +} + +// ── 4. TrictracEnv: training step after real self-play ──────────────────── + +/// Collect one Trictrac episode, then verify that a gradient step runs +/// without panic and produces a finite loss. +#[test] +fn trictrac_train_step_finite_loss() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(10_000); + + // Generate one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + let n_samples = samples.len(); + replay.extend(samples); + + // Train on a batch from this episode. + let batch_size = 8.min(n_samples); + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); +} + +// ── 5. Backend transfer: train → infer → same outputs ───────────────────── + +/// Weights transferred from the training backend to the inference backend +/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. +#[test] +fn valid_model_matches_train_model_outputs() { + use burn::tensor::{Tensor, TensorData}; + + let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let train_model = MlpNet::::new(&cfg, &train_device()); + let infer_model: MlpNet = train_model.valid(); + + // Build the same input on both backends. + let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; + + let obs_train = Tensor::::from_data( + TensorData::new(obs_data.clone(), [1, 4]), + &train_device(), + ); + let obs_infer = Tensor::::from_data( + TensorData::new(obs_data, [1, 4]), + &infer_device(), + ); + + let (p_train, v_train) = train_model.forward(obs_train); + let (p_infer, v_infer) = infer_model.forward(obs_infer); + + let p_train: Vec = p_train.into_data().to_vec().unwrap(); + let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); + let v_train: Vec = v_train.into_data().to_vec().unwrap(); + let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); + + for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { + assert!( + (a - b).abs() < 1e-5, + "policy[{i}] differs after valid(): train={a}, infer={b}" + ); + } + assert!( + (v_train[0] - v_infer[0]).abs() < 1e-5, + "value differs after valid(): train={}, infer={}", + v_train[0], v_infer[0] + ); +} + +// ── 6. Loss converges on a fixed batch ──────────────────────────────────── + +/// With repeated gradient steps on the same Countdown batch, the loss must +/// decrease monotonically (or at least end lower than it started). +#[test] +fn loss_decreases_on_fixed_batch() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + + // Collect a fixed batch from one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); + assert!(!samples.is_empty()); + + let batch: Vec = { + let mut replay = ReplayBuffer::new(1000); + replay.extend(samples); + replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() + }; + + // Overfit on the same fixed batch for 20 steps. + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut first_loss = f32::NAN; + let mut last_loss = f32::NAN; + + for step in 0..20 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); + model = m; + assert!(loss.is_finite(), "loss is not finite at step {step}"); + if step == 0 { first_loss = loss; } + last_loss = loss; + } + + assert!( + last_loss < first_loss, + "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" + ); +} + +// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── + +/// Two full self-play + train iterations on TrictracEnv. +/// Verifies the entire pipeline runs without panic end-to-end. +#[test] +fn trictrac_two_iteration_loop() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mut model = MlpNet::::new(&cfg, &train_device()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(20_000); + + for iter in 0..2 { + // Self-play: 1 game per iteration. + let infer: MlpNet = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); + assert!(!samples.is_empty(), "iter {iter}: episode was empty"); + replay.extend(samples); + + // Training: 3 gradient steps. + let batch_size = 16.min(replay.len()); + for _ in 0..3 { + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "iter {iter}: loss={loss}"); + } + } +} From 9c82692ddb687057f15c0ef84d4dc8a5338f8697 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:49:55 +0100 Subject: [PATCH 09/16] feat(spiel_bot): benchmarks --- Cargo.lock | 120 ++++++++++++ spiel_bot/Cargo.toml | 7 + spiel_bot/benches/alphazero.rs | 341 +++++++++++++++++++++++++++++++++ 3 files changed, 468 insertions(+) create mode 100644 spiel_bot/benches/alphazero.rs diff --git a/Cargo.lock b/Cargo.lock index 0baa02a..34bfe80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -1116,6 +1122,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cast_trait" version = "0.1.2" @@ -1200,6 +1212,33 @@ dependencies = [ "rand 0.7.3", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1453,6 +1492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -4461,6 +4536,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4597,6 +4678,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5897,6 +6006,7 @@ version = "0.1.0" dependencies = [ "anyhow", "burn", + "criterion", "rand 0.9.2", "rand_distr", "trictrac-store", @@ -6310,6 +6420,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 323c953..3848dce 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,3 +9,10 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "alphazero" +harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs new file mode 100644 index 0000000..2950b09 --- /dev/null +++ b/spiel_bot/benches/alphazero.rs @@ -0,0 +1,341 @@ +//! AlphaZero pipeline benchmarks. +//! +//! Run with: +//! +//! ```sh +//! cargo bench -p spiel_bot +//! ``` +//! +//! Use `-- ` to run a specific group, e.g.: +//! +//! ```sh +//! cargo bench -p spiel_bot -- env/ +//! cargo bench -p spiel_bot -- network/ +//! cargo bench -p spiel_bot -- mcts/ +//! cargo bench -p spiel_bot -- episode/ +//! cargo bench -p spiel_bot -- train/ +//! ``` +//! +//! Target: ≥ 500 games/s for random play on CPU (consistent with +//! `random_game` throughput in `trictrac-store`). + +use std::time::Duration; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData, backend::Backend}, +}; +use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts}, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Shared types ─────────────────────────────────────────────────────────── + +type InferB = NdArray; +type TrainB = burn::backend::Autodiff>; + +fn infer_device() -> ::Device { Default::default() } +fn train_device() -> ::Device { Default::default() } + +fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } + +/// Uniform evaluator (returns zero logits and zero value). +/// Used to isolate MCTS tree-traversal cost from network cost. +struct ZeroEval(usize); +impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } +} + +// ── 1. Environment primitives ────────────────────────────────────────────── + +/// Baseline performance of the raw Trictrac environment without MCTS. +/// Target: ≥ 500 full games / second. +fn bench_env(c: &mut Criterion) { + let env = TrictracEnv; + + let mut group = c.benchmark_group("env"); + group.measurement_time(Duration::from_secs(10)); + + // ── apply_chance ────────────────────────────────────────────────────── + group.bench_function("apply_chance", |b| { + b.iter_batched( + || { + // A fresh game is always at RollDice (Chance) — ready for apply_chance. + env.new_game() + }, + |mut s| { + env.apply_chance(&mut s, &mut seeded()); + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + // ── legal_actions ───────────────────────────────────────────────────── + group.bench_function("legal_actions", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.legal_actions(&s))) + }); + + // ── observation (to_tensor) ─────────────────────────────────────────── + group.bench_function("observation", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.observation(&s, 0))) + }); + + // ── full random game ────────────────────────────────────────────────── + group.sample_size(50); + group.bench_function("random_game", |b| { + b.iter_batched( + seeded, + |mut rng| { + let mut s = env.new_game(); + loop { + match env.current_player(&s) { + Player::Terminal => break, + Player::Chance => env.apply_chance(&mut s, &mut rng), + _ => { + let actions = env.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + env.apply(&mut s, actions[idx]); + } + } + } + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +// ── 2. Network inference ─────────────────────────────────────────────────── + +/// Forward-pass latency for MLP variants (hidden = 64 / 256). +fn bench_network(c: &mut Criterion) { + let mut group = c.benchmark_group("network"); + group.measurement_time(Duration::from_secs(5)); + + for &hidden in &[64usize, 256] { + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = MlpNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + // Batch size 1 — single-position evaluation as in MCTS. + group.bench_with_input( + BenchmarkId::new("mlp_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + // Batch size 32 — training mini-batch. + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("mlp_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + group.finish(); +} + +// ── 3. MCTS ─────────────────────────────────────────────────────────────── + +/// MCTS cost at different simulation budgets with two evaluator types: +/// - `zero` — isolates tree-traversal overhead (no network). +/// - `mlp64` — real MLP, shows end-to-end cost per move. +fn bench_mcts(c: &mut Criterion) { + let env = TrictracEnv; + + // Build a decision-node state (after dice roll). + let state = { + let mut s = env.new_game(); + let mut rng = seeded(); + while env.current_player(&s).is_chance() { + env.apply_chance(&mut s, &mut rng); + } + s + }; + + let mut group = c.benchmark_group("mcts"); + group.measurement_time(Duration::from_secs(10)); + + let zero_eval = ZeroEval(514); + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); + let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); + + for &n_sim in &[1usize, 5, 20] { + let cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + // Zero evaluator: tree traversal only. + group.bench_with_input( + BenchmarkId::new("zero_eval", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + + // MLP evaluator: full cost per decision. + group.bench_with_input( + BenchmarkId::new("mlp64", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 4. Episode generation ───────────────────────────────────────────────── + +/// Full self-play episode latency (one complete game) at different MCTS +/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. +fn bench_episode(c: &mut Criterion) { + let env = TrictracEnv; + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let model = MlpNet::::new(&mlp_cfg, &infer_device()); + let eval = BurnEvaluator::::new(model, infer_device()); + + let mut group = c.benchmark_group("episode"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(60)); + + for &n_sim in &[1usize, 2] { + let mcts_cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + group.bench_with_input( + BenchmarkId::new("trictrac", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| { + black_box(generate_episode( + &env, + &eval, + &mcts_cfg, + &|_| 1.0, + &mut rng, + )) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 5. Training step ─────────────────────────────────────────────────────── + +/// Gradient-step latency for different batch sizes. +fn bench_train(c: &mut Criterion) { + use burn::optim::AdamConfig; + + let mut group = c.benchmark_group("train"); + group.measurement_time(Duration::from_secs(10)); + + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + + let dummy_samples = |n: usize| -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5; 217], + policy: { + let mut p = vec![0.0f32; 514]; + p[i % 514] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + }; + + for &batch_size in &[16usize, 64] { + let batch = dummy_samples(batch_size); + + group.bench_with_input( + BenchmarkId::new("mlp64_adam", batch_size), + &batch_size, + |b, _| { + b.iter_batched( + || { + ( + MlpNet::::new(&mlp_cfg, &train_device()), + AdamConfig::new().init::>(), + ) + }, + |(model, mut opt)| { + black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── Criterion entry point ────────────────────────────────────────────────── + +criterion_group!( + benches, + bench_env, + bench_network, + bench_mcts, + bench_episode, + bench_train, +); +criterion_main!(benches); From 822290d7224fa8cc6ce0aa324adcfe388dd705b2 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 23:05:53 +0100 Subject: [PATCH 10/16] feat(spiel_bot): upgrade network --- spiel_bot/benches/alphazero.rs | 34 +++++++++++- spiel_bot/src/alphazero/mod.rs | 14 ++++- spiel_bot/src/alphazero/trainer.rs | 86 ++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs index 2950b09..00d5b02 100644 --- a/spiel_bot/benches/alphazero.rs +++ b/spiel_bot/benches/alphazero.rs @@ -32,7 +32,7 @@ use spiel_bot::{ alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, env::{GameEnv, Player, TrictracEnv}, mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet}, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, }; // ── Shared types ─────────────────────────────────────────────────────────── @@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) { ); } + // ── ResNet (4 residual blocks) ──────────────────────────────────────── + for &hidden in &[256usize, 512] { + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = ResNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + group.bench_with_input( + BenchmarkId::new("resnet_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("resnet_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + group.finish(); } diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs index bb86724..d92224e 100644 --- a/spiel_bot/src/alphazero/mod.rs +++ b/spiel_bot/src/alphazero/mod.rs @@ -65,7 +65,7 @@ pub mod trainer; pub use replay::{ReplayBuffer, TrainSample}; pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::train_step; +pub use trainer::{cosine_lr, train_step}; use crate::mcts::MctsConfig; @@ -87,8 +87,17 @@ pub struct AlphaZeroConfig { pub batch_size: usize, /// Maximum number of samples in the replay buffer. pub replay_capacity: usize, - /// Adam learning rate. + /// Initial (peak) Adam learning rate. pub learning_rate: f64, + /// Minimum learning rate for cosine annealing (floor of the schedule). + /// + /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). + /// Compute the current LR with [`cosine_lr`]: + /// + /// ```rust,ignore + /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); + /// ``` + pub lr_min: f64, /// Number of outer iterations (self-play + train) to run. pub n_iterations: usize, /// Move index after which the action temperature drops to 0 (greedy play). @@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig { batch_size: 64, replay_capacity: 50_000, learning_rate: 1e-3, + lr_min: 1e-4, // cosine annealing floor n_iterations: 100, temperature_drop_move: 30, } diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs index d2482d1..9075519 100644 --- a/spiel_bot/src/alphazero/trainer.rs +++ b/spiel_bot/src/alphazero/trainer.rs @@ -5,6 +5,24 @@ //! - **Value loss** — mean-squared error between the predicted value and the //! actual game outcome. //! +//! # Learning-rate scheduling +//! +//! [`cosine_lr`] implements one-cycle cosine annealing: +//! +//! ```text +//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) +//! ``` +//! +//! Typical usage in the outer loop: +//! +//! ```rust,ignore +//! for step in 0..total_train_steps { +//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); +//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); +//! model = m; +//! } +//! ``` +//! //! # Backend //! //! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). @@ -96,6 +114,30 @@ where (model, loss_scalar) } +// ── Learning-rate schedule ───────────────────────────────────────────────── + +/// Cosine learning-rate schedule (one half-period, no warmup). +/// +/// Returns the learning rate for training step `step` out of `total_steps`: +/// +/// ```text +/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) +/// ``` +/// +/// - At `t = 0` returns `initial`. +/// - At `t = total_steps` (or beyond) returns `lr_min`. +/// +/// # Panics +/// +/// Does not panic. When `total_steps == 0`, returns `lr_min`. +pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { + if total_steps == 0 || step >= total_steps { + return lr_min; + } + let progress = step as f64 / total_steps as f64; + lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) +} + // ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -169,4 +211,48 @@ mod tests { let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); assert!(loss.is_finite()); } + + // ── cosine_lr ───────────────────────────────────────────────────────── + + #[test] + fn cosine_lr_at_step_zero_is_initial() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); + assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); + } + + #[test] + fn cosine_lr_at_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); + } + + #[test] + fn cosine_lr_beyond_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); + } + + #[test] + fn cosine_lr_midpoint_is_average() { + // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. + let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); + let expected = (1e-3 + 1e-5) / 2.0; + assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); + } + + #[test] + fn cosine_lr_monotone_decreasing() { + let mut prev = f64::INFINITY; + for step in 0..=100 { + let lr = super::cosine_lr(1e-3, 1e-5, step, 100); + assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); + prev = lr; + } + } + + #[test] + fn cosine_lr_zero_total_steps_returns_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); + assert!((lr - 1e-5).abs() < 1e-10); + } } From 3221b5256a370734d7cc5d861a981b0111ed49fb Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 23:06:21 +0100 Subject: [PATCH 11/16] feat(spiel_bot): alphazero eval binary --- spiel_bot/src/bin/az_eval.rs | 262 +++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 spiel_bot/src/bin/az_eval.rs diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs new file mode 100644 index 0000000..3c82519 --- /dev/null +++ b/spiel_bot/src/bin/az_eval.rs @@ -0,0 +1,262 @@ +//! Evaluate a trained AlphaZero checkpoint against a random player. +//! +//! # Usage +//! +//! ```sh +//! # Random weights (sanity check — should be ~50 %) +//! cargo run -p spiel_bot --bin az_eval --release +//! +//! # Trained MLP checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 +//! +//! # Trained ResNet checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | +//! | `--n-games ` | `100` | Games per side (total = 2 × N) | +//! | `--n-sim ` | `50` | MCTS simulations per move | +//! | `--seed ` | `42` | RNG seed | +//! | `--c-puct ` | `1.5` | PUCT exploration constant | + +use std::path::PathBuf; + +use burn::backend::NdArray; +use rand::{SeedableRng, rngs::SmallRng, Rng}; + +use spiel_bot::{ + alphazero::BurnEvaluator, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts, select_action}, + network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, +}; + +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + checkpoint: Option, + arch: String, + hidden: Option, + n_games: usize, + n_sim: usize, + seed: u64, + c_puct: f32, +} + +impl Default for Args { + fn default() -> Self { + Self { + checkpoint: None, + arch: "mlp".into(), + hidden: None, + n_games: 100, + n_sim: 50, + seed: 42, + c_puct: 1.5, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut args = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } + "--arch" => { i += 1; args.arch = raw[i].clone(); } + "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } + "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } + "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } + "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } + "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + args +} + +// ── Game loop ───────────────────────────────────────────────────────────────── + +/// Play one complete game. +/// +/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). +/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). +fn play_game( + env: &TrictracEnv, + mcts_side: usize, + evaluator: &dyn Evaluator, + mcts_cfg: &MctsConfig, + rng: &mut SmallRng, +) -> [f32; 2] { + let mut state = env.new_game(); + loop { + match env.current_player(&state) { + Player::Terminal => { + return env.returns(&state).expect("Terminal state must have returns"); + } + Player::Chance => env.apply_chance(&mut state, rng), + player => { + let side = player.index().unwrap(); // 0 = P1, 1 = P2 + let action = if side == mcts_side { + let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); + select_action(&root, 0.0, rng) // greedy (temperature = 0) + } else { + let actions = env.legal_actions(&state); + actions[rng.random_range(0..actions.len())] + }; + env.apply(&mut state, action); + } + } + } +} + +// ── Statistics ──────────────────────────────────────────────────────────────── + +#[derive(Default)] +struct Stats { + wins: u32, + draws: u32, + losses: u32, +} + +impl Stats { + fn record(&mut self, mcts_return: f32) { + if mcts_return > 0.0 { self.wins += 1; } + else if mcts_return < 0.0 { self.losses += 1; } + else { self.draws += 1; } + } + + fn total(&self) -> u32 { self.wins + self.draws + self.losses } + + fn win_rate_decisive(&self) -> f64 { + let d = self.wins + self.losses; + if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } + } + + fn print(&self) { + let n = self.total(); + let pct = |k: u32| 100.0 * k as f64 / n as f64; + println!( + " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", + self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), + ); + } +} + +// ── Evaluation ──────────────────────────────────────────────────────────────── + +fn run_evaluation( + evaluator: &dyn Evaluator, + n_games: usize, + mcts_cfg: &MctsConfig, + seed: u64, +) -> (Stats, Stats) { + let env = TrictracEnv; + let total = n_games * 2; + let mut as_p1 = Stats::default(); + let mut as_p2 = Stats::default(); + + for i in 0..total { + // Alternate sides: even games → MctsAgent as P1, odd → as P2. + let mcts_side = i % 2; + let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); + let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); + + let mcts_return = result[mcts_side]; + if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } + + let done = i + 1; + if done % 10 == 0 || done == total { + eprint!("\r [{done}/{total}] ", ); + } + } + eprintln!(); + (as_p1, as_p2) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let device: ::Device = Default::default(); + + // ── Load model ──────────────────────────────────────────────────────── + let evaluator: Box = match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => ResNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => ResNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => MlpNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => MlpNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + }; + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: 0.0, // no exploration noise during evaluation + dirichlet_eps: 0.0, + temperature: 0.0, // greedy action selection + }; + + // ── Header ──────────────────────────────────────────────────────────── + let ckpt_label = args.checkpoint + .as_deref() + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .unwrap_or("random weights"); + + println!(); + println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", + args.arch, args.n_sim); + println!("Games per side: {} | Total: {} | Seed: {}", + args.n_games, args.n_games * 2, args.seed); + println!(); + + // ── Run ─────────────────────────────────────────────────────────────── + let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); + + // ── Results ─────────────────────────────────────────────────────────── + println!("MctsAgent as P1 (White):"); + as_p1.print(); + + println!("MctsAgent as P2 (Black):"); + as_p2.print(); + + let combined_wins = as_p1.wins + as_p2.wins; + let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; + let combined_wr = if combined_decisive == 0 { 0.5 } + else { combined_wins as f64 / combined_decisive as f64 }; + + println!(); + println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", + combined_wr * 100.0, combined_wins, combined_decisive); + println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", + as_p1.win_rate_decisive() * 100.0, + as_p2.win_rate_decisive() * 100.0); +} From 150efe302fc1294df640340a0a162e8e366db59b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Mar 2026 12:28:39 +0100 Subject: [PATCH 12/16] feat(spiel_bot): az_train training command --- spiel_bot/src/bin/az_train.rs | 314 ++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 spiel_bot/src/bin/az_train.rs diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs new file mode 100644 index 0000000..ab385c2 --- /dev/null +++ b/spiel_bot/src/bin/az_train.rs @@ -0,0 +1,314 @@ +//! AlphaZero self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh (MLP, default settings) +//! cargo run -p spiel_bot --bin az_train --release +//! +//! # ResNet, 200 iterations, save every 20 +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/ +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin az_train --release -- \ +//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden N` | 256/512 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | `100` | Training iterations | +//! | `--n-games N` | `10` | Self-play games per iteration | +//! | `--n-train N` | `20` | Gradient steps per iteration | +//! | `--n-sim N` | `100` | MCTS simulations per move | +//! | `--batch N` | `64` | Mini-batch size | +//! | `--replay-cap N` | `50000` | Replay buffer capacity | +//! | `--lr F` | `1e-3` | Peak (initial) learning rate | +//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) | +//! | `--c-puct F` | `1.5` | PUCT exploration constant | +//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha | +//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight | +//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 | +//! | `--save-every N` | `10` | Save checkpoint every N iterations | +//! | `--seed N` | `42` | RNG seed | +//! | `--resume PATH` | (none) | Load weights from checkpoint before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{ + BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step, + }, + env::TrictracEnv, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + arch: String, + hidden: Option, + out_dir: PathBuf, + n_iter: usize, + n_games: usize, + n_train: usize, + n_sim: usize, + batch_size: usize, + replay_cap: usize, + lr: f64, + lr_min: f64, + c_puct: f32, + dirichlet_alpha: f32, + dirichlet_eps: f32, + temp_drop: usize, + save_every: usize, + seed: u64, + resume: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + arch: "mlp".into(), + hidden: None, + out_dir: PathBuf::from("checkpoints"), + n_iter: 100, + n_games: 10, + n_train: 20, + n_sim: 100, + batch_size: 64, + replay_cap: 50_000, + lr: 1e-3, + lr_min: 1e-4, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temp_drop: 30, + save_every: 10, + seed: 42, + resume: None, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--arch" => { i += 1; a.arch = raw[i].clone(); } + "--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); } + "--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); } + "--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); } + "--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); } + "--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); } + "--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); } + "--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); } + "--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +/// Generic training loop, parameterised over the network type. +/// +/// `save_fn` receives the **training-backend** model and the target path; +/// it is called in the match arm where the concrete network type is known. +fn train_loop( + mut model: N, + save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>, + args: &Args, +) +where + N: PolicyValueNet + AutodiffModule + Clone, + >::InnerModule: PolicyValueNet + Send + 'static, +{ + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + // Type is inferred as OptimizerAdaptor at the call site. + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(args.replay_cap); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + // Total gradient steps (used for cosine LR denominator). + let total_train_steps = (args.n_iter * args.n_train).max(1); + let mut global_step = 0usize; + + println!( + "\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}", + "", args.arch, args.n_iter, args.n_games, args.n_sim, "" + ); + + for iter in 0..args.n_iter { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + // Convert to inference backend (zero autodiff overhead). + let infer_model: >::InnerModule = model.valid(); + let evaluator: BurnEvaluator>::InnerModule> = + BurnEvaluator::new(infer_model, infer_device.clone()); + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: args.dirichlet_alpha, + dirichlet_eps: args.dirichlet_eps, + temperature: 1.0, + }; + + let temp_drop = args.temp_drop; + let temperature_fn = |step: usize| -> f32 { + if step < temp_drop { 1.0 } else { 0.0 } + }; + + let mut new_samples = 0usize; + for _ in 0..args.n_games { + let samples = + generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.batch_size { + for _ in 0..args.n_train { + let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + let batch: Vec = replay + .sample_batch(args.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = + train_step(model, &mut optimizer, &batch, &train_device, lr); + model = m; + loss_sum += loss; + n_steps += 1; + global_step += 1; + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s", + iter + 1, + args.n_iter, + replay.len(), + new_samples, + avg_loss, + lr_now, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.n_iter; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1)); + match save_fn(&model, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nTraining complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + // Create output directory if it doesn't exist. + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + + match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + ResNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => ResNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &ResNet, path: &Path| { + // Save via inference model to avoid autodiff record overhead. + m.valid().save(path) + }, + &args, + ); + } + + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + + let model = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + MlpNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => MlpNet::::new(&cfg, &train_device), + }; + + train_loop( + model, + &|m: &MlpNet, path: &Path| m.valid().save(path), + &args, + ); + } + } +} From 7c0f230e3de58319bc26558aed5ca3153d5a58fe Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 08:19:24 +0100 Subject: [PATCH 13/16] doc: tensor research --- doc/tensor_research.md | 253 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 doc/tensor_research.md diff --git a/doc/tensor_research.md b/doc/tensor_research.md new file mode 100644 index 0000000..b0d0ede --- /dev/null +++ b/doc/tensor_research.md @@ -0,0 +1,253 @@ +# Tensor research + +## Current tensor anatomy + +[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) +[24] active player color: 0 or 1 +[25] turn_stage: 1–5 +[26–27] dice values (raw 1–6) +[28–31] white: points, holes, can_bredouille, can_big_bredouille +[32–35] black: same +───────────────────────────────── +Total 36 floats + +The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's +AlphaZero uses a fully-connected network. + +### Fundamental problems with the current encoding + +1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network + must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with + all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. + +2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient + flow during training is uneven. + +3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it + triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers + produces a score. Including this explicitly is the single highest-value addition. + +4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely + different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. + +5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the + starting position). It's in the Player struct but not exported. + +## Key Trictrac distinctions from backgammon that shape the encoding + +| Concept | Backgammon | Trictrac | +| ------------------------- | ---------------------- | --------------------------------------------------------- | +| Hitting a blot | Removes checker to bar | Scores points, checker stays | +| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | +| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | +| 3-checker field | Safe with spare | Safe with spare | +| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | +| Both colors on a field | Impossible | Perfectly legal | +| Rest corner (field 12/13) | Does not exist | Special two-checker rules | + +The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary +indicators directly teaches the network the phase transitions the game hinges on. + +## Options + +### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) + +The minimum viable improvement. + +For each of the 24 fields, encode own and opponent separately with 4 indicators each: + +own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) +own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) +own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) +own_x[i]: max(0, count − 3) (overflow) +opp_1[i]: same for opponent +… + +Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). + +Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 +Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured +overhead is negligible. +Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from +scratch will be substantially faster and the converged policy quality better. + +### Option B — Option A + Trictrac-specific derived features (flat 1D) + +Recommended starting point. + +Add on top of Option A: + +// Quarter fill status — the single most important derived feature +quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) +quarter_filled_opp[q] (q=0..3): same for opponent +→ 8 values + +// Exit readiness +can_exit_own: 1.0 if all own checkers are in fields 19–24 +can_exit_opp: same for opponent +→ 2 values + +// Rest corner status (field 12/13) +own_corner_taken: 1.0 if field 12 has ≥2 own checkers +opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers +→ 2 values + +// Jan de 3 coups counter (normalized) +dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) +→ 1 value + +Size: 204 + 8 + 2 + 2 + 1 = 217 +Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve +the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes +expensive inference from the network. + +### Option C — Option B + richer positional features (flat 1D) + +More complete, higher sample efficiency, minor extra cost. + +Add on top of Option B: + +// Per-quarter fill fraction — how close to filling each quarter +own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 +opp_quarter_fill_fraction[q] (q=0..3): same for opponent +→ 8 values + +// Blot counts — number of own/opponent single-checker fields globally +// (tells the network at a glance how much battage risk/opportunity exists) +own_blot_count: (number of own fields with exactly 1 checker) / 15.0 +opp_blot_count: same for opponent +→ 2 values + +// Bredouille would-double multiplier (already present, but explicitly scaled) +// No change needed, already binary + +Size: 217 + 8 + 2 = 227 +Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network +from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). + +### Option D — 2D spatial tensor {K, 24} + +For CNN-based networks. Best eventual architecture but requires changing the training setup. + +Shape {14, 24} — 14 feature channels over 24 field positions: + +Channel 0: own_count_1 (blot) +Channel 1: own_count_2 +Channel 2: own_count_3 +Channel 3: own_count_overflow (float) +Channel 4: opp_count_1 +Channel 5: opp_count_2 +Channel 6: opp_count_3 +Channel 7: opp_count_overflow +Channel 8: own_corner_mask (1.0 at field 12) +Channel 9: opp_corner_mask (1.0 at field 13) +Channel 10: final_quarter_mask (1.0 at fields 19–24) +Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) +Channel 12: quarter_filled_opp (same for opponent) +Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) + +Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform +value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" +row by returning shape {K, 25} with position 0 holding global features. + +Size: 14 × 24 + few global channels ≈ 336–384 +C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated +accordingly. +Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's +alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, +W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. +Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel +size 2–3 captures adjacent-field "tout d'une" interactions. + +### On 3D tensors + +Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is +the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and +field-within-quarter patterns jointly. + +However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The +extra architecture work makes this premature unless you're already building a custom network. The information content +is the same as Option D. + +### Recommendation + +Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and +the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions +(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. + +Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when +the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. + +Costs summary: + +| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | +| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | +| Current | 36 | — | — | — | baseline | +| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | +| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | +| C | 227 | to_vec() rewrite | constant update | none | large + moderate | +| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | + +One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, +the new to_vec() must express everything from the active player's perspective (which the mirror already handles for +the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, +which they will be if computed inside to_vec(). + +## Other algorithms + +The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. + +### 1. Without MCTS, feature quality matters more + +AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred +simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — +the network must learn the full strategic structure directly from gradient updates. + +This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, +not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less +sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from +scratch. + +### 2. Normalization becomes mandatory, not optional + +AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps +Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw +values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can +be unstable. + +For DQN/PPO, every feature in the tensor should be in [0, 1]: + +dice values: / 6.0 +checker counts: overflow channel / 12.0 +points: / 12.0 +holes: / 12.0 +dice_roll_count: / 3.0 (clamped) + +Booleans and the TD-Gammon binary indicators are already in [0, 1]. + +### 3. The shape question depends on architecture, not algorithm + +| Architecture | Shape | When to use | +| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | +| MLP | {217} flat | Any algorithm, simplest baseline | +| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | +| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | +| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | + +The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want +convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet +template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all +— you just reshape the flat tensor in Python before passing it to the network. + +### One thing that genuinely changes: value function perspective + +AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This +works well. + +DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit +current-player indicator), because a single Q-network estimates action values for both players simultaneously. With +the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must +learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. +This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. + +Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. From e7d13c9a02480da812e13ecb762440e3937d3e7f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 22:12:52 +0100 Subject: [PATCH 14/16] feat(spiel_bot): dqn --- spiel_bot/src/bin/dqn_train.rs | 251 +++++++++++++++++++++++++++++ spiel_bot/src/dqn/episode.rs | 247 +++++++++++++++++++++++++++++ spiel_bot/src/dqn/mod.rs | 232 +++++++++++++++++++++++++++ spiel_bot/src/dqn/trainer.rs | 278 +++++++++++++++++++++++++++++++++ spiel_bot/src/env/trictrac.rs | 12 ++ spiel_bot/src/lib.rs | 1 + spiel_bot/src/network/mod.rs | 14 ++ spiel_bot/src/network/qnet.rs | 147 +++++++++++++++++ store/src/game.rs | 10 ++ 9 files changed, 1192 insertions(+) create mode 100644 spiel_bot/src/bin/dqn_train.rs create mode 100644 spiel_bot/src/dqn/episode.rs create mode 100644 spiel_bot/src/dqn/mod.rs create mode 100644 spiel_bot/src/dqn/trainer.rs create mode 100644 spiel_bot/src/network/qnet.rs diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs new file mode 100644 index 0000000..0ebe978 --- /dev/null +++ b/spiel_bot/src/bin/dqn_train.rs @@ -0,0 +1,251 @@ +//! DQN self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh with default settings +//! cargo run -p spiel_bot --bin dqn_train --release +//! +//! # Custom hyperparameters +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--hidden N` | 256 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | 100 | Training iterations | +//! | `--n-games N` | 10 | Self-play games per iteration | +//! | `--n-train N` | 20 | Gradient steps per iteration | +//! | `--batch N` | 64 | Mini-batch size | +//! | `--replay-cap N` | 50000 | Replay buffer capacity | +//! | `--lr F` | 1e-3 | Adam learning rate | +//! | `--epsilon-start F` | 1.0 | Initial exploration rate | +//! | `--epsilon-end F` | 0.05 | Final exploration rate | +//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | +//! | `--gamma F` | 0.99 | Discount factor | +//! | `--target-update N` | 500 | Hard-update target net every N steps | +//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | +//! | `--save-every N` | 10 | Save checkpoint every N iterations | +//! | `--seed N` | 42 | RNG seed | +//! | `--resume PATH` | (none) | Load weights before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + dqn::{ + DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, + generate_dqn_episode, hard_update, linear_epsilon, + }, + env::TrictracEnv, + network::{QNet, QNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + hidden: usize, + out_dir: PathBuf, + save_every: usize, + seed: u64, + resume: Option, + config: DqnConfig, +} + +impl Default for Args { + fn default() -> Self { + Self { + hidden: 256, + out_dir: PathBuf::from("checkpoints"), + save_every: 10, + seed: 42, + resume: None, + config: DqnConfig::default(), + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } + "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } + "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } + "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } + "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } + "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } + "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } + "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +fn train_loop( + mut q_net: QNet, + cfg: &QNetConfig, + save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, + args: &Args, +) { + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + let mut optimizer = AdamConfig::new().init(); + let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + let mut target_net: QNet = hard_update::(&q_net); + let mut global_step = 0usize; + let mut epsilon = args.config.epsilon_start; + + println!( + "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", + "", args.config.n_iterations, args.config.n_games_per_iter, + args.config.n_train_steps_per_iter, "" + ); + + for iter in 0..args.config.n_iterations { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + let infer_q: QNet = q_net.valid(); + let mut new_samples = 0usize; + + for _ in 0..args.config.n_games_per_iter { + let samples = generate_dqn_episode( + &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, + ); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.config.batch_size { + for _ in 0..args.config.n_train_steps_per_iter { + let batch: Vec<_> = replay + .sample_batch(args.config.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + // Target Q-values computed on the inference backend. + let target_q = compute_target_q( + &target_net, &batch, cfg.action_size, &infer_device, + ); + + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &target_q, + &train_device, args.config.learning_rate, args.config.gamma, + ); + q_net = q; + loss_sum += loss; + n_steps += 1; + global_step += 1; + + // Hard-update target net every target_update_freq steps. + if global_step % args.config.target_update_freq == 0 { + target_net = hard_update::(&q_net); + } + + // Linear epsilon decay. + epsilon = linear_epsilon( + args.config.epsilon_start, + args.config.epsilon_end, + global_step, + args.config.epsilon_decay_steps, + ); + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", + iter + 1, + args.config.n_iterations, + replay.len(), + new_samples, + avg_loss, + epsilon, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.config.n_iterations; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); + match save_fn(&q_net, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nDQN training complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; + + let q_net = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + QNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => QNet::::new(&cfg, &train_device), + }; + + train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); +} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs new file mode 100644 index 0000000..aca1343 --- /dev/null +++ b/spiel_bot/src/dqn/episode.rs @@ -0,0 +1,247 @@ +//! DQN self-play episode generation. +//! +//! Both players share the same Q-network (the [`TrictracEnv`] handles board +//! mirroring so that each player always acts from "White's perspective"). +//! Transitions for both players are stored in the returned sample list. +//! +//! # Reward +//! +//! After each full decision (action applied and the state has advanced through +//! any intervening chance nodes back to the same player's next turn), the +//! reward is: +//! +//! ```text +//! r = (my_total_score_now − my_total_score_then) +//! − (opp_total_score_now − opp_total_score_then) +//! ``` +//! +//! where `total_score = holes × 12 + points`. +//! +//! # Transition structure +//! +//! We use a "pending transition" per player. When a player acts again, we +//! *complete* the previous pending transition by filling in `next_obs`, +//! `next_legal`, and computing `reward`. Terminal transitions are completed +//! when the game ends. + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::{GameEnv, TrictracEnv}; +use crate::network::QValueNet; +use super::DqnSample; + +// ── Internals ───────────────────────────────────────────────────────────────── + +struct PendingTransition { + obs: Vec, + action: usize, + /// Score snapshot `[p1_total, p2_total]` at the moment of the action. + score_before: [i32; 2], +} + +/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. +fn epsilon_greedy>( + q_net: &Q, + obs: &[f32], + legal: &[usize], + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, +) -> usize { + debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); + if rng.random::() < epsilon { + legal[rng.random_range(0..legal.len())] + } else { + let obs_tensor = Tensor::::from_data( + TensorData::new(obs.to_vec(), [1, obs.len()]), + device, + ); + let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); + legal + .iter() + .copied() + .max_by(|&a, &b| { + q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap() + } +} + +/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. +fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { + let opp_idx = 1 - player_idx; + ((score_after[player_idx] - score_before[player_idx]) + - (score_after[opp_idx] - score_before[opp_idx])) as f32 +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Play one full game and return all transitions for both players. +/// +/// - `q_net` uses the **inference backend** (no autodiff wrapper). +/// - `epsilon` in `[0, 1]`: probability of taking a random action. +/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). +pub fn generate_dqn_episode>( + env: &TrictracEnv, + q_net: &Q, + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, + reward_scale: f32, +) -> Vec { + let obs_size = env.obs_size(); + let mut state = env.new_game(); + let mut pending: [Option; 2] = [None, None]; + let mut samples: Vec = Vec::new(); + + loop { + // ── Advance past chance nodes ────────────────────────────────────── + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + let score_now = TrictracEnv::score_snapshot(&state); + + if env.current_player(&state).is_terminal() { + // Complete all pending transitions as terminal. + for player_idx in 0..2 { + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: vec![0.0; obs_size], + next_legal: vec![], + done: true, + }); + } + } + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + let legal = env.legal_actions(&state); + let obs = env.observation(&state, player_idx); + + // ── Complete the previous transition for this player ─────────────── + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: obs.clone(), + next_legal: legal.clone(), + done: false, + }); + } + + // ── Pick and apply action ────────────────────────────────────────── + let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); + env.apply(&mut state, action); + + // ── Record new pending transition ────────────────────────────────── + pending[player_idx] = Some(PendingTransition { + obs, + action, + score_before: score_now, + }); + } + + samples +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::network::{QNet, QNetConfig}; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } + + fn tiny_q() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn episode_terminates_and_produces_samples() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_obs_size_correct() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert_eq!(s.obs.len(), 217, "obs size mismatch"); + if s.done { + assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); + assert!(s.next_legal.is_empty()); + } else { + assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); + assert!(!s.next_legal.is_empty()); + } + } + } + + #[test] + fn episode_actions_within_action_space() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert!(s.action < 514, "action {} out of bounds", s.action); + } + } + + #[test] + fn greedy_episode_also_terminates() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty()); + } + + #[test] + fn at_least_one_done_sample() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + let n_done = samples.iter().filter(|s| s.done).count(); + // Two players, so 1 or 2 terminal transitions. + assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); + } + + #[test] + fn compute_reward_correct() { + // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. + let before = [2 * 12 + 10, 0]; + let after = [3 * 12 + 2, 0]; + let r = compute_reward(0, &before, &after); + assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); + } + + #[test] + fn compute_reward_with_opponent_scoring() { + // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. + let before = [0, 0]; + let after = [2, 3]; + let r = compute_reward(0, &before, &after); + assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); + } +} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs new file mode 100644 index 0000000..8c34fc1 --- /dev/null +++ b/spiel_bot/src/dqn/mod.rs @@ -0,0 +1,232 @@ +//! DQN: self-play data generation, replay buffer, and training step. +//! +//! # Algorithm +//! +//! Deep Q-Network with: +//! - **ε-greedy** exploration (linearly decayed). +//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where +//! `score = holes × 12 + points`. +//! - **Experience replay** with a fixed-capacity circular buffer. +//! - **Target network**: hard-copied from the online Q-net every +//! `target_update_freq` gradient steps for training stability. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | +//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | + +pub mod episode; +pub mod trainer; + +pub use episode::generate_dqn_episode; +pub use trainer::{compute_target_q, dqn_train_step, hard_update}; + +use std::collections::VecDeque; +use rand::Rng; + +// ── DqnSample ───────────────────────────────────────────────────────────────── + +/// One transition `(s, a, r, s', done)` collected during self-play. +#[derive(Clone, Debug)] +pub struct DqnSample { + /// Observation from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// Action index taken. + pub action: usize, + /// Per-turn reward: `my_score_delta − opponent_score_delta`. + pub reward: f32, + /// Next observation from the same player's perspective. + /// All-zeros when `done = true` (ignored by the TD target). + pub next_obs: Vec, + /// Legal actions at `next_obs`. Empty when `done = true`. + pub next_legal: Vec, + /// `true` when `next_obs` is a terminal state. + pub done: bool, +} + +// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── + +/// Fixed-capacity circular replay buffer for [`DqnSample`]s. +/// +/// When full, the oldest sample is evicted on push. +/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. +pub struct DqnReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl DqnReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } + } + + pub fn push(&mut self, sample: DqnSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { self.push(s); } + } + + pub fn len(&self) -> usize { self.data.len() } + pub fn is_empty(&self) -> bool { self.data.is_empty() } + + /// Sample up to `n` distinct samples without replacement. + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { + let len = self.data.len(); + let n = n.min(len); + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── DqnConfig ───────────────────────────────────────────────────────────────── + +/// Top-level DQN hyperparameters for the training loop. +#[derive(Debug, Clone)] +pub struct DqnConfig { + /// Initial exploration rate (1.0 = fully random). + pub epsilon_start: f32, + /// Final exploration rate after decay. + pub epsilon_end: f32, + /// Number of gradient steps over which ε decays linearly from start to end. + /// + /// Should be calibrated to the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). A value larger than that + /// means exploration never reaches `epsilon_end` during the run. + pub epsilon_decay_steps: usize, + /// Discount factor γ for the TD target. Typical: 0.99. + pub gamma: f32, + /// Hard-copy Q → target every this many gradient steps. + /// + /// Should be much smaller than the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). + pub target_update_freq: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Number of outer iterations (self-play + train). + pub n_iterations: usize, + /// Self-play games per iteration. + pub n_games_per_iter: usize, + /// Gradient steps per iteration. + pub n_train_steps_per_iter: usize, + /// Reward normalisation divisor. + /// + /// Per-turn rewards (score delta) are divided by this constant before being + /// stored. Without normalisation, rewards can reach ±24 (jan with + /// bredouille = 12 pts × 2), driving Q-values into the hundreds and + /// causing MSE loss to grow unboundedly. + /// + /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping + /// Q-value magnitudes in a stable range. Set to `1.0` to disable. + pub reward_scale: f32, +} + +impl Default for DqnConfig { + fn default() -> Self { + // Total gradient steps with these defaults = 500 × 20 = 10_000, + // so epsilon decays fully and the target is updated 100 times. + Self { + epsilon_start: 1.0, + epsilon_end: 0.05, + epsilon_decay_steps: 10_000, + gamma: 0.99, + target_update_freq: 100, + learning_rate: 1e-3, + batch_size: 64, + replay_capacity: 50_000, + n_iterations: 500, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + reward_scale: 12.0, + } + } +} + +/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. +pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { + if decay_steps == 0 || step >= decay_steps { + return end; + } + start + (end - start) * (step as f32 / decay_steps as f32) +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(reward: f32) -> DqnSample { + DqnSample { + obs: vec![0.0], + action: 0, + reward, + next_obs: vec![0.0], + next_legal: vec![0], + done: false, + } + } + + #[test] + fn push_and_len() { + let mut buf = DqnReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = DqnReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); + assert_eq!(buf.len(), 3); + assert_eq!(buf.data[0].reward, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = DqnReplayBuffer::new(20); + for i in 0..10 { buf.push(dummy(i as f32)); } + let mut rng = SmallRng::seed_from_u64(0); + assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); + } + + #[test] + fn linear_epsilon_start() { + assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_end() { + assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_monotone() { + let mut prev = f32::INFINITY; + for step in 0..=100 { + let e = linear_epsilon(1.0, 0.05, step, 100); + assert!(e <= prev + 1e-6); + prev = e; + } + } +} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs new file mode 100644 index 0000000..b8b0a02 --- /dev/null +++ b/spiel_bot/src/dqn/trainer.rs @@ -0,0 +1,278 @@ +//! DQN gradient step and target-network management. +//! +//! # TD target +//! +//! ```text +//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done +//! y_i = r_i if done +//! ``` +//! +//! # Loss +//! +//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) +//! and `y_i` (computed from the frozen target net). +//! +//! # Target network +//! +//! [`hard_update`] copies the online Q-net weights into the target net by +//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + Int, Tensor, TensorData, + backend::{AutodiffBackend, Backend}, + }, +}; + +use crate::network::QValueNet; +use super::DqnSample; + +// ── Target Q computation ───────────────────────────────────────────────────── + +/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. +/// +/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` +/// (their bootstrap term is dropped by the TD target anyway). +/// +/// The target network runs on the **inference backend** (`InferB`) with no +/// gradient tape, so this function is backend-agnostic (`B: Backend`). +pub fn compute_target_q>( + target_net: &Q, + batch: &[DqnSample], + action_size: usize, + device: &B::Device, +) -> Vec { + let batch_size = batch.len(); + + // Collect indices of non-done samples (done samples have no next state). + let non_done: Vec = batch + .iter() + .enumerate() + .filter(|(_, s)| !s.done) + .map(|(i, _)| i) + .collect(); + + if non_done.is_empty() { + return vec![0.0; batch_size]; + } + + let obs_size = batch[0].next_obs.len(); + let nd = non_done.len(); + + // Stack next observations for non-done samples → [nd, obs_size]. + let obs_flat: Vec = non_done + .iter() + .flat_map(|&i| batch[i].next_obs.iter().copied()) + .collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [nd, obs_size]), + device, + ); + + // Forward target net → [nd, action_size], then to Vec. + let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); + + // For each non-done sample, pick max Q over legal next actions. + let mut result = vec![0.0f32; batch_size]; + for (k, &i) in non_done.iter().enumerate() { + let legal = &batch[i].next_legal; + let offset = k * action_size; + let max_q = legal + .iter() + .map(|&a| q_flat[offset + a]) + .fold(f32::NEG_INFINITY, f32::max); + // If legal is empty (shouldn't happen for non-done, but be safe): + result[i] = if max_q.is_finite() { max_q } else { 0.0 }; + } + result +} + +// ── Training step ───────────────────────────────────────────────────────────── + +/// Run one gradient step on `q_net` using `batch`. +/// +/// `target_max_q` must be pre-computed via [`compute_target_q`] using the +/// frozen target network and passed in here so that this function only +/// needs the **autodiff backend**. +/// +/// Returns the updated network and the scalar MSE loss. +pub fn dqn_train_step( + q_net: Q, + optimizer: &mut O, + batch: &[DqnSample], + target_max_q: &[f32], + device: &B::Device, + lr: f64, + gamma: f32, +) -> (Q, f32) +where + B: AutodiffBackend, + Q: QValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "dqn_train_step: empty batch"); + assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + + // ── Build observation tensor [B, obs_size] ──────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + + // ── Forward Q-net → [B, action_size] ───────────────────────────────── + let q_all = q_net.forward(obs_tensor); + + // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── + let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); + let action_tensor: Tensor = Tensor::::from_data( + TensorData::new(actions, [batch_size]), + device, + ) + .reshape([batch_size, 1]); // [B] → [B, 1] + let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] + + // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── + let targets: Vec = batch + .iter() + .zip(target_max_q.iter()) + .map(|(s, &max_q)| { + if s.done { s.reward } else { s.reward + gamma * max_q } + }) + .collect(); + let target_tensor = Tensor::::from_data( + TensorData::new(targets, [batch_size]), + device, + ); + + // ── MSE loss ────────────────────────────────────────────────────────── + let diff = q_pred - target_tensor.detach(); + let loss = (diff.clone() * diff).mean(); + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &q_net); + let q_net = optimizer.step(lr, q_net, grads); + + (q_net, loss_scalar) +} + +// ── Target network update ───────────────────────────────────────────────────── + +/// Hard-copy the online Q-net weights to a new target network. +/// +/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an +/// inference-backend module with identical weights. +pub fn hard_update>(q_net: &Q) -> Q::InnerModule { + q_net.valid() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + use crate::network::{QNet, QNetConfig}; + + type InferB = NdArray; + type TrainB = Autodiff>; + + fn infer_device() -> ::Device { Default::default() } + fn train_device() -> ::Device { Default::default() } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| DqnSample { + obs: vec![0.5f32; obs_size], + action: i % action_size, + reward: if i % 2 == 0 { 1.0 } else { -1.0 }, + next_obs: vec![0.5f32; obs_size], + next_legal: vec![0, 1], + done: i == n - 1, + }) + .collect() + } + + #[test] + fn compute_target_q_length() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 8); + } + + #[test] + fn compute_target_q_done_is_zero() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + // Single done sample. + let batch = vec![DqnSample { + obs: vec![0.0; 4], + action: 0, + reward: 5.0, + next_obs: vec![0.0; 4], + next_legal: vec![], + done: true, + }]; + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 1); + assert_eq!(tq[0], 0.0); + } + + #[test] + fn train_step_returns_finite_loss() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + + #[test] + fn train_step_loss_decreases() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(16, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, + ); + q_net = q; + assert!(loss.is_finite()); + prev_loss = loss; + } + assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn hard_update_copies_weights() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = hard_update::(&q_net); + + let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); + let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); + // After hard_update the target produces finite outputs. + assert!(q_out.iter().all(|v| v.is_finite())); + } +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs index 99ba058..8dc3676 100644 --- a/spiel_bot/src/env/trictrac.rs +++ b/spiel_bot/src/env/trictrac.rs @@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv { } } +// ── DQN helpers ─────────────────────────────────────────────────────────────── + +impl TrictracEnv { + /// Score snapshot for DQN reward computation. + /// + /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. + /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). + pub fn score_snapshot(s: &GameState) -> [i32; 2] { + [s.total_score(1), s.total_score(2)] + } +} + // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 23895b9..9dfb4de 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,4 +1,5 @@ pub mod alphazero; +pub mod dqn; pub mod env; pub mod mcts; pub mod network; diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs index df710e9..64f93ec 100644 --- a/spiel_bot/src/network/mod.rs +++ b/spiel_bot/src/network/mod.rs @@ -43,9 +43,11 @@ //! before passing to softmax. pub mod mlp; +pub mod qnet; pub mod resnet; pub use mlp::{MlpConfig, MlpNet}; +pub use qnet::{QNet, QNetConfig}; pub use resnet::{ResNet, ResNetConfig}; use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; @@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; /// - `obs`: `[batch, obs_size]` /// - policy output: `[batch, action_size]` — raw logits (no softmax applied) /// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// /// Note: `Sync` is intentionally absent — Burn's `Module` internally uses /// `OnceCell` for lazy parameter initialisation, which is not `Sync`. /// Use an `Arc>` wrapper if cross-thread sharing is needed. pub trait PolicyValueNet: Module + Send + 'static { fn forward(&self, obs: Tensor) -> (Tensor, Tensor); } + +/// A neural network that outputs one Q-value per action. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - output: `[batch, action_size]` — raw Q-values (no activation) +/// +/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. +pub trait QValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> Tensor; +} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs new file mode 100644 index 0000000..1737f72 --- /dev/null +++ b/spiel_bot/src/network/qnet.rs @@ -0,0 +1,147 @@ +//! Single-headed Q-value network for DQN. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! → Linear(hidden → action_size) ← raw Q-values, no activation +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{activation::relu, backend::Backend, Tensor}, +}; +use std::path::Path; + +use super::QValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`QNet`]. +#[derive(Debug, Clone)] +pub struct QNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for QNetConfig { + fn default() -> Self { + Self { obs_size: 217, action_size: 514, hidden_size: 256 } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Two-hidden-layer MLP that outputs one Q-value per action. +#[derive(Module, Debug)] +pub struct QNet { + fc1: Linear, + fc2: Linear, + q_head: Linear, +} + +impl QNet { + /// Construct a fresh network with random weights. + pub fn new(config: &QNetConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl QValueNet for QNet { + fn forward(&self, obs: Tensor) -> Tensor { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + self.q_head.forward(x) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + + fn default_net() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn forward_output_shape() { + let net = default_net(); + let obs = Tensor::zeros([4, 217], &device()); + let q = net.forward(obs); + assert_eq!(q.dims(), [4, 514]); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let q = net.forward(Tensor::zeros([1, 217], &device())); + assert_eq!(q.dims(), [1, 514]); + } + + #[test] + fn q_values_not_all_equal() { + let net = default_net(); + let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) + .into_data().to_vec().unwrap(); + let first = q[0]; + assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); + } + + #[test] + fn custom_config_shapes() { + let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; + let net = QNet::::new(&cfg, &device()); + let q = net.forward(Tensor::zeros([3, 10], &device())); + assert_eq!(q.dims(), [3, 20]); + } + + #[test] + fn save_load_preserves_weights() { + let net = default_net(); + let obs = Tensor::::ones([2, 217], &device()); + let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); + + let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); + net.save(&path).expect("save failed"); + + let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); + let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); + + for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); + } + let _ = std::fs::remove_file(path); + } +} diff --git a/store/src/game.rs b/store/src/game.rs index 2fde45c..e4e938c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -1011,6 +1011,16 @@ impl GameState { self.mark_points(player_id, points) } + /// Total accumulated score for a player: `holes × 12 + points`. + /// + /// Returns `0` if `player_id` is not found (e.g. before `init_player`). + pub fn total_score(&self, player_id: PlayerId) -> i32 { + self.players + .get(&player_id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + } + fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false; From ad30d09311aa34e3c5a48b3be4ac6b89ffa7fd8b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 12 Mar 2026 21:17:14 +0100 Subject: [PATCH 15/16] feat(spiel_bot): cli spiel_bot strategy --- Cargo.lock | 2 + client_cli/Cargo.toml | 7 +- client_cli/src/app.rs | 22 ++++ client_cli/src/main.rs | 10 +- spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/strategy.rs | 242 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 281 insertions(+), 4 deletions(-) create mode 100644 spiel_bot/src/strategy.rs diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..fa260cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "trictrac-bot", "trictrac-store", ] @@ -6854,6 +6855,7 @@ dependencies = [ "pico-args", "pretty_assertions", "renet", + "spiel_bot", "trictrac-bot", "trictrac-store", ] diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index e48a249..52318cb 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -3,7 +3,9 @@ name = "trictrac-client_cli" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "client_cli" +path = "src/main.rs" [dependencies] anyhow = "1.0.75" @@ -12,7 +14,8 @@ pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" trictrac-store = { path = "../store" } -trictrac-bot = { path = "../bot" } +trictrac-bot = { path = "../bot" } +spiel_bot = { path = "../spiel_bot" } itertools = "0.13.0" env_logger = "0.11.6" log = "0.4.20" diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index b803efe..ab61451 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,3 +1,4 @@ +use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy}; use trictrac_bot::{ BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, @@ -56,6 +57,27 @@ impl App { Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string())) as Box) } + "az" => { + Some(Box::new(AzBotStrategy::new_mlp(None)) as Box) + } + s if s.starts_with("az:") && !s.starts_with("az-") => { + let path = s.trim_start_matches("az:"); + Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box) + } + "az-resnet" => { + Some(Box::new(AzBotStrategy::new_resnet(None)) as Box) + } + s if s.starts_with("az-resnet:") => { + let path = s.trim_start_matches("az-resnet:"); + Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box) + } + "az-dqn" => { + Some(Box::new(DqnSpielBotStrategy::new(None)) as Box) + } + s if s.starts_with("az-dqn:") => { + let path = s.trim_start_matches("az-dqn:"); + Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box) + } _ => None, }) .collect() diff --git a/client_cli/src/main.rs b/client_cli/src/main.rs index 0107b43..e06299b 100644 --- a/client_cli/src/main.rs +++ b/client_cli/src/main.rs @@ -23,8 +23,14 @@ OPTIONS: - dummy: Default strategy selecting the first valid move - ai: AI strategy using the default model at models/trictrac_ppo.zip - ai:/path/to/model.zip: AI strategy using a custom model - - dqn: DQN strategy using native Rust implementation with Burn - - dqn:/path/to/model: DQN strategy using a custom model + - dqnburn: DQN strategy (burn-rl backend) + - dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model + - az: AlphaZero MlpNet (random weights) + - az:/path/to/model.mpk: AlphaZero MlpNet checkpoint + - az-resnet: AlphaZero ResNet (random weights) + - az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint + - az-dqn: DQN QNet (random weights, first-legal-move fallback) + - az-dqn:/path/to/model.mpk: DQN QNet checkpoint ARGS: diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..b541adc 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] trictrac-store = { path = "../store" } +trictrac-bot = { path = "../bot" } anyhow = "1" rand = "0.9" rand_distr = "0.5" diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 9dfb4de..cf6d865 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -3,3 +3,4 @@ pub mod dqn; pub mod env; pub mod mcts; pub mod network; +pub mod strategy; diff --git a/spiel_bot/src/strategy.rs b/spiel_bot/src/strategy.rs new file mode 100644 index 0000000..8309bf3 --- /dev/null +++ b/spiel_bot/src/strategy.rs @@ -0,0 +1,242 @@ +//! [`BotStrategy`] implementations backed by `spiel_bot` models. +//! +//! | Strategy struct | Network | CLI token | +//! |-----------------|---------|-----------| +//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` | +//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` | +//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` | +//! +//! All strategies operate from **White's perspective** (player_id = 1) internally; +//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black. + +use std::cell::RefCell; +use std::path::Path; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData}, +}; +use rand::{SeedableRng, rngs::SmallRng}; +use trictrac_bot::BotStrategy; +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction}, + CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId, +}; + +use crate::{ + alphazero::BurnEvaluator, + env::{GameEnv, TrictracEnv}, + mcts::{self, Evaluator, MctsConfig}, + network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig}, +}; + +type B = NdArray; + +/// Default MCTS simulations per move used by [`AzBotStrategy`]. +pub const AZ_BOT_N_SIM: usize = 50; + +// ── Shared helpers ───────────────────────────────────────────────────────────── + +/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state. +fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> { + match TrictracAction::from_action_index(action)?.to_event(game)? { + GameEvent::Move { moves, .. } => Some(moves), + _ => None, + } +} + +/// Fallback: return the first legal move from `MoveRules` (always succeeds). +fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&Color::White, &game.board, game.dice); + let moves = rules.get_possible_moves_sequences(true, vec![]); + *moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default())) +} + +// ── AzBotStrategy ───────────────────────────────────────────────────────────── + +/// AlphaZero bot usable as a [`BotStrategy`]. +/// +/// Supports both MlpNet and ResNet checkpoints through separate constructors. +/// Uses greedy (temperature = 0) MCTS for action selection. +/// +/// # Construction +/// +/// ```rust,ignore +/// // MlpNet with random weights +/// AzBotStrategy::new_mlp(None); +/// +/// // MlpNet from a checkpoint +/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk")); +/// +/// // ResNet from a checkpoint +/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk")); +/// ``` +pub struct AzBotStrategy { + game: GameState, + evaluator: Box, + mcts_config: MctsConfig, + /// Interior-mutable RNG so `choose_move(&self)` can drive MCTS. + rng: RefCell, +} + +impl std::fmt::Debug for AzBotStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AzBotStrategy") + .field("n_sim", &self.mcts_config.n_simulations) + .finish() + } +} + +impl AzBotStrategy { + fn from_evaluator(evaluator: Box) -> Self { + Self { + game: GameState::default(), + evaluator, + mcts_config: MctsConfig { + n_simulations: AZ_BOT_N_SIM, + dirichlet_alpha: 0.0, // no noise during play + dirichlet_eps: 0.0, + temperature: 0.0, // greedy selection + ..MctsConfig::default() + }, + rng: RefCell::new(SmallRng::seed_from_u64(42)), + } + } + + /// MlpNet-backed bot. `path = None` → random weights. + pub fn new_mlp(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 }; + let model = match path { + Some(p) => MlpNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az: load failed ({e}), using random weights"); + MlpNet::::new(&cfg, &device) + }), + None => MlpNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// ResNet-backed bot. `path = None` → random weights. + pub fn new_resnet(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 }; + let model = match path { + Some(p) => ResNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-resnet: load failed ({e}), using random weights"); + ResNet::::new(&cfg, &device) + }), + None => ResNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// Run MCTS and return the greedy best action index, or `None` if no legal moves. + fn best_action(&self) -> Option { + let env = TrictracEnv; + if env.legal_actions(&self.game).is_empty() { + return None; + } + let mut rng = self.rng.borrow_mut(); + let root = mcts::run_mcts( + &env, + &self.game, + self.evaluator.as_ref(), + &self.mcts_config, + &mut *rng, + ); + Some(mcts::select_action(&root, 0.0, &mut *rng)) + } +} + +impl BotStrategy for AzBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + // Action index 1 == TrictracAction::Go + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +} + +// ── DqnSpielBotStrategy ─────────────────────────────────────────────────────── + +/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`]. +/// +/// Selects actions by greedy argmax over Q-values, masked to legal moves. +/// When no checkpoint is provided the model falls back to the first legal move. +/// +/// # Construction +/// +/// ```rust,ignore +/// // No model — always picks first legal move +/// DqnSpielBotStrategy::new(None); +/// +/// // Trained checkpoint +/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk")); +/// ``` +#[derive(Debug)] +pub struct DqnSpielBotStrategy { + game: GameState, + model: Option>, +} + +impl DqnSpielBotStrategy { + /// Create a DQN bot. `path = None` → falls back to first legal move. + pub fn new(path: Option<&str>) -> Self { + let model = path.map(|p| { + let device: ::Device = Default::default(); + let cfg = QNetConfig::default(); + QNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-dqn: load failed ({e}), using random weights"); + QNet::::new(&cfg, &device) + }) + }); + Self { game: GameState::default(), model } + } + + /// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves. + fn best_action(&self) -> Option { + let model = self.model.as_ref()?; + let legal = get_valid_action_indices(&self.game).unwrap_or_default(); + if legal.is_empty() { + return None; + } + let device: ::Device = Default::default(); + let obs = self.game.to_tensor(); + let obs_t = Tensor::::from_data(TensorData::new(obs, [1, 217]), &device); + let q_vals: Vec = model.forward(obs_t).into_data().to_vec().unwrap(); + legal.into_iter().max_by(|&a, &b| { + q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + } +} + +impl BotStrategy for DqnSpielBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +} From cf50784a2387b976f87a1eb7795537016993fde9 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 11 Mar 2026 22:17:03 +0100 Subject: [PATCH 16/16] fix: --n-sim training parameter --- spiel_bot/src/mcts/mod.rs | 8 ++++---- spiel_bot/src/mcts/search.rs | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..eead171 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..4d36acc 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -156,7 +156,13 @@ pub(super) fn simulate( let returns = env .returns(&next_state) .expect("terminal node must have returns"); - returns[player_idx] + let v = returns[player_idx]; + // Update child stats so PUCT and mcts_policy count terminal visits. + // Store from player_idx's perspective so child.q() is directly usable + // by the parent's PUCT selection (high = good for the selecting player). + child.n += 1; + child.w += v; + v } else { let child_player = next_cp.index().unwrap(); let v = if crossed_chance { @@ -166,6 +172,13 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); + // Store from player_idx's (parent's) perspective so PUCT works correctly. + // `value` is from child_player's POV; negate when child is the opponent + // so that child.q() = expected return for the player CHOOSING this child. + // Without the negation, root would maximise the opponent's Q-value and + // systematically pick the worst action. + child.n += 1; + child.w += if child_player == player_idx { value } else { -value }; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player)