feat(spiel_bot): network with mlp and resnet

This commit is contained in:
Henri Bourcereau 2026-03-07 20:30:27 +01:00
parent df05a43022
commit d5cd4c2402
6 changed files with 543 additions and 0 deletions

View file

@ -0,0 +1,64 @@
//! Neural network abstractions for policy-value learning.
//!
//! # Trait
//!
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
//! implement. It takes an observation tensor and returns raw policy logits
//! plus a tanh-squashed scalar value estimate.
//!
//! # Architectures
//!
//! | Module | Description | Default hidden |
//! |--------|-------------|----------------|
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
//!
//! # Backend convention
//!
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
//! through the forward pass.
//!
//! Both modes use the exact same struct; only the type-level backend changes:
//!
//! ```rust,ignore
//! use burn::backend::{Autodiff, NdArray};
//! type InferBackend = NdArray<f32>;
//! type TrainBackend = Autodiff<NdArray<f32>>;
//!
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
//! ```
//!
//! # Output shapes
//!
//! Given a batch of `B` observations of size `obs_size`:
//!
//! | Output | Shape | Range |
//! |--------|-------|-------|
//! | `policy_logits` | `[B, action_size]` | (unnormalised) |
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
//!
//! Callers are responsible for masking illegal actions in `policy_logits`
//! before passing to softmax.
pub mod mlp;
pub mod resnet;
pub use mlp::{MlpConfig, MlpNet};
pub use resnet::{ResNet, ResNetConfig};
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
/// A neural network that produces a policy and a value from an observation.
///
/// # Shapes
/// - `obs`: `[batch, obs_size]`
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
}