From 822290d7224fa8cc6ce0aa324adcfe388dd705b2 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 23:05:53 +0100 Subject: [PATCH] feat(spiel_bot): upgrade network --- spiel_bot/benches/alphazero.rs | 34 +++++++++++- spiel_bot/src/alphazero/mod.rs | 14 ++++- spiel_bot/src/alphazero/trainer.rs | 86 ++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/spiel_bot/benches/alphazero.rs b/spiel_bot/benches/alphazero.rs index 2950b09..00d5b02 100644 --- a/spiel_bot/benches/alphazero.rs +++ b/spiel_bot/benches/alphazero.rs @@ -32,7 +32,7 @@ use spiel_bot::{ alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step}, env::{GameEnv, Player, TrictracEnv}, mcts::{Evaluator, MctsConfig, run_mcts}, - network::{MlpConfig, MlpNet, PolicyValueNet}, + network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig}, }; // ── Shared types ─────────────────────────────────────────────────────────── @@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) { ); } + // ── ResNet (4 residual blocks) ──────────────────────────────────────── + for &hidden in &[256usize, 512] { + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden }; + let model = ResNet::::new(&cfg, &infer_device()); + let obs: Vec = vec![0.5; 217]; + + group.bench_with_input( + BenchmarkId::new("resnet_b1", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs.clone(), [1, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + + let obs32: Vec = vec![0.5; 217 * 32]; + group.bench_with_input( + BenchmarkId::new("resnet_b32", hidden), + &hidden, + |b, _| { + b.iter(|| { + let data = TensorData::new(obs32.clone(), [32, 217]); + let t = Tensor::::from_data(data, &infer_device()); + black_box(model.forward(t)) + }) + }, + ); + } + group.finish(); } diff --git a/spiel_bot/src/alphazero/mod.rs b/spiel_bot/src/alphazero/mod.rs index bb86724..d92224e 100644 --- a/spiel_bot/src/alphazero/mod.rs +++ b/spiel_bot/src/alphazero/mod.rs @@ -65,7 +65,7 @@ pub mod trainer; pub use replay::{ReplayBuffer, TrainSample}; pub use selfplay::{BurnEvaluator, generate_episode}; -pub use trainer::train_step; +pub use trainer::{cosine_lr, train_step}; use crate::mcts::MctsConfig; @@ -87,8 +87,17 @@ pub struct AlphaZeroConfig { pub batch_size: usize, /// Maximum number of samples in the replay buffer. pub replay_capacity: usize, - /// Adam learning rate. + /// Initial (peak) Adam learning rate. pub learning_rate: f64, + /// Minimum learning rate for cosine annealing (floor of the schedule). + /// + /// Pass `learning_rate == lr_min` to disable scheduling (constant LR). + /// Compute the current LR with [`cosine_lr`]: + /// + /// ```rust,ignore + /// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps); + /// ``` + pub lr_min: f64, /// Number of outer iterations (self-play + train) to run. pub n_iterations: usize, /// Move index after which the action temperature drops to 0 (greedy play). @@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig { batch_size: 64, replay_capacity: 50_000, learning_rate: 1e-3, + lr_min: 1e-4, // cosine annealing floor n_iterations: 100, temperature_drop_move: 30, } diff --git a/spiel_bot/src/alphazero/trainer.rs b/spiel_bot/src/alphazero/trainer.rs index d2482d1..9075519 100644 --- a/spiel_bot/src/alphazero/trainer.rs +++ b/spiel_bot/src/alphazero/trainer.rs @@ -5,6 +5,24 @@ //! - **Value loss** — mean-squared error between the predicted value and the //! actual game outcome. //! +//! # Learning-rate scheduling +//! +//! [`cosine_lr`] implements one-cycle cosine annealing: +//! +//! ```text +//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T)) +//! ``` +//! +//! Typical usage in the outer loop: +//! +//! ```rust,ignore +//! for step in 0..total_train_steps { +//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps); +//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr); +//! model = m; +//! } +//! ``` +//! //! # Backend //! //! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff>`). @@ -96,6 +114,30 @@ where (model, loss_scalar) } +// ── Learning-rate schedule ───────────────────────────────────────────────── + +/// Cosine learning-rate schedule (one half-period, no warmup). +/// +/// Returns the learning rate for training step `step` out of `total_steps`: +/// +/// ```text +/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total)) +/// ``` +/// +/// - At `t = 0` returns `initial`. +/// - At `t = total_steps` (or beyond) returns `lr_min`. +/// +/// # Panics +/// +/// Does not panic. When `total_steps == 0`, returns `lr_min`. +pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 { + if total_steps == 0 || step >= total_steps { + return lr_min; + } + let progress = step as f64 / total_steps as f64; + lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos()) +} + // ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -169,4 +211,48 @@ mod tests { let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3); assert!(loss.is_finite()); } + + // ── cosine_lr ───────────────────────────────────────────────────────── + + #[test] + fn cosine_lr_at_step_zero_is_initial() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 100); + assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}"); + } + + #[test] + fn cosine_lr_at_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 100, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}"); + } + + #[test] + fn cosine_lr_beyond_end_is_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 200, 100); + assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}"); + } + + #[test] + fn cosine_lr_midpoint_is_average() { + // At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2. + let lr = super::cosine_lr(1e-3, 1e-5, 50, 100); + let expected = (1e-3 + 1e-5) / 2.0; + assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}"); + } + + #[test] + fn cosine_lr_monotone_decreasing() { + let mut prev = f64::INFINITY; + for step in 0..=100 { + let lr = super::cosine_lr(1e-3, 1e-5, step, 100); + assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}"); + prev = lr; + } + } + + #[test] + fn cosine_lr_zero_total_steps_returns_min() { + let lr = super::cosine_lr(1e-3, 1e-5, 0, 0); + assert!((lr - 1e-5).abs() < 1e-10); + } }