//! Two-hidden-layer MLP policy-value network. //! //! ```text //! Input [B, obs_size] //! → Linear(obs → hidden) → ReLU //! → Linear(hidden → hidden) → ReLU //! ├─ policy_head: Linear(hidden → action_size) [raw logits] //! └─ value_head: Linear(hidden → 1) → tanh [∈ (-1, 1)] //! ``` use burn::{ module::Module, nn::{Linear, LinearConfig}, record::{CompactRecorder, Recorder}, tensor::{ activation::{relu, tanh}, backend::Backend, Tensor, }, }; use std::path::Path; use super::PolicyValueNet; // ── Config ──────────────────────────────────────────────────────────────────── /// Configuration for [`MlpNet`]. #[derive(Debug, Clone)] pub struct MlpConfig { /// Number of input features. 217 for Trictrac's `to_tensor()`. pub obs_size: usize, /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. pub action_size: usize, /// Width of both hidden layers. pub hidden_size: usize, } impl Default for MlpConfig { fn default() -> Self { Self { obs_size: 217, action_size: 514, hidden_size: 256, } } } // ── Network ─────────────────────────────────────────────────────────────────── /// Simple two-hidden-layer MLP with shared trunk and two heads. /// /// Prefer this over [`ResNet`](super::ResNet) when training time is a /// priority, or as a fast baseline. #[derive(Module, Debug)] pub struct MlpNet { fc1: Linear, fc2: Linear, policy_head: Linear, value_head: Linear, } impl MlpNet { /// Construct a fresh network with random weights. pub fn new(config: &MlpConfig, device: &B::Device) -> Self { Self { fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), policy_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), value_head: LinearConfig::new(config.hidden_size, 1).init(device), } } /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). /// /// The file is written exactly at `path`; callers should append `.mpk` if /// they want the conventional extension. pub fn save(&self, path: &Path) -> anyhow::Result<()> { CompactRecorder::new() .record(self.clone().into_record(), path.to_path_buf()) .map_err(|e| anyhow::anyhow!("MlpNet::save failed: {e:?}")) } /// Load weights from `path` into a fresh model built from `config`. pub fn load(config: &MlpConfig, path: &Path, device: &B::Device) -> anyhow::Result { let record = CompactRecorder::new() .load(path.to_path_buf(), device) .map_err(|e| anyhow::anyhow!("MlpNet::load failed: {e:?}"))?; Ok(Self::new(config, device).load_record(record)) } } impl PolicyValueNet for MlpNet { fn forward(&self, obs: Tensor) -> (Tensor, Tensor) { let x = relu(self.fc1.forward(obs)); let x = relu(self.fc2.forward(x)); let policy = self.policy_head.forward(x.clone()); let value = tanh(self.value_head.forward(x)); (policy, value) } } // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; use burn::backend::NdArray; type B = NdArray; fn device() -> ::Device { Default::default() } fn default_net() -> MlpNet { MlpNet::new(&MlpConfig::default(), &device()) } fn zeros_obs(batch: usize) -> Tensor { Tensor::zeros([batch, 217], &device()) } // ── Shape tests ─────────────────────────────────────────────────────── #[test] fn forward_output_shapes() { let net = default_net(); let obs = zeros_obs(4); let (policy, value) = net.forward(obs); assert_eq!(policy.dims(), [4, 514], "policy shape mismatch"); assert_eq!(value.dims(), [4, 1], "value shape mismatch"); } #[test] fn forward_single_sample() { let net = default_net(); let (policy, value) = net.forward(zeros_obs(1)); assert_eq!(policy.dims(), [1, 514]); assert_eq!(value.dims(), [1, 1]); } // ── Value bounds ────────────────────────────────────────────────────── #[test] fn value_in_tanh_range() { let net = default_net(); // Use a non-zero input so the output is not trivially at 0. let obs = Tensor::::ones([8, 217], &device()); let (_, value) = net.forward(obs); let data: Vec = value.into_data().to_vec().unwrap(); for v in &data { assert!( *v > -1.0 && *v < 1.0, "value {v} is outside open interval (-1, 1)" ); } } // ── Policy logits ───────────────────────────────────────────────────── #[test] fn policy_logits_not_all_equal() { // With random weights the 514 logits should not all be identical. let net = default_net(); let (policy, _) = net.forward(zeros_obs(1)); let data: Vec = policy.into_data().to_vec().unwrap(); let first = data[0]; let all_same = data.iter().all(|&x| (x - first).abs() < 1e-6); assert!(!all_same, "all policy logits are identical — network may be degenerate"); } // ── Config propagation ──────────────────────────────────────────────── #[test] fn custom_config_shapes() { let config = MlpConfig { obs_size: 10, action_size: 20, hidden_size: 32, }; let net = MlpNet::::new(&config, &device()); let obs = Tensor::zeros([3, 10], &device()); let (policy, value) = net.forward(obs); assert_eq!(policy.dims(), [3, 20]); assert_eq!(value.dims(), [3, 1]); } // ── Save / Load ─────────────────────────────────────────────────────── #[test] fn save_load_preserves_weights() { let config = MlpConfig::default(); let net = default_net(); // Forward pass before saving. let obs = Tensor::::ones([2, 217], &device()); let (policy_before, value_before) = net.forward(obs.clone()); // Save to a temp file. let path = std::env::temp_dir().join("spiel_bot_test_mlp.mpk"); net.save(&path).expect("save failed"); // Load into a fresh model. let loaded = MlpNet::::load(&config, &path, &device()).expect("load failed"); let (policy_after, value_after) = loaded.forward(obs); // Outputs must be bitwise identical. let p_before: Vec = policy_before.into_data().to_vec().unwrap(); let p_after: Vec = policy_after.into_data().to_vec().unwrap(); for (i, (a, b)) in p_before.iter().zip(p_after.iter()).enumerate() { assert!((a - b).abs() < 1e-3, "policy[{i}]: {a} vs {b} differ by more than tolerance"); } let v_before: Vec = value_before.into_data().to_vec().unwrap(); let v_after: Vec = value_after.into_data().to_vec().unwrap(); for (i, (a, b)) in v_before.iter().zip(v_after.iter()).enumerate() { assert!((a - b).abs() < 1e-3, "value[{i}]: {a} vs {b} differ by more than tolerance"); } let _ = std::fs::remove_file(path); } }