Compare commits
12 commits
ff989a7ca0
...
3c0316e1b7
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c0316e1b7 | |||
| 8146733157 | |||
| 41b3fc5dad | |||
| 7d37eebe52 | |||
| 2329b76f7e | |||
| 27e0536978 | |||
| 0235fd46c2 | |||
| 0619cf6001 | |||
| 9606e175b8 | |||
| c0d844fdf6 | |||
| 0a4f139e1d | |||
| bc94d1fb3d |
1 changed files with 135 additions and 60 deletions
|
|
@ -26,7 +26,7 @@ for **self-play training**. Its goals:
|
|||
### 1.1 Neural Network Frameworks
|
||||
|
||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||
|-------|----------|-----|-----------|----------|-------|
|
||||
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|
||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||
|
|
@ -45,7 +45,7 @@ switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
|||
### 1.2 Other Key Crates
|
||||
|
||||
| Crate | Role |
|
||||
|-------|------|
|
||||
| -------------------- | ----------------------------------------------------------------- |
|
||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||
|
|
@ -78,7 +78,7 @@ A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
|||
genuine player choice:
|
||||
|
||||
| TurnStage | Node type | Handler |
|
||||
|-----------|-----------|---------|
|
||||
| ---------------- | ------------------------------- | ------------------------------- |
|
||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||
|
|
@ -110,7 +110,7 @@ prohibitively expensive.
|
|||
**C. Condition on dice in the observation (current approach)**
|
||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||
network naturally conditions on the rolled dice when it evaluates a position.
|
||||
MCTS only runs on player-decision nodes *after* the dice have been sampled;
|
||||
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
|
||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||
and value heads learn to play optimally given any dice pair.
|
||||
|
||||
|
|
@ -385,6 +385,7 @@ L = MSE(value_pred, z)
|
|||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `z` = game outcome (±1) from the active player's perspective
|
||||
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||
- Legal action masking is applied before computing CrossEntropy
|
||||
|
|
@ -640,7 +641,7 @@ path = "src/bin/az_eval.rs"
|
|||
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||
|
||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||
|--------|-----------------|------------------------|
|
||||
| ---------------- | --------------------------- | -------------------------------------------- |
|
||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||
| Opponent | hardcoded random | self-play |
|
||||
|
|
@ -705,3 +706,77 @@ replace `TrictracEnvironment` with `TrictracEnv`).
|
|||
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||
|
||||
## Implementation results
|
||||
|
||||
All benchmarks compile and run. Here's the complete results table:
|
||||
|
||||
| Group | Benchmark | Time |
|
||||
| ------- | ----------------------- | --------------------- |
|
||||
| env | apply_chance | 3.87 µs |
|
||||
| | legal_actions | 1.91 µs |
|
||||
| | observation (to_tensor) | 341 ns |
|
||||
| | random_game (baseline) | 3.55 ms → 282 games/s |
|
||||
| network | mlp_b1 hidden=64 | 94.9 µs |
|
||||
| | mlp_b32 hidden=64 | 141 µs |
|
||||
| | mlp_b1 hidden=256 | 352 µs |
|
||||
| | mlp_b32 hidden=256 | 479 µs |
|
||||
| mcts | zero_eval n=1 | 6.8 µs |
|
||||
| | zero_eval n=5 | 23.9 µs |
|
||||
| | zero_eval n=20 | 90.9 µs |
|
||||
| | mlp64 n=1 | 203 µs |
|
||||
| | mlp64 n=5 | 622 µs |
|
||||
| | mlp64 n=20 | 2.30 ms |
|
||||
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
|
||||
| | trictrac n=2 | 145 ms → 7 games/s |
|
||||
| train | mlp64 Adam b=16 | 1.93 ms |
|
||||
| | mlp64 Adam b=64 | 2.68 ms |
|
||||
|
||||
Key observations:
|
||||
|
||||
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
|
||||
- observation (217-value tensor): only 341 ns — not a bottleneck
|
||||
- legal_actions: 1.9 µs — well optimised
|
||||
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
|
||||
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
|
||||
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
|
||||
- Training: 2.7 ms/step at batch=64 → 370 steps/s
|
||||
|
||||
### Summary of Step 8
|
||||
|
||||
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
|
||||
|
||||
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
|
||||
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
|
||||
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
|
||||
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
|
||||
- Output: win/draw/loss per side + combined decisive win rate
|
||||
|
||||
Typical usage after training:
|
||||
cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
|
||||
|
||||
### az_train
|
||||
|
||||
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release
|
||||
|
||||
#### ResNet, more games, custom output dir
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
|
||||
--save-every 20 --out checkpoints/
|
||||
|
||||
#### Resume from iteration 50
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
|
||||
|
||||
What the binary does each iteration:
|
||||
|
||||
1. Calls model.valid() to get a zero-overhead inference copy for self-play
|
||||
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
|
||||
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
|
||||
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
|
||||
5. Saves a .mpk checkpoint every --save-every iterations and always on the last
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue