feat(spiel_bot): upgrade network
This commit is contained in:
parent
2e85c14dbb
commit
c8f2a097cd
3 changed files with 131 additions and 3 deletions
|
|
@ -32,7 +32,7 @@ use spiel_bot::{
|
|||
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
// ── Shared types ───────────────────────────────────────────────────────────
|
||||
|
|
@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) {
|
|||
);
|
||||
}
|
||||
|
||||
// ── ResNet (4 residual blocks) ────────────────────────────────────────
|
||||
for &hidden in &[256usize, 512] {
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = ResNet::<InferB>::new(&cfg, &infer_device());
|
||||
let obs: Vec<f32> = vec![0.5; 217];
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b1", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b32", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ pub mod trainer;
|
|||
|
||||
pub use replay::{ReplayBuffer, TrainSample};
|
||||
pub use selfplay::{BurnEvaluator, generate_episode};
|
||||
pub use trainer::train_step;
|
||||
pub use trainer::{cosine_lr, train_step};
|
||||
|
||||
use crate::mcts::MctsConfig;
|
||||
|
||||
|
|
@ -87,8 +87,17 @@ pub struct AlphaZeroConfig {
|
|||
pub batch_size: usize,
|
||||
/// Maximum number of samples in the replay buffer.
|
||||
pub replay_capacity: usize,
|
||||
/// Adam learning rate.
|
||||
/// Initial (peak) Adam learning rate.
|
||||
pub learning_rate: f64,
|
||||
/// Minimum learning rate for cosine annealing (floor of the schedule).
|
||||
///
|
||||
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
|
||||
/// Compute the current LR with [`cosine_lr`]:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
|
||||
/// ```
|
||||
pub lr_min: f64,
|
||||
/// Number of outer iterations (self-play + train) to run.
|
||||
pub n_iterations: usize,
|
||||
/// Move index after which the action temperature drops to 0 (greedy play).
|
||||
|
|
@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig {
|
|||
batch_size: 64,
|
||||
replay_capacity: 50_000,
|
||||
learning_rate: 1e-3,
|
||||
lr_min: 1e-4, // cosine annealing floor
|
||||
n_iterations: 100,
|
||||
temperature_drop_move: 30,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,24 @@
|
|||
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||
//! actual game outcome.
|
||||
//!
|
||||
//! # Learning-rate scheduling
|
||||
//!
|
||||
//! [`cosine_lr`] implements one-cycle cosine annealing:
|
||||
//!
|
||||
//! ```text
|
||||
//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T))
|
||||
//! ```
|
||||
//!
|
||||
//! Typical usage in the outer loop:
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! for step in 0..total_train_steps {
|
||||
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
|
||||
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
|
||||
//! model = m;
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! # Backend
|
||||
//!
|
||||
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||
|
|
@ -96,6 +114,30 @@ where
|
|||
(model, loss_scalar)
|
||||
}
|
||||
|
||||
// ── Learning-rate schedule ─────────────────────────────────────────────────
|
||||
|
||||
/// Cosine learning-rate schedule (one half-period, no warmup).
|
||||
///
|
||||
/// Returns the learning rate for training step `step` out of `total_steps`:
|
||||
///
|
||||
/// ```text
|
||||
/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total))
|
||||
/// ```
|
||||
///
|
||||
/// - At `t = 0` returns `initial`.
|
||||
/// - At `t = total_steps` (or beyond) returns `lr_min`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
|
||||
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
|
||||
if total_steps == 0 || step >= total_steps {
|
||||
return lr_min;
|
||||
}
|
||||
let progress = step as f64 / total_steps as f64;
|
||||
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -169,4 +211,48 @@ mod tests {
|
|||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||
assert!(loss.is_finite());
|
||||
}
|
||||
|
||||
// ── cosine_lr ─────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_at_step_zero_is_initial() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
|
||||
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_at_end_is_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
|
||||
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_beyond_end_is_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
|
||||
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_midpoint_is_average() {
|
||||
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
|
||||
let expected = (1e-3 + 1e-5) / 2.0;
|
||||
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_monotone_decreasing() {
|
||||
let mut prev = f64::INFINITY;
|
||||
for step in 0..=100 {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
|
||||
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
|
||||
prev = lr;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_lr_zero_total_steps_returns_min() {
|
||||
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
|
||||
assert!((lr - 1e-5).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue