feat(spiel_bot): az_train parallel games with rayon
This commit is contained in:
parent
150efe302f
commit
4691a84e23
4 changed files with 27 additions and 4 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -6009,6 +6009,7 @@ dependencies = [
|
|||
"criterion",
|
||||
"rand 0.9.2",
|
||||
"rand_distr",
|
||||
"rayon",
|
||||
"trictrac-store",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ anyhow = "1"
|
|||
rand = "0.9"
|
||||
rand_distr = "0.5"
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
rayon = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
|||
pub fn into_model(self) -> N {
|
||||
self.model
|
||||
}
|
||||
|
||||
pub fn model_ref(&self) -> &N {
|
||||
&self.model
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ use burn::{
|
|||
optim::AdamConfig,
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::{
|
||||
|
|
@ -195,10 +196,26 @@ where
|
|||
if step < temp_drop { 1.0 } else { 0.0 }
|
||||
};
|
||||
|
||||
// Prepare per-game seeds and evaluators sequentially so the main RNG
|
||||
// and model cloning stay deterministic regardless of thread scheduling.
|
||||
// Burn modules are Send but not Sync, so each task must own its model.
|
||||
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
|
||||
let game_evals: Vec<_> = (0..args.n_games)
|
||||
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
|
||||
.collect();
|
||||
drop(evaluator);
|
||||
|
||||
let all_samples: Vec<Vec<TrainSample>> = game_seeds
|
||||
.into_par_iter()
|
||||
.zip(game_evals.into_par_iter())
|
||||
.map(|(seed, game_eval)| {
|
||||
let mut game_rng = SmallRng::seed_from_u64(seed);
|
||||
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut new_samples = 0usize;
|
||||
for _ in 0..args.n_games {
|
||||
let samples =
|
||||
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
|
||||
for samples in all_samples {
|
||||
new_samples += samples.len();
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue