diff --git a/Cargo.lock b/Cargo.lock index d1f5a20..2e81285 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5896,6 +5896,7 @@ name = "spiel_bot" version = "0.1.0" dependencies = [ "anyhow", + "burn", "rand 0.9.2", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 2459f51..fba2aab 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 3d7924f..6e71016 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1 +1,2 @@ pub mod env; +pub mod network; diff --git a/spiel_bot/src/network/mlp.rs b/spiel_bot/src/network/mlp.rs new file mode 100644 index 0000000..eb6184e --- /dev/null +++ b/spiel_bot/src/network/mlp.rs @@ -0,0 +1,223 @@ +//! Two-hidden-layer MLP policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`MlpNet`]. +#[derive(Debug, Clone)] +pub struct MlpConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for MlpConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 256, + } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Simple two-hidden-layer MLP with shared trunk and two heads. +/// +/// Prefer this over [`ResNet`](super::ResNet) when training time is a +/// priority, or as a fast baseline. +#[derive(Module, Debug)] +pub struct MlpNet { + fc1: Linear, + fc2: Linear, + policy_head: Linear, + value_head: Linear, +} + +impl MlpNet { + /// Construct a fresh network with random weights. + pub fn new(config: &MlpConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + value_head: LinearConfig::new(config.hidden_size, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + /// + /// The file is written exactly at `path`; callers should append `.mpk` if + /// they want the conventional extension. + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for MlpNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn default_net() -> MlpNet { + MlpNet::new(&MlpConfig::default(), &device()) + } + + fn zeros_obs(batch: usize) -> Tensor { + Tensor::zeros([batch, 217], &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let net = default_net(); + let obs = zeros_obs(4); + let (policy, value) = net.forward(obs); + + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let (policy, value) = net.forward(zeros_obs(1)); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let net = default_net(); + // Use a non-zero input so the output is not trivially at 0. + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net.forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Policy logits ───────────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + // With random weights the 514 logits should not all be identical. + let net = default_net(); + let (policy, _) = net.forward(zeros_obs(1)); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical — network may be degenerate"); + } + + // ── Config propagation ──────────────────────────────────────────────── + + #[test] + fn custom_config_shapes() { + let config = MlpConfig { + obs_size: 10, + action_size: 20, + hidden_size: 32, + }; + let net = MlpNet::::new(&config, &device()); + let obs = Tensor::zeros([3, 10], &device()); + let (policy, value) = net.forward(obs); + assert_eq!(policy.dims(), [3, 20]); + assert_eq!(value.dims(), [3, 1]); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = MlpConfig::default(); + let net = default_net(); + + // Forward pass before saving. + let obs = Tensor::::ones([2, 217], &device()); + let (policy_before, value_before) = net.forward(obs.clone()); + + // Save to a temp file. + let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); + net.save(&path).expect("save failed"); + + // Load into a fresh model. + let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + // Outputs must be bitwise identical. + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } +} diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs new file mode 100644 index 0000000..df710e9 --- /dev/null +++ b/spiel_bot/src/network/mod.rs @@ -0,0 +1,64 @@ +//! Neural network abstractions for policy-value learning. +//! +//! # Trait +//! +//! [`PolicyValueNet`] is the single trait that all network architectures +//! implement. It takes an observation tensor and returns raw policy logits +//! plus a tanh-squashed scalar value estimate. +//! +//! # Architectures +//! +//! | Module | Description | Default hidden | +//! |--------|-------------|----------------| +//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | +//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | +//! +//! # Backend convention +//! +//! * **Inference / self-play** — use `NdArray` (no autodiff overhead). +//! * **Training** — use `Autodiff>` so Burn can differentiate +//! through the forward pass. +//! +//! Both modes use the exact same struct; only the type-level backend changes: +//! +//! ```rust,ignore +//! use burn::backend::{Autodiff, NdArray}; +//! type InferBackend = NdArray; +//! type TrainBackend = Autodiff>; +//! +//! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); +//! ``` +//! +//! # Output shapes +//! +//! Given a batch of `B` observations of size `obs_size`: +//! +//! | Output | Shape | Range | +//! |--------|-------|-------| +//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | +//! | `value` | `[B, 1]` | (-1, 1) via tanh | +//! +//! Callers are responsible for masking illegal actions in `policy_logits` +//! before passing to softmax. + +pub mod mlp; +pub mod resnet; + +pub use mlp::{MlpConfig, MlpNet}; +pub use resnet::{ResNet, ResNetConfig}; + +use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; + +/// A neural network that produces a policy and a value from an observation. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - policy output: `[batch, action_size]` — raw logits (no softmax applied) +/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses +/// `OnceCell` for lazy parameter initialisation, which is not `Sync`. +/// Use an `Arc>` wrapper if cross-thread sharing is needed. +pub trait PolicyValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor); +} diff --git a/spiel_bot/src/network/resnet.rs b/spiel_bot/src/network/resnet.rs new file mode 100644 index 0000000..d20d5ad --- /dev/null +++ b/spiel_bot/src/network/resnet.rs @@ -0,0 +1,253 @@ +//! Residual-block policy-value network. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU (input projection) +//! → ResBlock × 4 (residual trunk) +//! ├─ policy_head: Linear(hidden → action_size) [raw logits] +//! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] +//! +//! ResBlock: +//! x → Linear → ReLU → Linear → (+x) → ReLU +//! ``` +//! +//! Compared to [`MlpNet`](super::MlpNet) this network is deeper and better +//! suited for long training runs where board-pattern recognition matters. + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{ + activation::{relu, tanh}, + backend::Backend, + Tensor, + }, +}; +use std::path::Path; + +use super::PolicyValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`ResNet`]. +#[derive(Debug, Clone)] +pub struct ResNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of all hidden layers (input projection + residual blocks). + pub hidden_size: usize, +} + +impl Default for ResNetConfig { + fn default() -> Self { + Self { + obs_size: 217, + action_size: 514, + hidden_size: 512, + } + } +} + +// ── Residual block ──────────────────────────────────────────────────────────── + +/// A single residual block: `x ↦ ReLU(fc2(ReLU(fc1(x))) + x)`. +/// +/// Both linear layers preserve the hidden dimension so the skip connection +/// can be added without projection. +#[derive(Module, Debug)] +struct ResBlock { + fc1: Linear, + fc2: Linear, +} + +impl ResBlock { + fn new(hidden: usize, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(hidden, hidden).init(device), + fc2: LinearConfig::new(hidden, hidden).init(device), + } + } + + fn forward(&self, x: Tensor) -> Tensor { + let residual = x.clone(); + let out = relu(self.fc1.forward(x)); + relu(self.fc2.forward(out) + residual) + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Four-residual-block policy-value network. +/// +/// Prefer this over [`MlpNet`](super::MlpNet) for longer training runs and +/// when representing complex positional patterns is important. +#[derive(Module, Debug)] +pub struct ResNet { + input: Linear, + block0: ResBlock, + block1: ResBlock, + block2: ResBlock, + block3: ResBlock, + policy_head: Linear, + value_head: Linear, +} + +impl ResNet { + /// Construct a fresh network with random weights. + pub fn new(config: &ResNetConfig, device: &B::Device) -> Self { + let h = config.hidden_size; + Self { + input: LinearConfig::new(config.obs_size, h).init(device), + block0: ResBlock::new(h, device), + block1: ResBlock::new(h, device), + block2: ResBlock::new(h, device), + block3: ResBlock::new(h, device), + policy_head: LinearConfig::new(h, config.action_size).init(device), + value_head: LinearConfig::new(h, 1).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("ResNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &ResNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("ResNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl PolicyValueNet for ResNet { + fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { + let x = relu(self.input.forward(obs)); + let x = self.block0.forward(x); + let x = self.block1.forward(x); + let x = self.block2.forward(x); + let x = self.block3.forward(x); + let policy = self.policy_head.forward(x.clone()); + let value = tanh(self.value_head.forward(x)); + (policy, value) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { + Default::default() + } + + fn small_config() -> ResNetConfig { + // Use a small hidden size so tests are fast. + ResNetConfig { + obs_size: 217, + action_size: 514, + hidden_size: 64, + } + } + + fn net() -> ResNet { + ResNet::new(&small_config(), &device()) + } + + // ── Shape tests ─────────────────────────────────────────────────────── + + #[test] + fn forward_output_shapes() { + let obs = Tensor::zeros([4, 217], &device()); + let (policy, value) = net().forward(obs); + assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); + assert_eq!(value.dims(), [4, 1], "value shape mismatch"); + } + + #[test] + fn forward_single_sample() { + let (policy, value) = net().forward(Tensor::zeros([1, 217], &device())); + assert_eq!(policy.dims(), [1, 514]); + assert_eq!(value.dims(), [1, 1]); + } + + // ── Value bounds ────────────────────────────────────────────────────── + + #[test] + fn value_in_tanh_range() { + let obs = Tensor::::ones([8, 217], &device()); + let (_, value) = net().forward(obs); + let data: Vec = value.into_data().to_vec().unwrap(); + for v in &data { + assert!( + *v > -1.0 && *v < 1.0, + "value {v} is outside open interval (-1, 1)" + ); + } + } + + // ── Residual connections ────────────────────────────────────────────── + + #[test] + fn policy_logits_not_all_equal() { + let (policy, _) = net().forward(Tensor::zeros([1, 217], &device())); + let data: Vec = policy.into_data().to_vec().unwrap(); + let first = data[0]; + let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); + assert!(!all_same, "all policy logits are identical"); + } + + // ── Save / Load ─────────────────────────────────────────────────────── + + #[test] + fn save_load_preserves_weights() { + let config = small_config(); + let model = net(); + let obs = Tensor::::ones([2, 217], &device()); + + let (policy_before, value_before) = model.forward(obs.clone()); + + let path = std::env::temp_dir().join("spiel_bot_test_resnet.mpk"); + model.save(&path).expect("save failed"); + + let loaded = ResNet::::load(&config, &path, &device()).expect("load failed"); + let (policy_after, value_after) = loaded.forward(obs); + + let p_before: Vec = policy_before.into_data().to_vec().unwrap(); + let p_after: Vec = policy_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let v_before: Vec = value_before.into_data().to_vec().unwrap(); + let v_after: Vec = value_after.into_data().to_vec().unwrap(); + for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); + } + + let _ = std::fs::remove_file(path); + } + + // ── Integration: both architectures satisfy PolicyValueNet ──────────── + + #[test] + fn resnet_satisfies_trait() { + fn requires_net>(net: &N, obs: Tensor) { + let (p, v) = net.forward(obs); + assert_eq!(p.dims()[1], 514); + assert_eq!(v.dims()[1], 1); + } + requires_net(&net(), Tensor::zeros([2, 217], &device())); + } +}