feat(spiel_bot): AlphaZero
This commit is contained in:
parent
58ae8ad3b3
commit
b0ae4db2d9
5 changed files with 668 additions and 0 deletions
172
spiel_bot/src/alphazero/trainer.rs
Normal file
172
spiel_bot/src/alphazero/trainer.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
//! One gradient-descent training step for AlphaZero.
|
||||
//!
|
||||
//! The loss combines:
|
||||
//! - **Policy loss** — cross-entropy between MCTS visit counts and network logits.
|
||||
//! - **Value loss** — mean-squared error between the predicted value and the
|
||||
//! actual game outcome.
|
||||
//!
|
||||
//! # Backend
|
||||
//!
|
||||
//! `train_step` requires an `AutodiffBackend` (e.g. `Autodiff<NdArray<f32>>`).
|
||||
//! Self-play uses the inner backend (`NdArray<f32>`) for zero autodiff overhead.
|
||||
//! Weights are transferred between the two via [`burn::record`].
|
||||
|
||||
use burn::{
|
||||
module::AutodiffModule,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
prelude::ElementConversion,
|
||||
tensor::{
|
||||
activation::log_softmax,
|
||||
backend::AutodiffBackend,
|
||||
Tensor, TensorData,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::network::PolicyValueNet;
|
||||
use super::replay::TrainSample;
|
||||
|
||||
/// Run one gradient step on `model` using `batch`.
|
||||
///
|
||||
/// Returns the updated model and the scalar loss value for logging.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `lr` — learning rate (e.g. `1e-3`).
|
||||
/// - `batch` — slice of [`TrainSample`]s; must be non-empty.
|
||||
pub fn train_step<B, N, O>(
|
||||
model: N,
|
||||
optimizer: &mut O,
|
||||
batch: &[TrainSample],
|
||||
device: &B::Device,
|
||||
lr: f64,
|
||||
) -> (N, f32)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
N: PolicyValueNet<B> + AutodiffModule<B>,
|
||||
O: Optimizer<N, B>,
|
||||
{
|
||||
assert!(!batch.is_empty(), "train_step called with empty batch");
|
||||
|
||||
let batch_size = batch.len();
|
||||
let obs_size = batch[0].obs.len();
|
||||
let action_size = batch[0].policy.len();
|
||||
|
||||
// ── Build input tensors ────────────────────────────────────────────────
|
||||
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||
let policy_flat: Vec<f32> = batch.iter().flat_map(|s| s.policy.iter().copied()).collect();
|
||||
let value_flat: Vec<f32> = batch.iter().map(|s| s.value).collect();
|
||||
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||
device,
|
||||
);
|
||||
let policy_target = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(policy_flat, [batch_size, action_size]),
|
||||
device,
|
||||
);
|
||||
let value_target = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(value_flat, [batch_size, 1]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── Forward pass ──────────────────────────────────────────────────────
|
||||
let (policy_logits, value_pred) = model.forward(obs_tensor);
|
||||
|
||||
// ── Policy loss: -sum(π_mcts · log_softmax(logits)) ──────────────────
|
||||
let log_probs = log_softmax(policy_logits, 1);
|
||||
let policy_loss = (policy_target.clone().neg() * log_probs)
|
||||
.sum_dim(1)
|
||||
.mean();
|
||||
|
||||
// ── Value loss: MSE(value_pred, z) ────────────────────────────────────
|
||||
let diff = value_pred - value_target;
|
||||
let value_loss = (diff.clone() * diff).mean();
|
||||
|
||||
// ── Combined loss ─────────────────────────────────────────────────────
|
||||
let loss = policy_loss + value_loss;
|
||||
|
||||
// Extract scalar before backward (consumes the tensor).
|
||||
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||
|
||||
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
let model = optimizer.step(lr, model, grads);
|
||||
|
||||
(model, loss_scalar)
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
use crate::network::{MlpConfig, MlpNet};
|
||||
use super::super::replay::TrainSample;
|
||||
|
||||
type B = Autodiff<NdArray<f32>>;
|
||||
|
||||
fn device() -> <B as burn::tensor::backend::Backend>::Device {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<TrainSample> {
|
||||
(0..n)
|
||||
.map(|i| TrainSample {
|
||||
obs: vec![0.5f32; obs_size],
|
||||
policy: {
|
||||
let mut p = vec![0.0f32; action_size];
|
||||
p[i % action_size] = 1.0;
|
||||
p
|
||||
},
|
||||
value: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_returns_finite_loss() {
|
||||
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
|
||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||
assert!(loss > 0.0, "loss should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_decreases_over_steps() {
|
||||
let config = MlpConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||
let mut model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
// Same batch every step — loss should decrease.
|
||||
let batch = dummy_batch(16, 4, 4);
|
||||
|
||||
let mut prev_loss = f32::INFINITY;
|
||||
for _ in 0..10 {
|
||||
let (m, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-2);
|
||||
model = m;
|
||||
assert!(loss.is_finite());
|
||||
prev_loss = loss;
|
||||
}
|
||||
// After 10 steps on fixed data, loss should be below a reasonable threshold.
|
||||
assert!(prev_loss < 3.0, "loss did not decrease: {prev_loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_batch_size_one() {
|
||||
let config = MlpConfig { obs_size: 2, action_size: 2, hidden_size: 8 };
|
||||
let model = MlpNet::<B>::new(&config, &device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(1, 2, 2);
|
||||
let (_, loss) = train_step(model, &mut optimizer, &batch, &device(), 1e-3);
|
||||
assert!(loss.is_finite());
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue