feat(spiel_bot): dqn
This commit is contained in:
parent
7c0f230e3d
commit
e7d13c9a02
9 changed files with 1192 additions and 0 deletions
|
|
@ -43,9 +43,11 @@
|
|||
//! before passing to softmax.
|
||||
|
||||
pub mod mlp;
|
||||
pub mod qnet;
|
||||
pub mod resnet;
|
||||
|
||||
pub use mlp::{MlpConfig, MlpNet};
|
||||
pub use qnet::{QNet, QNetConfig};
|
||||
pub use resnet::{ResNet, ResNetConfig};
|
||||
|
||||
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||
|
|
@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
|||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||||
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||||
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||||
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||||
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||||
}
|
||||
|
||||
/// A neural network that outputs one Q-value per action.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - output: `[batch, action_size]` — raw Q-values (no activation)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
|
||||
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
|
||||
}
|
||||
|
|
|
|||
147
spiel_bot/src/network/qnet.rs
Normal file
147
spiel_bot/src/network/qnet.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
//! Single-headed Q-value network for DQN.
|
||||
//!
|
||||
//! ```text
|
||||
//! Input [B, obs_size]
|
||||
//! → Linear(obs → hidden) → ReLU
|
||||
//! → Linear(hidden → hidden) → ReLU
|
||||
//! → Linear(hidden → action_size) ← raw Q-values, no activation
|
||||
//! ```
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig},
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::{activation::relu, backend::Backend, Tensor},
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use super::QValueNet;
|
||||
|
||||
// ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for [`QNet`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QNetConfig {
|
||||
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||
pub obs_size: usize,
|
||||
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||
pub action_size: usize,
|
||||
/// Width of both hidden layers.
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Default for QNetConfig {
|
||||
fn default() -> Self {
|
||||
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Network ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Two-hidden-layer MLP that outputs one Q-value per action.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct QNet<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
q_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> QNet<B> {
|
||||
/// Construct a fresh network with random weights.
|
||||
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
CompactRecorder::new()
|
||||
.record(self.clone().into_record(), path.to_path_buf())
|
||||
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
|
||||
}
|
||||
|
||||
/// Load weights from `path` into a fresh model built from `config`.
|
||||
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||
let record = CompactRecorder::new()
|
||||
.load(path.to_path_buf(), device)
|
||||
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
|
||||
Ok(Self::new(config, device).load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> QValueNet<B> for QNet<B> {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = relu(self.fc1.forward(obs));
|
||||
let x = relu(self.fc2.forward(x));
|
||||
self.q_head.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device { Default::default() }
|
||||
|
||||
fn default_net() -> QNet<B> {
|
||||
QNet::new(&QNetConfig::default(), &device())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_output_shape() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::zeros([4, 217], &device());
|
||||
let q = net.forward(obs);
|
||||
assert_eq!(q.dims(), [4, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_single_sample() {
|
||||
let net = default_net();
|
||||
let q = net.forward(Tensor::zeros([1, 217], &device()));
|
||||
assert_eq!(q.dims(), [1, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn q_values_not_all_equal() {
|
||||
let net = default_net();
|
||||
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
|
||||
.into_data().to_vec().unwrap();
|
||||
let first = q[0];
|
||||
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_config_shapes() {
|
||||
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
|
||||
let net = QNet::<B>::new(&cfg, &device());
|
||||
let q = net.forward(Tensor::zeros([3, 10], &device()));
|
||||
assert_eq!(q.dims(), [3, 20]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_load_preserves_weights() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
|
||||
|
||||
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
|
||||
net.save(&path).expect("save failed");
|
||||
|
||||
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
|
||||
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
|
||||
|
||||
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
|
||||
}
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue