diff --git a/Cargo.lock b/Cargo.lock index fa260cd..a43261e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,12 +92,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - [[package]] name = "anstream" version = "0.6.21" @@ -1122,12 +1116,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - [[package]] name = "cast_trait" version = "0.1.2" @@ -1212,33 +1200,6 @@ dependencies = [ "rand 0.7.3", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1492,42 +1453,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot", - "is-terminal", - "itertools 0.10.5", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools 0.10.5", -] - [[package]] name = "critical-section" version = "1.2.0" @@ -4536,12 +4461,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" -[[package]] -name = "oorandom" -version = "11.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" - [[package]] name = "opaque-debug" version = "0.3.1" @@ -4678,34 +4597,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" -[[package]] -name = "plotters" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" - -[[package]] -name = "plotters-svg" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" -dependencies = [ - "plotters-backend", -] - [[package]] name = "png" version = "0.18.0" @@ -6000,19 +5891,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "spiel_bot" -version = "0.1.0" -dependencies = [ - "anyhow", - "burn", - "criterion", - "rand 0.9.2", - "rand_distr", - "trictrac-bot", - "trictrac-store", -] - [[package]] name = "spin" version = "0.10.0" @@ -6421,16 +6299,6 @@ dependencies = [ "zerovec", ] -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "tinyvec" version = "1.10.0" @@ -6855,7 +6723,6 @@ dependencies = [ "pico-args", "pretty_assertions", "renet", - "spiel_bot", "trictrac-bot", "trictrac-store", ] diff --git a/Cargo.toml b/Cargo.toml index 4c2eb15..b9e6d45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] +members = ["client_cli", "bot", "store"] diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index 52318cb..e48a249 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -3,9 +3,7 @@ name = "trictrac-client_cli" version = "0.1.0" edition = "2021" -[[bin]] -name = "client_cli" -path = "src/main.rs" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] anyhow = "1.0.75" @@ -14,8 +12,7 @@ pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" trictrac-store = { path = "../store" } -trictrac-bot = { path = "../bot" } -spiel_bot = { path = "../spiel_bot" } +trictrac-bot = { path = "../bot" } itertools = "0.13.0" env_logger = "0.11.6" log = "0.4.20" diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index ab61451..b803efe 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,4 +1,3 @@ -use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy}; use trictrac_bot::{ BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, @@ -57,27 +56,6 @@ impl App { Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string())) as Box) } - "az" => { - Some(Box::new(AzBotStrategy::new_mlp(None)) as Box) - } - s if s.starts_with("az:") && !s.starts_with("az-") => { - let path = s.trim_start_matches("az:"); - Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box) - } - "az-resnet" => { - Some(Box::new(AzBotStrategy::new_resnet(None)) as Box) - } - s if s.starts_with("az-resnet:") => { - let path = s.trim_start_matches("az-resnet:"); - Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box) - } - "az-dqn" => { - Some(Box::new(DqnSpielBotStrategy::new(None)) as Box) - } - s if s.starts_with("az-dqn:") => { - let path = s.trim_start_matches("az-dqn:"); - Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box) - } _ => None, }) .collect() diff --git a/client_cli/src/main.rs b/client_cli/src/main.rs index e06299b..0107b43 100644 --- a/client_cli/src/main.rs +++ b/client_cli/src/main.rs @@ -23,14 +23,8 @@ OPTIONS: - dummy: Default strategy selecting the first valid move - ai: AI strategy using the default model at models/trictrac_ppo.zip - ai:/path/to/model.zip: AI strategy using a custom model - - dqnburn: DQN strategy (burn-rl backend) - - dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model - - az: AlphaZero MlpNet (random weights) - - az:/path/to/model.mpk: AlphaZero MlpNet checkpoint - - az-resnet: AlphaZero ResNet (random weights) - - az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint - - az-dqn: DQN QNet (random weights, first-legal-move fallback) - - az-dqn:/path/to/model.mpk: DQN QNet checkpoint + - dqn: DQN strategy using native Rust implementation with Burn + - dqn:/path/to/model: DQN strategy using a custom model ARGS: diff --git a/doc/spiel_bot_research.md b/doc/spiel_bot_research.md deleted file mode 100644 index a8863af..0000000 --- a/doc/spiel_bot_research.md +++ /dev/null @@ -1,782 +0,0 @@ -# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac - -## 0. Context and Scope - -The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library -(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()` -encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every -other stage to an inline random-opponent loop. - -`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency -for **self-play training**. Its goals: - -- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel") - that works with Trictrac's multi-stage turn model and stochastic dice. -- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer) - as the first algorithm. -- Remain **modular**: adding DQN or PPO later requires only a new - `impl Algorithm for Dqn` without touching the environment or network layers. -- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from - `trictrac-store`. - ---- - -## 1. Library Landscape - -### 1.1 Neural Network Frameworks - -| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes | -| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- | -| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` | -| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance | -| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training | -| ndarray alone | no | no | yes | mature | array ops only; no autograd | - -**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++ -runtime needed, the `ndarray` backend is sufficient for CPU training and can -switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by -changing one type alias. - -`tch-rs` would be the best choice for raw training throughput (it is the most -battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the -pure-Rust constraint. If training speed becomes the bottleneck after prototyping, -switching `spiel_bot` to `tch-rs` is a one-line backend swap. - -### 1.2 Other Key Crates - -| Crate | Role | -| -------------------- | ----------------------------------------------------------------- | -| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) | -| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` | -| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) | -| `serde / serde_json` | replay buffer snapshots, checkpoint metadata | -| `anyhow` | error propagation (already used everywhere) | -| `indicatif` | training progress bars | -| `tracing` | structured logging per episode/iteration | - -### 1.3 What `burn-rl` Provides (and Does Not) - -The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`) -provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}` -trait. It does **not** provide: - -- MCTS or any tree-search algorithm -- Two-player self-play -- Legal action masking during training -- Chance-node handling - -For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its -own (simpler, more targeted) traits and implement MCTS from scratch. - ---- - -## 2. Trictrac-Specific Design Constraints - -### 2.1 Multi-Stage Turn Model - -A Trictrac turn passes through up to six `TurnStage` values. Only two involve -genuine player choice: - -| TurnStage | Node type | Handler | -| ---------------- | ------------------------------- | ------------------------------- | -| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` | -| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` | -| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` | -| `HoldOrGoChoice` | **Player decision** | MCTS / policy network | -| `Move` | **Player decision** | MCTS / policy network | -| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` | - -The environment wrapper advances through forced/chance stages automatically so -that from the algorithm's perspective every node it sees is a genuine player -decision. - -### 2.2 Stochastic Dice in MCTS - -AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice -introduce stochasticity. Three approaches exist: - -**A. Outcome sampling (recommended)** -During each MCTS simulation, when a chance node is reached, sample one dice -outcome at random and continue. After many simulations the expected value -converges. This is the approach used by OpenSpiel's MCTS for stochastic games -and requires no changes to the standard PUCT formula. - -**B. Chance-node averaging (expectimax)** -At each chance node, expand all 21 unique dice pairs weighted by their -probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is -exact but multiplies the branching factor by ~21 at every dice roll, making it -prohibitively expensive. - -**C. Condition on dice in the observation (current approach)** -Dice values are already encoded at indices [192–193] of `to_tensor()`. The -network naturally conditions on the rolled dice when it evaluates a position. -MCTS only runs on player-decision nodes _after_ the dice have been sampled; -chance nodes are bypassed by the environment wrapper (approach A). The policy -and value heads learn to play optimally given any dice pair. - -**Use approach A + C together**: the environment samples dice automatically -(chance node bypass), and the 217-dim tensor encodes the dice so the network -can exploit them. - -### 2.3 Perspective / Mirroring - -All move rules and tensor encoding are defined from White's perspective. -`to_tensor()` must always be called after mirroring the state for Black. -The environment wrapper handles this transparently: every observation returned -to an algorithm is already in the active player's perspective. - -### 2.4 Legal Action Masking - -A crucial difference from the existing `bot/` code: instead of penalizing -invalid actions with `ERROR_REWARD`, the policy head logits are **masked** -before softmax — illegal action logits are set to `-inf`. This prevents the -network from wasting capacity on illegal moves and eliminates the need for the -penalty-reward hack. - ---- - -## 3. Proposed Crate Architecture - -``` -spiel_bot/ -├── Cargo.toml -└── src/ - ├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo" - │ - ├── env/ - │ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface - │ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store - │ - ├── mcts/ - │ ├── mod.rs # MctsConfig, run_mcts() entry point - │ ├── node.rs # MctsNode (visit count, W, prior, children) - │ └── search.rs # simulate(), backup(), select_action() - │ - ├── network/ - │ ├── mod.rs # PolicyValueNet trait - │ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads - │ - ├── alphazero/ - │ ├── mod.rs # AlphaZeroConfig - │ ├── selfplay.rs # generate_episode() -> Vec - │ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle) - │ └── trainer.rs # training loop: selfplay → sample → loss → update - │ - └── agent/ - ├── mod.rs # Agent trait - ├── random.rs # RandomAgent (baseline) - └── mcts_agent.rs # MctsAgent: uses trained network for inference -``` - -Future algorithms slot in without touching the above: - -``` - ├── dqn/ # (future) DQN: impl Algorithm + own replay buffer - └── ppo/ # (future) PPO: impl Algorithm + rollout buffer -``` - ---- - -## 4. Core Traits - -### 4.1 `GameEnv` — the minimal OpenSpiel interface - -```rust -use rand::Rng; - -/// Who controls the current node. -pub enum Player { - P1, // player index 0 - P2, // player index 1 - Chance, // dice roll - Terminal, // game over -} - -pub trait GameEnv: Clone + Send + Sync + 'static { - type State: Clone + Send + Sync; - - /// Fresh game state. - fn new_game(&self) -> Self::State; - - /// Who acts at this node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices (always in [0, action_space())). - /// Empty only at Terminal nodes. - fn legal_actions(&self, s: &Self::State) -> Vec; - - /// Apply a player action (must be legal). - fn apply(&self, s: &mut Self::State, action: usize); - - /// Advance a Chance node by sampling dice; no-op at non-Chance nodes. - fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng); - - /// Observation tensor from `pov`'s perspective (0 or 1). - /// Returns 217 f32 values for Trictrac. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Flat observation size (217 for Trictrac). - fn obs_size(&self) -> usize; - - /// Total action-space size (514 for Trictrac). - fn action_space(&self) -> usize; - - /// Game outcome per player, or None if not Terminal. - /// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} -``` - -### 4.2 `PolicyValueNet` — neural network interface - -```rust -use burn::prelude::*; - -pub trait PolicyValueNet: Send + Sync { - /// Forward pass. - /// `obs`: [batch, obs_size] tensor. - /// Returns: (policy_logits [batch, action_space], value [batch]). - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); - - /// Save weights to `path`. - fn save(&self, path: &std::path::Path) -> anyhow::Result<()>; - - /// Load weights from `path`. - fn load(path: &std::path::Path) -> anyhow::Result - where - Self: Sized; -} -``` - -### 4.3 `Agent` — player policy interface - -```rust -pub trait Agent: Send { - /// Select an action index given the current game state observation. - /// `legal`: mask of valid action indices. - fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize; -} -``` - ---- - -## 5. MCTS Implementation - -### 5.1 Node - -```rust -pub struct MctsNode { - n: u32, // visit count N(s, a) - w: f32, // sum of backed-up values W(s, a) - p: f32, // prior from policy head P(s, a) - children: Vec<(usize, MctsNode)>, // (action_idx, child) - is_expanded: bool, -} - -impl MctsNode { - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT score used for selection. - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} -``` - -### 5.2 Simulation Loop - -One MCTS simulation (for deterministic decision nodes): - -``` -1. SELECTION — traverse from root, always pick child with highest PUCT, - auto-advancing forced/chance nodes via env.apply_chance(). -2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get - (policy_logits, value). Mask illegal actions, softmax - the remaining logits → priors P(s,a) for each child. -3. BACKUP — propagate -value up the tree (negate at each level because - perspective alternates between P1 and P2). -``` - -After `n_simulations` iterations, action selection at the root: - -```rust -// During training: sample proportional to N^(1/temperature) -// During evaluation: argmax N -fn select_action(root: &MctsNode, temperature: f32) -> usize { ... } -``` - -### 5.3 Configuration - -```rust -pub struct MctsConfig { - pub n_simulations: usize, // e.g. 200 - pub c_puct: f32, // exploration constant, e.g. 1.5 - pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3 - pub dirichlet_eps: f32, // noise weight, e.g. 0.25 - pub temperature: f32, // action sampling temperature (anneals to 0) -} -``` - -### 5.4 Handling Chance Nodes Inside MCTS - -When simulation reaches a Chance node (dice roll), the environment automatically -samples dice and advances to the next decision node. The MCTS tree does **not** -branch on dice outcomes — it treats the sampled outcome as the state. This -corresponds to "outcome sampling" (approach A from §2.2). Because each -simulation independently samples dice, the Q-values at player nodes converge to -their expected value over many simulations. - ---- - -## 6. Network Architecture - -### 6.1 ResNet Policy-Value Network - -A single trunk with residual blocks, then two heads: - -``` -Input: [batch, 217] - ↓ -Linear(217 → 512) + ReLU - ↓ -ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU) - ↓ trunk output [batch, 512] - ├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site) - └─ Value head: Linear(512 → 1) → tanh (output in [-1, 1]) -``` - -Burn implementation sketch: - -```rust -#[derive(Module, Debug)] -pub struct TrictracNet { - input: Linear, - res_blocks: Vec>, - policy_head: Linear, - value_head: Linear, -} - -impl TrictracNet { - pub fn forward(&self, obs: Tensor) - -> (Tensor, Tensor) - { - let x = activation::relu(self.input.forward(obs)); - let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x)); - let policy = self.policy_head.forward(x.clone()); // raw logits - let value = activation::tanh(self.value_head.forward(x)) - .squeeze(1); - (policy, value) - } -} -``` - -A simpler MLP (no residual blocks) is sufficient for a first version and much -faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two -heads. - -### 6.2 Loss Function - -``` -L = MSE(value_pred, z) - + CrossEntropy(policy_logits_masked, π_mcts) - - c_l2 * L2_regularization -``` - -Where: - -- `z` = game outcome (±1) from the active player's perspective -- `π_mcts` = normalized MCTS visit counts at the root (the policy target) -- Legal action masking is applied before computing CrossEntropy - ---- - -## 7. AlphaZero Training Loop - -``` -INIT - network ← random weights - replay ← empty ReplayBuffer(capacity = 100_000) - -LOOP forever: - ── Self-play phase ────────────────────────────────────────────── - (parallel with rayon, n_workers games at once) - for each game: - state ← env.new_game() - samples = [] - while not terminal: - advance forced/chance nodes automatically - obs ← env.observation(state, current_player) - legal ← env.legal_actions(state) - π, root_value ← mcts.run(state, network, config) - action ← sample from π (with temperature) - samples.push((obs, π, current_player)) - env.apply(state, action) - z ← env.returns(state) // final scores - for (obs, π, player) in samples: - replay.push(TrainSample { obs, policy: π, value: z[player] }) - - ── Training phase ─────────────────────────────────────────────── - for each gradient step: - batch ← replay.sample(batch_size) - (policy_logits, value_pred) ← network.forward(batch.obs) - loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy) - optimizer.step(loss.backward()) - - ── Evaluation (every N iterations) ───────────────────────────── - win_rate ← evaluate(network_new vs network_prev, n_eval_games) - if win_rate > 0.55: save checkpoint -``` - -### 7.1 Replay Buffer - -```rust -pub struct TrainSample { - pub obs: Vec, // 217 values - pub policy: Vec, // 514 values (normalized MCTS visit counts) - pub value: f32, // game outcome ∈ {-1, 0, +1} -} - -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - pub fn push(&mut self, s: TrainSample) { - if self.data.len() == self.capacity { self.data.pop_front(); } - self.data.push_back(s); - } - - pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - // sample without replacement - } -} -``` - -### 7.2 Parallelism Strategy - -Self-play is embarrassingly parallel (each game is independent): - -```rust -let samples: Vec = (0..n_games) - .into_par_iter() // rayon - .flat_map(|_| generate_episode(&env, &network, &mcts_config)) - .collect(); -``` - -Note: Burn's `NdArray` backend is not `Send` by default when using autodiff. -Self-play uses inference-only (no gradient tape), so a `NdArray` backend -(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with -`Autodiff>`. - -For larger scale, a producer-consumer architecture (crossbeam-channel) separates -self-play workers from the training thread, allowing continuous data generation -while the GPU trains. - ---- - -## 8. `TrictracEnv` Implementation Sketch - -```rust -use trictrac_store::{ - training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE}, - Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage, -}; - -#[derive(Clone)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - fn current_player(&self, s: &GameState) -> Player { - match s.stage { - Stage::Ended => Player::Terminal, - _ => match s.turn_stage { - TurnStage::RollWaiting => Player::Chance, - _ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 }, - }, - } - } - - fn legal_actions(&self, s: &GameState) -> Vec { - let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() }; - get_valid_action_indices(&view).unwrap_or_default() - } - - fn apply(&self, s: &mut GameState, action_idx: usize) { - // advance all forced/chance nodes first, then apply the player action - self.advance_forced(s); - let needs_mirror = s.active_player_id == 2; - let view = if needs_mirror { s.mirror() } else { s.clone() }; - if let Some(event) = TrictracAction::from_action_index(action_idx) - .and_then(|a| a.to_event(&view)) - .map(|e| if needs_mirror { e.get_mirror(false) } else { e }) - { - let _ = s.consume(&event); - } - // advance any forced stages that follow - self.advance_forced(s); - } - - fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) { - // RollDice → RollWaiting - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - // RollWaiting → next stage - let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) }; - let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice }); - self.advance_forced(s); - } - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() } - } - - fn obs_size(&self) -> usize { 217 } - fn action_space(&self) -> usize { ACTION_SPACE_SIZE } - - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { return None; } - // Convert hole+point scores to ±1 outcome - let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [ 1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [ 0.0, 0.0], - }) - } -} - -impl TrictracEnv { - /// Advance through all forced (non-decision, non-chance) stages. - fn advance_forced(&self, s: &mut GameState) { - use trictrac_store::PointsRules; - loop { - match s.turn_stage { - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Scoring is deterministic; compute and apply automatically. - let color = s.player_color_by_id(&s.active_player_id) - .unwrap_or(trictrac_store::Color::White); - let drc = s.players.get(&s.active_player_id) - .map(|p| p.dice_roll_count).unwrap_or(0); - let pr = PointsRules::new(&color, &s.board, s.dice); - let pts = pr.get_points(drc); - let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 }; - let _ = s.consume(&GameEvent::Mark { - player_id: s.active_player_id, points, - }); - } - TurnStage::RollDice => { - // RollDice is a forced "initiate roll" action with no real choice. - let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id }); - } - _ => break, - } - } - } -} -``` - ---- - -## 9. Cargo.toml Changes - -### 9.1 Add `spiel_bot` to the workspace - -```toml -# Cargo.toml (workspace root) -[workspace] -resolver = "2" -members = ["client_cli", "bot", "store", "spiel_bot"] -``` - -### 9.2 `spiel_bot/Cargo.toml` - -```toml -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[features] -default = ["alphazero"] -alphazero = [] -# dqn = [] # future -# ppo = [] # future - -[dependencies] -trictrac-store = { path = "../store" } -anyhow = "1" -rand = "0.9" -rayon = "1" -serde = { version = "1", features = ["derive"] } -serde_json = "1" - -# Burn: NdArray for pure-Rust CPU training -# Replace NdArray with Wgpu or Tch for GPU. -burn = { version = "0.20", features = ["ndarray", "autodiff"] } - -# Optional: progress display and structured logging -indicatif = "0.17" -tracing = "0.1" - -[[bin]] -name = "az_train" -path = "src/bin/az_train.rs" - -[[bin]] -name = "az_eval" -path = "src/bin/az_eval.rs" -``` - ---- - -## 10. Comparison: `bot` crate vs `spiel_bot` - -| Aspect | `bot` (existing) | `spiel_bot` (proposed) | -| ---------------- | --------------------------- | -------------------------------------------- | -| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` | -| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) | -| Opponent | hardcoded random | self-play | -| Invalid actions | penalise with reward | legal action mask (no penalty) | -| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait | -| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper | -| Burn dep | yes (0.20) | yes (0.20), same backend | -| `burn-rl` dep | yes | no | -| C++ dep | no | no | -| Python dep | no | no | -| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin | - -The two crates are **complementary**: `bot` is a working DQN/PPO baseline; -`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The -`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just -replace `TrictracEnvironment` with `TrictracEnv`). - ---- - -## 11. Implementation Order - -1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game - through the trait, verify terminal state and returns). -2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) + - Burn forward/backward pass test with dummy data. -3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests - (visit counts sum to `n_simulations`, legal mask respected). -4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub - (one iteration, check loss decreases). -5. **Integration test**: run 100 self-play games with a tiny network (1 res block, - 64 hidden units), verify the training loop completes without panics. -6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s - on CPU, consistent with `random_game` throughput). -7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule - hyperparameter sweep. -8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report - win rate every checkpoint. - ---- - -## 12. Key Open Questions - -1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded. - AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever - scored more holes). Better option: normalize the score margin. The final - choice affects how the value head is trained. - -2. **Episode length**: Trictrac games average ~600 steps (`random_game` data). - MCTS with 200 simulations per step means ~120k network evaluations per game. - At batch inference this is feasible on CPU; on GPU it becomes fast. Consider - limiting `n_simulations` to 50–100 for early training. - -3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé). - This is a long-horizon decision that AlphaZero handles naturally via MCTS - lookahead, but needs careful value normalization (a "Go" restarts scoring - within the same game). - -4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated - to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic. - This is optional but reduces code duplication. - -5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess, - 0.03 for Go. For Trictrac with action space 514, empirical tuning is needed. - A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1. - -## Implementation results - -All benchmarks compile and run. Here's the complete results table: - -| Group | Benchmark | Time | -| ------- | ----------------------- | --------------------- | -| env | apply_chance | 3.87 µs | -| | legal_actions | 1.91 µs | -| | observation (to_tensor) | 341 ns | -| | random_game (baseline) | 3.55 ms → 282 games/s | -| network | mlp_b1 hidden=64 | 94.9 µs | -| | mlp_b32 hidden=64 | 141 µs | -| | mlp_b1 hidden=256 | 352 µs | -| | mlp_b32 hidden=256 | 479 µs | -| mcts | zero_eval n=1 | 6.8 µs | -| | zero_eval n=5 | 23.9 µs | -| | zero_eval n=20 | 90.9 µs | -| | mlp64 n=1 | 203 µs | -| | mlp64 n=5 | 622 µs | -| | mlp64 n=20 | 2.30 ms | -| episode | trictrac n=1 | 51.8 ms → 19 games/s | -| | trictrac n=2 | 145 ms → 7 games/s | -| train | mlp64 Adam b=16 | 1.93 ms | -| | mlp64 Adam b=64 | 2.68 ms | - -Key observations: - -- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game) -- observation (217-value tensor): only 341 ns — not a bottleneck -- legal_actions: 1.9 µs — well optimised -- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs -- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal -- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it -- Training: 2.7 ms/step at batch=64 → 370 steps/s - -### Summary of Step 8 - -spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary: - -- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct -- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%) -- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side -- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise) -- Output: win/draw/loss per side + combined decisive win rate - -Typical usage after training: -cargo run -p spiel_bot --bin az_eval --release -- \ - --checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100 - -### az_train - -#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10) - -cargo run -p spiel_bot --bin az_train --release - -#### ResNet, more games, custom output dir - -cargo run -p spiel_bot --bin az_train --release -- \ - --arch resnet --n-iter 200 --n-games 20 --n-sim 100 \ - --save-every 20 --out checkpoints/ - -#### Resume from iteration 50 - -cargo run -p spiel_bot --bin az_train --release -- \ - --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50 - -What the binary does each iteration: - -1. Calls model.valid() to get a zero-overhead inference copy for self-play -2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy) -3. Pushes samples into a ReplayBuffer (capacity --replay-cap) -4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min -5. Saves a .mpk checkpoint every --save-every iterations and always on the last diff --git a/doc/tensor_research.md b/doc/tensor_research.md deleted file mode 100644 index b0d0ede..0000000 --- a/doc/tensor_research.md +++ /dev/null @@ -1,253 +0,0 @@ -# Tensor research - -## Current tensor anatomy - -[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) -[24] active player color: 0 or 1 -[25] turn_stage: 1–5 -[26–27] dice values (raw 1–6) -[28–31] white: points, holes, can_bredouille, can_big_bredouille -[32–35] black: same -───────────────────────────────── -Total 36 floats - -The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's -AlphaZero uses a fully-connected network. - -### Fundamental problems with the current encoding - -1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network - must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with - all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. - -2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient - flow during training is uneven. - -3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it - triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers - produces a score. Including this explicitly is the single highest-value addition. - -4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely - different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. - -5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the - starting position). It's in the Player struct but not exported. - -## Key Trictrac distinctions from backgammon that shape the encoding - -| Concept | Backgammon | Trictrac | -| ------------------------- | ---------------------- | --------------------------------------------------------- | -| Hitting a blot | Removes checker to bar | Scores points, checker stays | -| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | -| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | -| 3-checker field | Safe with spare | Safe with spare | -| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | -| Both colors on a field | Impossible | Perfectly legal | -| Rest corner (field 12/13) | Does not exist | Special two-checker rules | - -The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary -indicators directly teaches the network the phase transitions the game hinges on. - -## Options - -### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) - -The minimum viable improvement. - -For each of the 24 fields, encode own and opponent separately with 4 indicators each: - -own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) -own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) -own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) -own_x[i]: max(0, count − 3) (overflow) -opp_1[i]: same for opponent -… - -Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). - -Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 -Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured -overhead is negligible. -Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from -scratch will be substantially faster and the converged policy quality better. - -### Option B — Option A + Trictrac-specific derived features (flat 1D) - -Recommended starting point. - -Add on top of Option A: - -// Quarter fill status — the single most important derived feature -quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) -quarter_filled_opp[q] (q=0..3): same for opponent -→ 8 values - -// Exit readiness -can_exit_own: 1.0 if all own checkers are in fields 19–24 -can_exit_opp: same for opponent -→ 2 values - -// Rest corner status (field 12/13) -own_corner_taken: 1.0 if field 12 has ≥2 own checkers -opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers -→ 2 values - -// Jan de 3 coups counter (normalized) -dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) -→ 1 value - -Size: 204 + 8 + 2 + 2 + 1 = 217 -Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve -the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes -expensive inference from the network. - -### Option C — Option B + richer positional features (flat 1D) - -More complete, higher sample efficiency, minor extra cost. - -Add on top of Option B: - -// Per-quarter fill fraction — how close to filling each quarter -own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 -opp_quarter_fill_fraction[q] (q=0..3): same for opponent -→ 8 values - -// Blot counts — number of own/opponent single-checker fields globally -// (tells the network at a glance how much battage risk/opportunity exists) -own_blot_count: (number of own fields with exactly 1 checker) / 15.0 -opp_blot_count: same for opponent -→ 2 values - -// Bredouille would-double multiplier (already present, but explicitly scaled) -// No change needed, already binary - -Size: 217 + 8 + 2 = 227 -Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network -from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). - -### Option D — 2D spatial tensor {K, 24} - -For CNN-based networks. Best eventual architecture but requires changing the training setup. - -Shape {14, 24} — 14 feature channels over 24 field positions: - -Channel 0: own_count_1 (blot) -Channel 1: own_count_2 -Channel 2: own_count_3 -Channel 3: own_count_overflow (float) -Channel 4: opp_count_1 -Channel 5: opp_count_2 -Channel 6: opp_count_3 -Channel 7: opp_count_overflow -Channel 8: own_corner_mask (1.0 at field 12) -Channel 9: opp_corner_mask (1.0 at field 13) -Channel 10: final_quarter_mask (1.0 at fields 19–24) -Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) -Channel 12: quarter_filled_opp (same for opponent) -Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) - -Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform -value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" -row by returning shape {K, 25} with position 0 holding global features. - -Size: 14 × 24 + few global channels ≈ 336–384 -C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated -accordingly. -Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's -alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, -W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. -Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel -size 2–3 captures adjacent-field "tout d'une" interactions. - -### On 3D tensors - -Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is -the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and -field-within-quarter patterns jointly. - -However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The -extra architecture work makes this premature unless you're already building a custom network. The information content -is the same as Option D. - -### Recommendation - -Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and -the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions -(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. - -Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when -the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. - -Costs summary: - -| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | -| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | -| Current | 36 | — | — | — | baseline | -| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | -| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | -| C | 227 | to_vec() rewrite | constant update | none | large + moderate | -| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | - -One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, -the new to_vec() must express everything from the active player's perspective (which the mirror already handles for -the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, -which they will be if computed inside to_vec(). - -## Other algorithms - -The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. - -### 1. Without MCTS, feature quality matters more - -AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred -simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — -the network must learn the full strategic structure directly from gradient updates. - -This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, -not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less -sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from -scratch. - -### 2. Normalization becomes mandatory, not optional - -AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps -Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw -values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can -be unstable. - -For DQN/PPO, every feature in the tensor should be in [0, 1]: - -dice values: / 6.0 -checker counts: overflow channel / 12.0 -points: / 12.0 -holes: / 12.0 -dice_roll_count: / 3.0 (clamped) - -Booleans and the TD-Gammon binary indicators are already in [0, 1]. - -### 3. The shape question depends on architecture, not algorithm - -| Architecture | Shape | When to use | -| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | -| MLP | {217} flat | Any algorithm, simplest baseline | -| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | -| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | -| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | - -The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want -convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet -template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all -— you just reshape the flat tensor in Python before passing it to the network. - -### One thing that genuinely changes: value function perspective - -AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This -works well. - -DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit -current-player indicator), because a single Q-network estimates action values for both players simultaneously. With -the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must -learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. -This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. - -Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml deleted file mode 100644 index b541adc..0000000 --- a/spiel_bot/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "spiel_bot" -version = "0.1.0" -edition = "2021" - -[dependencies] -trictrac-store = { path = "../store" } -trictrac-bot = { path = "../bot" } -anyhow = "1" -rand = "0.9" -rand_distr = "0.5" -burn = { version = "0.20", features = ["ndarray", "autodiff"] } - -[dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } - -[[bench]] -name = "alphazero" -harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs deleted file mode 100644 index 00d5b02..0000000 --- a/spiel_bot/benches/alphazero.rs +++ /dev/null @@ -1,373 +0,0 @@ -//! AlphaZero pipeline benchmarks. -//! -//! Run with: -//! -//! ```sh -//! cargo bench -p spiel_bot -//! ``` -//! -//! Use `-- ` to run a specific group, e.g.: -//! -//! ```sh -//! cargo bench -p spiel_bot -- env/ -//! cargo bench -p spiel_bot -- network/ -//! cargo bench -p spiel_bot -- mcts/ -//! cargo bench -p spiel_bot -- episode/ -//! cargo bench -p spiel_bot -- train/ -//! ``` -//! -//! Target: ≥ 500 games/s for random play on CPU (consistent with -//! `random_game` throughput in `trictrac-store`). - -use std::time::Duration; - -use burn::{ - backend::NdArray, - tensor::{Tensor, TensorData, backend::Backend}, -}; -use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; -use rand::{Rng, SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, -}; - -// ── Shared types ─────────────────────────────────────────────────────────── - -type InferB = NdArray; -type TrainB = burn::backend::Autodiff>; - -fn infer_device() -> ::Device { Default::default() } -fn train_device() -> ::Device { Default::default() } - -fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } - -/// Uniform evaluator (returns zero logits and zero value). -/// Used to isolate MCTS tree-traversal cost from network cost. -struct ZeroEval(usize); -impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } -} - -// ── 1. Environment primitives ────────────────────────────────────────────── - -/// Baseline performance of the raw Trictrac environment without MCTS. -/// Target: ≥ 500 full games / second. -fn bench_env(c: &mut Criterion) { - let env = TrictracEnv; - - let mut group = c.benchmark_group("env"); - group.measurement_time(Duration::from_secs(10)); - - // ── apply_chance ────────────────────────────────────────────────────── - group.bench_function("apply_chance", |b| { - b.iter_batched( - || { - // A fresh game is always at RollDice (Chance) — ready for apply_chance. - env.new_game() - }, - |mut s| { - env.apply_chance(&mut s, &mut seeded()); - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - // ── legal_actions ───────────────────────────────────────────────────── - group.bench_function("legal_actions", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.legal_actions(&s))) - }); - - // ── observation (to_tensor) ─────────────────────────────────────────── - group.bench_function("observation", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.observation(&s, 0))) - }); - - // ── full random game ────────────────────────────────────────────────── - group.sample_size(50); - group.bench_function("random_game", |b| { - b.iter_batched( - seeded, - |mut rng| { - let mut s = env.new_game(); - loop { - match env.current_player(&s) { - Player::Terminal => break, - Player::Chance => env.apply_chance(&mut s, &mut rng), - _ => { - let actions = env.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - env.apply(&mut s, actions[idx]); - } - } - } - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - group.finish(); -} - -// ── 2. Network inference ─────────────────────────────────────────────────── - -/// Forward-pass latency for MLP variants (hidden = 64 / 256). -fn bench_network(c: &mut Criterion) { - let mut group = c.benchmark_group("network"); - group.measurement_time(Duration::from_secs(5)); - - for &hidden in &[64usize, 256] { - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = MlpNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - // Batch size 1 — single-position evaluation as in MCTS. - group.bench_with_input( - BenchmarkId::new("mlp_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - // Batch size 32 — training mini-batch. - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("mlp_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - // ── ResNet (4 residual blocks) ──────────────────────────────────────── - for &hidden in &[256usize, 512] { - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = ResNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - group.bench_with_input( - BenchmarkId::new("resnet_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("resnet_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - group.finish(); -} - -// ── 3. MCTS ─────────────────────────────────────────────────────────────── - -/// MCTS cost at different simulation budgets with two evaluator types: -/// - `zero` — isolates tree-traversal overhead (no network). -/// - `mlp64` — real MLP, shows end-to-end cost per move. -fn bench_mcts(c: &mut Criterion) { - let env = TrictracEnv; - - // Build a decision-node state (after dice roll). - let state = { - let mut s = env.new_game(); - let mut rng = seeded(); - while env.current_player(&s).is_chance() { - env.apply_chance(&mut s, &mut rng); - } - s - }; - - let mut group = c.benchmark_group("mcts"); - group.measurement_time(Duration::from_secs(10)); - - let zero_eval = ZeroEval(514); - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); - let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); - - for &n_sim in &[1usize, 5, 20] { - let cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - // Zero evaluator: tree traversal only. - group.bench_with_input( - BenchmarkId::new("zero_eval", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - - // MLP evaluator: full cost per decision. - group.bench_with_input( - BenchmarkId::new("mlp64", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 4. Episode generation ───────────────────────────────────────────────── - -/// Full self-play episode latency (one complete game) at different MCTS -/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. -fn bench_episode(c: &mut Criterion) { - let env = TrictracEnv; - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let model = MlpNet::::new(&mlp_cfg, &infer_device()); - let eval = BurnEvaluator::::new(model, infer_device()); - - let mut group = c.benchmark_group("episode"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(60)); - - for &n_sim in &[1usize, 2] { - let mcts_cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - group.bench_with_input( - BenchmarkId::new("trictrac", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| { - black_box(generate_episode( - &env, - &eval, - &mcts_cfg, - &|_| 1.0, - &mut rng, - )) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 5. Training step ─────────────────────────────────────────────────────── - -/// Gradient-step latency for different batch sizes. -fn bench_train(c: &mut Criterion) { - use burn::optim::AdamConfig; - - let mut group = c.benchmark_group("train"); - group.measurement_time(Duration::from_secs(10)); - - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - - let dummy_samples = |n: usize| -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5; 217], - policy: { - let mut p = vec![0.0f32; 514]; - p[i % 514] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - }; - - for &batch_size in &[16usize, 64] { - let batch = dummy_samples(batch_size); - - group.bench_with_input( - BenchmarkId::new("mlp64_adam", batch_size), - &batch_size, - |b, _| { - b.iter_batched( - || { - ( - MlpNet::::new(&mlp_cfg, &train_device()), - AdamConfig::new().init::>(), - ) - }, - |(model, mut opt)| { - black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── Criterion entry point ────────────────────────────────────────────────── - -criterion_group!( - benches, - bench_env, - bench_network, - bench_mcts, - bench_episode, - bench_train, -); -criterion_main!(benches); diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs deleted file mode 100644 index d92224e..0000000 --- a/spiel_bot/src/alphazero/mod.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! AlphaZero: self-play data generation, replay buffer, and training step. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | -//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | -//! | [`trainer`] | [`train_step`] | -//! -//! # Typical outer loop -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! use burn::optim::AdamConfig; -//! use spiel_bot::{ -//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, -//! env::TrictracEnv, -//! mcts::MctsConfig, -//! network::{MlpConfig, MlpNet}, -//! }; -//! -//! type Infer = NdArray; -//! type Train = Autodiff>; -//! -//! let device = Default::default(); -//! let env = TrictracEnv; -//! let config = AlphaZeroConfig::default(); -//! -//! // Build training model and optimizer. -//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); -//! let mut optimizer = AdamConfig::new().init(); -//! let mut replay = ReplayBuffer::new(config.replay_capacity); -//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); -//! -//! for _iter in 0..config.n_iterations { -//! // Convert to inference backend for self-play. -//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) -//! .load_record(train_model.clone().into_record()); -//! let eval = BurnEvaluator::new(infer_model, device.clone()); -//! -//! // Self-play: generate episodes. -//! for _ in 0..config.n_games_per_iter { -//! let samples = generate_episode(&env, &eval, &config.mcts, -//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); -//! replay.extend(samples); -//! } -//! -//! // Training: gradient steps. -//! if replay.len() >= config.batch_size { -//! for _ in 0..config.n_train_steps_per_iter { -//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) -//! .into_iter().cloned().collect(); -//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, -//! config.learning_rate); -//! train_model = m; -//! } -//! } -//! } -//! ``` - -pub mod replay; -pub mod selfplay; -pub mod trainer; - -pub use replay::{ReplayBuffer, TrainSample}; -pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::{cosine_lr, train_step}; - -use crate::mcts::MctsConfig; - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Top-level AlphaZero hyperparameters. -/// -/// The MCTS parameters live in [`MctsConfig`]; this struct holds the -/// outer training-loop parameters. -#[derive(Debug, Clone)] -pub struct AlphaZeroConfig { - /// MCTS parameters for self-play. - pub mcts: MctsConfig, - /// Number of self-play games per training iteration. - pub n_games_per_iter: usize, - /// Number of gradient steps per training iteration. - pub n_train_steps_per_iter: usize, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Initial (peak) Adam learning rate. - pub learning_rate: f64, - /// Minimum learning rate for cosine annealing (floor of the schedule). - /// - /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). - /// Compute the current LR with [`cosine_lr`]: - /// - /// ```rust,ignore - /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); - /// ``` - pub lr_min: f64, - /// Number of outer iterations (self-play + train) to run. - pub n_iterations: usize, - /// Move index after which the action temperature drops to 0 (greedy play). - pub temperature_drop_move: usize, -} - -impl Default for AlphaZeroConfig { - fn default() -> Self { - Self { - mcts: MctsConfig { - n_simulations: 100, - c_puct: 1.5, - dirichlet_alpha: 0.1, - dirichlet_eps: 0.25, - temperature: 1.0, - }, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - batch_size: 64, - replay_capacity: 50_000, - learning_rate: 1e-3, - lr_min: 1e-4, // cosine annealing floor - n_iterations: 100, - temperature_drop_move: 30, - } - } -} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs deleted file mode 100644 index 5e64cc4..0000000 --- a/spiel_bot/src/alphazero/replay.rs +++ /dev/null @@ -1,144 +0,0 @@ -//! Replay buffer for AlphaZero self-play data. - -use std::collections::VecDeque; -use rand::Rng; - -// ── Training sample ──────────────────────────────────────────────────────── - -/// One training example produced by self-play. -#[derive(Clone, Debug)] -pub struct TrainSample { - /// Observation tensor from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). - pub policy: Vec, - /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. - pub value: f32, -} - -// ── Replay buffer ────────────────────────────────────────────────────────── - -/// Fixed-capacity circular buffer of [`TrainSample`]s. -/// -/// When the buffer is full, the oldest sample is evicted on push. -/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. -pub struct ReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - /// Create a buffer with the given maximum capacity. - pub fn new(capacity: usize) -> Self { - Self { - data: VecDeque::with_capacity(capacity.min(1024)), - capacity, - } - } - - /// Add a sample; evicts the oldest if at capacity. - pub fn push(&mut self, sample: TrainSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - /// Add all samples from an episode. - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { - self.push(s); - } - } - - pub fn len(&self) -> usize { - self.data.len() - } - - pub fn is_empty(&self) -> bool { - self.data.is_empty() - } - - /// Sample up to `n` distinct samples, without replacement. - /// - /// If the buffer has fewer than `n` samples, all are returned (shuffled). - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { - let len = self.data.len(); - let n = n.min(len); - // Partial Fisher-Yates using index shuffling. - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(value: f32) -> TrainSample { - TrainSample { obs: vec![value], policy: vec![1.0], value } - } - - #[test] - fn push_and_len() { - let mut buf = ReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = ReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); // evicts 1.0 - assert_eq!(buf.len(), 3); - // Oldest remaining should be 2.0 - assert_eq!(buf.data[0].value, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(5, &mut rng); - assert_eq!(batch.len(), 5); - } - - #[test] - fn sample_batch_capped_at_len() { - let mut buf = ReplayBuffer::new(20); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - let mut rng = SmallRng::seed_from_u64(0); - let batch = buf.sample_batch(100, &mut rng); - assert_eq!(batch.len(), 2); - } - - #[test] - fn sample_batch_no_duplicates() { - let mut buf = ReplayBuffer::new(20); - for i in 0..10 { - buf.push(dummy(i as f32)); - } - let mut rng = SmallRng::seed_from_u64(1); - let batch = buf.sample_batch(10, &mut rng); - let mut seen: Vec = batch.iter().map(|s| s.value).collect(); - seen.sort_by(f32::total_cmp); - seen.dedup(); - assert_eq!(seen.len(), 10, "sample contained duplicates"); - } -} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs deleted file mode 100644 index 6f10f8d..0000000 --- a/spiel_bot/src/alphazero/selfplay.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! Self-play episode generation and Burn-backed evaluator. - -use std::marker::PhantomData; - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::GameEnv; -use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; -use crate::network::PolicyValueNet; - -use super::replay::TrainSample; - -// ── BurnEvaluator ────────────────────────────────────────────────────────── - -/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. -/// -/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so -/// that self-play generates no gradient tape overhead. -pub struct BurnEvaluator> { - model: N, - device: B::Device, - _b: PhantomData, -} - -impl> BurnEvaluator { - pub fn new(model: N, device: B::Device) -> Self { - Self { model, device, _b: PhantomData } - } - - pub fn into_model(self) -> N { - self.model - } -} - -// Safety: NdArray modules are Send; we never share across threads without -// external synchronisation. -unsafe impl> Send for BurnEvaluator {} -unsafe impl> Sync for BurnEvaluator {} - -impl> Evaluator for BurnEvaluator { - fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { - let obs_size = obs.len(); - let data = TensorData::new(obs.to_vec(), [1, obs_size]); - let obs_tensor = Tensor::::from_data(data, &self.device); - - let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); - - let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); - let value: Vec = value_tensor.into_data().to_vec().unwrap(); - - (policy, value[0]) - } -} - -// ── Episode generation ───────────────────────────────────────────────────── - -/// One pending observation waiting for its game-outcome value label. -struct PendingSample { - obs: Vec, - policy: Vec, - player: usize, -} - -/// Play one full game using MCTS guided by `evaluator`. -/// -/// Returns a [`TrainSample`] for every decision step in the game. -/// -/// `temperature_fn(step)` controls exploration: return `1.0` for early -/// moves and `0.0` after a fixed number of moves (e.g. move 30). -pub fn generate_episode( - env: &E, - evaluator: &dyn Evaluator, - mcts_config: &MctsConfig, - temperature_fn: &dyn Fn(usize) -> f32, - rng: &mut impl Rng, -) -> Vec { - let mut state = env.new_game(); - let mut pending: Vec = Vec::new(); - let mut step = 0usize; - - loop { - // Advance through chance nodes automatically. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - if env.current_player(&state).is_terminal() { - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - - // Run MCTS to get a policy. - let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); - let policy = mcts::mcts_policy(&root, env.action_space()); - - // Record the observation from the acting player's perspective. - let obs = env.observation(&state, player_idx); - pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); - - // Select and apply the action. - let temperature = temperature_fn(step); - let action = mcts::select_action(&root, temperature, rng); - env.apply(&mut state, action); - step += 1; - } - - // Assign game outcomes. - let returns = env.returns(&state).unwrap_or([0.0; 2]); - pending - .into_iter() - .map(|s| TrainSample { - obs: s.obs, - policy: s.policy, - value: returns[s.player], - }) - .collect() -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::env::Player; - use crate::mcts::{Evaluator, MctsConfig}; - use crate::network::{MlpConfig, MlpNet}; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(7) - } - - // Countdown game (same as in mcts tests). - #[derive(Clone, Debug)] - struct CState { remaining: u8, to_move: usize } - - #[derive(Clone)] - struct CountdownEnv; - - impl GameEnv for CountdownEnv { - type State = CState; - fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } else { Player::P2 } - } - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { s.remaining = 0; } - else { s.remaining -= sub; s.to_move = 1 - s.to_move; } - } - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 4.0, s.to_move as f32] - } - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - fn tiny_config() -> MctsConfig { - MctsConfig { n_simulations: 5, c_puct: 1.5, - dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } - } - - // ── BurnEvaluator tests ─────────────────────────────────────────────── - - #[test] - fn burn_evaluator_output_shapes() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); - assert_eq!(policy.len(), 2, "policy length should equal action_space"); - assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); - } - - // ── generate_episode tests ──────────────────────────────────────────── - - #[test] - fn episode_terminates_and_has_samples() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_sample_values_are_valid() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); - for s in &samples { - assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "unexpected value {}", s.value); - let sum: f32 = s.policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); - assert_eq!(s.obs.len(), 2); - } - } - - #[test] - fn episode_with_temperature_zero() { - let env = CountdownEnv; - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let eval = BurnEvaluator::new(model, device()); - // temperature=0 means greedy; episode must still terminate - let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); - assert!(!samples.is_empty()); - } -} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs deleted file mode 100644 index 9075519..0000000 --- a/spiel_bot/src/alphazero/trainer.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! One gradient-descent training step for AlphaZero. -//! -//! The loss combines: -//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. -//! - **Value loss** — mean-squared error between the predicted value and the -//! actual game outcome. -//! -//! # Learning-rate scheduling -//! -//! [`cosine_lr`] implements one-cycle cosine annealing: -//! -//! ```text -//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) -//! ``` -//! -//! Typical usage in the outer loop: -//! -//! ```rust,ignore -//! for step in 0..total_train_steps { -//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); -//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); -//! model = m; -//! } -//! ``` -//! -//! # Backend -//! -//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). -//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. -//! Weights are transferred between the two via [`burn::record`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - activation::log_softmax, - backend::AutodiffBackend, - Tensor, TensorData, - }, -}; - -use crate::network::PolicyValueNet; -use super::replay::TrainSample; - -/// Run one gradient step on `model` using `batch`. -/// -/// Returns the updated model and the scalar loss value for logging. -/// -/// # Parameters -/// -/// - `lr` — learning rate (e.g. `1e-3`). -/// - `batch` — slice of [`TrainSample`]s; must be non-empty. -pub fn train_step( - model: N, - optimizer: &mut O, - batch: &[TrainSample], - device: &B::Device, - lr: f64, -) -> (N, f32) -where - B: AutodiffBackend, - N: PolicyValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "train_step called with empty batch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - let action_size = batch[0].policy.len(); - - // ── Build input tensors ──────────────────────────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); - let value_flat: Vec = batch.iter().map(|s| s.value).collect(); - - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - let policy_target = Tensor::::from_data( - TensorData::new(policy_flat, [batch_size, action_size]), - device, - ); - let value_target = Tensor::::from_data( - TensorData::new(value_flat, [batch_size, 1]), - device, - ); - - // ── Forward pass ────────────────────────────────────────────────────── - let (policy_logits, value_pred) = model.forward(obs_tensor); - - // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── - let log_probs = log_softmax(policy_logits, 1); - let policy_loss = (policy_target.clone().neg() * log_probs) - .sum_dim(1) - .mean(); - - // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── - let diff = value_pred - value_target; - let value_loss = (diff.clone() * diff).mean(); - - // ── Combined loss ───────────────────────────────────────────────────── - let loss = policy_loss + value_loss; - - // Extract scalar before backward (consumes the tensor). - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - let model = optimizer.step(lr, model, grads); - - (model, loss_scalar) -} - -// ── Learning-rate schedule ───────────────────────────────────────────────── - -/// Cosine learning-rate schedule (one half-period, no warmup). -/// -/// Returns the learning rate for training step `step` out of `total_steps`: -/// -/// ```text -/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) -/// ``` -/// -/// - At `t = 0` returns `initial`. -/// - At `t = total_steps` (or beyond) returns `lr_min`. -/// -/// # Panics -/// -/// Does not panic. When `total_steps == 0`, returns `lr_min`. -pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { - if total_steps == 0 || step >= total_steps { - return lr_min; - } - let progress = step as f64 / total_steps as f64; - lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - - use crate::network::{MlpConfig, MlpNet}; - use super::super::replay::TrainSample; - - type B = Autodiff>; - - fn device() -> ::Device { - Default::default() - } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5f32; obs_size], - policy: { - let mut p = vec![0.0f32; action_size]; - p[i % action_size] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - } - - #[test] - fn train_step_returns_finite_loss() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); - } - - #[test] - fn loss_decreases_over_steps() { - let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - // Same batch every step — loss should decrease. - let batch = dummy_batch(16, 4, 4); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); - model = m; - assert!(loss.is_finite()); - prev_loss = loss; - } - // After 10 steps on fixed data, loss should be below a reasonable threshold. - assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn train_step_batch_size_one() { - let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; - let model = MlpNet::::new(&config, &device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(1, 2, 2); - let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); - assert!(loss.is_finite()); - } - - // ── cosine_lr ───────────────────────────────────────────────────────── - - #[test] - fn cosine_lr_at_step_zero_is_initial() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); - assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); - } - - #[test] - fn cosine_lr_at_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); - } - - #[test] - fn cosine_lr_beyond_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); - } - - #[test] - fn cosine_lr_midpoint_is_average() { - // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. - let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); - let expected = (1e-3 + 1e-5) / 2.0; - assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); - } - - #[test] - fn cosine_lr_monotone_decreasing() { - let mut prev = f64::INFINITY; - for step in 0..=100 { - let lr = super::cosine_lr(1e-3, 1e-5, step, 100); - assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); - prev = lr; - } - } - - #[test] - fn cosine_lr_zero_total_steps_returns_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); - assert!((lr - 1e-5).abs() < 1e-10); - } -} diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs deleted file mode 100644 index 3c82519..0000000 --- a/spiel_bot/src/bin/az_eval.rs +++ /dev/null @@ -1,262 +0,0 @@ -//! Evaluate a trained AlphaZero checkpoint against a random player. -//! -//! # Usage -//! -//! ```sh -//! # Random weights (sanity check — should be ~50 %) -//! cargo run -p spiel_bot --bin az_eval --release -//! -//! # Trained MLP checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 -//! -//! # Trained ResNet checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | -//! | `--arch mlp\|resnet` | `mlp` | Network architecture | -//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | -//! | `--n-games ` | `100` | Games per side (total = 2 × N) | -//! | `--n-sim ` | `50` | MCTS simulations per move | -//! | `--seed ` | `42` | RNG seed | -//! | `--c-puct ` | `1.5` | PUCT exploration constant | - -use std::path::PathBuf; - -use burn::backend::NdArray; -use rand::{SeedableRng, rngs::SmallRng, Rng}; - -use spiel_bot::{ - alphazero::BurnEvaluator, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts, select_action}, - network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, -}; - -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - checkpoint: Option, - arch: String, - hidden: Option, - n_games: usize, - n_sim: usize, - seed: u64, - c_puct: f32, -} - -impl Default for Args { - fn default() -> Self { - Self { - checkpoint: None, - arch: "mlp".into(), - hidden: None, - n_games: 100, - n_sim: 50, - seed: 42, - c_puct: 1.5, - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut args = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } - "--arch" => { i += 1; args.arch = raw[i].clone(); } - "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } - "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } - "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } - "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } - "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - args -} - -// ── Game loop ───────────────────────────────────────────────────────────────── - -/// Play one complete game. -/// -/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). -/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). -fn play_game( - env: &TrictracEnv, - mcts_side: usize, - evaluator: &dyn Evaluator, - mcts_cfg: &MctsConfig, - rng: &mut SmallRng, -) -> [f32; 2] { - let mut state = env.new_game(); - loop { - match env.current_player(&state) { - Player::Terminal => { - return env.returns(&state).expect("Terminal state must have returns"); - } - Player::Chance => env.apply_chance(&mut state, rng), - player => { - let side = player.index().unwrap(); // 0 = P1, 1 = P2 - let action = if side == mcts_side { - let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); - select_action(&root, 0.0, rng) // greedy (temperature = 0) - } else { - let actions = env.legal_actions(&state); - actions[rng.random_range(0..actions.len())] - }; - env.apply(&mut state, action); - } - } - } -} - -// ── Statistics ──────────────────────────────────────────────────────────────── - -#[derive(Default)] -struct Stats { - wins: u32, - draws: u32, - losses: u32, -} - -impl Stats { - fn record(&mut self, mcts_return: f32) { - if mcts_return > 0.0 { self.wins += 1; } - else if mcts_return < 0.0 { self.losses += 1; } - else { self.draws += 1; } - } - - fn total(&self) -> u32 { self.wins + self.draws + self.losses } - - fn win_rate_decisive(&self) -> f64 { - let d = self.wins + self.losses; - if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } - } - - fn print(&self) { - let n = self.total(); - let pct = |k: u32| 100.0 * k as f64 / n as f64; - println!( - " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", - self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), - ); - } -} - -// ── Evaluation ──────────────────────────────────────────────────────────────── - -fn run_evaluation( - evaluator: &dyn Evaluator, - n_games: usize, - mcts_cfg: &MctsConfig, - seed: u64, -) -> (Stats, Stats) { - let env = TrictracEnv; - let total = n_games * 2; - let mut as_p1 = Stats::default(); - let mut as_p2 = Stats::default(); - - for i in 0..total { - // Alternate sides: even games → MctsAgent as P1, odd → as P2. - let mcts_side = i % 2; - let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); - let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); - - let mcts_return = result[mcts_side]; - if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } - - let done = i + 1; - if done % 10 == 0 || done == total { - eprint!("\r [{done}/{total}] ", ); - } - } - eprintln!(); - (as_p1, as_p2) -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - let device: ::Device = Default::default(); - - // ── Load model ──────────────────────────────────────────────────────── - let evaluator: Box = match args.arch.as_str() { - "resnet" => { - let hidden = args.hidden.unwrap_or(512); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => ResNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => ResNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - "mlp" | _ => { - let hidden = args.hidden.unwrap_or(256); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => MlpNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => MlpNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - }; - - let mcts_cfg = MctsConfig { - n_simulations: args.n_sim, - c_puct: args.c_puct, - dirichlet_alpha: 0.0, // no exploration noise during evaluation - dirichlet_eps: 0.0, - temperature: 0.0, // greedy action selection - }; - - // ── Header ──────────────────────────────────────────────────────────── - let ckpt_label = args.checkpoint - .as_deref() - .and_then(|p| p.file_name()) - .and_then(|n| n.to_str()) - .unwrap_or("random weights"); - - println!(); - println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", - args.arch, args.n_sim); - println!("Games per side: {} | Total: {} | Seed: {}", - args.n_games, args.n_games * 2, args.seed); - println!(); - - // ── Run ─────────────────────────────────────────────────────────────── - let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); - - // ── Results ─────────────────────────────────────────────────────────── - println!("MctsAgent as P1 (White):"); - as_p1.print(); - - println!("MctsAgent as P2 (Black):"); - as_p2.print(); - - let combined_wins = as_p1.wins + as_p2.wins; - let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; - let combined_wr = if combined_decisive == 0 { 0.5 } - else { combined_wins as f64 / combined_decisive as f64 }; - - println!(); - println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", - combined_wr * 100.0, combined_wins, combined_decisive); - println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", - as_p1.win_rate_decisive() * 100.0, - as_p2.win_rate_decisive() * 100.0); -} diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs deleted file mode 100644 index ab385c2..0000000 --- a/spiel_bot/src/bin/az_train.rs +++ /dev/null @@ -1,314 +0,0 @@ -//! AlphaZero self-play training loop. -//! -//! # Usage -//! -//! ```sh -//! # Start fresh (MLP, default settings) -//! cargo run -p spiel_bot --bin az_train --release -//! -//! # ResNet, 200 iterations, save every 20 -//! cargo run -p spiel_bot --bin az_train --release -- \ -//! --arch resnet --n-iter 200 --save-every 20 --out checkpoints/ -//! -//! # Resume from a checkpoint -//! cargo run -p spiel_bot --bin az_train --release -- \ -//! --resume checkpoints/iter_0050.mpk --arch mlp --n-iter 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--arch mlp\|resnet` | `mlp` | Network architecture | -//! | `--hidden N` | 256/512 | Hidden layer width | -//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | -//! | `--n-iter N` | `100` | Training iterations | -//! | `--n-games N` | `10` | Self-play games per iteration | -//! | `--n-train N` | `20` | Gradient steps per iteration | -//! | `--n-sim N` | `100` | MCTS simulations per move | -//! | `--batch N` | `64` | Mini-batch size | -//! | `--replay-cap N` | `50000` | Replay buffer capacity | -//! | `--lr F` | `1e-3` | Peak (initial) learning rate | -//! | `--lr-min F` | `1e-4` | Floor learning rate (cosine annealing) | -//! | `--c-puct F` | `1.5` | PUCT exploration constant | -//! | `--dirichlet-alpha F` | `0.1` | Dirichlet noise alpha | -//! | `--dirichlet-eps F` | `0.25` | Dirichlet noise weight | -//! | `--temp-drop N` | `30` | Move after which temperature drops to 0 | -//! | `--save-every N` | `10` | Save checkpoint every N iterations | -//! | `--seed N` | `42` | RNG seed | -//! | `--resume PATH` | (none) | Load weights from checkpoint before training | - -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, - tensor::backend::Backend, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{ - BurnEvaluator, ReplayBuffer, TrainSample, cosine_lr, generate_episode, train_step, - }, - env::TrictracEnv, - mcts::MctsConfig, - network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, -}; - -type TrainB = Autodiff>; -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - arch: String, - hidden: Option, - out_dir: PathBuf, - n_iter: usize, - n_games: usize, - n_train: usize, - n_sim: usize, - batch_size: usize, - replay_cap: usize, - lr: f64, - lr_min: f64, - c_puct: f32, - dirichlet_alpha: f32, - dirichlet_eps: f32, - temp_drop: usize, - save_every: usize, - seed: u64, - resume: Option, -} - -impl Default for Args { - fn default() -> Self { - Self { - arch: "mlp".into(), - hidden: None, - out_dir: PathBuf::from("checkpoints"), - n_iter: 100, - n_games: 10, - n_train: 20, - n_sim: 100, - batch_size: 64, - replay_cap: 50_000, - lr: 1e-3, - lr_min: 1e-4, - c_puct: 1.5, - dirichlet_alpha: 0.1, - dirichlet_eps: 0.25, - temp_drop: 30, - save_every: 10, - seed: 42, - resume: None, - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut a = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--arch" => { i += 1; a.arch = raw[i].clone(); } - "--hidden" => { i += 1; a.hidden = Some(raw[i].parse().expect("--hidden: integer")); } - "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } - "--n-iter" => { i += 1; a.n_iter = raw[i].parse().expect("--n-iter: integer"); } - "--n-games" => { i += 1; a.n_games = raw[i].parse().expect("--n-games: integer"); } - "--n-train" => { i += 1; a.n_train = raw[i].parse().expect("--n-train: integer"); } - "--n-sim" => { i += 1; a.n_sim = raw[i].parse().expect("--n-sim: integer"); } - "--batch" => { i += 1; a.batch_size = raw[i].parse().expect("--batch: integer"); } - "--replay-cap" => { i += 1; a.replay_cap = raw[i].parse().expect("--replay-cap: integer"); } - "--lr" => { i += 1; a.lr = raw[i].parse().expect("--lr: float"); } - "--lr-min" => { i += 1; a.lr_min = raw[i].parse().expect("--lr-min: float"); } - "--c-puct" => { i += 1; a.c_puct = raw[i].parse().expect("--c-puct: float"); } - "--dirichlet-alpha" => { i += 1; a.dirichlet_alpha = raw[i].parse().expect("--dirichlet-alpha: float"); } - "--dirichlet-eps" => { i += 1; a.dirichlet_eps = raw[i].parse().expect("--dirichlet-eps: float"); } - "--temp-drop" => { i += 1; a.temp_drop = raw[i].parse().expect("--temp-drop: integer"); } - "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } - "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } - "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - a -} - -// ── Training loop ───────────────────────────────────────────────────────────── - -/// Generic training loop, parameterised over the network type. -/// -/// `save_fn` receives the **training-backend** model and the target path; -/// it is called in the match arm where the concrete network type is known. -fn train_loop( - mut model: N, - save_fn: &dyn Fn(&N, &Path) -> anyhow::Result<()>, - args: &Args, -) -where - N: PolicyValueNet + AutodiffModule + Clone, - >::InnerModule: PolicyValueNet + Send + 'static, -{ - let train_device: ::Device = Default::default(); - let infer_device: ::Device = Default::default(); - - // Type is inferred as OptimizerAdaptor at the call site. - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(args.replay_cap); - let mut rng = SmallRng::seed_from_u64(args.seed); - let env = TrictracEnv; - - // Total gradient steps (used for cosine LR denominator). - let total_train_steps = (args.n_iter * args.n_train).max(1); - let mut global_step = 0usize; - - println!( - "\n{:-<60}\n az_train — {} | {} iters | {} games/iter | {} sims/move\n{:-<60}", - "", args.arch, args.n_iter, args.n_games, args.n_sim, "" - ); - - for iter in 0..args.n_iter { - let t0 = Instant::now(); - - // ── Self-play ──────────────────────────────────────────────────── - // Convert to inference backend (zero autodiff overhead). - let infer_model: >::InnerModule = model.valid(); - let evaluator: BurnEvaluator>::InnerModule> = - BurnEvaluator::new(infer_model, infer_device.clone()); - - let mcts_cfg = MctsConfig { - n_simulations: args.n_sim, - c_puct: args.c_puct, - dirichlet_alpha: args.dirichlet_alpha, - dirichlet_eps: args.dirichlet_eps, - temperature: 1.0, - }; - - let temp_drop = args.temp_drop; - let temperature_fn = |step: usize| -> f32 { - if step < temp_drop { 1.0 } else { 0.0 } - }; - - let mut new_samples = 0usize; - for _ in 0..args.n_games { - let samples = - generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); - new_samples += samples.len(); - replay.extend(samples); - } - - // ── Training ───────────────────────────────────────────────────── - let mut loss_sum = 0.0f32; - let mut n_steps = 0usize; - - if replay.len() >= args.batch_size { - for _ in 0..args.n_train { - let lr = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); - let batch: Vec = replay - .sample_batch(args.batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = - train_step(model, &mut optimizer, &batch, &train_device, lr); - model = m; - loss_sum += loss; - n_steps += 1; - global_step += 1; - } - } - - // ── Logging ────────────────────────────────────────────────────── - let elapsed = t0.elapsed(); - let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; - let lr_now = cosine_lr(args.lr, args.lr_min, global_step, total_train_steps); - - println!( - "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | lr {:.2e} | {:.1}s", - iter + 1, - args.n_iter, - replay.len(), - new_samples, - avg_loss, - lr_now, - elapsed.as_secs_f32(), - ); - - // ── Checkpoint ─────────────────────────────────────────────────── - let is_last = iter + 1 == args.n_iter; - if (iter + 1) % args.save_every == 0 || is_last { - let path = args.out_dir.join(format!("iter_{:04}.mpk", iter + 1)); - match save_fn(&model, &path) { - Ok(()) => println!(" -> saved {}", path.display()), - Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), - } - } - } - - println!("\nTraining complete."); -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - - // Create output directory if it doesn't exist. - if let Err(e) = std::fs::create_dir_all(&args.out_dir) { - eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); - std::process::exit(1); - } - - let train_device: ::Device = Default::default(); - - match args.arch.as_str() { - "resnet" => { - let hidden = args.hidden.unwrap_or(512); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - - let model = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - ResNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => ResNet::::new(&cfg, &train_device), - }; - - train_loop( - model, - &|m: &ResNet, path: &Path| { - // Save via inference model to avoid autodiff record overhead. - m.valid().save(path) - }, - &args, - ); - } - - "mlp" | _ => { - let hidden = args.hidden.unwrap_or(256); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - - let model = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - MlpNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => MlpNet::::new(&cfg, &train_device), - }; - - train_loop( - model, - &|m: &MlpNet, path: &Path| m.valid().save(path), - &args, - ); - } - } -} diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs deleted file mode 100644 index 0ebe978..0000000 --- a/spiel_bot/src/bin/dqn_train.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! DQN self-play training loop. -//! -//! # Usage -//! -//! ```sh -//! # Start fresh with default settings -//! cargo run -p spiel_bot --bin dqn_train --release -//! -//! # Custom hyperparameters -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 -//! -//! # Resume from a checkpoint -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--hidden N` | 256 | Hidden layer width | -//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | -//! | `--n-iter N` | 100 | Training iterations | -//! | `--n-games N` | 10 | Self-play games per iteration | -//! | `--n-train N` | 20 | Gradient steps per iteration | -//! | `--batch N` | 64 | Mini-batch size | -//! | `--replay-cap N` | 50000 | Replay buffer capacity | -//! | `--lr F` | 1e-3 | Adam learning rate | -//! | `--epsilon-start F` | 1.0 | Initial exploration rate | -//! | `--epsilon-end F` | 0.05 | Final exploration rate | -//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | -//! | `--gamma F` | 0.99 | Discount factor | -//! | `--target-update N` | 500 | Hard-update target net every N steps | -//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | -//! | `--save-every N` | 10 | Save checkpoint every N iterations | -//! | `--seed N` | 42 | RNG seed | -//! | `--resume PATH` | (none) | Load weights before training | - -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, - tensor::backend::Backend, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - dqn::{ - DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, - generate_dqn_episode, hard_update, linear_epsilon, - }, - env::TrictracEnv, - network::{QNet, QNetConfig}, -}; - -type TrainB = Autodiff>; -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - hidden: usize, - out_dir: PathBuf, - save_every: usize, - seed: u64, - resume: Option, - config: DqnConfig, -} - -impl Default for Args { - fn default() -> Self { - Self { - hidden: 256, - out_dir: PathBuf::from("checkpoints"), - save_every: 10, - seed: 42, - resume: None, - config: DqnConfig::default(), - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut a = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } - "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } - "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } - "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } - "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } - "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } - "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } - "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } - "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } - "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } - "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } - "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } - "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } - "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } - "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } - "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } - "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - a -} - -// ── Training loop ───────────────────────────────────────────────────────────── - -fn train_loop( - mut q_net: QNet, - cfg: &QNetConfig, - save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, - args: &Args, -) { - let train_device: ::Device = Default::default(); - let infer_device: ::Device = Default::default(); - - let mut optimizer = AdamConfig::new().init(); - let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); - let mut rng = SmallRng::seed_from_u64(args.seed); - let env = TrictracEnv; - - let mut target_net: QNet = hard_update::(&q_net); - let mut global_step = 0usize; - let mut epsilon = args.config.epsilon_start; - - println!( - "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", - "", args.config.n_iterations, args.config.n_games_per_iter, - args.config.n_train_steps_per_iter, "" - ); - - for iter in 0..args.config.n_iterations { - let t0 = Instant::now(); - - // ── Self-play ──────────────────────────────────────────────────── - let infer_q: QNet = q_net.valid(); - let mut new_samples = 0usize; - - for _ in 0..args.config.n_games_per_iter { - let samples = generate_dqn_episode( - &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, - ); - new_samples += samples.len(); - replay.extend(samples); - } - - // ── Training ───────────────────────────────────────────────────── - let mut loss_sum = 0.0f32; - let mut n_steps = 0usize; - - if replay.len() >= args.config.batch_size { - for _ in 0..args.config.n_train_steps_per_iter { - let batch: Vec<_> = replay - .sample_batch(args.config.batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - // Target Q-values computed on the inference backend. - let target_q = compute_target_q( - &target_net, &batch, cfg.action_size, &infer_device, - ); - - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &target_q, - &train_device, args.config.learning_rate, args.config.gamma, - ); - q_net = q; - loss_sum += loss; - n_steps += 1; - global_step += 1; - - // Hard-update target net every target_update_freq steps. - if global_step % args.config.target_update_freq == 0 { - target_net = hard_update::(&q_net); - } - - // Linear epsilon decay. - epsilon = linear_epsilon( - args.config.epsilon_start, - args.config.epsilon_end, - global_step, - args.config.epsilon_decay_steps, - ); - } - } - - // ── Logging ────────────────────────────────────────────────────── - let elapsed = t0.elapsed(); - let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; - - println!( - "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", - iter + 1, - args.config.n_iterations, - replay.len(), - new_samples, - avg_loss, - epsilon, - elapsed.as_secs_f32(), - ); - - // ── Checkpoint ─────────────────────────────────────────────────── - let is_last = iter + 1 == args.config.n_iterations; - if (iter + 1) % args.save_every == 0 || is_last { - let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); - match save_fn(&q_net, &path) { - Ok(()) => println!(" -> saved {}", path.display()), - Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), - } - } - } - - println!("\nDQN training complete."); -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - - if let Err(e) = std::fs::create_dir_all(&args.out_dir) { - eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); - std::process::exit(1); - } - - let train_device: ::Device = Default::default(); - let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; - - let q_net = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - QNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => QNet::::new(&cfg, &train_device), - }; - - train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); -} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs deleted file mode 100644 index aca1343..0000000 --- a/spiel_bot/src/dqn/episode.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! DQN self-play episode generation. -//! -//! Both players share the same Q-network (the [`TrictracEnv`] handles board -//! mirroring so that each player always acts from "White's perspective"). -//! Transitions for both players are stored in the returned sample list. -//! -//! # Reward -//! -//! After each full decision (action applied and the state has advanced through -//! any intervening chance nodes back to the same player's next turn), the -//! reward is: -//! -//! ```text -//! r = (my_total_score_now − my_total_score_then) -//! − (opp_total_score_now − opp_total_score_then) -//! ``` -//! -//! where `total_score = holes × 12 + points`. -//! -//! # Transition structure -//! -//! We use a "pending transition" per player. When a player acts again, we -//! *complete* the previous pending transition by filling in `next_obs`, -//! `next_legal`, and computing `reward`. Terminal transitions are completed -//! when the game ends. - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::{GameEnv, TrictracEnv}; -use crate::network::QValueNet; -use super::DqnSample; - -// ── Internals ───────────────────────────────────────────────────────────────── - -struct PendingTransition { - obs: Vec, - action: usize, - /// Score snapshot `[p1_total, p2_total]` at the moment of the action. - score_before: [i32; 2], -} - -/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. -fn epsilon_greedy>( - q_net: &Q, - obs: &[f32], - legal: &[usize], - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, -) -> usize { - debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); - if rng.random::() < epsilon { - legal[rng.random_range(0..legal.len())] - } else { - let obs_tensor = Tensor::::from_data( - TensorData::new(obs.to_vec(), [1, obs.len()]), - device, - ); - let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); - legal - .iter() - .copied() - .max_by(|&a, &b| { - q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap() - } -} - -/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. -fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { - let opp_idx = 1 - player_idx; - ((score_after[player_idx] - score_before[player_idx]) - - (score_after[opp_idx] - score_before[opp_idx])) as f32 -} - -// ── Public API ──────────────────────────────────────────────────────────────── - -/// Play one full game and return all transitions for both players. -/// -/// - `q_net` uses the **inference backend** (no autodiff wrapper). -/// - `epsilon` in `[0, 1]`: probability of taking a random action. -/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). -pub fn generate_dqn_episode>( - env: &TrictracEnv, - q_net: &Q, - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, - reward_scale: f32, -) -> Vec { - let obs_size = env.obs_size(); - let mut state = env.new_game(); - let mut pending: [Option; 2] = [None, None]; - let mut samples: Vec = Vec::new(); - - loop { - // ── Advance past chance nodes ────────────────────────────────────── - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - let score_now = TrictracEnv::score_snapshot(&state); - - if env.current_player(&state).is_terminal() { - // Complete all pending transitions as terminal. - for player_idx in 0..2 { - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: vec![0.0; obs_size], - next_legal: vec![], - done: true, - }); - } - } - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - let legal = env.legal_actions(&state); - let obs = env.observation(&state, player_idx); - - // ── Complete the previous transition for this player ─────────────── - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: obs.clone(), - next_legal: legal.clone(), - done: false, - }); - } - - // ── Pick and apply action ────────────────────────────────────────── - let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); - env.apply(&mut state, action); - - // ── Record new pending transition ────────────────────────────────── - pending[player_idx] = Some(PendingTransition { - obs, - action, - score_before: score_now, - }); - } - - samples -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::network::{QNet, QNetConfig}; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } - - fn tiny_q() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn episode_terminates_and_produces_samples() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_obs_size_correct() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert_eq!(s.obs.len(), 217, "obs size mismatch"); - if s.done { - assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); - assert!(s.next_legal.is_empty()); - } else { - assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); - assert!(!s.next_legal.is_empty()); - } - } - } - - #[test] - fn episode_actions_within_action_space() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert!(s.action < 514, "action {} out of bounds", s.action); - } - } - - #[test] - fn greedy_episode_also_terminates() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty()); - } - - #[test] - fn at_least_one_done_sample() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - let n_done = samples.iter().filter(|s| s.done).count(); - // Two players, so 1 or 2 terminal transitions. - assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); - } - - #[test] - fn compute_reward_correct() { - // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. - let before = [2 * 12 + 10, 0]; - let after = [3 * 12 + 2, 0]; - let r = compute_reward(0, &before, &after); - assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); - } - - #[test] - fn compute_reward_with_opponent_scoring() { - // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. - let before = [0, 0]; - let after = [2, 3]; - let r = compute_reward(0, &before, &after); - assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); - } -} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs deleted file mode 100644 index 8c34fc1..0000000 --- a/spiel_bot/src/dqn/mod.rs +++ /dev/null @@ -1,232 +0,0 @@ -//! DQN: self-play data generation, replay buffer, and training step. -//! -//! # Algorithm -//! -//! Deep Q-Network with: -//! - **ε-greedy** exploration (linearly decayed). -//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where -//! `score = holes × 12 + points`. -//! - **Experience replay** with a fixed-capacity circular buffer. -//! - **Target network**: hard-copied from the online Q-net every -//! `target_update_freq` gradient steps for training stability. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | -//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | - -pub mod episode; -pub mod trainer; - -pub use episode::generate_dqn_episode; -pub use trainer::{compute_target_q, dqn_train_step, hard_update}; - -use std::collections::VecDeque; -use rand::Rng; - -// ── DqnSample ───────────────────────────────────────────────────────────────── - -/// One transition `(s, a, r, s', done)` collected during self-play. -#[derive(Clone, Debug)] -pub struct DqnSample { - /// Observation from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// Action index taken. - pub action: usize, - /// Per-turn reward: `my_score_delta − opponent_score_delta`. - pub reward: f32, - /// Next observation from the same player's perspective. - /// All-zeros when `done = true` (ignored by the TD target). - pub next_obs: Vec, - /// Legal actions at `next_obs`. Empty when `done = true`. - pub next_legal: Vec, - /// `true` when `next_obs` is a terminal state. - pub done: bool, -} - -// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── - -/// Fixed-capacity circular replay buffer for [`DqnSample`]s. -/// -/// When full, the oldest sample is evicted on push. -/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. -pub struct DqnReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl DqnReplayBuffer { - pub fn new(capacity: usize) -> Self { - Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } - } - - pub fn push(&mut self, sample: DqnSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { self.push(s); } - } - - pub fn len(&self) -> usize { self.data.len() } - pub fn is_empty(&self) -> bool { self.data.is_empty() } - - /// Sample up to `n` distinct samples without replacement. - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { - let len = self.data.len(); - let n = n.min(len); - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── DqnConfig ───────────────────────────────────────────────────────────────── - -/// Top-level DQN hyperparameters for the training loop. -#[derive(Debug, Clone)] -pub struct DqnConfig { - /// Initial exploration rate (1.0 = fully random). - pub epsilon_start: f32, - /// Final exploration rate after decay. - pub epsilon_end: f32, - /// Number of gradient steps over which ε decays linearly from start to end. - /// - /// Should be calibrated to the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). A value larger than that - /// means exploration never reaches `epsilon_end` during the run. - pub epsilon_decay_steps: usize, - /// Discount factor γ for the TD target. Typical: 0.99. - pub gamma: f32, - /// Hard-copy Q → target every this many gradient steps. - /// - /// Should be much smaller than the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). - pub target_update_freq: usize, - /// Adam learning rate. - pub learning_rate: f64, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Number of outer iterations (self-play + train). - pub n_iterations: usize, - /// Self-play games per iteration. - pub n_games_per_iter: usize, - /// Gradient steps per iteration. - pub n_train_steps_per_iter: usize, - /// Reward normalisation divisor. - /// - /// Per-turn rewards (score delta) are divided by this constant before being - /// stored. Without normalisation, rewards can reach ±24 (jan with - /// bredouille = 12 pts × 2), driving Q-values into the hundreds and - /// causing MSE loss to grow unboundedly. - /// - /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping - /// Q-value magnitudes in a stable range. Set to `1.0` to disable. - pub reward_scale: f32, -} - -impl Default for DqnConfig { - fn default() -> Self { - // Total gradient steps with these defaults = 500 × 20 = 10_000, - // so epsilon decays fully and the target is updated 100 times. - Self { - epsilon_start: 1.0, - epsilon_end: 0.05, - epsilon_decay_steps: 10_000, - gamma: 0.99, - target_update_freq: 100, - learning_rate: 1e-3, - batch_size: 64, - replay_capacity: 50_000, - n_iterations: 500, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - reward_scale: 12.0, - } - } -} - -/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. -pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { - if decay_steps == 0 || step >= decay_steps { - return end; - } - start + (end - start) * (step as f32 / decay_steps as f32) -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(reward: f32) -> DqnSample { - DqnSample { - obs: vec![0.0], - action: 0, - reward, - next_obs: vec![0.0], - next_legal: vec![0], - done: false, - } - } - - #[test] - fn push_and_len() { - let mut buf = DqnReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = DqnReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); - assert_eq!(buf.len(), 3); - assert_eq!(buf.data[0].reward, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = DqnReplayBuffer::new(20); - for i in 0..10 { buf.push(dummy(i as f32)); } - let mut rng = SmallRng::seed_from_u64(0); - assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); - } - - #[test] - fn linear_epsilon_start() { - assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_end() { - assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_monotone() { - let mut prev = f32::INFINITY; - for step in 0..=100 { - let e = linear_epsilon(1.0, 0.05, step, 100); - assert!(e <= prev + 1e-6); - prev = e; - } - } -} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs deleted file mode 100644 index b8b0a02..0000000 --- a/spiel_bot/src/dqn/trainer.rs +++ /dev/null @@ -1,278 +0,0 @@ -//! DQN gradient step and target-network management. -//! -//! # TD target -//! -//! ```text -//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done -//! y_i = r_i if done -//! ``` -//! -//! # Loss -//! -//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) -//! and `y_i` (computed from the frozen target net). -//! -//! # Target network -//! -//! [`hard_update`] copies the online Q-net weights into the target net by -//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - Int, Tensor, TensorData, - backend::{AutodiffBackend, Backend}, - }, -}; - -use crate::network::QValueNet; -use super::DqnSample; - -// ── Target Q computation ───────────────────────────────────────────────────── - -/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. -/// -/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` -/// (their bootstrap term is dropped by the TD target anyway). -/// -/// The target network runs on the **inference backend** (`InferB`) with no -/// gradient tape, so this function is backend-agnostic (`B: Backend`). -pub fn compute_target_q>( - target_net: &Q, - batch: &[DqnSample], - action_size: usize, - device: &B::Device, -) -> Vec { - let batch_size = batch.len(); - - // Collect indices of non-done samples (done samples have no next state). - let non_done: Vec = batch - .iter() - .enumerate() - .filter(|(_, s)| !s.done) - .map(|(i, _)| i) - .collect(); - - if non_done.is_empty() { - return vec![0.0; batch_size]; - } - - let obs_size = batch[0].next_obs.len(); - let nd = non_done.len(); - - // Stack next observations for non-done samples → [nd, obs_size]. - let obs_flat: Vec = non_done - .iter() - .flat_map(|&i| batch[i].next_obs.iter().copied()) - .collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [nd, obs_size]), - device, - ); - - // Forward target net → [nd, action_size], then to Vec. - let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); - - // For each non-done sample, pick max Q over legal next actions. - let mut result = vec![0.0f32; batch_size]; - for (k, &i) in non_done.iter().enumerate() { - let legal = &batch[i].next_legal; - let offset = k * action_size; - let max_q = legal - .iter() - .map(|&a| q_flat[offset + a]) - .fold(f32::NEG_INFINITY, f32::max); - // If legal is empty (shouldn't happen for non-done, but be safe): - result[i] = if max_q.is_finite() { max_q } else { 0.0 }; - } - result -} - -// ── Training step ───────────────────────────────────────────────────────────── - -/// Run one gradient step on `q_net` using `batch`. -/// -/// `target_max_q` must be pre-computed via [`compute_target_q`] using the -/// frozen target network and passed in here so that this function only -/// needs the **autodiff backend**. -/// -/// Returns the updated network and the scalar MSE loss. -pub fn dqn_train_step( - q_net: Q, - optimizer: &mut O, - batch: &[DqnSample], - target_max_q: &[f32], - device: &B::Device, - lr: f64, - gamma: f32, -) -> (Q, f32) -where - B: AutodiffBackend, - Q: QValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "dqn_train_step: empty batch"); - assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - - // ── Build observation tensor [B, obs_size] ──────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - - // ── Forward Q-net → [B, action_size] ───────────────────────────────── - let q_all = q_net.forward(obs_tensor); - - // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── - let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); - let action_tensor: Tensor = Tensor::::from_data( - TensorData::new(actions, [batch_size]), - device, - ) - .reshape([batch_size, 1]); // [B] → [B, 1] - let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] - - // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── - let targets: Vec = batch - .iter() - .zip(target_max_q.iter()) - .map(|(s, &max_q)| { - if s.done { s.reward } else { s.reward + gamma * max_q } - }) - .collect(); - let target_tensor = Tensor::::from_data( - TensorData::new(targets, [batch_size]), - device, - ); - - // ── MSE loss ────────────────────────────────────────────────────────── - let diff = q_pred - target_tensor.detach(); - let loss = (diff.clone() * diff).mean(); - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &q_net); - let q_net = optimizer.step(lr, q_net, grads); - - (q_net, loss_scalar) -} - -// ── Target network update ───────────────────────────────────────────────────── - -/// Hard-copy the online Q-net weights to a new target network. -/// -/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an -/// inference-backend module with identical weights. -pub fn hard_update>(q_net: &Q) -> Q::InnerModule { - q_net.valid() -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - use crate::network::{QNet, QNetConfig}; - - type InferB = NdArray; - type TrainB = Autodiff>; - - fn infer_device() -> ::Device { Default::default() } - fn train_device() -> ::Device { Default::default() } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| DqnSample { - obs: vec![0.5f32; obs_size], - action: i % action_size, - reward: if i % 2 == 0 { 1.0 } else { -1.0 }, - next_obs: vec![0.5f32; obs_size], - next_legal: vec![0, 1], - done: i == n - 1, - }) - .collect() - } - - #[test] - fn compute_target_q_length() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 8); - } - - #[test] - fn compute_target_q_done_is_zero() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - // Single done sample. - let batch = vec![DqnSample { - obs: vec![0.0; 4], - action: 0, - reward: 5.0, - next_obs: vec![0.0; 4], - next_legal: vec![], - done: true, - }]; - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 1); - assert_eq!(tq[0], 0.0); - } - - #[test] - fn train_step_returns_finite_loss() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - - #[test] - fn train_step_loss_decreases() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(16, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, - ); - q_net = q; - assert!(loss.is_finite()); - prev_loss = loss; - } - assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn hard_update_copies_weights() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = hard_update::(&q_net); - - let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); - let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); - // After hard_update the target produces finite outputs. - assert!(q_out.iter().all(|v| v.is_finite())); - } -} diff --git a/spiel_bot/src/env/mod.rs b/spiel_bot/src/env/mod.rs deleted file mode 100644 index 42b4ae0..0000000 --- a/spiel_bot/src/env/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -//! Game environment abstraction — the minimal "Rust OpenSpiel". -//! -//! A `GameEnv` describes the rules of a two-player, zero-sum game that may -//! contain stochastic (chance) nodes. Algorithms such as AlphaZero, DQN, -//! and PPO interact with a game exclusively through this trait. -//! -//! # Node taxonomy -//! -//! Every game position belongs to one of four categories, returned by -//! [`GameEnv::current_player`]: -//! -//! | [`Player`] | Meaning | -//! |-----------|---------| -//! | `P1` | Player 1 (index 0) must choose an action | -//! | `P2` | Player 2 (index 1) must choose an action | -//! | `Chance` | A stochastic event must be sampled (dice roll, card draw…) | -//! | `Terminal` | The game is over; [`GameEnv::returns`] is meaningful | -//! -//! # Perspective convention -//! -//! [`GameEnv::observation`] always returns the board from *the requested -//! player's* point of view. Callers pass `pov = 0` for Player 1 and -//! `pov = 1` for Player 2. The implementation is responsible for any -//! mirroring required (e.g. Trictrac always reasons from White's side). - -pub mod trictrac; -pub use trictrac::TrictracEnv; - -/// Who controls the current game node. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Player { - /// Player 1 (index 0) is to move. - P1, - /// Player 2 (index 1) is to move. - P2, - /// A stochastic event (dice roll, etc.) must be resolved. - Chance, - /// The game is over. - Terminal, -} - -impl Player { - /// Returns the player index (0 or 1) if this is a decision node, - /// or `None` for `Chance` / `Terminal`. - pub fn index(self) -> Option { - match self { - Player::P1 => Some(0), - Player::P2 => Some(1), - _ => None, - } - } - - pub fn is_decision(self) -> bool { - matches!(self, Player::P1 | Player::P2) - } - - pub fn is_chance(self) -> bool { - self == Player::Chance - } - - pub fn is_terminal(self) -> bool { - self == Player::Terminal - } -} - -/// Trait that completely describes a two-player zero-sum game. -/// -/// Implementors must be cheaply cloneable (the type is used as a stateless -/// factory; the mutable game state lives in `Self::State`). -pub trait GameEnv: Clone + Send + Sync + 'static { - /// The mutable game state. Must be `Clone` so MCTS can copy - /// game trees without touching the environment. - type State: Clone + Send + Sync; - - // ── State creation ──────────────────────────────────────────────────── - - /// Create a fresh game state at the initial position. - fn new_game(&self) -> Self::State; - - // ── Node queries ────────────────────────────────────────────────────── - - /// Classify the current node. - fn current_player(&self, s: &Self::State) -> Player; - - /// Legal action indices at a decision node (`current_player` is `P1`/`P2`). - /// - /// The returned indices are in `[0, action_space())`. - /// The result is unspecified (may panic or return empty) when called at a - /// `Chance` or `Terminal` node. - fn legal_actions(&self, s: &Self::State) -> Vec; - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action. `action` must be a value returned by - /// [`legal_actions`] for the current state. - fn apply(&self, s: &mut Self::State, action: usize); - - /// Sample and apply a stochastic outcome. Must only be called when - /// `current_player(s) == Player::Chance`. - fn apply_chance(&self, s: &mut Self::State, rng: &mut R); - - // ── Observation ─────────────────────────────────────────────────────── - - /// Observation tensor from player `pov`'s perspective (0 = P1, 1 = P2). - /// The returned slice has exactly [`obs_size()`] elements, all in `[0, 1]`. - fn observation(&self, s: &Self::State, pov: usize) -> Vec; - - /// Number of floats returned by [`observation`]. - fn obs_size(&self) -> usize; - - /// Total number of distinct action indices (the policy head output size). - fn action_space(&self) -> usize; - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Game outcome for each player, or `None` if the game is not over. - /// - /// Values are in `[-1, 1]`: `+1.0` = win, `-1.0` = loss, `0.0` = draw. - /// Index 0 = Player 1, index 1 = Player 2. - fn returns(&self, s: &Self::State) -> Option<[f32; 2]>; -} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs deleted file mode 100644 index 8dc3676..0000000 --- a/spiel_bot/src/env/trictrac.rs +++ /dev/null @@ -1,547 +0,0 @@ -//! [`GameEnv`] implementation for Trictrac. -//! -//! # Game flow (schools_enabled = false) -//! -//! With scoring schools disabled (the standard training configuration), -//! `MarkPoints` and `MarkAdvPoints` stages are never reached — the engine -//! applies them automatically inside `RollResult` and `Move`. The only -//! four stages that actually occur are: -//! -//! | `TurnStage` | [`Player`] kind | Handled by | -//! |-------------|-----------------|------------| -//! | `RollDice` | `Chance` | [`apply_chance`] | -//! | `RollWaiting` | `Chance` | [`apply_chance`] | -//! | `HoldOrGoChoice` | `P1`/`P2` | [`apply`] | -//! | `Move` | `P1`/`P2` | [`apply`] | -//! -//! # Perspective -//! -//! The Trictrac engine always reasons from White's perspective. Player 1 is -//! White; Player 2 is Black. When Player 2 is active, the board is mirrored -//! before computing legal actions / the observation tensor, and the resulting -//! event is mirrored back before being applied to the real state. This -//! mirrors the pattern used in `cxxengine.rs` and `random_game.rs`. - -use trictrac_store::{ - training_common::{get_valid_action_indices, TrictracAction, ACTION_SPACE_SIZE}, - Dice, GameEvent, GameState, Stage, TurnStage, -}; - -use super::{GameEnv, Player}; - -/// Stateless factory that produces Trictrac [`GameState`] environments. -/// -/// Schools (`schools_enabled`) are always disabled — scoring is automatic. -#[derive(Clone, Debug, Default)] -pub struct TrictracEnv; - -impl GameEnv for TrictracEnv { - type State = GameState; - - // ── State creation ──────────────────────────────────────────────────── - - fn new_game(&self) -> GameState { - GameState::new_with_players("P1", "P2") - } - - // ── Node queries ────────────────────────────────────────────────────── - - fn current_player(&self, s: &GameState) -> Player { - if s.stage == Stage::Ended { - return Player::Terminal; - } - match s.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => Player::Chance, - _ => { - if s.active_player_id == 1 { - Player::P1 - } else { - Player::P2 - } - } - } - } - - /// Returns the legal action indices for the active player. - /// - /// The board is automatically mirrored for Player 2 so that the engine - /// always reasons from White's perspective. The returned indices are - /// identical in meaning for both players (checker ordinals are - /// perspective-relative). - /// - /// # Panics - /// - /// Panics in debug builds if called at a `Chance` or `Terminal` node. - fn legal_actions(&self, s: &GameState) -> Vec { - debug_assert!( - self.current_player(s).is_decision(), - "legal_actions called at a non-decision node (turn_stage={:?})", - s.turn_stage - ); - let indices = if s.active_player_id == 2 { - get_valid_action_indices(&s.mirror()) - } else { - get_valid_action_indices(s) - }; - indices.unwrap_or_default() - } - - // ── State mutation ──────────────────────────────────────────────────── - - /// Apply a player action index to the game state. - /// - /// For Player 2, the action is decoded against the mirrored board and - /// the resulting event is un-mirrored before being applied. - /// - /// # Panics - /// - /// Panics in debug builds if `action` cannot be decoded or does not - /// produce a valid event for the current state. - fn apply(&self, s: &mut GameState, action: usize) { - let needs_mirror = s.active_player_id == 2; - - let event = if needs_mirror { - let view = s.mirror(); - TrictracAction::from_action_index(action) - .and_then(|a| a.to_event(&view)) - .map(|e| e.get_mirror(false)) - } else { - TrictracAction::from_action_index(action).and_then(|a| a.to_event(s)) - }; - - match event { - Some(e) => { - s.consume(&e).expect("apply: consume failed for valid action"); - } - None => { - panic!("apply: action index {action} produced no event in state {s}"); - } - } - } - - /// Sample dice and advance through a chance node. - /// - /// Handles both `RollDice` (triggers the roll mechanism, then samples - /// dice) and `RollWaiting` (only samples dice) in a single call so that - /// callers never need to distinguish the two. - /// - /// # Panics - /// - /// Panics in debug builds if called at a non-Chance node. - fn apply_chance(&self, s: &mut GameState, rng: &mut R) { - debug_assert!( - self.current_player(s).is_chance(), - "apply_chance called at a non-Chance node (turn_stage={:?})", - s.turn_stage - ); - - // Step 1: RollDice → RollWaiting (player initiates the roll). - if s.turn_stage == TurnStage::RollDice { - s.consume(&GameEvent::Roll { - player_id: s.active_player_id, - }) - .expect("apply_chance: Roll event failed"); - } - - // Step 2: RollWaiting → Move / HoldOrGoChoice / Ended. - // With schools_enabled=false, point marking is automatic inside consume(). - let dice = Dice { - values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)), - }; - s.consume(&GameEvent::RollResult { - player_id: s.active_player_id, - dice, - }) - .expect("apply_chance: RollResult event failed"); - } - - // ── Observation ─────────────────────────────────────────────────────── - - fn observation(&self, s: &GameState, pov: usize) -> Vec { - if pov == 0 { - s.to_tensor() - } else { - s.mirror().to_tensor() - } - } - - fn obs_size(&self) -> usize { - 217 - } - - fn action_space(&self) -> usize { - ACTION_SPACE_SIZE - } - - // ── Terminal values ─────────────────────────────────────────────────── - - /// Returns `Some([r1, r2])` when the game is over, `None` otherwise. - /// - /// The winner (higher cumulative score) receives `+1.0`; the loser - /// receives `-1.0`; an exact tie gives `0.0` each. A cumulative score - /// is `holes × 12 + points`. - fn returns(&self, s: &GameState) -> Option<[f32; 2]> { - if s.stage != Stage::Ended { - return None; - } - let score = |id: u64| -> i32 { - s.players - .get(&id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - }; - let s1 = score(1); - let s2 = score(2); - Some(match s1.cmp(&s2) { - std::cmp::Ordering::Greater => [1.0, -1.0], - std::cmp::Ordering::Less => [-1.0, 1.0], - std::cmp::Ordering::Equal => [0.0, 0.0], - }) - } -} - -// ── DQN helpers ─────────────────────────────────────────────────────────────── - -impl TrictracEnv { - /// Score snapshot for DQN reward computation. - /// - /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. - /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). - pub fn score_snapshot(s: &GameState) -> [i32; 2] { - [s.total_score(1), s.total_score(2)] - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{rngs::SmallRng, Rng, SeedableRng}; - - fn env() -> TrictracEnv { - TrictracEnv - } - - fn seeded_rng(seed: u64) -> SmallRng { - SmallRng::seed_from_u64(seed) - } - - // ── Initial state ───────────────────────────────────────────────────── - - #[test] - fn new_game_is_chance_node() { - let e = env(); - let s = e.new_game(); - // A fresh game starts at RollDice — a Chance node. - assert_eq!(e.current_player(&s), Player::Chance); - assert!(e.returns(&s).is_none()); - } - - #[test] - fn new_game_is_not_terminal() { - let e = env(); - let s = e.new_game(); - assert_ne!(e.current_player(&s), Player::Terminal); - assert!(e.returns(&s).is_none()); - } - - // ── Chance nodes ────────────────────────────────────────────────────── - - #[test] - fn apply_chance_reaches_decision_node() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(1); - - // A single chance step must yield a decision node (or end the game, - // which only happens after 12 holes — impossible on the first roll). - e.apply_chance(&mut s, &mut rng); - let p = e.current_player(&s); - assert!( - p.is_decision(), - "expected decision node after first roll, got {p:?}" - ); - } - - #[test] - fn apply_chance_from_rollwaiting() { - // Check that apply_chance works when called mid-way (at RollWaiting). - let e = env(); - let mut s = e.new_game(); - assert_eq!(s.turn_stage, TurnStage::RollDice); - - // Manually advance to RollWaiting. - s.consume(&GameEvent::Roll { player_id: s.active_player_id }) - .unwrap(); - assert_eq!(s.turn_stage, TurnStage::RollWaiting); - - let mut rng = seeded_rng(2); - e.apply_chance(&mut s, &mut rng); - - let p = e.current_player(&s); - assert!(p.is_decision() || p.is_terminal()); - } - - // ── Legal actions ───────────────────────────────────────────────────── - - #[test] - fn legal_actions_nonempty_after_roll() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(3); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let actions = e.legal_actions(&s); - assert!( - !actions.is_empty(), - "legal_actions must be non-empty at a decision node" - ); - } - - #[test] - fn legal_actions_within_action_space() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(4); - - e.apply_chance(&mut s, &mut rng); - for &a in e.legal_actions(&s).iter() { - assert!( - a < e.action_space(), - "action {a} out of bounds (action_space={})", - e.action_space() - ); - } - } - - // ── Observations ────────────────────────────────────────────────────── - - #[test] - fn observation_has_correct_size() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(5); - e.apply_chance(&mut s, &mut rng); - - assert_eq!(e.observation(&s, 0).len(), e.obs_size()); - assert_eq!(e.observation(&s, 1).len(), e.obs_size()); - } - - #[test] - fn observation_values_in_unit_interval() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(6); - e.apply_chance(&mut s, &mut rng); - - for (pov, obs) in [(0, e.observation(&s, 0)), (1, e.observation(&s, 1))] { - for (i, &v) in obs.iter().enumerate() { - assert!( - v >= 0.0 && v <= 1.0, - "pov={pov}: obs[{i}] = {v} is outside [0,1]" - ); - } - } - } - - #[test] - fn p1_and_p2_observations_differ() { - // The board is mirrored for P2, so the two observations should differ - // whenever there are checkers in non-symmetric positions (always true - // in a real game after a few moves). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(7); - - // Advance far enough that the board is non-trivial. - for _ in 0..6 { - while e.current_player(&s).is_chance() { - e.apply_chance(&mut s, &mut rng); - } - if e.current_player(&s).is_terminal() { - break; - } - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - - if !e.current_player(&s).is_terminal() { - let obs0 = e.observation(&s, 0); - let obs1 = e.observation(&s, 1); - assert_ne!(obs0, obs1, "P1 and P2 observations should differ on a non-symmetric board"); - } - } - - // ── Applying actions ────────────────────────────────────────────────── - - #[test] - fn apply_changes_state() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(8); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - let before = s.clone(); - let action = e.legal_actions(&s)[0]; - e.apply(&mut s, action); - - assert_ne!( - before.turn_stage, s.turn_stage, - "state must change after apply" - ); - } - - #[test] - fn apply_all_legal_actions_do_not_panic() { - // Verify that every action returned by legal_actions can be applied - // without panicking (on several independent copies of the same state). - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(9); - - e.apply_chance(&mut s, &mut rng); - assert!(e.current_player(&s).is_decision()); - - for action in e.legal_actions(&s) { - let mut copy = s.clone(); - e.apply(&mut copy, action); // must not panic - } - } - - // ── Full game ───────────────────────────────────────────────────────── - - /// Run a complete game with random actions through the `GameEnv` trait - /// and verify that: - /// - The game terminates. - /// - `returns()` is `Some` at the end. - /// - The outcome is valid: scores sum to 0 (zero-sum) or each player's - /// score is ±1 / 0. - /// - No step panics. - #[test] - fn full_random_game_terminates() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(42); - let max_steps = 50_000; - - for step in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty(), "step {step}: empty legal actions at decision node"); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - assert!(step < max_steps - 1, "game did not terminate within {max_steps} steps"); - } - - let result = e.returns(&s); - assert!(result.is_some(), "returns() must be Some at Terminal"); - - let [r1, r2] = result.unwrap(); - let sum = r1 + r2; - assert!( - (sum.abs() < 1e-5) || (sum - 0.0).abs() < 1e-5, - "game must be zero-sum: r1={r1}, r2={r2}, sum={sum}" - ); - assert!( - r1.abs() <= 1.0 && r2.abs() <= 1.0, - "returns must be in [-1,1]: r1={r1}, r2={r2}" - ); - } - - /// Run multiple games with different seeds to stress-test for panics. - #[test] - fn multiple_games_no_panic() { - let e = env(); - let max_steps = 20_000; - - for seed in 0..10u64 { - let mut s = e.new_game(); - let mut rng = seeded_rng(seed); - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - e.apply(&mut s, actions[idx]); - } - } - } - } - } - - // ── Returns ─────────────────────────────────────────────────────────── - - #[test] - fn returns_none_mid_game() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(11); - - // Advance a few steps but do not finish the game. - for _ in 0..4 { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P1 | Player::P2 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - if !e.current_player(&s).is_terminal() { - assert!( - e.returns(&s).is_none(), - "returns() must be None before the game ends" - ); - } - } - - // ── Player 2 actions ────────────────────────────────────────────────── - - /// Verify that Player 2 (Black) can take actions without panicking, - /// and that the state advances correctly. - #[test] - fn player2_can_act() { - let e = env(); - let mut s = e.new_game(); - let mut rng = seeded_rng(12); - - // Keep stepping until Player 2 gets a turn. - let max_steps = 5_000; - let mut p2_acted = false; - - for _ in 0..max_steps { - match e.current_player(&s) { - Player::Terminal => break, - Player::Chance => e.apply_chance(&mut s, &mut rng), - Player::P2 => { - let actions = e.legal_actions(&s); - assert!(!actions.is_empty()); - e.apply(&mut s, actions[0]); - p2_acted = true; - break; - } - Player::P1 => { - let actions = e.legal_actions(&s); - e.apply(&mut s, actions[0]); - } - } - } - - assert!(p2_acted, "Player 2 never got a turn in {max_steps} steps"); - } -} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs deleted file mode 100644 index cf6d865..0000000 --- a/spiel_bot/src/lib.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod alphazero; -pub mod dqn; -pub mod env; -pub mod mcts; -pub mod network; -pub mod strategy; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs deleted file mode 100644 index eead171..0000000 --- a/spiel_bot/src/mcts/mod.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. -//! -//! # Algorithm -//! -//! The implementation follows AlphaZero's MCTS: -//! -//! 1. **Expand root** — run the network once to get priors and a value -//! estimate; optionally add Dirichlet noise for training-time exploration. -//! 2. **Simulate** `n_simulations` times: -//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. -//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; -//! chance nodes are **not** stored in the tree (outcome sampling). -//! - *Expansion* — evaluate the network at the leaf; populate children. -//! - *Backup* — propagate the value upward; negate at each player boundary. -//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). -//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). -//! -//! # Perspective convention -//! -//! Every [`MctsNode::w`] is stored **from the perspective of the player who -//! acts at that node**. The backup negates the child value whenever the -//! acting player differs between parent and child. -//! -//! # Stochastic games -//! -//! When [`GameEnv::current_player`] returns [`Player::Chance`], the -//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and -//! continues. Chance nodes are skipped transparently; Q-values converge to -//! their expectation over many simulations (outcome sampling). - -pub mod node; -mod search; - -pub use node::MctsNode; - -use rand::Rng; - -use crate::env::GameEnv; - -// ── Evaluator trait ──────────────────────────────────────────────────────── - -/// Evaluates a game position for use in MCTS. -/// -/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) -/// but the `mcts` module itself does **not** depend on Burn. -pub trait Evaluator: Send + Sync { - /// Evaluate `obs` (flat observation vector of length `obs_size`). - /// - /// Returns: - /// - `policy_logits`: one raw logit per action (`action_space` entries). - /// Illegal action entries are masked inside the search — no need to - /// zero them here. - /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. - fn evaluate(&self, obs: &[f32]) -> (Vec, f32); -} - -// ── Configuration ───────────────────────────────────────────────────────── - -/// Hyperparameters for [`run_mcts`]. -#[derive(Debug, Clone)] -pub struct MctsConfig { - /// Number of MCTS simulations per move. Typical: 50–800. - pub n_simulations: usize, - /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. - pub c_puct: f32, - /// Dirichlet noise concentration α. Set to `0.0` to disable. - /// Typical: `0.3` for Chess, `0.1` for large action spaces. - pub dirichlet_alpha: f32, - /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. - pub dirichlet_eps: f32, - /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. - pub temperature: f32, -} - -impl Default for MctsConfig { - fn default() -> Self { - Self { - n_simulations: 200, - c_puct: 1.5, - dirichlet_alpha: 0.3, - dirichlet_eps: 0.25, - temperature: 1.0, - } - } -} - -// ── Public interface ─────────────────────────────────────────────────────── - -/// Run MCTS from `state` and return the populated root [`MctsNode`]. -/// -/// `state` must be a player-decision node (`P1` or `P2`). -/// Use [`mcts_policy`] and [`select_action`] on the returned root. -/// -/// # Panics -/// -/// Panics if `env.current_player(state)` is not `P1` or `P2`. -pub fn run_mcts( - env: &E, - state: &E::State, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, -) -> MctsNode { - let player_idx = env - .current_player(state) - .index() - .expect("run_mcts called at a non-decision node"); - - // ── Expand root (network called once here, not inside the loop) ──────── - let mut root = MctsNode::new(1.0); - search::expand::(&mut root, state, env, evaluator, player_idx); - - // ── Optional Dirichlet noise for training exploration ────────────────── - if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { - search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); - } - - // ── Simulations ──────────────────────────────────────────────────────── - for _ in 0..config.n_simulations { - search::simulate::( - &mut root, - state.clone(), - env, - evaluator, - config, - rng, - player_idx, - ); - } - - root -} - -/// Compute the MCTS policy: normalized visit counts at the root. -/// -/// Returns a vector of length `action_space` where `policy[a]` is the -/// fraction of simulations that visited action `a`. -pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { - let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); - let mut policy = vec![0.0f32; action_space]; - if total > 0.0 { - for (a, child) in &root.children { - policy[*a] = child.n as f32 / total; - } - } else if !root.children.is_empty() { - // n_simulations = 0: uniform over legal actions. - let uniform = 1.0 / root.children.len() as f32; - for (a, _) in &root.children { - policy[*a] = uniform; - } - } - policy -} - -/// Select an action index from the root after MCTS. -/// -/// * `temperature = 0` — greedy argmax of visit counts. -/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. -/// -/// # Panics -/// -/// Panics if the root has no children. -pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { - assert!(!root.children.is_empty(), "select_action called on a root with no children"); - if temperature <= 0.0 { - root.children - .iter() - .max_by_key(|(_, c)| c.n) - .map(|(a, _)| *a) - .unwrap() - } else { - let weights: Vec = root - .children - .iter() - .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) - .collect(); - let total: f32 = weights.iter().sum(); - let mut r: f32 = rng.random::() * total; - for (i, (a, _)) in root.children.iter().enumerate() { - r -= weights[i]; - if r <= 0.0 { - return *a; - } - } - root.children.last().map(|(a, _)| *a).unwrap() - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - use crate::env::Player; - - // ── Minimal deterministic test game ─────────────────────────────────── - // - // "Countdown" — two players alternate subtracting 1 or 2 from a counter. - // The player who brings the counter to 0 wins. - // No chance nodes, two legal actions (0 = -1, 1 = -2). - - #[derive(Clone, Debug)] - struct CState { - remaining: u8, - to_move: usize, // at terminal: last mover (winner) - } - - #[derive(Clone)] - struct CountdownEnv; - - impl crate::env::GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: 6, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { - Player::Terminal - } else if s.to_move == 0 { - Player::P1 - } else { - Player::P2 - } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - // to_move stays as winner - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / 6.0, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } - } - - // Uniform evaluator: all logits = 0, value = 0. - // `action_space` must match the environment's `action_space()`. - struct ZeroEval(usize); - impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } - } - - fn rng() -> SmallRng { - SmallRng::seed_from_u64(42) - } - - fn config_n(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, // off for reproducibility - dirichlet_eps: 0.0, - temperature: 1.0, - } - } - - // ── Visit count tests ───────────────────────────────────────────────── - - #[test] - fn visit_counts_sum_to_n_simulations() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 50, "visit counts must sum to n_simulations"); - } - - #[test] - fn all_root_children_are_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); - for (a, _) in &root.children { - assert!(legal.contains(a), "child action {a} is not legal"); - } - } - - // ── Policy tests ───────────────────────────────────────────────────── - - #[test] - fn policy_sums_to_one() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); - } - - #[test] - fn policy_zero_for_illegal_actions() { - let env = CountdownEnv; - // remaining = 1 → only action 0 is legal - let state = CState { remaining: 1, to_move: 0 }; - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); - } - - // ── Action selection tests ──────────────────────────────────────────── - - #[test] - fn greedy_selects_most_visited() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); - let greedy = select_action(&root, 0.0, &mut rng()); - let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); - assert_eq!(greedy, most_visited); - } - - #[test] - fn temperature_sampling_stays_legal() { - let env = CountdownEnv; - let state = env.new_game(); - let legal = env.legal_actions(&state); - let mut r = rng(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); - for _ in 0..20 { - let a = select_action(&root, 1.0, &mut r); - assert!(legal.contains(&a), "sampled action {a} is not legal"); - } - } - - // ── Zero-simulation edge case ───────────────────────────────────────── - - #[test] - fn zero_simulations_uniform_policy() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); - let policy = mcts_policy(&root, env.action_space()); - // With 0 simulations, fallback is uniform over the 2 legal actions. - let sum: f32 = policy.iter().sum(); - assert!((sum - 1.0).abs() < 1e-5); - } - - // ── Root value ──────────────────────────────────────────────────────── - - #[test] - fn root_q_in_valid_range() { - let env = CountdownEnv; - let state = env.new_game(); - let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); - let q = root.q(); - assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); - } - - // ── Integration: run on a real Trictrac game ────────────────────────── - - #[test] - fn no_panic_on_trictrac_state() { - use crate::env::TrictracEnv; - - let env = TrictracEnv; - let mut state = env.new_game(); - let mut r = rng(); - - // Advance past the initial chance node to reach a decision node. - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, &mut r); - } - - if env.current_player(&state).is_terminal() { - return; // unlikely but safe - } - - let config = MctsConfig { - n_simulations: 5, // tiny for speed - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - ..MctsConfig::default() - }; - - let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - // root.n = 1 (expansion) + n_simulations (one backup per simulation). - assert_eq!(root.n, 1 + config.n_simulations as u32); - // Every simulation crosses a chance node at depth 1 (dice roll after - // the player's move). Since the fix now updates child.n in that case, - // children visit counts must sum to exactly n_simulations. - let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, config.n_simulations as u32); - } -} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs deleted file mode 100644 index aff7735..0000000 --- a/spiel_bot/src/mcts/node.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! MCTS tree node. -//! -//! [`MctsNode`] holds the visit statistics for one player-decision position in -//! the search tree. A node is *expanded* the first time the policy-value -//! network is evaluated there; before that it is a leaf. - -/// One node in the MCTS tree, representing a player-decision position. -/// -/// `w` stores the sum of values backed up into this node, always from the -/// perspective of **the player who acts here**. `q()` therefore also returns -/// a value in `(-1, 1)` from that same perspective. -#[derive(Debug)] -pub struct MctsNode { - /// Visit count `N(s, a)`. - pub n: u32, - /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. - pub w: f32, - /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). - pub p: f32, - /// Children: `(action_index, child_node)`, populated on first expansion. - pub children: Vec<(usize, MctsNode)>, - /// `true` after the network has been evaluated and children have been set up. - pub expanded: bool, -} - -impl MctsNode { - /// Create a fresh, unexpanded leaf with the given prior probability. - pub fn new(prior: f32) -> Self { - Self { - n: 0, - w: 0.0, - p: prior, - children: Vec::new(), - expanded: false, - } - } - - /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. - #[inline] - pub fn q(&self) -> f32 { - if self.n == 0 { 0.0 } else { self.w / self.n as f32 } - } - - /// PUCT selection score: - /// - /// ```text - /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) - /// ``` - #[inline] - pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { - self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) - } -} - -// ── Tests ────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn q_zero_when_unvisited() { - let node = MctsNode::new(0.5); - assert_eq!(node.q(), 0.0); - } - - #[test] - fn q_reflects_w_over_n() { - let mut node = MctsNode::new(0.5); - node.n = 4; - node.w = 2.0; - assert!((node.q() - 0.5).abs() < 1e-6); - } - - #[test] - fn puct_exploration_dominates_unvisited() { - // Unvisited child should outscore a visited child with negative Q. - let mut visited = MctsNode::new(0.5); - visited.n = 10; - visited.w = -5.0; // Q = -0.5 - - let unvisited = MctsNode::new(0.5); - - let parent_n = 10; - let c = 1.5; - assert!( - unvisited.puct(parent_n, c) > visited.puct(parent_n, c), - "unvisited child should have higher PUCT than a negatively-valued visited child" - ); - } -} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs deleted file mode 100644 index 4d36acc..0000000 --- a/spiel_bot/src/mcts/search.rs +++ /dev/null @@ -1,197 +0,0 @@ -//! Simulation, expansion, backup, and noise helpers. -//! -//! These are internal to the `mcts` module; the public entry points are -//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. - -use rand::Rng; -use rand_distr::{Gamma, Distribution}; - -use crate::env::GameEnv; -use super::{Evaluator, MctsConfig}; -use super::node::MctsNode; - -// ── Masked softmax ───────────────────────────────────────────────────────── - -/// Numerically stable softmax over `legal` actions only. -/// -/// Illegal logits are treated as `-∞` and receive probability `0.0`. -/// Returns a probability vector of length `action_space`. -pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { - let mut probs = vec![0.0f32; action_space]; - if legal.is_empty() { - return probs; - } - let max_logit = legal - .iter() - .map(|&a| logits[a]) - .fold(f32::NEG_INFINITY, f32::max); - let mut sum = 0.0f32; - for &a in legal { - let e = (logits[a] - max_logit).exp(); - probs[a] = e; - sum += e; - } - if sum > 0.0 { - for &a in legal { - probs[a] /= sum; - } - } else { - let uniform = 1.0 / legal.len() as f32; - for &a in legal { - probs[a] = uniform; - } - } - probs -} - -// ── Dirichlet noise ──────────────────────────────────────────────────────── - -/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. -/// -/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. -/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. -pub(super) fn add_dirichlet_noise( - node: &mut MctsNode, - alpha: f32, - eps: f32, - rng: &mut impl Rng, -) { - let n = node.children.len(); - if n == 0 { - return; - } - let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { - return; - }; - let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); - let sum: f32 = samples.iter().sum(); - if sum <= 0.0 { - return; - } - for (i, (_, child)) in node.children.iter_mut().enumerate() { - let noise = samples[i] / sum; - child.p = (1.0 - eps) * child.p + eps * noise; - } -} - -// ── Expansion ────────────────────────────────────────────────────────────── - -/// Evaluate the network at `state` and populate `node` with children. -/// -/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. -/// Returns the network value estimate from `player_idx`'s perspective. -pub(super) fn expand( - node: &mut MctsNode, - state: &E::State, - env: &E, - evaluator: &dyn Evaluator, - player_idx: usize, -) -> f32 { - let obs = env.observation(state, player_idx); - let legal = env.legal_actions(state); - let (logits, value) = evaluator.evaluate(&obs); - let priors = masked_softmax(&logits, &legal, env.action_space()); - node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); - node.expanded = true; - node.n = 1; - node.w = value; - value -} - -// ── Simulation ───────────────────────────────────────────────────────────── - -/// One MCTS simulation from an **already-expanded** decision node. -/// -/// Traverses the tree with PUCT selection, expands the first unvisited leaf, -/// and backs up the result. -/// -/// * `player_idx` — the player (0 or 1) who acts at `state`. -/// * Returns the backed-up value **from `player_idx`'s perspective**. -pub(super) fn simulate( - node: &mut MctsNode, - state: E::State, - env: &E, - evaluator: &dyn Evaluator, - config: &MctsConfig, - rng: &mut impl Rng, - player_idx: usize, -) -> f32 { - debug_assert!(node.expanded, "simulate called on unexpanded node"); - - // ── Selection: child with highest PUCT ──────────────────────────────── - let parent_n = node.n; - let best = node - .children - .iter() - .enumerate() - .max_by(|(_, (_, a)), (_, (_, b))| { - a.puct(parent_n, config.c_puct) - .partial_cmp(&b.puct(parent_n, config.c_puct)) - .unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(i, _)| i) - .expect("expanded node must have at least one child"); - - let (action, child) = &mut node.children[best]; - let action = *action; - - // ── Apply action + advance through any chance nodes ─────────────────── - let mut next_state = state; - env.apply(&mut next_state, action); - - // Track whether we crossed a chance node (dice roll) on the way down. - // If we did, the child's cached legal actions are for a *different* dice - // outcome and must not be reused — evaluate with the network directly. - let mut crossed_chance = false; - while env.current_player(&next_state).is_chance() { - env.apply_chance(&mut next_state, rng); - crossed_chance = true; - } - - let next_cp = env.current_player(&next_state); - - // ── Evaluate leaf or terminal ────────────────────────────────────────── - // All values are converted to `player_idx`'s perspective before backup. - let child_value = if next_cp.is_terminal() { - let returns = env - .returns(&next_state) - .expect("terminal node must have returns"); - let v = returns[player_idx]; - // Update child stats so PUCT and mcts_policy count terminal visits. - // Store from player_idx's perspective so child.q() is directly usable - // by the parent's PUCT selection (high = good for the selecting player). - child.n += 1; - child.w += v; - v - } else { - let child_player = next_cp.index().unwrap(); - let v = if crossed_chance { - // Outcome sampling: after dice, evaluate the resulting position - // directly with the network. Do NOT build the tree across chance - // boundaries — the dice change which actions are legal, so any - // previously cached children would be for a different outcome. - let obs = env.observation(&next_state, child_player); - let (_, value) = evaluator.evaluate(&obs); - // Store from player_idx's (parent's) perspective so PUCT works correctly. - // `value` is from child_player's POV; negate when child is the opponent - // so that child.q() = expected return for the player CHOOSING this child. - // Without the negation, root would maximise the opponent's Q-value and - // systematically pick the worst action. - child.n += 1; - child.w += if child_player == player_idx { value } else { -value }; - value - } else if child.expanded { - simulate(child, next_state, env, evaluator, config, rng, child_player) - } else { - expand::(child, &next_state, env, evaluator, child_player) - }; - // Negate when the child belongs to the opponent. - if child_player == player_idx { v } else { -v } - }; - - // ── Backup ──────────────────────────────────────────────────────────── - node.n += 1; - node.w += child_value; - - child_value -} diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs deleted file mode 100644 index eb6184e..0000000 --- a/spiel_bot/src/network/mlp.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! Two-hidden-layer MLP policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`MlpNet`]. -#[derive(Debug, Clone)] -pub struct MlpConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for MlpConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 256, - } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Simple two-hidden-layer MLP with shared trunk and two heads. -/// -/// Prefer this over [`ResNet`](super::ResNet) when training time is a -/// priority, or as a fast baseline. -#[derive(Module, Debug)] -pub struct MlpNet { - fc1: Linear, - fc2: Linear, - policy_head: Linear, - value_head: Linear, -} - -impl MlpNet { - /// Construct a fresh network with random weights. - pub fn new(config: &MlpConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - value_head: LinearConfig::new(config.hidden_size, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - /// - /// The file is written exactly at `path`; callers should append `.mpk` if - /// they want the conventional extension. - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for MlpNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn default_net() -> MlpNet { - MlpNet::new(&MlpConfig::default(), &device()) - } - - fn zeros_obs(batch: usize) -> Tensor { - Tensor::zeros([batch, 217], &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let net = default_net(); - let obs = zeros_obs(4); - let (policy, value) = net.forward(obs); - - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let (policy, value) = net.forward(zeros_obs(1)); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let net = default_net(); - // Use a non-zero input so the output is not trivially at 0. - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net.forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Policy logits ───────────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - // With random weights the 514 logits should not all be identical. - let net = default_net(); - let (policy, _) = net.forward(zeros_obs(1)); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical — network may be degenerate"); - } - - // ── Config propagation ──────────────────────────────────────────────── - - #[test] - fn custom_config_shapes() { - let config = MlpConfig { - obs_size: 10, - action_size: 20, - hidden_size: 32, - }; - let net = MlpNet::::new(&config, &device()); - let obs = Tensor::zeros([3, 10], &device()); - let (policy, value) = net.forward(obs); - assert_eq!(policy.dims(), [3, 20]); - assert_eq!(value.dims(), [3, 1]); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = MlpConfig::default(); - let net = default_net(); - - // Forward pass before saving. - let obs = Tensor::::ones([2, 217], &device()); - let (policy_before, value_before) = net.forward(obs.clone()); - - // Save to a temp file. - let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); - net.save(&path).expect("save failed"); - - // Load into a fresh model. - let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - // Outputs must be bitwise identical. - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } -} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs deleted file mode 100644 index 64f93ec..0000000 --- a/spiel_bot/src/network/mod.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Neural network abstractions for policy-value learning. -//! -//! # Trait -//! -//! [`PolicyValueNet`] is the single trait that all network architectures -//! implement. It takes an observation tensor and returns raw policy logits -//! plus a tanh-squashed scalar value estimate. -//! -//! # Architectures -//! -//! | Module | Description | Default hidden | -//! |--------|-------------|----------------| -//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | -//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | -//! -//! # Backend convention -//! -//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). -//! * **Training** — use `Autodiff>` so Burn can differentiate -//! through the forward pass. -//! -//! Both modes use the exact same struct; only the type-level backend changes: -//! -//! ```rust,ignore -//! use burn::backend::{Autodiff, NdArray}; -//! type InferBackend = NdArray; -//! type TrainBackend = Autodiff>; -//! -//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); -//! ``` -//! -//! # Output shapes -//! -//! Given a batch of `B` observations of size `obs_size`: -//! -//! | Output | Shape | Range | -//! |--------|-------|-------| -//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | -//! | `value` | `[B, 1]` | (-1, 1) via tanh | -//! -//! Callers are responsible for masking illegal actions in `policy_logits` -//! before passing to softmax. - -pub mod mlp; -pub mod qnet; -pub mod resnet; - -pub use mlp::{MlpConfig, MlpNet}; -pub use qnet::{QNet, QNetConfig}; -pub use resnet::{ResNet, ResNetConfig}; - -use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; - -/// A neural network that produces a policy and a value from an observation. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) -/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) -/// -/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses -/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. -/// Use an `Arc>` wrapper if cross-thread sharing is needed. -pub trait PolicyValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor); -} - -/// A neural network that outputs one Q-value per action. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - output: `[batch, action_size]` — raw Q-values (no activation) -/// -/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. -pub trait QValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> Tensor; -} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs deleted file mode 100644 index 1737f72..0000000 --- a/spiel_bot/src/network/qnet.rs +++ /dev/null @@ -1,147 +0,0 @@ -//! Single-headed Q-value network for DQN. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! → Linear(hidden → action_size) ← raw Q-values, no activation -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{activation::relu, backend::Backend, Tensor}, -}; -use std::path::Path; - -use super::QValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`QNet`]. -#[derive(Debug, Clone)] -pub struct QNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for QNetConfig { - fn default() -> Self { - Self { obs_size: 217, action_size: 514, hidden_size: 256 } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Two-hidden-layer MLP that outputs one Q-value per action. -#[derive(Module, Debug)] -pub struct QNet { - fc1: Linear, - fc2: Linear, - q_head: Linear, -} - -impl QNet { - /// Construct a fresh network with random weights. - pub fn new(config: &QNetConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl QValueNet for QNet { - fn forward(&self, obs: Tensor) -> Tensor { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - self.q_head.forward(x) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - - fn default_net() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn forward_output_shape() { - let net = default_net(); - let obs = Tensor::zeros([4, 217], &device()); - let q = net.forward(obs); - assert_eq!(q.dims(), [4, 514]); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let q = net.forward(Tensor::zeros([1, 217], &device())); - assert_eq!(q.dims(), [1, 514]); - } - - #[test] - fn q_values_not_all_equal() { - let net = default_net(); - let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) - .into_data().to_vec().unwrap(); - let first = q[0]; - assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); - } - - #[test] - fn custom_config_shapes() { - let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; - let net = QNet::::new(&cfg, &device()); - let q = net.forward(Tensor::zeros([3, 10], &device())); - assert_eq!(q.dims(), [3, 20]); - } - - #[test] - fn save_load_preserves_weights() { - let net = default_net(); - let obs = Tensor::::ones([2, 217], &device()); - let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); - - let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); - net.save(&path).expect("save failed"); - - let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); - let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); - - for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); - } - let _ = std::fs::remove_file(path); - } -} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs deleted file mode 100644 index d20d5ad..0000000 --- a/spiel_bot/src/network/resnet.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Residual-block policy-value network. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU (input projection) -//! → ResBlock × 4 (residual trunk) -//! ├─ policy_head: Linear(hidden → action_size) [raw logits] -//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] -//! -//! ResBlock: -//! x → Linear → ReLU → Linear → (+x) → ReLU -//! ``` -//! -//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better -//! suited for long training runs where board-pattern recognition matters. - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{ - activation::{relu, tanh}, - backend::Backend, - Tensor, - }, -}; -use std::path::Path; - -use super::PolicyValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`ResNet`]. -#[derive(Debug, Clone)] -pub struct ResNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of all hidden layers (input projection + residual blocks). - pub hidden_size: usize, -} - -impl Default for ResNetConfig { - fn default() -> Self { - Self { - obs_size: 217, - action_size: 514, - hidden_size: 512, - } - } -} - -// ── Residual block ──────────────────────────────────────────────────────────── - -/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. -/// -/// Both linear layers preserve the hidden dimension so the skip connection -/// can be added without projection. -#[derive(Module, Debug)] -struct ResBlock { - fc1: Linear, - fc2: Linear, -} - -impl ResBlock { - fn new(hidden: usize, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(hidden, hidden).init(device), - fc2: LinearConfig::new(hidden, hidden).init(device), - } - } - - fn forward(&self, x: Tensor) -> Tensor { - let residual = x.clone(); - let out = relu(self.fc1.forward(x)); - relu(self.fc2.forward(out) + residual) - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Four-residual-block policy-value network. -/// -/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and -/// when representing complex positional patterns is important. -#[derive(Module, Debug)] -pub struct ResNet { - input: Linear, - block0: ResBlock, - block1: ResBlock, - block2: ResBlock, - block3: ResBlock, - policy_head: Linear, - value_head: Linear, -} - -impl ResNet { - /// Construct a fresh network with random weights. - pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { - let h = config.hidden_size; - Self { - input: LinearConfig::new(config.obs_size, h).init(device), - block0: ResBlock::new(h, device), - block1: ResBlock::new(h, device), - block2: ResBlock::new(h, device), - block3: ResBlock::new(h, device), - policy_head: LinearConfig::new(h, config.action_size).init(device), - value_head: LinearConfig::new(h, 1).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl PolicyValueNet for ResNet { - fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { - let x = relu(self.input.forward(obs)); - let x = self.block0.forward(x); - let x = self.block1.forward(x); - let x = self.block2.forward(x); - let x = self.block3.forward(x); - let policy = self.policy_head.forward(x.clone()); - let value = tanh(self.value_head.forward(x)); - (policy, value) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { - Default::default() - } - - fn small_config() -> ResNetConfig { - // Use a small hidden size so tests are fast. - ResNetConfig { - obs_size: 217, - action_size: 514, - hidden_size: 64, - } - } - - fn net() -> ResNet { - ResNet::new(&small_config(), &device()) - } - - // ── Shape tests ─────────────────────────────────────────────────────── - - #[test] - fn forward_output_shapes() { - let obs = Tensor::zeros([4, 217], &device()); - let (policy, value) = net().forward(obs); - assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); - assert_eq!(value.dims(), [4, 1], "value shape mismatch"); - } - - #[test] - fn forward_single_sample() { - let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); - assert_eq!(policy.dims(), [1, 514]); - assert_eq!(value.dims(), [1, 1]); - } - - // ── Value bounds ────────────────────────────────────────────────────── - - #[test] - fn value_in_tanh_range() { - let obs = Tensor::::ones([8, 217], &device()); - let (_, value) = net().forward(obs); - let data: Vec = value.into_data().to_vec().unwrap(); - for v in &data { - assert!( - *v > -1.0 && *v < 1.0, - "value {v} is outside open interval (-1, 1)" - ); - } - } - - // ── Residual connections ────────────────────────────────────────────── - - #[test] - fn policy_logits_not_all_equal() { - let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); - let data: Vec = policy.into_data().to_vec().unwrap(); - let first = data[0]; - let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); - assert!(!all_same, "all policy logits are identical"); - } - - // ── Save / Load ─────────────────────────────────────────────────────── - - #[test] - fn save_load_preserves_weights() { - let config = small_config(); - let model = net(); - let obs = Tensor::::ones([2, 217], &device()); - - let (policy_before, value_before) = model.forward(obs.clone()); - - let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); - model.save(&path).expect("save failed"); - - let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); - let (policy_after, value_after) = loaded.forward(obs); - - let p_before: Vec = policy_before.into_data().to_vec().unwrap(); - let p_after: Vec = policy_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let v_before: Vec = value_before.into_data().to_vec().unwrap(); - let v_after: Vec = value_after.into_data().to_vec().unwrap(); - for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); - } - - let _ = std::fs::remove_file(path); - } - - // ── Integration: both architectures satisfy PolicyValueNet ──────────── - - #[test] - fn resnet_satisfies_trait() { - fn requires_net>(net: &N, obs: Tensor) { - let (p, v) = net.forward(obs); - assert_eq!(p.dims()[1], 514); - assert_eq!(v.dims()[1], 1); - } - requires_net(&net(), Tensor::zeros([2, 217], &device())); - } -} diff --git a/spiel_bot/src/strategy.rs b/spiel_bot/src/strategy.rs deleted file mode 100644 index 8309bf3..0000000 --- a/spiel_bot/src/strategy.rs +++ /dev/null @@ -1,242 +0,0 @@ -//! [`BotStrategy`] implementations backed by `spiel_bot` models. -//! -//! | Strategy struct | Network | CLI token | -//! |-----------------|---------|-----------| -//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` | -//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` | -//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` | -//! -//! All strategies operate from **White's perspective** (player_id = 1) internally; -//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black. - -use std::cell::RefCell; -use std::path::Path; - -use burn::{ - backend::NdArray, - tensor::{Tensor, TensorData}, -}; -use rand::{SeedableRng, rngs::SmallRng}; -use trictrac_bot::BotStrategy; -use trictrac_store::{ - training_common::{get_valid_action_indices, TrictracAction}, - CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId, -}; - -use crate::{ - alphazero::BurnEvaluator, - env::{GameEnv, TrictracEnv}, - mcts::{self, Evaluator, MctsConfig}, - network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig}, -}; - -type B = NdArray; - -/// Default MCTS simulations per move used by [`AzBotStrategy`]. -pub const AZ_BOT_N_SIM: usize = 50; - -// ── Shared helpers ───────────────────────────────────────────────────────────── - -/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state. -fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> { - match TrictracAction::from_action_index(action)?.to_event(game)? { - GameEvent::Move { moves, .. } => Some(moves), - _ => None, - } -} - -/// Fallback: return the first legal move from `MoveRules` (always succeeds). -fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) { - let rules = MoveRules::new(&Color::White, &game.board, game.dice); - let moves = rules.get_possible_moves_sequences(true, vec![]); - *moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default())) -} - -// ── AzBotStrategy ───────────────────────────────────────────────────────────── - -/// AlphaZero bot usable as a [`BotStrategy`]. -/// -/// Supports both MlpNet and ResNet checkpoints through separate constructors. -/// Uses greedy (temperature = 0) MCTS for action selection. -/// -/// # Construction -/// -/// ```rust,ignore -/// // MlpNet with random weights -/// AzBotStrategy::new_mlp(None); -/// -/// // MlpNet from a checkpoint -/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk")); -/// -/// // ResNet from a checkpoint -/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk")); -/// ``` -pub struct AzBotStrategy { - game: GameState, - evaluator: Box, - mcts_config: MctsConfig, - /// Interior-mutable RNG so `choose_move(&self)` can drive MCTS. - rng: RefCell, -} - -impl std::fmt::Debug for AzBotStrategy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AzBotStrategy") - .field("n_sim", &self.mcts_config.n_simulations) - .finish() - } -} - -impl AzBotStrategy { - fn from_evaluator(evaluator: Box) -> Self { - Self { - game: GameState::default(), - evaluator, - mcts_config: MctsConfig { - n_simulations: AZ_BOT_N_SIM, - dirichlet_alpha: 0.0, // no noise during play - dirichlet_eps: 0.0, - temperature: 0.0, // greedy selection - ..MctsConfig::default() - }, - rng: RefCell::new(SmallRng::seed_from_u64(42)), - } - } - - /// MlpNet-backed bot. `path = None` → random weights. - pub fn new_mlp(path: Option<&str>) -> Self { - let device: ::Device = Default::default(); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 }; - let model = match path { - Some(p) => MlpNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { - eprintln!("az: load failed ({e}), using random weights"); - MlpNet::::new(&cfg, &device) - }), - None => MlpNet::::new(&cfg, &device), - }; - Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) - } - - /// ResNet-backed bot. `path = None` → random weights. - pub fn new_resnet(path: Option<&str>) -> Self { - let device: ::Device = Default::default(); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 }; - let model = match path { - Some(p) => ResNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { - eprintln!("az-resnet: load failed ({e}), using random weights"); - ResNet::::new(&cfg, &device) - }), - None => ResNet::::new(&cfg, &device), - }; - Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) - } - - /// Run MCTS and return the greedy best action index, or `None` if no legal moves. - fn best_action(&self) -> Option { - let env = TrictracEnv; - if env.legal_actions(&self.game).is_empty() { - return None; - } - let mut rng = self.rng.borrow_mut(); - let root = mcts::run_mcts( - &env, - &self.game, - self.evaluator.as_ref(), - &self.mcts_config, - &mut *rng, - ); - Some(mcts::select_action(&root, 0.0, &mut *rng)) - } -} - -impl BotStrategy for AzBotStrategy { - fn get_game(&self) -> &GameState { &self.game } - fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } - fn calculate_points(&self) -> u8 { self.game.dice_points.0 } - fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } - fn set_player_id(&mut self, _player_id: PlayerId) {} - fn set_color(&mut self, _color: Color) {} - - fn choose_go(&self) -> bool { - // Action index 1 == TrictracAction::Go - self.best_action().map(|a| a == 1).unwrap_or(false) - } - - fn choose_move(&self) -> (CheckerMove, CheckerMove) { - self.best_action() - .and_then(|a| action_to_moves(a, &self.game)) - .unwrap_or_else(|| fallback_move(&self.game)) - } -} - -// ── DqnSpielBotStrategy ─────────────────────────────────────────────────────── - -/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`]. -/// -/// Selects actions by greedy argmax over Q-values, masked to legal moves. -/// When no checkpoint is provided the model falls back to the first legal move. -/// -/// # Construction -/// -/// ```rust,ignore -/// // No model — always picks first legal move -/// DqnSpielBotStrategy::new(None); -/// -/// // Trained checkpoint -/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk")); -/// ``` -#[derive(Debug)] -pub struct DqnSpielBotStrategy { - game: GameState, - model: Option>, -} - -impl DqnSpielBotStrategy { - /// Create a DQN bot. `path = None` → falls back to first legal move. - pub fn new(path: Option<&str>) -> Self { - let model = path.map(|p| { - let device: ::Device = Default::default(); - let cfg = QNetConfig::default(); - QNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { - eprintln!("az-dqn: load failed ({e}), using random weights"); - QNet::::new(&cfg, &device) - }) - }); - Self { game: GameState::default(), model } - } - - /// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves. - fn best_action(&self) -> Option { - let model = self.model.as_ref()?; - let legal = get_valid_action_indices(&self.game).unwrap_or_default(); - if legal.is_empty() { - return None; - } - let device: ::Device = Default::default(); - let obs = self.game.to_tensor(); - let obs_t = Tensor::::from_data(TensorData::new(obs, [1, 217]), &device); - let q_vals: Vec = model.forward(obs_t).into_data().to_vec().unwrap(); - legal.into_iter().max_by(|&a, &b| { - q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal) - }) - } -} - -impl BotStrategy for DqnSpielBotStrategy { - fn get_game(&self) -> &GameState { &self.game } - fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } - fn calculate_points(&self) -> u8 { self.game.dice_points.0 } - fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } - fn set_player_id(&mut self, _player_id: PlayerId) {} - fn set_color(&mut self, _color: Color) {} - - fn choose_go(&self) -> bool { - self.best_action().map(|a| a == 1).unwrap_or(false) - } - - fn choose_move(&self) -> (CheckerMove, CheckerMove) { - self.best_action() - .and_then(|a| action_to_moves(a, &self.game)) - .unwrap_or_else(|| fallback_move(&self.game)) - } -} diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs deleted file mode 100644 index d73fda0..0000000 --- a/spiel_bot/tests/integration.rs +++ /dev/null @@ -1,391 +0,0 @@ -//! End-to-end integration tests for the AlphaZero training pipeline. -//! -//! Each test exercises the full chain: -//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] -//! -//! Two environments are used: -//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. -//! Used when we need many iterations without worrying about runtime. -//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that -//! the full pipeline compiles and runs correctly with 217-dim observations -//! and 514-dim action spaces. -//! -//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep -//! runtime minimal; correctness, not training quality, is what matters here. - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::MctsConfig, - network::{MlpConfig, MlpNet, PolicyValueNet}, -}; - -// ── Backend aliases ──────────────────────────────────────────────────────── - -type Train = Autodiff>; -type Infer = NdArray; - -// ── Helpers ──────────────────────────────────────────────────────────────── - -fn train_device() -> ::Device { - Default::default() -} - -fn infer_device() -> ::Device { - Default::default() -} - -/// Tiny 64-unit MLP, compatible with an obs/action space of any size. -fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { - let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; - MlpNet::new(&cfg, &train_device()) -} - -fn tiny_mcts(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - } -} - -fn seeded() -> SmallRng { - SmallRng::seed_from_u64(0) -} - -// ── Countdown environment (fast, local, no external deps) ───────────────── -// -// Two players alternate subtracting 1 or 2 from a counter that starts at N. -// The player who brings the counter to 0 wins. - -#[derive(Clone, Debug)] -struct CState { - remaining: u8, - to_move: usize, -} - -#[derive(Clone)] -struct CountdownEnv(u8); // starting value - -impl GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: self.0, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } - else { Player::P2 } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } -} - -// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── - -/// The canonical AlphaZero loop: self-play → replay → train, iterated. -/// Uses CountdownEnv so each game terminates in < 10 moves. -#[test] -fn countdown_full_loop_no_panic() { - let env = CountdownEnv(8); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(1_000); - - for _iter in 0..5 { - // Self-play: 3 games per iteration. - for _ in 0..3 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - replay.extend(samples); - } - - // Training: 4 gradient steps per iteration. - if replay.len() >= 4 { - for _ in 0..4 { - let batch: Vec = replay - .sample_batch(4, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - } - } - - assert!(replay.len() > 0); -} - -// ── 2. Replay buffer invariants ─────────────────────────────────────────── - -/// After several Countdown games, replay capacity is respected and batch -/// shapes are consistent. -#[test] -fn replay_buffer_capacity_and_shapes() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let capacity = 50; - let mut replay = ReplayBuffer::new(capacity); - - for _ in 0..20 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - replay.extend(samples); - } - - assert!(replay.len() <= capacity, "buffer exceeded capacity"); - assert!(replay.len() > 0); - - let batch = replay.sample_batch(8, &mut rng); - assert_eq!(batch.len(), 8.min(replay.len())); - for s in &batch { - assert_eq!(s.obs.len(), env.obs_size()); - assert_eq!(s.policy.len(), env.action_space()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); - assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); - } -} - -// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── - -/// Verify that one TrictracEnv episode produces samples with the correct -/// tensor dimensions: obs = 217, policy = 514. -#[test] -fn trictrac_sample_shapes() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - - assert!(!samples.is_empty(), "Trictrac episode produced no samples"); - - for (i, s) in samples.iter().enumerate() { - assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); - assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!( - (policy_sum - 1.0).abs() < 1e-4, - "sample {i}: policy sums to {policy_sum}" - ); - assert!( - s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "sample {i}: unexpected value {}", - s.value - ); - } -} - -// ── 4. TrictracEnv: training step after real self-play ──────────────────── - -/// Collect one Trictrac episode, then verify that a gradient step runs -/// without panic and produces a finite loss. -#[test] -fn trictrac_train_step_finite_loss() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(10_000); - - // Generate one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - let n_samples = samples.len(); - replay.extend(samples); - - // Train on a batch from this episode. - let batch_size = 8.min(n_samples); - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); -} - -// ── 5. Backend transfer: train → infer → same outputs ───────────────────── - -/// Weights transferred from the training backend to the inference backend -/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. -#[test] -fn valid_model_matches_train_model_outputs() { - use burn::tensor::{Tensor, TensorData}; - - let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let train_model = MlpNet::::new(&cfg, &train_device()); - let infer_model: MlpNet = train_model.valid(); - - // Build the same input on both backends. - let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; - - let obs_train = Tensor::::from_data( - TensorData::new(obs_data.clone(), [1, 4]), - &train_device(), - ); - let obs_infer = Tensor::::from_data( - TensorData::new(obs_data, [1, 4]), - &infer_device(), - ); - - let (p_train, v_train) = train_model.forward(obs_train); - let (p_infer, v_infer) = infer_model.forward(obs_infer); - - let p_train: Vec = p_train.into_data().to_vec().unwrap(); - let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); - let v_train: Vec = v_train.into_data().to_vec().unwrap(); - let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); - - for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { - assert!( - (a - b).abs() < 1e-5, - "policy[{i}] differs after valid(): train={a}, infer={b}" - ); - } - assert!( - (v_train[0] - v_infer[0]).abs() < 1e-5, - "value differs after valid(): train={}, infer={}", - v_train[0], v_infer[0] - ); -} - -// ── 6. Loss converges on a fixed batch ──────────────────────────────────── - -/// With repeated gradient steps on the same Countdown batch, the loss must -/// decrease monotonically (or at least end lower than it started). -#[test] -fn loss_decreases_on_fixed_batch() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - - // Collect a fixed batch from one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); - assert!(!samples.is_empty()); - - let batch: Vec = { - let mut replay = ReplayBuffer::new(1000); - replay.extend(samples); - replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() - }; - - // Overfit on the same fixed batch for 20 steps. - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut first_loss = f32::NAN; - let mut last_loss = f32::NAN; - - for step in 0..20 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); - model = m; - assert!(loss.is_finite(), "loss is not finite at step {step}"); - if step == 0 { first_loss = loss; } - last_loss = loss; - } - - assert!( - last_loss < first_loss, - "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" - ); -} - -// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── - -/// Two full self-play + train iterations on TrictracEnv. -/// Verifies the entire pipeline runs without panic end-to-end. -#[test] -fn trictrac_two_iteration_loop() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mut model = MlpNet::::new(&cfg, &train_device()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(20_000); - - for iter in 0..2 { - // Self-play: 1 game per iteration. - let infer: MlpNet = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); - assert!(!samples.is_empty(), "iter {iter}: episode was empty"); - replay.extend(samples); - - // Training: 3 gradient steps. - let batch_size = 16.min(replay.len()); - for _ in 0..3 { - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "iter {iter}: loss={loss}"); - } - } -} diff --git a/store/src/game.rs b/store/src/game.rs index e4e938c..f553bdb 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -225,26 +225,22 @@ impl GameState { let mut t = Vec::with_capacity(217); let pos: Vec = self.board.to_vec(); // 24 elements, positive=White, negative=Black - // [0..95] own (White) checkers, TD-Gammon encoding. - // Each field contributes 4 values: - // (count==1), (count==2), (count==3), (count-3)/12 ← all in [0,1] - // The overflow term is divided by 12 because the maximum excess is - // 15 (all checkers) − 3 = 12. + // [0..95] own (White) checkers, TD-Gammon encoding for &c in &pos { let own = c.max(0) as u8; t.push((own == 1) as u8 as f32); t.push((own == 2) as u8 as f32); t.push((own == 3) as u8 as f32); - t.push(own.saturating_sub(3) as f32 / 12.0); + t.push(own.saturating_sub(3) as f32); } - // [96..191] opp (Black) checkers, same encoding. + // [96..191] opp (Black) checkers, TD-Gammon encoding for &c in &pos { let opp = (-c).max(0) as u8; t.push((opp == 1) as u8 as f32); t.push((opp == 2) as u8 as f32); t.push((opp == 3) as u8 as f32); - t.push(opp.saturating_sub(3) as f32 / 12.0); + t.push(opp.saturating_sub(3) as f32); } // [192..193] dice @@ -1011,16 +1007,6 @@ impl GameState { self.mark_points(player_id, points) } - /// Total accumulated score for a player: `holes × 12 + points`. - /// - /// Returns `0` if `player_id` is not found (e.g. before `init_player`). - pub fn total_score(&self, player_id: PlayerId) -> i32 { - self.players - .get(&player_id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - } - fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false;