feat(spiel_bot): az_train parallel games with rayon

This commit is contained in:
Henri Bourcereau 2026-03-09 19:43:52 +01:00
parent e80dade303
commit 31bb568c2a
4 changed files with 27 additions and 4 deletions

1
Cargo.lock generated
View file

@ -6009,6 +6009,7 @@ dependencies = [
"criterion",
"rand 0.9.2",
"rand_distr",
"rayon",
"trictrac-store",
]

View file

@ -9,6 +9,7 @@ anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
rayon = "1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

View file

@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
pub fn into_model(self) -> N {
self.model
}
pub fn model_ref(&self) -> &N {
&self.model
}
}
// Safety: NdArray<f32> modules are Send; we never share across threads without

View file

@ -47,7 +47,8 @@ use burn::{
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use rayon::prelude::*;
use spiel_bot::{
alphazero::{
@ -195,10 +196,26 @@ where
if step < temp_drop { 1.0 } else { 0.0 }
};
// Prepare per-game seeds and evaluators sequentially so the main RNG
// and model cloning stay deterministic regardless of thread scheduling.
// Burn modules are Send but not Sync, so each task must own its model.
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
let game_evals: Vec<_> = (0..args.n_games)
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
.collect();
drop(evaluator);
let all_samples: Vec<Vec<TrainSample>> = game_seeds
.into_par_iter()
.zip(game_evals.into_par_iter())
.map(|(seed, game_eval)| {
let mut game_rng = SmallRng::seed_from_u64(seed);
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
})
.collect();
let mut new_samples = 0usize;
for _ in 0..args.n_games {
let samples =
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
for samples in all_samples {
new_samples += samples.len();
replay.extend(samples);
}