Compare commits

..

12 commits

View file

@ -26,7 +26,7 @@ for **self-play training**. Its goals:
### 1.1 Neural Network Frameworks ### 1.1 Neural Network Frameworks
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | | Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|-------|----------|-----|-----------|----------|-------| | --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | | **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | | **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | | **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
@ -45,7 +45,7 @@ switching `spiel_bot` to `tch-rs` is a one-line backend swap.
### 1.2 Other Key Crates ### 1.2 Other Key Crates
| Crate | Role | | Crate | Role |
|-------|------| | -------------------- | ----------------------------------------------------------------- |
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | | `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | | `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | | `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
@ -78,7 +78,7 @@ A Trictrac turn passes through up to six `TurnStage` values. Only two involve
genuine player choice: genuine player choice:
| TurnStage | Node type | Handler | | TurnStage | Node type | Handler |
|-----------|-----------|---------| | ---------------- | ------------------------------- | ------------------------------- |
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | | `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | | `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | | `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
@ -110,7 +110,7 @@ prohibitively expensive.
**C. Condition on dice in the observation (current approach)** **C. Condition on dice in the observation (current approach)**
Dice values are already encoded at indices [192193] of `to_tensor()`. The Dice values are already encoded at indices [192193] of `to_tensor()`. The
network naturally conditions on the rolled dice when it evaluates a position. network naturally conditions on the rolled dice when it evaluates a position.
MCTS only runs on player-decision nodes *after* the dice have been sampled; MCTS only runs on player-decision nodes _after_ the dice have been sampled;
chance nodes are bypassed by the environment wrapper (approach A). The policy chance nodes are bypassed by the environment wrapper (approach A). The policy
and value heads learn to play optimally given any dice pair. and value heads learn to play optimally given any dice pair.
@ -385,6 +385,7 @@ L = MSE(value_pred, z)
``` ```
Where: Where:
- `z` = game outcome (±1) from the active player's perspective - `z` = game outcome (±1) from the active player's perspective
- `π_mcts` = normalized MCTS visit counts at the root (the policy target) - `π_mcts` = normalized MCTS visit counts at the root (the policy target)
- Legal action masking is applied before computing CrossEntropy - Legal action masking is applied before computing CrossEntropy
@ -640,7 +641,7 @@ path = "src/bin/az_eval.rs"
## 10. Comparison: `bot` crate vs `spiel_bot` ## 10. Comparison: `bot` crate vs `spiel_bot`
| Aspect | `bot` (existing) | `spiel_bot` (proposed) | | Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|--------|-----------------|------------------------| | ---------------- | --------------------------- | -------------------------------------------- |
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | | State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | | Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
| Opponent | hardcoded random | self-play | | Opponent | hardcoded random | self-play |
@ -705,3 +706,77 @@ replace `TrictracEnvironment` with `TrictracEnv`).
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, 5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
## Implementation results
All benchmarks compile and run. Here's the complete results table:
| Group | Benchmark | Time |
| ------- | ----------------------- | --------------------- |
| env | apply_chance | 3.87 µs |
| | legal_actions | 1.91 µs |
| | observation (to_tensor) | 341 ns |
| | random_game (baseline) | 3.55 ms → 282 games/s |
| network | mlp_b1 hidden=64 | 94.9 µs |
| | mlp_b32 hidden=64 | 141 µs |
| | mlp_b1 hidden=256 | 352 µs |
| | mlp_b32 hidden=256 | 479 µs |
| mcts | zero_eval n=1 | 6.8 µs |
| | zero_eval n=5 | 23.9 µs |
| | zero_eval n=20 | 90.9 µs |
| | mlp64 n=1 | 203 µs |
| | mlp64 n=5 | 622 µs |
| | mlp64 n=20 | 2.30 ms |
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
| | trictrac n=2 | 145 ms → 7 games/s |
| train | mlp64 Adam b=16 | 1.93 ms |
| | mlp64 Adam b=64 | 2.68 ms |
Key observations:
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
- observation (217-value tensor): only 341 ns — not a bottleneck
- legal_actions: 1.9 µs — well optimised
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
- Training: 2.7 ms/step at batch=64 → 370 steps/s
### Summary of Step 8
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
- Output: win/draw/loss per side + combined decisive win rate
Typical usage after training:
cargo run -p spiel_bot --bin az_eval --release -- \
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
### az_train
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
cargo run -p spiel_bot --bin az_train --release
#### ResNet, more games, custom output dir
cargo run -p spiel_bot --bin az_train --release -- \
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
--save-every 20 --out checkpoints/
#### Resume from iteration 50
cargo run -p spiel_bot --bin az_train --release -- \
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
What the binary does each iteration:
1. Calls model.valid() to get a zero-overhead inference copy for self-play
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
5. Saves a .mpk checkpoint every --save-every iterations and always on the last