Compare commits

..

12 commits

View file

@ -26,7 +26,7 @@ for **self-play training**. Its goals:
### 1.1 Neural Network Frameworks
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|-------|----------|-----|-----------|----------|-------|
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
@ -45,7 +45,7 @@ switching `spiel_bot` to `tch-rs` is a one-line backend swap.
### 1.2 Other Key Crates
| Crate | Role |
| -------------------- | ----------------------------------------------------------------- |
|-------|------|
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
@ -78,7 +78,7 @@ A Trictrac turn passes through up to six `TurnStage` values. Only two involve
genuine player choice:
| TurnStage | Node type | Handler |
| ---------------- | ------------------------------- | ------------------------------- |
|-----------|-----------|---------|
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
@ -110,7 +110,7 @@ prohibitively expensive.
**C. Condition on dice in the observation (current approach)**
Dice values are already encoded at indices [192193] of `to_tensor()`. The
network naturally conditions on the rolled dice when it evaluates a position.
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
MCTS only runs on player-decision nodes *after* the dice have been sampled;
chance nodes are bypassed by the environment wrapper (approach A). The policy
and value heads learn to play optimally given any dice pair.
@ -385,7 +385,6 @@ L = MSE(value_pred, z)
```
Where:
- `z` = game outcome (±1) from the active player's perspective
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
- Legal action masking is applied before computing CrossEntropy
@ -641,7 +640,7 @@ path = "src/bin/az_eval.rs"
## 10. Comparison: `bot` crate vs `spiel_bot`
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
| ---------------- | --------------------------- | -------------------------------------------- |
|--------|-----------------|------------------------|
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
| Opponent | hardcoded random | self-play |
@ -706,77 +705,3 @@ replace `TrictracEnvironment` with `TrictracEnv`).
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
## Implementation results
All benchmarks compile and run. Here's the complete results table:
| Group | Benchmark | Time |
| ------- | ----------------------- | --------------------- |
| env | apply_chance | 3.87 µs |
| | legal_actions | 1.91 µs |
| | observation (to_tensor) | 341 ns |
| | random_game (baseline) | 3.55 ms → 282 games/s |
| network | mlp_b1 hidden=64 | 94.9 µs |
| | mlp_b32 hidden=64 | 141 µs |
| | mlp_b1 hidden=256 | 352 µs |
| | mlp_b32 hidden=256 | 479 µs |
| mcts | zero_eval n=1 | 6.8 µs |
| | zero_eval n=5 | 23.9 µs |
| | zero_eval n=20 | 90.9 µs |
| | mlp64 n=1 | 203 µs |
| | mlp64 n=5 | 622 µs |
| | mlp64 n=20 | 2.30 ms |
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
| | trictrac n=2 | 145 ms → 7 games/s |
| train | mlp64 Adam b=16 | 1.93 ms |
| | mlp64 Adam b=64 | 2.68 ms |
Key observations:
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
- observation (217-value tensor): only 341 ns — not a bottleneck
- legal_actions: 1.9 µs — well optimised
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
- Training: 2.7 ms/step at batch=64 → 370 steps/s
### Summary of Step 8
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
- Output: win/draw/loss per side + combined decisive win rate
Typical usage after training:
cargo run -p spiel_bot --bin az_eval --release -- \
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
### az_train
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
cargo run -p spiel_bot --bin az_train --release
#### ResNet, more games, custom output dir
cargo run -p spiel_bot --bin az_train --release -- \
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
--save-every 20 --out checkpoints/
#### Resume from iteration 50
cargo run -p spiel_bot --bin az_train --release -- \
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
What the binary does each iteration:
1. Calls model.valid() to get a zero-overhead inference copy for self-play
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
5. Saves a .mpk checkpoint every --save-every iterations and always on the last