diff --git a/Cargo.lock b/Cargo.lock index 0baa02a..34bfe80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -1116,6 +1122,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cast_trait" version = "0.1.2" @@ -1200,6 +1212,33 @@ dependencies = [ "rand 0.7.3", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1453,6 +1492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -4461,6 +4536,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4597,6 +4678,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5897,6 +6006,7 @@ version = "0.1.0" dependencies = [ "anyhow", "burn", + "criterion", "rand 0.9.2", "rand_distr", "trictrac-store", @@ -6310,6 +6420,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 323c953..3848dce 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,3 +9,10 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "alphazero" +harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs new file mode 100644 index 0000000..00d5b02 --- /dev/null +++ b/spiel_bot/benches/alphazero.rs @@ -0,0 +1,373 @@ +//! AlphaZero pipeline benchmarks. +//! +//! Run with: +//! +//! ```sh +//! cargo bench -p spiel_bot +//! ``` +//! +//! Use `-- ` to run a specific group, e.g.: +//! +//! ```sh +//! cargo bench -p spiel_bot -- env/ +//! cargo bench -p spiel_bot -- network/ +//! cargo bench -p spiel_bot -- mcts/ +//! cargo bench -p spiel_bot -- episode/ +//! cargo bench -p spiel_bot -- train/ +//! ``` +//! +//! Target: ≥ 500 games/s for random play on CPU (consistent with +//! `random_game` throughput in `trictrac-store`). + +use std::time::Duration; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData, backend::Backend}, +}; +use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts}, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, +}; + +// ── Shared types ─────────────────────────────────────────────────────────── + +type InferB = NdArray; +type TrainB = burn::backend::Autodiff>; + +fn infer_device() -> ::Device { Default::default() } +fn train_device() -> ::Device { Default::default() } + +fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } + +/// Uniform evaluator (returns zero logits and zero value). +/// Used to isolate MCTS tree-traversal cost from network cost. +struct ZeroEval(usize); +impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } +} + +// ── 1. Environment primitives ────────────────────────────────────────────── + +/// Baseline performance of the raw Trictrac environment without MCTS. +/// Target: ≥ 500 full games / second. +fn bench_env(c: &mut Criterion) { + let env = TrictracEnv; + + let mut group = c.benchmark_group("env"); + group.measurement_time(Duration::from_secs(10)); + + // ── apply_chance ────────────────────────────────────────────────────── + group.bench_function("apply_chance", |b| { + b.iter_batched( + || { + // A fresh game is always at RollDice (Chance) — ready for apply_chance. + env.new_game() + }, + |mut s| { + env.apply_chance(&mut s, &mut seeded()); + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + // ── legal_actions ───────────────────────────────────────────────────── + group.bench_function("legal_actions", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.legal_actions(&s))) + }); + + // ── observation (to_tensor) ─────────────────────────────────────────── + group.bench_function("observation", |b| { + let mut rng = seeded(); + let mut s = env.new_game(); + env.apply_chance(&mut s, &mut rng); + b.iter(|| black_box(env.observation(&s, 0))) + }); + + // ── full random game ────────────────────────────────────────────────── + group.sample_size(50); + group.bench_function("random_game", |b| { + b.iter_batched( + seeded, + |mut rng| { + let mut s = env.new_game(); + loop { + match env.current_player(&s) { + Player::Terminal => break, + Player::Chance => env.apply_chance(&mut s, &mut rng), + _ => { + let actions = env.legal_actions(&s); + let idx = rng.random_range(0..actions.len()); + env.apply(&mut s, actions[idx]); + } + } + } + black_box(s) + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +// ── 2. Network inference ─────────────────────────────────────────────────── + +/// Forward-pass latency for MLP variants (hidden = 64 / 256). +fn bench_network(c: &mut Criterion) { + let mut group = c.benchmark_group("network"); + group.measurement_time(Duration::from_secs(5)); + + for &hidden in &[64usize, 256] { + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = MlpNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + // Batch size 1 — single-position evaluation as in MCTS. + group.bench_with_input( + BenchmarkId::new("mlp_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + // Batch size 32 — training mini-batch. + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("mlp_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + // ── ResNet (4 residual blocks) ──────────────────────────────────────── + for &hidden in &[256usize, 512] { + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = ResNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + group.bench_with_input( + BenchmarkId::new("resnet_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("resnet_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + + group.finish(); +} + +// ── 3. MCTS ─────────────────────────────────────────────────────────────── + +/// MCTS cost at different simulation budgets with two evaluator types: +/// - `zero` — isolates tree-traversal overhead (no network). +/// - `mlp64` — real MLP, shows end-to-end cost per move. +fn bench_mcts(c: &mut Criterion) { + let env = TrictracEnv; + + // Build a decision-node state (after dice roll). + let state = { + let mut s = env.new_game(); + let mut rng = seeded(); + while env.current_player(&s).is_chance() { + env.apply_chance(&mut s, &mut rng); + } + s + }; + + let mut group = c.benchmark_group("mcts"); + group.measurement_time(Duration::from_secs(10)); + + let zero_eval = ZeroEval(514); + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); + let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); + + for &n_sim in &[1usize, 5, 20] { + let cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + // Zero evaluator: tree traversal only. + group.bench_with_input( + BenchmarkId::new("zero_eval", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + + // MLP evaluator: full cost per decision. + group.bench_with_input( + BenchmarkId::new("mlp64", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 4. Episode generation ───────────────────────────────────────────────── + +/// Full self-play episode latency (one complete game) at different MCTS +/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. +fn bench_episode(c: &mut Criterion) { + let env = TrictracEnv; + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let model = MlpNet::::new(&mlp_cfg, &infer_device()); + let eval = BurnEvaluator::::new(model, infer_device()); + + let mut group = c.benchmark_group("episode"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(60)); + + for &n_sim in &[1usize, 2] { + let mcts_cfg = MctsConfig { + n_simulations: n_sim, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + }; + + group.bench_with_input( + BenchmarkId::new("trictrac", n_sim), + &n_sim, + |b, _| { + b.iter_batched( + seeded, + |mut rng| { + black_box(generate_episode( + &env, + &eval, + &mcts_cfg, + &|_| 1.0, + &mut rng, + )) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── 5. Training step ─────────────────────────────────────────────────────── + +/// Gradient-step latency for different batch sizes. +fn bench_train(c: &mut Criterion) { + use burn::optim::AdamConfig; + + let mut group = c.benchmark_group("train"); + group.measurement_time(Duration::from_secs(10)); + + let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + + let dummy_samples = |n: usize| -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5; 217], + policy: { + let mut p = vec![0.0f32; 514]; + p[i % 514] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + }; + + for &batch_size in &[16usize, 64] { + let batch = dummy_samples(batch_size); + + group.bench_with_input( + BenchmarkId::new("mlp64_adam", batch_size), + &batch_size, + |b, _| { + b.iter_batched( + || { + ( + MlpNet::::new(&mlp_cfg, &train_device()), + AdamConfig::new().init::>(), + ) + }, + |(model, mut opt)| { + black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) + }, + BatchSize::SmallInput, + ) + }, + ); + } + + group.finish(); +} + +// ── Criterion entry point ────────────────────────────────────────────────── + +criterion_group!( + benches, + bench_env, + bench_network, + bench_mcts, + bench_episode, + bench_train, +); +criterion_main!(benches); diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs index bb86724..d92224e 100644 --- a/spiel_bot/src/alphazero/mod.rs +++ b/spiel_bot/src/alphazero/mod.rs @@ -65,7 +65,7 @@ pub mod trainer; pub use replay::{ReplayBuffer, TrainSample}; pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::train_step; +pub use trainer::{cosine_lr, train_step}; use crate::mcts::MctsConfig; @@ -87,8 +87,17 @@ pub struct AlphaZeroConfig { pub batch_size: usize, /// Maximum number of samples in the replay buffer. pub replay_capacity: usize, - /// Adam learning rate. + /// Initial (peak) Adam learning rate. pub learning_rate: f64, + /// Minimum learning rate for cosine annealing (floor of the schedule). + /// + /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). + /// Compute the current LR with [`cosine_lr`]: + /// + /// ```rust,ignore + /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); + /// ``` + pub lr_min: f64, /// Number of outer iterations (self-play + train) to run. pub n_iterations: usize, /// Move index after which the action temperature drops to 0 (greedy play). @@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig { batch_size: 64, replay_capacity: 50_000, learning_rate: 1e-3, + lr_min: 1e-4, // cosine annealing floor n_iterations: 100, temperature_drop_move: 30, } diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs index d2482d1..9075519 100644 --- a/spiel_bot/src/alphazero/trainer.rs +++ b/spiel_bot/src/alphazero/trainer.rs @@ -5,6 +5,24 @@ //! - **Value loss** — mean-squared error between the predicted value and the //! actual game outcome. //! +//! # Learning-rate scheduling +//! +//! [`cosine_lr`] implements one-cycle cosine annealing: +//! +//! ```text +//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) +//! ``` +//! +//! Typical usage in the outer loop: +//! +//! ```rust,ignore +//! for step in 0..total_train_steps { +//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); +//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); +//! model = m; +//! } +//! ``` +//! //! # Backend //! //! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). @@ -96,6 +114,30 @@ where (model, loss_scalar) } +// ── Learning-rate schedule ───────────────────────────────────────────────── + +/// Cosine learning-rate schedule (one half-period, no warmup). +/// +/// Returns the learning rate for training step `step` out of `total_steps`: +/// +/// ```text +/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) +/// ``` +/// +/// - At `t = 0` returns `initial`. +/// - At `t = total_steps` (or beyond) returns `lr_min`. +/// +/// # Panics +/// +/// Does not panic. When `total_steps == 0`, returns `lr_min`. +pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { + if total_steps == 0 || step >= total_steps { + return lr_min; + } + let progress = step as f64 / total_steps as f64; + lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) +} + // ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -169,4 +211,48 @@ mod tests { let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); assert!(loss.is_finite()); } + + // ── cosine_lr ───────────────────────────────────────────────────────── + + #[test] + fn cosine_lr_at_step_zero_is_initial() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); + assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); + } + + #[test] + fn cosine_lr_at_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); + } + + #[test] + fn cosine_lr_beyond_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); + } + + #[test] + fn cosine_lr_midpoint_is_average() { + // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. + let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); + let expected = (1e-3 + 1e-5) / 2.0; + assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); + } + + #[test] + fn cosine_lr_monotone_decreasing() { + let mut prev = f64::INFINITY; + for step in 0..=100 { + let lr = super::cosine_lr(1e-3, 1e-5, step, 100); + assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); + prev = lr; + } + } + + #[test] + fn cosine_lr_zero_total_steps_returns_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); + assert!((lr - 1e-5).abs() < 1e-10); + } } diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs new file mode 100644 index 0000000..3c82519 --- /dev/null +++ b/spiel_bot/src/bin/az_eval.rs @@ -0,0 +1,262 @@ +//! Evaluate a trained AlphaZero checkpoint against a random player. +//! +//! # Usage +//! +//! ```sh +//! # Random weights (sanity check — should be ~50 %) +//! cargo run -p spiel_bot --bin az_eval --release +//! +//! # Trained MLP checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 +//! +//! # Trained ResNet checkpoint +//! cargo run -p spiel_bot --bin az_eval --release -- \ +//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | +//! | `--arch mlp\|resnet` | `mlp` | Network architecture | +//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | +//! | `--n-games ` | `100` | Games per side (total = 2 × N) | +//! | `--n-sim ` | `50` | MCTS simulations per move | +//! | `--seed ` | `42` | RNG seed | +//! | `--c-puct ` | `1.5` | PUCT exploration constant | + +use std::path::PathBuf; + +use burn::backend::NdArray; +use rand::{SeedableRng, rngs::SmallRng, Rng}; + +use spiel_bot::{ + alphazero::BurnEvaluator, + env::{GameEnv, Player, TrictracEnv}, + mcts::{Evaluator, MctsConfig, run_mcts, select_action}, + network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, +}; + +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + checkpoint: Option, + arch: String, + hidden: Option, + n_games: usize, + n_sim: usize, + seed: u64, + c_puct: f32, +} + +impl Default for Args { + fn default() -> Self { + Self { + checkpoint: None, + arch: "mlp".into(), + hidden: None, + n_games: 100, + n_sim: 50, + seed: 42, + c_puct: 1.5, + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut args = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } + "--arch" => { i += 1; args.arch = raw[i].clone(); } + "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } + "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } + "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } + "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } + "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + args +} + +// ── Game loop ───────────────────────────────────────────────────────────────── + +/// Play one complete game. +/// +/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). +/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). +fn play_game( + env: &TrictracEnv, + mcts_side: usize, + evaluator: &dyn Evaluator, + mcts_cfg: &MctsConfig, + rng: &mut SmallRng, +) -> [f32; 2] { + let mut state = env.new_game(); + loop { + match env.current_player(&state) { + Player::Terminal => { + return env.returns(&state).expect("Terminal state must have returns"); + } + Player::Chance => env.apply_chance(&mut state, rng), + player => { + let side = player.index().unwrap(); // 0 = P1, 1 = P2 + let action = if side == mcts_side { + let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); + select_action(&root, 0.0, rng) // greedy (temperature = 0) + } else { + let actions = env.legal_actions(&state); + actions[rng.random_range(0..actions.len())] + }; + env.apply(&mut state, action); + } + } + } +} + +// ── Statistics ──────────────────────────────────────────────────────────────── + +#[derive(Default)] +struct Stats { + wins: u32, + draws: u32, + losses: u32, +} + +impl Stats { + fn record(&mut self, mcts_return: f32) { + if mcts_return > 0.0 { self.wins += 1; } + else if mcts_return < 0.0 { self.losses += 1; } + else { self.draws += 1; } + } + + fn total(&self) -> u32 { self.wins + self.draws + self.losses } + + fn win_rate_decisive(&self) -> f64 { + let d = self.wins + self.losses; + if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } + } + + fn print(&self) { + let n = self.total(); + let pct = |k: u32| 100.0 * k as f64 / n as f64; + println!( + " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", + self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), + ); + } +} + +// ── Evaluation ──────────────────────────────────────────────────────────────── + +fn run_evaluation( + evaluator: &dyn Evaluator, + n_games: usize, + mcts_cfg: &MctsConfig, + seed: u64, +) -> (Stats, Stats) { + let env = TrictracEnv; + let total = n_games * 2; + let mut as_p1 = Stats::default(); + let mut as_p2 = Stats::default(); + + for i in 0..total { + // Alternate sides: even games → MctsAgent as P1, odd → as P2. + let mcts_side = i % 2; + let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); + let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); + + let mcts_return = result[mcts_side]; + if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } + + let done = i + 1; + if done % 10 == 0 || done == total { + eprint!("\r [{done}/{total}] ", ); + } + } + eprintln!(); + (as_p1, as_p2) +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + let device: ::Device = Default::default(); + + // ── Load model ──────────────────────────────────────────────────────── + let evaluator: Box = match args.arch.as_str() { + "resnet" => { + let hidden = args.hidden.unwrap_or(512); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => ResNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => ResNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + "mlp" | _ => { + let hidden = args.hidden.unwrap_or(256); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = match &args.checkpoint { + Some(path) => MlpNet::::load(&cfg, path, &device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), + None => MlpNet::new(&cfg, &device), + }; + Box::new(BurnEvaluator::>::new(model, device)) + } + }; + + let mcts_cfg = MctsConfig { + n_simulations: args.n_sim, + c_puct: args.c_puct, + dirichlet_alpha: 0.0, // no exploration noise during evaluation + dirichlet_eps: 0.0, + temperature: 0.0, // greedy action selection + }; + + // ── Header ──────────────────────────────────────────────────────────── + let ckpt_label = args.checkpoint + .as_deref() + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .unwrap_or("random weights"); + + println!(); + println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", + args.arch, args.n_sim); + println!("Games per side: {} | Total: {} | Seed: {}", + args.n_games, args.n_games * 2, args.seed); + println!(); + + // ── Run ─────────────────────────────────────────────────────────────── + let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); + + // ── Results ─────────────────────────────────────────────────────────── + println!("MctsAgent as P1 (White):"); + as_p1.print(); + + println!("MctsAgent as P2 (Black):"); + as_p2.print(); + + let combined_wins = as_p1.wins + as_p2.wins; + let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; + let combined_wr = if combined_decisive == 0 { 0.5 } + else { combined_wins as f64 / combined_decisive as f64 }; + + println!(); + println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", + combined_wr * 100.0, combined_wins, combined_decisive); + println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", + as_p1.win_rate_decisive() * 100.0, + as_p2.win_rate_decisive() * 100.0); +} diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index e92bd09..a0a690d 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -401,8 +401,12 @@ mod tests { }; let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - assert!(root.n > 0); + // root.n = 1 (expansion) + n_simulations (one backup per simulation). + assert_eq!(root.n, 1 + config.n_simulations as u32); + // Children visit counts may sum to less than n_simulations when some + // simulations cross a chance node at depth 1 (turn ends after one move) + // and evaluate with the network directly without updating child.n. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 5); + assert!(total <= config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index c4960c7..55db701 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -138,8 +138,14 @@ pub(super) fn simulate( // ── Apply action + advance through any chance nodes ─────────────────── let mut next_state = state; env.apply(&mut next_state, action); + + // Track whether we crossed a chance node (dice roll) on the way down. + // If we did, the child's cached legal actions are for a *different* dice + // outcome and must not be reused — evaluate with the network directly. + let mut crossed_chance = false; while env.current_player(&next_state).is_chance() { env.apply_chance(&mut next_state, rng); + crossed_chance = true; } let next_cp = env.current_player(&next_state); @@ -153,7 +159,15 @@ pub(super) fn simulate( returns[player_idx] } else { let child_player = next_cp.index().unwrap(); - let v = if child.expanded { + let v = if crossed_chance { + // Outcome sampling: after dice, evaluate the resulting position + // directly with the network. Do NOT build the tree across chance + // boundaries — the dice change which actions are legal, so any + // previously cached children would be for a different outcome. + let obs = env.observation(&next_state, child_player); + let (_, value) = evaluator.evaluate(&obs); + value + } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) } else { expand::(child, &next_state, env, evaluator, child_player) diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs new file mode 100644 index 0000000..d73fda0 --- /dev/null +++ b/spiel_bot/tests/integration.rs @@ -0,0 +1,391 @@ +//! End-to-end integration tests for the AlphaZero training pipeline. +//! +//! Each test exercises the full chain: +//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] +//! +//! Two environments are used: +//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. +//! Used when we need many iterations without worrying about runtime. +//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that +//! the full pipeline compiles and runs correctly with 217-dim observations +//! and 514-dim action spaces. +//! +//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep +//! runtime minimal; correctness, not training quality, is what matters here. + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, + env::{GameEnv, Player, TrictracEnv}, + mcts::MctsConfig, + network::{MlpConfig, MlpNet, PolicyValueNet}, +}; + +// ── Backend aliases ──────────────────────────────────────────────────────── + +type Train = Autodiff>; +type Infer = NdArray; + +// ── Helpers ──────────────────────────────────────────────────────────────── + +fn train_device() -> ::Device { + Default::default() +} + +fn infer_device() -> ::Device { + Default::default() +} + +/// Tiny 64-unit MLP, compatible with an obs/action space of any size. +fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { + let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; + MlpNet::new(&cfg, &train_device()) +} + +fn tiny_mcts(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + temperature: 1.0, + } +} + +fn seeded() -> SmallRng { + SmallRng::seed_from_u64(0) +} + +// ── Countdown environment (fast, local, no external deps) ───────────────── +// +// Two players alternate subtracting 1 or 2 from a counter that starts at N. +// The player who brings the counter to 0 wins. + +#[derive(Clone, Debug)] +struct CState { + remaining: u8, + to_move: usize, +} + +#[derive(Clone)] +struct CountdownEnv(u8); // starting value + +impl GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: self.0, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } + else { Player::P2 } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } +} + +// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── + +/// The canonical AlphaZero loop: self-play → replay → train, iterated. +/// Uses CountdownEnv so each game terminates in < 10 moves. +#[test] +fn countdown_full_loop_no_panic() { + let env = CountdownEnv(8); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(1_000); + + for _iter in 0..5 { + // Self-play: 3 games per iteration. + for _ in 0..3 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + replay.extend(samples); + } + + // Training: 4 gradient steps per iteration. + if replay.len() >= 4 { + for _ in 0..4 { + let batch: Vec = replay + .sample_batch(4, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + } + } + + assert!(replay.len() > 0); +} + +// ── 2. Replay buffer invariants ─────────────────────────────────────────── + +/// After several Countdown games, replay capacity is respected and batch +/// shapes are consistent. +#[test] +fn replay_buffer_capacity_and_shapes() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let capacity = 50; + let mut replay = ReplayBuffer::new(capacity); + + for _ in 0..20 { + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + replay.extend(samples); + } + + assert!(replay.len() <= capacity, "buffer exceeded capacity"); + assert!(replay.len() > 0); + + let batch = replay.sample_batch(8, &mut rng); + assert_eq!(batch.len(), 8.min(replay.len())); + for s in &batch { + assert_eq!(s.obs.len(), env.obs_size()); + assert_eq!(s.policy.len(), env.action_space()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); + assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); + } +} + +// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── + +/// Verify that one TrictracEnv episode produces samples with the correct +/// tensor dimensions: obs = 217, policy = 514. +#[test] +fn trictrac_sample_shapes() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + + assert!(!samples.is_empty(), "Trictrac episode produced no samples"); + + for (i, s) in samples.iter().enumerate() { + assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); + assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); + let policy_sum: f32 = s.policy.iter().sum(); + assert!( + (policy_sum - 1.0).abs() < 1e-4, + "sample {i}: policy sums to {policy_sum}" + ); + assert!( + s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "sample {i}: unexpected value {}", + s.value + ); + } +} + +// ── 4. TrictracEnv: training step after real self-play ──────────────────── + +/// Collect one Trictrac episode, then verify that a gradient step runs +/// without panic and produces a finite loss. +#[test] +fn trictrac_train_step_finite_loss() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(10_000); + + // Generate one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); + assert!(!samples.is_empty()); + let n_samples = samples.len(); + replay.extend(samples); + + // Train on a batch from this episode. + let batch_size = 8.min(n_samples); + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); +} + +// ── 5. Backend transfer: train → infer → same outputs ───────────────────── + +/// Weights transferred from the training backend to the inference backend +/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. +#[test] +fn valid_model_matches_train_model_outputs() { + use burn::tensor::{Tensor, TensorData}; + + let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let train_model = MlpNet::::new(&cfg, &train_device()); + let infer_model: MlpNet = train_model.valid(); + + // Build the same input on both backends. + let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; + + let obs_train = Tensor::::from_data( + TensorData::new(obs_data.clone(), [1, 4]), + &train_device(), + ); + let obs_infer = Tensor::::from_data( + TensorData::new(obs_data, [1, 4]), + &infer_device(), + ); + + let (p_train, v_train) = train_model.forward(obs_train); + let (p_infer, v_infer) = infer_model.forward(obs_infer); + + let p_train: Vec = p_train.into_data().to_vec().unwrap(); + let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); + let v_train: Vec = v_train.into_data().to_vec().unwrap(); + let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); + + for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { + assert!( + (a - b).abs() < 1e-5, + "policy[{i}] differs after valid(): train={a}, infer={b}" + ); + } + assert!( + (v_train[0] - v_infer[0]).abs() < 1e-5, + "value differs after valid(): train={}, infer={}", + v_train[0], v_infer[0] + ); +} + +// ── 6. Loss converges on a fixed batch ──────────────────────────────────── + +/// With repeated gradient steps on the same Countdown batch, the loss must +/// decrease monotonically (or at least end lower than it started). +#[test] +fn loss_decreases_on_fixed_batch() { + let env = CountdownEnv(6); + let mut rng = seeded(); + let mcts = tiny_mcts(3); + let model = tiny_mlp(env.obs_size(), env.action_space()); + let mut optimizer = AdamConfig::new().init(); + + // Collect a fixed batch from one episode. + let infer = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); + assert!(!samples.is_empty()); + + let batch: Vec = { + let mut replay = ReplayBuffer::new(1000); + replay.extend(samples); + replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() + }; + + // Overfit on the same fixed batch for 20 steps. + let mut model = tiny_mlp(env.obs_size(), env.action_space()); + let mut first_loss = f32::NAN; + let mut last_loss = f32::NAN; + + for step in 0..20 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); + model = m; + assert!(loss.is_finite(), "loss is not finite at step {step}"); + if step == 0 { first_loss = loss; } + last_loss = loss; + } + + assert!( + last_loss < first_loss, + "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" + ); +} + +// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── + +/// Two full self-play + train iterations on TrictracEnv. +/// Verifies the entire pipeline runs without panic end-to-end. +#[test] +fn trictrac_two_iteration_loop() { + let env = TrictracEnv; + let mut rng = seeded(); + let mcts = tiny_mcts(2); + + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; + let mut model = MlpNet::::new(&cfg, &train_device()); + let mut optimizer = AdamConfig::new().init(); + let mut replay = ReplayBuffer::new(20_000); + + for iter in 0..2 { + // Self-play: 1 game per iteration. + let infer: MlpNet = model.valid(); + let eval = BurnEvaluator::::new(infer, infer_device()); + let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); + assert!(!samples.is_empty(), "iter {iter}: episode was empty"); + replay.extend(samples); + + // Training: 3 gradient steps. + let batch_size = 16.min(replay.len()); + for _ in 0..3 { + let batch: Vec = replay + .sample_batch(batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); + model = m; + assert!(loss.is_finite(), "iter {iter}: loss={loss}"); + } + } +}