diff --git a/Cargo.lock b/Cargo.lock index a43261e..0baa02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5891,6 +5891,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spiel_bot" +version = "0.1.0" +dependencies = [ + "anyhow", + "burn", + "rand 0.9.2", + "rand_distr", + "trictrac-store", +] + [[package]] name = "spin" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index b9e6d45..4c2eb15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store"] +members = ["client_cli", "bot", "store", "spiel_bot"] diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md new file mode 100644 index 0000000..7e5ed1f --- /dev/null +++ b/doc/spiel_bot_research.md @@ -0,0 +1,707 @@ +# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac + +## 0. Context and Scope + +The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library +(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` +encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every +other stage to an inline random-opponent loop. + +`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency +for **self-play training**. Its goals: + +- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") + that works with Trictrac's multi-stage turn model and stochastic dice. +- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) + as the first algorithm. +- Remain **modular**: adding DQN or PPO later requires only a new + `impl Algorithm for Dqn` without touching the environment or network layers. +- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from + `trictrac-store`. + +--- + +## 1. Library Landscape + +### 1.1 Neural Network Frameworks + +| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | +|-------|----------|-----|-----------|----------|-------| +| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | +| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | +| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | +| ndarray alone | no | no | yes | mature | array ops only; no autograd | + +**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ +runtime needed, the `ndarray` backend is sufficient for CPU training and can +switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by +changing one type alias. + +`tch-rs` would be the best choice for raw training throughput (it is the most +battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the +pure-Rust constraint. If training speed becomes the bottleneck after prototyping, +switching `spiel_bot` to `tch-rs` is a one-line backend swap. + +### 1.2 Other Key Crates + +| Crate | Role | +|-------|------| +| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | +| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | +| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | +| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | +| `anyhow` | error propagation (already used everywhere) | +| `indicatif` | training progress bars | +| `tracing` | structured logging per episode/iteration | + +### 1.3 What `burn-rl` Provides (and Does Not) + +The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) +provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` +trait. It does **not** provide: + +- MCTS or any tree-search algorithm +- Two-player self-play +- Legal action masking during training +- Chance-node handling + +For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its +own (simpler, more targeted) traits and implement MCTS from scratch. + +--- + +## 2. Trictrac-Specific Design Constraints + +### 2.1 Multi-Stage Turn Model + +A Trictrac turn passes through up to six `TurnStage` values. Only two involve +genuine player choice: + +| TurnStage | Node type | Handler | +|-----------|-----------|---------| +| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | +| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | +| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | +| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | +| `Move` | **Player decision** | MCTS / policy network | +| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | + +The environment wrapper advances through forced/chance stages automatically so +that from the algorithm's perspective every node it sees is a genuine player +decision. + +### 2.2 Stochastic Dice in MCTS + +AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice +introduce stochasticity. Three approaches exist: + +**A. Outcome sampling (recommended)** +During each MCTS simulation, when a chance node is reached, sample one dice +outcome at random and continue. After many simulations the expected value +converges. This is the approach used by OpenSpiel's MCTS for stochastic games +and requires no changes to the standard PUCT formula. + +**B. Chance-node averaging (expectimax)** +At each chance node, expand all 21 unique dice pairs weighted by their +probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is +exact but multiplies the branching factor by ~21 at every dice roll, making it +prohibitively expensive. + +**C. Condition on dice in the observation (current approach)** +Dice values are already encoded at indices [192–193] of `to_tensor()`. The +network naturally conditions on the rolled dice when it evaluates a position. +MCTS only runs on player-decision nodes *after* the dice have been sampled; +chance nodes are bypassed by the environment wrapper (approach A). The policy +and value heads learn to play optimally given any dice pair. + +**Use approach A + C together**: the environment samples dice automatically +(chance node bypass), and the 217-dim tensor encodes the dice so the network +can exploit them. + +### 2.3 Perspective / Mirroring + +All move rules and tensor encoding are defined from White's perspective. +`to_tensor()` must always be called after mirroring the state for Black. +The environment wrapper handles this transparently: every observation returned +to an algorithm is already in the active player's perspective. + +### 2.4 Legal Action Masking + +A crucial difference from the existing `bot/` code: instead of penalizing +invalid actions with `ERROR_REWARD`, the policy head logits are **masked** +before softmax — illegal action logits are set to `-inf`. This prevents the +network from wasting capacity on illegal moves and eliminates the need for the +penalty-reward hack. + +--- + +## 3. Proposed Crate Architecture + +``` +spiel_bot/ +├── Cargo.toml +└── src/ + ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" + │ + ├── env/ + │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface + │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store + │ + ├── mcts/ + │ ├── mod.rs # MctsConfig, run_mcts() entry point + │ ├── node.rs # MctsNode (visit count, W, prior, children) + │ └── search.rs # simulate(), backup(), select_action() + │ + ├── network/ + │ ├── mod.rs # PolicyValueNet trait + │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads + │ + ├── alphazero/ + │ ├── mod.rs # AlphaZeroConfig + │ ├── selfplay.rs # generate_episode() -> Vec + │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) + │ └── trainer.rs # training loop: selfplay → sample → loss → update + │ + └── agent/ + ├── mod.rs # Agent trait + ├── random.rs # RandomAgent (baseline) + └── mcts_agent.rs # MctsAgent: uses trained network for inference +``` + +Future algorithms slot in without touching the above: + +``` + ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer + └── ppo/ # (future) PPO: impl Algorithm + rollout buffer +``` + +--- + +## 4. Core Traits + +### 4.1 `GameEnv` — the minimal OpenSpiel interface + +```rust +use rand::Rng; + +/// Who controls the current node. +pub enum Player { + P1, // player index 0 + P2, // player index 1 + Chance, // dice roll + Terminal, // game over +} + +pub trait GameEnv: Clone + Send + Sync + 'static { + type State: Clone + Send + Sync; + + /// Fresh game state. + fn new_game(&self) -> Self::State; + + /// Who acts at this node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices (always in [0, action_space())). + /// Empty only at Terminal nodes. + fn legal_actions(&self, s: &Self::State) -> Vec; + + /// Apply a player action (must be legal). + fn apply(&self, s: &mut Self::State, action: usize); + + /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. + fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); + + /// Observation tensor from `pov`'s perspective (0 or 1). + /// Returns 217 f32 values for Trictrac. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Flat observation size (217 for Trictrac). + fn obs_size(&self) -> usize; + + /// Total action-space size (514 for Trictrac). + fn action_space(&self) -> usize; + + /// Game outcome per player, or None if not Terminal. + /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} +``` + +### 4.2 `PolicyValueNet` — neural network interface + +```rust +use burn::prelude::*; + +pub trait PolicyValueNet: Send + Sync { + /// Forward pass. + /// `obs`: [batch, obs_size] tensor. + /// Returns: (policy_logits [batch, action_space], value [batch]). + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); + + /// Save weights to `path`. + fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; + + /// Load weights from `path`. + fn load(path: &std::path::Path) -> anyhow::Result + where + Self: Sized; +} +``` + +### 4.3 `Agent` — player policy interface + +```rust +pub trait Agent: Send { + /// Select an action index given the current game state observation. + /// `legal`: mask of valid action indices. + fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; +} +``` + +--- + +## 5. MCTS Implementation + +### 5.1 Node + +```rust +pub struct MctsNode { + n: u32, // visit count N(s, a) + w: f32, // sum of backed-up values W(s, a) + p: f32, // prior from policy head P(s, a) + children: Vec<(usize, MctsNode)>, // (action_idx, child) + is_expanded: bool, +} + +impl MctsNode { + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT score used for selection. + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} +``` + +### 5.2 Simulation Loop + +One MCTS simulation (for deterministic decision nodes): + +``` +1. SELECTION — traverse from root, always pick child with highest PUCT, + auto-advancing forced/chance nodes via env.apply_chance(). +2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get + (policy_logits, value). Mask illegal actions, softmax + the remaining logits → priors P(s,a) for each child. +3. BACKUP — propagate -value up the tree (negate at each level because + perspective alternates between P1 and P2). +``` + +After `n_simulations` iterations, action selection at the root: + +```rust +// During training: sample proportional to N^(1/temperature) +// During evaluation: argmax N +fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } +``` + +### 5.3 Configuration + +```rust +pub struct MctsConfig { + pub n_simulations: usize, // e.g. 200 + pub c_puct: f32, // exploration constant, e.g. 1.5 + pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 + pub dirichlet_eps: f32, // noise weight, e.g. 0.25 + pub temperature: f32, // action sampling temperature (anneals to 0) +} +``` + +### 5.4 Handling Chance Nodes Inside MCTS + +When simulation reaches a Chance node (dice roll), the environment automatically +samples dice and advances to the next decision node. The MCTS tree does **not** +branch on dice outcomes — it treats the sampled outcome as the state. This +corresponds to "outcome sampling" (approach A from §2.2). Because each +simulation independently samples dice, the Q-values at player nodes converge to +their expected value over many simulations. + +--- + +## 6. Network Architecture + +### 6.1 ResNet Policy-Value Network + +A single trunk with residual blocks, then two heads: + +``` +Input: [batch, 217] + ↓ +Linear(217 → 512) + ReLU + ↓ +ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) + ↓ trunk output [batch, 512] + ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) + └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) +``` + +Burn implementation sketch: + +```rust +#[derive(Module, Debug)] +pub struct TrictracNet { + input: Linear, + res_blocks: Vec>, + policy_head: Linear, + value_head: Linear, +} + +impl TrictracNet { + pub fn forward(&self, obs: Tensor) + -> (Tensor, Tensor) + { + let x = activation::relu(self.input.forward(obs)); + let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); + let policy = self.policy_head.forward(x.clone()); // raw logits + let value = activation::tanh(self.value_head.forward(x)) + .squeeze(1); + (policy, value) + } +} +``` + +A simpler MLP (no residual blocks) is sufficient for a first version and much +faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two +heads. + +### 6.2 Loss Function + +``` +L = MSE(value_pred, z) + + CrossEntropy(policy_logits_masked, π_mcts) + - c_l2 * L2_regularization +``` + +Where: +- `z` = game outcome (±1) from the active player's perspective +- `π_mcts` = normalized MCTS visit counts at the root (the policy target) +- Legal action masking is applied before computing CrossEntropy + +--- + +## 7. AlphaZero Training Loop + +``` +INIT + network ← random weights + replay ← empty ReplayBuffer(capacity = 100_000) + +LOOP forever: + ── Self-play phase ────────────────────────────────────────────── + (parallel with rayon, n_workers games at once) + for each game: + state ← env.new_game() + samples = [] + while not terminal: + advance forced/chance nodes automatically + obs ← env.observation(state, current_player) + legal ← env.legal_actions(state) + π, root_value ← mcts.run(state, network, config) + action ← sample from π (with temperature) + samples.push((obs, π, current_player)) + env.apply(state, action) + z ← env.returns(state) // final scores + for (obs, π, player) in samples: + replay.push(TrainSample { obs, policy: π, value: z[player] }) + + ── Training phase ─────────────────────────────────────────────── + for each gradient step: + batch ← replay.sample(batch_size) + (policy_logits, value_pred) ← network.forward(batch.obs) + loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) + optimizer.step(loss.backward()) + + ── Evaluation (every N iterations) ───────────────────────────── + win_rate ← evaluate(network_new vs network_prev, n_eval_games) + if win_rate > 0.55: save checkpoint +``` + +### 7.1 Replay Buffer + +```rust +pub struct TrainSample { + pub obs: Vec, // 217 values + pub policy: Vec, // 514 values (normalized MCTS visit counts) + pub value: f32, // game outcome ∈ {-1, 0, +1} +} + +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + pub fn push(&mut self, s: TrainSample) { + if self.data.len() == self.capacity { self.data.pop_front(); } + self.data.push_back(s); + } + + pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + // sample without replacement + } +} +``` + +### 7.2 Parallelism Strategy + +Self-play is embarrassingly parallel (each game is independent): + +```rust +let samples: Vec = (0..n_games) + .into_par_iter() // rayon + .flat_map(|_| generate_episode(&env, &network, &mcts_config)) + .collect(); +``` + +Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. +Self-play uses inference-only (no gradient tape), so a `NdArray` backend +(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with +`Autodiff>`. + +For larger scale, a producer-consumer architecture (crossbeam-channel) separates +self-play workers from the training thread, allowing continuous data generation +while the GPU trains. + +--- + +## 8. `TrictracEnv` Implementation Sketch + +```rust +use trictrac_store::{ + training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, + Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, +}; + +#[derive(Clone)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + fn current_player(&self, s: &GameState) -> Player { + match s.stage { + Stage::Ended => Player::Terminal, + _ => match s.turn_stage { + TurnStage::RollWaiting => Player::Chance, + _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, + }, + } + } + + fn legal_actions(&self, s: &GameState) -> Vec { + let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; + get_valid_action_indices(&view).unwrap_or_default() + } + + fn apply(&self, s: &mut GameState, action_idx: usize) { + // advance all forced/chance nodes first, then apply the player action + self.advance_forced(s); + let needs_mirror = s.active_player_id == 2; + let view = if needs_mirror { s.mirror() } else { s.clone() }; + if let Some(event) = TrictracAction::from_action_index(action_idx) + .and_then(|a| a.to_event(&view)) + .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) + { + let _ = s.consume(&event); + } + // advance any forced stages that follow + self.advance_forced(s); + } + + fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { + // RollDice → RollWaiting + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + // RollWaiting → next stage + let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; + let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); + self.advance_forced(s); + } + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } + } + + fn obs_size(&self) -> usize { 217 } + fn action_space(&self) -> usize { ACTION_SPACE_SIZE } + + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { return None; } + // Convert hole+point scores to ±1 outcome + let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [ 1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [ 0.0, 0.0], + }) + } +} + +impl TrictracEnv { + /// Advance through all forced (non-decision, non-chance) stages. + fn advance_forced(&self, s: &mut GameState) { + use trictrac_store::PointsRules; + loop { + match s.turn_stage { + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // Scoring is deterministic; compute and apply automatically. + let color = s.player_color_by_id(&s.active_player_id) + .unwrap_or(trictrac_store::Color::White); + let drc = s.players.get(&s.active_player_id) + .map(|p| p.dice_roll_count).unwrap_or(0); + let pr = PointsRules::new(&color, &s.board, s.dice); + let pts = pr.get_points(drc); + let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; + let _ = s.consume(&GameEvent::Mark { + player_id: s.active_player_id, points, + }); + } + TurnStage::RollDice => { + // RollDice is a forced "initiate roll" action with no real choice. + let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); + } + _ => break, + } + } + } +} +``` + +--- + +## 9. Cargo.toml Changes + +### 9.1 Add `spiel_bot` to the workspace + +```toml +# Cargo.toml (workspace root) +[workspace] +resolver = "2" +members = ["client_cli", "bot", "store", "spiel_bot"] +``` + +### 9.2 `spiel_bot/Cargo.toml` + +```toml +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[features] +default = ["alphazero"] +alphazero = [] +# dqn = [] # future +# ppo = [] # future + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rayon = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Burn: NdArray for pure-Rust CPU training +# Replace NdArray with Wgpu or Tch for GPU. +burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +# Optional: progress display and structured logging +indicatif = "0.17" +tracing = "0.1" + +[[bin]] +name = "az_train" +path = "src/bin/az_train.rs" + +[[bin]] +name = "az_eval" +path = "src/bin/az_eval.rs" +``` + +--- + +## 10. Comparison: `bot` crate vs `spiel_bot` + +| Aspect | `bot` (existing) | `spiel_bot` (proposed) | +|--------|-----------------|------------------------| +| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | +| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | +| Opponent | hardcoded random | self-play | +| Invalid actions | penalise with reward | legal action mask (no penalty) | +| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | +| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | +| Burn dep | yes (0.20) | yes (0.20), same backend | +| `burn-rl` dep | yes | no | +| C++ dep | no | no | +| Python dep | no | no | +| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | + +The two crates are **complementary**: `bot` is a working DQN/PPO baseline; +`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The +`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just +replace `TrictracEnvironment` with `TrictracEnv`). + +--- + +## 11. Implementation Order + +1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game + through the trait, verify terminal state and returns). +2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + + Burn forward/backward pass test with dummy data. +3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests + (visit counts sum to `n_simulations`, legal mask respected). +4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub + (one iteration, check loss decreases). +5. **Integration test**: run 100 self-play games with a tiny network (1 res block, + 64 hidden units), verify the training loop completes without panics. +6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s + on CPU, consistent with `random_game` throughput). +7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule + hyperparameter sweep. +8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report + win rate every checkpoint. + +--- + +## 12. Key Open Questions + +1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. + AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever + scored more holes). Better option: normalize the score margin. The final + choice affects how the value head is trained. + +2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). + MCTS with 200 simulations per step means ~120k network evaluations per game. + At batch inference this is feasible on CPU; on GPU it becomes fast. Consider + limiting `n_simulations` to 50–100 for early training. + +3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). + This is a long-horizon decision that AlphaZero handles naturally via MCTS + lookahead, but needs careful value normalization (a "Go" restarts scoring + within the same game). + +4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated + to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. + This is optional but reduces code duplication. + +5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, + 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. + A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml new file mode 100644 index 0000000..323c953 --- /dev/null +++ b/spiel_bot/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "spiel_bot" +version = "0.1.0" +edition = "2021" + +[dependencies] +trictrac-store = { path = "../store" } +anyhow = "1" +rand = "0.9" +rand_distr = "0.5" +burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs new file mode 100644 index 0000000..bb86724 --- /dev/null +++ b/spiel_bot/src/alphazero/mod.rs @@ -0,0 +1,117 @@ +//! AlphaZero: self-play data generation, replay buffer, and training step. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | +//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | +//! | [`trainer`] | [`train_step`] | +//! +//! # Typical outer loop +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! use burn::optim::AdamConfig; +//! use spiel_bot::{ +//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, +//! env::TrictracEnv, +//! mcts::MctsConfig, +//! network::{MlpConfig, MlpNet}, +//! }; +//! +//! type Infer = NdArray; +//! type Train = Autodiff>; +//! +//! let device = Default::default(); +//! let env = TrictracEnv; +//! let config = AlphaZeroConfig::default(); +//! +//! // Build training model and optimizer. +//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); +//! let mut optimizer = AdamConfig::new().init(); +//! let mut replay = ReplayBuffer::new(config.replay_capacity); +//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); +//! +//! for _iter in 0..config.n_iterations { +//! // Convert to inference backend for self-play. +//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) +//! .load_record(train_model.clone().into_record()); +//! let eval = BurnEvaluator::new(infer_model, device.clone()); +//! +//! // Self-play: generate episodes. +//! for _ in 0..config.n_games_per_iter { +//! let samples = generate_episode(&env, &eval, &config.mcts, +//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); +//! replay.extend(samples); +//! } +//! +//! // Training: gradient steps. +//! if replay.len() >= config.batch_size { +//! for _ in 0..config.n_train_steps_per_iter { +//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) +//! .into_iter().cloned().collect(); +//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, +//! config.learning_rate); +//! train_model = m; +//! } +//! } +//! } +//! ``` + +pub mod replay; +pub mod selfplay; +pub mod trainer; + +pub use replay::{ReplayBuffer, TrainSample}; +pub use selfplay::{BurnEvaluator, generate_episode}; +pub use trainer::train_step; + +use crate::mcts::MctsConfig; + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Top-level AlphaZero hyperparameters. +/// +/// The MCTS parameters live in [`MctsConfig`]; this struct holds the +/// outer training-loop parameters. +#[derive(Debug, Clone)] +pub struct AlphaZeroConfig { + /// MCTS parameters for self-play. + pub mcts: MctsConfig, + /// Number of self-play games per training iteration. + pub n_games_per_iter: usize, + /// Number of gradient steps per training iteration. + pub n_train_steps_per_iter: usize, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Number of outer iterations (self-play + train) to run. + pub n_iterations: usize, + /// Move index after which the action temperature drops to 0 (greedy play). + pub temperature_drop_move: usize, +} + +impl Default for AlphaZeroConfig { + fn default() -> Self { + Self { + mcts: MctsConfig { + n_simulations: 100, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temperature: 1.0, + }, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + batch_size: 64, + replay_capacity: 50_000, + learning_rate: 1e-3, + n_iterations: 100, + temperature_drop_move: 30, + } + } +} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs new file mode 100644 index 0000000..5e64cc4 --- /dev/null +++ b/spiel_bot/src/alphazero/replay.rs @@ -0,0 +1,144 @@ +//! Replay buffer for AlphaZero self-play data. + +use std::collections::VecDeque; +use rand::Rng; + +// ── Training sample ──────────────────────────────────────────────────────── + +/// One training example produced by self-play. +#[derive(Clone, Debug)] +pub struct TrainSample { + /// Observation tensor from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). + pub policy: Vec, + /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. + pub value: f32, +} + +// ── Replay buffer ────────────────────────────────────────────────────────── + +/// Fixed-capacity circular buffer of [`TrainSample`]s. +/// +/// When the buffer is full, the oldest sample is evicted on push. +/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + /// Create a buffer with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity.min(1024)), + capacity, + } + } + + /// Add a sample; evicts the oldest if at capacity. + pub fn push(&mut self, sample: TrainSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + /// Add all samples from an episode. + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { + self.push(s); + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Sample up to `n` distinct samples, without replacement. + /// + /// If the buffer has fewer than `n` samples, all are returned (shuffled). + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + let len = self.data.len(); + let n = n.min(len); + // Partial Fisher-Yates using index shuffling. + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(value: f32) -> TrainSample { + TrainSample { obs: vec![value], policy: vec![1.0], value } + } + + #[test] + fn push_and_len() { + let mut buf = ReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = ReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); // evicts 1.0 + assert_eq!(buf.len(), 3); + // Oldest remaining should be 2.0 + assert_eq!(buf.data[0].value, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(5, &mut rng); + assert_eq!(batch.len(), 5); + } + + #[test] + fn sample_batch_capped_at_len() { + let mut buf = ReplayBuffer::new(20); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(100, &mut rng); + assert_eq!(batch.len(), 2); + } + + #[test] + fn sample_batch_no_duplicates() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(1); + let batch = buf.sample_batch(10, &mut rng); + let mut seen: Vec = batch.iter().map(|s| s.value).collect(); + seen.sort_by(f32::total_cmp); + seen.dedup(); + assert_eq!(seen.len(), 10, "sample contained duplicates"); + } +} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs new file mode 100644 index 0000000..6f10f8d --- /dev/null +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -0,0 +1,234 @@ +//! Self-play episode generation and Burn-backed evaluator. + +use std::marker::PhantomData; + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::GameEnv; +use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; +use crate::network::PolicyValueNet; + +use super::replay::TrainSample; + +// ── BurnEvaluator ────────────────────────────────────────────────────────── + +/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. +/// +/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so +/// that self-play generates no gradient tape overhead. +pub struct BurnEvaluator> { + model: N, + device: B::Device, + _b: PhantomData, +} + +impl> BurnEvaluator { + pub fn new(model: N, device: B::Device) -> Self { + Self { model, device, _b: PhantomData } + } + + pub fn into_model(self) -> N { + self.model + } +} + +// Safety: NdArray modules are Send; we never share across threads without +// external synchronisation. +unsafe impl> Send for BurnEvaluator {} +unsafe impl> Sync for BurnEvaluator {} + +impl> Evaluator for BurnEvaluator { + fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { + let obs_size = obs.len(); + let data = TensorData::new(obs.to_vec(), [1, obs_size]); + let obs_tensor = Tensor::::from_data(data, &self.device); + + let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); + + let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); + let value: Vec = value_tensor.into_data().to_vec().unwrap(); + + (policy, value[0]) + } +} + +// ── Episode generation ───────────────────────────────────────────────────── + +/// One pending observation waiting for its game-outcome value label. +struct PendingSample { + obs: Vec, + policy: Vec, + player: usize, +} + +/// Play one full game using MCTS guided by `evaluator`. +/// +/// Returns a [`TrainSample`] for every decision step in the game. +/// +/// `temperature_fn(step)` controls exploration: return `1.0` for early +/// moves and `0.0` after a fixed number of moves (e.g. move 30). +pub fn generate_episode( + env: &E, + evaluator: &dyn Evaluator, + mcts_config: &MctsConfig, + temperature_fn: &dyn Fn(usize) -> f32, + rng: &mut impl Rng, +) -> Vec { + let mut state = env.new_game(); + let mut pending: Vec = Vec::new(); + let mut step = 0usize; + + loop { + // Advance through chance nodes automatically. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + if env.current_player(&state).is_terminal() { + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + + // Run MCTS to get a policy. + let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); + let policy = mcts::mcts_policy(&root, env.action_space()); + + // Record the observation from the acting player's perspective. + let obs = env.observation(&state, player_idx); + pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); + + // Select and apply the action. + let temperature = temperature_fn(step); + let action = mcts::select_action(&root, temperature, rng); + env.apply(&mut state, action); + step += 1; + } + + // Assign game outcomes. + let returns = env.returns(&state).unwrap_or([0.0; 2]); + pending + .into_iter() + .map(|s| TrainSample { + obs: s.obs, + policy: s.policy, + value: returns[s.player], + }) + .collect() +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::env::Player; + use crate::mcts::{Evaluator, MctsConfig}; + use crate::network::{MlpConfig, MlpNet}; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(7) + } + + // Countdown game (same as in mcts tests). + #[derive(Clone, Debug)] + struct CState { remaining: u8, to_move: usize } + + #[derive(Clone)] + struct CountdownEnv; + + impl GameEnv for CountdownEnv { + type State = CState; + fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } else { Player::P2 } + } + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { s.remaining = 0; } + else { s.remaining -= sub; s.to_move = 1 - s.to_move; } + } + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 4.0, s.to_move as f32] + } + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + fn tiny_config() -> MctsConfig { + MctsConfig { n_simulations: 5, c_puct: 1.5, + dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } + } + + // ── BurnEvaluator tests ─────────────────────────────────────────────── + + #[test] + fn burn_evaluator_output_shapes() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); + assert_eq!(policy.len(), 2, "policy length should equal action_space"); + assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); + } + + // ── generate_episode tests ──────────────────────────────────────────── + + #[test] + fn episode_terminates_and_has_samples() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_sample_values_are_valid() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + for s in &samples { + assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "unexpected value {}", s.value); + let sum: f32 = s.policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); + assert_eq!(s.obs.len(), 2); + } + } + + #[test] + fn episode_with_temperature_zero() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + // temperature=0 means greedy; episode must still terminate + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); + assert!(!samples.is_empty()); + } +} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs new file mode 100644 index 0000000..d2482d1 --- /dev/null +++ b/spiel_bot/src/alphazero/trainer.rs @@ -0,0 +1,172 @@ +//! One gradient-descent training step for AlphaZero. +//! +//! The loss combines: +//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. +//! - **Value loss** — mean-squared error between the predicted value and the +//! actual game outcome. +//! +//! # Backend +//! +//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). +//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. +//! Weights are transferred between the two via [`burn::record`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + activation::log_softmax, + backend::AutodiffBackend, + Tensor, TensorData, + }, +}; + +use crate::network::PolicyValueNet; +use super::replay::TrainSample; + +/// Run one gradient step on `model` using `batch`. +/// +/// Returns the updated model and the scalar loss value for logging. +/// +/// # Parameters +/// +/// - `lr` — learning rate (e.g. `1e-3`). +/// - `batch` — slice of [`TrainSample`]s; must be non-empty. +pub fn train_step( + model: N, + optimizer: &mut O, + batch: &[TrainSample], + device: &B::Device, + lr: f64, +) -> (N, f32) +where + B: AutodiffBackend, + N: PolicyValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "train_step called with empty batch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + let action_size = batch[0].policy.len(); + + // ── Build input tensors ──────────────────────────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); + let value_flat: Vec = batch.iter().map(|s| s.value).collect(); + + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + let policy_target = Tensor::::from_data( + TensorData::new(policy_flat, [batch_size, action_size]), + device, + ); + let value_target = Tensor::::from_data( + TensorData::new(value_flat, [batch_size, 1]), + device, + ); + + // ── Forward pass ────────────────────────────────────────────────────── + let (policy_logits, value_pred) = model.forward(obs_tensor); + + // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── + let log_probs = log_softmax(policy_logits, 1); + let policy_loss = (policy_target.clone().neg() * log_probs) + .sum_dim(1) + .mean(); + + // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── + let diff = value_pred - value_target; + let value_loss = (diff.clone() * diff).mean(); + + // ── Combined loss ───────────────────────────────────────────────────── + let loss = policy_loss + value_loss; + + // Extract scalar before backward (consumes the tensor). + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + let model = optimizer.step(lr, model, grads); + + (model, loss_scalar) +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + + use crate::network::{MlpConfig, MlpNet}; + use super::super::replay::TrainSample; + + type B = Autodiff>; + + fn device() -> ::Device { + Default::default() + } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5f32; obs_size], + policy: { + let mut p = vec![0.0f32; action_size]; + p[i % action_size] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + } + + #[test] + fn train_step_returns_finite_loss() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); + } + + #[test] + fn loss_decreases_over_steps() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + // Same batch every step — loss should decrease. + let batch = dummy_batch(16, 4, 4); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); + model = m; + assert!(loss.is_finite()); + prev_loss = loss; + } + // After 10 steps on fixed data, loss should be below a reasonable threshold. + assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn train_step_batch_size_one() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(1, 2, 2); + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite()); + } +} diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs new file mode 100644 index 0000000..42b4ae0 --- /dev/null +++ b/spiel_bot/src/env/mod.rs @@ -0,0 +1,121 @@ +//! Game environment abstraction — the minimal "Rust OpenSpiel". +//! +//! A `GameEnv` describes the rules of a two-player, zero-sum game that may +//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, +//! and PPO interact with a game exclusively through this trait. +//! +//! # Node taxonomy +//! +//! Every game position belongs to one of four categories, returned by +//! [`GameEnv::current_player`]: +//! +//! | [`Player`] | Meaning | +//! |-----------|---------| +//! | `P1` | Player 1 (index 0) must choose an action | +//! | `P2` | Player 2 (index 1) must choose an action | +//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | +//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | +//! +//! # Perspective convention +//! +//! [`GameEnv::observation`] always returns the board from *the requested +//! player's* point of view. Callers pass `pov = 0` for Player 1 and +//! `pov = 1` for Player 2. The implementation is responsible for any +//! mirroring required (e.g. Trictrac always reasons from White's side). + +pub mod trictrac; +pub use trictrac::TrictracEnv; + +/// Who controls the current game node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Player { + /// Player 1 (index 0) is to move. + P1, + /// Player 2 (index 1) is to move. + P2, + /// A stochastic event (dice roll, etc.) must be resolved. + Chance, + /// The game is over. + Terminal, +} + +impl Player { + /// Returns the player index (0 or 1) if this is a decision node, + /// or `None` for `Chance` / `Terminal`. + pub fn index(self) -> Option { + match self { + Player::P1 => Some(0), + Player::P2 => Some(1), + _ => None, + } + } + + pub fn is_decision(self) -> bool { + matches!(self, Player::P1 | Player::P2) + } + + pub fn is_chance(self) -> bool { + self == Player::Chance + } + + pub fn is_terminal(self) -> bool { + self == Player::Terminal + } +} + +/// Trait that completely describes a two-player zero-sum game. +/// +/// Implementors must be cheaply cloneable (the type is used as a stateless +/// factory; the mutable game state lives in `Self::State`). +pub trait GameEnv: Clone + Send + Sync + 'static { + /// The mutable game state. Must be `Clone` so MCTS can copy + /// game trees without touching the environment. + type State: Clone + Send + Sync; + + // ── State creation ──────────────────────────────────────────────────── + + /// Create a fresh game state at the initial position. + fn new_game(&self) -> Self::State; + + // ── Node queries ────────────────────────────────────────────────────── + + /// Classify the current node. + fn current_player(&self, s: &Self::State) -> Player; + + /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). + /// + /// The returned indices are in `[0, action_space())`. + /// The result is unspecified (may panic or return empty) when called at a + /// `Chance` or `Terminal` node. + fn legal_actions(&self, s: &Self::State) -> Vec; + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action. `action` must be a value returned by + /// [`legal_actions`] for the current state. + fn apply(&self, s: &mut Self::State, action: usize); + + /// Sample and apply a stochastic outcome. Must only be called when + /// `current_player(s) == Player::Chance`. + fn apply_chance(&self, s: &mut Self::State, rng: &mut R); + + // ── Observation ─────────────────────────────────────────────────────── + + /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). + /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. + fn observation(&self, s: &Self::State, pov: usize) -> Vec; + + /// Number of floats returned by [`observation`]. + fn obs_size(&self) -> usize; + + /// Total number of distinct action indices (the policy head output size). + fn action_space(&self) -> usize; + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Game outcome for each player, or `None` if the game is not over. + /// + /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. + /// Index 0 = Player 1, index 1 = Player 2. + fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs new file mode 100644 index 0000000..99ba058 --- /dev/null +++ b/spiel_bot/src/env/trictrac.rs @@ -0,0 +1,535 @@ +//! [`GameEnv`] implementation for Trictrac. +//! +//! # Game flow (schools_enabled = false) +//! +//! With scoring schools disabled (the standard training configuration), +//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine +//! applies them automatically inside `RollResult` and `Move`. The only +//! four stages that actually occur are: +//! +//! | `TurnStage` | [`Player`] kind | Handled by | +//! |-------------|-----------------|------------| +//! | `RollDice` | `Chance` | [`apply_chance`] | +//! | `RollWaiting` | `Chance` | [`apply_chance`] | +//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | +//! | `Move` | `P1`/`P2` | [`apply`] | +//! +//! # Perspective +//! +//! The Trictrac engine always reasons from White's perspective. Player 1 is +//! White; Player 2 is Black. When Player 2 is active, the board is mirrored +//! before computing legal actions / the observation tensor, and the resulting +//! event is mirrored back before being applied to the real state. This +//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. + +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, + Dice, GameEvent, GameState, Stage, TurnStage, +}; + +use super::{GameEnv, Player}; + +/// Stateless factory that produces Trictrac [`GameState`] environments. +/// +/// Schools (`schools_enabled`) are always disabled — scoring is automatic. +#[derive(Clone, Debug, Default)] +pub struct TrictracEnv; + +impl GameEnv for TrictracEnv { + type State = GameState; + + // ── State creation ──────────────────────────────────────────────────── + + fn new_game(&self) -> GameState { + GameState::new_with_players("P1", "P2") + } + + // ── Node queries ────────────────────────────────────────────────────── + + fn current_player(&self, s: &GameState) -> Player { + if s.stage == Stage::Ended { + return Player::Terminal; + } + match s.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, + _ => { + if s.active_player_id == 1 { + Player::P1 + } else { + Player::P2 + } + } + } + } + + /// Returns the legal action indices for the active player. + /// + /// The board is automatically mirrored for Player 2 so that the engine + /// always reasons from White's perspective. The returned indices are + /// identical in meaning for both players (checker ordinals are + /// perspective-relative). + /// + /// # Panics + /// + /// Panics in debug builds if called at a `Chance` or `Terminal` node. + fn legal_actions(&self, s: &GameState) -> Vec { + debug_assert!( + self.current_player(s).is_decision(), + "legal_actions called at a non-decision node (turn_stage={:?})", + s.turn_stage + ); + let indices = if s.active_player_id == 2 { + get_valid_action_indices(&s.mirror()) + } else { + get_valid_action_indices(s) + }; + indices.unwrap_or_default() + } + + // ── State mutation ──────────────────────────────────────────────────── + + /// Apply a player action index to the game state. + /// + /// For Player 2, the action is decoded against the mirrored board and + /// the resulting event is un-mirrored before being applied. + /// + /// # Panics + /// + /// Panics in debug builds if `action` cannot be decoded or does not + /// produce a valid event for the current state. + fn apply(&self, s: &mut GameState, action: usize) { + let needs_mirror = s.active_player_id == 2; + + let event = if needs_mirror { + let view = s.mirror(); + TrictracAction::from_action_index(action) + .and_then(|a| a.to_event(&view)) + .map(|e| e.get_mirror(false)) + } else { + TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) + }; + + match event { + Some(e) => { + s.consume(&e).expect("apply: consume failed for valid action"); + } + None => { + panic!("apply: action index {action} produced no event in state {s}"); + } + } + } + + /// Sample dice and advance through a chance node. + /// + /// Handles both `RollDice` (triggers the roll mechanism, then samples + /// dice) and `RollWaiting` (only samples dice) in a single call so that + /// callers never need to distinguish the two. + /// + /// # Panics + /// + /// Panics in debug builds if called at a non-Chance node. + fn apply_chance(&self, s: &mut GameState, rng: &mut R) { + debug_assert!( + self.current_player(s).is_chance(), + "apply_chance called at a non-Chance node (turn_stage={:?})", + s.turn_stage + ); + + // Step 1: RollDice → RollWaiting (player initiates the roll). + if s.turn_stage == TurnStage::RollDice { + s.consume(&GameEvent::Roll { + player_id: s.active_player_id, + }) + .expect("apply_chance: Roll event failed"); + } + + // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. + // With schools_enabled=false, point marking is automatic inside consume(). + let dice = Dice { + values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), + }; + s.consume(&GameEvent::RollResult { + player_id: s.active_player_id, + dice, + }) + .expect("apply_chance: RollResult event failed"); + } + + // ── Observation ─────────────────────────────────────────────────────── + + fn observation(&self, s: &GameState, pov: usize) -> Vec { + if pov == 0 { + s.to_tensor() + } else { + s.mirror().to_tensor() + } + } + + fn obs_size(&self) -> usize { + 217 + } + + fn action_space(&self) -> usize { + ACTION_SPACE_SIZE + } + + // ── Terminal values ─────────────────────────────────────────────────── + + /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. + /// + /// The winner (higher cumulative score) receives `+1.0`; the loser + /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score + /// is `holes × 12 + points`. + fn returns(&self, s: &GameState) -> Option<[f32; 2]> { + if s.stage != Stage::Ended { + return None; + } + let score = |id: u64| -> i32 { + s.players + .get(&id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + }; + let s1 = score(1); + let s2 = score(2); + Some(match s1.cmp(&s2) { + std::cmp::Ordering::Greater => [1.0, -1.0], + std::cmp::Ordering::Less => [-1.0, 1.0], + std::cmp::Ordering::Equal => [0.0, 0.0], + }) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{rngs::SmallRng, Rng, SeedableRng}; + + fn env() -> TrictracEnv { + TrictracEnv + } + + fn seeded_rng(seed: u64) -> SmallRng { + SmallRng::seed_from_u64(seed) + } + + // ── Initial state ───────────────────────────────────────────────────── + + #[test] + fn new_game_is_chance_node() { + let e = env(); + let s = e.new_game(); + // A fresh game starts at RollDice — a Chance node. + assert_eq!(e.current_player(&s), Player::Chance); + assert!(e.returns(&s).is_none()); + } + + #[test] + fn new_game_is_not_terminal() { + let e = env(); + let s = e.new_game(); + assert_ne!(e.current_player(&s), Player::Terminal); + assert!(e.returns(&s).is_none()); + } + + // ── Chance nodes ────────────────────────────────────────────────────── + + #[test] + fn apply_chance_reaches_decision_node() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(1); + + // A single chance step must yield a decision node (or end the game, + // which only happens after 12 holes — impossible on the first roll). + e.apply_chance(&mut s, &mut rng); + let p = e.current_player(&s); + assert!( + p.is_decision(), + "expected decision node after first roll, got {p:?}" + ); + } + + #[test] + fn apply_chance_from_rollwaiting() { + // Check that apply_chance works when called mid-way (at RollWaiting). + let e = env(); + let mut s = e.new_game(); + assert_eq!(s.turn_stage, TurnStage::RollDice); + + // Manually advance to RollWaiting. + s.consume(&GameEvent::Roll { player_id: s.active_player_id }) + .unwrap(); + assert_eq!(s.turn_stage, TurnStage::RollWaiting); + + let mut rng = seeded_rng(2); + e.apply_chance(&mut s, &mut rng); + + let p = e.current_player(&s); + assert!(p.is_decision() || p.is_terminal()); + } + + // ── Legal actions ───────────────────────────────────────────────────── + + #[test] + fn legal_actions_nonempty_after_roll() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(3); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let actions = e.legal_actions(&s); + assert!( + !actions.is_empty(), + "legal_actions must be non-empty at a decision node" + ); + } + + #[test] + fn legal_actions_within_action_space() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(4); + + e.apply_chance(&mut s, &mut rng); + for &a in e.legal_actions(&s).iter() { + assert!( + a < e.action_space(), + "action {a} out of bounds (action_space={})", + e.action_space() + ); + } + } + + // ── Observations ────────────────────────────────────────────────────── + + #[test] + fn observation_has_correct_size() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(5); + e.apply_chance(&mut s, &mut rng); + + assert_eq!(e.observation(&s, 0).len(), e.obs_size()); + assert_eq!(e.observation(&s, 1).len(), e.obs_size()); + } + + #[test] + fn observation_values_in_unit_interval() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(6); + e.apply_chance(&mut s, &mut rng); + + for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { + for (i, &v) in obs.iter().enumerate() { + assert!( + v >= 0.0 && v <= 1.0, + "pov={pov}: obs[{i}] = {v} is outside [0,1]" + ); + } + } + } + + #[test] + fn p1_and_p2_observations_differ() { + // The board is mirrored for P2, so the two observations should differ + // whenever there are checkers in non-symmetric positions (always true + // in a real game after a few moves). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(7); + + // Advance far enough that the board is non-trivial. + for _ in 0..6 { + while e.current_player(&s).is_chance() { + e.apply_chance(&mut s, &mut rng); + } + if e.current_player(&s).is_terminal() { + break; + } + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + + if !e.current_player(&s).is_terminal() { + let obs0 = e.observation(&s, 0); + let obs1 = e.observation(&s, 1); + assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); + } + } + + // ── Applying actions ────────────────────────────────────────────────── + + #[test] + fn apply_changes_state() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(8); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + let before = s.clone(); + let action = e.legal_actions(&s)[0]; + e.apply(&mut s, action); + + assert_ne!( + before.turn_stage, s.turn_stage, + "state must change after apply" + ); + } + + #[test] + fn apply_all_legal_actions_do_not_panic() { + // Verify that every action returned by legal_actions can be applied + // without panicking (on several independent copies of the same state). + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(9); + + e.apply_chance(&mut s, &mut rng); + assert!(e.current_player(&s).is_decision()); + + for action in e.legal_actions(&s) { + let mut copy = s.clone(); + e.apply(&mut copy, action); // must not panic + } + } + + // ── Full game ───────────────────────────────────────────────────────── + + /// Run a complete game with random actions through the `GameEnv` trait + /// and verify that: + /// - The game terminates. + /// - `returns()` is `Some` at the end. + /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's + /// score is ±1 / 0. + /// - No step panics. + #[test] + fn full_random_game_terminates() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(42); + let max_steps = 50_000; + + for step in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); + } + + let result = e.returns(&s); + assert!(result.is_some(), "returns() must be Some at Terminal"); + + let [r1, r2] = result.unwrap(); + let sum = r1 + r2; + assert!( + (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, + "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" + ); + assert!( + r1.abs() <= 1.0 && r2.abs() <= 1.0, + "returns must be in [-1,1]: r1={r1}, r2={r2}" + ); + } + + /// Run multiple games with different seeds to stress-test for panics. + #[test] + fn multiple_games_no_panic() { + let e = env(); + let max_steps = 20_000; + + for seed in 0..10u64 { + let mut s = e.new_game(); + let mut rng = seeded_rng(seed); + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + e.apply(&mut s, actions[idx]); + } + } + } + } + } + + // ── Returns ─────────────────────────────────────────────────────────── + + #[test] + fn returns_none_mid_game() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(11); + + // Advance a few steps but do not finish the game. + for _ in 0..4 { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P1 | Player::P2 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + if !e.current_player(&s).is_terminal() { + assert!( + e.returns(&s).is_none(), + "returns() must be None before the game ends" + ); + } + } + + // ── Player 2 actions ────────────────────────────────────────────────── + + /// Verify that Player 2 (Black) can take actions without panicking, + /// and that the state advances correctly. + #[test] + fn player2_can_act() { + let e = env(); + let mut s = e.new_game(); + let mut rng = seeded_rng(12); + + // Keep stepping until Player 2 gets a turn. + let max_steps = 5_000; + let mut p2_acted = false; + + for _ in 0..max_steps { + match e.current_player(&s) { + Player::Terminal => break, + Player::Chance => e.apply_chance(&mut s, &mut rng), + Player::P2 => { + let actions = e.legal_actions(&s); + assert!(!actions.is_empty()); + e.apply(&mut s, actions[0]); + p2_acted = true; + break; + } + Player::P1 => { + let actions = e.legal_actions(&s); + e.apply(&mut s, actions[0]); + } + } + } + + assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs new file mode 100644 index 0000000..23895b9 --- /dev/null +++ b/spiel_bot/src/lib.rs @@ -0,0 +1,4 @@ +pub mod alphazero; +pub mod env; +pub mod mcts; +pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs new file mode 100644 index 0000000..e92bd09 --- /dev/null +++ b/spiel_bot/src/mcts/mod.rs @@ -0,0 +1,408 @@ +//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. +//! +//! # Algorithm +//! +//! The implementation follows AlphaZero's MCTS: +//! +//! 1. **Expand root** — run the network once to get priors and a value +//! estimate; optionally add Dirichlet noise for training-time exploration. +//! 2. **Simulate** `n_simulations` times: +//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. +//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; +//! chance nodes are **not** stored in the tree (outcome sampling). +//! - *Expansion* — evaluate the network at the leaf; populate children. +//! - *Backup* — propagate the value upward; negate at each player boundary. +//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). +//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). +//! +//! # Perspective convention +//! +//! Every [`MctsNode::w`] is stored **from the perspective of the player who +//! acts at that node**. The backup negates the child value whenever the +//! acting player differs between parent and child. +//! +//! # Stochastic games +//! +//! When [`GameEnv::current_player`] returns [`Player::Chance`], the +//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and +//! continues. Chance nodes are skipped transparently; Q-values converge to +//! their expectation over many simulations (outcome sampling). + +pub mod node; +mod search; + +pub use node::MctsNode; + +use rand::Rng; + +use crate::env::GameEnv; + +// ── Evaluator trait ──────────────────────────────────────────────────────── + +/// Evaluates a game position for use in MCTS. +/// +/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) +/// but the `mcts` module itself does **not** depend on Burn. +pub trait Evaluator: Send + Sync { + /// Evaluate `obs` (flat observation vector of length `obs_size`). + /// + /// Returns: + /// - `policy_logits`: one raw logit per action (`action_space` entries). + /// Illegal action entries are masked inside the search — no need to + /// zero them here. + /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. + fn evaluate(&self, obs: &[f32]) -> (Vec, f32); +} + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Hyperparameters for [`run_mcts`]. +#[derive(Debug, Clone)] +pub struct MctsConfig { + /// Number of MCTS simulations per move. Typical: 50–800. + pub n_simulations: usize, + /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. + pub c_puct: f32, + /// Dirichlet noise concentration α. Set to `0.0` to disable. + /// Typical: `0.3` for Chess, `0.1` for large action spaces. + pub dirichlet_alpha: f32, + /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. + pub dirichlet_eps: f32, + /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. + pub temperature: f32, +} + +impl Default for MctsConfig { + fn default() -> Self { + Self { + n_simulations: 200, + c_puct: 1.5, + dirichlet_alpha: 0.3, + dirichlet_eps: 0.25, + temperature: 1.0, + } + } +} + +// ── Public interface ─────────────────────────────────────────────────────── + +/// Run MCTS from `state` and return the populated root [`MctsNode`]. +/// +/// `state` must be a player-decision node (`P1` or `P2`). +/// Use [`mcts_policy`] and [`select_action`] on the returned root. +/// +/// # Panics +/// +/// Panics if `env.current_player(state)` is not `P1` or `P2`. +pub fn run_mcts( + env: &E, + state: &E::State, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, +) -> MctsNode { + let player_idx = env + .current_player(state) + .index() + .expect("run_mcts called at a non-decision node"); + + // ── Expand root (network called once here, not inside the loop) ──────── + let mut root = MctsNode::new(1.0); + search::expand::(&mut root, state, env, evaluator, player_idx); + + // ── Optional Dirichlet noise for training exploration ────────────────── + if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { + search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); + } + + // ── Simulations ──────────────────────────────────────────────────────── + for _ in 0..config.n_simulations { + search::simulate::( + &mut root, + state.clone(), + env, + evaluator, + config, + rng, + player_idx, + ); + } + + root +} + +/// Compute the MCTS policy: normalized visit counts at the root. +/// +/// Returns a vector of length `action_space` where `policy[a]` is the +/// fraction of simulations that visited action `a`. +pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { + let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); + let mut policy = vec![0.0f32; action_space]; + if total > 0.0 { + for (a, child) in &root.children { + policy[*a] = child.n as f32 / total; + } + } else if !root.children.is_empty() { + // n_simulations = 0: uniform over legal actions. + let uniform = 1.0 / root.children.len() as f32; + for (a, _) in &root.children { + policy[*a] = uniform; + } + } + policy +} + +/// Select an action index from the root after MCTS. +/// +/// * `temperature = 0` — greedy argmax of visit counts. +/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. +/// +/// # Panics +/// +/// Panics if the root has no children. +pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { + assert!(!root.children.is_empty(), "select_action called on a root with no children"); + if temperature <= 0.0 { + root.children + .iter() + .max_by_key(|(_, c)| c.n) + .map(|(a, _)| *a) + .unwrap() + } else { + let weights: Vec = root + .children + .iter() + .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) + .collect(); + let total: f32 = weights.iter().sum(); + let mut r: f32 = rng.random::() * total; + for (i, (a, _)) in root.children.iter().enumerate() { + r -= weights[i]; + if r <= 0.0 { + return *a; + } + } + root.children.last().map(|(a, _)| *a).unwrap() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + use crate::env::Player; + + // ── Minimal deterministic test game ─────────────────────────────────── + // + // "Countdown" — two players alternate subtracting 1 or 2 from a counter. + // The player who brings the counter to 0 wins. + // No chance nodes, two legal actions (0 = -1, 1 = -2). + + #[derive(Clone, Debug)] + struct CState { + remaining: u8, + to_move: usize, // at terminal: last mover (winner) + } + + #[derive(Clone)] + struct CountdownEnv; + + impl crate::env::GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: 6, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { + Player::Terminal + } else if s.to_move == 0 { + Player::P1 + } else { + Player::P2 + } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + // to_move stays as winner + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 6.0, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + // Uniform evaluator: all logits = 0, value = 0. + // `action_space` must match the environment's `action_space()`. + struct ZeroEval(usize); + impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(42) + } + + fn config_n(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, // off for reproducibility + dirichlet_eps: 0.0, + temperature: 1.0, + } + } + + // ── Visit count tests ───────────────────────────────────────────────── + + #[test] + fn visit_counts_sum_to_n_simulations() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 50, "visit counts must sum to n_simulations"); + } + + #[test] + fn all_root_children_are_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); + for (a, _) in &root.children { + assert!(legal.contains(a), "child action {a} is not legal"); + } + } + + // ── Policy tests ───────────────────────────────────────────────────── + + #[test] + fn policy_sums_to_one() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); + } + + #[test] + fn policy_zero_for_illegal_actions() { + let env = CountdownEnv; + // remaining = 1 → only action 0 is legal + let state = CState { remaining: 1, to_move: 0 }; + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); + } + + // ── Action selection tests ──────────────────────────────────────────── + + #[test] + fn greedy_selects_most_visited() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); + let greedy = select_action(&root, 0.0, &mut rng()); + let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); + assert_eq!(greedy, most_visited); + } + + #[test] + fn temperature_sampling_stays_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let mut r = rng(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); + for _ in 0..20 { + let a = select_action(&root, 1.0, &mut r); + assert!(legal.contains(&a), "sampled action {a} is not legal"); + } + } + + // ── Zero-simulation edge case ───────────────────────────────────────── + + #[test] + fn zero_simulations_uniform_policy() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + // With 0 simulations, fallback is uniform over the 2 legal actions. + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + // ── Root value ──────────────────────────────────────────────────────── + + #[test] + fn root_q_in_valid_range() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); + let q = root.q(); + assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); + } + + // ── Integration: run on a real Trictrac game ────────────────────────── + + #[test] + fn no_panic_on_trictrac_state() { + use crate::env::TrictracEnv; + + let env = TrictracEnv; + let mut state = env.new_game(); + let mut r = rng(); + + // Advance past the initial chance node to reach a decision node. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, &mut r); + } + + if env.current_player(&state).is_terminal() { + return; // unlikely but safe + } + + let config = MctsConfig { + n_simulations: 5, // tiny for speed + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + ..MctsConfig::default() + }; + + let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); + assert!(root.n > 0); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 5); + } +} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs new file mode 100644 index 0000000..aff7735 --- /dev/null +++ b/spiel_bot/src/mcts/node.rs @@ -0,0 +1,91 @@ +//! MCTS tree node. +//! +//! [`MctsNode`] holds the visit statistics for one player-decision position in +//! the search tree. A node is *expanded* the first time the policy-value +//! network is evaluated there; before that it is a leaf. + +/// One node in the MCTS tree, representing a player-decision position. +/// +/// `w` stores the sum of values backed up into this node, always from the +/// perspective of **the player who acts here**. `q()` therefore also returns +/// a value in `(-1, 1)` from that same perspective. +#[derive(Debug)] +pub struct MctsNode { + /// Visit count `N(s, a)`. + pub n: u32, + /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. + pub w: f32, + /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). + pub p: f32, + /// Children: `(action_index, child_node)`, populated on first expansion. + pub children: Vec<(usize, MctsNode)>, + /// `true` after the network has been evaluated and children have been set up. + pub expanded: bool, +} + +impl MctsNode { + /// Create a fresh, unexpanded leaf with the given prior probability. + pub fn new(prior: f32) -> Self { + Self { + n: 0, + w: 0.0, + p: prior, + children: Vec::new(), + expanded: false, + } + } + + /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. + #[inline] + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT selection score: + /// + /// ```text + /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) + /// ``` + #[inline] + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_zero_when_unvisited() { + let node = MctsNode::new(0.5); + assert_eq!(node.q(), 0.0); + } + + #[test] + fn q_reflects_w_over_n() { + let mut node = MctsNode::new(0.5); + node.n = 4; + node.w = 2.0; + assert!((node.q() - 0.5).abs() < 1e-6); + } + + #[test] + fn puct_exploration_dominates_unvisited() { + // Unvisited child should outscore a visited child with negative Q. + let mut visited = MctsNode::new(0.5); + visited.n = 10; + visited.w = -5.0; // Q = -0.5 + + let unvisited = MctsNode::new(0.5); + + let parent_n = 10; + let c = 1.5; + assert!( + unvisited.puct(parent_n, c) > visited.puct(parent_n, c), + "unvisited child should have higher PUCT than a negatively-valued visited child" + ); + } +} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs new file mode 100644 index 0000000..c4960c7 --- /dev/null +++ b/spiel_bot/src/mcts/search.rs @@ -0,0 +1,170 @@ +//! Simulation, expansion, backup, and noise helpers. +//! +//! These are internal to the `mcts` module; the public entry points are +//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. + +use rand::Rng; +use rand_distr::{Gamma, Distribution}; + +use crate::env::GameEnv; +use super::{Evaluator, MctsConfig}; +use super::node::MctsNode; + +// ── Masked softmax ───────────────────────────────────────────────────────── + +/// Numerically stable softmax over `legal` actions only. +/// +/// Illegal logits are treated as `-∞` and receive probability `0.0`. +/// Returns a probability vector of length `action_space`. +pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { + let mut probs = vec![0.0f32; action_space]; + if legal.is_empty() { + return probs; + } + let max_logit = legal + .iter() + .map(|&a| logits[a]) + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for &a in legal { + let e = (logits[a] - max_logit).exp(); + probs[a] = e; + sum += e; + } + if sum > 0.0 { + for &a in legal { + probs[a] /= sum; + } + } else { + let uniform = 1.0 / legal.len() as f32; + for &a in legal { + probs[a] = uniform; + } + } + probs +} + +// ── Dirichlet noise ──────────────────────────────────────────────────────── + +/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. +/// +/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. +/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. +pub(super) fn add_dirichlet_noise( + node: &mut MctsNode, + alpha: f32, + eps: f32, + rng: &mut impl Rng, +) { + let n = node.children.len(); + if n == 0 { + return; + } + let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { + return; + }; + let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); + let sum: f32 = samples.iter().sum(); + if sum <= 0.0 { + return; + } + for (i, (_, child)) in node.children.iter_mut().enumerate() { + let noise = samples[i] / sum; + child.p = (1.0 - eps) * child.p + eps * noise; + } +} + +// ── Expansion ────────────────────────────────────────────────────────────── + +/// Evaluate the network at `state` and populate `node` with children. +/// +/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. +/// Returns the network value estimate from `player_idx`'s perspective. +pub(super) fn expand( + node: &mut MctsNode, + state: &E::State, + env: &E, + evaluator: &dyn Evaluator, + player_idx: usize, +) -> f32 { + let obs = env.observation(state, player_idx); + let legal = env.legal_actions(state); + let (logits, value) = evaluator.evaluate(&obs); + let priors = masked_softmax(&logits, &legal, env.action_space()); + node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); + node.expanded = true; + node.n = 1; + node.w = value; + value +} + +// ── Simulation ───────────────────────────────────────────────────────────── + +/// One MCTS simulation from an **already-expanded** decision node. +/// +/// Traverses the tree with PUCT selection, expands the first unvisited leaf, +/// and backs up the result. +/// +/// * `player_idx` — the player (0 or 1) who acts at `state`. +/// * Returns the backed-up value **from `player_idx`'s perspective**. +pub(super) fn simulate( + node: &mut MctsNode, + state: E::State, + env: &E, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, + player_idx: usize, +) -> f32 { + debug_assert!(node.expanded, "simulate called on unexpanded node"); + + // ── Selection: child with highest PUCT ──────────────────────────────── + let parent_n = node.n; + let best = node + .children + .iter() + .enumerate() + .max_by(|(_, (_, a)), (_, (_, b))| { + a.puct(parent_n, config.c_puct) + .partial_cmp(&b.puct(parent_n, config.c_puct)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) + .expect("expanded node must have at least one child"); + + let (action, child) = &mut node.children[best]; + let action = *action; + + // ── Apply action + advance through any chance nodes ─────────────────── + let mut next_state = state; + env.apply(&mut next_state, action); + while env.current_player(&next_state).is_chance() { + env.apply_chance(&mut next_state, rng); + } + + let next_cp = env.current_player(&next_state); + + // ── Evaluate leaf or terminal ────────────────────────────────────────── + // All values are converted to `player_idx`'s perspective before backup. + let child_value = if next_cp.is_terminal() { + let returns = env + .returns(&next_state) + .expect("terminal node must have returns"); + returns[player_idx] + } else { + let child_player = next_cp.index().unwrap(); + let v = if child.expanded { + simulate(child, next_state, env, evaluator, config, rng, child_player) + } else { + expand::(child, &next_state, env, evaluator, child_player) + }; + // Negate when the child belongs to the opponent. + if child_player == player_idx { v } else { -v } + }; + + // ── Backup ──────────────────────────────────────────────────────────── + node.n += 1; + node.w += child_value; + + child_value +} diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs new file mode 100644 index 0000000..eb6184e --- /dev/null +++ b/spiel_bot/src/network/mlp.rs @@ -0,0 +1,223 @@ +//! Two-hidden-layer MLP policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`MlpNet`]. +#[derive(Debug, Clone)] +pub struct MlpConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for MlpConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 256, + } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Simple two-hidden-layer MLP with shared trunk and two heads. +/// +/// Prefer this over [`ResNet`](super::ResNet) when training time is a +/// priority, or as a fast baseline. +#[derive(Module, Debug)] +pub struct MlpNet { + fc1: Linear, + fc2: Linear, + policy_head: Linear, + value_head: Linear, +} + +impl MlpNet { + /// Construct a fresh network with random weights. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + value_head: LinearConfig::new(config.hidden_size, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + /// + /// The file is written exactly at `path`; callers should append `.mpk` if + /// they want the conventional extension. + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for MlpNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn default_net() -> MlpNet { + MlpNet::new(&MlpConfig::default(), &device()) + } + + fn zeros_obs(batch: usize) -> Tensor { + Tensor::zeros([batch, 217], &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let net = default_net(); + let obs = zeros_obs(4); + let (policy, value) = net.forward(obs); + + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let (policy, value) = net.forward(zeros_obs(1)); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let net = default_net(); + // Use a non-zero input so the output is not trivially at 0. + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net.forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Policy logits ───────────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + // With random weights the 514 logits should not all be identical. + let net = default_net(); + let (policy, _) = net.forward(zeros_obs(1)); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical — network may be degenerate"); + } + + // ── Config propagation ──────────────────────────────────────────────── + + #[test] + fn custom_config_shapes() { + let config = MlpConfig { + obs_size: 10, + action_size: 20, + hidden_size: 32, + }; + let net = MlpNet::::new(&config, &device()); + let obs = Tensor::zeros([3, 10], &device()); + let (policy, value) = net.forward(obs); + assert_eq!(policy.dims(), [3, 20]); + assert_eq!(value.dims(), [3, 1]); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = MlpConfig::default(); + let net = default_net(); + + // Forward pass before saving. + let obs = Tensor::::ones([2, 217], &device()); + let (policy_before, value_before) = net.forward(obs.clone()); + + // Save to a temp file. + let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); + net.save(&path).expect("save failed"); + + // Load into a fresh model. + let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + // Outputs must be bitwise identical. + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs new file mode 100644 index 0000000..df710e9 --- /dev/null +++ b/spiel_bot/src/network/mod.rs @@ -0,0 +1,64 @@ +//! Neural network abstractions for policy-value learning. +//! +//! # Trait +//! +//! [`PolicyValueNet`] is the single trait that all network architectures +//! implement. It takes an observation tensor and returns raw policy logits +//! plus a tanh-squashed scalar value estimate. +//! +//! # Architectures +//! +//! | Module | Description | Default hidden | +//! |--------|-------------|----------------| +//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | +//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | +//! +//! # Backend convention +//! +//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). +//! * **Training** — use `Autodiff>` so Burn can differentiate +//! through the forward pass. +//! +//! Both modes use the exact same struct; only the type-level backend changes: +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! type InferBackend = NdArray; +//! type TrainBackend = Autodiff>; +//! +//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! ``` +//! +//! # Output shapes +//! +//! Given a batch of `B` observations of size `obs_size`: +//! +//! | Output | Shape | Range | +//! |--------|-------|-------| +//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | +//! | `value` | `[B, 1]` | (-1, 1) via tanh | +//! +//! Callers are responsible for masking illegal actions in `policy_logits` +//! before passing to softmax. + +pub mod mlp; +pub mod resnet; + +pub use mlp::{MlpConfig, MlpNet}; +pub use resnet::{ResNet, ResNetConfig}; + +use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; + +/// A neural network that produces a policy and a value from an observation. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) +/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses +/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. +/// Use an `Arc>` wrapper if cross-thread sharing is needed. +pub trait PolicyValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); +} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs new file mode 100644 index 0000000..d20d5ad --- /dev/null +++ b/spiel_bot/src/network/resnet.rs @@ -0,0 +1,253 @@ +//! Residual-block policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU (input projection) +//! → ResBlock × 4 (residual trunk) +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! +//! ResBlock: +//! x → Linear → ReLU → Linear → (+x) → ReLU +//! ``` +//! +//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better +//! suited for long training runs where board-pattern recognition matters. + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`ResNet`]. +#[derive(Debug, Clone)] +pub struct ResNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of all hidden layers (input projection + residual blocks). + pub hidden_size: usize, +} + +impl Default for ResNetConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 512, + } + } +} + +// ── Residual block ──────────────────────────────────────────────────────────── + +/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. +/// +/// Both linear layers preserve the hidden dimension so the skip connection +/// can be added without projection. +#[derive(Module, Debug)] +struct ResBlock { + fc1: Linear, + fc2: Linear, +} + +impl ResBlock { + fn new(hidden: usize, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(hidden, hidden).init(device), + fc2: LinearConfig::new(hidden, hidden).init(device), + } + } + + fn forward(&self, x: Tensor) -> Tensor { + let residual = x.clone(); + let out = relu(self.fc1.forward(x)); + relu(self.fc2.forward(out) + residual) + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Four-residual-block policy-value network. +/// +/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and +/// when representing complex positional patterns is important. +#[derive(Module, Debug)] +pub struct ResNet { + input: Linear, + block0: ResBlock, + block1: ResBlock, + block2: ResBlock, + block3: ResBlock, + policy_head: Linear, + value_head: Linear, +} + +impl ResNet { + /// Construct a fresh network with random weights. + pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { + let h = config.hidden_size; + Self { + input: LinearConfig::new(config.obs_size, h).init(device), + block0: ResBlock::new(h, device), + block1: ResBlock::new(h, device), + block2: ResBlock::new(h, device), + block3: ResBlock::new(h, device), + policy_head: LinearConfig::new(h, config.action_size).init(device), + value_head: LinearConfig::new(h, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for ResNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.input.forward(obs)); + let x = self.block0.forward(x); + let x = self.block1.forward(x); + let x = self.block2.forward(x); + let x = self.block3.forward(x); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn small_config() -> ResNetConfig { + // Use a small hidden size so tests are fast. + ResNetConfig { + obs_size: 217, + action_size: 514, + hidden_size: 64, + } + } + + fn net() -> ResNet { + ResNet::new(&small_config(), &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let obs = Tensor::zeros([4, 217], &device()); + let (policy, value) = net().forward(obs); + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net().forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Residual connections ────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical"); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = small_config(); + let model = net(); + let obs = Tensor::::ones([2, 217], &device()); + + let (policy_before, value_before) = model.forward(obs.clone()); + + let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); + model.save(&path).expect("save failed"); + + let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } + + // ── Integration: both architectures satisfy PolicyValueNet ──────────── + + #[test] + fn resnet_satisfies_trait() { + fn requires_net>(net: &N, obs: Tensor) { + let (p, v) = net.forward(obs); + assert_eq!(p.dims()[1], 514); + assert_eq!(v.dims()[1], 1); + } + requires_net(&net(), Tensor::zeros([2, 217], &device())); + } +} diff --git a/store/src/game.rs b/store/src/game.rs index f553bdb..2fde45c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -225,22 +225,26 @@ impl GameState { let mut t = Vec::with_capacity(217); let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - // [0..95] own (White) checkers, TD-Gammon encoding + // [0..95] own (White) checkers, TD-Gammon encoding. + // Each field contributes 4 values: + // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] + // The overflow term is divided by 12 because the maximum excess is + // 15 (all checkers) − 3 = 12. for &c in &pos { let own = c.max(0) as u8; t.push((own == 1) as u8 as f32); t.push((own == 2) as u8 as f32); t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32); + t.push(own.saturating_sub(3) as f32 / 12.0); } - // [96..191] opp (Black) checkers, TD-Gammon encoding + // [96..191] opp (Black) checkers, same encoding. for &c in &pos { let opp = (-c).max(0) as u8; t.push((opp == 1) as u8 as f32); t.push((opp == 2) as u8 as f32); t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32); + t.push(opp.saturating_sub(3) as f32 / 12.0); } // [192..193] dice