doc: research parallel
This commit is contained in:
parent
31bb568c2a
commit
1554286f25
1 changed files with 121 additions and 0 deletions
121
doc/spiel_bot_parallel.md
Normal file
121
doc/spiel_bot_parallel.md
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
Part B — Batched MCTS leaf evaluation
|
||||||
|
|
||||||
|
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
|
||||||
|
|
||||||
|
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
|
||||||
|
|
||||||
|
pub trait Evaluator: Send + Sync {
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||||
|
|
||||||
|
/// Evaluate a batch of observations at once. Default falls back to
|
||||||
|
/// sequential calls; backends override this for efficiency.
|
||||||
|
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||||
|
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
|
||||||
|
|
||||||
|
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
|
||||||
|
|
||||||
|
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||||
|
let b = obs_batch.len();
|
||||||
|
let obs_size = obs_batch[0].len();
|
||||||
|
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
|
||||||
|
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||||
|
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||||
|
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||||
|
let action_size = policies.len() / b;
|
||||||
|
(0..b).map(|i| {
|
||||||
|
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B3 — Add eval_batch_size to MctsConfig
|
||||||
|
|
||||||
|
pub struct MctsConfig {
|
||||||
|
// ... existing fields ...
|
||||||
|
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
|
||||||
|
pub eval_batch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
Default: 1 (backwards-compatible).
|
||||||
|
|
||||||
|
Step B4 — Make the simulation iterative (mcts/search.rs)
|
||||||
|
|
||||||
|
The current simulate is recursive. For batching we need to split it into two phases:
|
||||||
|
|
||||||
|
descend (pure tree traversal — no network call):
|
||||||
|
|
||||||
|
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
|
||||||
|
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
|
||||||
|
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
|
||||||
|
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
|
||||||
|
|
||||||
|
ascend (backup — no network call):
|
||||||
|
|
||||||
|
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
|
||||||
|
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
|
||||||
|
|
||||||
|
Step B5 — Add run_mcts_batched to mcts/mod.rs
|
||||||
|
|
||||||
|
The new entry point, called by run_mcts when config.eval_batch_size > 1:
|
||||||
|
|
||||||
|
expand root (1 network call)
|
||||||
|
optionally add Dirichlet noise
|
||||||
|
|
||||||
|
for round in 0..(n*simulations / batch_size):
|
||||||
|
leaves = []
|
||||||
|
for * in 0..batch_size:
|
||||||
|
leaf = descend(root, state, env, rng)
|
||||||
|
leaves.push(leaf)
|
||||||
|
|
||||||
|
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
|
||||||
|
where leaf.kind == NeedsEval]
|
||||||
|
results = evaluator.evaluate_batch(obs_batch)
|
||||||
|
|
||||||
|
for (leaf, result) in zip(leaves, results):
|
||||||
|
expand the leaf node (insert children from result.policy)
|
||||||
|
ascend(root, leaf.path, result.value, leaf.player_idx)
|
||||||
|
// ascend also handles terminal and crossed-chance leaves
|
||||||
|
|
||||||
|
// handle remainder: n_simulations % batch_size
|
||||||
|
|
||||||
|
run_mcts becomes a thin dispatcher:
|
||||||
|
if config.eval_batch_size <= 1 {
|
||||||
|
// existing path (unchanged)
|
||||||
|
} else {
|
||||||
|
run_mcts_batched(...)
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B6 — CLI flag in az_train.rs
|
||||||
|
|
||||||
|
--eval-batch N default: 8 Leaf batch size for MCTS network calls
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Summary of file changes
|
||||||
|
|
||||||
|
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ File │ Changes │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ spiel_bot/Cargo.toml │ add rayon │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
|
||||||
|
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Key design constraint
|
||||||
|
|
||||||
|
Parts A and B are independent and composable:
|
||||||
|
|
||||||
|
- A only touches the outer game loop.
|
||||||
|
- B only touches the inner MCTS per game.
|
||||||
|
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.
|
||||||
Loading…
Add table
Add a link
Reference in a new issue