feat(spiel_bot): benchmarks
This commit is contained in:
parent
b074a401ba
commit
2e85c14dbb
3 changed files with 468 additions and 0 deletions
120
Cargo.lock
generated
120
Cargo.lock
generated
|
|
@ -92,6 +92,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
|
|
@ -1116,6 +1122,12 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cast_trait"
|
||||
version = "0.1.2"
|
||||
|
|
@ -1200,6 +1212,33 @@ dependencies = [
|
|||
"rand 0.7.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"ciborium-ll",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-io"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-ll"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
|
|
@ -1453,6 +1492,42 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "critical-section"
|
||||
version = "1.2.0"
|
||||
|
|
@ -4461,6 +4536,12 @@ version = "1.70.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
|
|
@ -4597,6 +4678,34 @@ version = "0.3.32"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
"plotters-svg",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.18.0"
|
||||
|
|
@ -5897,6 +6006,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"burn",
|
||||
"criterion",
|
||||
"rand 0.9.2",
|
||||
"rand_distr",
|
||||
"trictrac-store",
|
||||
|
|
@ -6310,6 +6420,16 @@ dependencies = [
|
|||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
|
|
|
|||
|
|
@ -9,3 +9,10 @@ anyhow = "1"
|
|||
rand = "0.9"
|
||||
rand_distr = "0.5"
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
||||
[[bench]]
|
||||
name = "alphazero"
|
||||
harness = false
|
||||
|
|
|
|||
341
spiel_bot/benches/alphazero.rs
Normal file
341
spiel_bot/benches/alphazero.rs
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
//! AlphaZero pipeline benchmarks.
|
||||
//!
|
||||
//! Run with:
|
||||
//!
|
||||
//! ```sh
|
||||
//! cargo bench -p spiel_bot
|
||||
//! ```
|
||||
//!
|
||||
//! Use `-- <filter>` to run a specific group, e.g.:
|
||||
//!
|
||||
//! ```sh
|
||||
//! cargo bench -p spiel_bot -- env/
|
||||
//! cargo bench -p spiel_bot -- network/
|
||||
//! cargo bench -p spiel_bot -- mcts/
|
||||
//! cargo bench -p spiel_bot -- episode/
|
||||
//! cargo bench -p spiel_bot -- train/
|
||||
//! ```
|
||||
//!
|
||||
//! Target: ≥ 500 games/s for random play on CPU (consistent with
|
||||
//! `random_game` throughput in `trictrac-store`).
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
tensor::{Tensor, TensorData, backend::Backend},
|
||||
};
|
||||
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
|
||||
use spiel_bot::{
|
||||
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet},
|
||||
};
|
||||
|
||||
// ── Shared types ───────────────────────────────────────────────────────────
|
||||
|
||||
type InferB = NdArray<f32>;
|
||||
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
|
||||
|
||||
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||
|
||||
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
|
||||
|
||||
/// Uniform evaluator (returns zero logits and zero value).
|
||||
/// Used to isolate MCTS tree-traversal cost from network cost.
|
||||
struct ZeroEval(usize);
|
||||
impl Evaluator for ZeroEval {
|
||||
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||
(vec![0.0f32; self.0], 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
// ── 1. Environment primitives ──────────────────────────────────────────────
|
||||
|
||||
/// Baseline performance of the raw Trictrac environment without MCTS.
|
||||
/// Target: ≥ 500 full games / second.
|
||||
fn bench_env(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
|
||||
let mut group = c.benchmark_group("env");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
// ── apply_chance ──────────────────────────────────────────────────────
|
||||
group.bench_function("apply_chance", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
|
||||
env.new_game()
|
||||
},
|
||||
|mut s| {
|
||||
env.apply_chance(&mut s, &mut seeded());
|
||||
black_box(s)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
// ── legal_actions ─────────────────────────────────────────────────────
|
||||
group.bench_function("legal_actions", |b| {
|
||||
let mut rng = seeded();
|
||||
let mut s = env.new_game();
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
b.iter(|| black_box(env.legal_actions(&s)))
|
||||
});
|
||||
|
||||
// ── observation (to_tensor) ───────────────────────────────────────────
|
||||
group.bench_function("observation", |b| {
|
||||
let mut rng = seeded();
|
||||
let mut s = env.new_game();
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
b.iter(|| black_box(env.observation(&s, 0)))
|
||||
});
|
||||
|
||||
// ── full random game ──────────────────────────────────────────────────
|
||||
group.sample_size(50);
|
||||
group.bench_function("random_game", |b| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| {
|
||||
let mut s = env.new_game();
|
||||
loop {
|
||||
match env.current_player(&s) {
|
||||
Player::Terminal => break,
|
||||
Player::Chance => env.apply_chance(&mut s, &mut rng),
|
||||
_ => {
|
||||
let actions = env.legal_actions(&s);
|
||||
let idx = rng.random_range(0..actions.len());
|
||||
env.apply(&mut s, actions[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
black_box(s)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 2. Network inference ───────────────────────────────────────────────────
|
||||
|
||||
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
|
||||
fn bench_network(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("network");
|
||||
group.measurement_time(Duration::from_secs(5));
|
||||
|
||||
for &hidden in &[64usize, 256] {
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
|
||||
let obs: Vec<f32> = vec![0.5; 217];
|
||||
|
||||
// Batch size 1 — single-position evaluation as in MCTS.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp_b1", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
// Batch size 32 — training mini-batch.
|
||||
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp_b32", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 3. MCTS ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// MCTS cost at different simulation budgets with two evaluator types:
|
||||
/// - `zero` — isolates tree-traversal overhead (no network).
|
||||
/// - `mlp64` — real MLP, shows end-to-end cost per move.
|
||||
fn bench_mcts(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
|
||||
// Build a decision-node state (after dice roll).
|
||||
let state = {
|
||||
let mut s = env.new_game();
|
||||
let mut rng = seeded();
|
||||
while env.current_player(&s).is_chance() {
|
||||
env.apply_chance(&mut s, &mut rng);
|
||||
}
|
||||
s
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group("mcts");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
let zero_eval = ZeroEval(514);
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
|
||||
|
||||
for &n_sim in &[1usize, 5, 20] {
|
||||
let cfg = MctsConfig {
|
||||
n_simulations: n_sim,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
// Zero evaluator: tree traversal only.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("zero_eval", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
// MLP evaluator: full cost per decision.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp64", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 4. Episode generation ─────────────────────────────────────────────────
|
||||
|
||||
/// Full self-play episode latency (one complete game) at different MCTS
|
||||
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
|
||||
fn bench_episode(c: &mut Criterion) {
|
||||
let env = TrictracEnv;
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
|
||||
|
||||
let mut group = c.benchmark_group("episode");
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
|
||||
for &n_sim in &[1usize, 2] {
|
||||
let mcts_cfg = MctsConfig {
|
||||
n_simulations: n_sim,
|
||||
c_puct: 1.5,
|
||||
dirichlet_alpha: 0.0,
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("trictrac", n_sim),
|
||||
&n_sim,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
seeded,
|
||||
|mut rng| {
|
||||
black_box(generate_episode(
|
||||
&env,
|
||||
&eval,
|
||||
&mcts_cfg,
|
||||
&|_| 1.0,
|
||||
&mut rng,
|
||||
))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── 5. Training step ───────────────────────────────────────────────────────
|
||||
|
||||
/// Gradient-step latency for different batch sizes.
|
||||
fn bench_train(c: &mut Criterion) {
|
||||
use burn::optim::AdamConfig;
|
||||
|
||||
let mut group = c.benchmark_group("train");
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||
|
||||
let dummy_samples = |n: usize| -> Vec<TrainSample> {
|
||||
(0..n)
|
||||
.map(|i| TrainSample {
|
||||
obs: vec![0.5; 217],
|
||||
policy: {
|
||||
let mut p = vec![0.0f32; 514];
|
||||
p[i % 514] = 1.0;
|
||||
p
|
||||
},
|
||||
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
for &batch_size in &[16usize, 64] {
|
||||
let batch = dummy_samples(batch_size);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("mlp64_adam", batch_size),
|
||||
&batch_size,
|
||||
|b, _| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
(
|
||||
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
|
||||
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
|
||||
)
|
||||
},
|
||||
|(model, mut opt)| {
|
||||
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ── Criterion entry point ──────────────────────────────────────────────────
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_env,
|
||||
bench_network,
|
||||
bench_mcts,
|
||||
bench_episode,
|
||||
bench_train,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
Loading…
Add table
Add a link
Reference in a new issue