diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md index 7e5ed1f..a8863af 100644 --- a/doc/spiel_bot_research.md +++ b/doc/spiel_bot_research.md @@ -3,12 +3,12 @@ ## 0. Context and Scope The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library -(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` +(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every other stage to an inline random-opponent loop. `spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency -for **self-play training**. Its goals: +for **self-play training**. Its goals: - Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") that works with Trictrac's multi-stage turn model and stochastic dice. @@ -25,12 +25,12 @@ for **self-play training**. Its goals: ### 1.1 Neural Network Frameworks -| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | -|-------|----------|-----|-----------|----------|-------| -| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | -| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | -| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | -| ndarray alone | no | no | yes | mature | array ops only; no autograd | +| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | +| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | +| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | +| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | +| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | +| ndarray alone | no | no | yes | mature | array ops only; no autograd | **Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ runtime needed, the `ndarray` backend is sufficient for CPU training and can @@ -39,33 +39,33 @@ changing one type alias. `tch-rs` would be the best choice for raw training throughput (it is the most battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the -pure-Rust constraint. If training speed becomes the bottleneck after prototyping, +pure-Rust constraint. If training speed becomes the bottleneck after prototyping, switching `spiel_bot` to `tch-rs` is a one-line backend swap. ### 1.2 Other Key Crates -| Crate | Role | -|-------|------| -| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | -| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | -| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | -| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | -| `anyhow` | error propagation (already used everywhere) | -| `indicatif` | training progress bars | -| `tracing` | structured logging per episode/iteration | +| Crate | Role | +| -------------------- | ----------------------------------------------------------------- | +| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | +| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | +| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | +| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | +| `anyhow` | error propagation (already used everywhere) | +| `indicatif` | training progress bars | +| `tracing` | structured logging per episode/iteration | ### 1.3 What `burn-rl` Provides (and Does Not) The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` -trait. It does **not** provide: +trait. It does **not** provide: - MCTS or any tree-search algorithm - Two-player self-play - Legal action masking during training - Chance-node handling -For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its +For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its own (simpler, more targeted) traits and implement MCTS from scratch. --- @@ -74,17 +74,17 @@ own (simpler, more targeted) traits and implement MCTS from scratch. ### 2.1 Multi-Stage Turn Model -A Trictrac turn passes through up to six `TurnStage` values. Only two involve +A Trictrac turn passes through up to six `TurnStage` values. Only two involve genuine player choice: -| TurnStage | Node type | Handler | -|-----------|-----------|---------| -| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | -| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | -| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | -| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | -| `Move` | **Player decision** | MCTS / policy network | -| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | +| TurnStage | Node type | Handler | +| ---------------- | ------------------------------- | ------------------------------- | +| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | +| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | +| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | +| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | +| `Move` | **Player decision** | MCTS / policy network | +| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | The environment wrapper advances through forced/chance stages automatically so that from the algorithm's perspective every node it sees is a genuine player @@ -92,26 +92,26 @@ decision. ### 2.2 Stochastic Dice in MCTS -AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice -introduce stochasticity. Three approaches exist: +AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice +introduce stochasticity. Three approaches exist: **A. Outcome sampling (recommended)** During each MCTS simulation, when a chance node is reached, sample one dice -outcome at random and continue. After many simulations the expected value -converges. This is the approach used by OpenSpiel's MCTS for stochastic games +outcome at random and continue. After many simulations the expected value +converges. This is the approach used by OpenSpiel's MCTS for stochastic games and requires no changes to the standard PUCT formula. **B. Chance-node averaging (expectimax)** At each chance node, expand all 21 unique dice pairs weighted by their -probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is +probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is exact but multiplies the branching factor by ~21 at every dice roll, making it prohibitively expensive. **C. Condition on dice in the observation (current approach)** -Dice values are already encoded at indices [192–193] of `to_tensor()`. The +Dice values are already encoded at indices [192–193] of `to_tensor()`. The network naturally conditions on the rolled dice when it evaluates a position. -MCTS only runs on player-decision nodes *after* the dice have been sampled; -chance nodes are bypassed by the environment wrapper (approach A). The policy +MCTS only runs on player-decision nodes _after_ the dice have been sampled; +chance nodes are bypassed by the environment wrapper (approach A). The policy and value heads learn to play optimally given any dice pair. **Use approach A + C together**: the environment samples dice automatically @@ -129,7 +129,7 @@ to an algorithm is already in the active player's perspective. A crucial difference from the existing `bot/` code: instead of penalizing invalid actions with `ERROR_REWARD`, the policy head logits are **masked** -before softmax — illegal action logits are set to `-inf`. This prevents the +before softmax — illegal action logits are set to `-inf`. This prevents the network from wasting capacity on illegal moves and eliminates the need for the penalty-reward hack. @@ -322,9 +322,9 @@ pub struct MctsConfig { ### 5.4 Handling Chance Nodes Inside MCTS When simulation reaches a Chance node (dice roll), the environment automatically -samples dice and advances to the next decision node. The MCTS tree does **not** -branch on dice outcomes — it treats the sampled outcome as the state. This -corresponds to "outcome sampling" (approach A from §2.2). Because each +samples dice and advances to the next decision node. The MCTS tree does **not** +branch on dice outcomes — it treats the sampled outcome as the state. This +corresponds to "outcome sampling" (approach A from §2.2). Because each simulation independently samples dice, the Q-values at player nodes converge to their expected value over many simulations. @@ -385,6 +385,7 @@ L = MSE(value_pred, z) ``` Where: + - `z` = game outcome (±1) from the active player's perspective - `π_mcts` = normalized MCTS visit counts at the root (the policy target) - Legal action masking is applied before computing CrossEntropy @@ -467,7 +468,7 @@ let samples: Vec = (0..n_games) Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. Self-play uses inference-only (no gradient tape), so a `NdArray` backend -(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with +(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with `Autodiff>`. For larger scale, a producer-consumer architecture (crossbeam-channel) separates @@ -639,22 +640,22 @@ path = "src/bin/az_eval.rs" ## 10. Comparison: `bot` crate vs `spiel_bot` -| Aspect | `bot` (existing) | `spiel_bot` (proposed) | -|--------|-----------------|------------------------| -| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | -| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | -| Opponent | hardcoded random | self-play | -| Invalid actions | penalise with reward | legal action mask (no penalty) | -| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | -| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | -| Burn dep | yes (0.20) | yes (0.20), same backend | -| `burn-rl` dep | yes | no | -| C++ dep | no | no | -| Python dep | no | no | -| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | +| Aspect | `bot` (existing) | `spiel_bot` (proposed) | +| ---------------- | --------------------------- | -------------------------------------------- | +| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | +| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | +| Opponent | hardcoded random | self-play | +| Invalid actions | penalise with reward | legal action mask (no penalty) | +| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | +| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | +| Burn dep | yes (0.20) | yes (0.20), same backend | +| `burn-rl` dep | yes | no | +| C++ dep | no | no | +| Python dep | no | no | +| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | The two crates are **complementary**: `bot` is a working DQN/PPO baseline; -`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The +`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The `TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just replace `TrictracEnvironment` with `TrictracEnv`). @@ -684,13 +685,13 @@ replace `TrictracEnvironment` with `TrictracEnv`). ## 12. Key Open Questions 1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. - AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever - scored more holes). Better option: normalize the score margin. The final + AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever + scored more holes). Better option: normalize the score margin. The final choice affects how the value head is trained. 2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). MCTS with 200 simulations per step means ~120k network evaluations per game. - At batch inference this is feasible on CPU; on GPU it becomes fast. Consider + At batch inference this is feasible on CPU; on GPU it becomes fast. Consider limiting `n_simulations` to 50–100 for early training. 3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). @@ -703,5 +704,79 @@ replace `TrictracEnvironment` with `TrictracEnv`). This is optional but reduces code duplication. 5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, - 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. + 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. + +## Implementation results + +All benchmarks compile and run. Here's the complete results table: + +| Group | Benchmark | Time | +| ------- | ----------------------- | --------------------- | +| env | apply_chance | 3.87 µs | +| | legal_actions | 1.91 µs | +| | observation (to_tensor) | 341 ns | +| | random_game (baseline) | 3.55 ms → 282 games/s | +| network | mlp_b1 hidden=64 | 94.9 µs | +| | mlp_b32 hidden=64 | 141 µs | +| | mlp_b1 hidden=256 | 352 µs | +| | mlp_b32 hidden=256 | 479 µs | +| mcts | zero_eval n=1 | 6.8 µs | +| | zero_eval n=5 | 23.9 µs | +| | zero_eval n=20 | 90.9 µs | +| | mlp64 n=1 | 203 µs | +| | mlp64 n=5 | 622 µs | +| | mlp64 n=20 | 2.30 ms | +| episode | trictrac n=1 | 51.8 ms → 19 games/s | +| | trictrac n=2 | 145 ms → 7 games/s | +| train | mlp64 Adam b=16 | 1.93 ms | +| | mlp64 Adam b=64 | 2.68 ms | + +Key observations: + +- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) +- observation (217-value tensor): only 341 ns — not a bottleneck +- legal_actions: 1.9 µs — well optimised +- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs +- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal +- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it +- Training: 2.7 ms/step at batch=64 → 370 steps/s + +### Summary of Step 8 + +spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: + +- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct +- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) +- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side +- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) +- Output: win/draw/loss per side + combined decisive win rate + +Typical usage after training: +cargo run -p spiel_bot --bin az_eval --release -- \ + --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 + +### az_train + +#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) + +cargo run -p spiel_bot --bin az_train --release + +#### ResNet, more games, custom output dir + +cargo run -p spiel_bot --bin az_train --release -- \ + --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ + --save-every 20 --out checkpoints/ + +#### Resume from iteration 50 + +cargo run -p spiel_bot --bin az_train --release -- \ + --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 + +What the binary does each iteration: + +1. Calls model.valid() to get a zero-overhead inference copy for self-play +2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) +3. Pushes samples into a ReplayBuffer (capacity --replay-cap) +4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min +5. Saves a .mpk checkpoint every --save-every iterations and always on the last