feat(spiel_bot): upgrade network
This commit is contained in:
parent
9c82692ddb
commit
822290d722
3 changed files with 131 additions and 3 deletions
|
|
@ -32,7 +32,7 @@ use spiel_bot::{
|
||||||
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||||
env::{GameEnv, Player, TrictracEnv},
|
env::{GameEnv, Player, TrictracEnv},
|
||||||
mcts::{Evaluator, MctsConfig, run_mcts},
|
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||||
network::{MlpConfig, MlpNet, PolicyValueNet},
|
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
|
||||||
};
|
};
|
||||||
|
|
||||||
// ── Shared types ───────────────────────────────────────────────────────────
|
// ── Shared types ───────────────────────────────────────────────────────────
|
||||||
|
|
@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── ResNet (4 residual blocks) ────────────────────────────────────────
|
||||||
|
for &hidden in &[256usize, 512] {
|
||||||
|
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||||
|
let model = ResNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
let obs: Vec<f32> = vec![0.5; 217];
|
||||||
|
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("resnet_b1", hidden),
|
||||||
|
&hidden,
|
||||||
|
|b, _| {
|
||||||
|
b.iter(|| {
|
||||||
|
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||||
|
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||||
|
black_box(model.forward(t))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::new("resnet_b32", hidden),
|
||||||
|
&hidden,
|
||||||
|
|b, _| {
|
||||||
|
b.iter(|| {
|
||||||
|
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||||
|
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||||
|
black_box(model.forward(t))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ pub mod trainer;
|
||||||
|
|
||||||
pub use replay::{ReplayBuffer, TrainSample};
|
pub use replay::{ReplayBuffer, TrainSample};
|
||||||
pub use selfplay::{BurnEvaluator, generate_episode};
|
pub use selfplay::{BurnEvaluator, generate_episode};
|
||||||
pub use trainer::train_step;
|
pub use trainer::{cosine_lr, train_step};
|
||||||
|
|
||||||
use crate::mcts::MctsConfig;
|
use crate::mcts::MctsConfig;
|
||||||
|
|
||||||
|
|
@ -87,8 +87,17 @@ pub struct AlphaZeroConfig {
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
/// Maximum number of samples in the replay buffer.
|
/// Maximum number of samples in the replay buffer.
|
||||||
pub replay_capacity: usize,
|
pub replay_capacity: usize,
|
||||||
/// Adam learning rate.
|
/// Initial (peak) Adam learning rate.
|
||||||
pub learning_rate: f64,
|
pub learning_rate: f64,
|
||||||
|
/// Minimum learning rate for cosine annealing (floor of the schedule).
|
||||||
|
///
|
||||||
|
/// Pass `learning_rate == lr_min` to disable scheduling (constant LR).
|
||||||
|
/// Compute the current LR with [`cosine_lr`]:
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_steps);
|
||||||
|
/// ```
|
||||||
|
pub lr_min: f64,
|
||||||
/// Number of outer iterations (self-play + train) to run.
|
/// Number of outer iterations (self-play + train) to run.
|
||||||
pub n_iterations: usize,
|
pub n_iterations: usize,
|
||||||
/// Move index after which the action temperature drops to 0 (greedy play).
|
/// Move index after which the action temperature drops to 0 (greedy play).
|
||||||
|
|
@ -110,6 +119,7 @@ impl Default for AlphaZeroConfig {
|
||||||
batch_size: 64,
|
batch_size: 64,
|
||||||
replay_capacity: 50_000,
|
replay_capacity: 50_000,
|
||||||
learning_rate: 1e-3,
|
learning_rate: 1e-3,
|
||||||
|
lr_min: 1e-4, // cosine annealing floor
|
||||||
n_iterations: 100,
|
n_iterations: 100,
|
||||||
temperature_drop_move: 30,
|
temperature_drop_move: 30,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,24 @@
|
||||||
//! - **Value loss** — mean-squared error between the predicted value and the
|
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||||
//! actual game outcome.
|
//! actual game outcome.
|
||||||
//!
|
//!
|
||||||
|
//! # Learning-rate scheduling
|
||||||
|
//!
|
||||||
|
//! [`cosine_lr`] implements one-cycle cosine annealing:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! lr(t) = lr_min + 0.5 · (lr_max − lr_min) · (1 + cos(π · t / T))
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Typical usage in the outer loop:
|
||||||
|
//!
|
||||||
|
//! ```rust,ignore
|
||||||
|
//! for step in 0..total_train_steps {
|
||||||
|
//! let lr = cosine_lr(config.learning_rate, config.lr_min, step, total_train_steps);
|
||||||
|
//! let (m, loss) = train_step(model, &mut optimizer, &batch, &device, lr);
|
||||||
|
//! model = m;
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
//! # Backend
|
//! # Backend
|
||||||
//!
|
//!
|
||||||
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||||
|
|
@ -96,6 +114,30 @@ where
|
||||||
(model, loss_scalar)
|
(model, loss_scalar)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Learning-rate schedule ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Cosine learning-rate schedule (one half-period, no warmup).
|
||||||
|
///
|
||||||
|
/// Returns the learning rate for training step `step` out of `total_steps`:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// lr(t) = lr_min + 0.5 · (initial − lr_min) · (1 + cos(π · t / total))
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - At `t = 0` returns `initial`.
|
||||||
|
/// - At `t = total_steps` (or beyond) returns `lr_min`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Does not panic. When `total_steps == 0`, returns `lr_min`.
|
||||||
|
pub fn cosine_lr(initial: f64, lr_min: f64, step: usize, total_steps: usize) -> f64 {
|
||||||
|
if total_steps == 0 || step >= total_steps {
|
||||||
|
return lr_min;
|
||||||
|
}
|
||||||
|
let progress = step as f64 / total_steps as f64;
|
||||||
|
lr_min + 0.5 * (initial - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos())
|
||||||
|
}
|
||||||
|
|
||||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -169,4 +211,48 @@ mod tests {
|
||||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||||
assert!(loss.is_finite());
|
assert!(loss.is_finite());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── cosine_lr ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_at_step_zero_is_initial() {
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, 0, 100);
|
||||||
|
assert!((lr - 1e-3).abs() < 1e-10, "expected initial lr, got {lr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_at_end_is_min() {
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, 100, 100);
|
||||||
|
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr, got {lr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_beyond_end_is_min() {
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, 200, 100);
|
||||||
|
assert!((lr - 1e-5).abs() < 1e-10, "expected min lr beyond end, got {lr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_midpoint_is_average() {
|
||||||
|
// At t = total/2, cos(π/2) = 0, so lr = (initial + min) / 2.
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, 50, 100);
|
||||||
|
let expected = (1e-3 + 1e-5) / 2.0;
|
||||||
|
assert!((lr - expected).abs() < 1e-10, "expected midpoint {expected}, got {lr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_monotone_decreasing() {
|
||||||
|
let mut prev = f64::INFINITY;
|
||||||
|
for step in 0..=100 {
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, step, 100);
|
||||||
|
assert!(lr <= prev + 1e-15, "lr increased at step {step}: {lr} > {prev}");
|
||||||
|
prev = lr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_lr_zero_total_steps_returns_min() {
|
||||||
|
let lr = super::cosine_lr(1e-3, 1e-5, 0, 0);
|
||||||
|
assert!((lr - 1e-5).abs() < 1e-10);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue