feat(spiel_bot): benchmarks

This commit is contained in:
Henri Bourcereau 2026-03-07 22:49:55 +01:00
parent b074a401ba
commit 2e85c14dbb
3 changed files with 468 additions and 0 deletions

120
Cargo.lock generated
View file

@ -92,6 +92,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.21"
@ -1116,6 +1122,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cast_trait"
version = "0.1.2"
@ -1200,6 +1212,33 @@ dependencies = [
"rand 0.7.3",
]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
@ -1453,6 +1492,42 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "critical-section"
version = "1.2.0"
@ -4461,6 +4536,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@ -4597,6 +4678,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.18.0"
@ -5897,6 +6006,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"burn",
"criterion",
"rand 0.9.2",
"rand_distr",
"trictrac-store",
@ -6310,6 +6420,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"

View file

@ -9,3 +9,10 @@ anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "alphazero"
harness = false

View file

@ -0,0 +1,341 @@
//! AlphaZero pipeline benchmarks.
//!
//! Run with:
//!
//! ```sh
//! cargo bench -p spiel_bot
//! ```
//!
//! Use `-- <filter>` to run a specific group, e.g.:
//!
//! ```sh
//! cargo bench -p spiel_bot -- env/
//! cargo bench -p spiel_bot -- network/
//! cargo bench -p spiel_bot -- mcts/
//! cargo bench -p spiel_bot -- episode/
//! cargo bench -p spiel_bot -- train/
//! ```
//!
//! Target: ≥ 500 games/s for random play on CPU (consistent with
//! `random_game` throughput in `trictrac-store`).
use std::time::Duration;
use burn::{
backend::NdArray,
tensor::{Tensor, TensorData, backend::Backend},
};
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet},
};
// ── Shared types ───────────────────────────────────────────────────────────
type InferB = NdArray<f32>;
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
/// Uniform evaluator (returns zero logits and zero value).
/// Used to isolate MCTS tree-traversal cost from network cost.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
// ── 1. Environment primitives ──────────────────────────────────────────────
/// Baseline performance of the raw Trictrac environment without MCTS.
/// Target: ≥ 500 full games / second.
fn bench_env(c: &mut Criterion) {
let env = TrictracEnv;
let mut group = c.benchmark_group("env");
group.measurement_time(Duration::from_secs(10));
// ── apply_chance ──────────────────────────────────────────────────────
group.bench_function("apply_chance", |b| {
b.iter_batched(
|| {
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
env.new_game()
},
|mut s| {
env.apply_chance(&mut s, &mut seeded());
black_box(s)
},
BatchSize::SmallInput,
)
});
// ── legal_actions ─────────────────────────────────────────────────────
group.bench_function("legal_actions", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.legal_actions(&s)))
});
// ── observation (to_tensor) ───────────────────────────────────────────
group.bench_function("observation", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.observation(&s, 0)))
});
// ── full random game ──────────────────────────────────────────────────
group.sample_size(50);
group.bench_function("random_game", |b| {
b.iter_batched(
seeded,
|mut rng| {
let mut s = env.new_game();
loop {
match env.current_player(&s) {
Player::Terminal => break,
Player::Chance => env.apply_chance(&mut s, &mut rng),
_ => {
let actions = env.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
env.apply(&mut s, actions[idx]);
}
}
}
black_box(s)
},
BatchSize::SmallInput,
)
});
group.finish();
}
// ── 2. Network inference ───────────────────────────────────────────────────
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
fn bench_network(c: &mut Criterion) {
let mut group = c.benchmark_group("network");
group.measurement_time(Duration::from_secs(5));
for &hidden in &[64usize, 256] {
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
// Batch size 1 — single-position evaluation as in MCTS.
group.bench_with_input(
BenchmarkId::new("mlp_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
// Batch size 32 — training mini-batch.
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("mlp_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}
// ── 3. MCTS ───────────────────────────────────────────────────────────────
/// MCTS cost at different simulation budgets with two evaluator types:
/// - `zero` — isolates tree-traversal overhead (no network).
/// - `mlp64` — real MLP, shows end-to-end cost per move.
fn bench_mcts(c: &mut Criterion) {
let env = TrictracEnv;
// Build a decision-node state (after dice roll).
let state = {
let mut s = env.new_game();
let mut rng = seeded();
while env.current_player(&s).is_chance() {
env.apply_chance(&mut s, &mut rng);
}
s
};
let mut group = c.benchmark_group("mcts");
group.measurement_time(Duration::from_secs(10));
let zero_eval = ZeroEval(514);
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
for &n_sim in &[1usize, 5, 20] {
let cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
// Zero evaluator: tree traversal only.
group.bench_with_input(
BenchmarkId::new("zero_eval", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
// MLP evaluator: full cost per decision.
group.bench_with_input(
BenchmarkId::new("mlp64", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 4. Episode generation ─────────────────────────────────────────────────
/// Full self-play episode latency (one complete game) at different MCTS
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
fn bench_episode(c: &mut Criterion) {
let env = TrictracEnv;
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
let mut group = c.benchmark_group("episode");
group.sample_size(10);
group.measurement_time(Duration::from_secs(60));
for &n_sim in &[1usize, 2] {
let mcts_cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
group.bench_with_input(
BenchmarkId::new("trictrac", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| {
black_box(generate_episode(
&env,
&eval,
&mcts_cfg,
&|_| 1.0,
&mut rng,
))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 5. Training step ───────────────────────────────────────────────────────
/// Gradient-step latency for different batch sizes.
fn bench_train(c: &mut Criterion) {
use burn::optim::AdamConfig;
let mut group = c.benchmark_group("train");
group.measurement_time(Duration::from_secs(10));
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let dummy_samples = |n: usize| -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5; 217],
policy: {
let mut p = vec![0.0f32; 514];
p[i % 514] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
};
for &batch_size in &[16usize, 64] {
let batch = dummy_samples(batch_size);
group.bench_with_input(
BenchmarkId::new("mlp64_adam", batch_size),
&batch_size,
|b, _| {
b.iter_batched(
|| {
(
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
)
},
|(model, mut opt)| {
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── Criterion entry point ──────────────────────────────────────────────────
criterion_group!(
benches,
bench_env,
bench_network,
bench_mcts,
bench_episode,
bench_train,
);
criterion_main!(benches);