From 2e85c14dbbaa9f04a9fe44c876140bef2b0ad28b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:49:55 +0100 Subject: [PATCH] feat(spiel_bot): benchmarks --- Cargo.lock | 120 ++++++++++++ spiel_bot/Cargo.toml | 7 + spiel_bot/benches/alphazero.rs | 341 +++++++++++++++++++++++++++++++++ 3 files changed, 468 insertions(+) create mode 100644 spiel_bot/benches/alphazero.rs diff --git a/Cargo.lock b/Cargo.lock index 0baa02a..34bfe80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -1116,6 +1122,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cast_trait" version = "0.1.2" @@ -1200,6 +1212,33 @@ dependencies = [ "rand 0.7.3", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1453,6 +1492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -4461,6 +4536,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4597,6 +4678,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5897,6 +6006,7 @@ version = "0.1.0" dependencies = [ "anyhow", "burn", + "criterion", "rand 0.9.2", "rand_distr", "trictrac-store", @@ -6310,6 +6420,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 323c953..3848dce 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,3 +9,10 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "alphazero" +harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs new file mode 100644 index 0000000..2950b09 --- /dev/null +++ b/spiel_bot/benches/alphazero.rs @@ -0,0 +1,341 @@ +//! AlphaZero pipeline benchmarks. +//! +//! Run with: +//! +//! ```sh +//! cargo bench -p spiel_bot +//! ``` +//! +//! Use `-- ` to run a specific group, e.g.: +//! +//! ```sh +//! cargo bench -p spiel_bot -- env/ +//! cargo bench -p spiel_bot -- network/ +//! cargo bench -p spiel_bot -- mcts/ +//! cargo bench -p spiel_bot -- episode/ +//! cargo bench -p spiel_bot -- train/ +//! ``` +//! +//! Target: ≥ 500 games/s for random play on CPU (consistent with +//! `random_game` throughput in `trictrac-store`). + +use std::time::Duration; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData, backend::Backend}, +}; +use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts}, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Shared types ─────────────────────────────────────────────────────────── + +type InferB = NdArray; +type TrainB = burn::backend::Autodiff>; + +fn infer_device() -> ::Device { Default::default() } +fn train_device() -> ::Device { Default::default() } + +fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } + +/// Uniform evaluator (returns zero logits and zero value). +/// Used to isolate MCTS tree-traversal cost from network cost. +struct ZeroEval(usize); +impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } +} + +// ── 1. Environment primitives ────────────────────────────────────────────── + +/// Baseline performance of the raw Trictrac environment without MCTS. +/// Target: ≥ 500 full games / second. +fn bench_env(c: &mut Criterion) { + let env = TrictracEnv; + + let mut group = c.benchmark_group("env"); + group.measurement_time(Duration::from_secs(10)); + + // ── apply_chance ────────────────────────────────────────────────────── + group.bench_function("apply_chance", |b| { + b.iter_batched( + || { + // A fresh game is always at RollDice (Chance) — ready for apply_chance. + env.new_game() + }, + |mut s| { + env.apply_chance(&mut s, &mut seeded()); + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + // ── legal_actions ───────────────────────────────────────────────────── + group.bench_function("legal_actions", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.legal_actions(&s))) + }); + + // ── observation (to_tensor) ─────────────────────────────────────────── + group.bench_function("observation", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.observation(&s, 0))) + }); + + // ── full random game ────────────────────────────────────────────────── + group.sample_size(50); + group.bench_function("random_game", |b| { + b.iter_batched( + seeded, + |mut rng| { + let mut s = env.new_game(); + loop { + match env.current_player(&s) { + Player::Terminal => break, + Player::Chance => env.apply_chance(&mut s, &mut rng), + _ => { + let actions = env.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + env.apply(&mut s, actions[idx]); + } + } + } + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +// ── 2. Network inference ─────────────────────────────────────────────────── + +/// Forward-pass latency for MLP variants (hidden = 64 / 256). +fn bench_network(c: &mut Criterion) { + let mut group = c.benchmark_group("network"); + group.measurement_time(Duration::from_secs(5)); + + for &hidden in &[64usize, 256] { + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = MlpNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + // Batch size 1 — single-position evaluation as in MCTS. + group.bench_with_input( + BenchmarkId::new("mlp_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + // Batch size 32 — training mini-batch. + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("mlp_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + group.finish(); +} + +// ── 3. MCTS ─────────────────────────────────────────────────────────────── + +/// MCTS cost at different simulation budgets with two evaluator types: +/// - `zero` — isolates tree-traversal overhead (no network). +/// - `mlp64` — real MLP, shows end-to-end cost per move. +fn bench_mcts(c: &mut Criterion) { + let env = TrictracEnv; + + // Build a decision-node state (after dice roll). + let state = { + let mut s = env.new_game(); + let mut rng = seeded(); + while env.current_player(&s).is_chance() { + env.apply_chance(&mut s, &mut rng); + } + s + }; + + let mut group = c.benchmark_group("mcts"); + group.measurement_time(Duration::from_secs(10)); + + let zero_eval = ZeroEval(514); + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); + let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); + + for &n_sim in &[1usize, 5, 20] { + let cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + // Zero evaluator: tree traversal only. + group.bench_with_input( + BenchmarkId::new("zero_eval", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + + // MLP evaluator: full cost per decision. + group.bench_with_input( + BenchmarkId::new("mlp64", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 4. Episode generation ───────────────────────────────────────────────── + +/// Full self-play episode latency (one complete game) at different MCTS +/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. +fn bench_episode(c: &mut Criterion) { + let env = TrictracEnv; + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let model = MlpNet::::new(&mlp_cfg, &infer_device()); + let eval = BurnEvaluator::::new(model, infer_device()); + + let mut group = c.benchmark_group("episode"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(60)); + + for &n_sim in &[1usize, 2] { + let mcts_cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + group.bench_with_input( + BenchmarkId::new("trictrac", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| { + black_box(generate_episode( + &env, + &eval, + &mcts_cfg, + &|_| 1.0, + &mut rng, + )) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 5. Training step ─────────────────────────────────────────────────────── + +/// Gradient-step latency for different batch sizes. +fn bench_train(c: &mut Criterion) { + use burn::optim::AdamConfig; + + let mut group = c.benchmark_group("train"); + group.measurement_time(Duration::from_secs(10)); + + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + + let dummy_samples = |n: usize| -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5; 217], + policy: { + let mut p = vec![0.0f32; 514]; + p[i % 514] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + }; + + for &batch_size in &[16usize, 64] { + let batch = dummy_samples(batch_size); + + group.bench_with_input( + BenchmarkId::new("mlp64_adam", batch_size), + &batch_size, + |b, _| { + b.iter_batched( + || { + ( + MlpNet::::new(&mlp_cfg, &train_device()), + AdamConfig::new().init::>(), + ) + }, + |(model, mut opt)| { + black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── Criterion entry point ────────────────────────────────────────────────── + +criterion_group!( + benches, + bench_env, + bench_network, + bench_mcts, + bench_episode, + bench_train, +); +criterion_main!(benches);