From 31bb568c2a51f51abf0e962019d3e8360b7b281f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 9 Mar 2026 19:43:52 +0100 Subject: [PATCH] feat(spiel_bot): az_train parallel games with rayon --- Cargo.lock | 1 + spiel_bot/Cargo.toml | 1 + spiel_bot/src/alphazero/selfplay.rs | 4 ++++ spiel_bot/src/bin/az_train.rs | 25 +++++++++++++++++++++---- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..a6c9481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "rayon", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..682505b 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } +rayon = "1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs index 6f10f8d..b38b7f4 100644 --- a/spiel_bot/src/alphazero/selfplay.rs +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -31,6 +31,10 @@ impl> BurnEvaluator { pub fn into_model(self) -> N { self.model } + + pub fn model_ref(&self) -> &N { + &self.model + } } // Safety: NdArray modules are Send; we never share across threads without diff --git a/spiel_bot/src/bin/az_train.rs b/spiel_bot/src/bin/az_train.rs index ab385c2..824abe5 100644 --- a/spiel_bot/src/bin/az_train.rs +++ b/spiel_bot/src/bin/az_train.rs @@ -47,7 +47,8 @@ use burn::{ optim::AdamConfig, tensor::backend::Backend, }; -use rand::{SeedableRng, rngs::SmallRng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; +use rayon::prelude::*; use spiel_bot::{ alphazero::{ @@ -195,10 +196,26 @@ where if step < temp_drop { 1.0 } else { 0.0 } }; + // Prepare per-game seeds and evaluators sequentially so the main RNG + // and model cloning stay deterministic regardless of thread scheduling. + // Burn modules are Send but not Sync, so each task must own its model. + let game_seeds: Vec = (0..args.n_games).map(|_| rng.random()).collect(); + let game_evals: Vec<_> = (0..args.n_games) + .map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone())) + .collect(); + drop(evaluator); + + let all_samples: Vec> = game_seeds + .into_par_iter() + .zip(game_evals.into_par_iter()) + .map(|(seed, game_eval)| { + let mut game_rng = SmallRng::seed_from_u64(seed); + generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng) + }) + .collect(); + let mut new_samples = 0usize; - for _ in 0..args.n_games { - let samples = - generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng); + for samples in all_samples { new_samples += samples.len(); replay.extend(samples); }