diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs new file mode 100644 index 0000000..0ebe978 --- /dev/null +++ b/spiel_bot/src/bin/dqn_train.rs @@ -0,0 +1,251 @@ +//! DQN self-play training loop. +//! +//! # Usage +//! +//! ```sh +//! # Start fresh with default settings +//! cargo run -p spiel_bot --bin dqn_train --release +//! +//! # Custom hyperparameters +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 +//! +//! # Resume from a checkpoint +//! cargo run -p spiel_bot --bin dqn_train --release -- \ +//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 +//! ``` +//! +//! # Options +//! +//! | Flag | Default | Description | +//! |------|---------|-------------| +//! | `--hidden N` | 256 | Hidden layer width | +//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | +//! | `--n-iter N` | 100 | Training iterations | +//! | `--n-games N` | 10 | Self-play games per iteration | +//! | `--n-train N` | 20 | Gradient steps per iteration | +//! | `--batch N` | 64 | Mini-batch size | +//! | `--replay-cap N` | 50000 | Replay buffer capacity | +//! | `--lr F` | 1e-3 | Adam learning rate | +//! | `--epsilon-start F` | 1.0 | Initial exploration rate | +//! | `--epsilon-end F` | 0.05 | Final exploration rate | +//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | +//! | `--gamma F` | 0.99 | Discount factor | +//! | `--target-update N` | 500 | Hard-update target net every N steps | +//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | +//! | `--save-every N` | 10 | Save checkpoint every N iterations | +//! | `--seed N` | 42 | RNG seed | +//! | `--resume PATH` | (none) | Load weights before training | + +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use burn::{ + backend::{Autodiff, NdArray}, + module::AutodiffModule, + optim::AdamConfig, + tensor::backend::Backend, +}; +use rand::{SeedableRng, rngs::SmallRng}; + +use spiel_bot::{ + dqn::{ + DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, + generate_dqn_episode, hard_update, linear_epsilon, + }, + env::TrictracEnv, + network::{QNet, QNetConfig}, +}; + +type TrainB = Autodiff>; +type InferB = NdArray; + +// ── CLI ─────────────────────────────────────────────────────────────────────── + +struct Args { + hidden: usize, + out_dir: PathBuf, + save_every: usize, + seed: u64, + resume: Option, + config: DqnConfig, +} + +impl Default for Args { + fn default() -> Self { + Self { + hidden: 256, + out_dir: PathBuf::from("checkpoints"), + save_every: 10, + seed: 42, + resume: None, + config: DqnConfig::default(), + } + } +} + +fn parse_args() -> Args { + let raw: Vec = std::env::args().collect(); + let mut a = Args::default(); + let mut i = 1; + while i < raw.len() { + match raw[i].as_str() { + "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } + "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } + "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } + "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } + "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } + "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } + "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } + "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } + "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } + "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } + "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } + "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } + "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } + "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } + "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } + "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } + "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } + other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } + } + i += 1; + } + a +} + +// ── Training loop ───────────────────────────────────────────────────────────── + +fn train_loop( + mut q_net: QNet, + cfg: &QNetConfig, + save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, + args: &Args, +) { + let train_device: ::Device = Default::default(); + let infer_device: ::Device = Default::default(); + + let mut optimizer = AdamConfig::new().init(); + let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); + let mut rng = SmallRng::seed_from_u64(args.seed); + let env = TrictracEnv; + + let mut target_net: QNet = hard_update::(&q_net); + let mut global_step = 0usize; + let mut epsilon = args.config.epsilon_start; + + println!( + "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", + "", args.config.n_iterations, args.config.n_games_per_iter, + args.config.n_train_steps_per_iter, "" + ); + + for iter in 0..args.config.n_iterations { + let t0 = Instant::now(); + + // ── Self-play ──────────────────────────────────────────────────── + let infer_q: QNet = q_net.valid(); + let mut new_samples = 0usize; + + for _ in 0..args.config.n_games_per_iter { + let samples = generate_dqn_episode( + &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, + ); + new_samples += samples.len(); + replay.extend(samples); + } + + // ── Training ───────────────────────────────────────────────────── + let mut loss_sum = 0.0f32; + let mut n_steps = 0usize; + + if replay.len() >= args.config.batch_size { + for _ in 0..args.config.n_train_steps_per_iter { + let batch: Vec<_> = replay + .sample_batch(args.config.batch_size, &mut rng) + .into_iter() + .cloned() + .collect(); + + // Target Q-values computed on the inference backend. + let target_q = compute_target_q( + &target_net, &batch, cfg.action_size, &infer_device, + ); + + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &target_q, + &train_device, args.config.learning_rate, args.config.gamma, + ); + q_net = q; + loss_sum += loss; + n_steps += 1; + global_step += 1; + + // Hard-update target net every target_update_freq steps. + if global_step % args.config.target_update_freq == 0 { + target_net = hard_update::(&q_net); + } + + // Linear epsilon decay. + epsilon = linear_epsilon( + args.config.epsilon_start, + args.config.epsilon_end, + global_step, + args.config.epsilon_decay_steps, + ); + } + } + + // ── Logging ────────────────────────────────────────────────────── + let elapsed = t0.elapsed(); + let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; + + println!( + "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", + iter + 1, + args.config.n_iterations, + replay.len(), + new_samples, + avg_loss, + epsilon, + elapsed.as_secs_f32(), + ); + + // ── Checkpoint ─────────────────────────────────────────────────── + let is_last = iter + 1 == args.config.n_iterations; + if (iter + 1) % args.save_every == 0 || is_last { + let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); + match save_fn(&q_net, &path) { + Ok(()) => println!(" -> saved {}", path.display()), + Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), + } + } + } + + println!("\nDQN training complete."); +} + +// ── Main ────────────────────────────────────────────────────────────────────── + +fn main() { + let args = parse_args(); + + if let Err(e) = std::fs::create_dir_all(&args.out_dir) { + eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); + std::process::exit(1); + } + + let train_device: ::Device = Default::default(); + let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; + + let q_net = match &args.resume { + Some(path) => { + println!("Resuming from {}", path.display()); + QNet::::load(&cfg, path, &train_device) + .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) + } + None => QNet::::new(&cfg, &train_device), + }; + + train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); +} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs new file mode 100644 index 0000000..aca1343 --- /dev/null +++ b/spiel_bot/src/dqn/episode.rs @@ -0,0 +1,247 @@ +//! DQN self-play episode generation. +//! +//! Both players share the same Q-network (the [`TrictracEnv`] handles board +//! mirroring so that each player always acts from "White's perspective"). +//! Transitions for both players are stored in the returned sample list. +//! +//! # Reward +//! +//! After each full decision (action applied and the state has advanced through +//! any intervening chance nodes back to the same player's next turn), the +//! reward is: +//! +//! ```text +//! r = (my_total_score_now − my_total_score_then) +//! − (opp_total_score_now − opp_total_score_then) +//! ``` +//! +//! where `total_score = holes × 12 + points`. +//! +//! # Transition structure +//! +//! We use a "pending transition" per player. When a player acts again, we +//! *complete* the previous pending transition by filling in `next_obs`, +//! `next_legal`, and computing `reward`. Terminal transitions are completed +//! when the game ends. + +use burn::tensor::{backend::Backend, Tensor, TensorData}; +use rand::Rng; + +use crate::env::{GameEnv, TrictracEnv}; +use crate::network::QValueNet; +use super::DqnSample; + +// ── Internals ───────────────────────────────────────────────────────────────── + +struct PendingTransition { + obs: Vec, + action: usize, + /// Score snapshot `[p1_total, p2_total]` at the moment of the action. + score_before: [i32; 2], +} + +/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. +fn epsilon_greedy>( + q_net: &Q, + obs: &[f32], + legal: &[usize], + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, +) -> usize { + debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); + if rng.random::() < epsilon { + legal[rng.random_range(0..legal.len())] + } else { + let obs_tensor = Tensor::::from_data( + TensorData::new(obs.to_vec(), [1, obs.len()]), + device, + ); + let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); + legal + .iter() + .copied() + .max_by(|&a, &b| { + q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap() + } +} + +/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. +fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { + let opp_idx = 1 - player_idx; + ((score_after[player_idx] - score_before[player_idx]) + - (score_after[opp_idx] - score_before[opp_idx])) as f32 +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Play one full game and return all transitions for both players. +/// +/// - `q_net` uses the **inference backend** (no autodiff wrapper). +/// - `epsilon` in `[0, 1]`: probability of taking a random action. +/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). +pub fn generate_dqn_episode>( + env: &TrictracEnv, + q_net: &Q, + epsilon: f32, + rng: &mut impl Rng, + device: &B::Device, + reward_scale: f32, +) -> Vec { + let obs_size = env.obs_size(); + let mut state = env.new_game(); + let mut pending: [Option; 2] = [None, None]; + let mut samples: Vec = Vec::new(); + + loop { + // ── Advance past chance nodes ────────────────────────────────────── + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, rng); + } + + let score_now = TrictracEnv::score_snapshot(&state); + + if env.current_player(&state).is_terminal() { + // Complete all pending transitions as terminal. + for player_idx in 0..2 { + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: vec![0.0; obs_size], + next_legal: vec![], + done: true, + }); + } + } + break; + } + + let player_idx = env.current_player(&state).index().unwrap(); + let legal = env.legal_actions(&state); + let obs = env.observation(&state, player_idx); + + // ── Complete the previous transition for this player ─────────────── + if let Some(prev) = pending[player_idx].take() { + let reward = + compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; + samples.push(DqnSample { + obs: prev.obs, + action: prev.action, + reward, + next_obs: obs.clone(), + next_legal: legal.clone(), + done: false, + }); + } + + // ── Pick and apply action ────────────────────────────────────────── + let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); + env.apply(&mut state, action); + + // ── Record new pending transition ────────────────────────────────── + pending[player_idx] = Some(PendingTransition { + obs, + action, + score_before: score_now, + }); + } + + samples +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + use rand::{SeedableRng, rngs::SmallRng}; + + use crate::network::{QNet, QNetConfig}; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } + + fn tiny_q() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn episode_terminates_and_produces_samples() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty(), "episode must produce at least one sample"); + } + + #[test] + fn episode_obs_size_correct() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert_eq!(s.obs.len(), 217, "obs size mismatch"); + if s.done { + assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); + assert!(s.next_legal.is_empty()); + } else { + assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); + assert!(!s.next_legal.is_empty()); + } + } + } + + #[test] + fn episode_actions_within_action_space() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + for s in &samples { + assert!(s.action < 514, "action {} out of bounds", s.action); + } + } + + #[test] + fn greedy_episode_also_terminates() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); + assert!(!samples.is_empty()); + } + + #[test] + fn at_least_one_done_sample() { + let env = TrictracEnv; + let q = tiny_q(); + let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); + let n_done = samples.iter().filter(|s| s.done).count(); + // Two players, so 1 or 2 terminal transitions. + assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); + } + + #[test] + fn compute_reward_correct() { + // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. + let before = [2 * 12 + 10, 0]; + let after = [3 * 12 + 2, 0]; + let r = compute_reward(0, &before, &after); + assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); + } + + #[test] + fn compute_reward_with_opponent_scoring() { + // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. + let before = [0, 0]; + let after = [2, 3]; + let r = compute_reward(0, &before, &after); + assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); + } +} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs new file mode 100644 index 0000000..8c34fc1 --- /dev/null +++ b/spiel_bot/src/dqn/mod.rs @@ -0,0 +1,232 @@ +//! DQN: self-play data generation, replay buffer, and training step. +//! +//! # Algorithm +//! +//! Deep Q-Network with: +//! - **ε-greedy** exploration (linearly decayed). +//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where +//! `score = holes × 12 + points`. +//! - **Experience replay** with a fixed-capacity circular buffer. +//! - **Target network**: hard-copied from the online Q-net every +//! `target_update_freq` gradient steps for training stability. +//! +//! # Modules +//! +//! | Module | Contents | +//! |--------|----------| +//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | +//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | + +pub mod episode; +pub mod trainer; + +pub use episode::generate_dqn_episode; +pub use trainer::{compute_target_q, dqn_train_step, hard_update}; + +use std::collections::VecDeque; +use rand::Rng; + +// ── DqnSample ───────────────────────────────────────────────────────────────── + +/// One transition `(s, a, r, s', done)` collected during self-play. +#[derive(Clone, Debug)] +pub struct DqnSample { + /// Observation from the acting player's perspective (`obs_size` floats). + pub obs: Vec, + /// Action index taken. + pub action: usize, + /// Per-turn reward: `my_score_delta − opponent_score_delta`. + pub reward: f32, + /// Next observation from the same player's perspective. + /// All-zeros when `done = true` (ignored by the TD target). + pub next_obs: Vec, + /// Legal actions at `next_obs`. Empty when `done = true`. + pub next_legal: Vec, + /// `true` when `next_obs` is a terminal state. + pub done: bool, +} + +// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── + +/// Fixed-capacity circular replay buffer for [`DqnSample`]s. +/// +/// When full, the oldest sample is evicted on push. +/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. +pub struct DqnReplayBuffer { + data: VecDeque, + capacity: usize, +} + +impl DqnReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } + } + + pub fn push(&mut self, sample: DqnSample) { + if self.data.len() == self.capacity { + self.data.pop_front(); + } + self.data.push_back(sample); + } + + pub fn extend(&mut self, samples: impl IntoIterator) { + for s in samples { self.push(s); } + } + + pub fn len(&self) -> usize { self.data.len() } + pub fn is_empty(&self) -> bool { self.data.is_empty() } + + /// Sample up to `n` distinct samples without replacement. + pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { + let len = self.data.len(); + let n = n.min(len); + let mut indices: Vec = (0..len).collect(); + for i in 0..n { + let j = rng.random_range(i..len); + indices.swap(i, j); + } + indices[..n].iter().map(|&i| &self.data[i]).collect() + } +} + +// ── DqnConfig ───────────────────────────────────────────────────────────────── + +/// Top-level DQN hyperparameters for the training loop. +#[derive(Debug, Clone)] +pub struct DqnConfig { + /// Initial exploration rate (1.0 = fully random). + pub epsilon_start: f32, + /// Final exploration rate after decay. + pub epsilon_end: f32, + /// Number of gradient steps over which ε decays linearly from start to end. + /// + /// Should be calibrated to the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). A value larger than that + /// means exploration never reaches `epsilon_end` during the run. + pub epsilon_decay_steps: usize, + /// Discount factor γ for the TD target. Typical: 0.99. + pub gamma: f32, + /// Hard-copy Q → target every this many gradient steps. + /// + /// Should be much smaller than the total number of gradient steps + /// (`n_iterations × n_train_steps_per_iter`). + pub target_update_freq: usize, + /// Adam learning rate. + pub learning_rate: f64, + /// Mini-batch size for each gradient step. + pub batch_size: usize, + /// Maximum number of samples in the replay buffer. + pub replay_capacity: usize, + /// Number of outer iterations (self-play + train). + pub n_iterations: usize, + /// Self-play games per iteration. + pub n_games_per_iter: usize, + /// Gradient steps per iteration. + pub n_train_steps_per_iter: usize, + /// Reward normalisation divisor. + /// + /// Per-turn rewards (score delta) are divided by this constant before being + /// stored. Without normalisation, rewards can reach ±24 (jan with + /// bredouille = 12 pts × 2), driving Q-values into the hundreds and + /// causing MSE loss to grow unboundedly. + /// + /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping + /// Q-value magnitudes in a stable range. Set to `1.0` to disable. + pub reward_scale: f32, +} + +impl Default for DqnConfig { + fn default() -> Self { + // Total gradient steps with these defaults = 500 × 20 = 10_000, + // so epsilon decays fully and the target is updated 100 times. + Self { + epsilon_start: 1.0, + epsilon_end: 0.05, + epsilon_decay_steps: 10_000, + gamma: 0.99, + target_update_freq: 100, + learning_rate: 1e-3, + batch_size: 64, + replay_capacity: 50_000, + n_iterations: 500, + n_games_per_iter: 10, + n_train_steps_per_iter: 20, + reward_scale: 12.0, + } + } +} + +/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. +pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { + if decay_steps == 0 || step >= decay_steps { + return end; + } + start + (end - start) * (step as f32 / decay_steps as f32) +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + + fn dummy(reward: f32) -> DqnSample { + DqnSample { + obs: vec![0.0], + action: 0, + reward, + next_obs: vec![0.0], + next_legal: vec![0], + done: false, + } + } + + #[test] + fn push_and_len() { + let mut buf = DqnReplayBuffer::new(10); + assert!(buf.is_empty()); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + assert_eq!(buf.len(), 2); + } + + #[test] + fn evicts_oldest_at_capacity() { + let mut buf = DqnReplayBuffer::new(3); + buf.push(dummy(1.0)); + buf.push(dummy(2.0)); + buf.push(dummy(3.0)); + buf.push(dummy(4.0)); + assert_eq!(buf.len(), 3); + assert_eq!(buf.data[0].reward, 2.0); + } + + #[test] + fn sample_batch_size() { + let mut buf = DqnReplayBuffer::new(20); + for i in 0..10 { buf.push(dummy(i as f32)); } + let mut rng = SmallRng::seed_from_u64(0); + assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); + } + + #[test] + fn linear_epsilon_start() { + assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_end() { + assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); + } + + #[test] + fn linear_epsilon_monotone() { + let mut prev = f32::INFINITY; + for step in 0..=100 { + let e = linear_epsilon(1.0, 0.05, step, 100); + assert!(e <= prev + 1e-6); + prev = e; + } + } +} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs new file mode 100644 index 0000000..b8b0a02 --- /dev/null +++ b/spiel_bot/src/dqn/trainer.rs @@ -0,0 +1,278 @@ +//! DQN gradient step and target-network management. +//! +//! # TD target +//! +//! ```text +//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done +//! y_i = r_i if done +//! ``` +//! +//! # Loss +//! +//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) +//! and `y_i` (computed from the frozen target net). +//! +//! # Target network +//! +//! [`hard_update`] copies the online Q-net weights into the target net by +//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. + +use burn::{ + module::AutodiffModule, + optim::{GradientsParams, Optimizer}, + prelude::ElementConversion, + tensor::{ + Int, Tensor, TensorData, + backend::{AutodiffBackend, Backend}, + }, +}; + +use crate::network::QValueNet; +use super::DqnSample; + +// ── Target Q computation ───────────────────────────────────────────────────── + +/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. +/// +/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` +/// (their bootstrap term is dropped by the TD target anyway). +/// +/// The target network runs on the **inference backend** (`InferB`) with no +/// gradient tape, so this function is backend-agnostic (`B: Backend`). +pub fn compute_target_q>( + target_net: &Q, + batch: &[DqnSample], + action_size: usize, + device: &B::Device, +) -> Vec { + let batch_size = batch.len(); + + // Collect indices of non-done samples (done samples have no next state). + let non_done: Vec = batch + .iter() + .enumerate() + .filter(|(_, s)| !s.done) + .map(|(i, _)| i) + .collect(); + + if non_done.is_empty() { + return vec![0.0; batch_size]; + } + + let obs_size = batch[0].next_obs.len(); + let nd = non_done.len(); + + // Stack next observations for non-done samples → [nd, obs_size]. + let obs_flat: Vec = non_done + .iter() + .flat_map(|&i| batch[i].next_obs.iter().copied()) + .collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [nd, obs_size]), + device, + ); + + // Forward target net → [nd, action_size], then to Vec. + let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); + + // For each non-done sample, pick max Q over legal next actions. + let mut result = vec![0.0f32; batch_size]; + for (k, &i) in non_done.iter().enumerate() { + let legal = &batch[i].next_legal; + let offset = k * action_size; + let max_q = legal + .iter() + .map(|&a| q_flat[offset + a]) + .fold(f32::NEG_INFINITY, f32::max); + // If legal is empty (shouldn't happen for non-done, but be safe): + result[i] = if max_q.is_finite() { max_q } else { 0.0 }; + } + result +} + +// ── Training step ───────────────────────────────────────────────────────────── + +/// Run one gradient step on `q_net` using `batch`. +/// +/// `target_max_q` must be pre-computed via [`compute_target_q`] using the +/// frozen target network and passed in here so that this function only +/// needs the **autodiff backend**. +/// +/// Returns the updated network and the scalar MSE loss. +pub fn dqn_train_step( + q_net: Q, + optimizer: &mut O, + batch: &[DqnSample], + target_max_q: &[f32], + device: &B::Device, + lr: f64, + gamma: f32, +) -> (Q, f32) +where + B: AutodiffBackend, + Q: QValueNet + AutodiffModule, + O: Optimizer, +{ + assert!(!batch.is_empty(), "dqn_train_step: empty batch"); + assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); + + let batch_size = batch.len(); + let obs_size = batch[0].obs.len(); + + // ── Build observation tensor [B, obs_size] ──────────────────────────── + let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); + let obs_tensor = Tensor::::from_data( + TensorData::new(obs_flat, [batch_size, obs_size]), + device, + ); + + // ── Forward Q-net → [B, action_size] ───────────────────────────────── + let q_all = q_net.forward(obs_tensor); + + // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── + let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); + let action_tensor: Tensor = Tensor::::from_data( + TensorData::new(actions, [batch_size]), + device, + ) + .reshape([batch_size, 1]); // [B] → [B, 1] + let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] + + // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── + let targets: Vec = batch + .iter() + .zip(target_max_q.iter()) + .map(|(s, &max_q)| { + if s.done { s.reward } else { s.reward + gamma * max_q } + }) + .collect(); + let target_tensor = Tensor::::from_data( + TensorData::new(targets, [batch_size]), + device, + ); + + // ── MSE loss ────────────────────────────────────────────────────────── + let diff = q_pred - target_tensor.detach(); + let loss = (diff.clone() * diff).mean(); + let loss_scalar: f32 = loss.clone().into_scalar().elem(); + + // ── Backward + optimizer step ───────────────────────────────────────── + let grads = loss.backward(); + let grads = GradientsParams::from_grads(grads, &q_net); + let q_net = optimizer.step(lr, q_net, grads); + + (q_net, loss_scalar) +} + +// ── Target network update ───────────────────────────────────────────────────── + +/// Hard-copy the online Q-net weights to a new target network. +/// +/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an +/// inference-backend module with identical weights. +pub fn hard_update>(q_net: &Q) -> Q::InnerModule { + q_net.valid() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::{ + backend::{Autodiff, NdArray}, + optim::AdamConfig, + }; + use crate::network::{QNet, QNetConfig}; + + type InferB = NdArray; + type TrainB = Autodiff>; + + fn infer_device() -> ::Device { Default::default() } + fn train_device() -> ::Device { Default::default() } + + fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { + (0..n) + .map(|i| DqnSample { + obs: vec![0.5f32; obs_size], + action: i % action_size, + reward: if i % 2 == 0 { 1.0 } else { -1.0 }, + next_obs: vec![0.5f32; obs_size], + next_legal: vec![0, 1], + done: i == n - 1, + }) + .collect() + } + + #[test] + fn compute_target_q_length() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 8); + } + + #[test] + fn compute_target_q_done_is_zero() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let target = QNet::::new(&cfg, &infer_device()); + // Single done sample. + let batch = vec![DqnSample { + obs: vec![0.0; 4], + action: 0, + reward: 5.0, + next_obs: vec![0.0; 4], + next_legal: vec![], + done: true, + }]; + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + assert_eq!(tq.len(), 1); + assert_eq!(tq[0], 0.0); + } + + #[test] + fn train_step_returns_finite_loss() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(8, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); + assert!(loss.is_finite(), "loss must be finite, got {loss}"); + } + + #[test] + fn train_step_loss_decreases() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; + let mut q_net = QNet::::new(&cfg, &train_device()); + let target = QNet::::new(&cfg, &infer_device()); + let mut optimizer = AdamConfig::new().init(); + let batch = dummy_batch(16, 4, 4); + let tq = compute_target_q(&target, &batch, 4, &infer_device()); + + let mut prev_loss = f32::INFINITY; + for _ in 0..10 { + let (q, loss) = dqn_train_step( + q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, + ); + q_net = q; + assert!(loss.is_finite()); + prev_loss = loss; + } + assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); + } + + #[test] + fn hard_update_copies_weights() { + let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; + let q_net = QNet::::new(&cfg, &train_device()); + let target = hard_update::(&q_net); + + let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); + let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); + // After hard_update the target produces finite outputs. + assert!(q_out.iter().all(|v| v.is_finite())); + } +} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs index 99ba058..8dc3676 100644 --- a/spiel_bot/src/env/trictrac.rs +++ b/spiel_bot/src/env/trictrac.rs @@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv { } } +// ── DQN helpers ─────────────────────────────────────────────────────────────── + +impl TrictracEnv { + /// Score snapshot for DQN reward computation. + /// + /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. + /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). + pub fn score_snapshot(s: &GameState) -> [i32; 2] { + [s.total_score(1), s.total_score(2)] + } +} + // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 23895b9..9dfb4de 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,4 +1,5 @@ pub mod alphazero; +pub mod dqn; pub mod env; pub mod mcts; pub mod network; diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs index df710e9..64f93ec 100644 --- a/spiel_bot/src/network/mod.rs +++ b/spiel_bot/src/network/mod.rs @@ -43,9 +43,11 @@ //! before passing to softmax. pub mod mlp; +pub mod qnet; pub mod resnet; pub use mlp::{MlpConfig, MlpNet}; +pub use qnet::{QNet, QNetConfig}; pub use resnet::{ResNet, ResNetConfig}; use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; @@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; /// - `obs`: `[batch, obs_size]` /// - policy output: `[batch, action_size]` — raw logits (no softmax applied) /// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) +/// /// Note: `Sync` is intentionally absent — Burn's `Module` internally uses /// `OnceCell` for lazy parameter initialisation, which is not `Sync`. /// Use an `Arc>` wrapper if cross-thread sharing is needed. pub trait PolicyValueNet: Module + Send + 'static { fn forward(&self, obs: Tensor) -> (Tensor, Tensor); } + +/// A neural network that outputs one Q-value per action. +/// +/// # Shapes +/// - `obs`: `[batch, obs_size]` +/// - output: `[batch, action_size]` — raw Q-values (no activation) +/// +/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. +pub trait QValueNet: Module + Send + 'static { + fn forward(&self, obs: Tensor) -> Tensor; +} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs new file mode 100644 index 0000000..1737f72 --- /dev/null +++ b/spiel_bot/src/network/qnet.rs @@ -0,0 +1,147 @@ +//! Single-headed Q-value network for DQN. +//! +//! ```text +//! Input [B, obs_size] +//! → Linear(obs → hidden) → ReLU +//! → Linear(hidden → hidden) → ReLU +//! → Linear(hidden → action_size) ← raw Q-values, no activation +//! ``` + +use burn::{ + module::Module, + nn::{Linear, LinearConfig}, + record::{CompactRecorder, Recorder}, + tensor::{activation::relu, backend::Backend, Tensor}, +}; +use std::path::Path; + +use super::QValueNet; + +// ── Config ──────────────────────────────────────────────────────────────────── + +/// Configuration for [`QNet`]. +#[derive(Debug, Clone)] +pub struct QNetConfig { + /// Number of input features. 217 for Trictrac's `to_tensor()`. + pub obs_size: usize, + /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. + pub action_size: usize, + /// Width of both hidden layers. + pub hidden_size: usize, +} + +impl Default for QNetConfig { + fn default() -> Self { + Self { obs_size: 217, action_size: 514, hidden_size: 256 } + } +} + +// ── Network ─────────────────────────────────────────────────────────────────── + +/// Two-hidden-layer MLP that outputs one Q-value per action. +#[derive(Module, Debug)] +pub struct QNet { + fc1: Linear, + fc2: Linear, + q_head: Linear, +} + +impl QNet { + /// Construct a fresh network with random weights. + pub fn new(config: &QNetConfig, device: &B::Device) -> Self { + Self { + fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), + fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), + q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), + } + } + + /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + CompactRecorder::new() + .record(self.clone().into_record(), path.to_path_buf()) + .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) + } + + /// Load weights from `path` into a fresh model built from `config`. + pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { + let record = CompactRecorder::new() + .load(path.to_path_buf(), device) + .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; + Ok(Self::new(config, device).load_record(record)) + } +} + +impl QValueNet for QNet { + fn forward(&self, obs: Tensor) -> Tensor { + let x = relu(self.fc1.forward(obs)); + let x = relu(self.fc2.forward(x)); + self.q_head.forward(x) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use burn::backend::NdArray; + + type B = NdArray; + + fn device() -> ::Device { Default::default() } + + fn default_net() -> QNet { + QNet::new(&QNetConfig::default(), &device()) + } + + #[test] + fn forward_output_shape() { + let net = default_net(); + let obs = Tensor::zeros([4, 217], &device()); + let q = net.forward(obs); + assert_eq!(q.dims(), [4, 514]); + } + + #[test] + fn forward_single_sample() { + let net = default_net(); + let q = net.forward(Tensor::zeros([1, 217], &device())); + assert_eq!(q.dims(), [1, 514]); + } + + #[test] + fn q_values_not_all_equal() { + let net = default_net(); + let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) + .into_data().to_vec().unwrap(); + let first = q[0]; + assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); + } + + #[test] + fn custom_config_shapes() { + let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; + let net = QNet::::new(&cfg, &device()); + let q = net.forward(Tensor::zeros([3, 10], &device())); + assert_eq!(q.dims(), [3, 20]); + } + + #[test] + fn save_load_preserves_weights() { + let net = default_net(); + let obs = Tensor::::ones([2, 217], &device()); + let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); + + let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); + net.save(&path).expect("save failed"); + + let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); + let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); + + for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { + assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); + } + let _ = std::fs::remove_file(path); + } +} diff --git a/store/src/game.rs b/store/src/game.rs index 2fde45c..e4e938c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -1011,6 +1011,16 @@ impl GameState { self.mark_points(player_id, points) } + /// Total accumulated score for a player: `holes × 12 + points`. + /// + /// Returns `0` if `player_id` is not found (e.g. before `init_player`). + pub fn total_score(&self, player_id: PlayerId) -> i32 { + self.players + .get(&player_id) + .map(|p| p.holes as i32 * 12 + p.points as i32) + .unwrap_or(0) + } + fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false;