feat(spiel_bot): benchmarks
This commit is contained in:
parent
2329b76f7e
commit
7d37eebe52
3 changed files with 468 additions and 0 deletions
120
Cargo.lock
generated
120
Cargo.lock
generated
|
|
@ -92,6 +92,12 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anes"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstream"
|
name = "anstream"
|
||||||
version = "0.6.21"
|
version = "0.6.21"
|
||||||
|
|
@ -1116,6 +1122,12 @@ version = "0.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cast"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cast_trait"
|
name = "cast_trait"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
|
|
@ -1200,6 +1212,33 @@ dependencies = [
|
||||||
"rand 0.7.3",
|
"rand 0.7.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ciborium"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
|
||||||
|
dependencies = [
|
||||||
|
"ciborium-io",
|
||||||
|
"ciborium-ll",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ciborium-io"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ciborium-ll"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||||
|
dependencies = [
|
||||||
|
"ciborium-io",
|
||||||
|
"half",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cipher"
|
name = "cipher"
|
||||||
version = "0.4.4"
|
version = "0.4.4"
|
||||||
|
|
@ -1453,6 +1492,42 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "criterion"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||||
|
dependencies = [
|
||||||
|
"anes",
|
||||||
|
"cast",
|
||||||
|
"ciborium",
|
||||||
|
"clap",
|
||||||
|
"criterion-plot",
|
||||||
|
"is-terminal",
|
||||||
|
"itertools 0.10.5",
|
||||||
|
"num-traits",
|
||||||
|
"once_cell",
|
||||||
|
"oorandom",
|
||||||
|
"plotters",
|
||||||
|
"rayon",
|
||||||
|
"regex",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
|
"tinytemplate",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "criterion-plot"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||||
|
dependencies = [
|
||||||
|
"cast",
|
||||||
|
"itertools 0.10.5",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "critical-section"
|
name = "critical-section"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
|
@ -4461,6 +4536,12 @@ version = "1.70.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "oorandom"
|
||||||
|
version = "11.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opaque-debug"
|
name = "opaque-debug"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
|
|
@ -4597,6 +4678,34 @@ version = "0.3.32"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "plotters"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
"plotters-backend",
|
||||||
|
"plotters-svg",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "plotters-backend"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "plotters-svg"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
|
||||||
|
dependencies = [
|
||||||
|
"plotters-backend",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "png"
|
name = "png"
|
||||||
version = "0.18.0"
|
version = "0.18.0"
|
||||||
|
|
@ -5897,6 +6006,7 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"burn",
|
"burn",
|
||||||
|
"criterion",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
|
|
@ -6310,6 +6420,16 @@ dependencies = [
|
||||||
"zerovec",
|
"zerovec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tinytemplate"
|
||||||
|
version = "1.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.10.0"
|
version = "1.10.0"
|
||||||
|
|
|
||||||
|
|
@ -9,3 +9,10 @@ anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
rand_distr = "0.5"
|
rand_distr = "0.5"
|
||||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "alphazero"
|
||||||
|
harness = false
|
||||||
|
|
|
||||||
341
spiel_bot/benches/alphazero.rs
Normal file
341
spiel_bot/benches/alphazero.rs
Normal file
|
|
@ -0,0 +1,341 @@
|
||||||
|
//! AlphaZero pipeline benchmarks.
|
||||||
|
//!
|
||||||
|
//! Run with:
|
||||||
|
//!
|
||||||
|
//! ```sh
|
||||||
|
//! cargo bench -p spiel_bot
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Use `-- <filter>` to run a specific group, e.g.:
|
||||||
|
//!
|
||||||
|
//! ```sh
|
||||||
|
//! cargo bench -p spiel_bot -- env/
|
||||||
|
//! cargo bench -p spiel_bot -- network/
|
||||||
|
//! cargo bench -p spiel_bot -- mcts/
|
||||||
|
//! cargo bench -p spiel_bot -- episode/
|
||||||
|
//! cargo bench -p spiel_bot -- train/
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Target: ≥ 500 games/s for random play on CPU (consistent with
|
||||||
|
//! `random_game` throughput in `trictrac-store`).
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
backend::NdArray,
|
||||||
|
tensor::{Tensor, TensorData, backend::Backend},
|
||||||
|
};
|
||||||
|
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||||
|
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
use spiel_bot::{
|
||||||
|
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||||
|
env::{GameEnv, Player, TrictracEnv},
|
||||||
|
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||||
|
network::{MlpConfig, MlpNet, PolicyValueNet},
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Shared types ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type InferB = NdArray<f32>;
|
||||||
|
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||||
|
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||||
|
|
||||||
|
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
|
||||||
|
|
||||||
|
/// Uniform evaluator (returns zero logits and zero value).
|
||||||
|
/// Used to isolate MCTS tree-traversal cost from network cost.
|
||||||
|
struct ZeroEval(usize);
|
||||||
|
impl Evaluator for ZeroEval {
|
||||||
|
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||||
|
(vec![0.0f32; self.0], 0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 1. Environment primitives ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Baseline performance of the raw Trictrac environment without MCTS.
|
||||||
|
/// Target: ≥ 500 full games / second.
|
||||||
|
fn bench_env(c: &mut Criterion) {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("env");
|
||||||
|
group.measurement_time(Duration::from_secs(10));
|
||||||
|
|
||||||
|
// ── apply_chance ──────────────────────────────────────────────────────
|
||||||
|
group.bench_function("apply_chance", |b| {
|
||||||
|
b.iter_batched(
|
||||||
|
|| {
|
||||||
|
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
|
||||||
|
env.new_game()
|
||||||
|
},
|
||||||
|
|mut s| {
|
||||||
|
env.apply_chance(&mut s, &mut seeded());
|
||||||
|
black_box(s)
|
||||||
|
},
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// ── legal_actions ─────────────────────────────────────────────────────
|
||||||
|
group.bench_function("legal_actions", |b| {
|
||||||
|
let mut rng = seeded();
|
||||||
|
let mut s = env.new_game();
|
||||||
|
env.apply_chance(&mut s, &mut rng);
|
||||||
|
b.iter(|| black_box(env.legal_actions(&s)))
|
||||||
|
});
|
||||||
|
|
||||||
|
// ── observation (to_tensor) ───────────────────────────────────────────
|
||||||
|
group.bench_function("observation", |b| {
|
||||||
|
let mut rng = seeded();
|
||||||
|
let mut s = env.new_game();
|
||||||
|
env.apply_chance(&mut s, &mut rng);
|
||||||
|
b.iter(|| black_box(env.observation(&s, 0)))
|
||||||
|
});
|
||||||
|
|
||||||
|
// ── full random game ──────────────────────────────────────────────────
|
||||||
|
group.sample_size(50);
|
||||||
|
group.bench_function("random_game", |b| {
|
||||||
|
b.iter_batched(
|
||||||
|
seeded,
|
||||||
|
|mut rng| {
|
||||||
|
let mut s = env.new_game();
|
||||||
|
loop {
|
||||||
|
match env.current_player(&s) {
|
||||||
|
Player::Terminal => break,
|
||||||
|
Player::Chance => env.apply_chance(&mut s, &mut rng),
|
||||||
|
_ => {
|
||||||
|
let actions = env.legal_actions(&s);
|
||||||
|
let idx = rng.random_range(0..actions.len());
|
||||||
|
env.apply(&mut s, actions[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
black_box(s)
|
||||||
|
},
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 2. Network inference ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
|
||||||
|
fn bench_network(c: &mut Criterion) {
|
||||||
|
let mut group = c.benchmark_group("network");
|
||||||
|
group.measurement_time(Duration::from_secs(5));
|
||||||
|
|
||||||
|
for &hidden in &[64usize, 256] {
|
||||||
|
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||||
|
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
let obs: Vec<f32> = vec![0.5; 217];
|
||||||
|
|
||||||
|
// Batch size 1 — single-position evaluation as in MCTS.
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("mlp_b1", hidden),
|
||||||
|
&hidden,
|
||||||
|
|b, _| {
|
||||||
|
b.iter(|| {
|
||||||
|
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||||
|
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||||
|
black_box(model.forward(t))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Batch size 32 — training mini-batch.
|
||||||
|
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("mlp_b32", hidden),
|
||||||
|
&hidden,
|
||||||
|
|b, _| {
|
||||||
|
b.iter(|| {
|
||||||
|
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||||
|
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||||
|
black_box(model.forward(t))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 3. MCTS ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// MCTS cost at different simulation budgets with two evaluator types:
|
||||||
|
/// - `zero` — isolates tree-traversal overhead (no network).
|
||||||
|
/// - `mlp64` — real MLP, shows end-to-end cost per move.
|
||||||
|
fn bench_mcts(c: &mut Criterion) {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
|
||||||
|
// Build a decision-node state (after dice roll).
|
||||||
|
let state = {
|
||||||
|
let mut s = env.new_game();
|
||||||
|
let mut rng = seeded();
|
||||||
|
while env.current_player(&s).is_chance() {
|
||||||
|
env.apply_chance(&mut s, &mut rng);
|
||||||
|
}
|
||||||
|
s
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("mcts");
|
||||||
|
group.measurement_time(Duration::from_secs(10));
|
||||||
|
|
||||||
|
let zero_eval = ZeroEval(514);
|
||||||
|
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||||
|
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||||
|
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
|
||||||
|
|
||||||
|
for &n_sim in &[1usize, 5, 20] {
|
||||||
|
let cfg = MctsConfig {
|
||||||
|
n_simulations: n_sim,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0,
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
temperature: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Zero evaluator: tree traversal only.
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("zero_eval", n_sim),
|
||||||
|
&n_sim,
|
||||||
|
|b, _| {
|
||||||
|
b.iter_batched(
|
||||||
|
seeded,
|
||||||
|
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// MLP evaluator: full cost per decision.
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("mlp64", n_sim),
|
||||||
|
&n_sim,
|
||||||
|
|b, _| {
|
||||||
|
b.iter_batched(
|
||||||
|
seeded,
|
||||||
|
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 4. Episode generation ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Full self-play episode latency (one complete game) at different MCTS
|
||||||
|
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
|
||||||
|
fn bench_episode(c: &mut Criterion) {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||||
|
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
|
||||||
|
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("episode");
|
||||||
|
group.sample_size(10);
|
||||||
|
group.measurement_time(Duration::from_secs(60));
|
||||||
|
|
||||||
|
for &n_sim in &[1usize, 2] {
|
||||||
|
let mcts_cfg = MctsConfig {
|
||||||
|
n_simulations: n_sim,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0,
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
temperature: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("trictrac", n_sim),
|
||||||
|
&n_sim,
|
||||||
|
|b, _| {
|
||||||
|
b.iter_batched(
|
||||||
|
seeded,
|
||||||
|
|mut rng| {
|
||||||
|
black_box(generate_episode(
|
||||||
|
&env,
|
||||||
|
&eval,
|
||||||
|
&mcts_cfg,
|
||||||
|
&|_| 1.0,
|
||||||
|
&mut rng,
|
||||||
|
))
|
||||||
|
},
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 5. Training step ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Gradient-step latency for different batch sizes.
|
||||||
|
fn bench_train(c: &mut Criterion) {
|
||||||
|
use burn::optim::AdamConfig;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("train");
|
||||||
|
group.measurement_time(Duration::from_secs(10));
|
||||||
|
|
||||||
|
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
|
||||||
|
|
||||||
|
let dummy_samples = |n: usize| -> Vec<TrainSample> {
|
||||||
|
(0..n)
|
||||||
|
.map(|i| TrainSample {
|
||||||
|
obs: vec![0.5; 217],
|
||||||
|
policy: {
|
||||||
|
let mut p = vec![0.0f32; 514];
|
||||||
|
p[i % 514] = 1.0;
|
||||||
|
p
|
||||||
|
},
|
||||||
|
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
for &batch_size in &[16usize, 64] {
|
||||||
|
let batch = dummy_samples(batch_size);
|
||||||
|
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("mlp64_adam", batch_size),
|
||||||
|
&batch_size,
|
||||||
|
|b, _| {
|
||||||
|
b.iter_batched(
|
||||||
|
|| {
|
||||||
|
(
|
||||||
|
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
|
||||||
|
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
|(model, mut opt)| {
|
||||||
|
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
|
||||||
|
},
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Criterion entry point ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
criterion_group!(
|
||||||
|
benches,
|
||||||
|
bench_env,
|
||||||
|
bench_network,
|
||||||
|
bench_mcts,
|
||||||
|
bench_episode,
|
||||||
|
bench_train,
|
||||||
|
);
|
||||||
|
criterion_main!(benches);
|
||||||
Loading…
Add table
Add a link
Reference in a new issue