Compare commits

...

5 commits

9 changed files with 1272 additions and 5 deletions

120
Cargo.lock generated
View file

@ -92,6 +92,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.21"
@ -1116,6 +1122,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cast_trait"
version = "0.1.2"
@ -1200,6 +1212,33 @@ dependencies = [
"rand 0.7.3",
]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
@ -1453,6 +1492,42 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "critical-section"
version = "1.2.0"
@ -4461,6 +4536,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@ -4597,6 +4678,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.18.0"
@ -5897,6 +6006,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"burn",
"criterion",
"rand 0.9.2",
"rand_distr",
"trictrac-store",
@ -6310,6 +6420,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"

View file

@ -9,3 +9,10 @@ anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "alphazero"
harness = false

View file

@ -0,0 +1,373 @@
//! AlphaZero pipeline benchmarks.
//!
//! Run with:
//!
//! ```sh
//! cargo bench -p spiel_bot
//! ```
//!
//! Use `-- <filter>` to run a specific group, e.g.:
//!
//! ```sh
//! cargo bench -p spiel_bot -- env/
//! cargo bench -p spiel_bot -- network/
//! cargo bench -p spiel_bot -- mcts/
//! cargo bench -p spiel_bot -- episode/
//! cargo bench -p spiel_bot -- train/
//! ```
//!
//! Target: ≥ 500 games/s for random play on CPU (consistent with
//! `random_game` throughput in `trictrac-store`).
use std::time::Duration;
use burn::{
backend::NdArray,
tensor::{Tensor, TensorData, backend::Backend},
};
use criterion::{BatchSize, BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
// ── Shared types ───────────────────────────────────────────────────────────
type InferB = NdArray<f32>;
type TrainB = burn::backend::Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn seeded() -> SmallRng { SmallRng::seed_from_u64(0) }
/// Uniform evaluator (returns zero logits and zero value).
/// Used to isolate MCTS tree-traversal cost from network cost.
struct ZeroEval(usize);
impl Evaluator for ZeroEval {
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
(vec![0.0f32; self.0], 0.0)
}
}
// ── 1. Environment primitives ──────────────────────────────────────────────
/// Baseline performance of the raw Trictrac environment without MCTS.
/// Target: ≥ 500 full games / second.
fn bench_env(c: &mut Criterion) {
let env = TrictracEnv;
let mut group = c.benchmark_group("env");
group.measurement_time(Duration::from_secs(10));
// ── apply_chance ──────────────────────────────────────────────────────
group.bench_function("apply_chance", |b| {
b.iter_batched(
|| {
// A fresh game is always at RollDice (Chance) — ready for apply_chance.
env.new_game()
},
|mut s| {
env.apply_chance(&mut s, &mut seeded());
black_box(s)
},
BatchSize::SmallInput,
)
});
// ── legal_actions ─────────────────────────────────────────────────────
group.bench_function("legal_actions", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.legal_actions(&s)))
});
// ── observation (to_tensor) ───────────────────────────────────────────
group.bench_function("observation", |b| {
let mut rng = seeded();
let mut s = env.new_game();
env.apply_chance(&mut s, &mut rng);
b.iter(|| black_box(env.observation(&s, 0)))
});
// ── full random game ──────────────────────────────────────────────────
group.sample_size(50);
group.bench_function("random_game", |b| {
b.iter_batched(
seeded,
|mut rng| {
let mut s = env.new_game();
loop {
match env.current_player(&s) {
Player::Terminal => break,
Player::Chance => env.apply_chance(&mut s, &mut rng),
_ => {
let actions = env.legal_actions(&s);
let idx = rng.random_range(0..actions.len());
env.apply(&mut s, actions[idx]);
}
}
}
black_box(s)
},
BatchSize::SmallInput,
)
});
group.finish();
}
// ── 2. Network inference ───────────────────────────────────────────────────
/// Forward-pass latency for MLP variants (hidden = 64 / 256).
fn bench_network(c: &mut Criterion) {
let mut group = c.benchmark_group("network");
group.measurement_time(Duration::from_secs(5));
for &hidden in &[64usize, 256] {
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = MlpNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
// Batch size 1 — single-position evaluation as in MCTS.
group.bench_with_input(
BenchmarkId::new("mlp_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
// Batch size 32 — training mini-batch.
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("mlp_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
// ── ResNet (4 residual blocks) ────────────────────────────────────────
for &hidden in &[256usize, 512] {
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = ResNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
group.bench_with_input(
BenchmarkId::new("resnet_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("resnet_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}
// ── 3. MCTS ───────────────────────────────────────────────────────────────
/// MCTS cost at different simulation budgets with two evaluator types:
/// - `zero` — isolates tree-traversal overhead (no network).
/// - `mlp64` — real MLP, shows end-to-end cost per move.
fn bench_mcts(c: &mut Criterion) {
let env = TrictracEnv;
// Build a decision-node state (after dice roll).
let state = {
let mut s = env.new_game();
let mut rng = seeded();
while env.current_player(&s).is_chance() {
env.apply_chance(&mut s, &mut rng);
}
s
};
let mut group = c.benchmark_group("mcts");
group.measurement_time(Duration::from_secs(10));
let zero_eval = ZeroEval(514);
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mlp_model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let mlp_eval = BurnEvaluator::<InferB, _>::new(mlp_model, infer_device());
for &n_sim in &[1usize, 5, 20] {
let cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
// Zero evaluator: tree traversal only.
group.bench_with_input(
BenchmarkId::new("zero_eval", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &zero_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
// MLP evaluator: full cost per decision.
group.bench_with_input(
BenchmarkId::new("mlp64", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| black_box(run_mcts(&env, &state, &mlp_eval, &cfg, &mut rng)),
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 4. Episode generation ─────────────────────────────────────────────────
/// Full self-play episode latency (one complete game) at different MCTS
/// simulation budgets. Target: ≥ 1 game/s at n_sim=20 on CPU.
fn bench_episode(c: &mut Criterion) {
let env = TrictracEnv;
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let model = MlpNet::<InferB>::new(&mlp_cfg, &infer_device());
let eval = BurnEvaluator::<InferB, _>::new(model, infer_device());
let mut group = c.benchmark_group("episode");
group.sample_size(10);
group.measurement_time(Duration::from_secs(60));
for &n_sim in &[1usize, 2] {
let mcts_cfg = MctsConfig {
n_simulations: n_sim,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
};
group.bench_with_input(
BenchmarkId::new("trictrac", n_sim),
&n_sim,
|b, _| {
b.iter_batched(
seeded,
|mut rng| {
black_box(generate_episode(
&env,
&eval,
&mcts_cfg,
&|_| 1.0,
&mut rng,
))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── 5. Training step ───────────────────────────────────────────────────────
/// Gradient-step latency for different batch sizes.
fn bench_train(c: &mut Criterion) {
use burn::optim::AdamConfig;
let mut group = c.benchmark_group("train");
group.measurement_time(Duration::from_secs(10));
let mlp_cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let dummy_samples = |n: usize| -> Vec<TrainSample> {
(0..n)
.map(|i| TrainSample {
obs: vec![0.5; 217],
policy: {
let mut p = vec![0.0f32; 514];
p[i % 514] = 1.0;
p
},
value: if i % 2 == 0 { 1.0 } else { -1.0 },
})
.collect()
};
for &batch_size in &[16usize, 64] {
let batch = dummy_samples(batch_size);
group.bench_with_input(
BenchmarkId::new("mlp64_adam", batch_size),
&batch_size,
|b, _| {
b.iter_batched(
|| {
(
MlpNet::<TrainB>::new(&mlp_cfg, &train_device()),
AdamConfig::new().init::<TrainB, MlpNet<TrainB>>(),
)
},
|(model, mut opt)| {
black_box(train_step(model, &mut opt, &batch, &train_device(), 1e-3))
},
BatchSize::SmallInput,
)
},
);
}
group.finish();
}
// ── Criterion entry point ──────────────────────────────────────────────────
criterion_group!(
benches,
bench_env,
bench_network,
bench_mcts,
bench_episode,
bench_train,
);
criterion_main!(benches);

View file

@ -65,7 +65,7 @@ pub mod trainer;
pub use replay::{ReplayBuffer, TrainSample};
pub use selfplay::{BurnEvaluator, generate_episode};
pub use trainer::train_step;
pub use trainer::{cosine_lr, train_step};
use crate::mcts::MctsConfig;
@ -87,8 +87,17 @@ pub struct AlphaZeroConfig {
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Adam learning rate.
/// Initial (peak) Adam learning rate.
pub learning_rate: f64,
/// Minimum learning rate for cosine annealing (floor of the schedule).
///
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
/// Compute the current LR with [`cosine_lr`]:
///
/// ```rust,ignore
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
/// ```
pub lr_min: f64,
/// Number of outer iterations (self-play + train) to run.
pub n_iterations: usize,
/// Move index after which the action temperature drops to 0 (greedy play).
@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig {
batch_size: 64,
replay_capacity: 50_000,
learning_rate: 1e-3,
lr_min: 1e-4, // cosine annealing floor
n_iterations: 100,
temperature_drop_move: 30,
}

View file

@ -5,6 +5,24 @@
//! - **Value loss** — mean-squared error between the predicted value and the
//! actual game outcome.
//!
//! # Learning-rate scheduling
//!
//! [`cosine_lr`] implements one-cycle cosine annealing:
//!
//! ```text
//! lr(t) = lr_min + 0.5 · (lr_max lr_min) · (1 + cos(π · t / T))
//! ```
//!
//! Typical usage in the outer loop:
//!
//! ```rust,ignore
//! for step in 0..total_train_steps {
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
//! model = m;
//! }
//! ```
//!
//! # Backend
//!
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
@ -96,6 +114,30 @@ where
(model, loss_scalar)
}
// ── Learning-rate schedule ─────────────────────────────────────────────────
/// Cosine learning-rate schedule (one half-period, no warmup).
///
/// Returns the learning rate for training step `step` out of `total_steps`:
///
/// ```text
/// lr(t) = lr_min + 0.5 · (initial lr_min) · (1 + cos(π · t / total))
/// ```
///
/// - At `t = 0` returns `initial`.
/// - At `t = total_steps` (or beyond) returns `lr_min`.
///
/// # Panics
///
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
if total_steps == 0 || step >= total_steps {
return lr_min;
}
let progress = step as f64 / total_steps as f64;
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
@ -169,4 +211,48 @@ mod tests {
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite());
}
// ── cosine_lr ─────────────────────────────────────────────────────────
#[test]
fn cosine_lr_at_step_zero_is_initial() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
}
#[test]
fn cosine_lr_at_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
}
#[test]
fn cosine_lr_beyond_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
}
#[test]
fn cosine_lr_midpoint_is_average() {
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
let expected = (1e-3 + 1e-5) / 2.0;
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
}
#[test]
fn cosine_lr_monotone_decreasing() {
let mut prev = f64::INFINITY;
for step in 0..=100 {
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
prev = lr;
}
}
#[test]
fn cosine_lr_zero_total_steps_returns_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
assert!((lr - 1e-5).abs() < 1e-10);
}
}

View file

@ -0,0 +1,262 @@
//! Evaluate a trained AlphaZero checkpoint against a random player.
//!
//! # Usage
//!
//! ```sh
//! # Random weights (sanity check — should be ~50 %)
//! cargo run -p spiel_bot --bin az_eval --release
//!
//! # Trained MLP checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch mlp --n-games 200 --n-sim 50
//!
//! # Trained ResNet checkpoint
//! cargo run -p spiel_bot --bin az_eval --release -- \
//! --checkpoint model.mpk --arch resnet --hidden 512 --n-games 100 --n-sim 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--checkpoint <path>` | (none) | Load weights from `.mpk` file; random weights if omitted |
//! | `--arch mlp\|resnet` | `mlp` | Network architecture |
//! | `--hidden <N>` | 256 (mlp) / 512 (resnet) | Hidden size |
//! | `--n-games <N>` | `100` | Games per side (total = 2 × N) |
//! | `--n-sim <N>` | `50` | MCTS simulations per move |
//! | `--seed <N>` | `42` | RNG seed |
//! | `--c-puct <F>` | `1.5` | PUCT exploration constant |
use std::path::PathBuf;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng, Rng};
use spiel_bot::{
alphazero::BurnEvaluator,
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts, select_action},
network::{MlpConfig, MlpNet, ResNet, ResNetConfig},
};
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
checkpoint: Option<PathBuf>,
arch: String,
hidden: Option<usize>,
n_games: usize,
n_sim: usize,
seed: u64,
c_puct: f32,
}
impl Default for Args {
fn default() -> Self {
Self {
checkpoint: None,
arch: "mlp".into(),
hidden: None,
n_games: 100,
n_sim: 50,
seed: 42,
c_puct: 1.5,
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut args = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--checkpoint" => { i += 1; args.checkpoint = Some(PathBuf::from(&raw[i])); }
"--arch" => { i += 1; args.arch = raw[i].clone(); }
"--hidden" => { i += 1; args.hidden = Some(raw[i].parse().expect("--hidden must be an integer")); }
"--n-games" => { i += 1; args.n_games = raw[i].parse().expect("--n-games must be an integer"); }
"--n-sim" => { i += 1; args.n_sim = raw[i].parse().expect("--n-sim must be an integer"); }
"--seed" => { i += 1; args.seed = raw[i].parse().expect("--seed must be an integer"); }
"--c-puct" => { i += 1; args.c_puct = raw[i].parse().expect("--c-puct must be a float"); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
args
}
// ── Game loop ─────────────────────────────────────────────────────────────────
/// Play one complete game.
///
/// `mcts_side` — 0 means MctsAgent plays as P1 (White), 1 means P2 (Black).
/// Returns `[r1, r2]` — P1 and P2 outcomes (+1 / -1 / 0).
fn play_game(
env: &TrictracEnv,
mcts_side: usize,
evaluator: &dyn Evaluator,
mcts_cfg: &MctsConfig,
rng: &mut SmallRng,
) -> [f32; 2] {
let mut state = env.new_game();
loop {
match env.current_player(&state) {
Player::Terminal => {
return env.returns(&state).expect("Terminal state must have returns");
}
Player::Chance => env.apply_chance(&mut state, rng),
player => {
let side = player.index().unwrap(); // 0 = P1, 1 = P2
let action = if side == mcts_side {
let root = run_mcts(env, &state, evaluator, mcts_cfg, rng);
select_action(&root, 0.0, rng) // greedy (temperature = 0)
} else {
let actions = env.legal_actions(&state);
actions[rng.random_range(0..actions.len())]
};
env.apply(&mut state, action);
}
}
}
}
// ── Statistics ────────────────────────────────────────────────────────────────
#[derive(Default)]
struct Stats {
wins: u32,
draws: u32,
losses: u32,
}
impl Stats {
fn record(&mut self, mcts_return: f32) {
if mcts_return > 0.0 { self.wins += 1; }
else if mcts_return < 0.0 { self.losses += 1; }
else { self.draws += 1; }
}
fn total(&self) -> u32 { self.wins + self.draws + self.losses }
fn win_rate_decisive(&self) -> f64 {
let d = self.wins + self.losses;
if d == 0 { 0.5 } else { self.wins as f64 / d as f64 }
}
fn print(&self) {
let n = self.total();
let pct = |k: u32| 100.0 * k as f64 / n as f64;
println!(
" Win {}/{n} ({:.1}%) Draw {}/{n} ({:.1}%) Loss {}/{n} ({:.1}%)",
self.wins, pct(self.wins), self.draws, pct(self.draws), self.losses, pct(self.losses),
);
}
}
// ── Evaluation ────────────────────────────────────────────────────────────────
fn run_evaluation(
evaluator: &dyn Evaluator,
n_games: usize,
mcts_cfg: &MctsConfig,
seed: u64,
) -> (Stats, Stats) {
let env = TrictracEnv;
let total = n_games * 2;
let mut as_p1 = Stats::default();
let mut as_p2 = Stats::default();
for i in 0..total {
// Alternate sides: even games → MctsAgent as P1, odd → as P2.
let mcts_side = i % 2;
let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(i as u64));
let result = play_game(&env, mcts_side, evaluator, mcts_cfg, &mut rng);
let mcts_return = result[mcts_side];
if mcts_side == 0 { as_p1.record(mcts_return); } else { as_p2.record(mcts_return); }
let done = i + 1;
if done % 10 == 0 || done == total {
eprint!("\r [{done}/{total}] ", );
}
}
eprintln!();
(as_p1, as_p2)
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
let device: <InferB as burn::tensor::backend::Backend>::Device = Default::default();
// ── Load model ────────────────────────────────────────────────────────
let evaluator: Box<dyn Evaluator> = match args.arch.as_str() {
"resnet" => {
let hidden = args.hidden.unwrap_or(512);
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => ResNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => ResNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, ResNet<InferB>>::new(model, device))
}
"mlp" | _ => {
let hidden = args.hidden.unwrap_or(256);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = match &args.checkpoint {
Some(path) => MlpNet::<InferB>::load(&cfg, path, &device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }),
None => MlpNet::new(&cfg, &device),
};
Box::new(BurnEvaluator::<InferB, MlpNet<InferB>>::new(model, device))
}
};
let mcts_cfg = MctsConfig {
n_simulations: args.n_sim,
c_puct: args.c_puct,
dirichlet_alpha: 0.0, // no exploration noise during evaluation
dirichlet_eps: 0.0,
temperature: 0.0, // greedy action selection
};
// ── Header ────────────────────────────────────────────────────────────
let ckpt_label = args.checkpoint
.as_deref()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("random weights");
println!();
println!("az_eval — MctsAgent ({}, {ckpt_label}, n_sim={}) vs RandomAgent",
args.arch, args.n_sim);
println!("Games per side: {} | Total: {} | Seed: {}",
args.n_games, args.n_games * 2, args.seed);
println!();
// ── Run ───────────────────────────────────────────────────────────────
let (as_p1, as_p2) = run_evaluation(evaluator.as_ref(), args.n_games, &mcts_cfg, args.seed);
// ── Results ───────────────────────────────────────────────────────────
println!("MctsAgent as P1 (White):");
as_p1.print();
println!("MctsAgent as P2 (Black):");
as_p2.print();
let combined_wins = as_p1.wins + as_p2.wins;
let combined_decisive = combined_wins + as_p1.losses + as_p2.losses;
let combined_wr = if combined_decisive == 0 { 0.5 }
else { combined_wins as f64 / combined_decisive as f64 };
println!();
println!("Combined win rate (excluding draws): {:.1}% [{}/{}]",
combined_wr * 100.0, combined_wins, combined_decisive);
println!(" P1 decisive: {:.1}% | P2 decisive: {:.1}%",
as_p1.win_rate_decisive() * 100.0,
as_p2.win_rate_decisive() * 100.0);
}

View file

@ -401,8 +401,12 @@ mod tests {
};
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
assert!(root.n > 0);
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
assert_eq!(root.n, 1 + config.n_simulations as u32);
// Children visit counts may sum to less than n_simulations when some
// simulations cross a chance node at depth 1 (turn ends after one move)
// and evaluate with the network directly without updating child.n.
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert_eq!(total, 5);
assert!(total <= config.n_simulations as u32);
}
}

View file

@ -138,8 +138,14 @@ pub(super) fn simulate<E: GameEnv>(
// ── Apply action + advance through any chance nodes ───────────────────
let mut next_state = state;
env.apply(&mut next_state, action);
// Track whether we crossed a chance node (dice roll) on the way down.
// If we did, the child's cached legal actions are for a *different* dice
// outcome and must not be reused — evaluate with the network directly.
let mut crossed_chance = false;
while env.current_player(&next_state).is_chance() {
env.apply_chance(&mut next_state, rng);
crossed_chance = true;
}
let next_cp = env.current_player(&next_state);
@ -153,7 +159,15 @@ pub(super) fn simulate<E: GameEnv>(
returns[player_idx]
} else {
let child_player = next_cp.index().unwrap();
let v = if child.expanded {
let v = if crossed_chance {
// Outcome sampling: after dice, evaluate the resulting position
// directly with the network. Do NOT build the tree across chance
// boundaries — the dice change which actions are legal, so any
// previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs);
value
} else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player)
} else {
expand::<E>(child, &next_state, env, evaluator, child_player)

View file

@ -0,0 +1,391 @@
//! End-to-end integration tests for the AlphaZero training pipeline.
//!
//! Each test exercises the full chain:
//! [`GameEnv`] → MCTS → [`generate_episode`] → [`ReplayBuffer`] → [`train_step`]
//!
//! Two environments are used:
//! - **CountdownEnv** — trivial deterministic game, terminates in < 10 moves.
//! Used when we need many iterations without worrying about runtime.
//! - **TrictracEnv** — the real game. Used to verify tensor shapes and that
//! the full pipeline compiles and runs correctly with 217-dim observations
//! and 514-dim action spaces.
//!
//! All tests use `n_simulations = 2` and `hidden_size = 64` to keep
//! runtime minimal; correctness, not training quality, is what matters here.
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
alphazero::{BurnEvaluator, ReplayBuffer, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::MctsConfig,
network::{MlpConfig, MlpNet, PolicyValueNet},
};
// ── Backend aliases ────────────────────────────────────────────────────────
type Train = Autodiff<NdArray<f32>>;
type Infer = NdArray<f32>;
// ── Helpers ────────────────────────────────────────────────────────────────
fn train_device() -> <Train as burn::tensor::backend::Backend>::Device {
Default::default()
}
fn infer_device() -> <Infer as burn::tensor::backend::Backend>::Device {
Default::default()
}
/// Tiny 64-unit MLP, compatible with an obs/action space of any size.
fn tiny_mlp(obs: usize, actions: usize) -> MlpNet<Train> {
let cfg = MlpConfig { obs_size: obs, action_size: actions, hidden_size: 64 };
MlpNet::new(&cfg, &train_device())
}
fn tiny_mcts(n: usize) -> MctsConfig {
MctsConfig {
n_simulations: n,
c_puct: 1.5,
dirichlet_alpha: 0.0,
dirichlet_eps: 0.0,
temperature: 1.0,
}
}
fn seeded() -> SmallRng {
SmallRng::seed_from_u64(0)
}
// ── Countdown environment (fast, local, no external deps) ─────────────────
//
// Two players alternate subtracting 1 or 2 from a counter that starts at N.
// The player who brings the counter to 0 wins.
#[derive(Clone, Debug)]
struct CState {
remaining: u8,
to_move: usize,
}
#[derive(Clone)]
struct CountdownEnv(u8); // starting value
impl GameEnv for CountdownEnv {
type State = CState;
fn new_game(&self) -> CState {
CState { remaining: self.0, to_move: 0 }
}
fn current_player(&self, s: &CState) -> Player {
if s.remaining == 0 { Player::Terminal }
else if s.to_move == 0 { Player::P1 }
else { Player::P2 }
}
fn legal_actions(&self, s: &CState) -> Vec<usize> {
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
}
fn apply(&self, s: &mut CState, action: usize) {
let sub = (action as u8) + 1;
if s.remaining <= sub {
s.remaining = 0;
} else {
s.remaining -= sub;
s.to_move = 1 - s.to_move;
}
}
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
vec![s.remaining as f32 / self.0 as f32, s.to_move as f32]
}
fn obs_size(&self) -> usize { 2 }
fn action_space(&self) -> usize { 2 }
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
if s.remaining != 0 { return None; }
let mut r = [-1.0f32; 2];
r[s.to_move] = 1.0;
Some(r)
}
}
// ── 1. Full loop on CountdownEnv ──────────────────────────────────────────
/// The canonical AlphaZero loop: self-play → replay → train, iterated.
/// Uses CountdownEnv so each game terminates in < 10 moves.
#[test]
fn countdown_full_loop_no_panic() {
let env = CountdownEnv(8);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(1_000);
for _iter in 0..5 {
// Self-play: 3 games per iteration.
for _ in 0..3 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
replay.extend(samples);
}
// Training: 4 gradient steps per iteration.
if replay.len() >= 4 {
for _ in 0..4 {
let batch: Vec<TrainSample> = replay
.sample_batch(4, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
}
}
assert!(replay.len() > 0);
}
// ── 2. Replay buffer invariants ───────────────────────────────────────────
/// After several Countdown games, replay capacity is respected and batch
/// shapes are consistent.
#[test]
fn replay_buffer_capacity_and_shapes() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let capacity = 50;
let mut replay = ReplayBuffer::new(capacity);
for _ in 0..20 {
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
replay.extend(samples);
}
assert!(replay.len() <= capacity, "buffer exceeded capacity");
assert!(replay.len() > 0);
let batch = replay.sample_batch(8, &mut rng);
assert_eq!(batch.len(), 8.min(replay.len()));
for s in &batch {
assert_eq!(s.obs.len(), env.obs_size());
assert_eq!(s.policy.len(), env.action_space());
let policy_sum: f32 = s.policy.iter().sum();
assert!((policy_sum - 1.0).abs() < 1e-4, "policy sums to {policy_sum}");
assert!(s.value.abs() <= 1.0, "value {} out of range", s.value);
}
}
// ── 3. TrictracEnv: sample shapes ─────────────────────────────────────────
/// Verify that one TrictracEnv episode produces samples with the correct
/// tensor dimensions: obs = 217, policy = 514.
#[test]
fn trictrac_sample_shapes() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty(), "Trictrac episode produced no samples");
for (i, s) in samples.iter().enumerate() {
assert_eq!(s.obs.len(), 217, "sample {i}: obs.len() = {}", s.obs.len());
assert_eq!(s.policy.len(), 514, "sample {i}: policy.len() = {}", s.policy.len());
let policy_sum: f32 = s.policy.iter().sum();
assert!(
(policy_sum - 1.0).abs() < 1e-4,
"sample {i}: policy sums to {policy_sum}"
);
assert!(
s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
"sample {i}: unexpected value {}",
s.value
);
}
}
// ── 4. TrictracEnv: training step after real self-play ────────────────────
/// Collect one Trictrac episode, then verify that a gradient step runs
/// without panic and produces a finite loss.
#[test]
fn trictrac_train_step_finite_loss() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(10_000);
// Generate one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|_| 1.0, &mut rng);
assert!(!samples.is_empty());
let n_samples = samples.len();
replay.extend(samples);
// Train on a batch from this episode.
let batch_size = 8.min(n_samples);
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (_, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
assert!(loss.is_finite(), "loss must be finite after Trictrac training, got {loss}");
assert!(loss > 0.0, "loss should be positive");
}
// ── 5. Backend transfer: train → infer → same outputs ─────────────────────
/// Weights transferred from the training backend to the inference backend
/// (via `AutodiffModule::valid()`) must produce bit-identical forward passes.
#[test]
fn valid_model_matches_train_model_outputs() {
use burn::tensor::{Tensor, TensorData};
let cfg = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let train_model = MlpNet::<Train>::new(&cfg, &train_device());
let infer_model: MlpNet<Infer> = train_model.valid();
// Build the same input on both backends.
let obs_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
let obs_train = Tensor::<Train, 2>::from_data(
TensorData::new(obs_data.clone(), [1, 4]),
&train_device(),
);
let obs_infer = Tensor::<Infer, 2>::from_data(
TensorData::new(obs_data, [1, 4]),
&infer_device(),
);
let (p_train, v_train) = train_model.forward(obs_train);
let (p_infer, v_infer) = infer_model.forward(obs_infer);
let p_train: Vec<f32> = p_train.into_data().to_vec().unwrap();
let p_infer: Vec<f32> = p_infer.into_data().to_vec().unwrap();
let v_train: Vec<f32> = v_train.into_data().to_vec().unwrap();
let v_infer: Vec<f32> = v_infer.into_data().to_vec().unwrap();
for (i, (a, b)) in p_train.iter().zip(p_infer.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"policy[{i}] differs after valid(): train={a}, infer={b}"
);
}
assert!(
(v_train[0] - v_infer[0]).abs() < 1e-5,
"value differs after valid(): train={}, infer={}",
v_train[0], v_infer[0]
);
}
// ── 6. Loss converges on a fixed batch ────────────────────────────────────
/// With repeated gradient steps on the same Countdown batch, the loss must
/// decrease monotonically (or at least end lower than it started).
#[test]
fn loss_decreases_on_fixed_batch() {
let env = CountdownEnv(6);
let mut rng = seeded();
let mcts = tiny_mcts(3);
let model = tiny_mlp(env.obs_size(), env.action_space());
let mut optimizer = AdamConfig::new().init();
// Collect a fixed batch from one episode.
let infer = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples: Vec<TrainSample> = generate_episode(&env, &eval, &mcts, &|_| 0.0, &mut rng);
assert!(!samples.is_empty());
let batch: Vec<TrainSample> = {
let mut replay = ReplayBuffer::new(1000);
replay.extend(samples);
replay.sample_batch(replay.len(), &mut rng).into_iter().cloned().collect()
};
// Overfit on the same fixed batch for 20 steps.
let mut model = tiny_mlp(env.obs_size(), env.action_space());
let mut first_loss = f32::NAN;
let mut last_loss = f32::NAN;
for step in 0..20 {
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-2);
model = m;
assert!(loss.is_finite(), "loss is not finite at step {step}");
if step == 0 { first_loss = loss; }
last_loss = loss;
}
assert!(
last_loss < first_loss,
"loss did not decrease after 20 steps: first={first_loss}, last={last_loss}"
);
}
// ── 7. Trictrac: multi-iteration loop ─────────────────────────────────────
/// Two full self-play + train iterations on TrictracEnv.
/// Verifies the entire pipeline runs without panic end-to-end.
#[test]
fn trictrac_two_iteration_loop() {
let env = TrictracEnv;
let mut rng = seeded();
let mcts = tiny_mcts(2);
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 64 };
let mut model = MlpNet::<Train>::new(&cfg, &train_device());
let mut optimizer = AdamConfig::new().init();
let mut replay = ReplayBuffer::new(20_000);
for iter in 0..2 {
// Self-play: 1 game per iteration.
let infer: MlpNet<Infer> = model.valid();
let eval = BurnEvaluator::<Infer, _>::new(infer, infer_device());
let samples = generate_episode(&env, &eval, &mcts, &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
assert!(!samples.is_empty(), "iter {iter}: episode was empty");
replay.extend(samples);
// Training: 3 gradient steps.
let batch_size = 16.min(replay.len());
for _ in 0..3 {
let batch: Vec<TrainSample> = replay
.sample_batch(batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
let (m, loss) = train_step(model, &mut optimizer, &batch, &train_device(), 1e-3);
model = m;
assert!(loss.is_finite(), "iter {iter}: loss={loss}");
}
}
}