Compare commits
6 commits
db5c1ea4f4
...
eadc101741
| Author | SHA1 | Date | |
|---|---|---|---|
| eadc101741 | |||
| baa47e996d | |||
| 7ba4b9bbf3 | |||
| a31d2c1f30 | |||
| e8d7e7b09d | |||
| 566cb6b476 |
18 changed files with 3274 additions and 5 deletions
11
Cargo.lock
generated
11
Cargo.lock
generated
|
|
@ -5891,6 +5891,17 @@ dependencies = [
|
||||||
"windows-sys 0.60.2",
|
"windows-sys 0.60.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spiel_bot"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"burn",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"rand_distr",
|
||||||
|
"trictrac-store",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
members = ["client_cli", "bot", "store"]
|
members = ["client_cli", "bot", "store", "spiel_bot"]
|
||||||
|
|
|
||||||
707
doc/spiel_bot_research.md
Normal file
707
doc/spiel_bot_research.md
Normal file
|
|
@ -0,0 +1,707 @@
|
||||||
|
# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac
|
||||||
|
|
||||||
|
## 0. Context and Scope
|
||||||
|
|
||||||
|
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
||||||
|
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||||
|
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
||||||
|
other stage to an inline random-opponent loop.
|
||||||
|
|
||||||
|
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
||||||
|
for **self-play training**. Its goals:
|
||||||
|
|
||||||
|
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
||||||
|
that works with Trictrac's multi-stage turn model and stochastic dice.
|
||||||
|
- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer)
|
||||||
|
as the first algorithm.
|
||||||
|
- Remain **modular**: adding DQN or PPO later requires only a new
|
||||||
|
`impl Algorithm for Dqn` without touching the environment or network layers.
|
||||||
|
- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from
|
||||||
|
`trictrac-store`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Library Landscape
|
||||||
|
|
||||||
|
### 1.1 Neural Network Frameworks
|
||||||
|
|
||||||
|
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||||
|
|-------|----------|-----|-----------|----------|-------|
|
||||||
|
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||||
|
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||||
|
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||||
|
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||||
|
|
||||||
|
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
||||||
|
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
||||||
|
switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by
|
||||||
|
changing one type alias.
|
||||||
|
|
||||||
|
`tch-rs` would be the best choice for raw training throughput (it is the most
|
||||||
|
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
||||||
|
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||||
|
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
||||||
|
|
||||||
|
### 1.2 Other Key Crates
|
||||||
|
|
||||||
|
| Crate | Role |
|
||||||
|
|-------|------|
|
||||||
|
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||||
|
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||||
|
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||||
|
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||||
|
| `anyhow` | error propagation (already used everywhere) |
|
||||||
|
| `indicatif` | training progress bars |
|
||||||
|
| `tracing` | structured logging per episode/iteration |
|
||||||
|
|
||||||
|
### 1.3 What `burn-rl` Provides (and Does Not)
|
||||||
|
|
||||||
|
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
||||||
|
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
||||||
|
trait. It does **not** provide:
|
||||||
|
|
||||||
|
- MCTS or any tree-search algorithm
|
||||||
|
- Two-player self-play
|
||||||
|
- Legal action masking during training
|
||||||
|
- Chance-node handling
|
||||||
|
|
||||||
|
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||||
|
own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Trictrac-Specific Design Constraints
|
||||||
|
|
||||||
|
### 2.1 Multi-Stage Turn Model
|
||||||
|
|
||||||
|
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||||
|
genuine player choice:
|
||||||
|
|
||||||
|
| TurnStage | Node type | Handler |
|
||||||
|
|-----------|-----------|---------|
|
||||||
|
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||||
|
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||||
|
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||||
|
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||||
|
| `Move` | **Player decision** | MCTS / policy network |
|
||||||
|
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||||
|
|
||||||
|
The environment wrapper advances through forced/chance stages automatically so
|
||||||
|
that from the algorithm's perspective every node it sees is a genuine player
|
||||||
|
decision.
|
||||||
|
|
||||||
|
### 2.2 Stochastic Dice in MCTS
|
||||||
|
|
||||||
|
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||||
|
introduce stochasticity. Three approaches exist:
|
||||||
|
|
||||||
|
**A. Outcome sampling (recommended)**
|
||||||
|
During each MCTS simulation, when a chance node is reached, sample one dice
|
||||||
|
outcome at random and continue. After many simulations the expected value
|
||||||
|
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||||
|
and requires no changes to the standard PUCT formula.
|
||||||
|
|
||||||
|
**B. Chance-node averaging (expectimax)**
|
||||||
|
At each chance node, expand all 21 unique dice pairs weighted by their
|
||||||
|
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||||
|
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
||||||
|
prohibitively expensive.
|
||||||
|
|
||||||
|
**C. Condition on dice in the observation (current approach)**
|
||||||
|
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||||
|
network naturally conditions on the rolled dice when it evaluates a position.
|
||||||
|
MCTS only runs on player-decision nodes *after* the dice have been sampled;
|
||||||
|
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||||
|
and value heads learn to play optimally given any dice pair.
|
||||||
|
|
||||||
|
**Use approach A + C together**: the environment samples dice automatically
|
||||||
|
(chance node bypass), and the 217-dim tensor encodes the dice so the network
|
||||||
|
can exploit them.
|
||||||
|
|
||||||
|
### 2.3 Perspective / Mirroring
|
||||||
|
|
||||||
|
All move rules and tensor encoding are defined from White's perspective.
|
||||||
|
`to_tensor()` must always be called after mirroring the state for Black.
|
||||||
|
The environment wrapper handles this transparently: every observation returned
|
||||||
|
to an algorithm is already in the active player's perspective.
|
||||||
|
|
||||||
|
### 2.4 Legal Action Masking
|
||||||
|
|
||||||
|
A crucial difference from the existing `bot/` code: instead of penalizing
|
||||||
|
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
||||||
|
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||||
|
network from wasting capacity on illegal moves and eliminates the need for the
|
||||||
|
penalty-reward hack.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Proposed Crate Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
spiel_bot/
|
||||||
|
├── Cargo.toml
|
||||||
|
└── src/
|
||||||
|
├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo"
|
||||||
|
│
|
||||||
|
├── env/
|
||||||
|
│ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface
|
||||||
|
│ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store
|
||||||
|
│
|
||||||
|
├── mcts/
|
||||||
|
│ ├── mod.rs # MctsConfig, run_mcts() entry point
|
||||||
|
│ ├── node.rs # MctsNode (visit count, W, prior, children)
|
||||||
|
│ └── search.rs # simulate(), backup(), select_action()
|
||||||
|
│
|
||||||
|
├── network/
|
||||||
|
│ ├── mod.rs # PolicyValueNet trait
|
||||||
|
│ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads
|
||||||
|
│
|
||||||
|
├── alphazero/
|
||||||
|
│ ├── mod.rs # AlphaZeroConfig
|
||||||
|
│ ├── selfplay.rs # generate_episode() -> Vec<TrainSample>
|
||||||
|
│ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle)
|
||||||
|
│ └── trainer.rs # training loop: selfplay → sample → loss → update
|
||||||
|
│
|
||||||
|
└── agent/
|
||||||
|
├── mod.rs # Agent trait
|
||||||
|
├── random.rs # RandomAgent (baseline)
|
||||||
|
└── mcts_agent.rs # MctsAgent: uses trained network for inference
|
||||||
|
```
|
||||||
|
|
||||||
|
Future algorithms slot in without touching the above:
|
||||||
|
|
||||||
|
```
|
||||||
|
├── dqn/ # (future) DQN: impl Algorithm + own replay buffer
|
||||||
|
└── ppo/ # (future) PPO: impl Algorithm + rollout buffer
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Core Traits
|
||||||
|
|
||||||
|
### 4.1 `GameEnv` — the minimal OpenSpiel interface
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
/// Who controls the current node.
|
||||||
|
pub enum Player {
|
||||||
|
P1, // player index 0
|
||||||
|
P2, // player index 1
|
||||||
|
Chance, // dice roll
|
||||||
|
Terminal, // game over
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait GameEnv: Clone + Send + Sync + 'static {
|
||||||
|
type State: Clone + Send + Sync;
|
||||||
|
|
||||||
|
/// Fresh game state.
|
||||||
|
fn new_game(&self) -> Self::State;
|
||||||
|
|
||||||
|
/// Who acts at this node.
|
||||||
|
fn current_player(&self, s: &Self::State) -> Player;
|
||||||
|
|
||||||
|
/// Legal action indices (always in [0, action_space())).
|
||||||
|
/// Empty only at Terminal nodes.
|
||||||
|
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
|
||||||
|
|
||||||
|
/// Apply a player action (must be legal).
|
||||||
|
fn apply(&self, s: &mut Self::State, action: usize);
|
||||||
|
|
||||||
|
/// Advance a Chance node by sampling dice; no-op at non-Chance nodes.
|
||||||
|
fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng);
|
||||||
|
|
||||||
|
/// Observation tensor from `pov`'s perspective (0 or 1).
|
||||||
|
/// Returns 217 f32 values for Trictrac.
|
||||||
|
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
|
||||||
|
|
||||||
|
/// Flat observation size (217 for Trictrac).
|
||||||
|
fn obs_size(&self) -> usize;
|
||||||
|
|
||||||
|
/// Total action-space size (514 for Trictrac).
|
||||||
|
fn action_space(&self) -> usize;
|
||||||
|
|
||||||
|
/// Game outcome per player, or None if not Terminal.
|
||||||
|
/// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw.
|
||||||
|
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 `PolicyValueNet` — neural network interface
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
pub trait PolicyValueNet<B: Backend>: Send + Sync {
|
||||||
|
/// Forward pass.
|
||||||
|
/// `obs`: [batch, obs_size] tensor.
|
||||||
|
/// Returns: (policy_logits [batch, action_space], value [batch]).
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
|
||||||
|
|
||||||
|
/// Save weights to `path`.
|
||||||
|
fn save(&self, path: &std::path::Path) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
/// Load weights from `path`.
|
||||||
|
fn load(path: &std::path::Path) -> anyhow::Result<Self>
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 `Agent` — player policy interface
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait Agent: Send {
|
||||||
|
/// Select an action index given the current game state observation.
|
||||||
|
/// `legal`: mask of valid action indices.
|
||||||
|
fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. MCTS Implementation
|
||||||
|
|
||||||
|
### 5.1 Node
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct MctsNode {
|
||||||
|
n: u32, // visit count N(s, a)
|
||||||
|
w: f32, // sum of backed-up values W(s, a)
|
||||||
|
p: f32, // prior from policy head P(s, a)
|
||||||
|
children: Vec<(usize, MctsNode)>, // (action_idx, child)
|
||||||
|
is_expanded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsNode {
|
||||||
|
pub fn q(&self) -> f32 {
|
||||||
|
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUCT score used for selection.
|
||||||
|
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||||
|
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Simulation Loop
|
||||||
|
|
||||||
|
One MCTS simulation (for deterministic decision nodes):
|
||||||
|
|
||||||
|
```
|
||||||
|
1. SELECTION — traverse from root, always pick child with highest PUCT,
|
||||||
|
auto-advancing forced/chance nodes via env.apply_chance().
|
||||||
|
2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get
|
||||||
|
(policy_logits, value). Mask illegal actions, softmax
|
||||||
|
the remaining logits → priors P(s,a) for each child.
|
||||||
|
3. BACKUP — propagate -value up the tree (negate at each level because
|
||||||
|
perspective alternates between P1 and P2).
|
||||||
|
```
|
||||||
|
|
||||||
|
After `n_simulations` iterations, action selection at the root:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// During training: sample proportional to N^(1/temperature)
|
||||||
|
// During evaluation: argmax N
|
||||||
|
fn select_action(root: &MctsNode, temperature: f32) -> usize { ... }
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 Configuration
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct MctsConfig {
|
||||||
|
pub n_simulations: usize, // e.g. 200
|
||||||
|
pub c_puct: f32, // exploration constant, e.g. 1.5
|
||||||
|
pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3
|
||||||
|
pub dirichlet_eps: f32, // noise weight, e.g. 0.25
|
||||||
|
pub temperature: f32, // action sampling temperature (anneals to 0)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 Handling Chance Nodes Inside MCTS
|
||||||
|
|
||||||
|
When simulation reaches a Chance node (dice roll), the environment automatically
|
||||||
|
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||||
|
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||||
|
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||||
|
simulation independently samples dice, the Q-values at player nodes converge to
|
||||||
|
their expected value over many simulations.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Network Architecture
|
||||||
|
|
||||||
|
### 6.1 ResNet Policy-Value Network
|
||||||
|
|
||||||
|
A single trunk with residual blocks, then two heads:
|
||||||
|
|
||||||
|
```
|
||||||
|
Input: [batch, 217]
|
||||||
|
↓
|
||||||
|
Linear(217 → 512) + ReLU
|
||||||
|
↓
|
||||||
|
ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU)
|
||||||
|
↓ trunk output [batch, 512]
|
||||||
|
├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site)
|
||||||
|
└─ Value head: Linear(512 → 1) → tanh (output in [-1, 1])
|
||||||
|
```
|
||||||
|
|
||||||
|
Burn implementation sketch:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct TrictracNet<B: Backend> {
|
||||||
|
input: Linear<B>,
|
||||||
|
res_blocks: Vec<ResBlock<B>>,
|
||||||
|
policy_head: Linear<B>,
|
||||||
|
value_head: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> TrictracNet<B> {
|
||||||
|
pub fn forward(&self, obs: Tensor<B, 2>)
|
||||||
|
-> (Tensor<B, 2>, Tensor<B, 1>)
|
||||||
|
{
|
||||||
|
let x = activation::relu(self.input.forward(obs));
|
||||||
|
let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x));
|
||||||
|
let policy = self.policy_head.forward(x.clone()); // raw logits
|
||||||
|
let value = activation::tanh(self.value_head.forward(x))
|
||||||
|
.squeeze(1);
|
||||||
|
(policy, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
A simpler MLP (no residual blocks) is sufficient for a first version and much
|
||||||
|
faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two
|
||||||
|
heads.
|
||||||
|
|
||||||
|
### 6.2 Loss Function
|
||||||
|
|
||||||
|
```
|
||||||
|
L = MSE(value_pred, z)
|
||||||
|
+ CrossEntropy(policy_logits_masked, π_mcts)
|
||||||
|
- c_l2 * L2_regularization
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `z` = game outcome (±1) from the active player's perspective
|
||||||
|
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||||
|
- Legal action masking is applied before computing CrossEntropy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. AlphaZero Training Loop
|
||||||
|
|
||||||
|
```
|
||||||
|
INIT
|
||||||
|
network ← random weights
|
||||||
|
replay ← empty ReplayBuffer(capacity = 100_000)
|
||||||
|
|
||||||
|
LOOP forever:
|
||||||
|
── Self-play phase ──────────────────────────────────────────────
|
||||||
|
(parallel with rayon, n_workers games at once)
|
||||||
|
for each game:
|
||||||
|
state ← env.new_game()
|
||||||
|
samples = []
|
||||||
|
while not terminal:
|
||||||
|
advance forced/chance nodes automatically
|
||||||
|
obs ← env.observation(state, current_player)
|
||||||
|
legal ← env.legal_actions(state)
|
||||||
|
π, root_value ← mcts.run(state, network, config)
|
||||||
|
action ← sample from π (with temperature)
|
||||||
|
samples.push((obs, π, current_player))
|
||||||
|
env.apply(state, action)
|
||||||
|
z ← env.returns(state) // final scores
|
||||||
|
for (obs, π, player) in samples:
|
||||||
|
replay.push(TrainSample { obs, policy: π, value: z[player] })
|
||||||
|
|
||||||
|
── Training phase ───────────────────────────────────────────────
|
||||||
|
for each gradient step:
|
||||||
|
batch ← replay.sample(batch_size)
|
||||||
|
(policy_logits, value_pred) ← network.forward(batch.obs)
|
||||||
|
loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy)
|
||||||
|
optimizer.step(loss.backward())
|
||||||
|
|
||||||
|
── Evaluation (every N iterations) ─────────────────────────────
|
||||||
|
win_rate ← evaluate(network_new vs network_prev, n_eval_games)
|
||||||
|
if win_rate > 0.55: save checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.1 Replay Buffer
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct TrainSample {
|
||||||
|
pub obs: Vec<f32>, // 217 values
|
||||||
|
pub policy: Vec<f32>, // 514 values (normalized MCTS visit counts)
|
||||||
|
pub value: f32, // game outcome ∈ {-1, 0, +1}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
data: VecDeque<TrainSample>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
pub fn push(&mut self, s: TrainSample) {
|
||||||
|
if self.data.len() == self.capacity { self.data.pop_front(); }
|
||||||
|
self.data.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||||
|
// sample without replacement
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 Parallelism Strategy
|
||||||
|
|
||||||
|
Self-play is embarrassingly parallel (each game is independent):
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let samples: Vec<TrainSample> = (0..n_games)
|
||||||
|
.into_par_iter() // rayon
|
||||||
|
.flat_map(|_| generate_episode(&env, &network, &mcts_config))
|
||||||
|
.collect();
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
||||||
|
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
||||||
|
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||||
|
`Autodiff<NdArray<f32>>`.
|
||||||
|
|
||||||
|
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
||||||
|
self-play workers from the training thread, allowing continuous data generation
|
||||||
|
while the GPU trains.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. `TrictracEnv` Implementation Sketch
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use trictrac_store::{
|
||||||
|
training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE},
|
||||||
|
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TrictracEnv;
|
||||||
|
|
||||||
|
impl GameEnv for TrictracEnv {
|
||||||
|
type State = GameState;
|
||||||
|
|
||||||
|
fn new_game(&self) -> GameState {
|
||||||
|
GameState::new_with_players("P1", "P2")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_player(&self, s: &GameState) -> Player {
|
||||||
|
match s.stage {
|
||||||
|
Stage::Ended => Player::Terminal,
|
||||||
|
_ => match s.turn_stage {
|
||||||
|
TurnStage::RollWaiting => Player::Chance,
|
||||||
|
_ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
|
||||||
|
let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() };
|
||||||
|
get_valid_action_indices(&view).unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(&self, s: &mut GameState, action_idx: usize) {
|
||||||
|
// advance all forced/chance nodes first, then apply the player action
|
||||||
|
self.advance_forced(s);
|
||||||
|
let needs_mirror = s.active_player_id == 2;
|
||||||
|
let view = if needs_mirror { s.mirror() } else { s.clone() };
|
||||||
|
if let Some(event) = TrictracAction::from_action_index(action_idx)
|
||||||
|
.and_then(|a| a.to_event(&view))
|
||||||
|
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||||
|
{
|
||||||
|
let _ = s.consume(&event);
|
||||||
|
}
|
||||||
|
// advance any forced stages that follow
|
||||||
|
self.advance_forced(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) {
|
||||||
|
// RollDice → RollWaiting
|
||||||
|
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||||
|
// RollWaiting → next stage
|
||||||
|
let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) };
|
||||||
|
let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice });
|
||||||
|
self.advance_forced(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
|
||||||
|
if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn obs_size(&self) -> usize { 217 }
|
||||||
|
fn action_space(&self) -> usize { ACTION_SPACE_SIZE }
|
||||||
|
|
||||||
|
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
|
||||||
|
if s.stage != Stage::Ended { return None; }
|
||||||
|
// Convert hole+point scores to ±1 outcome
|
||||||
|
let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||||
|
let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||||
|
Some(match s1.cmp(&s2) {
|
||||||
|
std::cmp::Ordering::Greater => [ 1.0, -1.0],
|
||||||
|
std::cmp::Ordering::Less => [-1.0, 1.0],
|
||||||
|
std::cmp::Ordering::Equal => [ 0.0, 0.0],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracEnv {
|
||||||
|
/// Advance through all forced (non-decision, non-chance) stages.
|
||||||
|
fn advance_forced(&self, s: &mut GameState) {
|
||||||
|
use trictrac_store::PointsRules;
|
||||||
|
loop {
|
||||||
|
match s.turn_stage {
|
||||||
|
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||||
|
// Scoring is deterministic; compute and apply automatically.
|
||||||
|
let color = s.player_color_by_id(&s.active_player_id)
|
||||||
|
.unwrap_or(trictrac_store::Color::White);
|
||||||
|
let drc = s.players.get(&s.active_player_id)
|
||||||
|
.map(|p| p.dice_roll_count).unwrap_or(0);
|
||||||
|
let pr = PointsRules::new(&color, &s.board, s.dice);
|
||||||
|
let pts = pr.get_points(drc);
|
||||||
|
let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 };
|
||||||
|
let _ = s.consume(&GameEvent::Mark {
|
||||||
|
player_id: s.active_player_id, points,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
TurnStage::RollDice => {
|
||||||
|
// RollDice is a forced "initiate roll" action with no real choice.
|
||||||
|
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||||
|
}
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Cargo.toml Changes
|
||||||
|
|
||||||
|
### 9.1 Add `spiel_bot` to the workspace
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# Cargo.toml (workspace root)
|
||||||
|
[workspace]
|
||||||
|
resolver = "2"
|
||||||
|
members = ["client_cli", "bot", "store", "spiel_bot"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9.2 `spiel_bot/Cargo.toml`
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[package]
|
||||||
|
name = "spiel_bot"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["alphazero"]
|
||||||
|
alphazero = []
|
||||||
|
# dqn = [] # future
|
||||||
|
# ppo = [] # future
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
trictrac-store = { path = "../store" }
|
||||||
|
anyhow = "1"
|
||||||
|
rand = "0.9"
|
||||||
|
rayon = "1"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
|
||||||
|
# Burn: NdArray for pure-Rust CPU training
|
||||||
|
# Replace NdArray with Wgpu or Tch for GPU.
|
||||||
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
|
||||||
|
# Optional: progress display and structured logging
|
||||||
|
indicatif = "0.17"
|
||||||
|
tracing = "0.1"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "az_train"
|
||||||
|
path = "src/bin/az_train.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "az_eval"
|
||||||
|
path = "src/bin/az_eval.rs"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||||
|
|
||||||
|
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||||
|
|--------|-----------------|------------------------|
|
||||||
|
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||||
|
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||||
|
| Opponent | hardcoded random | self-play |
|
||||||
|
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||||
|
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||||
|
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||||
|
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||||
|
| `burn-rl` dep | yes | no |
|
||||||
|
| C++ dep | no | no |
|
||||||
|
| Python dep | no | no |
|
||||||
|
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||||
|
|
||||||
|
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
||||||
|
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||||
|
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
||||||
|
replace `TrictracEnvironment` with `TrictracEnv`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Implementation Order
|
||||||
|
|
||||||
|
1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game
|
||||||
|
through the trait, verify terminal state and returns).
|
||||||
|
2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) +
|
||||||
|
Burn forward/backward pass test with dummy data.
|
||||||
|
3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests
|
||||||
|
(visit counts sum to `n_simulations`, legal mask respected).
|
||||||
|
4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub
|
||||||
|
(one iteration, check loss decreases).
|
||||||
|
5. **Integration test**: run 100 self-play games with a tiny network (1 res block,
|
||||||
|
64 hidden units), verify the training loop completes without panics.
|
||||||
|
6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s
|
||||||
|
on CPU, consistent with `random_game` throughput).
|
||||||
|
7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule
|
||||||
|
hyperparameter sweep.
|
||||||
|
8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report
|
||||||
|
win rate every checkpoint.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Key Open Questions
|
||||||
|
|
||||||
|
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
||||||
|
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||||
|
scored more holes). Better option: normalize the score margin. The final
|
||||||
|
choice affects how the value head is trained.
|
||||||
|
|
||||||
|
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
||||||
|
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
||||||
|
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||||
|
limiting `n_simulations` to 50–100 for early training.
|
||||||
|
|
||||||
|
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
||||||
|
This is a long-horizon decision that AlphaZero handles naturally via MCTS
|
||||||
|
lookahead, but needs careful value normalization (a "Go" restarts scoring
|
||||||
|
within the same game).
|
||||||
|
|
||||||
|
4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated
|
||||||
|
to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic.
|
||||||
|
This is optional but reduces code duplication.
|
||||||
|
|
||||||
|
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||||
|
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||||
|
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||||
11
spiel_bot/Cargo.toml
Normal file
11
spiel_bot/Cargo.toml
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
[package]
|
||||||
|
name = "spiel_bot"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
trictrac-store = { path = "../store" }
|
||||||
|
anyhow = "1"
|
||||||
|
rand = "0.9"
|
||||||
|
rand_distr = "0.5"
|
||||||
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
117
spiel_bot/src/alphazero/mod.rs
Normal file
117
spiel_bot/src/alphazero/mod.rs
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
//! AlphaZero: self-play data generation, replay buffer, and training step.
|
||||||
|
//!
|
||||||
|
//! # Modules
|
||||||
|
//!
|
||||||
|
//! | Module | Contents |
|
||||||
|
//! |--------|----------|
|
||||||
|
//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] |
|
||||||
|
//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] |
|
||||||
|
//! | [`trainer`] | [`train_step`] |
|
||||||
|
//!
|
||||||
|
//! # Typical outer loop
|
||||||
|
//!
|
||||||
|
//! ```rust,ignore
|
||||||
|
//! use burn::backend::{Autodiff, NdArray};
|
||||||
|
//! use burn::optim::AdamConfig;
|
||||||
|
//! use spiel_bot::{
|
||||||
|
//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step},
|
||||||
|
//! env::TrictracEnv,
|
||||||
|
//! mcts::MctsConfig,
|
||||||
|
//! network::{MlpConfig, MlpNet},
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! type Infer = NdArray<f32>;
|
||||||
|
//! type Train = Autodiff<NdArray<f32>>;
|
||||||
|
//!
|
||||||
|
//! let device = Default::default();
|
||||||
|
//! let env = TrictracEnv;
|
||||||
|
//! let config = AlphaZeroConfig::default();
|
||||||
|
//!
|
||||||
|
//! // Build training model and optimizer.
|
||||||
|
//! let mut train_model = MlpNet::<Train>::new(&MlpConfig::default(), &device);
|
||||||
|
//! let mut optimizer = AdamConfig::new().init();
|
||||||
|
//! let mut replay = ReplayBuffer::new(config.replay_capacity);
|
||||||
|
//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||||
|
//!
|
||||||
|
//! for _iter in 0..config.n_iterations {
|
||||||
|
//! // Convert to inference backend for self-play.
|
||||||
|
//! let infer_model = MlpNet::<Infer>::new(&MlpConfig::default(), &device)
|
||||||
|
//! .load_record(train_model.clone().into_record());
|
||||||
|
//! let eval = BurnEvaluator::new(infer_model, device.clone());
|
||||||
|
//!
|
||||||
|
//! // Self-play: generate episodes.
|
||||||
|
//! for _ in 0..config.n_games_per_iter {
|
||||||
|
//! let samples = generate_episode(&env, &eval, &config.mcts,
|
||||||
|
//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
|
||||||
|
//! replay.extend(samples);
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! // Training: gradient steps.
|
||||||
|
//! if replay.len() >= config.batch_size {
|
||||||
|
//! for _ in 0..config.n_train_steps_per_iter {
|
||||||
|
//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng)
|
||||||
|
//! .into_iter().cloned().collect();
|
||||||
|
//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device,
|
||||||
|
//! config.learning_rate);
|
||||||
|
//! train_model = m;
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
pub mod replay;
|
||||||
|
pub mod selfplay;
|
||||||
|
pub mod trainer;
|
||||||
|
|
||||||
|
pub use replay::{ReplayBuffer, TrainSample};
|
||||||
|
pub use selfplay::{BurnEvaluator, generate_episode};
|
||||||
|
pub use trainer::train_step;
|
||||||
|
|
||||||
|
use crate::mcts::MctsConfig;
|
||||||
|
|
||||||
|
// ── Configuration ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Top-level AlphaZero hyperparameters.
|
||||||
|
///
|
||||||
|
/// The MCTS parameters live in [`MctsConfig`]; this struct holds the
|
||||||
|
/// outer training-loop parameters.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AlphaZeroConfig {
|
||||||
|
/// MCTS parameters for self-play.
|
||||||
|
pub mcts: MctsConfig,
|
||||||
|
/// Number of self-play games per training iteration.
|
||||||
|
pub n_games_per_iter: usize,
|
||||||
|
/// Number of gradient steps per training iteration.
|
||||||
|
pub n_train_steps_per_iter: usize,
|
||||||
|
/// Mini-batch size for each gradient step.
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// Maximum number of samples in the replay buffer.
|
||||||
|
pub replay_capacity: usize,
|
||||||
|
/// Adam learning rate.
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Number of outer iterations (self-play + train) to run.
|
||||||
|
pub n_iterations: usize,
|
||||||
|
/// Move index after which the action temperature drops to 0 (greedy play).
|
||||||
|
pub temperature_drop_move: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AlphaZeroConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
mcts: MctsConfig {
|
||||||
|
n_simulations: 100,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.1,
|
||||||
|
dirichlet_eps: 0.25,
|
||||||
|
temperature: 1.0,
|
||||||
|
},
|
||||||
|
n_games_per_iter: 10,
|
||||||
|
n_train_steps_per_iter: 20,
|
||||||
|
batch_size: 64,
|
||||||
|
replay_capacity: 50_000,
|
||||||
|
learning_rate: 1e-3,
|
||||||
|
n_iterations: 100,
|
||||||
|
temperature_drop_move: 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
144
spiel_bot/src/alphazero/replay.rs
Normal file
144
spiel_bot/src/alphazero/replay.rs
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
//! Replay buffer for AlphaZero self-play data.
|
||||||
|
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// ── Training sample ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One training example produced by self-play.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TrainSample {
|
||||||
|
/// Observation tensor from the acting player's perspective (`obs_size` floats).
|
||||||
|
pub obs: Vec<f32>,
|
||||||
|
/// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1).
|
||||||
|
pub policy: Vec<f32>,
|
||||||
|
/// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw.
|
||||||
|
pub value: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Replay buffer ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Fixed-capacity circular buffer of [`TrainSample`]s.
|
||||||
|
///
|
||||||
|
/// When the buffer is full, the oldest sample is evicted on push.
|
||||||
|
/// Samples are drawn without replacement using a Fisher-Yates partial shuffle.
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
data: VecDeque<TrainSample>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
/// Create a buffer with the given maximum capacity.
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data: VecDeque::with_capacity(capacity.min(1024)),
|
||||||
|
capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a sample; evicts the oldest if at capacity.
|
||||||
|
pub fn push(&mut self, sample: TrainSample) {
|
||||||
|
if self.data.len() == self.capacity {
|
||||||
|
self.data.pop_front();
|
||||||
|
}
|
||||||
|
self.data.push_back(sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add all samples from an episode.
|
||||||
|
pub fn extend(&mut self, samples: impl IntoIterator<Item = TrainSample>) {
|
||||||
|
for s in samples {
|
||||||
|
self.push(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sample up to `n` distinct samples, without replacement.
|
||||||
|
///
|
||||||
|
/// If the buffer has fewer than `n` samples, all are returned (shuffled).
|
||||||
|
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||||
|
let len = self.data.len();
|
||||||
|
let n = n.min(len);
|
||||||
|
// Partial Fisher-Yates using index shuffling.
|
||||||
|
let mut indices: Vec<usize> = (0..len).collect();
|
||||||
|
for i in 0..n {
|
||||||
|
let j = rng.random_range(i..len);
|
||||||
|
indices.swap(i, j);
|
||||||
|
}
|
||||||
|
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
fn dummy(value: f32) -> TrainSample {
|
||||||
|
TrainSample { obs: vec![value], policy: vec![1.0], value }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_and_len() {
|
||||||
|
let mut buf = ReplayBuffer::new(10);
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
assert_eq!(buf.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn evicts_oldest_at_capacity() {
|
||||||
|
let mut buf = ReplayBuffer::new(3);
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
buf.push(dummy(3.0));
|
||||||
|
buf.push(dummy(4.0)); // evicts 1.0
|
||||||
|
assert_eq!(buf.len(), 3);
|
||||||
|
// Oldest remaining should be 2.0
|
||||||
|
assert_eq!(buf.data[0].value, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_size() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
for i in 0..10 {
|
||||||
|
buf.push(dummy(i as f32));
|
||||||
|
}
|
||||||
|
let mut rng = SmallRng::seed_from_u64(0);
|
||||||
|
let batch = buf.sample_batch(5, &mut rng);
|
||||||
|
assert_eq!(batch.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_capped_at_len() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
let mut rng = SmallRng::seed_from_u64(0);
|
||||||
|
let batch = buf.sample_batch(100, &mut rng);
|
||||||
|
assert_eq!(batch.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_no_duplicates() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
for i in 0..10 {
|
||||||
|
buf.push(dummy(i as f32));
|
||||||
|
}
|
||||||
|
let mut rng = SmallRng::seed_from_u64(1);
|
||||||
|
let batch = buf.sample_batch(10, &mut rng);
|
||||||
|
let mut seen: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||||
|
seen.sort_by(f32::total_cmp);
|
||||||
|
seen.dedup();
|
||||||
|
assert_eq!(seen.len(), 10, "sample contained duplicates");
|
||||||
|
}
|
||||||
|
}
|
||||||
234
spiel_bot/src/alphazero/selfplay.rs
Normal file
234
spiel_bot/src/alphazero/selfplay.rs
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
//! Self-play episode generation and Burn-backed evaluator.
|
||||||
|
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
use crate::mcts::{self, Evaluator, MctsConfig, MctsNode};
|
||||||
|
use crate::network::PolicyValueNet;
|
||||||
|
|
||||||
|
use super::replay::TrainSample;
|
||||||
|
|
||||||
|
// ── BurnEvaluator ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS.
|
||||||
|
///
|
||||||
|
/// Use the **inference backend** (`NdArray<f32>`, no `Autodiff` wrapper) so
|
||||||
|
/// that self-play generates no gradient tape overhead.
|
||||||
|
pub struct BurnEvaluator<B: Backend, N: PolicyValueNet<B>> {
|
||||||
|
model: N,
|
||||||
|
device: B::Device,
|
||||||
|
_b: PhantomData<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
||||||
|
pub fn new(model: N, device: B::Device) -> Self {
|
||||||
|
Self { model, device, _b: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_model(self) -> N {
|
||||||
|
self.model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||||
|
// external synchronisation.
|
||||||
|
unsafe impl<B: Backend, N: PolicyValueNet<B>> Send for BurnEvaluator<B, N> {}
|
||||||
|
unsafe impl<B: Backend, N: PolicyValueNet<B>> Sync for BurnEvaluator<B, N> {}
|
||||||
|
|
||||||
|
impl<B: Backend, N: PolicyValueNet<B>> Evaluator for BurnEvaluator<B, N> {
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32) {
|
||||||
|
let obs_size = obs.len();
|
||||||
|
let data = TensorData::new(obs.to_vec(), [1, obs_size]);
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(data, &self.device);
|
||||||
|
|
||||||
|
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||||
|
|
||||||
|
let policy: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||||
|
let value: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||||
|
|
||||||
|
(policy, value[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Episode generation ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One pending observation waiting for its game-outcome value label.
|
||||||
|
struct PendingSample {
|
||||||
|
obs: Vec<f32>,
|
||||||
|
policy: Vec<f32>,
|
||||||
|
player: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Play one full game using MCTS guided by `evaluator`.
|
||||||
|
///
|
||||||
|
/// Returns a [`TrainSample`] for every decision step in the game.
|
||||||
|
///
|
||||||
|
/// `temperature_fn(step)` controls exploration: return `1.0` for early
|
||||||
|
/// moves and `0.0` after a fixed number of moves (e.g. move 30).
|
||||||
|
pub fn generate_episode<E: GameEnv>(
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
mcts_config: &MctsConfig,
|
||||||
|
temperature_fn: &dyn Fn(usize) -> f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> Vec<TrainSample> {
|
||||||
|
let mut state = env.new_game();
|
||||||
|
let mut pending: Vec<PendingSample> = Vec::new();
|
||||||
|
let mut step = 0usize;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Advance through chance nodes automatically.
|
||||||
|
while env.current_player(&state).is_chance() {
|
||||||
|
env.apply_chance(&mut state, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
if env.current_player(&state).is_terminal() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let player_idx = env.current_player(&state).index().unwrap();
|
||||||
|
|
||||||
|
// Run MCTS to get a policy.
|
||||||
|
let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng);
|
||||||
|
let policy = mcts::mcts_policy(&root, env.action_space());
|
||||||
|
|
||||||
|
// Record the observation from the acting player's perspective.
|
||||||
|
let obs = env.observation(&state, player_idx);
|
||||||
|
pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx });
|
||||||
|
|
||||||
|
// Select and apply the action.
|
||||||
|
let temperature = temperature_fn(step);
|
||||||
|
let action = mcts::select_action(&root, temperature, rng);
|
||||||
|
env.apply(&mut state, action);
|
||||||
|
step += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign game outcomes.
|
||||||
|
let returns = env.returns(&state).unwrap_or([0.0; 2]);
|
||||||
|
pending
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| TrainSample {
|
||||||
|
obs: s.obs,
|
||||||
|
policy: s.policy,
|
||||||
|
value: returns[s.player],
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
use crate::env::Player;
|
||||||
|
use crate::mcts::{Evaluator, MctsConfig};
|
||||||
|
use crate::network::{MlpConfig, MlpNet};
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rng() -> SmallRng {
|
||||||
|
SmallRng::seed_from_u64(7)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Countdown game (same as in mcts tests).
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct CState { remaining: u8, to_move: usize }
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountdownEnv;
|
||||||
|
|
||||||
|
impl GameEnv for CountdownEnv {
|
||||||
|
type State = CState;
|
||||||
|
fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } }
|
||||||
|
fn current_player(&self, s: &CState) -> Player {
|
||||||
|
if s.remaining == 0 { Player::Terminal }
|
||||||
|
else if s.to_move == 0 { Player::P1 } else { Player::P2 }
|
||||||
|
}
|
||||||
|
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||||
|
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||||
|
}
|
||||||
|
fn apply(&self, s: &mut CState, action: usize) {
|
||||||
|
let sub = (action as u8) + 1;
|
||||||
|
if s.remaining <= sub { s.remaining = 0; }
|
||||||
|
else { s.remaining -= sub; s.to_move = 1 - s.to_move; }
|
||||||
|
}
|
||||||
|
fn apply_chance<R: Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||||
|
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||||
|
vec![s.remaining as f32 / 4.0, s.to_move as f32]
|
||||||
|
}
|
||||||
|
fn obs_size(&self) -> usize { 2 }
|
||||||
|
fn action_space(&self) -> usize { 2 }
|
||||||
|
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||||
|
if s.remaining != 0 { return None; }
|
||||||
|
let mut r = [-1.0f32; 2];
|
||||||
|
r[s.to_move] = 1.0;
|
||||||
|
Some(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tiny_config() -> MctsConfig {
|
||||||
|
MctsConfig { n_simulations: 5, c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── BurnEvaluator tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn burn_evaluator_output_shapes() {
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let (policy, value) = eval.evaluate(&[0.5f32, 0.5]);
|
||||||
|
assert_eq!(policy.len(), 2, "policy length should equal action_space");
|
||||||
|
assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── generate_episode tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_terminates_and_has_samples() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||||
|
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_sample_values_are_valid() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||||
|
for s in &samples {
|
||||||
|
assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
|
||||||
|
"unexpected value {}", s.value);
|
||||||
|
let sum: f32 = s.policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}");
|
||||||
|
assert_eq!(s.obs.len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_with_temperature_zero() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
// temperature=0 means greedy; episode must still terminate
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng());
|
||||||
|
assert!(!samples.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
172
spiel_bot/src/alphazero/trainer.rs
Normal file
172
spiel_bot/src/alphazero/trainer.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
//! One gradient-descent training step for AlphaZero.
|
||||||
|
//!
|
||||||
|
//! The loss combines:
|
||||||
|
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
|
||||||
|
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||||
|
//! actual game outcome.
|
||||||
|
//!
|
||||||
|
//! # Backend
|
||||||
|
//!
|
||||||
|
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||||
|
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
|
||||||
|
//! Weights are transferred between the two via [`burn::record`].
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::AutodiffModule,
|
||||||
|
optim::{GradientsParams, Optimizer},
|
||||||
|
prelude::ElementConversion,
|
||||||
|
tensor::{
|
||||||
|
activation::log_softmax,
|
||||||
|
backend::AutodiffBackend,
|
||||||
|
Tensor, TensorData,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::network::PolicyValueNet;
|
||||||
|
use super::replay::TrainSample;
|
||||||
|
|
||||||
|
/// Run one gradient step on `model` using `batch`.
|
||||||
|
///
|
||||||
|
/// Returns the updated model and the scalar loss value for logging.
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// - `lr` — learning rate (e.g. `1e-3`).
|
||||||
|
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
|
||||||
|
pub fn train_step<B, N, O>(
|
||||||
|
model: N,
|
||||||
|
optimizer: &mut O,
|
||||||
|
batch: &[TrainSample],
|
||||||
|
device: &B::Device,
|
||||||
|
lr: f64,
|
||||||
|
) -> (N, f32)
|
||||||
|
where
|
||||||
|
B: AutodiffBackend,
|
||||||
|
N: PolicyValueNet<B> + AutodiffModule<B>,
|
||||||
|
O: Optimizer<N, B>,
|
||||||
|
{
|
||||||
|
assert!(!batch.is_empty(), "train_step called with empty batch");
|
||||||
|
|
||||||
|
let batch_size = batch.len();
|
||||||
|
let obs_size = batch[0].obs.len();
|
||||||
|
let action_size = batch[0].policy.len();
|
||||||
|
|
||||||
|
// ── Build input tensors ────────────────────────────────────────────────
|
||||||
|
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||||
|
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
|
||||||
|
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||||
|
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
let policy_target = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(policy_flat, [batch_size, action_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
let value_target = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(value_flat, [batch_size, 1]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── Forward pass ──────────────────────────────────────────────────────
|
||||||
|
let (policy_logits, value_pred) = model.forward(obs_tensor);
|
||||||
|
|
||||||
|
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
|
||||||
|
let log_probs = log_softmax(policy_logits, 1);
|
||||||
|
let policy_loss = (policy_target.clone().neg() * log_probs)
|
||||||
|
.sum_dim(1)
|
||||||
|
.mean();
|
||||||
|
|
||||||
|
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
|
||||||
|
let diff = value_pred - value_target;
|
||||||
|
let value_loss = (diff.clone() * diff).mean();
|
||||||
|
|
||||||
|
// ── Combined loss ─────────────────────────────────────────────────────
|
||||||
|
let loss = policy_loss + value_loss;
|
||||||
|
|
||||||
|
// Extract scalar before backward (consumes the tensor).
|
||||||
|
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||||
|
|
||||||
|
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||||
|
let grads = loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
let model = optimizer.step(lr, model, grads);
|
||||||
|
|
||||||
|
(model, loss_scalar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::{
|
||||||
|
backend::{Autodiff, NdArray},
|
||||||
|
optim::AdamConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::network::{MlpConfig, MlpNet};
|
||||||
|
use super::super::replay::TrainSample;
|
||||||
|
|
||||||
|
type B = Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
fn device() -> <B as burn::tensor::backend::Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
|
||||||
|
(0..n)
|
||||||
|
.map(|i| TrainSample {
|
||||||
|
obs: vec![0.5f32; obs_size],
|
||||||
|
policy: {
|
||||||
|
let mut p = vec![0.0f32; action_size];
|
||||||
|
p[i % action_size] = 1.0;
|
||||||
|
p
|
||||||
|
},
|
||||||
|
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_returns_finite_loss() {
|
||||||
|
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(8, 4, 4);
|
||||||
|
|
||||||
|
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||||
|
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||||
|
assert!(loss > 0.0, "loss should be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn loss_decreases_over_steps() {
|
||||||
|
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||||
|
let mut model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
// Same batch every step — loss should decrease.
|
||||||
|
let batch = dummy_batch(16, 4, 4);
|
||||||
|
|
||||||
|
let mut prev_loss = f32::INFINITY;
|
||||||
|
for _ in 0..10 {
|
||||||
|
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
|
||||||
|
model = m;
|
||||||
|
assert!(loss.is_finite());
|
||||||
|
prev_loss = loss;
|
||||||
|
}
|
||||||
|
// After 10 steps on fixed data, loss should be below a reasonable threshold.
|
||||||
|
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_batch_size_one() {
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(1, 2, 2);
|
||||||
|
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||||
|
assert!(loss.is_finite());
|
||||||
|
}
|
||||||
|
}
|
||||||
121
spiel_bot/src/env/mod.rs
vendored
Normal file
121
spiel_bot/src/env/mod.rs
vendored
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
//! Game environment abstraction — the minimal "Rust OpenSpiel".
|
||||||
|
//!
|
||||||
|
//! A `GameEnv` describes the rules of a two-player, zero-sum game that may
|
||||||
|
//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN,
|
||||||
|
//! and PPO interact with a game exclusively through this trait.
|
||||||
|
//!
|
||||||
|
//! # Node taxonomy
|
||||||
|
//!
|
||||||
|
//! Every game position belongs to one of four categories, returned by
|
||||||
|
//! [`GameEnv::current_player`]:
|
||||||
|
//!
|
||||||
|
//! | [`Player`] | Meaning |
|
||||||
|
//! |-----------|---------|
|
||||||
|
//! | `P1` | Player 1 (index 0) must choose an action |
|
||||||
|
//! | `P2` | Player 2 (index 1) must choose an action |
|
||||||
|
//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) |
|
||||||
|
//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful |
|
||||||
|
//!
|
||||||
|
//! # Perspective convention
|
||||||
|
//!
|
||||||
|
//! [`GameEnv::observation`] always returns the board from *the requested
|
||||||
|
//! player's* point of view. Callers pass `pov = 0` for Player 1 and
|
||||||
|
//! `pov = 1` for Player 2. The implementation is responsible for any
|
||||||
|
//! mirroring required (e.g. Trictrac always reasons from White's side).
|
||||||
|
|
||||||
|
pub mod trictrac;
|
||||||
|
pub use trictrac::TrictracEnv;
|
||||||
|
|
||||||
|
/// Who controls the current game node.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum Player {
|
||||||
|
/// Player 1 (index 0) is to move.
|
||||||
|
P1,
|
||||||
|
/// Player 2 (index 1) is to move.
|
||||||
|
P2,
|
||||||
|
/// A stochastic event (dice roll, etc.) must be resolved.
|
||||||
|
Chance,
|
||||||
|
/// The game is over.
|
||||||
|
Terminal,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Player {
|
||||||
|
/// Returns the player index (0 or 1) if this is a decision node,
|
||||||
|
/// or `None` for `Chance` / `Terminal`.
|
||||||
|
pub fn index(self) -> Option<usize> {
|
||||||
|
match self {
|
||||||
|
Player::P1 => Some(0),
|
||||||
|
Player::P2 => Some(1),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_decision(self) -> bool {
|
||||||
|
matches!(self, Player::P1 | Player::P2)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_chance(self) -> bool {
|
||||||
|
self == Player::Chance
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_terminal(self) -> bool {
|
||||||
|
self == Player::Terminal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait that completely describes a two-player zero-sum game.
|
||||||
|
///
|
||||||
|
/// Implementors must be cheaply cloneable (the type is used as a stateless
|
||||||
|
/// factory; the mutable game state lives in `Self::State`).
|
||||||
|
pub trait GameEnv: Clone + Send + Sync + 'static {
|
||||||
|
/// The mutable game state. Must be `Clone` so MCTS can copy
|
||||||
|
/// game trees without touching the environment.
|
||||||
|
type State: Clone + Send + Sync;
|
||||||
|
|
||||||
|
// ── State creation ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Create a fresh game state at the initial position.
|
||||||
|
fn new_game(&self) -> Self::State;
|
||||||
|
|
||||||
|
// ── Node queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Classify the current node.
|
||||||
|
fn current_player(&self, s: &Self::State) -> Player;
|
||||||
|
|
||||||
|
/// Legal action indices at a decision node (`current_player` is `P1`/`P2`).
|
||||||
|
///
|
||||||
|
/// The returned indices are in `[0, action_space())`.
|
||||||
|
/// The result is unspecified (may panic or return empty) when called at a
|
||||||
|
/// `Chance` or `Terminal` node.
|
||||||
|
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
|
||||||
|
|
||||||
|
// ── State mutation ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Apply a player action. `action` must be a value returned by
|
||||||
|
/// [`legal_actions`] for the current state.
|
||||||
|
fn apply(&self, s: &mut Self::State, action: usize);
|
||||||
|
|
||||||
|
/// Sample and apply a stochastic outcome. Must only be called when
|
||||||
|
/// `current_player(s) == Player::Chance`.
|
||||||
|
fn apply_chance<R: rand::Rng>(&self, s: &mut Self::State, rng: &mut R);
|
||||||
|
|
||||||
|
// ── Observation ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2).
|
||||||
|
/// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`.
|
||||||
|
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
|
||||||
|
|
||||||
|
/// Number of floats returned by [`observation`].
|
||||||
|
fn obs_size(&self) -> usize;
|
||||||
|
|
||||||
|
/// Total number of distinct action indices (the policy head output size).
|
||||||
|
fn action_space(&self) -> usize;
|
||||||
|
|
||||||
|
// ── Terminal values ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Game outcome for each player, or `None` if the game is not over.
|
||||||
|
///
|
||||||
|
/// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw.
|
||||||
|
/// Index 0 = Player 1, index 1 = Player 2.
|
||||||
|
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
|
||||||
|
}
|
||||||
535
spiel_bot/src/env/trictrac.rs
vendored
Normal file
535
spiel_bot/src/env/trictrac.rs
vendored
Normal file
|
|
@ -0,0 +1,535 @@
|
||||||
|
//! [`GameEnv`] implementation for Trictrac.
|
||||||
|
//!
|
||||||
|
//! # Game flow (schools_enabled = false)
|
||||||
|
//!
|
||||||
|
//! With scoring schools disabled (the standard training configuration),
|
||||||
|
//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine
|
||||||
|
//! applies them automatically inside `RollResult` and `Move`. The only
|
||||||
|
//! four stages that actually occur are:
|
||||||
|
//!
|
||||||
|
//! | `TurnStage` | [`Player`] kind | Handled by |
|
||||||
|
//! |-------------|-----------------|------------|
|
||||||
|
//! | `RollDice` | `Chance` | [`apply_chance`] |
|
||||||
|
//! | `RollWaiting` | `Chance` | [`apply_chance`] |
|
||||||
|
//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] |
|
||||||
|
//! | `Move` | `P1`/`P2` | [`apply`] |
|
||||||
|
//!
|
||||||
|
//! # Perspective
|
||||||
|
//!
|
||||||
|
//! The Trictrac engine always reasons from White's perspective. Player 1 is
|
||||||
|
//! White; Player 2 is Black. When Player 2 is active, the board is mirrored
|
||||||
|
//! before computing legal actions / the observation tensor, and the resulting
|
||||||
|
//! event is mirrored back before being applied to the real state. This
|
||||||
|
//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`.
|
||||||
|
|
||||||
|
use trictrac_store::{
|
||||||
|
training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE},
|
||||||
|
Dice, GameEvent, GameState, Stage, TurnStage,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{GameEnv, Player};
|
||||||
|
|
||||||
|
/// Stateless factory that produces Trictrac [`GameState`] environments.
|
||||||
|
///
|
||||||
|
/// Schools (`schools_enabled`) are always disabled — scoring is automatic.
|
||||||
|
#[derive(Clone, Debug, Default)]
|
||||||
|
pub struct TrictracEnv;
|
||||||
|
|
||||||
|
impl GameEnv for TrictracEnv {
|
||||||
|
type State = GameState;
|
||||||
|
|
||||||
|
// ── State creation ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn new_game(&self) -> GameState {
|
||||||
|
GameState::new_with_players("P1", "P2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Node queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn current_player(&self, s: &GameState) -> Player {
|
||||||
|
if s.stage == Stage::Ended {
|
||||||
|
return Player::Terminal;
|
||||||
|
}
|
||||||
|
match s.turn_stage {
|
||||||
|
TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance,
|
||||||
|
_ => {
|
||||||
|
if s.active_player_id == 1 {
|
||||||
|
Player::P1
|
||||||
|
} else {
|
||||||
|
Player::P2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the legal action indices for the active player.
|
||||||
|
///
|
||||||
|
/// The board is automatically mirrored for Player 2 so that the engine
|
||||||
|
/// always reasons from White's perspective. The returned indices are
|
||||||
|
/// identical in meaning for both players (checker ordinals are
|
||||||
|
/// perspective-relative).
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics in debug builds if called at a `Chance` or `Terminal` node.
|
||||||
|
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
|
||||||
|
debug_assert!(
|
||||||
|
self.current_player(s).is_decision(),
|
||||||
|
"legal_actions called at a non-decision node (turn_stage={:?})",
|
||||||
|
s.turn_stage
|
||||||
|
);
|
||||||
|
let indices = if s.active_player_id == 2 {
|
||||||
|
get_valid_action_indices(&s.mirror())
|
||||||
|
} else {
|
||||||
|
get_valid_action_indices(s)
|
||||||
|
};
|
||||||
|
indices.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── State mutation ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Apply a player action index to the game state.
|
||||||
|
///
|
||||||
|
/// For Player 2, the action is decoded against the mirrored board and
|
||||||
|
/// the resulting event is un-mirrored before being applied.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics in debug builds if `action` cannot be decoded or does not
|
||||||
|
/// produce a valid event for the current state.
|
||||||
|
fn apply(&self, s: &mut GameState, action: usize) {
|
||||||
|
let needs_mirror = s.active_player_id == 2;
|
||||||
|
|
||||||
|
let event = if needs_mirror {
|
||||||
|
let view = s.mirror();
|
||||||
|
TrictracAction::from_action_index(action)
|
||||||
|
.and_then(|a| a.to_event(&view))
|
||||||
|
.map(|e| e.get_mirror(false))
|
||||||
|
} else {
|
||||||
|
TrictracAction::from_action_index(action).and_then(|a| a.to_event(s))
|
||||||
|
};
|
||||||
|
|
||||||
|
match event {
|
||||||
|
Some(e) => {
|
||||||
|
s.consume(&e).expect("apply: consume failed for valid action");
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
panic!("apply: action index {action} produced no event in state {s}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sample dice and advance through a chance node.
|
||||||
|
///
|
||||||
|
/// Handles both `RollDice` (triggers the roll mechanism, then samples
|
||||||
|
/// dice) and `RollWaiting` (only samples dice) in a single call so that
|
||||||
|
/// callers never need to distinguish the two.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics in debug builds if called at a non-Chance node.
|
||||||
|
fn apply_chance<R: rand::Rng>(&self, s: &mut GameState, rng: &mut R) {
|
||||||
|
debug_assert!(
|
||||||
|
self.current_player(s).is_chance(),
|
||||||
|
"apply_chance called at a non-Chance node (turn_stage={:?})",
|
||||||
|
s.turn_stage
|
||||||
|
);
|
||||||
|
|
||||||
|
// Step 1: RollDice → RollWaiting (player initiates the roll).
|
||||||
|
if s.turn_stage == TurnStage::RollDice {
|
||||||
|
s.consume(&GameEvent::Roll {
|
||||||
|
player_id: s.active_player_id,
|
||||||
|
})
|
||||||
|
.expect("apply_chance: Roll event failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: RollWaiting → Move / HoldOrGoChoice / Ended.
|
||||||
|
// With schools_enabled=false, point marking is automatic inside consume().
|
||||||
|
let dice = Dice {
|
||||||
|
values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)),
|
||||||
|
};
|
||||||
|
s.consume(&GameEvent::RollResult {
|
||||||
|
player_id: s.active_player_id,
|
||||||
|
dice,
|
||||||
|
})
|
||||||
|
.expect("apply_chance: RollResult event failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Observation ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
|
||||||
|
if pov == 0 {
|
||||||
|
s.to_tensor()
|
||||||
|
} else {
|
||||||
|
s.mirror().to_tensor()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn obs_size(&self) -> usize {
|
||||||
|
217
|
||||||
|
}
|
||||||
|
|
||||||
|
fn action_space(&self) -> usize {
|
||||||
|
ACTION_SPACE_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Terminal values ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Returns `Some([r1, r2])` when the game is over, `None` otherwise.
|
||||||
|
///
|
||||||
|
/// The winner (higher cumulative score) receives `+1.0`; the loser
|
||||||
|
/// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score
|
||||||
|
/// is `holes × 12 + points`.
|
||||||
|
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
|
||||||
|
if s.stage != Stage::Ended {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let score = |id: u64| -> i32 {
|
||||||
|
s.players
|
||||||
|
.get(&id)
|
||||||
|
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||||
|
.unwrap_or(0)
|
||||||
|
};
|
||||||
|
let s1 = score(1);
|
||||||
|
let s2 = score(2);
|
||||||
|
Some(match s1.cmp(&s2) {
|
||||||
|
std::cmp::Ordering::Greater => [1.0, -1.0],
|
||||||
|
std::cmp::Ordering::Less => [-1.0, 1.0],
|
||||||
|
std::cmp::Ordering::Equal => [0.0, 0.0],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{rngs::SmallRng, Rng, SeedableRng};
|
||||||
|
|
||||||
|
fn env() -> TrictracEnv {
|
||||||
|
TrictracEnv
|
||||||
|
}
|
||||||
|
|
||||||
|
fn seeded_rng(seed: u64) -> SmallRng {
|
||||||
|
SmallRng::seed_from_u64(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Initial state ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_game_is_chance_node() {
|
||||||
|
let e = env();
|
||||||
|
let s = e.new_game();
|
||||||
|
// A fresh game starts at RollDice — a Chance node.
|
||||||
|
assert_eq!(e.current_player(&s), Player::Chance);
|
||||||
|
assert!(e.returns(&s).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_game_is_not_terminal() {
|
||||||
|
let e = env();
|
||||||
|
let s = e.new_game();
|
||||||
|
assert_ne!(e.current_player(&s), Player::Terminal);
|
||||||
|
assert!(e.returns(&s).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Chance nodes ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_chance_reaches_decision_node() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(1);
|
||||||
|
|
||||||
|
// A single chance step must yield a decision node (or end the game,
|
||||||
|
// which only happens after 12 holes — impossible on the first roll).
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
let p = e.current_player(&s);
|
||||||
|
assert!(
|
||||||
|
p.is_decision(),
|
||||||
|
"expected decision node after first roll, got {p:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_chance_from_rollwaiting() {
|
||||||
|
// Check that apply_chance works when called mid-way (at RollWaiting).
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
assert_eq!(s.turn_stage, TurnStage::RollDice);
|
||||||
|
|
||||||
|
// Manually advance to RollWaiting.
|
||||||
|
s.consume(&GameEvent::Roll { player_id: s.active_player_id })
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(s.turn_stage, TurnStage::RollWaiting);
|
||||||
|
|
||||||
|
let mut rng = seeded_rng(2);
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
|
||||||
|
let p = e.current_player(&s);
|
||||||
|
assert!(p.is_decision() || p.is_terminal());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Legal actions ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn legal_actions_nonempty_after_roll() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(3);
|
||||||
|
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
assert!(e.current_player(&s).is_decision());
|
||||||
|
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
assert!(
|
||||||
|
!actions.is_empty(),
|
||||||
|
"legal_actions must be non-empty at a decision node"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn legal_actions_within_action_space() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(4);
|
||||||
|
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
for &a in e.legal_actions(&s).iter() {
|
||||||
|
assert!(
|
||||||
|
a < e.action_space(),
|
||||||
|
"action {a} out of bounds (action_space={})",
|
||||||
|
e.action_space()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Observations ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn observation_has_correct_size() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(5);
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
|
||||||
|
assert_eq!(e.observation(&s, 0).len(), e.obs_size());
|
||||||
|
assert_eq!(e.observation(&s, 1).len(), e.obs_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn observation_values_in_unit_interval() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(6);
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
|
||||||
|
for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] {
|
||||||
|
for (i, &v) in obs.iter().enumerate() {
|
||||||
|
assert!(
|
||||||
|
v >= 0.0 && v <= 1.0,
|
||||||
|
"pov={pov}: obs[{i}] = {v} is outside [0,1]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn p1_and_p2_observations_differ() {
|
||||||
|
// The board is mirrored for P2, so the two observations should differ
|
||||||
|
// whenever there are checkers in non-symmetric positions (always true
|
||||||
|
// in a real game after a few moves).
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(7);
|
||||||
|
|
||||||
|
// Advance far enough that the board is non-trivial.
|
||||||
|
for _ in 0..6 {
|
||||||
|
while e.current_player(&s).is_chance() {
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
}
|
||||||
|
if e.current_player(&s).is_terminal() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
e.apply(&mut s, actions[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.current_player(&s).is_terminal() {
|
||||||
|
let obs0 = e.observation(&s, 0);
|
||||||
|
let obs1 = e.observation(&s, 1);
|
||||||
|
assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Applying actions ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_changes_state() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(8);
|
||||||
|
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
assert!(e.current_player(&s).is_decision());
|
||||||
|
|
||||||
|
let before = s.clone();
|
||||||
|
let action = e.legal_actions(&s)[0];
|
||||||
|
e.apply(&mut s, action);
|
||||||
|
|
||||||
|
assert_ne!(
|
||||||
|
before.turn_stage, s.turn_stage,
|
||||||
|
"state must change after apply"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn apply_all_legal_actions_do_not_panic() {
|
||||||
|
// Verify that every action returned by legal_actions can be applied
|
||||||
|
// without panicking (on several independent copies of the same state).
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(9);
|
||||||
|
|
||||||
|
e.apply_chance(&mut s, &mut rng);
|
||||||
|
assert!(e.current_player(&s).is_decision());
|
||||||
|
|
||||||
|
for action in e.legal_actions(&s) {
|
||||||
|
let mut copy = s.clone();
|
||||||
|
e.apply(&mut copy, action); // must not panic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Full game ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Run a complete game with random actions through the `GameEnv` trait
|
||||||
|
/// and verify that:
|
||||||
|
/// - The game terminates.
|
||||||
|
/// - `returns()` is `Some` at the end.
|
||||||
|
/// - The outcome is valid: scores sum to 0 (zero-sum) or each player's
|
||||||
|
/// score is ±1 / 0.
|
||||||
|
/// - No step panics.
|
||||||
|
#[test]
|
||||||
|
fn full_random_game_terminates() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
let max_steps = 50_000;
|
||||||
|
|
||||||
|
for step in 0..max_steps {
|
||||||
|
match e.current_player(&s) {
|
||||||
|
Player::Terminal => break,
|
||||||
|
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||||
|
Player::P1 | Player::P2 => {
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node");
|
||||||
|
let idx = rng.random_range(0..actions.len());
|
||||||
|
e.apply(&mut s, actions[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = e.returns(&s);
|
||||||
|
assert!(result.is_some(), "returns() must be Some at Terminal");
|
||||||
|
|
||||||
|
let [r1, r2] = result.unwrap();
|
||||||
|
let sum = r1 + r2;
|
||||||
|
assert!(
|
||||||
|
(sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5,
|
||||||
|
"game must be zero-sum: r1={r1}, r2={r2}, sum={sum}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
r1.abs() <= 1.0 && r2.abs() <= 1.0,
|
||||||
|
"returns must be in [-1,1]: r1={r1}, r2={r2}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run multiple games with different seeds to stress-test for panics.
|
||||||
|
#[test]
|
||||||
|
fn multiple_games_no_panic() {
|
||||||
|
let e = env();
|
||||||
|
let max_steps = 20_000;
|
||||||
|
|
||||||
|
for seed in 0..10u64 {
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(seed);
|
||||||
|
|
||||||
|
for _ in 0..max_steps {
|
||||||
|
match e.current_player(&s) {
|
||||||
|
Player::Terminal => break,
|
||||||
|
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||||
|
Player::P1 | Player::P2 => {
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
let idx = rng.random_range(0..actions.len());
|
||||||
|
e.apply(&mut s, actions[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Returns ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_mid_game() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(11);
|
||||||
|
|
||||||
|
// Advance a few steps but do not finish the game.
|
||||||
|
for _ in 0..4 {
|
||||||
|
match e.current_player(&s) {
|
||||||
|
Player::Terminal => break,
|
||||||
|
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||||
|
Player::P1 | Player::P2 => {
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
e.apply(&mut s, actions[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.current_player(&s).is_terminal() {
|
||||||
|
assert!(
|
||||||
|
e.returns(&s).is_none(),
|
||||||
|
"returns() must be None before the game ends"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Player 2 actions ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Verify that Player 2 (Black) can take actions without panicking,
|
||||||
|
/// and that the state advances correctly.
|
||||||
|
#[test]
|
||||||
|
fn player2_can_act() {
|
||||||
|
let e = env();
|
||||||
|
let mut s = e.new_game();
|
||||||
|
let mut rng = seeded_rng(12);
|
||||||
|
|
||||||
|
// Keep stepping until Player 2 gets a turn.
|
||||||
|
let max_steps = 5_000;
|
||||||
|
let mut p2_acted = false;
|
||||||
|
|
||||||
|
for _ in 0..max_steps {
|
||||||
|
match e.current_player(&s) {
|
||||||
|
Player::Terminal => break,
|
||||||
|
Player::Chance => e.apply_chance(&mut s, &mut rng),
|
||||||
|
Player::P2 => {
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
assert!(!actions.is_empty());
|
||||||
|
e.apply(&mut s, actions[0]);
|
||||||
|
p2_acted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Player::P1 => {
|
||||||
|
let actions = e.legal_actions(&s);
|
||||||
|
e.apply(&mut s, actions[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps");
|
||||||
|
}
|
||||||
|
}
|
||||||
4
spiel_bot/src/lib.rs
Normal file
4
spiel_bot/src/lib.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
pub mod alphazero;
|
||||||
|
pub mod env;
|
||||||
|
pub mod mcts;
|
||||||
|
pub mod network;
|
||||||
408
spiel_bot/src/mcts/mod.rs
Normal file
408
spiel_bot/src/mcts/mod.rs
Normal file
|
|
@ -0,0 +1,408 @@
|
||||||
|
//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance.
|
||||||
|
//!
|
||||||
|
//! # Algorithm
|
||||||
|
//!
|
||||||
|
//! The implementation follows AlphaZero's MCTS:
|
||||||
|
//!
|
||||||
|
//! 1. **Expand root** — run the network once to get priors and a value
|
||||||
|
//! estimate; optionally add Dirichlet noise for training-time exploration.
|
||||||
|
//! 2. **Simulate** `n_simulations` times:
|
||||||
|
//! - *Selection* — traverse the tree with PUCT until an unvisited leaf.
|
||||||
|
//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes;
|
||||||
|
//! chance nodes are **not** stored in the tree (outcome sampling).
|
||||||
|
//! - *Expansion* — evaluate the network at the leaf; populate children.
|
||||||
|
//! - *Backup* — propagate the value upward; negate at each player boundary.
|
||||||
|
//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]).
|
||||||
|
//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]).
|
||||||
|
//!
|
||||||
|
//! # Perspective convention
|
||||||
|
//!
|
||||||
|
//! Every [`MctsNode::w`] is stored **from the perspective of the player who
|
||||||
|
//! acts at that node**. The backup negates the child value whenever the
|
||||||
|
//! acting player differs between parent and child.
|
||||||
|
//!
|
||||||
|
//! # Stochastic games
|
||||||
|
//!
|
||||||
|
//! When [`GameEnv::current_player`] returns [`Player::Chance`], the
|
||||||
|
//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and
|
||||||
|
//! continues. Chance nodes are skipped transparently; Q-values converge to
|
||||||
|
//! their expectation over many simulations (outcome sampling).
|
||||||
|
|
||||||
|
pub mod node;
|
||||||
|
mod search;
|
||||||
|
|
||||||
|
pub use node::MctsNode;
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
|
||||||
|
// ── Evaluator trait ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Evaluates a game position for use in MCTS.
|
||||||
|
///
|
||||||
|
/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet)
|
||||||
|
/// but the `mcts` module itself does **not** depend on Burn.
|
||||||
|
pub trait Evaluator: Send + Sync {
|
||||||
|
/// Evaluate `obs` (flat observation vector of length `obs_size`).
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `policy_logits`: one raw logit per action (`action_space` entries).
|
||||||
|
/// Illegal action entries are masked inside the search — no need to
|
||||||
|
/// zero them here.
|
||||||
|
/// - `value`: scalar in `(-1, 1)` from **the current player's** perspective.
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Configuration ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Hyperparameters for [`run_mcts`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MctsConfig {
|
||||||
|
/// Number of MCTS simulations per move. Typical: 50–800.
|
||||||
|
pub n_simulations: usize,
|
||||||
|
/// PUCT exploration constant `c_puct`. Typical: 1.0–2.0.
|
||||||
|
pub c_puct: f32,
|
||||||
|
/// Dirichlet noise concentration α. Set to `0.0` to disable.
|
||||||
|
/// Typical: `0.3` for Chess, `0.1` for large action spaces.
|
||||||
|
pub dirichlet_alpha: f32,
|
||||||
|
/// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`.
|
||||||
|
pub dirichlet_eps: f32,
|
||||||
|
/// Action sampling temperature. `> 0` = proportional sample, `0` = argmax.
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MctsConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
n_simulations: 200,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.3,
|
||||||
|
dirichlet_eps: 0.25,
|
||||||
|
temperature: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public interface ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Run MCTS from `state` and return the populated root [`MctsNode`].
|
||||||
|
///
|
||||||
|
/// `state` must be a player-decision node (`P1` or `P2`).
|
||||||
|
/// Use [`mcts_policy`] and [`select_action`] on the returned root.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `env.current_player(state)` is not `P1` or `P2`.
|
||||||
|
pub fn run_mcts<E: GameEnv>(
|
||||||
|
env: &E,
|
||||||
|
state: &E::State,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
config: &MctsConfig,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> MctsNode {
|
||||||
|
let player_idx = env
|
||||||
|
.current_player(state)
|
||||||
|
.index()
|
||||||
|
.expect("run_mcts called at a non-decision node");
|
||||||
|
|
||||||
|
// ── Expand root (network called once here, not inside the loop) ────────
|
||||||
|
let mut root = MctsNode::new(1.0);
|
||||||
|
search::expand::<E>(&mut root, state, env, evaluator, player_idx);
|
||||||
|
|
||||||
|
// ── Optional Dirichlet noise for training exploration ──────────────────
|
||||||
|
if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 {
|
||||||
|
search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Simulations ────────────────────────────────────────────────────────
|
||||||
|
for _ in 0..config.n_simulations {
|
||||||
|
search::simulate::<E>(
|
||||||
|
&mut root,
|
||||||
|
state.clone(),
|
||||||
|
env,
|
||||||
|
evaluator,
|
||||||
|
config,
|
||||||
|
rng,
|
||||||
|
player_idx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
root
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the MCTS policy: normalized visit counts at the root.
|
||||||
|
///
|
||||||
|
/// Returns a vector of length `action_space` where `policy[a]` is the
|
||||||
|
/// fraction of simulations that visited action `a`.
|
||||||
|
pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec<f32> {
|
||||||
|
let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum();
|
||||||
|
let mut policy = vec![0.0f32; action_space];
|
||||||
|
if total > 0.0 {
|
||||||
|
for (a, child) in &root.children {
|
||||||
|
policy[*a] = child.n as f32 / total;
|
||||||
|
}
|
||||||
|
} else if !root.children.is_empty() {
|
||||||
|
// n_simulations = 0: uniform over legal actions.
|
||||||
|
let uniform = 1.0 / root.children.len() as f32;
|
||||||
|
for (a, _) in &root.children {
|
||||||
|
policy[*a] = uniform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
policy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select an action index from the root after MCTS.
|
||||||
|
///
|
||||||
|
/// * `temperature = 0` — greedy argmax of visit counts.
|
||||||
|
/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the root has no children.
|
||||||
|
pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize {
|
||||||
|
assert!(!root.children.is_empty(), "select_action called on a root with no children");
|
||||||
|
if temperature <= 0.0 {
|
||||||
|
root.children
|
||||||
|
.iter()
|
||||||
|
.max_by_key(|(_, c)| c.n)
|
||||||
|
.map(|(a, _)| *a)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
let weights: Vec<f32> = root
|
||||||
|
.children
|
||||||
|
.iter()
|
||||||
|
.map(|(_, c)| (c.n as f32).powf(1.0 / temperature))
|
||||||
|
.collect();
|
||||||
|
let total: f32 = weights.iter().sum();
|
||||||
|
let mut r: f32 = rng.random::<f32>() * total;
|
||||||
|
for (i, (a, _)) in root.children.iter().enumerate() {
|
||||||
|
r -= weights[i];
|
||||||
|
if r <= 0.0 {
|
||||||
|
return *a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
root.children.last().map(|(a, _)| *a).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
use crate::env::Player;
|
||||||
|
|
||||||
|
// ── Minimal deterministic test game ───────────────────────────────────
|
||||||
|
//
|
||||||
|
// "Countdown" — two players alternate subtracting 1 or 2 from a counter.
|
||||||
|
// The player who brings the counter to 0 wins.
|
||||||
|
// No chance nodes, two legal actions (0 = -1, 1 = -2).
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct CState {
|
||||||
|
remaining: u8,
|
||||||
|
to_move: usize, // at terminal: last mover (winner)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountdownEnv;
|
||||||
|
|
||||||
|
impl crate::env::GameEnv for CountdownEnv {
|
||||||
|
type State = CState;
|
||||||
|
|
||||||
|
fn new_game(&self) -> CState {
|
||||||
|
CState { remaining: 6, to_move: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_player(&self, s: &CState) -> Player {
|
||||||
|
if s.remaining == 0 {
|
||||||
|
Player::Terminal
|
||||||
|
} else if s.to_move == 0 {
|
||||||
|
Player::P1
|
||||||
|
} else {
|
||||||
|
Player::P2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||||
|
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(&self, s: &mut CState, action: usize) {
|
||||||
|
let sub = (action as u8) + 1;
|
||||||
|
if s.remaining <= sub {
|
||||||
|
s.remaining = 0;
|
||||||
|
// to_move stays as winner
|
||||||
|
} else {
|
||||||
|
s.remaining -= sub;
|
||||||
|
s.to_move = 1 - s.to_move;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||||
|
|
||||||
|
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||||
|
vec![s.remaining as f32 / 6.0, s.to_move as f32]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn obs_size(&self) -> usize { 2 }
|
||||||
|
fn action_space(&self) -> usize { 2 }
|
||||||
|
|
||||||
|
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||||
|
if s.remaining != 0 { return None; }
|
||||||
|
let mut r = [-1.0f32; 2];
|
||||||
|
r[s.to_move] = 1.0;
|
||||||
|
Some(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uniform evaluator: all logits = 0, value = 0.
|
||||||
|
// `action_space` must match the environment's `action_space()`.
|
||||||
|
struct ZeroEval(usize);
|
||||||
|
impl Evaluator for ZeroEval {
|
||||||
|
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||||
|
(vec![0.0f32; self.0], 0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rng() -> SmallRng {
|
||||||
|
SmallRng::seed_from_u64(42)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_n(n: usize) -> MctsConfig {
|
||||||
|
MctsConfig {
|
||||||
|
n_simulations: n,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0, // off for reproducibility
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
temperature: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Visit count tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn visit_counts_sum_to_n_simulations() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng());
|
||||||
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
|
assert_eq!(total, 50, "visit counts must sum to n_simulations");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn all_root_children_are_legal() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let legal = env.legal_actions(&state);
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng());
|
||||||
|
for (a, _) in &root.children {
|
||||||
|
assert!(legal.contains(a), "child action {a} is not legal");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Policy tests ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_sums_to_one() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
let sum: f32 = policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_zero_for_illegal_actions() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
// remaining = 1 → only action 0 is legal
|
||||||
|
let state = CState { remaining: 1, to_move: 0 };
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Action selection tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn greedy_selects_most_visited() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng());
|
||||||
|
let greedy = select_action(&root, 0.0, &mut rng());
|
||||||
|
let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap();
|
||||||
|
assert_eq!(greedy, most_visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temperature_sampling_stays_legal() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let legal = env.legal_actions(&state);
|
||||||
|
let mut r = rng();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r);
|
||||||
|
for _ in 0..20 {
|
||||||
|
let a = select_action(&root, 1.0, &mut r);
|
||||||
|
assert!(legal.contains(&a), "sampled action {a} is not legal");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Zero-simulation edge case ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_simulations_uniform_policy() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
// With 0 simulations, fallback is uniform over the 2 legal actions.
|
||||||
|
let sum: f32 = policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Root value ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn root_q_in_valid_range() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng());
|
||||||
|
let q = root.q();
|
||||||
|
assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: run on a real Trictrac game ──────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_panic_on_trictrac_state() {
|
||||||
|
use crate::env::TrictracEnv;
|
||||||
|
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let mut state = env.new_game();
|
||||||
|
let mut r = rng();
|
||||||
|
|
||||||
|
// Advance past the initial chance node to reach a decision node.
|
||||||
|
while env.current_player(&state).is_chance() {
|
||||||
|
env.apply_chance(&mut state, &mut r);
|
||||||
|
}
|
||||||
|
|
||||||
|
if env.current_player(&state).is_terminal() {
|
||||||
|
return; // unlikely but safe
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = MctsConfig {
|
||||||
|
n_simulations: 5, // tiny for speed
|
||||||
|
dirichlet_alpha: 0.0,
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
..MctsConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||||
|
assert!(root.n > 0);
|
||||||
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
|
assert_eq!(total, 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
91
spiel_bot/src/mcts/node.rs
Normal file
91
spiel_bot/src/mcts/node.rs
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
//! MCTS tree node.
|
||||||
|
//!
|
||||||
|
//! [`MctsNode`] holds the visit statistics for one player-decision position in
|
||||||
|
//! the search tree. A node is *expanded* the first time the policy-value
|
||||||
|
//! network is evaluated there; before that it is a leaf.
|
||||||
|
|
||||||
|
/// One node in the MCTS tree, representing a player-decision position.
|
||||||
|
///
|
||||||
|
/// `w` stores the sum of values backed up into this node, always from the
|
||||||
|
/// perspective of **the player who acts here**. `q()` therefore also returns
|
||||||
|
/// a value in `(-1, 1)` from that same perspective.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MctsNode {
|
||||||
|
/// Visit count `N(s, a)`.
|
||||||
|
pub n: u32,
|
||||||
|
/// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective.
|
||||||
|
pub w: f32,
|
||||||
|
/// Prior probability `P(s, a)` assigned by the policy head (after masked softmax).
|
||||||
|
pub p: f32,
|
||||||
|
/// Children: `(action_index, child_node)`, populated on first expansion.
|
||||||
|
pub children: Vec<(usize, MctsNode)>,
|
||||||
|
/// `true` after the network has been evaluated and children have been set up.
|
||||||
|
pub expanded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsNode {
|
||||||
|
/// Create a fresh, unexpanded leaf with the given prior probability.
|
||||||
|
pub fn new(prior: f32) -> Self {
|
||||||
|
Self {
|
||||||
|
n: 0,
|
||||||
|
w: 0.0,
|
||||||
|
p: prior,
|
||||||
|
children: Vec::new(),
|
||||||
|
expanded: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Q(s, a) = W / N`, or `0.0` if this node has never been visited.
|
||||||
|
#[inline]
|
||||||
|
pub fn q(&self) -> f32 {
|
||||||
|
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUCT selection score:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a))
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||||
|
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_zero_when_unvisited() {
|
||||||
|
let node = MctsNode::new(0.5);
|
||||||
|
assert_eq!(node.q(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_reflects_w_over_n() {
|
||||||
|
let mut node = MctsNode::new(0.5);
|
||||||
|
node.n = 4;
|
||||||
|
node.w = 2.0;
|
||||||
|
assert!((node.q() - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn puct_exploration_dominates_unvisited() {
|
||||||
|
// Unvisited child should outscore a visited child with negative Q.
|
||||||
|
let mut visited = MctsNode::new(0.5);
|
||||||
|
visited.n = 10;
|
||||||
|
visited.w = -5.0; // Q = -0.5
|
||||||
|
|
||||||
|
let unvisited = MctsNode::new(0.5);
|
||||||
|
|
||||||
|
let parent_n = 10;
|
||||||
|
let c = 1.5;
|
||||||
|
assert!(
|
||||||
|
unvisited.puct(parent_n, c) > visited.puct(parent_n, c),
|
||||||
|
"unvisited child should have higher PUCT than a negatively-valued visited child"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
170
spiel_bot/src/mcts/search.rs
Normal file
170
spiel_bot/src/mcts/search.rs
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
//! Simulation, expansion, backup, and noise helpers.
|
||||||
|
//!
|
||||||
|
//! These are internal to the `mcts` module; the public entry points are
|
||||||
|
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
use rand_distr::{Gamma, Distribution};
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
use super::{Evaluator, MctsConfig};
|
||||||
|
use super::node::MctsNode;
|
||||||
|
|
||||||
|
// ── Masked softmax ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Numerically stable softmax over `legal` actions only.
|
||||||
|
///
|
||||||
|
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
|
||||||
|
/// Returns a probability vector of length `action_space`.
|
||||||
|
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
|
||||||
|
let mut probs = vec![0.0f32; action_space];
|
||||||
|
if legal.is_empty() {
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
let max_logit = legal
|
||||||
|
.iter()
|
||||||
|
.map(|&a| logits[a])
|
||||||
|
.fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
for &a in legal {
|
||||||
|
let e = (logits[a] - max_logit).exp();
|
||||||
|
probs[a] = e;
|
||||||
|
sum += e;
|
||||||
|
}
|
||||||
|
if sum > 0.0 {
|
||||||
|
for &a in legal {
|
||||||
|
probs[a] /= sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let uniform = 1.0 / legal.len() as f32;
|
||||||
|
for &a in legal {
|
||||||
|
probs[a] = uniform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
probs
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Dirichlet noise ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
|
||||||
|
///
|
||||||
|
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
|
||||||
|
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
|
||||||
|
pub(super) fn add_dirichlet_noise(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
alpha: f32,
|
||||||
|
eps: f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) {
|
||||||
|
let n = node.children.len();
|
||||||
|
if n == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
|
||||||
|
let sum: f32 = samples.iter().sum();
|
||||||
|
if sum <= 0.0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (i, (_, child)) in node.children.iter_mut().enumerate() {
|
||||||
|
let noise = samples[i] / sum;
|
||||||
|
child.p = (1.0 - eps) * child.p + eps * noise;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Expansion ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Evaluate the network at `state` and populate `node` with children.
|
||||||
|
///
|
||||||
|
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
|
||||||
|
/// Returns the network value estimate from `player_idx`'s perspective.
|
||||||
|
pub(super) fn expand<E: GameEnv>(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
state: &E::State,
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
player_idx: usize,
|
||||||
|
) -> f32 {
|
||||||
|
let obs = env.observation(state, player_idx);
|
||||||
|
let legal = env.legal_actions(state);
|
||||||
|
let (logits, value) = evaluator.evaluate(&obs);
|
||||||
|
let priors = masked_softmax(&logits, &legal, env.action_space());
|
||||||
|
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
|
||||||
|
node.expanded = true;
|
||||||
|
node.n = 1;
|
||||||
|
node.w = value;
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Simulation ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One MCTS simulation from an **already-expanded** decision node.
|
||||||
|
///
|
||||||
|
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
|
||||||
|
/// and backs up the result.
|
||||||
|
///
|
||||||
|
/// * `player_idx` — the player (0 or 1) who acts at `state`.
|
||||||
|
/// * Returns the backed-up value **from `player_idx`'s perspective**.
|
||||||
|
pub(super) fn simulate<E: GameEnv>(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
state: E::State,
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
config: &MctsConfig,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
player_idx: usize,
|
||||||
|
) -> f32 {
|
||||||
|
debug_assert!(node.expanded, "simulate called on unexpanded node");
|
||||||
|
|
||||||
|
// ── Selection: child with highest PUCT ────────────────────────────────
|
||||||
|
let parent_n = node.n;
|
||||||
|
let best = node
|
||||||
|
.children
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, (_, a)), (_, (_, b))| {
|
||||||
|
a.puct(parent_n, config.c_puct)
|
||||||
|
.partial_cmp(&b.puct(parent_n, config.c_puct))
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.expect("expanded node must have at least one child");
|
||||||
|
|
||||||
|
let (action, child) = &mut node.children[best];
|
||||||
|
let action = *action;
|
||||||
|
|
||||||
|
// ── Apply action + advance through any chance nodes ───────────────────
|
||||||
|
let mut next_state = state;
|
||||||
|
env.apply(&mut next_state, action);
|
||||||
|
while env.current_player(&next_state).is_chance() {
|
||||||
|
env.apply_chance(&mut next_state, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
let next_cp = env.current_player(&next_state);
|
||||||
|
|
||||||
|
// ── Evaluate leaf or terminal ──────────────────────────────────────────
|
||||||
|
// All values are converted to `player_idx`'s perspective before backup.
|
||||||
|
let child_value = if next_cp.is_terminal() {
|
||||||
|
let returns = env
|
||||||
|
.returns(&next_state)
|
||||||
|
.expect("terminal node must have returns");
|
||||||
|
returns[player_idx]
|
||||||
|
} else {
|
||||||
|
let child_player = next_cp.index().unwrap();
|
||||||
|
let v = if child.expanded {
|
||||||
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||||
|
} else {
|
||||||
|
expand::<E>(child, &next_state, env, evaluator, child_player)
|
||||||
|
};
|
||||||
|
// Negate when the child belongs to the opponent.
|
||||||
|
if child_player == player_idx { v } else { -v }
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Backup ────────────────────────────────────────────────────────────
|
||||||
|
node.n += 1;
|
||||||
|
node.w += child_value;
|
||||||
|
|
||||||
|
child_value
|
||||||
|
}
|
||||||
223
spiel_bot/src/network/mlp.rs
Normal file
223
spiel_bot/src/network/mlp.rs
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
//! Two-hidden-layer MLP policy-value network.
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! Input [B, obs_size]
|
||||||
|
//! → Linear(obs → hidden) → ReLU
|
||||||
|
//! → Linear(hidden → hidden) → ReLU
|
||||||
|
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
|
||||||
|
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::Module,
|
||||||
|
nn::{Linear, LinearConfig},
|
||||||
|
record::{CompactRecorder, Recorder},
|
||||||
|
tensor::{
|
||||||
|
activation::{relu, tanh},
|
||||||
|
backend::Backend,
|
||||||
|
Tensor,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use super::PolicyValueNet;
|
||||||
|
|
||||||
|
// ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Configuration for [`MlpNet`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MlpConfig {
|
||||||
|
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||||
|
pub obs_size: usize,
|
||||||
|
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||||
|
pub action_size: usize,
|
||||||
|
/// Width of both hidden layers.
|
||||||
|
pub hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MlpConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
obs_size: 217,
|
||||||
|
action_size: 514,
|
||||||
|
hidden_size: 256,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Network ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Simple two-hidden-layer MLP with shared trunk and two heads.
|
||||||
|
///
|
||||||
|
/// Prefer this over [`ResNet`](super::ResNet) when training time is a
|
||||||
|
/// priority, or as a fast baseline.
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct MlpNet<B: Backend> {
|
||||||
|
fc1: Linear<B>,
|
||||||
|
fc2: Linear<B>,
|
||||||
|
policy_head: Linear<B>,
|
||||||
|
value_head: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MlpNet<B> {
|
||||||
|
/// Construct a fresh network with random weights.
|
||||||
|
pub fn new(config: &MlpConfig, device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||||
|
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||||
|
policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||||
|
value_head: LinearConfig::new(config.hidden_size, 1).init(device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||||
|
///
|
||||||
|
/// The file is written exactly at `path`; callers should append `.mpk` if
|
||||||
|
/// they want the conventional extension.
|
||||||
|
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||||
|
CompactRecorder::new()
|
||||||
|
.record(self.clone().into_record(), path.to_path_buf())
|
||||||
|
.map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load weights from `path` into a fresh model built from `config`.
|
||||||
|
pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||||
|
let record = CompactRecorder::new()
|
||||||
|
.load(path.to_path_buf(), device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?;
|
||||||
|
Ok(Self::new(config, device).load_record(record))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> PolicyValueNet<B> for MlpNet<B> {
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||||
|
let x = relu(self.fc1.forward(obs));
|
||||||
|
let x = relu(self.fc2.forward(x));
|
||||||
|
let policy = self.policy_head.forward(x.clone());
|
||||||
|
let value = tanh(self.value_head.forward(x));
|
||||||
|
(policy, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_net() -> MlpNet<B> {
|
||||||
|
MlpNet::new(&MlpConfig::default(), &device())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_obs(batch: usize) -> Tensor<B, 2> {
|
||||||
|
Tensor::zeros([batch, 217], &device())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Shape tests ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_output_shapes() {
|
||||||
|
let net = default_net();
|
||||||
|
let obs = zeros_obs(4);
|
||||||
|
let (policy, value) = net.forward(obs);
|
||||||
|
|
||||||
|
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
|
||||||
|
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_single_sample() {
|
||||||
|
let net = default_net();
|
||||||
|
let (policy, value) = net.forward(zeros_obs(1));
|
||||||
|
assert_eq!(policy.dims(), [1, 514]);
|
||||||
|
assert_eq!(value.dims(), [1, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Value bounds ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn value_in_tanh_range() {
|
||||||
|
let net = default_net();
|
||||||
|
// Use a non-zero input so the output is not trivially at 0.
|
||||||
|
let obs = Tensor::<B, 2>::ones([8, 217], &device());
|
||||||
|
let (_, value) = net.forward(obs);
|
||||||
|
let data: Vec<f32> = value.into_data().to_vec().unwrap();
|
||||||
|
for v in &data {
|
||||||
|
assert!(
|
||||||
|
*v > -1.0 && *v < 1.0,
|
||||||
|
"value {v} is outside open interval (-1, 1)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Policy logits ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_logits_not_all_equal() {
|
||||||
|
// With random weights the 514 logits should not all be identical.
|
||||||
|
let net = default_net();
|
||||||
|
let (policy, _) = net.forward(zeros_obs(1));
|
||||||
|
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
|
||||||
|
let first = data[0];
|
||||||
|
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
|
||||||
|
assert!(!all_same, "all policy logits are identical — network may be degenerate");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Config propagation ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_config_shapes() {
|
||||||
|
let config = MlpConfig {
|
||||||
|
obs_size: 10,
|
||||||
|
action_size: 20,
|
||||||
|
hidden_size: 32,
|
||||||
|
};
|
||||||
|
let net = MlpNet::<B>::new(&config, &device());
|
||||||
|
let obs = Tensor::zeros([3, 10], &device());
|
||||||
|
let (policy, value) = net.forward(obs);
|
||||||
|
assert_eq!(policy.dims(), [3, 20]);
|
||||||
|
assert_eq!(value.dims(), [3, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Save / Load ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_load_preserves_weights() {
|
||||||
|
let config = MlpConfig::default();
|
||||||
|
let net = default_net();
|
||||||
|
|
||||||
|
// Forward pass before saving.
|
||||||
|
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||||
|
let (policy_before, value_before) = net.forward(obs.clone());
|
||||||
|
|
||||||
|
// Save to a temp file.
|
||||||
|
let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk");
|
||||||
|
net.save(&path).expect("save failed");
|
||||||
|
|
||||||
|
// Load into a fresh model.
|
||||||
|
let loaded = MlpNet::<B>::load(&config, &path, &device()).expect("load failed");
|
||||||
|
let (policy_after, value_after) = loaded.forward(obs);
|
||||||
|
|
||||||
|
// Outputs must be bitwise identical.
|
||||||
|
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
|
||||||
|
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
|
||||||
|
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
|
||||||
|
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
|
||||||
|
}
|
||||||
|
|
||||||
|
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
|
||||||
|
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
|
||||||
|
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
|
||||||
|
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
64
spiel_bot/src/network/mod.rs
Normal file
64
spiel_bot/src/network/mod.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
//! Neural network abstractions for policy-value learning.
|
||||||
|
//!
|
||||||
|
//! # Trait
|
||||||
|
//!
|
||||||
|
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
|
||||||
|
//! implement. It takes an observation tensor and returns raw policy logits
|
||||||
|
//! plus a tanh-squashed scalar value estimate.
|
||||||
|
//!
|
||||||
|
//! # Architectures
|
||||||
|
//!
|
||||||
|
//! | Module | Description | Default hidden |
|
||||||
|
//! |--------|-------------|----------------|
|
||||||
|
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
|
||||||
|
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
|
||||||
|
//!
|
||||||
|
//! # Backend convention
|
||||||
|
//!
|
||||||
|
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
|
||||||
|
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
|
||||||
|
//! through the forward pass.
|
||||||
|
//!
|
||||||
|
//! Both modes use the exact same struct; only the type-level backend changes:
|
||||||
|
//!
|
||||||
|
//! ```rust,ignore
|
||||||
|
//! use burn::backend::{Autodiff, NdArray};
|
||||||
|
//! type InferBackend = NdArray<f32>;
|
||||||
|
//! type TrainBackend = Autodiff<NdArray<f32>>;
|
||||||
|
//!
|
||||||
|
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
|
||||||
|
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! # Output shapes
|
||||||
|
//!
|
||||||
|
//! Given a batch of `B` observations of size `obs_size`:
|
||||||
|
//!
|
||||||
|
//! | Output | Shape | Range |
|
||||||
|
//! |--------|-------|-------|
|
||||||
|
//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) |
|
||||||
|
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
|
||||||
|
//!
|
||||||
|
//! Callers are responsible for masking illegal actions in `policy_logits`
|
||||||
|
//! before passing to softmax.
|
||||||
|
|
||||||
|
pub mod mlp;
|
||||||
|
pub mod resnet;
|
||||||
|
|
||||||
|
pub use mlp::{MlpConfig, MlpNet};
|
||||||
|
pub use resnet::{ResNet, ResNetConfig};
|
||||||
|
|
||||||
|
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||||
|
|
||||||
|
/// A neural network that produces a policy and a value from an observation.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
/// - `obs`: `[batch, obs_size]`
|
||||||
|
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||||||
|
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||||||
|
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||||||
|
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||||||
|
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||||||
|
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||||||
|
}
|
||||||
253
spiel_bot/src/network/resnet.rs
Normal file
253
spiel_bot/src/network/resnet.rs
Normal file
|
|
@ -0,0 +1,253 @@
|
||||||
|
//! Residual-block policy-value network.
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! Input [B, obs_size]
|
||||||
|
//! → Linear(obs → hidden) → ReLU (input projection)
|
||||||
|
//! → ResBlock × 4 (residual trunk)
|
||||||
|
//! ├─ policy_head: Linear(hidden → action_size) [raw logits]
|
||||||
|
//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)]
|
||||||
|
//!
|
||||||
|
//! ResBlock:
|
||||||
|
//! x → Linear → ReLU → Linear → (+x) → ReLU
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better
|
||||||
|
//! suited for long training runs where board-pattern recognition matters.
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::Module,
|
||||||
|
nn::{Linear, LinearConfig},
|
||||||
|
record::{CompactRecorder, Recorder},
|
||||||
|
tensor::{
|
||||||
|
activation::{relu, tanh},
|
||||||
|
backend::Backend,
|
||||||
|
Tensor,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use super::PolicyValueNet;
|
||||||
|
|
||||||
|
// ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Configuration for [`ResNet`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ResNetConfig {
|
||||||
|
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||||
|
pub obs_size: usize,
|
||||||
|
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||||
|
pub action_size: usize,
|
||||||
|
/// Width of all hidden layers (input projection + residual blocks).
|
||||||
|
pub hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ResNetConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
obs_size: 217,
|
||||||
|
action_size: 514,
|
||||||
|
hidden_size: 512,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Residual block ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`.
|
||||||
|
///
|
||||||
|
/// Both linear layers preserve the hidden dimension so the skip connection
|
||||||
|
/// can be added without projection.
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
struct ResBlock<B: Backend> {
|
||||||
|
fc1: Linear<B>,
|
||||||
|
fc2: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ResBlock<B> {
|
||||||
|
fn new(hidden: usize, device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
fc1: LinearConfig::new(hidden, hidden).init(device),
|
||||||
|
fc2: LinearConfig::new(hidden, hidden).init(device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let residual = x.clone();
|
||||||
|
let out = relu(self.fc1.forward(x));
|
||||||
|
relu(self.fc2.forward(out) + residual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Network ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Four-residual-block policy-value network.
|
||||||
|
///
|
||||||
|
/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and
|
||||||
|
/// when representing complex positional patterns is important.
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct ResNet<B: Backend> {
|
||||||
|
input: Linear<B>,
|
||||||
|
block0: ResBlock<B>,
|
||||||
|
block1: ResBlock<B>,
|
||||||
|
block2: ResBlock<B>,
|
||||||
|
block3: ResBlock<B>,
|
||||||
|
policy_head: Linear<B>,
|
||||||
|
value_head: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ResNet<B> {
|
||||||
|
/// Construct a fresh network with random weights.
|
||||||
|
pub fn new(config: &ResNetConfig, device: &B::Device) -> Self {
|
||||||
|
let h = config.hidden_size;
|
||||||
|
Self {
|
||||||
|
input: LinearConfig::new(config.obs_size, h).init(device),
|
||||||
|
block0: ResBlock::new(h, device),
|
||||||
|
block1: ResBlock::new(h, device),
|
||||||
|
block2: ResBlock::new(h, device),
|
||||||
|
block3: ResBlock::new(h, device),
|
||||||
|
policy_head: LinearConfig::new(h, config.action_size).init(device),
|
||||||
|
value_head: LinearConfig::new(h, 1).init(device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||||
|
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||||
|
CompactRecorder::new()
|
||||||
|
.record(self.clone().into_record(), path.to_path_buf())
|
||||||
|
.map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load weights from `path` into a fresh model built from `config`.
|
||||||
|
pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||||
|
let record = CompactRecorder::new()
|
||||||
|
.load(path.to_path_buf(), device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?;
|
||||||
|
Ok(Self::new(config, device).load_record(record))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> PolicyValueNet<B> for ResNet<B> {
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
|
||||||
|
let x = relu(self.input.forward(obs));
|
||||||
|
let x = self.block0.forward(x);
|
||||||
|
let x = self.block1.forward(x);
|
||||||
|
let x = self.block2.forward(x);
|
||||||
|
let x = self.block3.forward(x);
|
||||||
|
let policy = self.policy_head.forward(x.clone());
|
||||||
|
let value = tanh(self.value_head.forward(x));
|
||||||
|
(policy, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn small_config() -> ResNetConfig {
|
||||||
|
// Use a small hidden size so tests are fast.
|
||||||
|
ResNetConfig {
|
||||||
|
obs_size: 217,
|
||||||
|
action_size: 514,
|
||||||
|
hidden_size: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn net() -> ResNet<B> {
|
||||||
|
ResNet::new(&small_config(), &device())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Shape tests ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_output_shapes() {
|
||||||
|
let obs = Tensor::zeros([4, 217], &device());
|
||||||
|
let (policy, value) = net().forward(obs);
|
||||||
|
assert_eq!(policy.dims(), [4, 514], "policy shape mismatch");
|
||||||
|
assert_eq!(value.dims(), [4, 1], "value shape mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_single_sample() {
|
||||||
|
let (policy, value) = net().forward(Tensor::zeros([1, 217], &device()));
|
||||||
|
assert_eq!(policy.dims(), [1, 514]);
|
||||||
|
assert_eq!(value.dims(), [1, 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Value bounds ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn value_in_tanh_range() {
|
||||||
|
let obs = Tensor::<B, 2>::ones([8, 217], &device());
|
||||||
|
let (_, value) = net().forward(obs);
|
||||||
|
let data: Vec<f32> = value.into_data().to_vec().unwrap();
|
||||||
|
for v in &data {
|
||||||
|
assert!(
|
||||||
|
*v > -1.0 && *v < 1.0,
|
||||||
|
"value {v} is outside open interval (-1, 1)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Residual connections ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_logits_not_all_equal() {
|
||||||
|
let (policy, _) = net().forward(Tensor::zeros([1, 217], &device()));
|
||||||
|
let data: Vec<f32> = policy.into_data().to_vec().unwrap();
|
||||||
|
let first = data[0];
|
||||||
|
let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6);
|
||||||
|
assert!(!all_same, "all policy logits are identical");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Save / Load ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_load_preserves_weights() {
|
||||||
|
let config = small_config();
|
||||||
|
let model = net();
|
||||||
|
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||||
|
|
||||||
|
let (policy_before, value_before) = model.forward(obs.clone());
|
||||||
|
|
||||||
|
let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk");
|
||||||
|
model.save(&path).expect("save failed");
|
||||||
|
|
||||||
|
let loaded = ResNet::<B>::load(&config, &path, &device()).expect("load failed");
|
||||||
|
let (policy_after, value_after) = loaded.forward(obs);
|
||||||
|
|
||||||
|
let p_before: Vec<f32> = policy_before.into_data().to_vec().unwrap();
|
||||||
|
let p_after: Vec<f32> = policy_after.into_data().to_vec().unwrap();
|
||||||
|
for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() {
|
||||||
|
assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance");
|
||||||
|
}
|
||||||
|
|
||||||
|
let v_before: Vec<f32> = value_before.into_data().to_vec().unwrap();
|
||||||
|
let v_after: Vec<f32> = value_after.into_data().to_vec().unwrap();
|
||||||
|
for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() {
|
||||||
|
assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance");
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: both architectures satisfy PolicyValueNet ────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resnet_satisfies_trait() {
|
||||||
|
fn requires_net<B: Backend, N: PolicyValueNet<B>>(net: &N, obs: Tensor<B, 2>) {
|
||||||
|
let (p, v) = net.forward(obs);
|
||||||
|
assert_eq!(p.dims()[1], 514);
|
||||||
|
assert_eq!(v.dims()[1], 1);
|
||||||
|
}
|
||||||
|
requires_net(&net(), Tensor::zeros([2, 217], &device()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -225,22 +225,26 @@ impl GameState {
|
||||||
let mut t = Vec::with_capacity(217);
|
let mut t = Vec::with_capacity(217);
|
||||||
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
|
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
|
||||||
|
|
||||||
// [0..95] own (White) checkers, TD-Gammon encoding
|
// [0..95] own (White) checkers, TD-Gammon encoding.
|
||||||
|
// Each field contributes 4 values:
|
||||||
|
// (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1]
|
||||||
|
// The overflow term is divided by 12 because the maximum excess is
|
||||||
|
// 15 (all checkers) − 3 = 12.
|
||||||
for &c in &pos {
|
for &c in &pos {
|
||||||
let own = c.max(0) as u8;
|
let own = c.max(0) as u8;
|
||||||
t.push((own == 1) as u8 as f32);
|
t.push((own == 1) as u8 as f32);
|
||||||
t.push((own == 2) as u8 as f32);
|
t.push((own == 2) as u8 as f32);
|
||||||
t.push((own == 3) as u8 as f32);
|
t.push((own == 3) as u8 as f32);
|
||||||
t.push(own.saturating_sub(3) as f32);
|
t.push(own.saturating_sub(3) as f32 / 12.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// [96..191] opp (Black) checkers, TD-Gammon encoding
|
// [96..191] opp (Black) checkers, same encoding.
|
||||||
for &c in &pos {
|
for &c in &pos {
|
||||||
let opp = (-c).max(0) as u8;
|
let opp = (-c).max(0) as u8;
|
||||||
t.push((opp == 1) as u8 as f32);
|
t.push((opp == 1) as u8 as f32);
|
||||||
t.push((opp == 2) as u8 as f32);
|
t.push((opp == 2) as u8 as f32);
|
||||||
t.push((opp == 3) as u8 as f32);
|
t.push((opp == 3) as u8 as f32);
|
||||||
t.push(opp.saturating_sub(3) as f32);
|
t.push(opp.saturating_sub(3) as f32 / 12.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// [192..193] dice
|
// [192..193] dice
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue