6.6 KiB
Part B — Batched MCTS leaf evaluation
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
pub trait Evaluator: Send + Sync { fn evaluate(&self, obs: &[f32]) -> (Vec, f32);
/// Evaluate a batch of observations at once. Default falls back to
/// sequential calls; backends override this for efficiency.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
}
}
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec, f32)> { let b = obs_batch.len(); let obs_size = obs_batch[0].len(); let flat: Vec = obs_batch.iter().flat_map(|o| o.iter().copied()).collect(); let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device); let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); let policies: Vec = policy_tensor.into_data().to_vec().unwrap(); let values: Vec = value_tensor.into_data().to_vec().unwrap(); let action_size = policies.len() / b; (0..b).map(|i| { (policies[i * action_size..(i + 1) * action_size].to_vec(), values[i]) }).collect() }
Step B3 — Add eval_batch_size to MctsConfig
pub struct MctsConfig { // ... existing fields ... /// Number of leaves to batch per network call. 1 = no batching (current behaviour). pub eval_batch_size: usize, }
Default: 1 (backwards-compatible).
Step B4 — Make the simulation iterative (mcts/search.rs)
The current simulate is recursive. For batching we need to split it into two phases:
descend (pure tree traversal — no network call):
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
- Return a LeafWork { path: Vec, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
ascend (backup — no network call):
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
Step B5 — Add run_mcts_batched to mcts/mod.rs
The new entry point, called by run_mcts when config.eval_batch_size > 1:
expand root (1 network call) optionally add Dirichlet noise
for round in 0..(n*simulations / batch_size): leaves = [] for * in 0..batch_size: leaf = descend(root, state, env, rng) leaves.push(leaf)
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
where leaf.kind == NeedsEval]
results = evaluator.evaluate_batch(obs_batch)
for (leaf, result) in zip(leaves, results):
expand the leaf node (insert children from result.policy)
ascend(root, leaf.path, result.value, leaf.player_idx)
// ascend also handles terminal and crossed-chance leaves
// handle remainder: n_simulations % batch_size
run_mcts becomes a thin dispatcher: if config.eval_batch_size <= 1 { // existing path (unchanged) } else { run_mcts_batched(...) }
Step B6 — CLI flag in az_train.rs
--eval-batch N default: 8 Leaf batch size for MCTS network calls
Summary of file changes
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐ │ File │ Changes │ ├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ spiel_bot/Cargo.toml │ add rayon │ ├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │ ├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │ ├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │ ├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤ │ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │ └───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
Key design constraint
Parts A and B are independent and composable:
- A only touches the outer game loop.
- B only touches the inner MCTS per game.
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.