feat(spiel_bot): az_train parallel games with rayon
This commit is contained in:
parent
150efe302f
commit
4691a84e23
4 changed files with 27 additions and 4 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -6009,6 +6009,7 @@ dependencies = [
|
||||||
"criterion",
|
"criterion",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
|
"rayon",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
rand_distr = "0.5"
|
rand_distr = "0.5"
|
||||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
rayon = "1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
||||||
pub fn into_model(self) -> N {
|
pub fn into_model(self) -> N {
|
||||||
self.model
|
self.model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn model_ref(&self) -> &N {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,8 @@ use burn::{
|
||||||
optim::AdamConfig,
|
optim::AdamConfig,
|
||||||
tensor::backend::Backend,
|
tensor::backend::Backend,
|
||||||
};
|
};
|
||||||
use rand::{SeedableRng, rngs::SmallRng};
|
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use spiel_bot::{
|
use spiel_bot::{
|
||||||
alphazero::{
|
alphazero::{
|
||||||
|
|
@ -195,10 +196,26 @@ where
|
||||||
if step < temp_drop { 1.0 } else { 0.0 }
|
if step < temp_drop { 1.0 } else { 0.0 }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Prepare per-game seeds and evaluators sequentially so the main RNG
|
||||||
|
// and model cloning stay deterministic regardless of thread scheduling.
|
||||||
|
// Burn modules are Send but not Sync, so each task must own its model.
|
||||||
|
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
|
||||||
|
let game_evals: Vec<_> = (0..args.n_games)
|
||||||
|
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
|
||||||
|
.collect();
|
||||||
|
drop(evaluator);
|
||||||
|
|
||||||
|
let all_samples: Vec<Vec<TrainSample>> = game_seeds
|
||||||
|
.into_par_iter()
|
||||||
|
.zip(game_evals.into_par_iter())
|
||||||
|
.map(|(seed, game_eval)| {
|
||||||
|
let mut game_rng = SmallRng::seed_from_u64(seed);
|
||||||
|
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
let mut new_samples = 0usize;
|
let mut new_samples = 0usize;
|
||||||
for _ in 0..args.n_games {
|
for samples in all_samples {
|
||||||
let samples =
|
|
||||||
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
|
|
||||||
new_samples += samples.len();
|
new_samples += samples.len();
|
||||||
replay.extend(samples);
|
replay.extend(samples);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue