feat(spiel_bot): AlphaZero
This commit is contained in:
parent
0619cf6001
commit
0235fd46c2
5 changed files with 668 additions and 0 deletions
117
spiel_bot/src/alphazero/mod.rs
Normal file
117
spiel_bot/src/alphazero/mod.rs
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
//! AlphaZero: self-play data generation, replay buffer, and training step.
|
||||||
|
//!
|
||||||
|
//! # Modules
|
||||||
|
//!
|
||||||
|
//! | Module | Contents |
|
||||||
|
//! |--------|----------|
|
||||||
|
//! | [`replay`] | [`TrainSample`], [`ReplayBuffer`] |
|
||||||
|
//! | [`selfplay`] | [`BurnEvaluator`], [`generate_episode`] |
|
||||||
|
//! | [`trainer`] | [`train_step`] |
|
||||||
|
//!
|
||||||
|
//! # Typical outer loop
|
||||||
|
//!
|
||||||
|
//! ```rust,ignore
|
||||||
|
//! use burn::backend::{Autodiff, NdArray};
|
||||||
|
//! use burn::optim::AdamConfig;
|
||||||
|
//! use spiel_bot::{
|
||||||
|
//! alphazero::{AlphaZeroConfig, BurnEvaluator, ReplayBuffer, generate_episode, train_step},
|
||||||
|
//! env::TrictracEnv,
|
||||||
|
//! mcts::MctsConfig,
|
||||||
|
//! network::{MlpConfig, MlpNet},
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! type Infer = NdArray<f32>;
|
||||||
|
//! type Train = Autodiff<NdArray<f32>>;
|
||||||
|
//!
|
||||||
|
//! let device = Default::default();
|
||||||
|
//! let env = TrictracEnv;
|
||||||
|
//! let config = AlphaZeroConfig::default();
|
||||||
|
//!
|
||||||
|
//! // Build training model and optimizer.
|
||||||
|
//! let mut train_model = MlpNet::<Train>::new(&MlpConfig::default(), &device);
|
||||||
|
//! let mut optimizer = AdamConfig::new().init();
|
||||||
|
//! let mut replay = ReplayBuffer::new(config.replay_capacity);
|
||||||
|
//! let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||||
|
//!
|
||||||
|
//! for _iter in 0..config.n_iterations {
|
||||||
|
//! // Convert to inference backend for self-play.
|
||||||
|
//! let infer_model = MlpNet::<Infer>::new(&MlpConfig::default(), &device)
|
||||||
|
//! .load_record(train_model.clone().into_record());
|
||||||
|
//! let eval = BurnEvaluator::new(infer_model, device.clone());
|
||||||
|
//!
|
||||||
|
//! // Self-play: generate episodes.
|
||||||
|
//! for _ in 0..config.n_games_per_iter {
|
||||||
|
//! let samples = generate_episode(&env, &eval, &config.mcts,
|
||||||
|
//! &|step| if step < 30 { 1.0 } else { 0.0 }, &mut rng);
|
||||||
|
//! replay.extend(samples);
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! // Training: gradient steps.
|
||||||
|
//! if replay.len() >= config.batch_size {
|
||||||
|
//! for _ in 0..config.n_train_steps_per_iter {
|
||||||
|
//! let batch: Vec<_> = replay.sample_batch(config.batch_size, &mut rng)
|
||||||
|
//! .into_iter().cloned().collect();
|
||||||
|
//! let (m, _loss) = train_step(train_model, &mut optimizer, &batch, &device,
|
||||||
|
//! config.learning_rate);
|
||||||
|
//! train_model = m;
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
pub mod replay;
|
||||||
|
pub mod selfplay;
|
||||||
|
pub mod trainer;
|
||||||
|
|
||||||
|
pub use replay::{ReplayBuffer, TrainSample};
|
||||||
|
pub use selfplay::{BurnEvaluator, generate_episode};
|
||||||
|
pub use trainer::train_step;
|
||||||
|
|
||||||
|
use crate::mcts::MctsConfig;
|
||||||
|
|
||||||
|
// ── Configuration ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Top-level AlphaZero hyperparameters.
|
||||||
|
///
|
||||||
|
/// The MCTS parameters live in [`MctsConfig`]; this struct holds the
|
||||||
|
/// outer training-loop parameters.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AlphaZeroConfig {
|
||||||
|
/// MCTS parameters for self-play.
|
||||||
|
pub mcts: MctsConfig,
|
||||||
|
/// Number of self-play games per training iteration.
|
||||||
|
pub n_games_per_iter: usize,
|
||||||
|
/// Number of gradient steps per training iteration.
|
||||||
|
pub n_train_steps_per_iter: usize,
|
||||||
|
/// Mini-batch size for each gradient step.
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// Maximum number of samples in the replay buffer.
|
||||||
|
pub replay_capacity: usize,
|
||||||
|
/// Adam learning rate.
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Number of outer iterations (self-play + train) to run.
|
||||||
|
pub n_iterations: usize,
|
||||||
|
/// Move index after which the action temperature drops to 0 (greedy play).
|
||||||
|
pub temperature_drop_move: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AlphaZeroConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
mcts: MctsConfig {
|
||||||
|
n_simulations: 100,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.1,
|
||||||
|
dirichlet_eps: 0.25,
|
||||||
|
temperature: 1.0,
|
||||||
|
},
|
||||||
|
n_games_per_iter: 10,
|
||||||
|
n_train_steps_per_iter: 20,
|
||||||
|
batch_size: 64,
|
||||||
|
replay_capacity: 50_000,
|
||||||
|
learning_rate: 1e-3,
|
||||||
|
n_iterations: 100,
|
||||||
|
temperature_drop_move: 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
144
spiel_bot/src/alphazero/replay.rs
Normal file
144
spiel_bot/src/alphazero/replay.rs
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
//! Replay buffer for AlphaZero self-play data.
|
||||||
|
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// ── Training sample ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One training example produced by self-play.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TrainSample {
|
||||||
|
/// Observation tensor from the acting player's perspective (`obs_size` floats).
|
||||||
|
pub obs: Vec<f32>,
|
||||||
|
/// MCTS policy target: normalized visit counts (`action_space` floats, sums to 1).
|
||||||
|
pub policy: Vec<f32>,
|
||||||
|
/// Game outcome from the acting player's perspective: +1 win, -1 loss, 0 draw.
|
||||||
|
pub value: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Replay buffer ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Fixed-capacity circular buffer of [`TrainSample`]s.
|
||||||
|
///
|
||||||
|
/// When the buffer is full, the oldest sample is evicted on push.
|
||||||
|
/// Samples are drawn without replacement using a Fisher-Yates partial shuffle.
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
data: VecDeque<TrainSample>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
/// Create a buffer with the given maximum capacity.
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data: VecDeque::with_capacity(capacity.min(1024)),
|
||||||
|
capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a sample; evicts the oldest if at capacity.
|
||||||
|
pub fn push(&mut self, sample: TrainSample) {
|
||||||
|
if self.data.len() == self.capacity {
|
||||||
|
self.data.pop_front();
|
||||||
|
}
|
||||||
|
self.data.push_back(sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add all samples from an episode.
|
||||||
|
pub fn extend(&mut self, samples: impl IntoIterator<Item = TrainSample>) {
|
||||||
|
for s in samples {
|
||||||
|
self.push(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sample up to `n` distinct samples, without replacement.
|
||||||
|
///
|
||||||
|
/// If the buffer has fewer than `n` samples, all are returned (shuffled).
|
||||||
|
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||||
|
let len = self.data.len();
|
||||||
|
let n = n.min(len);
|
||||||
|
// Partial Fisher-Yates using index shuffling.
|
||||||
|
let mut indices: Vec<usize> = (0..len).collect();
|
||||||
|
for i in 0..n {
|
||||||
|
let j = rng.random_range(i..len);
|
||||||
|
indices.swap(i, j);
|
||||||
|
}
|
||||||
|
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
fn dummy(value: f32) -> TrainSample {
|
||||||
|
TrainSample { obs: vec![value], policy: vec![1.0], value }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_and_len() {
|
||||||
|
let mut buf = ReplayBuffer::new(10);
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
assert_eq!(buf.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn evicts_oldest_at_capacity() {
|
||||||
|
let mut buf = ReplayBuffer::new(3);
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
buf.push(dummy(3.0));
|
||||||
|
buf.push(dummy(4.0)); // evicts 1.0
|
||||||
|
assert_eq!(buf.len(), 3);
|
||||||
|
// Oldest remaining should be 2.0
|
||||||
|
assert_eq!(buf.data[0].value, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_size() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
for i in 0..10 {
|
||||||
|
buf.push(dummy(i as f32));
|
||||||
|
}
|
||||||
|
let mut rng = SmallRng::seed_from_u64(0);
|
||||||
|
let batch = buf.sample_batch(5, &mut rng);
|
||||||
|
assert_eq!(batch.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_capped_at_len() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
let mut rng = SmallRng::seed_from_u64(0);
|
||||||
|
let batch = buf.sample_batch(100, &mut rng);
|
||||||
|
assert_eq!(batch.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_no_duplicates() {
|
||||||
|
let mut buf = ReplayBuffer::new(20);
|
||||||
|
for i in 0..10 {
|
||||||
|
buf.push(dummy(i as f32));
|
||||||
|
}
|
||||||
|
let mut rng = SmallRng::seed_from_u64(1);
|
||||||
|
let batch = buf.sample_batch(10, &mut rng);
|
||||||
|
let mut seen: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||||
|
seen.sort_by(f32::total_cmp);
|
||||||
|
seen.dedup();
|
||||||
|
assert_eq!(seen.len(), 10, "sample contained duplicates");
|
||||||
|
}
|
||||||
|
}
|
||||||
234
spiel_bot/src/alphazero/selfplay.rs
Normal file
234
spiel_bot/src/alphazero/selfplay.rs
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
//! Self-play episode generation and Burn-backed evaluator.
|
||||||
|
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
use crate::mcts::{self, Evaluator, MctsConfig, MctsNode};
|
||||||
|
use crate::network::PolicyValueNet;
|
||||||
|
|
||||||
|
use super::replay::TrainSample;
|
||||||
|
|
||||||
|
// ── BurnEvaluator ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Wraps a [`PolicyValueNet`] as an [`Evaluator`] for MCTS.
|
||||||
|
///
|
||||||
|
/// Use the **inference backend** (`NdArray<f32>`, no `Autodiff` wrapper) so
|
||||||
|
/// that self-play generates no gradient tape overhead.
|
||||||
|
pub struct BurnEvaluator<B: Backend, N: PolicyValueNet<B>> {
|
||||||
|
model: N,
|
||||||
|
device: B::Device,
|
||||||
|
_b: PhantomData<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
||||||
|
pub fn new(model: N, device: B::Device) -> Self {
|
||||||
|
Self { model, device, _b: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_model(self) -> N {
|
||||||
|
self.model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||||
|
// external synchronisation.
|
||||||
|
unsafe impl<B: Backend, N: PolicyValueNet<B>> Send for BurnEvaluator<B, N> {}
|
||||||
|
unsafe impl<B: Backend, N: PolicyValueNet<B>> Sync for BurnEvaluator<B, N> {}
|
||||||
|
|
||||||
|
impl<B: Backend, N: PolicyValueNet<B>> Evaluator for BurnEvaluator<B, N> {
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32) {
|
||||||
|
let obs_size = obs.len();
|
||||||
|
let data = TensorData::new(obs.to_vec(), [1, obs_size]);
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(data, &self.device);
|
||||||
|
|
||||||
|
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||||
|
|
||||||
|
let policy: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||||
|
let value: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||||
|
|
||||||
|
(policy, value[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Episode generation ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One pending observation waiting for its game-outcome value label.
|
||||||
|
struct PendingSample {
|
||||||
|
obs: Vec<f32>,
|
||||||
|
policy: Vec<f32>,
|
||||||
|
player: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Play one full game using MCTS guided by `evaluator`.
|
||||||
|
///
|
||||||
|
/// Returns a [`TrainSample`] for every decision step in the game.
|
||||||
|
///
|
||||||
|
/// `temperature_fn(step)` controls exploration: return `1.0` for early
|
||||||
|
/// moves and `0.0` after a fixed number of moves (e.g. move 30).
|
||||||
|
pub fn generate_episode<E: GameEnv>(
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
mcts_config: &MctsConfig,
|
||||||
|
temperature_fn: &dyn Fn(usize) -> f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> Vec<TrainSample> {
|
||||||
|
let mut state = env.new_game();
|
||||||
|
let mut pending: Vec<PendingSample> = Vec::new();
|
||||||
|
let mut step = 0usize;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Advance through chance nodes automatically.
|
||||||
|
while env.current_player(&state).is_chance() {
|
||||||
|
env.apply_chance(&mut state, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
if env.current_player(&state).is_terminal() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let player_idx = env.current_player(&state).index().unwrap();
|
||||||
|
|
||||||
|
// Run MCTS to get a policy.
|
||||||
|
let root: MctsNode = mcts::run_mcts(env, &state, evaluator, mcts_config, rng);
|
||||||
|
let policy = mcts::mcts_policy(&root, env.action_space());
|
||||||
|
|
||||||
|
// Record the observation from the acting player's perspective.
|
||||||
|
let obs = env.observation(&state, player_idx);
|
||||||
|
pending.push(PendingSample { obs, policy: policy.clone(), player: player_idx });
|
||||||
|
|
||||||
|
// Select and apply the action.
|
||||||
|
let temperature = temperature_fn(step);
|
||||||
|
let action = mcts::select_action(&root, temperature, rng);
|
||||||
|
env.apply(&mut state, action);
|
||||||
|
step += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign game outcomes.
|
||||||
|
let returns = env.returns(&state).unwrap_or([0.0; 2]);
|
||||||
|
pending
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| TrainSample {
|
||||||
|
obs: s.obs,
|
||||||
|
policy: s.policy,
|
||||||
|
value: returns[s.player],
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
use crate::env::Player;
|
||||||
|
use crate::mcts::{Evaluator, MctsConfig};
|
||||||
|
use crate::network::{MlpConfig, MlpNet};
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rng() -> SmallRng {
|
||||||
|
SmallRng::seed_from_u64(7)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Countdown game (same as in mcts tests).
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct CState { remaining: u8, to_move: usize }
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountdownEnv;
|
||||||
|
|
||||||
|
impl GameEnv for CountdownEnv {
|
||||||
|
type State = CState;
|
||||||
|
fn new_game(&self) -> CState { CState { remaining: 4, to_move: 0 } }
|
||||||
|
fn current_player(&self, s: &CState) -> Player {
|
||||||
|
if s.remaining == 0 { Player::Terminal }
|
||||||
|
else if s.to_move == 0 { Player::P1 } else { Player::P2 }
|
||||||
|
}
|
||||||
|
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||||
|
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||||
|
}
|
||||||
|
fn apply(&self, s: &mut CState, action: usize) {
|
||||||
|
let sub = (action as u8) + 1;
|
||||||
|
if s.remaining <= sub { s.remaining = 0; }
|
||||||
|
else { s.remaining -= sub; s.to_move = 1 - s.to_move; }
|
||||||
|
}
|
||||||
|
fn apply_chance<R: Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||||
|
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||||
|
vec![s.remaining as f32 / 4.0, s.to_move as f32]
|
||||||
|
}
|
||||||
|
fn obs_size(&self) -> usize { 2 }
|
||||||
|
fn action_space(&self) -> usize { 2 }
|
||||||
|
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||||
|
if s.remaining != 0 { return None; }
|
||||||
|
let mut r = [-1.0f32; 2];
|
||||||
|
r[s.to_move] = 1.0;
|
||||||
|
Some(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tiny_config() -> MctsConfig {
|
||||||
|
MctsConfig { n_simulations: 5, c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0, dirichlet_eps: 0.0, temperature: 1.0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── BurnEvaluator tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn burn_evaluator_output_shapes() {
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let (policy, value) = eval.evaluate(&[0.5f32, 0.5]);
|
||||||
|
assert_eq!(policy.len(), 2, "policy length should equal action_space");
|
||||||
|
assert!(value > -1.0 && value < 1.0, "value {value} should be in (-1,1)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── generate_episode tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_terminates_and_has_samples() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||||
|
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_sample_values_are_valid() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 1.0, &mut rng());
|
||||||
|
for s in &samples {
|
||||||
|
assert!(s.value == 1.0 || s.value == -1.0 || s.value == 0.0,
|
||||||
|
"unexpected value {}", s.value);
|
||||||
|
let sum: f32 = s.policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-4, "policy sums to {sum}");
|
||||||
|
assert_eq!(s.obs.len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_with_temperature_zero() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let eval = BurnEvaluator::new(model, device());
|
||||||
|
// temperature=0 means greedy; episode must still terminate
|
||||||
|
let samples = generate_episode(&env, &eval, &tiny_config(), &|_| 0.0, &mut rng());
|
||||||
|
assert!(!samples.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
172
spiel_bot/src/alphazero/trainer.rs
Normal file
172
spiel_bot/src/alphazero/trainer.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
//! One gradient-descent training step for AlphaZero.
|
||||||
|
//!
|
||||||
|
//! The loss combines:
|
||||||
|
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
|
||||||
|
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||||
|
//! actual game outcome.
|
||||||
|
//!
|
||||||
|
//! # Backend
|
||||||
|
//!
|
||||||
|
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||||
|
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
|
||||||
|
//! Weights are transferred between the two via [`burn::record`].
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::AutodiffModule,
|
||||||
|
optim::{GradientsParams, Optimizer},
|
||||||
|
prelude::ElementConversion,
|
||||||
|
tensor::{
|
||||||
|
activation::log_softmax,
|
||||||
|
backend::AutodiffBackend,
|
||||||
|
Tensor, TensorData,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::network::PolicyValueNet;
|
||||||
|
use super::replay::TrainSample;
|
||||||
|
|
||||||
|
/// Run one gradient step on `model` using `batch`.
|
||||||
|
///
|
||||||
|
/// Returns the updated model and the scalar loss value for logging.
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// - `lr` — learning rate (e.g. `1e-3`).
|
||||||
|
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
|
||||||
|
pub fn train_step<B, N, O>(
|
||||||
|
model: N,
|
||||||
|
optimizer: &mut O,
|
||||||
|
batch: &[TrainSample],
|
||||||
|
device: &B::Device,
|
||||||
|
lr: f64,
|
||||||
|
) -> (N, f32)
|
||||||
|
where
|
||||||
|
B: AutodiffBackend,
|
||||||
|
N: PolicyValueNet<B> + AutodiffModule<B>,
|
||||||
|
O: Optimizer<N, B>,
|
||||||
|
{
|
||||||
|
assert!(!batch.is_empty(), "train_step called with empty batch");
|
||||||
|
|
||||||
|
let batch_size = batch.len();
|
||||||
|
let obs_size = batch[0].obs.len();
|
||||||
|
let action_size = batch[0].policy.len();
|
||||||
|
|
||||||
|
// ── Build input tensors ────────────────────────────────────────────────
|
||||||
|
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||||
|
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
|
||||||
|
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||||
|
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
let policy_target = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(policy_flat, [batch_size, action_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
let value_target = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(value_flat, [batch_size, 1]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── Forward pass ──────────────────────────────────────────────────────
|
||||||
|
let (policy_logits, value_pred) = model.forward(obs_tensor);
|
||||||
|
|
||||||
|
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
|
||||||
|
let log_probs = log_softmax(policy_logits, 1);
|
||||||
|
let policy_loss = (policy_target.clone().neg() * log_probs)
|
||||||
|
.sum_dim(1)
|
||||||
|
.mean();
|
||||||
|
|
||||||
|
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
|
||||||
|
let diff = value_pred - value_target;
|
||||||
|
let value_loss = (diff.clone() * diff).mean();
|
||||||
|
|
||||||
|
// ── Combined loss ─────────────────────────────────────────────────────
|
||||||
|
let loss = policy_loss + value_loss;
|
||||||
|
|
||||||
|
// Extract scalar before backward (consumes the tensor).
|
||||||
|
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||||
|
|
||||||
|
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||||
|
let grads = loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
let model = optimizer.step(lr, model, grads);
|
||||||
|
|
||||||
|
(model, loss_scalar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::{
|
||||||
|
backend::{Autodiff, NdArray},
|
||||||
|
optim::AdamConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::network::{MlpConfig, MlpNet};
|
||||||
|
use super::super::replay::TrainSample;
|
||||||
|
|
||||||
|
type B = Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
fn device() -> <B as burn::tensor::backend::Backend>::Device {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
|
||||||
|
(0..n)
|
||||||
|
.map(|i| TrainSample {
|
||||||
|
obs: vec![0.5f32; obs_size],
|
||||||
|
policy: {
|
||||||
|
let mut p = vec![0.0f32; action_size];
|
||||||
|
p[i % action_size] = 1.0;
|
||||||
|
p
|
||||||
|
},
|
||||||
|
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_returns_finite_loss() {
|
||||||
|
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(8, 4, 4);
|
||||||
|
|
||||||
|
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||||
|
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||||
|
assert!(loss > 0.0, "loss should be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn loss_decreases_over_steps() {
|
||||||
|
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||||
|
let mut model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
// Same batch every step — loss should decrease.
|
||||||
|
let batch = dummy_batch(16, 4, 4);
|
||||||
|
|
||||||
|
let mut prev_loss = f32::INFINITY;
|
||||||
|
for _ in 0..10 {
|
||||||
|
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
|
||||||
|
model = m;
|
||||||
|
assert!(loss.is_finite());
|
||||||
|
prev_loss = loss;
|
||||||
|
}
|
||||||
|
// After 10 steps on fixed data, loss should be below a reasonable threshold.
|
||||||
|
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_batch_size_one() {
|
||||||
|
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||||
|
let model = MlpNet::<B>::new(&config, &device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(1, 2, 2);
|
||||||
|
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||||
|
assert!(loss.is_finite());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod alphazero;
|
||||||
pub mod env;
|
pub mod env;
|
||||||
pub mod mcts;
|
pub mod mcts;
|
||||||
pub mod network;
|
pub mod network;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue