//! Neural network abstractions for policy-value learning. //! //! # Trait //! //! [`PolicyValueNet`] is the single trait that all network architectures //! implement. It takes an observation tensor and returns raw policy logits //! plus a tanh-squashed scalar value estimate. //! //! # Architectures //! //! | Module | Description | Default hidden | //! |--------|-------------|----------------| //! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 | //! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 | //! //! # Backend convention //! //! * **Inference / self-play** — use `NdArray` (no autodiff overhead). //! * **Training** — use `Autodiff>` so Burn can differentiate //! through the forward pass. //! //! Both modes use the exact same struct; only the type-level backend changes: //! //! ```rust,ignore //! use burn::backend::{Autodiff, NdArray}; //! type InferBackend = NdArray; //! type TrainBackend = Autodiff>; //! //! let infer_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); //! let train_net = MlpNet::::new(&MlpConfig::default(), &Default::default()); //! ``` //! //! # Output shapes //! //! Given a batch of `B` observations of size `obs_size`: //! //! | Output | Shape | Range | //! |--------|-------|-------| //! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) | //! | `value` | `[B, 1]` | (-1, 1) via tanh | //! //! Callers are responsible for masking illegal actions in `policy_logits` //! before passing to softmax. pub mod mlp; pub mod resnet; pub use mlp::{MlpConfig, MlpNet}; pub use resnet::{ResNet, ResNetConfig}; use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; /// A neural network that produces a policy and a value from an observation. /// /// # Shapes /// - `obs`: `[batch, obs_size]` /// - policy output: `[batch, action_size]` — raw logits (no softmax applied) /// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) /// Note: `Sync` is intentionally absent — Burn's `Module` internally uses /// `OnceCell` for lazy parameter initialisation, which is not `Sync`. /// Use an `Arc>` wrapper if cross-thread sharing is needed. pub trait PolicyValueNet: Module + Send + 'static { fn forward(&self, obs: Tensor) -> (Tensor, Tensor); }