64 lines
2.3 KiB
Rust
64 lines
2.3 KiB
Rust
//! Neural network abstractions for policy-value learning.
|
||
//!
|
||
//! # Trait
|
||
//!
|
||
//! [`PolicyValueNet<B>`] is the single trait that all network architectures
|
||
//! implement. It takes an observation tensor and returns raw policy logits
|
||
//! plus a tanh-squashed scalar value estimate.
|
||
//!
|
||
//! # Architectures
|
||
//!
|
||
//! | Module | Description | Default hidden |
|
||
//! |--------|-------------|----------------|
|
||
//! | [`MlpNet`] | 2-hidden-layer MLP — fast to train, good baseline | 256 |
|
||
//! | [`ResNet`] | 4-residual-block network — stronger long-term | 512 |
|
||
//!
|
||
//! # Backend convention
|
||
//!
|
||
//! * **Inference / self-play** — use `NdArray<f32>` (no autodiff overhead).
|
||
//! * **Training** — use `Autodiff<NdArray<f32>>` so Burn can differentiate
|
||
//! through the forward pass.
|
||
//!
|
||
//! Both modes use the exact same struct; only the type-level backend changes:
|
||
//!
|
||
//! ```rust,ignore
|
||
//! use burn::backend::{Autodiff, NdArray};
|
||
//! type InferBackend = NdArray<f32>;
|
||
//! type TrainBackend = Autodiff<NdArray<f32>>;
|
||
//!
|
||
//! let infer_net = MlpNet::<InferBackend>::new(&MlpConfig::default(), &Default::default());
|
||
//! let train_net = MlpNet::<TrainBackend>::new(&MlpConfig::default(), &Default::default());
|
||
//! ```
|
||
//!
|
||
//! # Output shapes
|
||
//!
|
||
//! Given a batch of `B` observations of size `obs_size`:
|
||
//!
|
||
//! | Output | Shape | Range |
|
||
//! |--------|-------|-------|
|
||
//! | `policy_logits` | `[B, action_size]` | ℝ (unnormalised) |
|
||
//! | `value` | `[B, 1]` | (-1, 1) via tanh |
|
||
//!
|
||
//! Callers are responsible for masking illegal actions in `policy_logits`
|
||
//! before passing to softmax.
|
||
|
||
pub mod mlp;
|
||
pub mod resnet;
|
||
|
||
pub use mlp::{MlpConfig, MlpNet};
|
||
pub use resnet::{ResNet, ResNetConfig};
|
||
|
||
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||
|
||
/// A neural network that produces a policy and a value from an observation.
|
||
///
|
||
/// # Shapes
|
||
/// - `obs`: `[batch, obs_size]`
|
||
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||
}
|