Compare commits
12 commits
3c0316e1b7
...
ff989a7ca0
| Author | SHA1 | Date | |
|---|---|---|---|
| ff989a7ca0 | |||
| f00913bca3 | |||
| c8f2a097cd | |||
| 2e85c14dbb | |||
| b074a401ba | |||
| 53eeda349e | |||
| eadc101741 | |||
| baa47e996d | |||
| 7ba4b9bbf3 | |||
| a31d2c1f30 | |||
| e8d7e7b09d | |||
| 566cb6b476 |
1 changed files with 60 additions and 135 deletions
|
|
@ -3,12 +3,12 @@
|
||||||
## 0. Context and Scope
|
## 0. Context and Scope
|
||||||
|
|
||||||
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
||||||
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||||
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
||||||
other stage to an inline random-opponent loop.
|
other stage to an inline random-opponent loop.
|
||||||
|
|
||||||
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
||||||
for **self-play training**. Its goals:
|
for **self-play training**. Its goals:
|
||||||
|
|
||||||
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
||||||
that works with Trictrac's multi-stage turn model and stochastic dice.
|
that works with Trictrac's multi-stage turn model and stochastic dice.
|
||||||
|
|
@ -25,12 +25,12 @@ for **self-play training**. Its goals:
|
||||||
|
|
||||||
### 1.1 Neural Network Frameworks
|
### 1.1 Neural Network Frameworks
|
||||||
|
|
||||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||||
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|
|-------|----------|-----|-----------|----------|-------|
|
||||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||||
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||||
|
|
||||||
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
||||||
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
||||||
|
|
@ -39,33 +39,33 @@ changing one type alias.
|
||||||
|
|
||||||
`tch-rs` would be the best choice for raw training throughput (it is the most
|
`tch-rs` would be the best choice for raw training throughput (it is the most
|
||||||
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
||||||
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||||
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
||||||
|
|
||||||
### 1.2 Other Key Crates
|
### 1.2 Other Key Crates
|
||||||
|
|
||||||
| Crate | Role |
|
| Crate | Role |
|
||||||
| -------------------- | ----------------------------------------------------------------- |
|
|-------|------|
|
||||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||||
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||||
| `anyhow` | error propagation (already used everywhere) |
|
| `anyhow` | error propagation (already used everywhere) |
|
||||||
| `indicatif` | training progress bars |
|
| `indicatif` | training progress bars |
|
||||||
| `tracing` | structured logging per episode/iteration |
|
| `tracing` | structured logging per episode/iteration |
|
||||||
|
|
||||||
### 1.3 What `burn-rl` Provides (and Does Not)
|
### 1.3 What `burn-rl` Provides (and Does Not)
|
||||||
|
|
||||||
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
||||||
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
||||||
trait. It does **not** provide:
|
trait. It does **not** provide:
|
||||||
|
|
||||||
- MCTS or any tree-search algorithm
|
- MCTS or any tree-search algorithm
|
||||||
- Two-player self-play
|
- Two-player self-play
|
||||||
- Legal action masking during training
|
- Legal action masking during training
|
||||||
- Chance-node handling
|
- Chance-node handling
|
||||||
|
|
||||||
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||||
own (simpler, more targeted) traits and implement MCTS from scratch.
|
own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
@ -74,17 +74,17 @@ own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||||
|
|
||||||
### 2.1 Multi-Stage Turn Model
|
### 2.1 Multi-Stage Turn Model
|
||||||
|
|
||||||
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||||
genuine player choice:
|
genuine player choice:
|
||||||
|
|
||||||
| TurnStage | Node type | Handler |
|
| TurnStage | Node type | Handler |
|
||||||
| ---------------- | ------------------------------- | ------------------------------- |
|
|-----------|-----------|---------|
|
||||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||||
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||||
| `Move` | **Player decision** | MCTS / policy network |
|
| `Move` | **Player decision** | MCTS / policy network |
|
||||||
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||||
|
|
||||||
The environment wrapper advances through forced/chance stages automatically so
|
The environment wrapper advances through forced/chance stages automatically so
|
||||||
that from the algorithm's perspective every node it sees is a genuine player
|
that from the algorithm's perspective every node it sees is a genuine player
|
||||||
|
|
@ -92,26 +92,26 @@ decision.
|
||||||
|
|
||||||
### 2.2 Stochastic Dice in MCTS
|
### 2.2 Stochastic Dice in MCTS
|
||||||
|
|
||||||
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||||
introduce stochasticity. Three approaches exist:
|
introduce stochasticity. Three approaches exist:
|
||||||
|
|
||||||
**A. Outcome sampling (recommended)**
|
**A. Outcome sampling (recommended)**
|
||||||
During each MCTS simulation, when a chance node is reached, sample one dice
|
During each MCTS simulation, when a chance node is reached, sample one dice
|
||||||
outcome at random and continue. After many simulations the expected value
|
outcome at random and continue. After many simulations the expected value
|
||||||
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||||
and requires no changes to the standard PUCT formula.
|
and requires no changes to the standard PUCT formula.
|
||||||
|
|
||||||
**B. Chance-node averaging (expectimax)**
|
**B. Chance-node averaging (expectimax)**
|
||||||
At each chance node, expand all 21 unique dice pairs weighted by their
|
At each chance node, expand all 21 unique dice pairs weighted by their
|
||||||
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||||
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
||||||
prohibitively expensive.
|
prohibitively expensive.
|
||||||
|
|
||||||
**C. Condition on dice in the observation (current approach)**
|
**C. Condition on dice in the observation (current approach)**
|
||||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||||
network naturally conditions on the rolled dice when it evaluates a position.
|
network naturally conditions on the rolled dice when it evaluates a position.
|
||||||
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
|
MCTS only runs on player-decision nodes *after* the dice have been sampled;
|
||||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||||
and value heads learn to play optimally given any dice pair.
|
and value heads learn to play optimally given any dice pair.
|
||||||
|
|
||||||
**Use approach A + C together**: the environment samples dice automatically
|
**Use approach A + C together**: the environment samples dice automatically
|
||||||
|
|
@ -129,7 +129,7 @@ to an algorithm is already in the active player's perspective.
|
||||||
|
|
||||||
A crucial difference from the existing `bot/` code: instead of penalizing
|
A crucial difference from the existing `bot/` code: instead of penalizing
|
||||||
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
||||||
before softmax — illegal action logits are set to `-inf`. This prevents the
|
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||||
network from wasting capacity on illegal moves and eliminates the need for the
|
network from wasting capacity on illegal moves and eliminates the need for the
|
||||||
penalty-reward hack.
|
penalty-reward hack.
|
||||||
|
|
||||||
|
|
@ -322,9 +322,9 @@ pub struct MctsConfig {
|
||||||
### 5.4 Handling Chance Nodes Inside MCTS
|
### 5.4 Handling Chance Nodes Inside MCTS
|
||||||
|
|
||||||
When simulation reaches a Chance node (dice roll), the environment automatically
|
When simulation reaches a Chance node (dice roll), the environment automatically
|
||||||
samples dice and advances to the next decision node. The MCTS tree does **not**
|
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||||
branch on dice outcomes — it treats the sampled outcome as the state. This
|
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||||
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||||
simulation independently samples dice, the Q-values at player nodes converge to
|
simulation independently samples dice, the Q-values at player nodes converge to
|
||||||
their expected value over many simulations.
|
their expected value over many simulations.
|
||||||
|
|
||||||
|
|
@ -385,7 +385,6 @@ L = MSE(value_pred, z)
|
||||||
```
|
```
|
||||||
|
|
||||||
Where:
|
Where:
|
||||||
|
|
||||||
- `z` = game outcome (±1) from the active player's perspective
|
- `z` = game outcome (±1) from the active player's perspective
|
||||||
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||||
- Legal action masking is applied before computing CrossEntropy
|
- Legal action masking is applied before computing CrossEntropy
|
||||||
|
|
@ -468,7 +467,7 @@ let samples: Vec<TrainSample> = (0..n_games)
|
||||||
|
|
||||||
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
||||||
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
||||||
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||||
`Autodiff<NdArray<f32>>`.
|
`Autodiff<NdArray<f32>>`.
|
||||||
|
|
||||||
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
||||||
|
|
@ -640,22 +639,22 @@ path = "src/bin/az_eval.rs"
|
||||||
|
|
||||||
## 10. Comparison: `bot` crate vs `spiel_bot`
|
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||||
|
|
||||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||||
| ---------------- | --------------------------- | -------------------------------------------- |
|
|--------|-----------------|------------------------|
|
||||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||||
| Opponent | hardcoded random | self-play |
|
| Opponent | hardcoded random | self-play |
|
||||||
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||||
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||||
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||||
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||||
| `burn-rl` dep | yes | no |
|
| `burn-rl` dep | yes | no |
|
||||||
| C++ dep | no | no |
|
| C++ dep | no | no |
|
||||||
| Python dep | no | no |
|
| Python dep | no | no |
|
||||||
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||||
|
|
||||||
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
||||||
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||||
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
||||||
replace `TrictracEnvironment` with `TrictracEnv`).
|
replace `TrictracEnvironment` with `TrictracEnv`).
|
||||||
|
|
||||||
|
|
@ -685,13 +684,13 @@ replace `TrictracEnvironment` with `TrictracEnv`).
|
||||||
## 12. Key Open Questions
|
## 12. Key Open Questions
|
||||||
|
|
||||||
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
||||||
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||||
scored more holes). Better option: normalize the score margin. The final
|
scored more holes). Better option: normalize the score margin. The final
|
||||||
choice affects how the value head is trained.
|
choice affects how the value head is trained.
|
||||||
|
|
||||||
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
||||||
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
||||||
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||||
limiting `n_simulations` to 50–100 for early training.
|
limiting `n_simulations` to 50–100 for early training.
|
||||||
|
|
||||||
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
||||||
|
|
@ -704,79 +703,5 @@ replace `TrictracEnvironment` with `TrictracEnv`).
|
||||||
This is optional but reduces code duplication.
|
This is optional but reduces code duplication.
|
||||||
|
|
||||||
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||||
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||||
|
|
||||||
## Implementation results
|
|
||||||
|
|
||||||
All benchmarks compile and run. Here's the complete results table:
|
|
||||||
|
|
||||||
| Group | Benchmark | Time |
|
|
||||||
| ------- | ----------------------- | --------------------- |
|
|
||||||
| env | apply_chance | 3.87 µs |
|
|
||||||
| | legal_actions | 1.91 µs |
|
|
||||||
| | observation (to_tensor) | 341 ns |
|
|
||||||
| | random_game (baseline) | 3.55 ms → 282 games/s |
|
|
||||||
| network | mlp_b1 hidden=64 | 94.9 µs |
|
|
||||||
| | mlp_b32 hidden=64 | 141 µs |
|
|
||||||
| | mlp_b1 hidden=256 | 352 µs |
|
|
||||||
| | mlp_b32 hidden=256 | 479 µs |
|
|
||||||
| mcts | zero_eval n=1 | 6.8 µs |
|
|
||||||
| | zero_eval n=5 | 23.9 µs |
|
|
||||||
| | zero_eval n=20 | 90.9 µs |
|
|
||||||
| | mlp64 n=1 | 203 µs |
|
|
||||||
| | mlp64 n=5 | 622 µs |
|
|
||||||
| | mlp64 n=20 | 2.30 ms |
|
|
||||||
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
|
|
||||||
| | trictrac n=2 | 145 ms → 7 games/s |
|
|
||||||
| train | mlp64 Adam b=16 | 1.93 ms |
|
|
||||||
| | mlp64 Adam b=64 | 2.68 ms |
|
|
||||||
|
|
||||||
Key observations:
|
|
||||||
|
|
||||||
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
|
|
||||||
- observation (217-value tensor): only 341 ns — not a bottleneck
|
|
||||||
- legal_actions: 1.9 µs — well optimised
|
|
||||||
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
|
|
||||||
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
|
|
||||||
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
|
|
||||||
- Training: 2.7 ms/step at batch=64 → 370 steps/s
|
|
||||||
|
|
||||||
### Summary of Step 8
|
|
||||||
|
|
||||||
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
|
|
||||||
|
|
||||||
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
|
|
||||||
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
|
|
||||||
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
|
|
||||||
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
|
|
||||||
- Output: win/draw/loss per side + combined decisive win rate
|
|
||||||
|
|
||||||
Typical usage after training:
|
|
||||||
cargo run -p spiel_bot --bin az_eval --release -- \
|
|
||||||
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
|
|
||||||
|
|
||||||
### az_train
|
|
||||||
|
|
||||||
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
|
|
||||||
|
|
||||||
cargo run -p spiel_bot --bin az_train --release
|
|
||||||
|
|
||||||
#### ResNet, more games, custom output dir
|
|
||||||
|
|
||||||
cargo run -p spiel_bot --bin az_train --release -- \
|
|
||||||
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
|
|
||||||
--save-every 20 --out checkpoints/
|
|
||||||
|
|
||||||
#### Resume from iteration 50
|
|
||||||
|
|
||||||
cargo run -p spiel_bot --bin az_train --release -- \
|
|
||||||
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
|
|
||||||
|
|
||||||
What the binary does each iteration:
|
|
||||||
|
|
||||||
1. Calls model.valid() to get a zero-overhead inference copy for self-play
|
|
||||||
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
|
|
||||||
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
|
|
||||||
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
|
|
||||||
5. Saves a .mpk checkpoint every --save-every iterations and always on the last
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue