Compare commits
12 commits
ff989a7ca0
...
3c0316e1b7
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c0316e1b7 | |||
| 8146733157 | |||
| 41b3fc5dad | |||
| 7d37eebe52 | |||
| 2329b76f7e | |||
| 27e0536978 | |||
| 0235fd46c2 | |||
| 0619cf6001 | |||
| 9606e175b8 | |||
| c0d844fdf6 | |||
| 0a4f139e1d | |||
| bc94d1fb3d |
1 changed files with 135 additions and 60 deletions
|
|
@ -3,12 +3,12 @@
|
|||
## 0. Context and Scope
|
||||
|
||||
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
||||
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
||||
other stage to an inline random-opponent loop.
|
||||
|
||||
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
||||
for **self-play training**. Its goals:
|
||||
for **self-play training**. Its goals:
|
||||
|
||||
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
||||
that works with Trictrac's multi-stage turn model and stochastic dice.
|
||||
|
|
@ -25,12 +25,12 @@ for **self-play training**. Its goals:
|
|||
|
||||
### 1.1 Neural Network Frameworks
|
||||
|
||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||
|-------|----------|-----|-----------|----------|-------|
|
||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|
||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||
|
||||
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
||||
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
||||
|
|
@ -39,33 +39,33 @@ changing one type alias.
|
|||
|
||||
`tch-rs` would be the best choice for raw training throughput (it is the most
|
||||
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
||||
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
||||
|
||||
### 1.2 Other Key Crates
|
||||
|
||||
| Crate | Role |
|
||||
|-------|------|
|
||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||
| `anyhow` | error propagation (already used everywhere) |
|
||||
| `indicatif` | training progress bars |
|
||||
| `tracing` | structured logging per episode/iteration |
|
||||
| Crate | Role |
|
||||
| -------------------- | ----------------------------------------------------------------- |
|
||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||
| `anyhow` | error propagation (already used everywhere) |
|
||||
| `indicatif` | training progress bars |
|
||||
| `tracing` | structured logging per episode/iteration |
|
||||
|
||||
### 1.3 What `burn-rl` Provides (and Does Not)
|
||||
|
||||
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
||||
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
||||
trait. It does **not** provide:
|
||||
trait. It does **not** provide:
|
||||
|
||||
- MCTS or any tree-search algorithm
|
||||
- Two-player self-play
|
||||
- Legal action masking during training
|
||||
- Chance-node handling
|
||||
|
||||
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||
own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||
|
||||
---
|
||||
|
|
@ -74,17 +74,17 @@ own (simpler, more targeted) traits and implement MCTS from scratch.
|
|||
|
||||
### 2.1 Multi-Stage Turn Model
|
||||
|
||||
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||
genuine player choice:
|
||||
|
||||
| TurnStage | Node type | Handler |
|
||||
|-----------|-----------|---------|
|
||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||
| `Move` | **Player decision** | MCTS / policy network |
|
||||
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||
| TurnStage | Node type | Handler |
|
||||
| ---------------- | ------------------------------- | ------------------------------- |
|
||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||
| `Move` | **Player decision** | MCTS / policy network |
|
||||
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||
|
||||
The environment wrapper advances through forced/chance stages automatically so
|
||||
that from the algorithm's perspective every node it sees is a genuine player
|
||||
|
|
@ -92,26 +92,26 @@ decision.
|
|||
|
||||
### 2.2 Stochastic Dice in MCTS
|
||||
|
||||
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||
introduce stochasticity. Three approaches exist:
|
||||
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||
introduce stochasticity. Three approaches exist:
|
||||
|
||||
**A. Outcome sampling (recommended)**
|
||||
During each MCTS simulation, when a chance node is reached, sample one dice
|
||||
outcome at random and continue. After many simulations the expected value
|
||||
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||
outcome at random and continue. After many simulations the expected value
|
||||
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||
and requires no changes to the standard PUCT formula.
|
||||
|
||||
**B. Chance-node averaging (expectimax)**
|
||||
At each chance node, expand all 21 unique dice pairs weighted by their
|
||||
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
||||
prohibitively expensive.
|
||||
|
||||
**C. Condition on dice in the observation (current approach)**
|
||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||
network naturally conditions on the rolled dice when it evaluates a position.
|
||||
MCTS only runs on player-decision nodes *after* the dice have been sampled;
|
||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
|
||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||
and value heads learn to play optimally given any dice pair.
|
||||
|
||||
**Use approach A + C together**: the environment samples dice automatically
|
||||
|
|
@ -129,7 +129,7 @@ to an algorithm is already in the active player's perspective.
|
|||
|
||||
A crucial difference from the existing `bot/` code: instead of penalizing
|
||||
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
||||
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||
network from wasting capacity on illegal moves and eliminates the need for the
|
||||
penalty-reward hack.
|
||||
|
||||
|
|
@ -322,9 +322,9 @@ pub struct MctsConfig {
|
|||
### 5.4 Handling Chance Nodes Inside MCTS
|
||||
|
||||
When simulation reaches a Chance node (dice roll), the environment automatically
|
||||
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||
simulation independently samples dice, the Q-values at player nodes converge to
|
||||
their expected value over many simulations.
|
||||
|
||||
|
|
@ -385,6 +385,7 @@ L = MSE(value_pred, z)
|
|||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `z` = game outcome (±1) from the active player's perspective
|
||||
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||
- Legal action masking is applied before computing CrossEntropy
|
||||
|
|
@ -467,7 +468,7 @@ let samples: Vec<TrainSample> = (0..n_games)
|
|||
|
||||
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
||||
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
||||
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||
`Autodiff<NdArray<f32>>`.
|
||||
|
||||
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
||||
|
|
@ -639,22 +640,22 @@ path = "src/bin/az_eval.rs"
|
|||
|
||||
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||
|
||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||
|--------|-----------------|------------------------|
|
||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||
| Opponent | hardcoded random | self-play |
|
||||
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||
| `burn-rl` dep | yes | no |
|
||||
| C++ dep | no | no |
|
||||
| Python dep | no | no |
|
||||
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||
| ---------------- | --------------------------- | -------------------------------------------- |
|
||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||
| Opponent | hardcoded random | self-play |
|
||||
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||
| `burn-rl` dep | yes | no |
|
||||
| C++ dep | no | no |
|
||||
| Python dep | no | no |
|
||||
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||
|
||||
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
||||
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
||||
replace `TrictracEnvironment` with `TrictracEnv`).
|
||||
|
||||
|
|
@ -684,13 +685,13 @@ replace `TrictracEnvironment` with `TrictracEnv`).
|
|||
## 12. Key Open Questions
|
||||
|
||||
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
||||
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||
scored more holes). Better option: normalize the score margin. The final
|
||||
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||
scored more holes). Better option: normalize the score margin. The final
|
||||
choice affects how the value head is trained.
|
||||
|
||||
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
||||
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
||||
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||
limiting `n_simulations` to 50–100 for early training.
|
||||
|
||||
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
||||
|
|
@ -703,5 +704,79 @@ replace `TrictracEnvironment` with `TrictracEnv`).
|
|||
This is optional but reduces code duplication.
|
||||
|
||||
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||
|
||||
## Implementation results
|
||||
|
||||
All benchmarks compile and run. Here's the complete results table:
|
||||
|
||||
| Group | Benchmark | Time |
|
||||
| ------- | ----------------------- | --------------------- |
|
||||
| env | apply_chance | 3.87 µs |
|
||||
| | legal_actions | 1.91 µs |
|
||||
| | observation (to_tensor) | 341 ns |
|
||||
| | random_game (baseline) | 3.55 ms → 282 games/s |
|
||||
| network | mlp_b1 hidden=64 | 94.9 µs |
|
||||
| | mlp_b32 hidden=64 | 141 µs |
|
||||
| | mlp_b1 hidden=256 | 352 µs |
|
||||
| | mlp_b32 hidden=256 | 479 µs |
|
||||
| mcts | zero_eval n=1 | 6.8 µs |
|
||||
| | zero_eval n=5 | 23.9 µs |
|
||||
| | zero_eval n=20 | 90.9 µs |
|
||||
| | mlp64 n=1 | 203 µs |
|
||||
| | mlp64 n=5 | 622 µs |
|
||||
| | mlp64 n=20 | 2.30 ms |
|
||||
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
|
||||
| | trictrac n=2 | 145 ms → 7 games/s |
|
||||
| train | mlp64 Adam b=16 | 1.93 ms |
|
||||
| | mlp64 Adam b=64 | 2.68 ms |
|
||||
|
||||
Key observations:
|
||||
|
||||
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
|
||||
- observation (217-value tensor): only 341 ns — not a bottleneck
|
||||
- legal_actions: 1.9 µs — well optimised
|
||||
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
|
||||
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
|
||||
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
|
||||
- Training: 2.7 ms/step at batch=64 → 370 steps/s
|
||||
|
||||
### Summary of Step 8
|
||||
|
||||
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
|
||||
|
||||
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
|
||||
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
|
||||
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
|
||||
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
|
||||
- Output: win/draw/loss per side + combined decisive win rate
|
||||
|
||||
Typical usage after training:
|
||||
cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
|
||||
|
||||
### az_train
|
||||
|
||||
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release
|
||||
|
||||
#### ResNet, more games, custom output dir
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
|
||||
--save-every 20 --out checkpoints/
|
||||
|
||||
#### Resume from iteration 50
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
|
||||
|
||||
What the binary does each iteration:
|
||||
|
||||
1. Calls model.valid() to get a zero-overhead inference copy for self-play
|
||||
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
|
||||
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
|
||||
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
|
||||
5. Saves a .mpk checkpoint every --save-every iterations and always on the last
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue