From 4691a84e23cb3919c282ca04f737d50a7643d54e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 9 Mar 2026 19:43:52 +0100 Subject: [PATCH 1/7] feat(spiel_bot): az_train parallel games with rayon --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/alphazero/selfplay.rs | 4 ++++ spiel_bot/src/bin/az_train.rs | 25 +++++++++++++++++++++---- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..a6c9481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "rayon", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..682505b 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } +rayon = "1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs index 6f10f8d..b38b7f4 100644 --- a/spiel_bot/src/alphazero/selfplay.rs +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -31,6 +31,10 @@ impl> BurnEvaluator { pub fn into_model(self) -> N { self.model } + + pub fn model_ref(&self) -> &N { + &self.model + } } // Safety: NdArray modules are Send; we never share across threads without diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs index ab385c2..824abe5 100644 --- a/spiel_bot/src/bin/az_train.rs +++ b/spiel_bot/src/bin/az_train.rs @@ -47,7 +47,8 @@ use burn::{ optim::AdamConfig, tensor::backend::Backend, }; -use rand::{SeedableRng, rngs::SmallRng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; +use rayon::prelude::*; use spiel_bot::{ alphazero::{ @@ -195,10 +196,26 @@ where if step < temp_drop { 1.0 } else { 0.0 } }; + // Prepare per-game seeds and evaluators sequentially so the main RNG + // and model cloning stay deterministic regardless of thread scheduling. + // Burn modules are Send but not Sync, so each task must own its model. + let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); + let game_evals: Vec<_> = (0..args.n_games) + .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) + .collect(); + drop(evaluator); + + let all_samples: Vec> = game_seeds + .into_par_iter() + .zip(game_evals.into_par_iter()) + .map(|(seed, game_eval)| { + let mut game_rng = SmallRng::seed_from_u64(seed); + generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) + }) + .collect(); + let mut new_samples = 0usize; - for _ in 0..args.n_games { - let samples = - generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + for samples in all_samples { new_samples += samples.len(); replay.extend(samples); } From bb6ef47a5f99c64c942ecc376a24c9e939f58c2b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 08:17:43 +0100 Subject: [PATCH 2/7] doc: research parallel --- doc/spiel_bot_parallel.md | 121 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 doc/spiel_bot_parallel.md diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md new file mode 100644 index 0000000..d9e021e --- /dev/null +++ b/doc/spiel_bot_parallel.md @@ -0,0 +1,121 @@ +Part B — Batched MCTS leaf evaluation + +Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. + +Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) + +pub trait Evaluator: Send + Sync { +fn evaluate(&self, obs: &[f32]) -> (Vec, f32); + + /// Evaluate a batch of observations at once. Default falls back to + /// sequential calls; backends override this for efficiency. + fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { + obs_batch.iter().map(|obs| self.evaluate(obs)).collect() + } + +} + +Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) + +Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. + +fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { +let b = obs_batch.len(); +let obs_size = obs_batch[0].len(); +let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); +let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); +let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); +let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); +let values: Vec = value_tensor.into_data().to_vec().unwrap(); +let action_size = policies.len() / b; +(0..b).map(|i| { +(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) +}).collect() +} + +Step B3 — Add eval_batch_size to MctsConfig + +pub struct MctsConfig { +// ... existing fields ... +/// Number of leaves to batch per network call. 1 = no batching (current behaviour). +pub eval_batch_size: usize, +} + +Default: 1 (backwards-compatible). + +Step B4 — Make the simulation iterative (mcts/search.rs) + +The current simulate is recursive. For batching we need to split it into two phases: + +descend (pure tree traversal — no network call): + +- Traverse from root following PUCT, advancing through chance nodes with apply_chance. +- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). +- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. +- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. + +ascend (backup — no network call): + +- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. +- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. + +Step B5 — Add run_mcts_batched to mcts/mod.rs + +The new entry point, called by run_mcts when config.eval_batch_size > 1: + +expand root (1 network call) +optionally add Dirichlet noise + +for round in 0..(n*simulations / batch_size): +leaves = [] +for * in 0..batch_size: +leaf = descend(root, state, env, rng) +leaves.push(leaf) + + obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves + where leaf.kind == NeedsEval] + results = evaluator.evaluate_batch(obs_batch) + + for (leaf, result) in zip(leaves, results): + expand the leaf node (insert children from result.policy) + ascend(root, leaf.path, result.value, leaf.player_idx) + // ascend also handles terminal and crossed-chance leaves + +// handle remainder: n_simulations % batch_size + +run_mcts becomes a thin dispatcher: +if config.eval_batch_size <= 1 { +// existing path (unchanged) +} else { +run_mcts_batched(...) +} + +Step B6 — CLI flag in az_train.rs + +--eval-batch N default: 8 Leaf batch size for MCTS network calls + +--- + +Summary of file changes + +┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ +│ File │ Changes │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ spiel_bot/Cargo.toml │ add rayon │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ +└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + +Key design constraint + +Parts A and B are independent and composable: + +- A only touches the outer game loop. +- B only touches the inner MCTS per game. +- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state. From 7c0f230e3de58319bc26558aed5ca3153d5a58fe Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 08:19:24 +0100 Subject: [PATCH 3/7] doc: tensor research --- doc/tensor_research.md | 253 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 doc/tensor_research.md diff --git a/doc/tensor_research.md b/doc/tensor_research.md new file mode 100644 index 0000000..b0d0ede --- /dev/null +++ b/doc/tensor_research.md @@ -0,0 +1,253 @@ +# Tensor research + +## Current tensor anatomy + +[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) +[24] active player color: 0 or 1 +[25] turn_stage: 1–5 +[26–27] dice values (raw 1–6) +[28–31] white: points, holes, can_bredouille, can_big_bredouille +[32–35] black: same +───────────────────────────────── +Total 36 floats + +The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's +AlphaZero uses a fully-connected network. + +### Fundamental problems with the current encoding + +1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network + must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with + all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. + +2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient + flow during training is uneven. + +3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it + triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers + produces a score. Including this explicitly is the single highest-value addition. + +4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely + different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. + +5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the + starting position). It's in the Player struct but not exported. + +## Key Trictrac distinctions from backgammon that shape the encoding + +| Concept | Backgammon | Trictrac | +| ------------------------- | ---------------------- | --------------------------------------------------------- | +| Hitting a blot | Removes checker to bar | Scores points, checker stays | +| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | +| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | +| 3-checker field | Safe with spare | Safe with spare | +| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | +| Both colors on a field | Impossible | Perfectly legal | +| Rest corner (field 12/13) | Does not exist | Special two-checker rules | + +The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary +indicators directly teaches the network the phase transitions the game hinges on. + +## Options + +### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) + +The minimum viable improvement. + +For each of the 24 fields, encode own and opponent separately with 4 indicators each: + +own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) +own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) +own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) +own_x[i]: max(0, count − 3) (overflow) +opp_1[i]: same for opponent +… + +Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). + +Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 +Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured +overhead is negligible. +Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from +scratch will be substantially faster and the converged policy quality better. + +### Option B — Option A + Trictrac-specific derived features (flat 1D) + +Recommended starting point. + +Add on top of Option A: + +// Quarter fill status — the single most important derived feature +quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) +quarter_filled_opp[q] (q=0..3): same for opponent +→ 8 values + +// Exit readiness +can_exit_own: 1.0 if all own checkers are in fields 19–24 +can_exit_opp: same for opponent +→ 2 values + +// Rest corner status (field 12/13) +own_corner_taken: 1.0 if field 12 has ≥2 own checkers +opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers +→ 2 values + +// Jan de 3 coups counter (normalized) +dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) +→ 1 value + +Size: 204 + 8 + 2 + 2 + 1 = 217 +Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve +the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes +expensive inference from the network. + +### Option C — Option B + richer positional features (flat 1D) + +More complete, higher sample efficiency, minor extra cost. + +Add on top of Option B: + +// Per-quarter fill fraction — how close to filling each quarter +own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 +opp_quarter_fill_fraction[q] (q=0..3): same for opponent +→ 8 values + +// Blot counts — number of own/opponent single-checker fields globally +// (tells the network at a glance how much battage risk/opportunity exists) +own_blot_count: (number of own fields with exactly 1 checker) / 15.0 +opp_blot_count: same for opponent +→ 2 values + +// Bredouille would-double multiplier (already present, but explicitly scaled) +// No change needed, already binary + +Size: 217 + 8 + 2 = 227 +Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network +from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). + +### Option D — 2D spatial tensor {K, 24} + +For CNN-based networks. Best eventual architecture but requires changing the training setup. + +Shape {14, 24} — 14 feature channels over 24 field positions: + +Channel 0: own_count_1 (blot) +Channel 1: own_count_2 +Channel 2: own_count_3 +Channel 3: own_count_overflow (float) +Channel 4: opp_count_1 +Channel 5: opp_count_2 +Channel 6: opp_count_3 +Channel 7: opp_count_overflow +Channel 8: own_corner_mask (1.0 at field 12) +Channel 9: opp_corner_mask (1.0 at field 13) +Channel 10: final_quarter_mask (1.0 at fields 19–24) +Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) +Channel 12: quarter_filled_opp (same for opponent) +Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) + +Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform +value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" +row by returning shape {K, 25} with position 0 holding global features. + +Size: 14 × 24 + few global channels ≈ 336–384 +C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated +accordingly. +Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's +alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, +W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. +Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel +size 2–3 captures adjacent-field "tout d'une" interactions. + +### On 3D tensors + +Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is +the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and +field-within-quarter patterns jointly. + +However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The +extra architecture work makes this premature unless you're already building a custom network. The information content +is the same as Option D. + +### Recommendation + +Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and +the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions +(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. + +Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when +the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. + +Costs summary: + +| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | +| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | +| Current | 36 | — | — | — | baseline | +| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | +| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | +| C | 227 | to_vec() rewrite | constant update | none | large + moderate | +| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | + +One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, +the new to_vec() must express everything from the active player's perspective (which the mirror already handles for +the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, +which they will be if computed inside to_vec(). + +## Other algorithms + +The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. + +### 1. Without MCTS, feature quality matters more + +AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred +simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — +the network must learn the full strategic structure directly from gradient updates. + +This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, +not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less +sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from +scratch. + +### 2. Normalization becomes mandatory, not optional + +AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps +Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw +values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can +be unstable. + +For DQN/PPO, every feature in the tensor should be in [0, 1]: + +dice values: / 6.0 +checker counts: overflow channel / 12.0 +points: / 12.0 +holes: / 12.0 +dice_roll_count: / 3.0 (clamped) + +Booleans and the TD-Gammon binary indicators are already in [0, 1]. + +### 3. The shape question depends on architecture, not algorithm + +| Architecture | Shape | When to use | +| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | +| MLP | {217} flat | Any algorithm, simplest baseline | +| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | +| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | +| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | + +The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want +convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet +template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all +— you just reshape the flat tensor in Python before passing it to the network. + +### One thing that genuinely changes: value function perspective + +AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This +works well. + +DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit +current-player indicator), because a single Q-network estimates action values for both players simultaneously. With +the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must +learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. +This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. + +Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. From e7d13c9a02480da812e13ecb762440e3937d3e7f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 22:12:52 +0100 Subject: [PATCH 4/7] feat(spiel_bot): dqn --- spiel_bot/src/bin/dqn_train.rs | 251 +++++++++++++++++++++++++++++ spiel_bot/src/dqn/episode.rs | 247 +++++++++++++++++++++++++++++ spiel_bot/src/dqn/mod.rs | 232 +++++++++++++++++++++++++++ spiel_bot/src/dqn/trainer.rs | 278 +++++++++++++++++++++++++++++++++ spiel_bot/src/env/trictrac.rs | 12 ++ spiel_bot/src/lib.rs | 1 + spiel_bot/src/network/mod.rs | 14 ++ spiel_bot/src/network/qnet.rs | 147 +++++++++++++++++ store/src/game.rs | 10 ++ 9 files changed, 1192 insertions(+) create mode 100644 spiel_bot/src/bin/dqn_train.rs create mode 100644 spiel_bot/src/dqn/episode.rs create mode 100644 spiel_bot/src/dqn/mod.rs create mode 100644 spiel_bot/src/dqn/trainer.rs create mode 100644 spiel_bot/src/network/qnet.rs diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs new file mode 100644 index 0000000..0ebe978 --- /dev/null +++ b/spiel_bot/src/bin/dqn_train.rs @@ -0,0 +1,251 @@ +//! DQN self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh with default settings +//! cargo run -p spiel_bot --bin dqn_train --release +//! +//! # Custom hyperparameters +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--hidden N` | 256 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | 100 | Training iterations | +//! | `--n-games N` | 10 | Self-play games per iteration | +//! | `--n-train N` | 20 | Gradient steps per iteration | +//! | `--batch N` | 64 | Mini-batch size | +//! | `--replay-cap N` | 50000 | Replay buffer capacity | +//! | `--lr F` | 1e-3 | Adam learning rate | +//! | `--epsilon-start F` | 1.0 | Initial exploration rate | +//! | `--epsilon-end F` | 0.05 | Final exploration rate | +//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | +//! | `--gamma F` | 0.99 | Discount factor | +//! | `--target-update N` | 500 | Hard-update target net every N steps | +//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | +//! | `--save-every N` | 10 | Save checkpoint every N iterations | +//! | `--seed N` | 42 | RNG seed | +//! | `--resume PATH` | (none) | Load weights before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + dqn::{ + DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, + generate_dqn_episode, hard_update, linear_epsilon, + }, + env::TrictracEnv, + network::{QNet, QNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + hidden: usize, + out_dir: PathBuf, + save_every: usize, + seed: u64, + resume: Option, + config: DqnConfig, +} + +impl Default for Args { + fn default() -> Self { + Self { + hidden: 256, + out_dir: PathBuf::from("checkpoints"), + save_every: 10, + seed: 42, + resume: None, + config: DqnConfig::default(), + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } + "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } + "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } + "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } + "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } + "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } + "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } + "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +fn train_loop( + mut q_net: QNet, + cfg: &QNetConfig, + save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, + args: &Args, +) { + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + let mut optimizer = AdamConfig::new().init(); + let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + let mut target_net: QNet = hard_update::(&q_net); + let mut global_step = 0usize; + let mut epsilon = args.config.epsilon_start; + + println!( + "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", + "", args.config.n_iterations, args.config.n_games_per_iter, + args.config.n_train_steps_per_iter, "" + ); + + for iter in 0..args.config.n_iterations { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + let infer_q: QNet = q_net.valid(); + let mut new_samples = 0usize; + + for _ in 0..args.config.n_games_per_iter { + let samples = generate_dqn_episode( + &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, + ); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.config.batch_size { + for _ in 0..args.config.n_train_steps_per_iter { + let batch: Vec<_> = replay + .sample_batch(args.config.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + // Target Q-values computed on the inference backend. + let target_q = compute_target_q( + &target_net, &batch, cfg.action_size, &infer_device, + ); + + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &target_q, + &train_device, args.config.learning_rate, args.config.gamma, + ); + q_net = q; + loss_sum += loss; + n_steps += 1; + global_step += 1; + + // Hard-update target net every target_update_freq steps. + if global_step % args.config.target_update_freq == 0 { + target_net = hard_update::(&q_net); + } + + // Linear epsilon decay. + epsilon = linear_epsilon( + args.config.epsilon_start, + args.config.epsilon_end, + global_step, + args.config.epsilon_decay_steps, + ); + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", + iter + 1, + args.config.n_iterations, + replay.len(), + new_samples, + avg_loss, + epsilon, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.config.n_iterations; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); + match save_fn(&q_net, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nDQN training complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; + + let q_net = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + QNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => QNet::::new(&cfg, &train_device), + }; + + train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); +} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs new file mode 100644 index 0000000..aca1343 --- /dev/null +++ b/spiel_bot/src/dqn/episode.rs @@ -0,0 +1,247 @@ +//! DQN self-play episode generation. +//! +//! Both players share the same Q-network (the [`TrictracEnv`] handles board +//! mirroring so that each player always acts from "White's perspective"). +//! Transitions for both players are stored in the returned sample list. +//! +//! # Reward +//! +//! After each full decision (action applied and the state has advanced through +//! any intervening chance nodes back to the same player's next turn), the +//! reward is: +//! +//! ```text +//! r = (my_total_score_now − my_total_score_then) +//! − (opp_total_score_now − opp_total_score_then) +//! ``` +//! +//! where `total_score = holes × 12 + points`. +//! +//! # Transition structure +//! +//! We use a "pending transition" per player. When a player acts again, we +//! *complete* the previous pending transition by filling in `next_obs`, +//! `next_legal`, and computing `reward`. Terminal transitions are completed +//! when the game ends. + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::{GameEnv, TrictracEnv}; +use crate::network::QValueNet; +use super::DqnSample; + +// ── Internals ───────────────────────────────────────────────────────────────── + +struct PendingTransition { + obs: Vec, + action: usize, + /// Score snapshot `[p1_total, p2_total]` at the moment of the action. + score_before: [i32; 2], +} + +/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. +fn epsilon_greedy>( + q_net: &Q, + obs: &[f32], + legal: &[usize], + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, +) -> usize { + debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); + if rng.random::() < epsilon { + legal[rng.random_range(0..legal.len())] + } else { + let obs_tensor = Tensor::::from_data( + TensorData::new(obs.to_vec(), [1, obs.len()]), + device, + ); + let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); + legal + .iter() + .copied() + .max_by(|&a, &b| { + q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap() + } +} + +/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. +fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { + let opp_idx = 1 - player_idx; + ((score_after[player_idx] - score_before[player_idx]) + - (score_after[opp_idx] - score_before[opp_idx])) as f32 +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Play one full game and return all transitions for both players. +/// +/// - `q_net` uses the **inference backend** (no autodiff wrapper). +/// - `epsilon` in `[0, 1]`: probability of taking a random action. +/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). +pub fn generate_dqn_episode>( + env: &TrictracEnv, + q_net: &Q, + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, + reward_scale: f32, +) -> Vec { + let obs_size = env.obs_size(); + let mut state = env.new_game(); + let mut pending: [Option; 2] = [None, None]; + let mut samples: Vec = Vec::new(); + + loop { + // ── Advance past chance nodes ────────────────────────────────────── + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + let score_now = TrictracEnv::score_snapshot(&state); + + if env.current_player(&state).is_terminal() { + // Complete all pending transitions as terminal. + for player_idx in 0..2 { + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: vec![0.0; obs_size], + next_legal: vec![], + done: true, + }); + } + } + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + let legal = env.legal_actions(&state); + let obs = env.observation(&state, player_idx); + + // ── Complete the previous transition for this player ─────────────── + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: obs.clone(), + next_legal: legal.clone(), + done: false, + }); + } + + // ── Pick and apply action ────────────────────────────────────────── + let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); + env.apply(&mut state, action); + + // ── Record new pending transition ────────────────────────────────── + pending[player_idx] = Some(PendingTransition { + obs, + action, + score_before: score_now, + }); + } + + samples +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::network::{QNet, QNetConfig}; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } + + fn tiny_q() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn episode_terminates_and_produces_samples() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_obs_size_correct() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert_eq!(s.obs.len(), 217, "obs size mismatch"); + if s.done { + assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); + assert!(s.next_legal.is_empty()); + } else { + assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); + assert!(!s.next_legal.is_empty()); + } + } + } + + #[test] + fn episode_actions_within_action_space() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert!(s.action < 514, "action {} out of bounds", s.action); + } + } + + #[test] + fn greedy_episode_also_terminates() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty()); + } + + #[test] + fn at_least_one_done_sample() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + let n_done = samples.iter().filter(|s| s.done).count(); + // Two players, so 1 or 2 terminal transitions. + assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); + } + + #[test] + fn compute_reward_correct() { + // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. + let before = [2 * 12 + 10, 0]; + let after = [3 * 12 + 2, 0]; + let r = compute_reward(0, &before, &after); + assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); + } + + #[test] + fn compute_reward_with_opponent_scoring() { + // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. + let before = [0, 0]; + let after = [2, 3]; + let r = compute_reward(0, &before, &after); + assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); + } +} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs new file mode 100644 index 0000000..8c34fc1 --- /dev/null +++ b/spiel_bot/src/dqn/mod.rs @@ -0,0 +1,232 @@ +//! DQN: self-play data generation, replay buffer, and training step. +//! +//! # Algorithm +//! +//! Deep Q-Network with: +//! - **ε-greedy** exploration (linearly decayed). +//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where +//! `score = holes × 12 + points`. +//! - **Experience replay** with a fixed-capacity circular buffer. +//! - **Target network**: hard-copied from the online Q-net every +//! `target_update_freq` gradient steps for training stability. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | +//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | + +pub mod episode; +pub mod trainer; + +pub use episode::generate_dqn_episode; +pub use trainer::{compute_target_q, dqn_train_step, hard_update}; + +use std::collections::VecDeque; +use rand::Rng; + +// ── DqnSample ───────────────────────────────────────────────────────────────── + +/// One transition `(s, a, r, s', done)` collected during self-play. +#[derive(Clone, Debug)] +pub struct DqnSample { + /// Observation from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// Action index taken. + pub action: usize, + /// Per-turn reward: `my_score_delta − opponent_score_delta`. + pub reward: f32, + /// Next observation from the same player's perspective. + /// All-zeros when `done = true` (ignored by the TD target). + pub next_obs: Vec, + /// Legal actions at `next_obs`. Empty when `done = true`. + pub next_legal: Vec, + /// `true` when `next_obs` is a terminal state. + pub done: bool, +} + +// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── + +/// Fixed-capacity circular replay buffer for [`DqnSample`]s. +/// +/// When full, the oldest sample is evicted on push. +/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. +pub struct DqnReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl DqnReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } + } + + pub fn push(&mut self, sample: DqnSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { self.push(s); } + } + + pub fn len(&self) -> usize { self.data.len() } + pub fn is_empty(&self) -> bool { self.data.is_empty() } + + /// Sample up to `n` distinct samples without replacement. + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { + let len = self.data.len(); + let n = n.min(len); + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── DqnConfig ───────────────────────────────────────────────────────────────── + +/// Top-level DQN hyperparameters for the training loop. +#[derive(Debug, Clone)] +pub struct DqnConfig { + /// Initial exploration rate (1.0 = fully random). + pub epsilon_start: f32, + /// Final exploration rate after decay. + pub epsilon_end: f32, + /// Number of gradient steps over which ε decays linearly from start to end. + /// + /// Should be calibrated to the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). A value larger than that + /// means exploration never reaches `epsilon_end` during the run. + pub epsilon_decay_steps: usize, + /// Discount factor γ for the TD target. Typical: 0.99. + pub gamma: f32, + /// Hard-copy Q → target every this many gradient steps. + /// + /// Should be much smaller than the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). + pub target_update_freq: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Number of outer iterations (self-play + train). + pub n_iterations: usize, + /// Self-play games per iteration. + pub n_games_per_iter: usize, + /// Gradient steps per iteration. + pub n_train_steps_per_iter: usize, + /// Reward normalisation divisor. + /// + /// Per-turn rewards (score delta) are divided by this constant before being + /// stored. Without normalisation, rewards can reach ±24 (jan with + /// bredouille = 12 pts × 2), driving Q-values into the hundreds and + /// causing MSE loss to grow unboundedly. + /// + /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping + /// Q-value magnitudes in a stable range. Set to `1.0` to disable. + pub reward_scale: f32, +} + +impl Default for DqnConfig { + fn default() -> Self { + // Total gradient steps with these defaults = 500 × 20 = 10_000, + // so epsilon decays fully and the target is updated 100 times. + Self { + epsilon_start: 1.0, + epsilon_end: 0.05, + epsilon_decay_steps: 10_000, + gamma: 0.99, + target_update_freq: 100, + learning_rate: 1e-3, + batch_size: 64, + replay_capacity: 50_000, + n_iterations: 500, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + reward_scale: 12.0, + } + } +} + +/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. +pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { + if decay_steps == 0 || step >= decay_steps { + return end; + } + start + (end - start) * (step as f32 / decay_steps as f32) +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(reward: f32) -> DqnSample { + DqnSample { + obs: vec![0.0], + action: 0, + reward, + next_obs: vec![0.0], + next_legal: vec![0], + done: false, + } + } + + #[test] + fn push_and_len() { + let mut buf = DqnReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = DqnReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); + assert_eq!(buf.len(), 3); + assert_eq!(buf.data[0].reward, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = DqnReplayBuffer::new(20); + for i in 0..10 { buf.push(dummy(i as f32)); } + let mut rng = SmallRng::seed_from_u64(0); + assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); + } + + #[test] + fn linear_epsilon_start() { + assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_end() { + assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_monotone() { + let mut prev = f32::INFINITY; + for step in 0..=100 { + let e = linear_epsilon(1.0, 0.05, step, 100); + assert!(e <= prev + 1e-6); + prev = e; + } + } +} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs new file mode 100644 index 0000000..b8b0a02 --- /dev/null +++ b/spiel_bot/src/dqn/trainer.rs @@ -0,0 +1,278 @@ +//! DQN gradient step and target-network management. +//! +//! # TD target +//! +//! ```text +//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done +//! y_i = r_i if done +//! ``` +//! +//! # Loss +//! +//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) +//! and `y_i` (computed from the frozen target net). +//! +//! # Target network +//! +//! [`hard_update`] copies the online Q-net weights into the target net by +//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + Int, Tensor, TensorData, + backend::{AutodiffBackend, Backend}, + }, +}; + +use crate::network::QValueNet; +use super::DqnSample; + +// ── Target Q computation ───────────────────────────────────────────────────── + +/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. +/// +/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` +/// (their bootstrap term is dropped by the TD target anyway). +/// +/// The target network runs on the **inference backend** (`InferB`) with no +/// gradient tape, so this function is backend-agnostic (`B: Backend`). +pub fn compute_target_q>( + target_net: &Q, + batch: &[DqnSample], + action_size: usize, + device: &B::Device, +) -> Vec { + let batch_size = batch.len(); + + // Collect indices of non-done samples (done samples have no next state). + let non_done: Vec = batch + .iter() + .enumerate() + .filter(|(_, s)| !s.done) + .map(|(i, _)| i) + .collect(); + + if non_done.is_empty() { + return vec![0.0; batch_size]; + } + + let obs_size = batch[0].next_obs.len(); + let nd = non_done.len(); + + // Stack next observations for non-done samples → [nd, obs_size]. + let obs_flat: Vec = non_done + .iter() + .flat_map(|&i| batch[i].next_obs.iter().copied()) + .collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [nd, obs_size]), + device, + ); + + // Forward target net → [nd, action_size], then to Vec. + let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); + + // For each non-done sample, pick max Q over legal next actions. + let mut result = vec![0.0f32; batch_size]; + for (k, &i) in non_done.iter().enumerate() { + let legal = &batch[i].next_legal; + let offset = k * action_size; + let max_q = legal + .iter() + .map(|&a| q_flat[offset + a]) + .fold(f32::NEG_INFINITY, f32::max); + // If legal is empty (shouldn't happen for non-done, but be safe): + result[i] = if max_q.is_finite() { max_q } else { 0.0 }; + } + result +} + +// ── Training step ───────────────────────────────────────────────────────────── + +/// Run one gradient step on `q_net` using `batch`. +/// +/// `target_max_q` must be pre-computed via [`compute_target_q`] using the +/// frozen target network and passed in here so that this function only +/// needs the **autodiff backend**. +/// +/// Returns the updated network and the scalar MSE loss. +pub fn dqn_train_step( + q_net: Q, + optimizer: &mut O, + batch: &[DqnSample], + target_max_q: &[f32], + device: &B::Device, + lr: f64, + gamma: f32, +) -> (Q, f32) +where + B: AutodiffBackend, + Q: QValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "dqn_train_step: empty batch"); + assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + + // ── Build observation tensor [B, obs_size] ──────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + + // ── Forward Q-net → [B, action_size] ───────────────────────────────── + let q_all = q_net.forward(obs_tensor); + + // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── + let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); + let action_tensor: Tensor = Tensor::::from_data( + TensorData::new(actions, [batch_size]), + device, + ) + .reshape([batch_size, 1]); // [B] → [B, 1] + let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] + + // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── + let targets: Vec = batch + .iter() + .zip(target_max_q.iter()) + .map(|(s, &max_q)| { + if s.done { s.reward } else { s.reward + gamma * max_q } + }) + .collect(); + let target_tensor = Tensor::::from_data( + TensorData::new(targets, [batch_size]), + device, + ); + + // ── MSE loss ────────────────────────────────────────────────────────── + let diff = q_pred - target_tensor.detach(); + let loss = (diff.clone() * diff).mean(); + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &q_net); + let q_net = optimizer.step(lr, q_net, grads); + + (q_net, loss_scalar) +} + +// ── Target network update ───────────────────────────────────────────────────── + +/// Hard-copy the online Q-net weights to a new target network. +/// +/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an +/// inference-backend module with identical weights. +pub fn hard_update>(q_net: &Q) -> Q::InnerModule { + q_net.valid() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + use crate::network::{QNet, QNetConfig}; + + type InferB = NdArray; + type TrainB = Autodiff>; + + fn infer_device() -> ::Device { Default::default() } + fn train_device() -> ::Device { Default::default() } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| DqnSample { + obs: vec![0.5f32; obs_size], + action: i % action_size, + reward: if i % 2 == 0 { 1.0 } else { -1.0 }, + next_obs: vec![0.5f32; obs_size], + next_legal: vec![0, 1], + done: i == n - 1, + }) + .collect() + } + + #[test] + fn compute_target_q_length() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 8); + } + + #[test] + fn compute_target_q_done_is_zero() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + // Single done sample. + let batch = vec![DqnSample { + obs: vec![0.0; 4], + action: 0, + reward: 5.0, + next_obs: vec![0.0; 4], + next_legal: vec![], + done: true, + }]; + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 1); + assert_eq!(tq[0], 0.0); + } + + #[test] + fn train_step_returns_finite_loss() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + + #[test] + fn train_step_loss_decreases() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(16, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, + ); + q_net = q; + assert!(loss.is_finite()); + prev_loss = loss; + } + assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn hard_update_copies_weights() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = hard_update::(&q_net); + + let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); + let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); + // After hard_update the target produces finite outputs. + assert!(q_out.iter().all(|v| v.is_finite())); + } +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs index 99ba058..8dc3676 100644 --- a/spiel_bot/src/env/trictrac.rs +++ b/spiel_bot/src/env/trictrac.rs @@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv { } } +// ── DQN helpers ─────────────────────────────────────────────────────────────── + +impl TrictracEnv { + /// Score snapshot for DQN reward computation. + /// + /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. + /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). + pub fn score_snapshot(s: &GameState) -> [i32; 2] { + [s.total_score(1), s.total_score(2)] + } +} + // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 23895b9..9dfb4de 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,4 +1,5 @@ pub mod alphazero; +pub mod dqn; pub mod env; pub mod mcts; pub mod network; diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs index df710e9..64f93ec 100644 --- a/spiel_bot/src/network/mod.rs +++ b/spiel_bot/src/network/mod.rs @@ -43,9 +43,11 @@ //! before passing to softmax. pub mod mlp; +pub mod qnet; pub mod resnet; pub use mlp::{MlpConfig, MlpNet}; +pub use qnet::{QNet, QNetConfig}; pub use resnet::{ResNet, ResNetConfig}; use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; @@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; /// - `obs`: `[batch, obs_size]` /// - policy output: `[batch, action_size]` — raw logits (no softmax applied) /// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// /// Note: `Sync` is intentionally absent — Burn's `Module` internally uses /// `OnceCell` for lazy parameter initialisation, which is not `Sync`. /// Use an `Arc>` wrapper if cross-thread sharing is needed. pub trait PolicyValueNet: Module + Send + 'static { fn forward(&self, obs: Tensor) -> (Tensor, Tensor); } + +/// A neural network that outputs one Q-value per action. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - output: `[batch, action_size]` — raw Q-values (no activation) +/// +/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. +pub trait QValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> Tensor; +} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs new file mode 100644 index 0000000..1737f72 --- /dev/null +++ b/spiel_bot/src/network/qnet.rs @@ -0,0 +1,147 @@ +//! Single-headed Q-value network for DQN. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! → Linear(hidden → action_size) ← raw Q-values, no activation +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{activation::relu, backend::Backend, Tensor}, +}; +use std::path::Path; + +use super::QValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`QNet`]. +#[derive(Debug, Clone)] +pub struct QNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for QNetConfig { + fn default() -> Self { + Self { obs_size: 217, action_size: 514, hidden_size: 256 } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Two-hidden-layer MLP that outputs one Q-value per action. +#[derive(Module, Debug)] +pub struct QNet { + fc1: Linear, + fc2: Linear, + q_head: Linear, +} + +impl QNet { + /// Construct a fresh network with random weights. + pub fn new(config: &QNetConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl QValueNet for QNet { + fn forward(&self, obs: Tensor) -> Tensor { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + self.q_head.forward(x) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + + fn default_net() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn forward_output_shape() { + let net = default_net(); + let obs = Tensor::zeros([4, 217], &device()); + let q = net.forward(obs); + assert_eq!(q.dims(), [4, 514]); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let q = net.forward(Tensor::zeros([1, 217], &device())); + assert_eq!(q.dims(), [1, 514]); + } + + #[test] + fn q_values_not_all_equal() { + let net = default_net(); + let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) + .into_data().to_vec().unwrap(); + let first = q[0]; + assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); + } + + #[test] + fn custom_config_shapes() { + let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; + let net = QNet::::new(&cfg, &device()); + let q = net.forward(Tensor::zeros([3, 10], &device())); + assert_eq!(q.dims(), [3, 20]); + } + + #[test] + fn save_load_preserves_weights() { + let net = default_net(); + let obs = Tensor::::ones([2, 217], &device()); + let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); + + let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); + net.save(&path).expect("save failed"); + + let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); + let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); + + for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); + } + let _ = std::fs::remove_file(path); + } +} diff --git a/store/src/game.rs b/store/src/game.rs index 2fde45c..e4e938c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -1011,6 +1011,16 @@ impl GameState { self.mark_points(player_id, points) } + /// Total accumulated score for a player: `holes × 12 + points`. + /// + /// Returns `0` if `player_id` is not found (e.g. before `init_player`). + pub fn total_score(&self, player_id: PlayerId) -> i32 { + self.players + .get(&player_id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + } + fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false; From e80dade303e983dc66599ad8d5cbd81e6da955a4 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 11 Mar 2026 22:17:03 +0100 Subject: [PATCH 5/7] fix: --n-sim training parameter --- spiel_bot/src/mcts/mod.rs | 8 ++++---- spiel_bot/src/mcts/search.rs | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..eead171 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..1d9750d 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -166,6 +166,12 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); + // Record the visit so that PUCT and mcts_policy use real counts. + // Without this, child.n stays 0 for every simulation in games where + // every player action is immediately followed by a chance node (e.g. + // Trictrac), causing mcts_policy to always return a uniform policy. + child.n += 1; + child.w += value; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) From 31bb568c2a51f51abf0e962019d3e8360b7b281f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 9 Mar 2026 19:43:52 +0100 Subject: [PATCH 6/7] feat(spiel_bot): az_train parallel games with rayon --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/alphazero/selfplay.rs | 4 ++++ spiel_bot/src/bin/az_train.rs | 25 +++++++++++++++++++++---- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..a6c9481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "rayon", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..682505b 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } +rayon = "1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs index 6f10f8d..b38b7f4 100644 --- a/spiel_bot/src/alphazero/selfplay.rs +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -31,6 +31,10 @@ impl> BurnEvaluator { pub fn into_model(self) -> N { self.model } + + pub fn model_ref(&self) -> &N { + &self.model + } } // Safety: NdArray modules are Send; we never share across threads without diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs index ab385c2..824abe5 100644 --- a/spiel_bot/src/bin/az_train.rs +++ b/spiel_bot/src/bin/az_train.rs @@ -47,7 +47,8 @@ use burn::{ optim::AdamConfig, tensor::backend::Backend, }; -use rand::{SeedableRng, rngs::SmallRng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; +use rayon::prelude::*; use spiel_bot::{ alphazero::{ @@ -195,10 +196,26 @@ where if step < temp_drop { 1.0 } else { 0.0 } }; + // Prepare per-game seeds and evaluators sequentially so the main RNG + // and model cloning stay deterministic regardless of thread scheduling. + // Burn modules are Send but not Sync, so each task must own its model. + let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); + let game_evals: Vec<_> = (0..args.n_games) + .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) + .collect(); + drop(evaluator); + + let all_samples: Vec> = game_seeds + .into_par_iter() + .zip(game_evals.into_par_iter()) + .map(|(seed, game_eval)| { + let mut game_rng = SmallRng::seed_from_u64(seed); + generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) + }) + .collect(); + let mut new_samples = 0usize; - for _ in 0..args.n_games { - let samples = - generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + for samples in all_samples { new_samples += samples.len(); replay.extend(samples); } From 1554286f25f887e8cb8bedd4c41c7d7e9e2eeba0 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 08:17:43 +0100 Subject: [PATCH 7/7] doc: research parallel --- doc/spiel_bot_parallel.md | 121 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 doc/spiel_bot_parallel.md diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md new file mode 100644 index 0000000..d9e021e --- /dev/null +++ b/doc/spiel_bot_parallel.md @@ -0,0 +1,121 @@ +Part B — Batched MCTS leaf evaluation + +Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. + +Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) + +pub trait Evaluator: Send + Sync { +fn evaluate(&self, obs: &[f32]) -> (Vec, f32); + + /// Evaluate a batch of observations at once. Default falls back to + /// sequential calls; backends override this for efficiency. + fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { + obs_batch.iter().map(|obs| self.evaluate(obs)).collect() + } + +} + +Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) + +Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. + +fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { +let b = obs_batch.len(); +let obs_size = obs_batch[0].len(); +let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); +let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); +let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); +let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); +let values: Vec = value_tensor.into_data().to_vec().unwrap(); +let action_size = policies.len() / b; +(0..b).map(|i| { +(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) +}).collect() +} + +Step B3 — Add eval_batch_size to MctsConfig + +pub struct MctsConfig { +// ... existing fields ... +/// Number of leaves to batch per network call. 1 = no batching (current behaviour). +pub eval_batch_size: usize, +} + +Default: 1 (backwards-compatible). + +Step B4 — Make the simulation iterative (mcts/search.rs) + +The current simulate is recursive. For batching we need to split it into two phases: + +descend (pure tree traversal — no network call): + +- Traverse from root following PUCT, advancing through chance nodes with apply_chance. +- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). +- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. +- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. + +ascend (backup — no network call): + +- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. +- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. + +Step B5 — Add run_mcts_batched to mcts/mod.rs + +The new entry point, called by run_mcts when config.eval_batch_size > 1: + +expand root (1 network call) +optionally add Dirichlet noise + +for round in 0..(n*simulations / batch_size): +leaves = [] +for * in 0..batch_size: +leaf = descend(root, state, env, rng) +leaves.push(leaf) + + obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves + where leaf.kind == NeedsEval] + results = evaluator.evaluate_batch(obs_batch) + + for (leaf, result) in zip(leaves, results): + expand the leaf node (insert children from result.policy) + ascend(root, leaf.path, result.value, leaf.player_idx) + // ascend also handles terminal and crossed-chance leaves + +// handle remainder: n_simulations % batch_size + +run_mcts becomes a thin dispatcher: +if config.eval_batch_size <= 1 { +// existing path (unchanged) +} else { +run_mcts_batched(...) +} + +Step B6 — CLI flag in az_train.rs + +--eval-batch N default: 8 Leaf batch size for MCTS network calls + +--- + +Summary of file changes + +┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ +│ File │ Changes │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ spiel_bot/Cargo.toml │ add rayon │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ +└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + +Key design constraint + +Parts A and B are independent and composable: + +- A only touches the outer game loop. +- B only touches the inner MCTS per game. +- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.