From eadc1017413dece45ff3c12f420e081a0362f6a9 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 21:00:27 +0100 Subject: [PATCH] feat(spiel_bot): AlphaZero --- spiel_bot/src/alphazero/mod.rs | 117 ++++++++++++++ spiel_bot/src/alphazero/replay.rs | 144 +++++++++++++++++ spiel_bot/src/alphazero/selfplay.rs | 234 ++++++++++++++++++++++++++++ spiel_bot/src/alphazero/trainer.rs | 172 ++++++++++++++++++++ spiel_bot/src/lib.rs | 1 + 5 files changed, 668 insertions(+) create mode 100644 spiel_bot/src/alphazero/mod.rs create mode 100644 spiel_bot/src/alphazero/replay.rs create mode 100644 spiel_bot/src/alphazero/selfplay.rs create mode 100644 spiel_bot/src/alphazero/trainer.rs diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs new file mode 100644 index 0000000..bb86724 --- /dev/null +++ b/spiel_bot/src/alphazero/mod.rs @@ -0,0 +1,117 @@ +//! AlphaZero: self-play data generation, replay buffer, and training step. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] | +//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] | +//! | [`trainer`] | [`train_step`] | +//! +//! # Typical outer loop +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! use burn::optim::AdamConfig; +//! use spiel_bot::{ +//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step}, +//! env::TrictracEnv, +//! mcts::MctsConfig, +//! network::{MlpConfig, MlpNet}, +//! }; +//! +//! type Infer = NdArray; +//! type Train = Autodiff>; +//! +//! let device = Default::default(); +//! let env = TrictracEnv; +//! let config = AlphaZeroConfig::default(); +//! +//! // Build training model and optimizer. +//! let mut train_model = MlpNet::::new(&MlpConfig::default(), &device); +//! let mut optimizer = AdamConfig::new().init(); +//! let mut replay = ReplayBuffer::new(config.replay_capacity); +//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0); +//! +//! for _iter in 0..config.n_iterations { +//! // Convert to inference backend for self-play. +//! let infer_model = MlpNet::::new(&MlpConfig::default(), &device) +//! .load_record(train_model.clone().into_record()); +//! let eval = BurnEvaluator::new(infer_model, device.clone()); +//! +//! // Self-play: generate episodes. +//! for _ in 0..config.n_games_per_iter { +//! let samples = generate_episode(&env, &eval, &config.mcts, +//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng); +//! replay.extend(samples); +//! } +//! +//! // Training: gradient steps. +//! if replay.len() >= config.batch_size { +//! for _ in 0..config.n_train_steps_per_iter { +//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng) +//! .into_iter().cloned().collect(); +//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device, +//! config.learning_rate); +//! train_model = m; +//! } +//! } +//! } +//! ``` + +pub mod replay; +pub mod selfplay; +pub mod trainer; + +pub use replay::{ReplayBuffer, TrainSample}; +pub use selfplay::{BurnEvaluator, generate_episode}; +pub use trainer::train_step; + +use crate::mcts::MctsConfig; + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Top-level AlphaZero hyperparameters. +/// +/// The MCTS parameters live in [`MctsConfig`]; this struct holds the +/// outer training-loop parameters. +#[derive(Debug, Clone)] +pub struct AlphaZeroConfig { + /// MCTS parameters for self-play. + pub mcts: MctsConfig, + /// Number of self-play games per training iteration. + pub n_games_per_iter: usize, + /// Number of gradient steps per training iteration. + pub n_train_steps_per_iter: usize, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Number of outer iterations (self-play + train) to run. + pub n_iterations: usize, + /// Move index after which the action temperature drops to 0 (greedy play). + pub temperature_drop_move: usize, +} + +impl Default for AlphaZeroConfig { + fn default() -> Self { + Self { + mcts: MctsConfig { + n_simulations: 100, + c_puct: 1.5, + dirichlet_alpha: 0.1, + dirichlet_eps: 0.25, + temperature: 1.0, + }, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + batch_size: 64, + replay_capacity: 50_000, + learning_rate: 1e-3, + n_iterations: 100, + temperature_drop_move: 30, + } + } +} diff --git a/spiel_bot/src/alphazero/replay.rs b/spiel_bot/src/alphazero/replay.rs new file mode 100644 index 0000000..5e64cc4 --- /dev/null +++ b/spiel_bot/src/alphazero/replay.rs @@ -0,0 +1,144 @@ +//! Replay buffer for AlphaZero self-play data. + +use std::collections::VecDeque; +use rand::Rng; + +// ── Training sample ──────────────────────────────────────────────────────── + +/// One training example produced by self-play. +#[derive(Clone, Debug)] +pub struct TrainSample { + /// Observation tensor from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1). + pub policy: Vec, + /// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw. + pub value: f32, +} + +// ── Replay buffer ────────────────────────────────────────────────────────── + +/// Fixed-capacity circular buffer of [`TrainSample`]s. +/// +/// When the buffer is full, the oldest sample is evicted on push. +/// Samples are drawn without replacement using a Fisher-Yates partial shuffle. +pub struct ReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + /// Create a buffer with the given maximum capacity. + pub fn new(capacity: usize) -> Self { + Self { + data: VecDeque::with_capacity(capacity.min(1024)), + capacity, + } + } + + /// Add a sample; evicts the oldest if at capacity. + pub fn push(&mut self, sample: TrainSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + /// Add all samples from an episode. + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { + self.push(s); + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Sample up to `n` distinct samples, without replacement. + /// + /// If the buffer has fewer than `n` samples, all are returned (shuffled). + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> { + let len = self.data.len(); + let n = n.min(len); + // Partial Fisher-Yates using index shuffling. + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(value: f32) -> TrainSample { + TrainSample { obs: vec![value], policy: vec![1.0], value } + } + + #[test] + fn push_and_len() { + let mut buf = ReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = ReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); // evicts 1.0 + assert_eq!(buf.len(), 3); + // Oldest remaining should be 2.0 + assert_eq!(buf.data[0].value, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(5, &mut rng); + assert_eq!(batch.len(), 5); + } + + #[test] + fn sample_batch_capped_at_len() { + let mut buf = ReplayBuffer::new(20); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + let mut rng = SmallRng::seed_from_u64(0); + let batch = buf.sample_batch(100, &mut rng); + assert_eq!(batch.len(), 2); + } + + #[test] + fn sample_batch_no_duplicates() { + let mut buf = ReplayBuffer::new(20); + for i in 0..10 { + buf.push(dummy(i as f32)); + } + let mut rng = SmallRng::seed_from_u64(1); + let batch = buf.sample_batch(10, &mut rng); + let mut seen: Vec = batch.iter().map(|s| s.value).collect(); + seen.sort_by(f32::total_cmp); + seen.dedup(); + assert_eq!(seen.len(), 10, "sample contained duplicates"); + } +} diff --git a/spiel_bot/src/alphazero/selfplay.rs b/spiel_bot/src/alphazero/selfplay.rs new file mode 100644 index 0000000..6f10f8d --- /dev/null +++ b/spiel_bot/src/alphazero/selfplay.rs @@ -0,0 +1,234 @@ +//! Self-play episode generation and Burn-backed evaluator. + +use std::marker::PhantomData; + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::GameEnv; +use crate::mcts::{self, Evaluator, MctsConfig, MctsNode}; +use crate::network::PolicyValueNet; + +use super::replay::TrainSample; + +// ── BurnEvaluator ────────────────────────────────────────────────────────── + +/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS. +/// +/// Use the **inference backend** (`NdArray`, no `Autodiff` wrapper) so +/// that self-play generates no gradient tape overhead. +pub struct BurnEvaluator> { + model: N, + device: B::Device, + _b: PhantomData, +} + +impl> BurnEvaluator { + pub fn new(model: N, device: B::Device) -> Self { + Self { model, device, _b: PhantomData } + } + + pub fn into_model(self) -> N { + self.model + } +} + +// Safety: NdArray modules are Send; we never share across threads without +// external synchronisation. +unsafe impl> Send for BurnEvaluator {} +unsafe impl> Sync for BurnEvaluator {} + +impl> Evaluator for BurnEvaluator { + fn evaluate(&self, obs: &[f32]) -> (Vec, f32) { + let obs_size = obs.len(); + let data = TensorData::new(obs.to_vec(), [1, obs_size]); + let obs_tensor = Tensor::::from_data(data, &self.device); + + let (policy_tensor, value_tensor) = self.model.forward(obs_tensor); + + let policy: Vec = policy_tensor.into_data().to_vec().unwrap(); + let value: Vec = value_tensor.into_data().to_vec().unwrap(); + + (policy, value[0]) + } +} + +// ── Episode generation ───────────────────────────────────────────────────── + +/// One pending observation waiting for its game-outcome value label. +struct PendingSample { + obs: Vec, + policy: Vec, + player: usize, +} + +/// Play one full game using MCTS guided by `evaluator`. +/// +/// Returns a [`TrainSample`] for every decision step in the game. +/// +/// `temperature_fn(step)` controls exploration: return `1.0` for early +/// moves and `0.0` after a fixed number of moves (e.g. move 30). +pub fn generate_episode( + env: &E, + evaluator: &dyn Evaluator, + mcts_config: &MctsConfig, + temperature_fn: &dyn Fn(usize) -> f32, + rng: &mut impl Rng, +) -> Vec { + let mut state = env.new_game(); + let mut pending: Vec = Vec::new(); + let mut step = 0usize; + + loop { + // Advance through chance nodes automatically. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + if env.current_player(&state).is_terminal() { + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + + // Run MCTS to get a policy. + let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng); + let policy = mcts::mcts_policy(&root, env.action_space()); + + // Record the observation from the acting player's perspective. + let obs = env.observation(&state, player_idx); + pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx }); + + // Select and apply the action. + let temperature = temperature_fn(step); + let action = mcts::select_action(&root, temperature, rng); + env.apply(&mut state, action); + step += 1; + } + + // Assign game outcomes. + let returns = env.returns(&state).unwrap_or([0.0; 2]); + pending + .into_iter() + .map(|s| TrainSample { + obs: s.obs, + policy: s.policy, + value: returns[s.player], + }) + .collect() +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::env::Player; + use crate::mcts::{Evaluator, MctsConfig}; + use crate::network::{MlpConfig, MlpNet}; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(7) + } + + // Countdown game (same as in mcts tests). + #[derive(Clone, Debug)] + struct CState { remaining: u8, to_move: usize } + + #[derive(Clone)] + struct CountdownEnv; + + impl GameEnv for CountdownEnv { + type State = CState; + fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } } + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { Player::Terminal } + else if s.to_move == 0 { Player::P1 } else { Player::P2 } + } + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { s.remaining = 0; } + else { s.remaining -= sub; s.to_move = 1 - s.to_move; } + } + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 4.0, s.to_move as f32] + } + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + fn tiny_config() -> MctsConfig { + MctsConfig { n_simulations: 5, c_puct: 1.5, + dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 } + } + + // ── BurnEvaluator tests ─────────────────────────────────────────────── + + #[test] + fn burn_evaluator_output_shapes() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let (policy, value) = eval.evaluate(&[0.5f32, 0.5]); + assert_eq!(policy.len(), 2, "policy length should equal action_space"); + assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)"); + } + + // ── generate_episode tests ──────────────────────────────────────────── + + #[test] + fn episode_terminates_and_has_samples() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_sample_values_are_valid() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng()); + for s in &samples { + assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0, + "unexpected value {}", s.value); + let sum: f32 = s.policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}"); + assert_eq!(s.obs.len(), 2); + } + } + + #[test] + fn episode_with_temperature_zero() { + let env = CountdownEnv; + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let eval = BurnEvaluator::new(model, device()); + // temperature=0 means greedy; episode must still terminate + let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng()); + assert!(!samples.is_empty()); + } +} diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs new file mode 100644 index 0000000..d2482d1 --- /dev/null +++ b/spiel_bot/src/alphazero/trainer.rs @@ -0,0 +1,172 @@ +//! One gradient-descent training step for AlphaZero. +//! +//! The loss combines: +//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits. +//! - **Value loss** — mean-squared error between the predicted value and the +//! actual game outcome. +//! +//! # Backend +//! +//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). +//! Self-play uses the inner backend (`NdArray`) for zero autodiff overhead. +//! Weights are transferred between the two via [`burn::record`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + activation::log_softmax, + backend::AutodiffBackend, + Tensor, TensorData, + }, +}; + +use crate::network::PolicyValueNet; +use super::replay::TrainSample; + +/// Run one gradient step on `model` using `batch`. +/// +/// Returns the updated model and the scalar loss value for logging. +/// +/// # Parameters +/// +/// - `lr` — learning rate (e.g. `1e-3`). +/// - `batch` — slice of [`TrainSample`]s; must be non-empty. +pub fn train_step( + model: N, + optimizer: &mut O, + batch: &[TrainSample], + device: &B::Device, + lr: f64, +) -> (N, f32) +where + B: AutodiffBackend, + N: PolicyValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "train_step called with empty batch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + let action_size = batch[0].policy.len(); + + // ── Build input tensors ──────────────────────────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let policy_flat: Vec = batch.iter().flat_map(|s| s.policy.iter().copied()).collect(); + let value_flat: Vec = batch.iter().map(|s| s.value).collect(); + + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + let policy_target = Tensor::::from_data( + TensorData::new(policy_flat, [batch_size, action_size]), + device, + ); + let value_target = Tensor::::from_data( + TensorData::new(value_flat, [batch_size, 1]), + device, + ); + + // ── Forward pass ────────────────────────────────────────────────────── + let (policy_logits, value_pred) = model.forward(obs_tensor); + + // ── Policy loss: -sum(π_mcts · log_softmax(logits)) ────────────────── + let log_probs = log_softmax(policy_logits, 1); + let policy_loss = (policy_target.clone().neg() * log_probs) + .sum_dim(1) + .mean(); + + // ── Value loss: MSE(value_pred, z) ──────────────────────────────────── + let diff = value_pred - value_target; + let value_loss = (diff.clone() * diff).mean(); + + // ── Combined loss ───────────────────────────────────────────────────── + let loss = policy_loss + value_loss; + + // Extract scalar before backward (consumes the tensor). + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + let model = optimizer.step(lr, model, grads); + + (model, loss_scalar) +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + + use crate::network::{MlpConfig, MlpNet}; + use super::super::replay::TrainSample; + + type B = Autodiff>; + + fn device() -> ::Device { + Default::default() + } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| TrainSample { + obs: vec![0.5f32; obs_size], + policy: { + let mut p = vec![0.0f32; action_size]; + p[i % action_size] = 1.0; + p + }, + value: if i % 2 == 0 { 1.0 } else { -1.0 }, + }) + .collect() + } + + #[test] + fn train_step_returns_finite_loss() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + assert!(loss > 0.0, "loss should be positive"); + } + + #[test] + fn loss_decreases_over_steps() { + let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + // Same batch every step — loss should decrease. + let batch = dummy_batch(16, 4, 4); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2); + model = m; + assert!(loss.is_finite()); + prev_loss = loss; + } + // After 10 steps on fixed data, loss should be below a reasonable threshold. + assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn train_step_batch_size_one() { + let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 }; + let model = MlpNet::::new(&config, &device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(1, 2, 2); + let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); + assert!(loss.is_finite()); + } +} diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 5beb37c..23895b9 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,3 +1,4 @@ +pub mod alphazero; pub mod env; pub mod mcts; pub mod network;