From 4691a84e23cb3919c282ca04f737d50a7643d54e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 9 Mar 2026 19:43:52 +0100 Subject: [PATCH 1/2] feat(spiel_bot): az_train parallel games with rayon --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/alphazero/selfplay.rs | 4 ++++ spiel_bot/src/bin/az_train.rs | 25 +++++++++++++++++++++---- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..a6c9481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "rayon", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..682505b 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } +rayon = "1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs index 6f10f8d..b38b7f4 100644 --- a/spiel_bot/src/alphazero/selfplay.rs +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -31,6 +31,10 @@ impl> BurnEvaluator { pub fn into_model(self) -> N { self.model } + + pub fn model_ref(&self) -> &N { + &self.model + } } // Safety: NdArray modules are Send; we never share across threads without diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs index ab385c2..824abe5 100644 --- a/spiel_bot/src/bin/az_train.rs +++ b/spiel_bot/src/bin/az_train.rs @@ -47,7 +47,8 @@ use burn::{ optim::AdamConfig, tensor::backend::Backend, }; -use rand::{SeedableRng, rngs::SmallRng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; +use rayon::prelude::*; use spiel_bot::{ alphazero::{ @@ -195,10 +196,26 @@ where if step < temp_drop { 1.0 } else { 0.0 } }; + // Prepare per-game seeds and evaluators sequentially so the main RNG + // and model cloning stay deterministic regardless of thread scheduling. + // Burn modules are Send but not Sync, so each task must own its model. + let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); + let game_evals: Vec<_> = (0..args.n_games) + .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) + .collect(); + drop(evaluator); + + let all_samples: Vec> = game_seeds + .into_par_iter() + .zip(game_evals.into_par_iter()) + .map(|(seed, game_eval)| { + let mut game_rng = SmallRng::seed_from_u64(seed); + generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) + }) + .collect(); + let mut new_samples = 0usize; - for _ in 0..args.n_games { - let samples = - generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + for samples in all_samples { new_samples += samples.len(); replay.extend(samples); } From bb6ef47a5f99c64c942ecc376a24c9e939f58c2b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Tue, 10 Mar 2026 08:17:43 +0100 Subject: [PATCH 2/2] doc: research parallel --- doc/spiel_bot_parallel.md | 121 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 doc/spiel_bot_parallel.md diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md new file mode 100644 index 0000000..d9e021e --- /dev/null +++ b/doc/spiel_bot_parallel.md @@ -0,0 +1,121 @@ +Part B — Batched MCTS leaf evaluation + +Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. + +Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) + +pub trait Evaluator: Send + Sync { +fn evaluate(&self, obs: &[f32]) -> (Vec, f32); + + /// Evaluate a batch of observations at once. Default falls back to + /// sequential calls; backends override this for efficiency. + fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { + obs_batch.iter().map(|obs| self.evaluate(obs)).collect() + } + +} + +Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) + +Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. + +fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { +let b = obs_batch.len(); +let obs_size = obs_batch[0].len(); +let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); +let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); +let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); +let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); +let values: Vec = value_tensor.into_data().to_vec().unwrap(); +let action_size = policies.len() / b; +(0..b).map(|i| { +(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) +}).collect() +} + +Step B3 — Add eval_batch_size to MctsConfig + +pub struct MctsConfig { +// ... existing fields ... +/// Number of leaves to batch per network call. 1 = no batching (current behaviour). +pub eval_batch_size: usize, +} + +Default: 1 (backwards-compatible). + +Step B4 — Make the simulation iterative (mcts/search.rs) + +The current simulate is recursive. For batching we need to split it into two phases: + +descend (pure tree traversal — no network call): + +- Traverse from root following PUCT, advancing through chance nodes with apply_chance. +- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). +- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. +- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. + +ascend (backup — no network call): + +- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. +- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. + +Step B5 — Add run_mcts_batched to mcts/mod.rs + +The new entry point, called by run_mcts when config.eval_batch_size > 1: + +expand root (1 network call) +optionally add Dirichlet noise + +for round in 0..(n*simulations / batch_size): +leaves = [] +for * in 0..batch_size: +leaf = descend(root, state, env, rng) +leaves.push(leaf) + + obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves + where leaf.kind == NeedsEval] + results = evaluator.evaluate_batch(obs_batch) + + for (leaf, result) in zip(leaves, results): + expand the leaf node (insert children from result.policy) + ascend(root, leaf.path, result.value, leaf.player_idx) + // ascend also handles terminal and crossed-chance leaves + +// handle remainder: n_simulations % batch_size + +run_mcts becomes a thin dispatcher: +if config.eval_batch_size <= 1 { +// existing path (unchanged) +} else { +run_mcts_batched(...) +} + +Step B6 — CLI flag in az_train.rs + +--eval-batch N default: 8 Leaf batch size for MCTS network calls + +--- + +Summary of file changes + +┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ +│ File │ Changes │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ spiel_bot/Cargo.toml │ add rayon │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ +├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ +│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ +└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ + +Key design constraint + +Parts A and B are independent and composable: + +- A only touches the outer game loop. +- B only touches the inner MCTS per game. +- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.