diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..0baa02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,12 +92,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - [[package]] name = "anstream" version = "0.6.21" @@ -1122,12 +1116,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - [[package]] name = "cast_trait" version = "0.1.2" @@ -1212,33 +1200,6 @@ dependencies = [ "rand 0.7.3", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1492,42 +1453,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot", - "is-terminal", - "itertools 0.10.5", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools 0.10.5", -] - [[package]] name = "critical-section" version = "1.2.0" @@ -4536,12 +4461,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" -[[package]] -name = "oorandom" -version = "11.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" - [[package]] name = "opaque-debug" version = "0.3.1" @@ -4678,34 +4597,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" -[[package]] -name = "plotters" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" - -[[package]] -name = "plotters-svg" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" -dependencies = [ - "plotters-backend", -] - [[package]] name = "png" version = "0.18.0" @@ -6006,7 +5897,6 @@ version = "0.1.0" dependencies = [ "anyhow", "burn", - "criterion", "rand 0.9.2", "rand_distr", "trictrac-store", @@ -6420,16 +6310,6 @@ dependencies = [ "zerovec", ] -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "tinyvec" version = "1.10.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..323c953 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -9,10 +9,3 @@ anyhow = "1" rand = "0.9" rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } - -[dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } - -[[bench]] -name = "alphazero" -harness = false diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs deleted file mode 100644 index 00d5b02..0000000 --- a/spiel_bot/benches/alphazero.rs +++ /dev/null @@ -1,373 +0,0 @@ -//! AlphaZero pipeline benchmarks. -//! -//! Run with: -//! -//! ```sh -//! cargo bench -p spiel_bot -//! ``` -//! -//! Use `-- ` to run a specific group, e.g.: -//! -//! ```sh -//! cargo bench -p spiel_bot -- env/ -//! cargo bench -p spiel_bot -- network/ -//! cargo bench -p spiel_bot -- mcts/ -//! cargo bench -p spiel_bot -- episode/ -//! cargo bench -p spiel_bot -- train/ -//! ``` -//! -//! Target: ≥ 500 games/s for random play on CPU (consistent with -//! `random_game` throughput in `trictrac-store`). - -use std::time::Duration; - -use burn::{ - backend::NdArray, - tensor::{Tensor, TensorData, backend::Backend}, -}; -use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; -use rand::{Rng, SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, -}; - -// ── Shared types ─────────────────────────────────────────────────────────── - -type InferB = NdArray; -type TrainB = burn::backend::Autodiff>; - -fn infer_device() -> ::Device { Default::default() } -fn train_device() -> ::Device { Default::default() } - -fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) } - -/// Uniform evaluator (returns zero logits and zero value). -/// Used to isolate MCTS tree-traversal cost from network cost. -struct ZeroEval(usize); -impl Evaluator for ZeroEval { - fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { - (vec![0.0f32; self.0], 0.0) - } -} - -// ── 1. Environment primitives ────────────────────────────────────────────── - -/// Baseline performance of the raw Trictrac environment without MCTS. -/// Target: ≥ 500 full games / second. -fn bench_env(c: &mut Criterion) { - let env = TrictracEnv; - - let mut group = c.benchmark_group("env"); - group.measurement_time(Duration::from_secs(10)); - - // ── apply_chance ────────────────────────────────────────────────────── - group.bench_function("apply_chance", |b| { - b.iter_batched( - || { - // A fresh game is always at RollDice (Chance) — ready for apply_chance. - env.new_game() - }, - |mut s| { - env.apply_chance(&mut s, &mut seeded()); - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - // ── legal_actions ───────────────────────────────────────────────────── - group.bench_function("legal_actions", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.legal_actions(&s))) - }); - - // ── observation (to_tensor) ─────────────────────────────────────────── - group.bench_function("observation", |b| { - let mut rng = seeded(); - let mut s = env.new_game(); - env.apply_chance(&mut s, &mut rng); - b.iter(|| black_box(env.observation(&s, 0))) - }); - - // ── full random game ────────────────────────────────────────────────── - group.sample_size(50); - group.bench_function("random_game", |b| { - b.iter_batched( - seeded, - |mut rng| { - let mut s = env.new_game(); - loop { - match env.current_player(&s) { - Player::Terminal => break, - Player::Chance => env.apply_chance(&mut s, &mut rng), - _ => { - let actions = env.legal_actions(&s); - let idx = rng.random_range(0..actions.len()); - env.apply(&mut s, actions[idx]); - } - } - } - black_box(s) - }, - BatchSize::SmallInput, - ) - }); - - group.finish(); -} - -// ── 2. Network inference ─────────────────────────────────────────────────── - -/// Forward-pass latency for MLP variants (hidden = 64 / 256). -fn bench_network(c: &mut Criterion) { - let mut group = c.benchmark_group("network"); - group.measurement_time(Duration::from_secs(5)); - - for &hidden in &[64usize, 256] { - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = MlpNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - // Batch size 1 — single-position evaluation as in MCTS. - group.bench_with_input( - BenchmarkId::new("mlp_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - // Batch size 32 — training mini-batch. - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("mlp_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - // ── ResNet (4 residual blocks) ──────────────────────────────────────── - for &hidden in &[256usize, 512] { - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = ResNet::::new(&cfg, &infer_device()); - let obs: Vec = vec![0.5; 217]; - - group.bench_with_input( - BenchmarkId::new("resnet_b1", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs.clone(), [1, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - - let obs32: Vec = vec![0.5; 217 * 32]; - group.bench_with_input( - BenchmarkId::new("resnet_b32", hidden), - &hidden, - |b, _| { - b.iter(|| { - let data = TensorData::new(obs32.clone(), [32, 217]); - let t = Tensor::::from_data(data, &infer_device()); - black_box(model.forward(t)) - }) - }, - ); - } - - group.finish(); -} - -// ── 3. MCTS ─────────────────────────────────────────────────────────────── - -/// MCTS cost at different simulation budgets with two evaluator types: -/// - `zero` — isolates tree-traversal overhead (no network). -/// - `mlp64` — real MLP, shows end-to-end cost per move. -fn bench_mcts(c: &mut Criterion) { - let env = TrictracEnv; - - // Build a decision-node state (after dice roll). - let state = { - let mut s = env.new_game(); - let mut rng = seeded(); - while env.current_player(&s).is_chance() { - env.apply_chance(&mut s, &mut rng); - } - s - }; - - let mut group = c.benchmark_group("mcts"); - group.measurement_time(Duration::from_secs(10)); - - let zero_eval = ZeroEval(514); - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mlp_model = MlpNet::::new(&mlp_cfg, &infer_device()); - let mlp_eval = BurnEvaluator::::new(mlp_model, infer_device()); - - for &n_sim in &[1usize, 5, 20] { - let cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - // Zero evaluator: tree traversal only. - group.bench_with_input( - BenchmarkId::new("zero_eval", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - - // MLP evaluator: full cost per decision. - group.bench_with_input( - BenchmarkId::new("mlp64", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)), - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 4. Episode generation ───────────────────────────────────────────────── - -/// Full self-play episode latency (one complete game) at different MCTS -/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU. -fn bench_episode(c: &mut Criterion) { - let env = TrictracEnv; - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let model = MlpNet::::new(&mlp_cfg, &infer_device()); - let eval = BurnEvaluator::::new(model, infer_device()); - - let mut group = c.benchmark_group("episode"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(60)); - - for &n_sim in &[1usize, 2] { - let mcts_cfg = MctsConfig { - n_simulations: n_sim, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - }; - - group.bench_with_input( - BenchmarkId::new("trictrac", n_sim), - &n_sim, - |b, _| { - b.iter_batched( - seeded, - |mut rng| { - black_box(generate_episode( - &env, - &eval, - &mcts_cfg, - &|_| 1.0, - &mut rng, - )) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── 5. Training step ─────────────────────────────────────────────────────── - -/// Gradient-step latency for different batch sizes. -fn bench_train(c: &mut Criterion) { - use burn::optim::AdamConfig; - - let mut group = c.benchmark_group("train"); - group.measurement_time(Duration::from_secs(10)); - - let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - - let dummy_samples = |n: usize| -> Vec { - (0..n) - .map(|i| TrainSample { - obs: vec![0.5; 217], - policy: { - let mut p = vec![0.0f32; 514]; - p[i % 514] = 1.0; - p - }, - value: if i % 2 == 0 { 1.0 } else { -1.0 }, - }) - .collect() - }; - - for &batch_size in &[16usize, 64] { - let batch = dummy_samples(batch_size); - - group.bench_with_input( - BenchmarkId::new("mlp64_adam", batch_size), - &batch_size, - |b, _| { - b.iter_batched( - || { - ( - MlpNet::::new(&mlp_cfg, &train_device()), - AdamConfig::new().init::>(), - ) - }, - |(model, mut opt)| { - black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3)) - }, - BatchSize::SmallInput, - ) - }, - ); - } - - group.finish(); -} - -// ── Criterion entry point ────────────────────────────────────────────────── - -criterion_group!( - benches, - bench_env, - bench_network, - bench_mcts, - bench_episode, - bench_train, -); -criterion_main!(benches); diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs index d92224e..bb86724 100644 --- a/spiel_bot/src/alphazero/mod.rs +++ b/spiel_bot/src/alphazero/mod.rs @@ -65,7 +65,7 @@ pub mod trainer; pub use replay::{ReplayBuffer, TrainSample}; pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::{cosine_lr, train_step}; +pub use trainer::train_step; use crate::mcts::MctsConfig; @@ -87,17 +87,8 @@ pub struct AlphaZeroConfig { pub batch_size: usize, /// Maximum number of samples in the replay buffer. pub replay_capacity: usize, - /// Initial (peak) Adam learning rate. + /// Adam learning rate. pub learning_rate: f64, - /// Minimum learning rate for cosine annealing (floor of the schedule). - /// - /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). - /// Compute the current LR with [`cosine_lr`]: - /// - /// ```rust,ignore - /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); - /// ``` - pub lr_min: f64, /// Number of outer iterations (self-play + train) to run. pub n_iterations: usize, /// Move index after which the action temperature drops to 0 (greedy play). @@ -119,7 +110,6 @@ impl Default for AlphaZeroConfig { batch_size: 64, replay_capacity: 50_000, learning_rate: 1e-3, - lr_min: 1e-4, // cosine annealing floor n_iterations: 100, temperature_drop_move: 30, } diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs index 9075519..d2482d1 100644 --- a/spiel_bot/src/alphazero/trainer.rs +++ b/spiel_bot/src/alphazero/trainer.rs @@ -5,24 +5,6 @@ //! - **Value loss** — mean-squared error between the predicted value and the //! actual game outcome. //! -//! # Learning-rate scheduling -//! -//! [`cosine_lr`] implements one-cycle cosine annealing: -//! -//! ```text -//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) -//! ``` -//! -//! Typical usage in the outer loop: -//! -//! ```rust,ignore -//! for step in 0..total_train_steps { -//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); -//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); -//! model = m; -//! } -//! ``` -//! //! # Backend //! //! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). @@ -114,30 +96,6 @@ where (model, loss_scalar) } -// ── Learning-rate schedule ───────────────────────────────────────────────── - -/// Cosine learning-rate schedule (one half-period, no warmup). -/// -/// Returns the learning rate for training step `step` out of `total_steps`: -/// -/// ```text -/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) -/// ``` -/// -/// - At `t = 0` returns `initial`. -/// - At `t = total_steps` (or beyond) returns `lr_min`. -/// -/// # Panics -/// -/// Does not panic. When `total_steps == 0`, returns `lr_min`. -pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { - if total_steps == 0 || step >= total_steps { - return lr_min; - } - let progress = step as f64 / total_steps as f64; - lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) -} - // ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -211,48 +169,4 @@ mod tests { let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); assert!(loss.is_finite()); } - - // ── cosine_lr ───────────────────────────────────────────────────────── - - #[test] - fn cosine_lr_at_step_zero_is_initial() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); - assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); - } - - #[test] - fn cosine_lr_at_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); - } - - #[test] - fn cosine_lr_beyond_end_is_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); - assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); - } - - #[test] - fn cosine_lr_midpoint_is_average() { - // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. - let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); - let expected = (1e-3 + 1e-5) / 2.0; - assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); - } - - #[test] - fn cosine_lr_monotone_decreasing() { - let mut prev = f64::INFINITY; - for step in 0..=100 { - let lr = super::cosine_lr(1e-3, 1e-5, step, 100); - assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); - prev = lr; - } - } - - #[test] - fn cosine_lr_zero_total_steps_returns_min() { - let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); - assert!((lr - 1e-5).abs() < 1e-10); - } } diff --git a/spiel_bot/src/bin/az_eval.rs b/spiel_bot/src/bin/az_eval.rs deleted file mode 100644 index 3c82519..0000000 --- a/spiel_bot/src/bin/az_eval.rs +++ /dev/null @@ -1,262 +0,0 @@ -//! Evaluate a trained AlphaZero checkpoint against a random player. -//! -//! # Usage -//! -//! ```sh -//! # Random weights (sanity check — should be ~50 %) -//! cargo run -p spiel_bot --bin az_eval --release -//! -//! # Trained MLP checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50 -//! -//! # Trained ResNet checkpoint -//! cargo run -p spiel_bot --bin az_eval --release -- \ -//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--checkpoint ` | (none) | Load weights from `.mpk` file; random weights if omitted | -//! | `--arch mlp\|resnet` | `mlp` | Network architecture | -//! | `--hidden ` | 256 (mlp) / 512 (resnet) | Hidden size | -//! | `--n-games ` | `100` | Games per side (total = 2 × N) | -//! | `--n-sim ` | `50` | MCTS simulations per move | -//! | `--seed ` | `42` | RNG seed | -//! | `--c-puct ` | `1.5` | PUCT exploration constant | - -use std::path::PathBuf; - -use burn::backend::NdArray; -use rand::{SeedableRng, rngs::SmallRng, Rng}; - -use spiel_bot::{ - alphazero::BurnEvaluator, - env::{GameEnv, Player, TrictracEnv}, - mcts::{Evaluator, MctsConfig, run_mcts, select_action}, - network::{MlpConfig, MlpNet, ResNet, ResNetConfig}, -}; - -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - checkpoint: Option, - arch: String, - hidden: Option, - n_games: usize, - n_sim: usize, - seed: u64, - c_puct: f32, -} - -impl Default for Args { - fn default() -> Self { - Self { - checkpoint: None, - arch: "mlp".into(), - hidden: None, - n_games: 100, - n_sim: 50, - seed: 42, - c_puct: 1.5, - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut args = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); } - "--arch" => { i += 1; args.arch = raw[i].clone(); } - "--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); } - "--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); } - "--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); } - "--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); } - "--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - args -} - -// ── Game loop ───────────────────────────────────────────────────────────────── - -/// Play one complete game. -/// -/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black). -/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0). -fn play_game( - env: &TrictracEnv, - mcts_side: usize, - evaluator: &dyn Evaluator, - mcts_cfg: &MctsConfig, - rng: &mut SmallRng, -) -> [f32; 2] { - let mut state = env.new_game(); - loop { - match env.current_player(&state) { - Player::Terminal => { - return env.returns(&state).expect("Terminal state must have returns"); - } - Player::Chance => env.apply_chance(&mut state, rng), - player => { - let side = player.index().unwrap(); // 0 = P1, 1 = P2 - let action = if side == mcts_side { - let root = run_mcts(env, &state, evaluator, mcts_cfg, rng); - select_action(&root, 0.0, rng) // greedy (temperature = 0) - } else { - let actions = env.legal_actions(&state); - actions[rng.random_range(0..actions.len())] - }; - env.apply(&mut state, action); - } - } - } -} - -// ── Statistics ──────────────────────────────────────────────────────────────── - -#[derive(Default)] -struct Stats { - wins: u32, - draws: u32, - losses: u32, -} - -impl Stats { - fn record(&mut self, mcts_return: f32) { - if mcts_return > 0.0 { self.wins += 1; } - else if mcts_return < 0.0 { self.losses += 1; } - else { self.draws += 1; } - } - - fn total(&self) -> u32 { self.wins + self.draws + self.losses } - - fn win_rate_decisive(&self) -> f64 { - let d = self.wins + self.losses; - if d == 0 { 0.5 } else { self.wins as f64 / d as f64 } - } - - fn print(&self) { - let n = self.total(); - let pct = |k: u32| 100.0 * k as f64 / n as f64; - println!( - " Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)", - self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses), - ); - } -} - -// ── Evaluation ──────────────────────────────────────────────────────────────── - -fn run_evaluation( - evaluator: &dyn Evaluator, - n_games: usize, - mcts_cfg: &MctsConfig, - seed: u64, -) -> (Stats, Stats) { - let env = TrictracEnv; - let total = n_games * 2; - let mut as_p1 = Stats::default(); - let mut as_p2 = Stats::default(); - - for i in 0..total { - // Alternate sides: even games → MctsAgent as P1, odd → as P2. - let mcts_side = i % 2; - let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64)); - let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng); - - let mcts_return = result[mcts_side]; - if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); } - - let done = i + 1; - if done % 10 == 0 || done == total { - eprint!("\r [{done}/{total}] ", ); - } - } - eprintln!(); - (as_p1, as_p2) -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - let device: ::Device = Default::default(); - - // ── Load model ──────────────────────────────────────────────────────── - let evaluator: Box = match args.arch.as_str() { - "resnet" => { - let hidden = args.hidden.unwrap_or(512); - let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => ResNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => ResNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - "mlp" | _ => { - let hidden = args.hidden.unwrap_or(256); - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; - let model = match &args.checkpoint { - Some(path) => MlpNet::::load(&cfg, path, &device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }), - None => MlpNet::new(&cfg, &device), - }; - Box::new(BurnEvaluator::>::new(model, device)) - } - }; - - let mcts_cfg = MctsConfig { - n_simulations: args.n_sim, - c_puct: args.c_puct, - dirichlet_alpha: 0.0, // no exploration noise during evaluation - dirichlet_eps: 0.0, - temperature: 0.0, // greedy action selection - }; - - // ── Header ──────────────────────────────────────────────────────────── - let ckpt_label = args.checkpoint - .as_deref() - .and_then(|p| p.file_name()) - .and_then(|n| n.to_str()) - .unwrap_or("random weights"); - - println!(); - println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent", - args.arch, args.n_sim); - println!("Games per side: {} | Total: {} | Seed: {}", - args.n_games, args.n_games * 2, args.seed); - println!(); - - // ── Run ─────────────────────────────────────────────────────────────── - let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed); - - // ── Results ─────────────────────────────────────────────────────────── - println!("MctsAgent as P1 (White):"); - as_p1.print(); - - println!("MctsAgent as P2 (Black):"); - as_p2.print(); - - let combined_wins = as_p1.wins + as_p2.wins; - let combined_decisive = combined_wins + as_p1.losses + as_p2.losses; - let combined_wr = if combined_decisive == 0 { 0.5 } - else { combined_wins as f64 / combined_decisive as f64 }; - - println!(); - println!("Combined win rate (excluding draws): {:.1}% [{}/{}]", - combined_wr * 100.0, combined_wins, combined_decisive); - println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%", - as_p1.win_rate_decisive() * 100.0, - as_p2.win_rate_decisive() * 100.0); -} diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..e92bd09 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -401,12 +401,8 @@ mod tests { }; let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - // root.n = 1 (expansion) + n_simulations (one backup per simulation). - assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + assert!(root.n > 0); let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, 5); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..c4960c7 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -138,14 +138,8 @@ pub(super) fn simulate( // ── Apply action + advance through any chance nodes ─────────────────── let mut next_state = state; env.apply(&mut next_state, action); - - // Track whether we crossed a chance node (dice roll) on the way down. - // If we did, the child's cached legal actions are for a *different* dice - // outcome and must not be reused — evaluate with the network directly. - let mut crossed_chance = false; while env.current_player(&next_state).is_chance() { env.apply_chance(&mut next_state, rng); - crossed_chance = true; } let next_cp = env.current_player(&next_state); @@ -159,15 +153,7 @@ pub(super) fn simulate( returns[player_idx] } else { let child_player = next_cp.index().unwrap(); - let v = if crossed_chance { - // Outcome sampling: after dice, evaluate the resulting position - // directly with the network. Do NOT build the tree across chance - // boundaries — the dice change which actions are legal, so any - // previously cached children would be for a different outcome. - let obs = env.observation(&next_state, child_player); - let (_, value) = evaluator.evaluate(&obs); - value - } else if child.expanded { + let v = if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) } else { expand::(child, &next_state, env, evaluator, child_player) diff --git a/spiel_bot/tests/integration.rs b/spiel_bot/tests/integration.rs deleted file mode 100644 index d73fda0..0000000 --- a/spiel_bot/tests/integration.rs +++ /dev/null @@ -1,391 +0,0 @@ -//! End-to-end integration tests for the AlphaZero training pipeline. -//! -//! Each test exercises the full chain: -//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`] -//! -//! Two environments are used: -//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves. -//! Used when we need many iterations without worrying about runtime. -//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that -//! the full pipeline compiles and runs correctly with 217-dim observations -//! and 514-dim action spaces. -//! -//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep -//! runtime minimal; correctness, not training quality, is what matters here. - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step}, - env::{GameEnv, Player, TrictracEnv}, - mcts::MctsConfig, - network::{MlpConfig, MlpNet, PolicyValueNet}, -}; - -// ── Backend aliases ──────────────────────────────────────────────────────── - -type Train = Autodiff>; -type Infer = NdArray; - -// ── Helpers ──────────────────────────────────────────────────────────────── - -fn train_device() -> ::Device { - Default::default() -} - -fn infer_device() -> ::Device { - Default::default() -} - -/// Tiny 64-unit MLP, compatible with an obs/action space of any size. -fn tiny_mlp(obs: usize, actions: usize) -> MlpNet { - let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 }; - MlpNet::new(&cfg, &train_device()) -} - -fn tiny_mcts(n: usize) -> MctsConfig { - MctsConfig { - n_simulations: n, - c_puct: 1.5, - dirichlet_alpha: 0.0, - dirichlet_eps: 0.0, - temperature: 1.0, - } -} - -fn seeded() -> SmallRng { - SmallRng::seed_from_u64(0) -} - -// ── Countdown environment (fast, local, no external deps) ───────────────── -// -// Two players alternate subtracting 1 or 2 from a counter that starts at N. -// The player who brings the counter to 0 wins. - -#[derive(Clone, Debug)] -struct CState { - remaining: u8, - to_move: usize, -} - -#[derive(Clone)] -struct CountdownEnv(u8); // starting value - -impl GameEnv for CountdownEnv { - type State = CState; - - fn new_game(&self) -> CState { - CState { remaining: self.0, to_move: 0 } - } - - fn current_player(&self, s: &CState) -> Player { - if s.remaining == 0 { Player::Terminal } - else if s.to_move == 0 { Player::P1 } - else { Player::P2 } - } - - fn legal_actions(&self, s: &CState) -> Vec { - if s.remaining >= 2 { vec![0, 1] } else { vec![0] } - } - - fn apply(&self, s: &mut CState, action: usize) { - let sub = (action as u8) + 1; - if s.remaining <= sub { - s.remaining = 0; - } else { - s.remaining -= sub; - s.to_move = 1 - s.to_move; - } - } - - fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} - - fn observation(&self, s: &CState, _pov: usize) -> Vec { - vec![s.remaining as f32 / self.0 as f32, s.to_move as f32] - } - - fn obs_size(&self) -> usize { 2 } - fn action_space(&self) -> usize { 2 } - - fn returns(&self, s: &CState) -> Option<[f32; 2]> { - if s.remaining != 0 { return None; } - let mut r = [-1.0f32; 2]; - r[s.to_move] = 1.0; - Some(r) - } -} - -// ── 1. Full loop on CountdownEnv ────────────────────────────────────────── - -/// The canonical AlphaZero loop: self-play → replay → train, iterated. -/// Uses CountdownEnv so each game terminates in < 10 moves. -#[test] -fn countdown_full_loop_no_panic() { - let env = CountdownEnv(8); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(1_000); - - for _iter in 0..5 { - // Self-play: 3 games per iteration. - for _ in 0..3 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - replay.extend(samples); - } - - // Training: 4 gradient steps per iteration. - if replay.len() >= 4 { - for _ in 0..4 { - let batch: Vec = replay - .sample_batch(4, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - } - } - - assert!(replay.len() > 0); -} - -// ── 2. Replay buffer invariants ─────────────────────────────────────────── - -/// After several Countdown games, replay capacity is respected and batch -/// shapes are consistent. -#[test] -fn replay_buffer_capacity_and_shapes() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let capacity = 50; - let mut replay = ReplayBuffer::new(capacity); - - for _ in 0..20 { - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - replay.extend(samples); - } - - assert!(replay.len() <= capacity, "buffer exceeded capacity"); - assert!(replay.len() > 0); - - let batch = replay.sample_batch(8, &mut rng); - assert_eq!(batch.len(), 8.min(replay.len())); - for s in &batch { - assert_eq!(s.obs.len(), env.obs_size()); - assert_eq!(s.policy.len(), env.action_space()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}"); - assert!(s.value.abs() <= 1.0, "value {} out of range", s.value); - } -} - -// ── 3. TrictracEnv: sample shapes ───────────────────────────────────────── - -/// Verify that one TrictracEnv episode produces samples with the correct -/// tensor dimensions: obs = 217, policy = 514. -#[test] -fn trictrac_sample_shapes() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - - assert!(!samples.is_empty(), "Trictrac episode produced no samples"); - - for (i, s) in samples.iter().enumerate() { - assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len()); - assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len()); - let policy_sum: f32 = s.policy.iter().sum(); - assert!( - (policy_sum - 1.0).abs() < 1e-4, - "sample {i}: policy sums to {policy_sum}" - ); - assert!( - s.value == 1.0 || s.value == -1.0 || s.value == 0.0, - "sample {i}: unexpected value {}", - s.value - ); - } -} - -// ── 4. TrictracEnv: training step after real self-play ──────────────────── - -/// Collect one Trictrac episode, then verify that a gradient step runs -/// without panic and produces a finite loss. -#[test] -fn trictrac_train_step_finite_loss() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(10_000); - - // Generate one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng); - assert!(!samples.is_empty()); - let n_samples = samples.len(); - replay.extend(samples); - - // Train on a batch from this episode. - let batch_size = 8.min(n_samples); - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}"); - assert!(loss > 0.0, "loss should be positive"); -} - -// ── 5. Backend transfer: train → infer → same outputs ───────────────────── - -/// Weights transferred from the training backend to the inference backend -/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes. -#[test] -fn valid_model_matches_train_model_outputs() { - use burn::tensor::{Tensor, TensorData}; - - let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let train_model = MlpNet::::new(&cfg, &train_device()); - let infer_model: MlpNet = train_model.valid(); - - // Build the same input on both backends. - let obs_data: Vec = vec![0.1, 0.2, 0.3, 0.4]; - - let obs_train = Tensor::::from_data( - TensorData::new(obs_data.clone(), [1, 4]), - &train_device(), - ); - let obs_infer = Tensor::::from_data( - TensorData::new(obs_data, [1, 4]), - &infer_device(), - ); - - let (p_train, v_train) = train_model.forward(obs_train); - let (p_infer, v_infer) = infer_model.forward(obs_infer); - - let p_train: Vec = p_train.into_data().to_vec().unwrap(); - let p_infer: Vec = p_infer.into_data().to_vec().unwrap(); - let v_train: Vec = v_train.into_data().to_vec().unwrap(); - let v_infer: Vec = v_infer.into_data().to_vec().unwrap(); - - for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() { - assert!( - (a - b).abs() < 1e-5, - "policy[{i}] differs after valid(): train={a}, infer={b}" - ); - } - assert!( - (v_train[0] - v_infer[0]).abs() < 1e-5, - "value differs after valid(): train={}, infer={}", - v_train[0], v_infer[0] - ); -} - -// ── 6. Loss converges on a fixed batch ──────────────────────────────────── - -/// With repeated gradient steps on the same Countdown batch, the loss must -/// decrease monotonically (or at least end lower than it started). -#[test] -fn loss_decreases_on_fixed_batch() { - let env = CountdownEnv(6); - let mut rng = seeded(); - let mcts = tiny_mcts(3); - let model = tiny_mlp(env.obs_size(), env.action_space()); - let mut optimizer = AdamConfig::new().init(); - - // Collect a fixed batch from one episode. - let infer = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples: Vec = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng); - assert!(!samples.is_empty()); - - let batch: Vec = { - let mut replay = ReplayBuffer::new(1000); - replay.extend(samples); - replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect() - }; - - // Overfit on the same fixed batch for 20 steps. - let mut model = tiny_mlp(env.obs_size(), env.action_space()); - let mut first_loss = f32::NAN; - let mut last_loss = f32::NAN; - - for step in 0..20 { - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2); - model = m; - assert!(loss.is_finite(), "loss is not finite at step {step}"); - if step == 0 { first_loss = loss; } - last_loss = loss; - } - - assert!( - last_loss < first_loss, - "loss did not decrease after 20 steps: first={first_loss}, last={last_loss}" - ); -} - -// ── 7. Trictrac: multi-iteration loop ───────────────────────────────────── - -/// Two full self-play + train iterations on TrictracEnv. -/// Verifies the entire pipeline runs without panic end-to-end. -#[test] -fn trictrac_two_iteration_loop() { - let env = TrictracEnv; - let mut rng = seeded(); - let mcts = tiny_mcts(2); - - let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 }; - let mut model = MlpNet::::new(&cfg, &train_device()); - let mut optimizer = AdamConfig::new().init(); - let mut replay = ReplayBuffer::new(20_000); - - for iter in 0..2 { - // Self-play: 1 game per iteration. - let infer: MlpNet = model.valid(); - let eval = BurnEvaluator::::new(infer, infer_device()); - let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); - assert!(!samples.is_empty(), "iter {iter}: episode was empty"); - replay.extend(samples); - - // Training: 3 gradient steps. - let batch_size = 16.min(replay.len()); - for _ in 0..3 { - let batch: Vec = replay - .sample_batch(batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3); - model = m; - assert!(loss.is_finite(), "iter {iter}: loss={loss}"); - } - } -}