diff --git a/Cargo.lock b/Cargo.lock index a6c9481..34bfe80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,7 +6009,6 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", - "rayon", "trictrac-store", ] diff --git a/doc/spiel_bot_parallel.md b/doc/spiel_bot_parallel.md deleted file mode 100644 index d9e021e..0000000 --- a/doc/spiel_bot_parallel.md +++ /dev/null @@ -1,121 +0,0 @@ -Part B — Batched MCTS leaf evaluation - -Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls. - -Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs) - -pub trait Evaluator: Send + Sync { -fn evaluate(&self, obs: &[f32]) -> (Vec, f32); - - /// Evaluate a batch of observations at once. Default falls back to - /// sequential calls; backends override this for efficiency. - fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { - obs_batch.iter().map(|obs| self.evaluate(obs)).collect() - } - -} - -Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs) - -Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows. - -fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { -let b = obs_batch.len(); -let obs_size = obs_batch[0].len(); -let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); -let obs_tensor = Tensor::::from_data(TensorData::new(flat, [b, obs_size]), &self.device); -let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); -let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); -let values: Vec = value_tensor.into_data().to_vec().unwrap(); -let action_size = policies.len() / b; -(0..b).map(|i| { -(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) -}).collect() -} - -Step B3 — Add eval_batch_size to MctsConfig - -pub struct MctsConfig { -// ... existing fields ... -/// Number of leaves to batch per network call. 1 = no batching (current behaviour). -pub eval_batch_size: usize, -} - -Default: 1 (backwards-compatible). - -Step B4 — Make the simulation iterative (mcts/search.rs) - -The current simulate is recursive. For batching we need to split it into two phases: - -descend (pure tree traversal — no network call): - -- Traverse from root following PUCT, advancing through chance nodes with apply_chance. -- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect). -- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance. -- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path. - -ascend (backup — no network call): - -- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions. -- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value. - -Step B5 — Add run_mcts_batched to mcts/mod.rs - -The new entry point, called by run_mcts when config.eval_batch_size > 1: - -expand root (1 network call) -optionally add Dirichlet noise - -for round in 0..(n*simulations / batch_size): -leaves = [] -for * in 0..batch_size: -leaf = descend(root, state, env, rng) -leaves.push(leaf) - - obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves - where leaf.kind == NeedsEval] - results = evaluator.evaluate_batch(obs_batch) - - for (leaf, result) in zip(leaves, results): - expand the leaf node (insert children from result.policy) - ascend(root, leaf.path, result.value, leaf.player_idx) - // ascend also handles terminal and crossed-chance leaves - -// handle remainder: n_simulations % batch_size - -run_mcts becomes a thin dispatcher: -if config.eval_batch_size <= 1 { -// existing path (unchanged) -} else { -run_mcts_batched(...) -} - -Step B6 — CLI flag in az_train.rs - ---eval-batch N default: 8 Leaf batch size for MCTS network calls - ---- - -Summary of file changes - -┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ -│ File │ Changes │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ spiel_bot/Cargo.toml │ add rayon │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ -├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ -│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ -└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘ - -Key design constraint - -Parts A and B are independent and composable: - -- A only touches the outer game loop. -- B only touches the inner MCTS per game. -- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state. diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 682505b..3848dce 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,7 +9,6 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } -rayon = "1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs index b38b7f4..6f10f8d 100644 --- a/spiel_bot/src/alphazero/selfplay.rs +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -31,10 +31,6 @@ impl> BurnEvaluator { pub fn into_model(self) -> N { self.model } - - pub fn model_ref(&self) -> &N { - &self.model - } } // Safety: NdArray modules are Send; we never share across threads without diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs index 824abe5..ab385c2 100644 --- a/spiel_bot/src/bin/az_train.rs +++ b/spiel_bot/src/bin/az_train.rs @@ -47,8 +47,7 @@ use burn::{ optim::AdamConfig, tensor::backend::Backend, }; -use rand::{Rng, SeedableRng, rngs::SmallRng}; -use rayon::prelude::*; +use rand::{SeedableRng, rngs::SmallRng}; use spiel_bot::{ alphazero::{ @@ -196,26 +195,10 @@ where if step < temp_drop { 1.0 } else { 0.0 } }; - // Prepare per-game seeds and evaluators sequentially so the main RNG - // and model cloning stay deterministic regardless of thread scheduling. - // Burn modules are Send but not Sync, so each task must own its model. - let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); - let game_evals: Vec<_> = (0..args.n_games) - .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) - .collect(); - drop(evaluator); - - let all_samples: Vec> = game_seeds - .into_par_iter() - .zip(game_evals.into_par_iter()) - .map(|(seed, game_eval)| { - let mut game_rng = SmallRng::seed_from_u64(seed); - generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) - }) - .collect(); - let mut new_samples = 0usize; - for samples in all_samples { + for _ in 0..args.n_games { + let samples = + generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); new_samples += samples.len(); replay.extend(samples); }