diff --git a/Cargo.lock b/Cargo.lock index 0baa02a..a43261e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5891,17 +5891,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "spiel_bot" -version = "0.1.0" -dependencies = [ - "anyhow", - "burn", - "rand 0.9.2", - "rand_distr", - "trictrac-store", -] - [[package]] name = "spin" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 4c2eb15..b9e6d45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] +members = ["client_cli", "bot", "store"] diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md deleted file mode 100644 index 7e5ed1f..0000000 --- a/doc/spiel_bot_research.md +++ /dev/null @@ -1,707 +0,0 @@ -# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac - -## 0. Context and Scope - -The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library -(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` -encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every -other stage to an inline random-opponent loop. - -`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency -for **self-play training**. Its goals: - -- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") - that works with Trictrac's multi-stage turn model and stochastic dice. -- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) - as the first algorithm. -- Remain **modular**: adding DQN or PPO later requires only a new - `impl Algorithm for Dqn` without touching the environment or network layers. -- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from - `trictrac-store`. - ---- - -## 1. Library Landscape - -### 1.1 Neural Network Frameworks - -| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | -|-------|----------|-----|-----------|----------|-------| -| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | -| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | -| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | -| ndarray alone | no | no | yes | mature | array ops only; no autograd | - -**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ -runtime needed, the `ndarray` backend is sufficient for CPU training and can -switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by -changing one type alias. - -`tch-rs` would be the best choice for raw training throughput (it is the most -battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the -pure-Rust constraint. If training speed becomes the bottleneck after prototyping, -switching `spiel_bot` to `tch-rs` is a one-line backend swap. - -### 1.2 Other Key Crates - -| Crate | Role | -|-------|------| -| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | -| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | -| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | -| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | -| `anyhow` | error propagation (already used everywhere) | -| `indicatif` | training progress bars | -| `tracing` | structured logging per episode/iteration | - -### 1.3 What `burn-rl` Provides (and Does Not) - -The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) -provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` -trait. It does **not** provide: - -- MCTS or any tree-search algorithm -- Two-player self-play -- Legal action masking during training -- Chance-node handling - -For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its -own (simpler, more targeted) traits and implement MCTS from scratch. - ---- - -## 2. Trictrac-Specific Design Constraints - -### 2.1 Multi-Stage Turn Model - -A Trictrac turn passes through up to six `TurnStage` values. Only two involve -genuine player choice: - -| TurnStage | Node type | Handler | -|-----------|-----------|---------| -| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | -| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | -| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | -| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | -| `Move` | **Player decision** | MCTS / policy network | -| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | - -The environment wrapper advances through forced/chance stages automatically so -that from the algorithm's perspective every node it sees is a genuine player -decision. - -### 2.2 Stochastic Dice in MCTS - -AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice -introduce stochasticity. Three approaches exist: - -**A. Outcome sampling (recommended)** -During each MCTS simulation, when a chance node is reached, sample one dice -outcome at random and continue. After many simulations the expected value -converges. This is the approach used by OpenSpiel's MCTS for stochastic games -and requires no changes to the standard PUCT formula. - -**B. Chance-node averaging (expectimax)** -At each chance node, expand all 21 unique dice pairs weighted by their -probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is -exact but multiplies the branching factor by ~21 at every dice roll, making it -prohibitively expensive. - -**C. Condition on dice in the observation (current approach)** -Dice values are already encoded at indices [192–193] of `to_tensor()`. The -network naturally conditions on the rolled dice when it evaluates a position. -MCTS only runs on player-decision nodes *after* the dice have been sampled; -chance nodes are bypassed by the environment wrapper (approach A). The policy -and value heads learn to play optimally given any dice pair. - -**Use approach A + C together**: the environment samples dice automatically -(chance node bypass), and the 217-dim tensor encodes the dice so the network -can exploit them. - -### 2.3 Perspective / Mirroring - -All move rules and tensor encoding are defined from White's perspective. -`to_tensor()` must always be called after mirroring the state for Black. -The environment wrapper handles this transparently: every observation returned -to an algorithm is already in the active player's perspective. - -### 2.4 Legal Action Masking - -A crucial difference from the existing `bot/` code: instead of penalizing -invalid actions with `ERROR_REWARD`, the policy head logits are **masked** -before softmax — illegal action logits are set to `-inf`. This prevents the -network from wasting capacity on illegal moves and eliminates the need for the -penalty-reward hack. - ---- - -## 3. Proposed Crate Architecture - -``` -spiel_bot/ -├── Cargo.toml -└── src/ - ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" - │ - ├── env/ - │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface - │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store - │ - ├── mcts/ - │ ├── mod.rs # MctsConfig, run_mcts() entry point - │ ├── node.rs # MctsNode (visit count, W, prior, children) - │ └── search.rs # simulate(), backup(), select_action() - │ - ├── network/ - │ ├── mod.rs # PolicyValueNet trait - │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads - │ - ├── alphazero/ - │ ├── mod.rs # AlphaZeroConfig - │ ├── selfplay.rs # generate_episode() -> Vec - │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) - │ └── trainer.rs # training loop: selfplay → sample → loss → update - │ - └── agent/ - ├── mod.rs # Agent trait - ├── random.rs # RandomAgent (baseline) - └── mcts_agent.rs # MctsAgent: uses trained network for inference -``` - -Future algorithms slot in without touching the above: - -``` - ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer - └── ppo/ # (future) PPO: impl Algorithm + rollout buffer -``` - ---- - -## 4. Core Traits - -### 4.1 `GameEnv` — the minimal OpenSpiel interface - -```rust -use rand::Rng; - -/// Who controls the current node. -pub enum Player { - P1, // player index 0 - P2, // player index 1 - Chance, // dice roll - Terminal, // game over -} - -pub trait GameEnv: Clone + Send + Sync + 'static { - type State: Clone + Send + Sync; - - /// Fresh game state. - fn new_game(&self) -> Self::State; - - /// Who acts at this node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices (always in [0, action_space())). - /// Empty only at Terminal nodes. - fn legal_actions(&self, s: &Self::State) -> Vec; - - /// Apply a player action (must be legal). - fn apply(&self, s: &mut Self::State, action: usize); - - /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. - fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); - - /// Observation tensor from `pov`'s perspective (0 or 1). - /// Returns 217 f32 values for Trictrac. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Flat observation size (217 for Trictrac). - fn obs_size(&self) -> usize; - - /// Total action-space size (514 for Trictrac). - fn action_space(&self) -> usize; - - /// Game outcome per player, or None if not Terminal. - /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} -``` - -### 4.2 `PolicyValueNet` — neural network interface - -```rust -use burn::prelude::*; - -pub trait PolicyValueNet: Send + Sync { - /// Forward pass. - /// `obs`: [batch, obs_size] tensor. - /// Returns: (policy_logits [batch, action_space], value [batch]). - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); - - /// Save weights to `path`. - fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; - - /// Load weights from `path`. - fn load(path: &std::path::Path) -> anyhow::Result - where - Self: Sized; -} -``` - -### 4.3 `Agent` — player policy interface - -```rust -pub trait Agent: Send { - /// Select an action index given the current game state observation. - /// `legal`: mask of valid action indices. - fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; -} -``` - ---- - -## 5. MCTS Implementation - -### 5.1 Node - -```rust -pub struct MctsNode { - n: u32, // visit count N(s, a) - w: f32, // sum of backed-up values W(s, a) - p: f32, // prior from policy head P(s, a) - children: Vec<(usize, MctsNode)>, // (action_idx, child) - is_expanded: bool, -} - -impl MctsNode { - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT score used for selection. - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} -``` - -### 5.2 Simulation Loop - -One MCTS simulation (for deterministic decision nodes): - -``` -1. SELECTION — traverse from root, always pick child with highest PUCT, - auto-advancing forced/chance nodes via env.apply_chance(). -2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get - (policy_logits, value). Mask illegal actions, softmax - the remaining logits → priors P(s,a) for each child. -3. BACKUP — propagate -value up the tree (negate at each level because - perspective alternates between P1 and P2). -``` - -After `n_simulations` iterations, action selection at the root: - -```rust -// During training: sample proportional to N^(1/temperature) -// During evaluation: argmax N -fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } -``` - -### 5.3 Configuration - -```rust -pub struct MctsConfig { - pub n_simulations: usize, // e.g. 200 - pub c_puct: f32, // exploration constant, e.g. 1.5 - pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 - pub dirichlet_eps: f32, // noise weight, e.g. 0.25 - pub temperature: f32, // action sampling temperature (anneals to 0) -} -``` - -### 5.4 Handling Chance Nodes Inside MCTS - -When simulation reaches a Chance node (dice roll), the environment automatically -samples dice and advances to the next decision node. The MCTS tree does **not** -branch on dice outcomes — it treats the sampled outcome as the state. This -corresponds to "outcome sampling" (approach A from §2.2). Because each -simulation independently samples dice, the Q-values at player nodes converge to -their expected value over many simulations. - ---- - -## 6. Network Architecture - -### 6.1 ResNet Policy-Value Network - -A single trunk with residual blocks, then two heads: - -``` -Input: [batch, 217] - ↓ -Linear(217 → 512) + ReLU - ↓ -ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) - ↓ trunk output [batch, 512] - ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) - └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) -``` - -Burn implementation sketch: - -```rust -#[derive(Module, Debug)] -pub struct TrictracNet { - input: Linear, - res_blocks: Vec>, - policy_head: Linear, - value_head: Linear, -} - -impl TrictracNet { - pub fn forward(&self, obs: Tensor) - -> (Tensor, Tensor) - { - let x = activation::relu(self.input.forward(obs)); - let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); - let policy = self.policy_head.forward(x.clone()); // raw logits - let value = activation::tanh(self.value_head.forward(x)) - .squeeze(1); - (policy, value) - } -} -``` - -A simpler MLP (no residual blocks) is sufficient for a first version and much -faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two -heads. - -### 6.2 Loss Function - -``` -L = MSE(value_pred, z) - + CrossEntropy(policy_logits_masked, π_mcts) - - c_l2 * L2_regularization -``` - -Where: -- `z` = game outcome (±1) from the active player's perspective -- `π_mcts` = normalized MCTS visit counts at the root (the policy target) -- Legal action masking is applied before computing CrossEntropy - ---- - -## 7. AlphaZero Training Loop - -``` -INIT - network ← random weights - replay ← empty ReplayBuffer(capacity = 100_000) - -LOOP forever: - ── Self-play phase ────────────────────────────────────────────── - (parallel with rayon, n_workers games at once) - for each game: - state ← env.new_game() - samples = [] - while not terminal: - advance forced/chance nodes automatically - obs ← env.observation(state, current_player) - legal ← env.legal_actions(state) - π, root_value ← mcts.run(state, network, config) - action ← sample from π (with temperature) - samples.push((obs, π, current_player)) - env.apply(state, action) - z ← env.returns(state) // final scores - for (obs, π, player) in samples: - replay.push(TrainSample { obs, policy: π, value: z[player] }) - - ── Training phase ─────────────────────────────────────────────── - for each gradient step: - batch ← replay.sample(batch_size) - (policy_logits, value_pred) ← network.forward(batch.obs) - loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) - optimizer.step(loss.backward()) - - ── Evaluation (every N iterations) ───────────────────────────── - win_rate ← evaluate(network_new vs network_prev, n_eval_games) - if win_rate > 0.55: save checkpoint -``` - -### 7.1 Replay Buffer - -```rust -pub struct TrainSample { - pub obs: Vec, // 217 values - pub policy: Vec, // 514 values (normalized MCTS visit counts) - pub value: f32, // game outcome ∈ {-1, 0, +1} -} - -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - pub fn push(&mut self, s: TrainSample) { - if self.data.len() == self.capacity { self.data.pop_front(); } - self.data.push_back(s); - } - - pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - // sample without replacement - } -} -``` - -### 7.2 Parallelism Strategy - -Self-play is embarrassingly parallel (each game is independent): - -```rust -let samples: Vec = (0..n_games) - .into_par_iter() // rayon - .flat_map(|_| generate_episode(&env, &network, &mcts_config)) - .collect(); -``` - -Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. -Self-play uses inference-only (no gradient tape), so a `NdArray` backend -(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with -`Autodiff>`. - -For larger scale, a producer-consumer architecture (crossbeam-channel) separates -self-play workers from the training thread, allowing continuous data generation -while the GPU trains. - ---- - -## 8. `TrictracEnv` Implementation Sketch - -```rust -use trictrac_store::{ - training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, - Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, -}; - -#[derive(Clone)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - fn current_player(&self, s: &GameState) -> Player { - match s.stage { - Stage::Ended => Player::Terminal, - _ => match s.turn_stage { - TurnStage::RollWaiting => Player::Chance, - _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, - }, - } - } - - fn legal_actions(&self, s: &GameState) -> Vec { - let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; - get_valid_action_indices(&view).unwrap_or_default() - } - - fn apply(&self, s: &mut GameState, action_idx: usize) { - // advance all forced/chance nodes first, then apply the player action - self.advance_forced(s); - let needs_mirror = s.active_player_id == 2; - let view = if needs_mirror { s.mirror() } else { s.clone() }; - if let Some(event) = TrictracAction::from_action_index(action_idx) - .and_then(|a| a.to_event(&view)) - .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) - { - let _ = s.consume(&event); - } - // advance any forced stages that follow - self.advance_forced(s); - } - - fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { - // RollDice → RollWaiting - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - // RollWaiting → next stage - let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; - let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); - self.advance_forced(s); - } - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } - } - - fn obs_size(&self) -> usize { 217 } - fn action_space(&self) -> usize { ACTION_SPACE_SIZE } - - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { return None; } - // Convert hole+point scores to ±1 outcome - let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [ 1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [ 0.0, 0.0], - }) - } -} - -impl TrictracEnv { - /// Advance through all forced (non-decision, non-chance) stages. - fn advance_forced(&self, s: &mut GameState) { - use trictrac_store::PointsRules; - loop { - match s.turn_stage { - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Scoring is deterministic; compute and apply automatically. - let color = s.player_color_by_id(&s.active_player_id) - .unwrap_or(trictrac_store::Color::White); - let drc = s.players.get(&s.active_player_id) - .map(|p| p.dice_roll_count).unwrap_or(0); - let pr = PointsRules::new(&color, &s.board, s.dice); - let pts = pr.get_points(drc); - let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; - let _ = s.consume(&GameEvent::Mark { - player_id: s.active_player_id, points, - }); - } - TurnStage::RollDice => { - // RollDice is a forced "initiate roll" action with no real choice. - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - } - _ => break, - } - } - } -} -``` - ---- - -## 9. Cargo.toml Changes - -### 9.1 Add `spiel_bot` to the workspace - -```toml -# Cargo.toml (workspace root) -[workspace] -resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] -``` - -### 9.2 `spiel_bot/Cargo.toml` - -```toml -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[features] -default = ["alphazero"] -alphazero = [] -# dqn = [] # future -# ppo = [] # future - -[dependencies] -trictrac-store = { path = "../store" } -anyhow = "1" -rand = "0.9" -rayon = "1" -serde = { version = "1", features = ["derive"] } -serde_json = "1" - -# Burn: NdArray for pure-Rust CPU training -# Replace NdArray with Wgpu or Tch for GPU. -burn = { version = "0.20", features = ["ndarray", "autodiff"] } - -# Optional: progress display and structured logging -indicatif = "0.17" -tracing = "0.1" - -[[bin]] -name = "az_train" -path = "src/bin/az_train.rs" - -[[bin]] -name = "az_eval" -path = "src/bin/az_eval.rs" -``` - ---- - -## 10. Comparison: `bot` crate vs `spiel_bot` - -| Aspect | `bot` (existing) | `spiel_bot` (proposed) | -|--------|-----------------|------------------------| -| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | -| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | -| Opponent | hardcoded random | self-play | -| Invalid actions | penalise with reward | legal action mask (no penalty) | -| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | -| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | -| Burn dep | yes (0.20) | yes (0.20), same backend | -| `burn-rl` dep | yes | no | -| C++ dep | no | no | -| Python dep | no | no | -| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | - -The two crates are **complementary**: `bot` is a working DQN/PPO baseline; -`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The -`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just -replace `TrictracEnvironment` with `TrictracEnv`). - ---- - -## 11. Implementation Order - -1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game - through the trait, verify terminal state and returns). -2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + - Burn forward/backward pass test with dummy data. -3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests - (visit counts sum to `n_simulations`, legal mask respected). -4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub - (one iteration, check loss decreases). -5. **Integration test**: run 100 self-play games with a tiny network (1 res block, - 64 hidden units), verify the training loop completes without panics. -6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s - on CPU, consistent with `random_game` throughput). -7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule - hyperparameter sweep. -8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report - win rate every checkpoint. - ---- - -## 12. Key Open Questions - -1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. - AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever - scored more holes). Better option: normalize the score margin. The final - choice affects how the value head is trained. - -2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). - MCTS with 200 simulations per step means ~120k network evaluations per game. - At batch inference this is feasible on CPU; on GPU it becomes fast. Consider - limiting `n_simulations` to 50–100 for early training. - -3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). - This is a long-horizon decision that AlphaZero handles naturally via MCTS - lookahead, but needs careful value normalization (a "Go" restarts scoring - within the same game). - -4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated - to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. - This is optional but reduces code duplication. - -5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, - 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. - A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml deleted file mode 100644 index 323c953..0000000 --- a/spiel_bot/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[dependencies] -trictrac-store = { path = "../store" } -anyhow = "1" -rand = "0.9" -rand_distr = "0.5" -burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs deleted file mode 100644 index bb86724..0000000 --- a/spiel_bot/src/alphazero/mod.rs +++ /dev/null @@ -1,117 +0,0 @@ -//! AlphaZero: self-play data generation, replay buffer, and training step. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | -//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | -//! | [`trainer`] | [`train_step`] | -//! -//! # Typical outer loop -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! use burn::optim::AdamConfig; -//! use spiel_bot::{ -//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, -//! env::TrictracEnv, -//! mcts::MctsConfig, -//! network::{MlpConfig, MlpNet}, -//! }; -//! -//! type Infer = NdArray; -//! type Train = Autodiff>; -//! -//! let device = Default::default(); -//! let env = TrictracEnv; -//! let config = AlphaZeroConfig::default(); -//! -//! // Build training model and optimizer. -//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); -//! let mut optimizer = AdamConfig::new().init(); -//! let mut replay = ReplayBuffer::new(config.replay_capacity); -//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); -//! -//! for _iter in 0..config.n_iterations { -//! // Convert to inference backend for self-play. -//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) -//! .load_record(train_model.clone().into_record()); -//! let eval = BurnEvaluator::new(infer_model, device.clone()); -//! -//! // Self-play: generate episodes. -//! for _ in 0..config.n_games_per_iter { -//! let samples = generate_episode(&env, &eval, &config.mcts, -//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); -//! replay.extend(samples); -//! } -//! -//! // Training: gradient steps. -//! if replay.len() >= config.batch_size { -//! for _ in 0..config.n_train_steps_per_iter { -//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) -//! .into_iter().cloned().collect(); -//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, -//! config.learning_rate); -//! train_model = m; -//! } -//! } -//! } -//! ``` - -pub mod replay; -pub mod selfplay; -pub mod trainer; - -pub use replay::{ReplayBuffer, TrainSample}; -pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::train_step; - -use crate::mcts::MctsConfig; - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Top-level AlphaZero hyperparameters. -/// -/// The MCTS parameters live in [`MctsConfig`]; this struct holds the -/// outer training-loop parameters. -#[derive(Debug, Clone)] -pub struct AlphaZeroConfig { - /// MCTS parameters for self-play. - pub mcts: MctsConfig, - /// Number of self-play games per training iteration. - pub n_games_per_iter: usize, - /// Number of gradient steps per training iteration. - pub n_train_steps_per_iter: usize, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Adam learning rate. - pub learning_rate: f64, - /// Number of outer iterations (self-play + train) to run. - pub n_iterations: usize, - /// Move index after which the action temperature drops to 0 (greedy play). - pub temperature_drop_move: usize, -} - -impl Default for AlphaZeroConfig { - fn default() -> Self { - Self { - mcts: MctsConfig { - n_simulations: 100, - c_puct: 1.5, - dirichlet_alpha: 0.1, - dirichlet_eps: 0.25, - temperature: 1.0, - }, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - batch_size: 64, - replay_capacity: 50_000, - learning_rate: 1e-3, - n_iterations: 100, - temperature_drop_move: 30, - } - } -} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs deleted file mode 100644 index 5e64cc4..0000000 --- a/spiel_bot/src/alphazero/replay.rs +++ /dev/null @@ -1,144 +0,0 @@ -//! Replay buffer for AlphaZero self-play data. - -use std::collections::VecDeque; -use rand::Rng; - -// ── Training sample ──────────────────────────────────────────────────────── - -/// One training example produced by self-play. -#[derive(Clone, Debug)] -pub struct TrainSample { - /// Observation tensor from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). - pub policy: Vec, - /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. - pub value: f32, -} - -// ── Replay buffer ────────────────────────────────────────────────────────── - -/// Fixed-capacity circular buffer of [`TrainSample`]s. -/// -/// When the buffer is full, the oldest sample is evicted on push. -/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - /// Create a buffer with the given maximum capacity. - pub fn new(capacity: usize) -> Self { - Self { - data: VecDeque::with_capacity(capacity.min(1024)), - capacity, - } - } - - /// Add a sample; evicts the oldest if at capacity. - pub fn push(&mut self, sample: TrainSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - /// Add all samples from an episode. - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { - self.push(s); - } - } - - pub fn len(&self) -> usize { - self.data.len() - } - - pub fn is_empty(&self) -> bool { - self.data.is_empty() - } - - /// Sample up to `n` distinct samples, without replacement. - /// - /// If the buffer has fewer than `n` samples, all are returned (shuffled). - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - let len = self.data.len(); - let n = n.min(len); - // Partial Fisher-Yates using index shuffling. - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(value: f32) -> TrainSample { - TrainSample { obs: vec![value], policy: vec![1.0], value } - } - - #[test] - fn push_and_len() { - let mut buf = ReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = ReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); // evicts 1.0 - assert_eq!(buf.len(), 3); - // Oldest remaining should be 2.0 - assert_eq!(buf.data[0].value, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(5, &mut rng); - assert_eq!(batch.len(), 5); - } - - #[test] - fn sample_batch_capped_at_len() { - let mut buf = ReplayBuffer::new(20); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(100, &mut rng); - assert_eq!(batch.len(), 2); - } - - #[test] - fn sample_batch_no_duplicates() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(1); - let batch = buf.sample_batch(10, &mut rng); - let mut seen: Vec = batch.iter().map(|s| s.value).collect(); - seen.sort_by(f32::total_cmp); - seen.dedup(); - assert_eq!(seen.len(), 10, "sample contained duplicates"); - } -} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs deleted file mode 100644 index 6f10f8d..0000000 --- a/spiel_bot/src/alphazero/selfplay.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! Self-play episode generation and Burn-backed evaluator. - -use std::marker::PhantomData; - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::GameEnv; -use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; -use crate::network::PolicyValueNet; - -use super::replay::TrainSample; - -// ── BurnEvaluator ────────────────────────────────────────────────────────── - -/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. -/// -/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so -/// that self-play generates no gradient tape overhead. -pub struct BurnEvaluator> { - model: N, - device: B::Device, - _b: PhantomData, -} - -impl> BurnEvaluator { - pub fn new(model: N, device: B::Device) -> Self { - Self { model, device, _b: PhantomData } - } - - pub fn into_model(self) -> N { - self.model - } -} - -// Safety: NdArray modules are Send; we never share across threads without -// external synchronisation. -unsafe impl> Send for BurnEvaluator {} -unsafe impl> Sync for BurnEvaluator {} - -impl> Evaluator for BurnEvaluator { - fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { - let obs_size = obs.len(); - let data = TensorData::new(obs.to_vec(), [1, obs_size]); - let obs_tensor = Tensor::::from_data(data, &self.device); - - let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); - - let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); - let value: Vec = value_tensor.into_data().to_vec().unwrap(); - - (policy, value[0]) - } -} - -// ── Episode generation ───────────────────────────────────────────────────── - -/// One pending observation waiting for its game-outcome value label. -struct PendingSample { - obs: Vec, - policy: Vec, - player: usize, -} - -/// Play one full game using MCTS guided by `evaluator`. -/// -/// Returns a [`TrainSample`] for every decision step in the game. -/// -/// `temperature_fn(step)` controls exploration: return `1.0` for early -/// moves and `0.0` after a fixed number of moves (e.g. move 30). -pub fn generate_episode( - env: &E, - evaluator: &dyn Evaluator, - mcts_config: &MctsConfig, - temperature_fn: &dyn Fn(usize) -> f32, - rng: &mut impl Rng, -) -> Vec { - let mut state = env.new_game(); - let mut pending: Vec = Vec::new(); - let mut step = 0usize; - - loop { - // Advance through chance nodes automatically. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - if env.current_player(&state).is_terminal() { - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - - // Run MCTS to get a policy. - let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); - let policy = mcts::mcts_policy(&root, env.action_space()); - - // Record the observation from the acting player's perspective. - let obs = env.observation(&state, player_idx); - pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); - - // Select and apply the action. - let temperature = temperature_fn(step); - let action = mcts::select_action(&root, temperature, rng); - env.apply(&mut state, action); - step += 1; - } - - // Assign game outcomes. - let returns = env.returns(&state).unwrap_or([0.0; 2]); - pending - .into_iter() - .map(|s| TrainSample { - obs: s.obs, - policy: s.policy, - value: returns[s.player], - }) - .collect() -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::env::Player; - use crate::mcts::{Evaluator, MctsConfig}; - use crate::network::{MlpConfig, MlpNet}; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(7) - } - - // Countdown game (same as in mcts tests). - #[derive(Clone, Debug)] - struct CState { remaining: u8, to_move: usize } - - #[derive(Clone)] - struct CountdownEnv; - - impl GameEnv for CountdownEnv { - type State = CState; - fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } else { Player::P2 } - } - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { s.remaining = 0; } - else { s.remaining -= sub; s.to_move = 1 - s.to_move; } - } - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 4.0, s.to_move as f32] - } - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - fn tiny_config() -> MctsConfig { - MctsConfig { n_simulations: 5, c_puct: 1.5, - dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } - } - - // ── BurnEvaluator tests ─────────────────────────────────────────────── - - #[test] - fn burn_evaluator_output_shapes() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); - assert_eq!(policy.len(), 2, "policy length should equal action_space"); - assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); - } - - // ── generate_episode tests ──────────────────────────────────────────── - - #[test] - fn episode_terminates_and_has_samples() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_sample_values_are_valid() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - for s in &samples { - assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "unexpected value {}", s.value); - let sum: f32 = s.policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); - assert_eq!(s.obs.len(), 2); - } - } - - #[test] - fn episode_with_temperature_zero() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - // temperature=0 means greedy; episode must still terminate - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); - assert!(!samples.is_empty()); - } -} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs deleted file mode 100644 index d2482d1..0000000 --- a/spiel_bot/src/alphazero/trainer.rs +++ /dev/null @@ -1,172 +0,0 @@ -//! One gradient-descent training step for AlphaZero. -//! -//! The loss combines: -//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. -//! - **Value loss** — mean-squared error between the predicted value and the -//! actual game outcome. -//! -//! # Backend -//! -//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). -//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. -//! Weights are transferred between the two via [`burn::record`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - activation::log_softmax, - backend::AutodiffBackend, - Tensor, TensorData, - }, -}; - -use crate::network::PolicyValueNet; -use super::replay::TrainSample; - -/// Run one gradient step on `model` using `batch`. -/// -/// Returns the updated model and the scalar loss value for logging. -/// -/// # Parameters -/// -/// - `lr` — learning rate (e.g. `1e-3`). -/// - `batch` — slice of [`TrainSample`]s; must be non-empty. -pub fn train_step( - model: N, - optimizer: &mut O, - batch: &[TrainSample], - device: &B::Device, - lr: f64, -) -> (N, f32) -where - B: AutodiffBackend, - N: PolicyValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "train_step called with empty batch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - let action_size = batch[0].policy.len(); - - // ── Build input tensors ──────────────────────────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); - let value_flat: Vec = batch.iter().map(|s| s.value).collect(); - - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - let policy_target = Tensor::::from_data( - TensorData::new(policy_flat, [batch_size, action_size]), - device, - ); - let value_target = Tensor::::from_data( - TensorData::new(value_flat, [batch_size, 1]), - device, - ); - - // ── Forward pass ────────────────────────────────────────────────────── - let (policy_logits, value_pred) = model.forward(obs_tensor); - - // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── - let log_probs = log_softmax(policy_logits, 1); - let policy_loss = (policy_target.clone().neg() * log_probs) - .sum_dim(1) - .mean(); - - // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── - let diff = value_pred - value_target; - let value_loss = (diff.clone() * diff).mean(); - - // ── Combined loss ───────────────────────────────────────────────────── - let loss = policy_loss + value_loss; - - // Extract scalar before backward (consumes the tensor). - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - let model = optimizer.step(lr, model, grads); - - (model, loss_scalar) -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - - use crate::network::{MlpConfig, MlpNet}; - use super::super::replay::TrainSample; - - type B = Autodiff>; - - fn device() -> ::Device { - Default::default() - } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5f32; obs_size], - policy: { - let mut p = vec![0.0f32; action_size]; - p[i % action_size] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - } - - #[test] - fn train_step_returns_finite_loss() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); - } - - #[test] - fn loss_decreases_over_steps() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - // Same batch every step — loss should decrease. - let batch = dummy_batch(16, 4, 4); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); - model = m; - assert!(loss.is_finite()); - prev_loss = loss; - } - // After 10 steps on fixed data, loss should be below a reasonable threshold. - assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn train_step_batch_size_one() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(1, 2, 2); - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite()); - } -} diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs deleted file mode 100644 index 42b4ae0..0000000 --- a/spiel_bot/src/env/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -//! Game environment abstraction — the minimal "Rust OpenSpiel". -//! -//! A `GameEnv` describes the rules of a two-player, zero-sum game that may -//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, -//! and PPO interact with a game exclusively through this trait. -//! -//! # Node taxonomy -//! -//! Every game position belongs to one of four categories, returned by -//! [`GameEnv::current_player`]: -//! -//! | [`Player`] | Meaning | -//! |-----------|---------| -//! | `P1` | Player 1 (index 0) must choose an action | -//! | `P2` | Player 2 (index 1) must choose an action | -//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | -//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | -//! -//! # Perspective convention -//! -//! [`GameEnv::observation`] always returns the board from *the requested -//! player's* point of view. Callers pass `pov = 0` for Player 1 and -//! `pov = 1` for Player 2. The implementation is responsible for any -//! mirroring required (e.g. Trictrac always reasons from White's side). - -pub mod trictrac; -pub use trictrac::TrictracEnv; - -/// Who controls the current game node. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Player { - /// Player 1 (index 0) is to move. - P1, - /// Player 2 (index 1) is to move. - P2, - /// A stochastic event (dice roll, etc.) must be resolved. - Chance, - /// The game is over. - Terminal, -} - -impl Player { - /// Returns the player index (0 or 1) if this is a decision node, - /// or `None` for `Chance` / `Terminal`. - pub fn index(self) -> Option { - match self { - Player::P1 => Some(0), - Player::P2 => Some(1), - _ => None, - } - } - - pub fn is_decision(self) -> bool { - matches!(self, Player::P1 | Player::P2) - } - - pub fn is_chance(self) -> bool { - self == Player::Chance - } - - pub fn is_terminal(self) -> bool { - self == Player::Terminal - } -} - -/// Trait that completely describes a two-player zero-sum game. -/// -/// Implementors must be cheaply cloneable (the type is used as a stateless -/// factory; the mutable game state lives in `Self::State`). -pub trait GameEnv: Clone + Send + Sync + 'static { - /// The mutable game state. Must be `Clone` so MCTS can copy - /// game trees without touching the environment. - type State: Clone + Send + Sync; - - // ── State creation ──────────────────────────────────────────────────── - - /// Create a fresh game state at the initial position. - fn new_game(&self) -> Self::State; - - // ── Node queries ────────────────────────────────────────────────────── - - /// Classify the current node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). - /// - /// The returned indices are in `[0, action_space())`. - /// The result is unspecified (may panic or return empty) when called at a - /// `Chance` or `Terminal` node. - fn legal_actions(&self, s: &Self::State) -> Vec; - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action. `action` must be a value returned by - /// [`legal_actions`] for the current state. - fn apply(&self, s: &mut Self::State, action: usize); - - /// Sample and apply a stochastic outcome. Must only be called when - /// `current_player(s) == Player::Chance`. - fn apply_chance(&self, s: &mut Self::State, rng: &mut R); - - // ── Observation ─────────────────────────────────────────────────────── - - /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). - /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Number of floats returned by [`observation`]. - fn obs_size(&self) -> usize; - - /// Total number of distinct action indices (the policy head output size). - fn action_space(&self) -> usize; - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Game outcome for each player, or `None` if the game is not over. - /// - /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. - /// Index 0 = Player 1, index 1 = Player 2. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs deleted file mode 100644 index 99ba058..0000000 --- a/spiel_bot/src/env/trictrac.rs +++ /dev/null @@ -1,535 +0,0 @@ -//! [`GameEnv`] implementation for Trictrac. -//! -//! # Game flow (schools_enabled = false) -//! -//! With scoring schools disabled (the standard training configuration), -//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine -//! applies them automatically inside `RollResult` and `Move`. The only -//! four stages that actually occur are: -//! -//! | `TurnStage` | [`Player`] kind | Handled by | -//! |-------------|-----------------|------------| -//! | `RollDice` | `Chance` | [`apply_chance`] | -//! | `RollWaiting` | `Chance` | [`apply_chance`] | -//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | -//! | `Move` | `P1`/`P2` | [`apply`] | -//! -//! # Perspective -//! -//! The Trictrac engine always reasons from White's perspective. Player 1 is -//! White; Player 2 is Black. When Player 2 is active, the board is mirrored -//! before computing legal actions / the observation tensor, and the resulting -//! event is mirrored back before being applied to the real state. This -//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. - -use trictrac_store::{ - training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, - Dice, GameEvent, GameState, Stage, TurnStage, -}; - -use super::{GameEnv, Player}; - -/// Stateless factory that produces Trictrac [`GameState`] environments. -/// -/// Schools (`schools_enabled`) are always disabled — scoring is automatic. -#[derive(Clone, Debug, Default)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - // ── State creation ──────────────────────────────────────────────────── - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - // ── Node queries ────────────────────────────────────────────────────── - - fn current_player(&self, s: &GameState) -> Player { - if s.stage == Stage::Ended { - return Player::Terminal; - } - match s.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, - _ => { - if s.active_player_id == 1 { - Player::P1 - } else { - Player::P2 - } - } - } - } - - /// Returns the legal action indices for the active player. - /// - /// The board is automatically mirrored for Player 2 so that the engine - /// always reasons from White's perspective. The returned indices are - /// identical in meaning for both players (checker ordinals are - /// perspective-relative). - /// - /// # Panics - /// - /// Panics in debug builds if called at a `Chance` or `Terminal` node. - fn legal_actions(&self, s: &GameState) -> Vec { - debug_assert!( - self.current_player(s).is_decision(), - "legal_actions called at a non-decision node (turn_stage={:?})", - s.turn_stage - ); - let indices = if s.active_player_id == 2 { - get_valid_action_indices(&s.mirror()) - } else { - get_valid_action_indices(s) - }; - indices.unwrap_or_default() - } - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action index to the game state. - /// - /// For Player 2, the action is decoded against the mirrored board and - /// the resulting event is un-mirrored before being applied. - /// - /// # Panics - /// - /// Panics in debug builds if `action` cannot be decoded or does not - /// produce a valid event for the current state. - fn apply(&self, s: &mut GameState, action: usize) { - let needs_mirror = s.active_player_id == 2; - - let event = if needs_mirror { - let view = s.mirror(); - TrictracAction::from_action_index(action) - .and_then(|a| a.to_event(&view)) - .map(|e| e.get_mirror(false)) - } else { - TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) - }; - - match event { - Some(e) => { - s.consume(&e).expect("apply: consume failed for valid action"); - } - None => { - panic!("apply: action index {action} produced no event in state {s}"); - } - } - } - - /// Sample dice and advance through a chance node. - /// - /// Handles both `RollDice` (triggers the roll mechanism, then samples - /// dice) and `RollWaiting` (only samples dice) in a single call so that - /// callers never need to distinguish the two. - /// - /// # Panics - /// - /// Panics in debug builds if called at a non-Chance node. - fn apply_chance(&self, s: &mut GameState, rng: &mut R) { - debug_assert!( - self.current_player(s).is_chance(), - "apply_chance called at a non-Chance node (turn_stage={:?})", - s.turn_stage - ); - - // Step 1: RollDice → RollWaiting (player initiates the roll). - if s.turn_stage == TurnStage::RollDice { - s.consume(&GameEvent::Roll { - player_id: s.active_player_id, - }) - .expect("apply_chance: Roll event failed"); - } - - // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. - // With schools_enabled=false, point marking is automatic inside consume(). - let dice = Dice { - values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), - }; - s.consume(&GameEvent::RollResult { - player_id: s.active_player_id, - dice, - }) - .expect("apply_chance: RollResult event failed"); - } - - // ── Observation ─────────────────────────────────────────────────────── - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { - s.to_tensor() - } else { - s.mirror().to_tensor() - } - } - - fn obs_size(&self) -> usize { - 217 - } - - fn action_space(&self) -> usize { - ACTION_SPACE_SIZE - } - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. - /// - /// The winner (higher cumulative score) receives `+1.0`; the loser - /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score - /// is `holes × 12 + points`. - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { - return None; - } - let score = |id: u64| -> i32 { - s.players - .get(&id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - }; - let s1 = score(1); - let s2 = score(2); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [0.0, 0.0], - }) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{rngs::SmallRng, Rng, SeedableRng}; - - fn env() -> TrictracEnv { - TrictracEnv - } - - fn seeded_rng(seed: u64) -> SmallRng { - SmallRng::seed_from_u64(seed) - } - - // ── Initial state ───────────────────────────────────────────────────── - - #[test] - fn new_game_is_chance_node() { - let e = env(); - let s = e.new_game(); - // A fresh game starts at RollDice — a Chance node. - assert_eq!(e.current_player(&s), Player::Chance); - assert!(e.returns(&s).is_none()); - } - - #[test] - fn new_game_is_not_terminal() { - let e = env(); - let s = e.new_game(); - assert_ne!(e.current_player(&s), Player::Terminal); - assert!(e.returns(&s).is_none()); - } - - // ── Chance nodes ────────────────────────────────────────────────────── - - #[test] - fn apply_chance_reaches_decision_node() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(1); - - // A single chance step must yield a decision node (or end the game, - // which only happens after 12 holes — impossible on the first roll). - e.apply_chance(&mut s, &mut rng); - let p = e.current_player(&s); - assert!( - p.is_decision(), - "expected decision node after first roll, got {p:?}" - ); - } - - #[test] - fn apply_chance_from_rollwaiting() { - // Check that apply_chance works when called mid-way (at RollWaiting). - let e = env(); - let mut s = e.new_game(); - assert_eq!(s.turn_stage, TurnStage::RollDice); - - // Manually advance to RollWaiting. - s.consume(&GameEvent::Roll { player_id: s.active_player_id }) - .unwrap(); - assert_eq!(s.turn_stage, TurnStage::RollWaiting); - - let mut rng = seeded_rng(2); - e.apply_chance(&mut s, &mut rng); - - let p = e.current_player(&s); - assert!(p.is_decision() || p.is_terminal()); - } - - // ── Legal actions ───────────────────────────────────────────────────── - - #[test] - fn legal_actions_nonempty_after_roll() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(3); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let actions = e.legal_actions(&s); - assert!( - !actions.is_empty(), - "legal_actions must be non-empty at a decision node" - ); - } - - #[test] - fn legal_actions_within_action_space() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(4); - - e.apply_chance(&mut s, &mut rng); - for &a in e.legal_actions(&s).iter() { - assert!( - a < e.action_space(), - "action {a} out of bounds (action_space={})", - e.action_space() - ); - } - } - - // ── Observations ────────────────────────────────────────────────────── - - #[test] - fn observation_has_correct_size() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(5); - e.apply_chance(&mut s, &mut rng); - - assert_eq!(e.observation(&s, 0).len(), e.obs_size()); - assert_eq!(e.observation(&s, 1).len(), e.obs_size()); - } - - #[test] - fn observation_values_in_unit_interval() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(6); - e.apply_chance(&mut s, &mut rng); - - for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { - for (i, &v) in obs.iter().enumerate() { - assert!( - v >= 0.0 && v <= 1.0, - "pov={pov}: obs[{i}] = {v} is outside [0,1]" - ); - } - } - } - - #[test] - fn p1_and_p2_observations_differ() { - // The board is mirrored for P2, so the two observations should differ - // whenever there are checkers in non-symmetric positions (always true - // in a real game after a few moves). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(7); - - // Advance far enough that the board is non-trivial. - for _ in 0..6 { - while e.current_player(&s).is_chance() { - e.apply_chance(&mut s, &mut rng); - } - if e.current_player(&s).is_terminal() { - break; - } - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - - if !e.current_player(&s).is_terminal() { - let obs0 = e.observation(&s, 0); - let obs1 = e.observation(&s, 1); - assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); - } - } - - // ── Applying actions ────────────────────────────────────────────────── - - #[test] - fn apply_changes_state() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(8); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let before = s.clone(); - let action = e.legal_actions(&s)[0]; - e.apply(&mut s, action); - - assert_ne!( - before.turn_stage, s.turn_stage, - "state must change after apply" - ); - } - - #[test] - fn apply_all_legal_actions_do_not_panic() { - // Verify that every action returned by legal_actions can be applied - // without panicking (on several independent copies of the same state). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(9); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - for action in e.legal_actions(&s) { - let mut copy = s.clone(); - e.apply(&mut copy, action); // must not panic - } - } - - // ── Full game ───────────────────────────────────────────────────────── - - /// Run a complete game with random actions through the `GameEnv` trait - /// and verify that: - /// - The game terminates. - /// - `returns()` is `Some` at the end. - /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's - /// score is ±1 / 0. - /// - No step panics. - #[test] - fn full_random_game_terminates() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(42); - let max_steps = 50_000; - - for step in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); - } - - let result = e.returns(&s); - assert!(result.is_some(), "returns() must be Some at Terminal"); - - let [r1, r2] = result.unwrap(); - let sum = r1 + r2; - assert!( - (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, - "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" - ); - assert!( - r1.abs() <= 1.0 && r2.abs() <= 1.0, - "returns must be in [-1,1]: r1={r1}, r2={r2}" - ); - } - - /// Run multiple games with different seeds to stress-test for panics. - #[test] - fn multiple_games_no_panic() { - let e = env(); - let max_steps = 20_000; - - for seed in 0..10u64 { - let mut s = e.new_game(); - let mut rng = seeded_rng(seed); - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - } - } - } - - // ── Returns ─────────────────────────────────────────────────────────── - - #[test] - fn returns_none_mid_game() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(11); - - // Advance a few steps but do not finish the game. - for _ in 0..4 { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - if !e.current_player(&s).is_terminal() { - assert!( - e.returns(&s).is_none(), - "returns() must be None before the game ends" - ); - } - } - - // ── Player 2 actions ────────────────────────────────────────────────── - - /// Verify that Player 2 (Black) can take actions without panicking, - /// and that the state advances correctly. - #[test] - fn player2_can_act() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(12); - - // Keep stepping until Player 2 gets a turn. - let max_steps = 5_000; - let mut p2_acted = false; - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty()); - e.apply(&mut s, actions[0]); - p2_acted = true; - break; - } - Player::P1 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); - } -} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs deleted file mode 100644 index 23895b9..0000000 --- a/spiel_bot/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod alphazero; -pub mod env; -pub mod mcts; -pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs deleted file mode 100644 index e92bd09..0000000 --- a/spiel_bot/src/mcts/mod.rs +++ /dev/null @@ -1,408 +0,0 @@ -//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. -//! -//! # Algorithm -//! -//! The implementation follows AlphaZero's MCTS: -//! -//! 1. **Expand root** — run the network once to get priors and a value -//! estimate; optionally add Dirichlet noise for training-time exploration. -//! 2. **Simulate** `n_simulations` times: -//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. -//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; -//! chance nodes are **not** stored in the tree (outcome sampling). -//! - *Expansion* — evaluate the network at the leaf; populate children. -//! - *Backup* — propagate the value upward; negate at each player boundary. -//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). -//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). -//! -//! # Perspective convention -//! -//! Every [`MctsNode::w`] is stored **from the perspective of the player who -//! acts at that node**. The backup negates the child value whenever the -//! acting player differs between parent and child. -//! -//! # Stochastic games -//! -//! When [`GameEnv::current_player`] returns [`Player::Chance`], the -//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and -//! continues. Chance nodes are skipped transparently; Q-values converge to -//! their expectation over many simulations (outcome sampling). - -pub mod node; -mod search; - -pub use node::MctsNode; - -use rand::Rng; - -use crate::env::GameEnv; - -// ── Evaluator trait ──────────────────────────────────────────────────────── - -/// Evaluates a game position for use in MCTS. -/// -/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) -/// but the `mcts` module itself does **not** depend on Burn. -pub trait Evaluator: Send + Sync { - /// Evaluate `obs` (flat observation vector of length `obs_size`). - /// - /// Returns: - /// - `policy_logits`: one raw logit per action (`action_space` entries). - /// Illegal action entries are masked inside the search — no need to - /// zero them here. - /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. - fn evaluate(&self, obs: &[f32]) -> (Vec, f32); -} - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Hyperparameters for [`run_mcts`]. -#[derive(Debug, Clone)] -pub struct MctsConfig { - /// Number of MCTS simulations per move. Typical: 50–800. - pub n_simulations: usize, - /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. - pub c_puct: f32, - /// Dirichlet noise concentration α. Set to `0.0` to disable. - /// Typical: `0.3` for Chess, `0.1` for large action spaces. - pub dirichlet_alpha: f32, - /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. - pub dirichlet_eps: f32, - /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. - pub temperature: f32, -} - -impl Default for MctsConfig { - fn default() -> Self { - Self { - n_simulations: 200, - c_puct: 1.5, - dirichlet_alpha: 0.3, - dirichlet_eps: 0.25, - temperature: 1.0, - } - } -} - -// ── Public interface ─────────────────────────────────────────────────────── - -/// Run MCTS from `state` and return the populated root [`MctsNode`]. -/// -/// `state` must be a player-decision node (`P1` or `P2`). -/// Use [`mcts_policy`] and [`select_action`] on the returned root. -/// -/// # Panics -/// -/// Panics if `env.current_player(state)` is not `P1` or `P2`. -pub fn run_mcts( - env: &E, - state: &E::State, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, -) -> MctsNode { - let player_idx = env - .current_player(state) - .index() - .expect("run_mcts called at a non-decision node"); - - // ── Expand root (network called once here, not inside the loop) ──────── - let mut root = MctsNode::new(1.0); - search::expand::(&mut root, state, env, evaluator, player_idx); - - // ── Optional Dirichlet noise for training exploration ────────────────── - if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { - search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); - } - - // ── Simulations ──────────────────────────────────────────────────────── - for _ in 0..config.n_simulations { - search::simulate::( - &mut root, - state.clone(), - env, - evaluator, - config, - rng, - player_idx, - ); - } - - root -} - -/// Compute the MCTS policy: normalized visit counts at the root. -/// -/// Returns a vector of length `action_space` where `policy[a]` is the -/// fraction of simulations that visited action `a`. -pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { - let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); - let mut policy = vec![0.0f32; action_space]; - if total > 0.0 { - for (a, child) in &root.children { - policy[*a] = child.n as f32 / total; - } - } else if !root.children.is_empty() { - // n_simulations = 0: uniform over legal actions. - let uniform = 1.0 / root.children.len() as f32; - for (a, _) in &root.children { - policy[*a] = uniform; - } - } - policy -} - -/// Select an action index from the root after MCTS. -/// -/// * `temperature = 0` — greedy argmax of visit counts. -/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. -/// -/// # Panics -/// -/// Panics if the root has no children. -pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { - assert!(!root.children.is_empty(), "select_action called on a root with no children"); - if temperature <= 0.0 { - root.children - .iter() - .max_by_key(|(_, c)| c.n) - .map(|(a, _)| *a) - .unwrap() - } else { - let weights: Vec = root - .children - .iter() - .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) - .collect(); - let total: f32 = weights.iter().sum(); - let mut r: f32 = rng.random::() * total; - for (i, (a, _)) in root.children.iter().enumerate() { - r -= weights[i]; - if r <= 0.0 { - return *a; - } - } - root.children.last().map(|(a, _)| *a).unwrap() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - use crate::env::Player; - - // ── Minimal deterministic test game ─────────────────────────────────── - // - // "Countdown" — two players alternate subtracting 1 or 2 from a counter. - // The player who brings the counter to 0 wins. - // No chance nodes, two legal actions (0 = -1, 1 = -2). - - #[derive(Clone, Debug)] - struct CState { - remaining: u8, - to_move: usize, // at terminal: last mover (winner) - } - - #[derive(Clone)] - struct CountdownEnv; - - impl crate::env::GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: 6, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { - Player::Terminal - } else if s.to_move == 0 { - Player::P1 - } else { - Player::P2 - } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - // to_move stays as winner - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 6.0, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - // Uniform evaluator: all logits = 0, value = 0. - // `action_space` must match the environment's `action_space()`. - struct ZeroEval(usize); - impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(42) - } - - fn config_n(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, // off for reproducibility - dirichlet_eps: 0.0, - temperature: 1.0, - } - } - - // ── Visit count tests ───────────────────────────────────────────────── - - #[test] - fn visit_counts_sum_to_n_simulations() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 50, "visit counts must sum to n_simulations"); - } - - #[test] - fn all_root_children_are_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); - for (a, _) in &root.children { - assert!(legal.contains(a), "child action {a} is not legal"); - } - } - - // ── Policy tests ───────────────────────────────────────────────────── - - #[test] - fn policy_sums_to_one() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); - } - - #[test] - fn policy_zero_for_illegal_actions() { - let env = CountdownEnv; - // remaining = 1 → only action 0 is legal - let state = CState { remaining: 1, to_move: 0 }; - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); - } - - // ── Action selection tests ──────────────────────────────────────────── - - #[test] - fn greedy_selects_most_visited() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); - let greedy = select_action(&root, 0.0, &mut rng()); - let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); - assert_eq!(greedy, most_visited); - } - - #[test] - fn temperature_sampling_stays_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let mut r = rng(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); - for _ in 0..20 { - let a = select_action(&root, 1.0, &mut r); - assert!(legal.contains(&a), "sampled action {a} is not legal"); - } - } - - // ── Zero-simulation edge case ───────────────────────────────────────── - - #[test] - fn zero_simulations_uniform_policy() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - // With 0 simulations, fallback is uniform over the 2 legal actions. - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5); - } - - // ── Root value ──────────────────────────────────────────────────────── - - #[test] - fn root_q_in_valid_range() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); - let q = root.q(); - assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); - } - - // ── Integration: run on a real Trictrac game ────────────────────────── - - #[test] - fn no_panic_on_trictrac_state() { - use crate::env::TrictracEnv; - - let env = TrictracEnv; - let mut state = env.new_game(); - let mut r = rng(); - - // Advance past the initial chance node to reach a decision node. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, &mut r); - } - - if env.current_player(&state).is_terminal() { - return; // unlikely but safe - } - - let config = MctsConfig { - n_simulations: 5, // tiny for speed - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - ..MctsConfig::default() - }; - - let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - assert!(root.n > 0); - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 5); - } -} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs deleted file mode 100644 index aff7735..0000000 --- a/spiel_bot/src/mcts/node.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! MCTS tree node. -//! -//! [`MctsNode`] holds the visit statistics for one player-decision position in -//! the search tree. A node is *expanded* the first time the policy-value -//! network is evaluated there; before that it is a leaf. - -/// One node in the MCTS tree, representing a player-decision position. -/// -/// `w` stores the sum of values backed up into this node, always from the -/// perspective of **the player who acts here**. `q()` therefore also returns -/// a value in `(-1, 1)` from that same perspective. -#[derive(Debug)] -pub struct MctsNode { - /// Visit count `N(s, a)`. - pub n: u32, - /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. - pub w: f32, - /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). - pub p: f32, - /// Children: `(action_index, child_node)`, populated on first expansion. - pub children: Vec<(usize, MctsNode)>, - /// `true` after the network has been evaluated and children have been set up. - pub expanded: bool, -} - -impl MctsNode { - /// Create a fresh, unexpanded leaf with the given prior probability. - pub fn new(prior: f32) -> Self { - Self { - n: 0, - w: 0.0, - p: prior, - children: Vec::new(), - expanded: false, - } - } - - /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. - #[inline] - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT selection score: - /// - /// ```text - /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) - /// ``` - #[inline] - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn q_zero_when_unvisited() { - let node = MctsNode::new(0.5); - assert_eq!(node.q(), 0.0); - } - - #[test] - fn q_reflects_w_over_n() { - let mut node = MctsNode::new(0.5); - node.n = 4; - node.w = 2.0; - assert!((node.q() - 0.5).abs() < 1e-6); - } - - #[test] - fn puct_exploration_dominates_unvisited() { - // Unvisited child should outscore a visited child with negative Q. - let mut visited = MctsNode::new(0.5); - visited.n = 10; - visited.w = -5.0; // Q = -0.5 - - let unvisited = MctsNode::new(0.5); - - let parent_n = 10; - let c = 1.5; - assert!( - unvisited.puct(parent_n, c) > visited.puct(parent_n, c), - "unvisited child should have higher PUCT than a negatively-valued visited child" - ); - } -} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs deleted file mode 100644 index c4960c7..0000000 --- a/spiel_bot/src/mcts/search.rs +++ /dev/null @@ -1,170 +0,0 @@ -//! Simulation, expansion, backup, and noise helpers. -//! -//! These are internal to the `mcts` module; the public entry points are -//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. - -use rand::Rng; -use rand_distr::{Gamma, Distribution}; - -use crate::env::GameEnv; -use super::{Evaluator, MctsConfig}; -use super::node::MctsNode; - -// ── Masked softmax ───────────────────────────────────────────────────────── - -/// Numerically stable softmax over `legal` actions only. -/// -/// Illegal logits are treated as `-∞` and receive probability `0.0`. -/// Returns a probability vector of length `action_space`. -pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { - let mut probs = vec![0.0f32; action_space]; - if legal.is_empty() { - return probs; - } - let max_logit = legal - .iter() - .map(|&a| logits[a]) - .fold(f32::NEG_INFINITY, f32::max); - let mut sum = 0.0f32; - for &a in legal { - let e = (logits[a] - max_logit).exp(); - probs[a] = e; - sum += e; - } - if sum > 0.0 { - for &a in legal { - probs[a] /= sum; - } - } else { - let uniform = 1.0 / legal.len() as f32; - for &a in legal { - probs[a] = uniform; - } - } - probs -} - -// ── Dirichlet noise ──────────────────────────────────────────────────────── - -/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. -/// -/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. -/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. -pub(super) fn add_dirichlet_noise( - node: &mut MctsNode, - alpha: f32, - eps: f32, - rng: &mut impl Rng, -) { - let n = node.children.len(); - if n == 0 { - return; - } - let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { - return; - }; - let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); - let sum: f32 = samples.iter().sum(); - if sum <= 0.0 { - return; - } - for (i, (_, child)) in node.children.iter_mut().enumerate() { - let noise = samples[i] / sum; - child.p = (1.0 - eps) * child.p + eps * noise; - } -} - -// ── Expansion ────────────────────────────────────────────────────────────── - -/// Evaluate the network at `state` and populate `node` with children. -/// -/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. -/// Returns the network value estimate from `player_idx`'s perspective. -pub(super) fn expand( - node: &mut MctsNode, - state: &E::State, - env: &E, - evaluator: &dyn Evaluator, - player_idx: usize, -) -> f32 { - let obs = env.observation(state, player_idx); - let legal = env.legal_actions(state); - let (logits, value) = evaluator.evaluate(&obs); - let priors = masked_softmax(&logits, &legal, env.action_space()); - node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); - node.expanded = true; - node.n = 1; - node.w = value; - value -} - -// ── Simulation ───────────────────────────────────────────────────────────── - -/// One MCTS simulation from an **already-expanded** decision node. -/// -/// Traverses the tree with PUCT selection, expands the first unvisited leaf, -/// and backs up the result. -/// -/// * `player_idx` — the player (0 or 1) who acts at `state`. -/// * Returns the backed-up value **from `player_idx`'s perspective**. -pub(super) fn simulate( - node: &mut MctsNode, - state: E::State, - env: &E, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, - player_idx: usize, -) -> f32 { - debug_assert!(node.expanded, "simulate called on unexpanded node"); - - // ── Selection: child with highest PUCT ──────────────────────────────── - let parent_n = node.n; - let best = node - .children - .iter() - .enumerate() - .max_by(|(_, (_, a)), (_, (_, b))| { - a.puct(parent_n, config.c_puct) - .partial_cmp(&b.puct(parent_n, config.c_puct)) - .unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(i, _)| i) - .expect("expanded node must have at least one child"); - - let (action, child) = &mut node.children[best]; - let action = *action; - - // ── Apply action + advance through any chance nodes ─────────────────── - let mut next_state = state; - env.apply(&mut next_state, action); - while env.current_player(&next_state).is_chance() { - env.apply_chance(&mut next_state, rng); - } - - let next_cp = env.current_player(&next_state); - - // ── Evaluate leaf or terminal ────────────────────────────────────────── - // All values are converted to `player_idx`'s perspective before backup. - let child_value = if next_cp.is_terminal() { - let returns = env - .returns(&next_state) - .expect("terminal node must have returns"); - returns[player_idx] - } else { - let child_player = next_cp.index().unwrap(); - let v = if child.expanded { - simulate(child, next_state, env, evaluator, config, rng, child_player) - } else { - expand::(child, &next_state, env, evaluator, child_player) - }; - // Negate when the child belongs to the opponent. - if child_player == player_idx { v } else { -v } - }; - - // ── Backup ──────────────────────────────────────────────────────────── - node.n += 1; - node.w += child_value; - - child_value -} diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs deleted file mode 100644 index eb6184e..0000000 --- a/spiel_bot/src/network/mlp.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! Two-hidden-layer MLP policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`MlpNet`]. -#[derive(Debug, Clone)] -pub struct MlpConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for MlpConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 256, - } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Simple two-hidden-layer MLP with shared trunk and two heads. -/// -/// Prefer this over [`ResNet`](super::ResNet) when training time is a -/// priority, or as a fast baseline. -#[derive(Module, Debug)] -pub struct MlpNet { - fc1: Linear, - fc2: Linear, - policy_head: Linear, - value_head: Linear, -} - -impl MlpNet { - /// Construct a fresh network with random weights. - pub fn new(config: &MlpConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - value_head: LinearConfig::new(config.hidden_size, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - /// - /// The file is written exactly at `path`; callers should append `.mpk` if - /// they want the conventional extension. - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for MlpNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn default_net() -> MlpNet { - MlpNet::new(&MlpConfig::default(), &device()) - } - - fn zeros_obs(batch: usize) -> Tensor { - Tensor::zeros([batch, 217], &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let net = default_net(); - let obs = zeros_obs(4); - let (policy, value) = net.forward(obs); - - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let (policy, value) = net.forward(zeros_obs(1)); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let net = default_net(); - // Use a non-zero input so the output is not trivially at 0. - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net.forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Policy logits ───────────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - // With random weights the 514 logits should not all be identical. - let net = default_net(); - let (policy, _) = net.forward(zeros_obs(1)); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical — network may be degenerate"); - } - - // ── Config propagation ──────────────────────────────────────────────── - - #[test] - fn custom_config_shapes() { - let config = MlpConfig { - obs_size: 10, - action_size: 20, - hidden_size: 32, - }; - let net = MlpNet::::new(&config, &device()); - let obs = Tensor::zeros([3, 10], &device()); - let (policy, value) = net.forward(obs); - assert_eq!(policy.dims(), [3, 20]); - assert_eq!(value.dims(), [3, 1]); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = MlpConfig::default(); - let net = default_net(); - - // Forward pass before saving. - let obs = Tensor::::ones([2, 217], &device()); - let (policy_before, value_before) = net.forward(obs.clone()); - - // Save to a temp file. - let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); - net.save(&path).expect("save failed"); - - // Load into a fresh model. - let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - // Outputs must be bitwise identical. - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } -} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs deleted file mode 100644 index df710e9..0000000 --- a/spiel_bot/src/network/mod.rs +++ /dev/null @@ -1,64 +0,0 @@ -//! Neural network abstractions for policy-value learning. -//! -//! # Trait -//! -//! [`PolicyValueNet`] is the single trait that all network architectures -//! implement. It takes an observation tensor and returns raw policy logits -//! plus a tanh-squashed scalar value estimate. -//! -//! # Architectures -//! -//! | Module | Description | Default hidden | -//! |--------|-------------|----------------| -//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | -//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | -//! -//! # Backend convention -//! -//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). -//! * **Training** — use `Autodiff>` so Burn can differentiate -//! through the forward pass. -//! -//! Both modes use the exact same struct; only the type-level backend changes: -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! type InferBackend = NdArray; -//! type TrainBackend = Autodiff>; -//! -//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! ``` -//! -//! # Output shapes -//! -//! Given a batch of `B` observations of size `obs_size`: -//! -//! | Output | Shape | Range | -//! |--------|-------|-------| -//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | -//! | `value` | `[B, 1]` | (-1, 1) via tanh | -//! -//! Callers are responsible for masking illegal actions in `policy_logits` -//! before passing to softmax. - -pub mod mlp; -pub mod resnet; - -pub use mlp::{MlpConfig, MlpNet}; -pub use resnet::{ResNet, ResNetConfig}; - -use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; - -/// A neural network that produces a policy and a value from an observation. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) -/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) -/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses -/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. -/// Use an `Arc>` wrapper if cross-thread sharing is needed. -pub trait PolicyValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); -} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs deleted file mode 100644 index d20d5ad..0000000 --- a/spiel_bot/src/network/resnet.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Residual-block policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU (input projection) -//! → ResBlock × 4 (residual trunk) -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! -//! ResBlock: -//! x → Linear → ReLU → Linear → (+x) → ReLU -//! ``` -//! -//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better -//! suited for long training runs where board-pattern recognition matters. - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`ResNet`]. -#[derive(Debug, Clone)] -pub struct ResNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of all hidden layers (input projection + residual blocks). - pub hidden_size: usize, -} - -impl Default for ResNetConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 512, - } - } -} - -// ── Residual block ──────────────────────────────────────────────────────────── - -/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. -/// -/// Both linear layers preserve the hidden dimension so the skip connection -/// can be added without projection. -#[derive(Module, Debug)] -struct ResBlock { - fc1: Linear, - fc2: Linear, -} - -impl ResBlock { - fn new(hidden: usize, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(hidden, hidden).init(device), - fc2: LinearConfig::new(hidden, hidden).init(device), - } - } - - fn forward(&self, x: Tensor) -> Tensor { - let residual = x.clone(); - let out = relu(self.fc1.forward(x)); - relu(self.fc2.forward(out) + residual) - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Four-residual-block policy-value network. -/// -/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and -/// when representing complex positional patterns is important. -#[derive(Module, Debug)] -pub struct ResNet { - input: Linear, - block0: ResBlock, - block1: ResBlock, - block2: ResBlock, - block3: ResBlock, - policy_head: Linear, - value_head: Linear, -} - -impl ResNet { - /// Construct a fresh network with random weights. - pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { - let h = config.hidden_size; - Self { - input: LinearConfig::new(config.obs_size, h).init(device), - block0: ResBlock::new(h, device), - block1: ResBlock::new(h, device), - block2: ResBlock::new(h, device), - block3: ResBlock::new(h, device), - policy_head: LinearConfig::new(h, config.action_size).init(device), - value_head: LinearConfig::new(h, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for ResNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.input.forward(obs)); - let x = self.block0.forward(x); - let x = self.block1.forward(x); - let x = self.block2.forward(x); - let x = self.block3.forward(x); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn small_config() -> ResNetConfig { - // Use a small hidden size so tests are fast. - ResNetConfig { - obs_size: 217, - action_size: 514, - hidden_size: 64, - } - } - - fn net() -> ResNet { - ResNet::new(&small_config(), &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let obs = Tensor::zeros([4, 217], &device()); - let (policy, value) = net().forward(obs); - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net().forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Residual connections ────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical"); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = small_config(); - let model = net(); - let obs = Tensor::::ones([2, 217], &device()); - - let (policy_before, value_before) = model.forward(obs.clone()); - - let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); - model.save(&path).expect("save failed"); - - let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } - - // ── Integration: both architectures satisfy PolicyValueNet ──────────── - - #[test] - fn resnet_satisfies_trait() { - fn requires_net>(net: &N, obs: Tensor) { - let (p, v) = net.forward(obs); - assert_eq!(p.dims()[1], 514); - assert_eq!(v.dims()[1], 1); - } - requires_net(&net(), Tensor::zeros([2, 217], &device())); - } -} diff --git a/store/src/game.rs b/store/src/game.rs index 2fde45c..f553bdb 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -225,26 +225,22 @@ impl GameState { let mut t = Vec::with_capacity(217); let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - // [0..95] own (White) checkers, TD-Gammon encoding. - // Each field contributes 4 values: - // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] - // The overflow term is divided by 12 because the maximum excess is - // 15 (all checkers) − 3 = 12. + // [0..95] own (White) checkers, TD-Gammon encoding for &c in &pos { let own = c.max(0) as u8; t.push((own == 1) as u8 as f32); t.push((own == 2) as u8 as f32); t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32 / 12.0); + t.push(own.saturating_sub(3) as f32); } - // [96..191] opp (Black) checkers, same encoding. + // [96..191] opp (Black) checkers, TD-Gammon encoding for &c in &pos { let opp = (-c).max(0) as u8; t.push((opp == 1) as u8 as f32); t.push((opp == 2) as u8 as f32); t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32 / 12.0); + t.push(opp.saturating_sub(3) as f32); } // [192..193] dice