feat(spiel_bot): upgrade network

This commit is contained in:
Henri Bourcereau 2026-03-07 23:05:53 +01:00
parent 2e85c14dbb
commit c8f2a097cd
3 changed files with 131 additions and 3 deletions

View file

@ -32,7 +32,7 @@ use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet},
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
// ── Shared types ───────────────────────────────────────────────────────────
@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) {
);
}
// ── ResNet (4 residual blocks) ────────────────────────────────────────
for &hidden in &[256usize, 512] {
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = ResNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
group.bench_with_input(
BenchmarkId::new("resnet_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("resnet_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}

View file

@ -65,7 +65,7 @@ pub mod trainer;
pub use replay::{ReplayBuffer, TrainSample};
pub use selfplay::{BurnEvaluator, generate_episode};
pub use trainer::train_step;
pub use trainer::{cosine_lr, train_step};
use crate::mcts::MctsConfig;
@ -87,8 +87,17 @@ pub struct AlphaZeroConfig {
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Adam learning rate.
/// Initial (peak) Adam learning rate.
pub learning_rate: f64,
/// Minimum learning rate for cosine annealing (floor of the schedule).
///
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
/// Compute the current LR with [`cosine_lr`]:
///
/// ```rust,ignore
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
/// ```
pub lr_min: f64,
/// Number of outer iterations (self-play + train) to run.
pub n_iterations: usize,
/// Move index after which the action temperature drops to 0 (greedy play).
@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig {
batch_size: 64,
replay_capacity: 50_000,
learning_rate: 1e-3,
lr_min: 1e-4, // cosine annealing floor
n_iterations: 100,
temperature_drop_move: 30,
}

View file

@ -5,6 +5,24 @@
//! - **Value loss** — mean-squared error between the predicted value and the
//! actual game outcome.
//!
//! # Learning-rate scheduling
//!
//! [`cosine_lr`] implements one-cycle cosine annealing:
//!
//! ```text
//! lr(t) = lr_min + 0.5 · (lr_max lr_min) · (1 + cos(π · t / T))
//! ```
//!
//! Typical usage in the outer loop:
//!
//! ```rust,ignore
//! for step in 0..total_train_steps {
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
//! model = m;
//! }
//! ```
//!
//! # Backend
//!
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
@ -96,6 +114,30 @@ where
(model, loss_scalar)
}
// ── Learning-rate schedule ─────────────────────────────────────────────────
/// Cosine learning-rate schedule (one half-period, no warmup).
///
/// Returns the learning rate for training step `step` out of `total_steps`:
///
/// ```text
/// lr(t) = lr_min + 0.5 · (initial lr_min) · (1 + cos(π · t / total))
/// ```
///
/// - At `t = 0` returns `initial`.
/// - At `t = total_steps` (or beyond) returns `lr_min`.
///
/// # Panics
///
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
if total_steps == 0 || step >= total_steps {
return lr_min;
}
let progress = step as f64 / total_steps as f64;
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
}
// ── Tests ──────────────────────────────────────────────────────────────────
#[cfg(test)]
@ -169,4 +211,48 @@ mod tests {
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
assert!(loss.is_finite());
}
// ── cosine_lr ─────────────────────────────────────────────────────────
#[test]
fn cosine_lr_at_step_zero_is_initial() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
}
#[test]
fn cosine_lr_at_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
}
#[test]
fn cosine_lr_beyond_end_is_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
}
#[test]
fn cosine_lr_midpoint_is_average() {
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
let expected = (1e-3 + 1e-5) / 2.0;
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
}
#[test]
fn cosine_lr_monotone_decreasing() {
let mut prev = f64::INFINITY;
for step in 0..=100 {
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
prev = lr;
}
}
#[test]
fn cosine_lr_zero_total_steps_returns_min() {
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
assert!((lr - 1e-5).abs() < 1e-10);
}
}