Compare commits

..

12 commits

View file

@ -3,12 +3,12 @@
## 0. Context and Scope ## 0. Context and Scope
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` (DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
other stage to an inline random-opponent loop. other stage to an inline random-opponent loop.
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency `spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
for **self-play training**. Its goals: for **self-play training**. Its goals:
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") - Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
that works with Trictrac's multi-stage turn model and stochastic dice. that works with Trictrac's multi-stage turn model and stochastic dice.
@ -25,12 +25,12 @@ for **self-play training**. Its goals:
### 1.1 Neural Network Frameworks ### 1.1 Neural Network Frameworks
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | | Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|-------|----------|-----|-----------|----------|-------| | --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | | **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | | **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | | **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
| ndarray alone | no | no | yes | mature | array ops only; no autograd | | ndarray alone | no | no | yes | mature | array ops only; no autograd |
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ **Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
runtime needed, the `ndarray` backend is sufficient for CPU training and can runtime needed, the `ndarray` backend is sufficient for CPU training and can
@ -39,33 +39,33 @@ changing one type alias.
`tch-rs` would be the best choice for raw training throughput (it is the most `tch-rs` would be the best choice for raw training throughput (it is the most
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
pure-Rust constraint. If training speed becomes the bottleneck after prototyping, pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
switching `spiel_bot` to `tch-rs` is a one-line backend swap. switching `spiel_bot` to `tch-rs` is a one-line backend swap.
### 1.2 Other Key Crates ### 1.2 Other Key Crates
| Crate | Role | | Crate | Role |
|-------|------| | -------------------- | ----------------------------------------------------------------- |
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | | `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | | `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | | `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | | `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
| `anyhow` | error propagation (already used everywhere) | | `anyhow` | error propagation (already used everywhere) |
| `indicatif` | training progress bars | | `indicatif` | training progress bars |
| `tracing` | structured logging per episode/iteration | | `tracing` | structured logging per episode/iteration |
### 1.3 What `burn-rl` Provides (and Does Not) ### 1.3 What `burn-rl` Provides (and Does Not)
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
trait. It does **not** provide: trait. It does **not** provide:
- MCTS or any tree-search algorithm - MCTS or any tree-search algorithm
- Two-player self-play - Two-player self-play
- Legal action masking during training - Legal action masking during training
- Chance-node handling - Chance-node handling
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
own (simpler, more targeted) traits and implement MCTS from scratch. own (simpler, more targeted) traits and implement MCTS from scratch.
--- ---
@ -74,17 +74,17 @@ own (simpler, more targeted) traits and implement MCTS from scratch.
### 2.1 Multi-Stage Turn Model ### 2.1 Multi-Stage Turn Model
A Trictrac turn passes through up to six `TurnStage` values. Only two involve A Trictrac turn passes through up to six `TurnStage` values. Only two involve
genuine player choice: genuine player choice:
| TurnStage | Node type | Handler | | TurnStage | Node type | Handler |
|-----------|-----------|---------| | ---------------- | ------------------------------- | ------------------------------- |
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | | `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | | `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | | `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | | `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
| `Move` | **Player decision** | MCTS / policy network | | `Move` | **Player decision** | MCTS / policy network |
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | | `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
The environment wrapper advances through forced/chance stages automatically so The environment wrapper advances through forced/chance stages automatically so
that from the algorithm's perspective every node it sees is a genuine player that from the algorithm's perspective every node it sees is a genuine player
@ -92,26 +92,26 @@ decision.
### 2.2 Stochastic Dice in MCTS ### 2.2 Stochastic Dice in MCTS
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
introduce stochasticity. Three approaches exist: introduce stochasticity. Three approaches exist:
**A. Outcome sampling (recommended)** **A. Outcome sampling (recommended)**
During each MCTS simulation, when a chance node is reached, sample one dice During each MCTS simulation, when a chance node is reached, sample one dice
outcome at random and continue. After many simulations the expected value outcome at random and continue. After many simulations the expected value
converges. This is the approach used by OpenSpiel's MCTS for stochastic games converges. This is the approach used by OpenSpiel's MCTS for stochastic games
and requires no changes to the standard PUCT formula. and requires no changes to the standard PUCT formula.
**B. Chance-node averaging (expectimax)** **B. Chance-node averaging (expectimax)**
At each chance node, expand all 21 unique dice pairs weighted by their At each chance node, expand all 21 unique dice pairs weighted by their
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
exact but multiplies the branching factor by ~21 at every dice roll, making it exact but multiplies the branching factor by ~21 at every dice roll, making it
prohibitively expensive. prohibitively expensive.
**C. Condition on dice in the observation (current approach)** **C. Condition on dice in the observation (current approach)**
Dice values are already encoded at indices [192193] of `to_tensor()`. The Dice values are already encoded at indices [192193] of `to_tensor()`. The
network naturally conditions on the rolled dice when it evaluates a position. network naturally conditions on the rolled dice when it evaluates a position.
MCTS only runs on player-decision nodes *after* the dice have been sampled; MCTS only runs on player-decision nodes _after_ the dice have been sampled;
chance nodes are bypassed by the environment wrapper (approach A). The policy chance nodes are bypassed by the environment wrapper (approach A). The policy
and value heads learn to play optimally given any dice pair. and value heads learn to play optimally given any dice pair.
**Use approach A + C together**: the environment samples dice automatically **Use approach A + C together**: the environment samples dice automatically
@ -129,7 +129,7 @@ to an algorithm is already in the active player's perspective.
A crucial difference from the existing `bot/` code: instead of penalizing A crucial difference from the existing `bot/` code: instead of penalizing
invalid actions with `ERROR_REWARD`, the policy head logits are **masked** invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
before softmax — illegal action logits are set to `-inf`. This prevents the before softmax — illegal action logits are set to `-inf`. This prevents the
network from wasting capacity on illegal moves and eliminates the need for the network from wasting capacity on illegal moves and eliminates the need for the
penalty-reward hack. penalty-reward hack.
@ -322,9 +322,9 @@ pub struct MctsConfig {
### 5.4 Handling Chance Nodes Inside MCTS ### 5.4 Handling Chance Nodes Inside MCTS
When simulation reaches a Chance node (dice roll), the environment automatically When simulation reaches a Chance node (dice roll), the environment automatically
samples dice and advances to the next decision node. The MCTS tree does **not** samples dice and advances to the next decision node. The MCTS tree does **not**
branch on dice outcomes — it treats the sampled outcome as the state. This branch on dice outcomes — it treats the sampled outcome as the state. This
corresponds to "outcome sampling" (approach A from §2.2). Because each corresponds to "outcome sampling" (approach A from §2.2). Because each
simulation independently samples dice, the Q-values at player nodes converge to simulation independently samples dice, the Q-values at player nodes converge to
their expected value over many simulations. their expected value over many simulations.
@ -385,6 +385,7 @@ L = MSE(value_pred, z)
``` ```
Where: Where:
- `z` = game outcome (±1) from the active player's perspective - `z` = game outcome (±1) from the active player's perspective
- `π_mcts` = normalized MCTS visit counts at the root (the policy target) - `π_mcts` = normalized MCTS visit counts at the root (the policy target)
- Legal action masking is applied before computing CrossEntropy - Legal action masking is applied before computing CrossEntropy
@ -467,7 +468,7 @@ let samples: Vec<TrainSample> = (0..n_games)
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with (without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
`Autodiff<NdArray<f32>>`. `Autodiff<NdArray<f32>>`.
For larger scale, a producer-consumer architecture (crossbeam-channel) separates For larger scale, a producer-consumer architecture (crossbeam-channel) separates
@ -639,22 +640,22 @@ path = "src/bin/az_eval.rs"
## 10. Comparison: `bot` crate vs `spiel_bot` ## 10. Comparison: `bot` crate vs `spiel_bot`
| Aspect | `bot` (existing) | `spiel_bot` (proposed) | | Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|--------|-----------------|------------------------| | ---------------- | --------------------------- | -------------------------------------------- |
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | | State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | | Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
| Opponent | hardcoded random | self-play | | Opponent | hardcoded random | self-play |
| Invalid actions | penalise with reward | legal action mask (no penalty) | | Invalid actions | penalise with reward | legal action mask (no penalty) |
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | | Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | | Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
| Burn dep | yes (0.20) | yes (0.20), same backend | | Burn dep | yes (0.20) | yes (0.20), same backend |
| `burn-rl` dep | yes | no | | `burn-rl` dep | yes | no |
| C++ dep | no | no | | C++ dep | no | no |
| Python dep | no | no | | Python dep | no | no |
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | | Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
The two crates are **complementary**: `bot` is a working DQN/PPO baseline; The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The `spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just `TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
replace `TrictracEnvironment` with `TrictracEnv`). replace `TrictracEnvironment` with `TrictracEnv`).
@ -684,13 +685,13 @@ replace `TrictracEnvironment` with `TrictracEnv`).
## 12. Key Open Questions ## 12. Key Open Questions
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. 1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
scored more holes). Better option: normalize the score margin. The final scored more holes). Better option: normalize the score margin. The final
choice affects how the value head is trained. choice affects how the value head is trained.
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). 2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
MCTS with 200 simulations per step means ~120k network evaluations per game. MCTS with 200 simulations per step means ~120k network evaluations per game.
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
limiting `n_simulations` to 50100 for early training. limiting `n_simulations` to 50100 for early training.
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). 3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
@ -703,5 +704,79 @@ replace `TrictracEnvironment` with `TrictracEnv`).
This is optional but reduces code duplication. This is optional but reduces code duplication.
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, 5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
## Implementation results
All benchmarks compile and run. Here's the complete results table:
| Group | Benchmark | Time |
| ------- | ----------------------- | --------------------- |
| env | apply_chance | 3.87 µs |
| | legal_actions | 1.91 µs |
| | observation (to_tensor) | 341 ns |
| | random_game (baseline) | 3.55 ms → 282 games/s |
| network | mlp_b1 hidden=64 | 94.9 µs |
| | mlp_b32 hidden=64 | 141 µs |
| | mlp_b1 hidden=256 | 352 µs |
| | mlp_b32 hidden=256 | 479 µs |
| mcts | zero_eval n=1 | 6.8 µs |
| | zero_eval n=5 | 23.9 µs |
| | zero_eval n=20 | 90.9 µs |
| | mlp64 n=1 | 203 µs |
| | mlp64 n=5 | 622 µs |
| | mlp64 n=20 | 2.30 ms |
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
| | trictrac n=2 | 145 ms → 7 games/s |
| train | mlp64 Adam b=16 | 1.93 ms |
| | mlp64 Adam b=64 | 2.68 ms |
Key observations:
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
- observation (217-value tensor): only 341 ns — not a bottleneck
- legal_actions: 1.9 µs — well optimised
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
- Training: 2.7 ms/step at batch=64 → 370 steps/s
### Summary of Step 8
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
- Output: win/draw/loss per side + combined decisive win rate
Typical usage after training:
cargo run -p spiel_bot --bin az_eval --release -- \
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
### az_train
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
cargo run -p spiel_bot --bin az_train --release
#### ResNet, more games, custom output dir
cargo run -p spiel_bot --bin az_train --release -- \
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
--save-every 20 --out checkpoints/
#### Resume from iteration 50
cargo run -p spiel_bot --bin az_train --release -- \
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
What the binary does each iteration:
1. Calls model.valid() to get a zero-overhead inference copy for self-play
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
5. Saves a .mpk checkpoint every --save-every iterations and always on the last