Compare commits

..

2 commits

5 changed files with 148 additions and 4 deletions

1
Cargo.lock generated
View file

@ -6009,6 +6009,7 @@ dependencies = [
"criterion",
"rand 0.9.2",
"rand_distr",
"rayon",
"trictrac-store",
]

121
doc/spiel_bot_parallel.md Normal file
View file

@ -0,0 +1,121 @@
Part B — Batched MCTS leaf evaluation
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
pub trait Evaluator: Send + Sync {
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
/// Evaluate a batch of observations at once. Default falls back to
/// sequential calls; backends override this for efficiency.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
}
}
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
let b = obs_batch.len();
let obs_size = obs_batch[0].len();
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
let action_size = policies.len() / b;
(0..b).map(|i| {
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
}).collect()
}
Step B3 — Add eval_batch_size to MctsConfig
pub struct MctsConfig {
// ... existing fields ...
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
pub eval_batch_size: usize,
}
Default: 1 (backwards-compatible).
Step B4 — Make the simulation iterative (mcts/search.rs)
The current simulate is recursive. For batching we need to split it into two phases:
descend (pure tree traversal — no network call):
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
ascend (backup — no network call):
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
Step B5 — Add run_mcts_batched to mcts/mod.rs
The new entry point, called by run_mcts when config.eval_batch_size > 1:
expand root (1 network call)
optionally add Dirichlet noise
for round in 0..(n*simulations / batch_size):
leaves = []
for * in 0..batch_size:
leaf = descend(root, state, env, rng)
leaves.push(leaf)
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
where leaf.kind == NeedsEval]
results = evaluator.evaluate_batch(obs_batch)
for (leaf, result) in zip(leaves, results):
expand the leaf node (insert children from result.policy)
ascend(root, leaf.path, result.value, leaf.player_idx)
// ascend also handles terminal and crossed-chance leaves
// handle remainder: n_simulations % batch_size
run_mcts becomes a thin dispatcher:
if config.eval_batch_size <= 1 {
// existing path (unchanged)
} else {
run_mcts_batched(...)
}
Step B6 — CLI flag in az_train.rs
--eval-batch N default: 8 Leaf batch size for MCTS network calls
---
Summary of file changes
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
│ File │ Changes │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ spiel_bot/Cargo.toml │ add rayon │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
Key design constraint
Parts A and B are independent and composable:
- A only touches the outer game loop.
- B only touches the inner MCTS per game.
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.

View file

@ -9,6 +9,7 @@ anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
rayon = "1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

View file

@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
pub fn into_model(self) -> N {
self.model
}
pub fn model_ref(&self) -> &N {
&self.model
}
}
// Safety: NdArray<f32> modules are Send; we never share across threads without

View file

@ -47,7 +47,8 @@ use burn::{
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use rayon::prelude::*;
use spiel_bot::{
alphazero::{
@ -195,10 +196,26 @@ where
if step < temp_drop { 1.0 } else { 0.0 }
};
// Prepare per-game seeds and evaluators sequentially so the main RNG
// and model cloning stay deterministic regardless of thread scheduling.
// Burn modules are Send but not Sync, so each task must own its model.
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
let game_evals: Vec<_> = (0..args.n_games)
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
.collect();
drop(evaluator);
let all_samples: Vec<Vec<TrainSample>> = game_seeds
.into_par_iter()
.zip(game_evals.into_par_iter())
.map(|(seed, game_eval)| {
let mut game_rng = SmallRng::seed_from_u64(seed);
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
})
.collect();
let mut new_samples = 0usize;
for _ in 0..args.n_games {
let samples =
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
for samples in all_samples {
new_samples += samples.len();
replay.extend(samples);
}