feat(spiel_bot): dqn
This commit is contained in:
parent
7c0f230e3d
commit
e7d13c9a02
9 changed files with 1192 additions and 0 deletions
251
spiel_bot/src/bin/dqn_train.rs
Normal file
251
spiel_bot/src/bin/dqn_train.rs
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
//! DQN self-play training loop.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```sh
|
||||
//! # Start fresh with default settings
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release
|
||||
//!
|
||||
//! # Custom hyperparameters
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||
//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000
|
||||
//!
|
||||
//! # Resume from a checkpoint
|
||||
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||
//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100
|
||||
//! ```
|
||||
//!
|
||||
//! # Options
|
||||
//!
|
||||
//! | Flag | Default | Description |
|
||||
//! |------|---------|-------------|
|
||||
//! | `--hidden N` | 256 | Hidden layer width |
|
||||
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
|
||||
//! | `--n-iter N` | 100 | Training iterations |
|
||||
//! | `--n-games N` | 10 | Self-play games per iteration |
|
||||
//! | `--n-train N` | 20 | Gradient steps per iteration |
|
||||
//! | `--batch N` | 64 | Mini-batch size |
|
||||
//! | `--replay-cap N` | 50000 | Replay buffer capacity |
|
||||
//! | `--lr F` | 1e-3 | Adam learning rate |
|
||||
//! | `--epsilon-start F` | 1.0 | Initial exploration rate |
|
||||
//! | `--epsilon-end F` | 0.05 | Final exploration rate |
|
||||
//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor |
|
||||
//! | `--gamma F` | 0.99 | Discount factor |
|
||||
//! | `--target-update N` | 500 | Hard-update target net every N steps |
|
||||
//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) |
|
||||
//! | `--save-every N` | 10 | Save checkpoint every N iterations |
|
||||
//! | `--seed N` | 42 | RNG seed |
|
||||
//! | `--resume PATH` | (none) | Load weights before training |
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
module::AutodiffModule,
|
||||
optim::AdamConfig,
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use spiel_bot::{
|
||||
dqn::{
|
||||
DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step,
|
||||
generate_dqn_episode, hard_update, linear_epsilon,
|
||||
},
|
||||
env::TrictracEnv,
|
||||
network::{QNet, QNetConfig},
|
||||
};
|
||||
|
||||
type TrainB = Autodiff<NdArray<f32>>;
|
||||
type InferB = NdArray<f32>;
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
struct Args {
|
||||
hidden: usize,
|
||||
out_dir: PathBuf,
|
||||
save_every: usize,
|
||||
seed: u64,
|
||||
resume: Option<PathBuf>,
|
||||
config: DqnConfig,
|
||||
}
|
||||
|
||||
impl Default for Args {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hidden: 256,
|
||||
out_dir: PathBuf::from("checkpoints"),
|
||||
save_every: 10,
|
||||
seed: 42,
|
||||
resume: None,
|
||||
config: DqnConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let raw: Vec<String> = std::env::args().collect();
|
||||
let mut a = Args::default();
|
||||
let mut i = 1;
|
||||
while i < raw.len() {
|
||||
match raw[i].as_str() {
|
||||
"--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); }
|
||||
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
|
||||
"--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); }
|
||||
"--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); }
|
||||
"--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); }
|
||||
"--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); }
|
||||
"--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); }
|
||||
"--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); }
|
||||
"--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); }
|
||||
"--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); }
|
||||
"--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); }
|
||||
"--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); }
|
||||
"--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); }
|
||||
"--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); }
|
||||
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
|
||||
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
|
||||
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
|
||||
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
// ── Training loop ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn train_loop(
|
||||
mut q_net: QNet<TrainB>,
|
||||
cfg: &QNetConfig,
|
||||
save_fn: &dyn Fn(&QNet<TrainB>, &Path) -> anyhow::Result<()>,
|
||||
args: &Args,
|
||||
) {
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
let infer_device: <InferB as Backend>::Device = Default::default();
|
||||
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let mut replay = DqnReplayBuffer::new(args.config.replay_capacity);
|
||||
let mut rng = SmallRng::seed_from_u64(args.seed);
|
||||
let env = TrictracEnv;
|
||||
|
||||
let mut target_net: QNet<InferB> = hard_update::<TrainB, _>(&q_net);
|
||||
let mut global_step = 0usize;
|
||||
let mut epsilon = args.config.epsilon_start;
|
||||
|
||||
println!(
|
||||
"\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}",
|
||||
"", args.config.n_iterations, args.config.n_games_per_iter,
|
||||
args.config.n_train_steps_per_iter, ""
|
||||
);
|
||||
|
||||
for iter in 0..args.config.n_iterations {
|
||||
let t0 = Instant::now();
|
||||
|
||||
// ── Self-play ────────────────────────────────────────────────────
|
||||
let infer_q: QNet<InferB> = q_net.valid();
|
||||
let mut new_samples = 0usize;
|
||||
|
||||
for _ in 0..args.config.n_games_per_iter {
|
||||
let samples = generate_dqn_episode(
|
||||
&env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale,
|
||||
);
|
||||
new_samples += samples.len();
|
||||
replay.extend(samples);
|
||||
}
|
||||
|
||||
// ── Training ─────────────────────────────────────────────────────
|
||||
let mut loss_sum = 0.0f32;
|
||||
let mut n_steps = 0usize;
|
||||
|
||||
if replay.len() >= args.config.batch_size {
|
||||
for _ in 0..args.config.n_train_steps_per_iter {
|
||||
let batch: Vec<_> = replay
|
||||
.sample_batch(args.config.batch_size, &mut rng)
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Target Q-values computed on the inference backend.
|
||||
let target_q = compute_target_q(
|
||||
&target_net, &batch, cfg.action_size, &infer_device,
|
||||
);
|
||||
|
||||
let (q, loss) = dqn_train_step(
|
||||
q_net, &mut optimizer, &batch, &target_q,
|
||||
&train_device, args.config.learning_rate, args.config.gamma,
|
||||
);
|
||||
q_net = q;
|
||||
loss_sum += loss;
|
||||
n_steps += 1;
|
||||
global_step += 1;
|
||||
|
||||
// Hard-update target net every target_update_freq steps.
|
||||
if global_step % args.config.target_update_freq == 0 {
|
||||
target_net = hard_update::<TrainB, _>(&q_net);
|
||||
}
|
||||
|
||||
// Linear epsilon decay.
|
||||
epsilon = linear_epsilon(
|
||||
args.config.epsilon_start,
|
||||
args.config.epsilon_end,
|
||||
global_step,
|
||||
args.config.epsilon_decay_steps,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Logging ──────────────────────────────────────────────────────
|
||||
let elapsed = t0.elapsed();
|
||||
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
|
||||
|
||||
println!(
|
||||
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s",
|
||||
iter + 1,
|
||||
args.config.n_iterations,
|
||||
replay.len(),
|
||||
new_samples,
|
||||
avg_loss,
|
||||
epsilon,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
// ── Checkpoint ───────────────────────────────────────────────────
|
||||
let is_last = iter + 1 == args.config.n_iterations;
|
||||
if (iter + 1) % args.save_every == 0 || is_last {
|
||||
let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1));
|
||||
match save_fn(&q_net, &path) {
|
||||
Ok(()) => println!(" -> saved {}", path.display()),
|
||||
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nDQN training complete.");
|
||||
}
|
||||
|
||||
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn main() {
|
||||
let args = parse_args();
|
||||
|
||||
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
|
||||
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||
let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden };
|
||||
|
||||
let q_net = match &args.resume {
|
||||
Some(path) => {
|
||||
println!("Resuming from {}", path.display());
|
||||
QNet::<TrainB>::load(&cfg, path, &train_device)
|
||||
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
|
||||
}
|
||||
None => QNet::<TrainB>::new(&cfg, &train_device),
|
||||
};
|
||||
|
||||
train_loop(q_net, &cfg, &|m: &QNet<TrainB>, path| m.valid().save(path), &args);
|
||||
}
|
||||
247
spiel_bot/src/dqn/episode.rs
Normal file
247
spiel_bot/src/dqn/episode.rs
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
//! DQN self-play episode generation.
|
||||
//!
|
||||
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
|
||||
//! mirroring so that each player always acts from "White's perspective").
|
||||
//! Transitions for both players are stored in the returned sample list.
|
||||
//!
|
||||
//! # Reward
|
||||
//!
|
||||
//! After each full decision (action applied and the state has advanced through
|
||||
//! any intervening chance nodes back to the same player's next turn), the
|
||||
//! reward is:
|
||||
//!
|
||||
//! ```text
|
||||
//! r = (my_total_score_now − my_total_score_then)
|
||||
//! − (opp_total_score_now − opp_total_score_then)
|
||||
//! ```
|
||||
//!
|
||||
//! where `total_score = holes × 12 + points`.
|
||||
//!
|
||||
//! # Transition structure
|
||||
//!
|
||||
//! We use a "pending transition" per player. When a player acts again, we
|
||||
//! *complete* the previous pending transition by filling in `next_obs`,
|
||||
//! `next_legal`, and computing `reward`. Terminal transitions are completed
|
||||
//! when the game ends.
|
||||
|
||||
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::env::{GameEnv, TrictracEnv};
|
||||
use crate::network::QValueNet;
|
||||
use super::DqnSample;
|
||||
|
||||
// ── Internals ─────────────────────────────────────────────────────────────────
|
||||
|
||||
struct PendingTransition {
|
||||
obs: Vec<f32>,
|
||||
action: usize,
|
||||
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
|
||||
score_before: [i32; 2],
|
||||
}
|
||||
|
||||
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
|
||||
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
|
||||
q_net: &Q,
|
||||
obs: &[f32],
|
||||
legal: &[usize],
|
||||
epsilon: f32,
|
||||
rng: &mut impl Rng,
|
||||
device: &B::Device,
|
||||
) -> usize {
|
||||
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
|
||||
if rng.random::<f32>() < epsilon {
|
||||
legal[rng.random_range(0..legal.len())]
|
||||
} else {
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs.to_vec(), [1, obs.len()]),
|
||||
device,
|
||||
);
|
||||
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||
legal
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|&a, &b| {
|
||||
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
|
||||
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
|
||||
let opp_idx = 1 - player_idx;
|
||||
((score_after[player_idx] - score_before[player_idx])
|
||||
- (score_after[opp_idx] - score_before[opp_idx])) as f32
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Play one full game and return all transitions for both players.
|
||||
///
|
||||
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
|
||||
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
|
||||
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
|
||||
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
|
||||
env: &TrictracEnv,
|
||||
q_net: &Q,
|
||||
epsilon: f32,
|
||||
rng: &mut impl Rng,
|
||||
device: &B::Device,
|
||||
reward_scale: f32,
|
||||
) -> Vec<DqnSample> {
|
||||
let obs_size = env.obs_size();
|
||||
let mut state = env.new_game();
|
||||
let mut pending: [Option<PendingTransition>; 2] = [None, None];
|
||||
let mut samples: Vec<DqnSample> = Vec::new();
|
||||
|
||||
loop {
|
||||
// ── Advance past chance nodes ──────────────────────────────────────
|
||||
while env.current_player(&state).is_chance() {
|
||||
env.apply_chance(&mut state, rng);
|
||||
}
|
||||
|
||||
let score_now = TrictracEnv::score_snapshot(&state);
|
||||
|
||||
if env.current_player(&state).is_terminal() {
|
||||
// Complete all pending transitions as terminal.
|
||||
for player_idx in 0..2 {
|
||||
if let Some(prev) = pending[player_idx].take() {
|
||||
let reward =
|
||||
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||
samples.push(DqnSample {
|
||||
obs: prev.obs,
|
||||
action: prev.action,
|
||||
reward,
|
||||
next_obs: vec![0.0; obs_size],
|
||||
next_legal: vec![],
|
||||
done: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let player_idx = env.current_player(&state).index().unwrap();
|
||||
let legal = env.legal_actions(&state);
|
||||
let obs = env.observation(&state, player_idx);
|
||||
|
||||
// ── Complete the previous transition for this player ───────────────
|
||||
if let Some(prev) = pending[player_idx].take() {
|
||||
let reward =
|
||||
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||
samples.push(DqnSample {
|
||||
obs: prev.obs,
|
||||
action: prev.action,
|
||||
reward,
|
||||
next_obs: obs.clone(),
|
||||
next_legal: legal.clone(),
|
||||
done: false,
|
||||
});
|
||||
}
|
||||
|
||||
// ── Pick and apply action ──────────────────────────────────────────
|
||||
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
|
||||
env.apply(&mut state, action);
|
||||
|
||||
// ── Record new pending transition ──────────────────────────────────
|
||||
pending[player_idx] = Some(PendingTransition {
|
||||
obs,
|
||||
action,
|
||||
score_before: score_now,
|
||||
});
|
||||
}
|
||||
|
||||
samples
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
use crate::network::{QNet, QNetConfig};
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device { Default::default() }
|
||||
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
|
||||
|
||||
fn tiny_q() -> QNet<B> {
|
||||
QNet::new(&QNetConfig::default(), &device())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_terminates_and_produces_samples() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_obs_size_correct() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
for s in &samples {
|
||||
assert_eq!(s.obs.len(), 217, "obs size mismatch");
|
||||
if s.done {
|
||||
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
|
||||
assert!(s.next_legal.is_empty());
|
||||
} else {
|
||||
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
|
||||
assert!(!s.next_legal.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn episode_actions_within_action_space() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
for s in &samples {
|
||||
assert!(s.action < 514, "action {} out of bounds", s.action);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greedy_episode_also_terminates() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
|
||||
assert!(!samples.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn at_least_one_done_sample() {
|
||||
let env = TrictracEnv;
|
||||
let q = tiny_q();
|
||||
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||
let n_done = samples.iter().filter(|s| s.done).count();
|
||||
// Two players, so 1 or 2 terminal transitions.
|
||||
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_reward_correct() {
|
||||
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
|
||||
let before = [2 * 12 + 10, 0];
|
||||
let after = [3 * 12 + 2, 0];
|
||||
let r = compute_reward(0, &before, &after);
|
||||
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_reward_with_opponent_scoring() {
|
||||
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
|
||||
let before = [0, 0];
|
||||
let after = [2, 3];
|
||||
let r = compute_reward(0, &before, &after);
|
||||
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
|
||||
}
|
||||
}
|
||||
232
spiel_bot/src/dqn/mod.rs
Normal file
232
spiel_bot/src/dqn/mod.rs
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
//! DQN: self-play data generation, replay buffer, and training step.
|
||||
//!
|
||||
//! # Algorithm
|
||||
//!
|
||||
//! Deep Q-Network with:
|
||||
//! - **ε-greedy** exploration (linearly decayed).
|
||||
//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where
|
||||
//! `score = holes × 12 + points`.
|
||||
//! - **Experience replay** with a fixed-capacity circular buffer.
|
||||
//! - **Target network**: hard-copied from the online Q-net every
|
||||
//! `target_update_freq` gradient steps for training stability.
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! | Module | Contents |
|
||||
//! |--------|----------|
|
||||
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
|
||||
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
|
||||
|
||||
pub mod episode;
|
||||
pub mod trainer;
|
||||
|
||||
pub use episode::generate_dqn_episode;
|
||||
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use rand::Rng;
|
||||
|
||||
// ── DqnSample ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// One transition `(s, a, r, s', done)` collected during self-play.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DqnSample {
|
||||
/// Observation from the acting player's perspective (`obs_size` floats).
|
||||
pub obs: Vec<f32>,
|
||||
/// Action index taken.
|
||||
pub action: usize,
|
||||
/// Per-turn reward: `my_score_delta − opponent_score_delta`.
|
||||
pub reward: f32,
|
||||
/// Next observation from the same player's perspective.
|
||||
/// All-zeros when `done = true` (ignored by the TD target).
|
||||
pub next_obs: Vec<f32>,
|
||||
/// Legal actions at `next_obs`. Empty when `done = true`.
|
||||
pub next_legal: Vec<usize>,
|
||||
/// `true` when `next_obs` is a terminal state.
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
|
||||
///
|
||||
/// When full, the oldest sample is evicted on push.
|
||||
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
|
||||
pub struct DqnReplayBuffer {
|
||||
data: VecDeque<DqnSample>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl DqnReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
|
||||
}
|
||||
|
||||
pub fn push(&mut self, sample: DqnSample) {
|
||||
if self.data.len() == self.capacity {
|
||||
self.data.pop_front();
|
||||
}
|
||||
self.data.push_back(sample);
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
|
||||
for s in samples { self.push(s); }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.data.len() }
|
||||
pub fn is_empty(&self) -> bool { self.data.is_empty() }
|
||||
|
||||
/// Sample up to `n` distinct samples without replacement.
|
||||
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
|
||||
let len = self.data.len();
|
||||
let n = n.min(len);
|
||||
let mut indices: Vec<usize> = (0..len).collect();
|
||||
for i in 0..n {
|
||||
let j = rng.random_range(i..len);
|
||||
indices.swap(i, j);
|
||||
}
|
||||
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── DqnConfig ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Top-level DQN hyperparameters for the training loop.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DqnConfig {
|
||||
/// Initial exploration rate (1.0 = fully random).
|
||||
pub epsilon_start: f32,
|
||||
/// Final exploration rate after decay.
|
||||
pub epsilon_end: f32,
|
||||
/// Number of gradient steps over which ε decays linearly from start to end.
|
||||
///
|
||||
/// Should be calibrated to the total number of gradient steps
|
||||
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
|
||||
/// means exploration never reaches `epsilon_end` during the run.
|
||||
pub epsilon_decay_steps: usize,
|
||||
/// Discount factor γ for the TD target. Typical: 0.99.
|
||||
pub gamma: f32,
|
||||
/// Hard-copy Q → target every this many gradient steps.
|
||||
///
|
||||
/// Should be much smaller than the total number of gradient steps
|
||||
/// (`n_iterations × n_train_steps_per_iter`).
|
||||
pub target_update_freq: usize,
|
||||
/// Adam learning rate.
|
||||
pub learning_rate: f64,
|
||||
/// Mini-batch size for each gradient step.
|
||||
pub batch_size: usize,
|
||||
/// Maximum number of samples in the replay buffer.
|
||||
pub replay_capacity: usize,
|
||||
/// Number of outer iterations (self-play + train).
|
||||
pub n_iterations: usize,
|
||||
/// Self-play games per iteration.
|
||||
pub n_games_per_iter: usize,
|
||||
/// Gradient steps per iteration.
|
||||
pub n_train_steps_per_iter: usize,
|
||||
/// Reward normalisation divisor.
|
||||
///
|
||||
/// Per-turn rewards (score delta) are divided by this constant before being
|
||||
/// stored. Without normalisation, rewards can reach ±24 (jan with
|
||||
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
|
||||
/// causing MSE loss to grow unboundedly.
|
||||
///
|
||||
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
|
||||
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
|
||||
pub reward_scale: f32,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
// Total gradient steps with these defaults = 500 × 20 = 10_000,
|
||||
// so epsilon decays fully and the target is updated 100 times.
|
||||
Self {
|
||||
epsilon_start: 1.0,
|
||||
epsilon_end: 0.05,
|
||||
epsilon_decay_steps: 10_000,
|
||||
gamma: 0.99,
|
||||
target_update_freq: 100,
|
||||
learning_rate: 1e-3,
|
||||
batch_size: 64,
|
||||
replay_capacity: 50_000,
|
||||
n_iterations: 500,
|
||||
n_games_per_iter: 10,
|
||||
n_train_steps_per_iter: 20,
|
||||
reward_scale: 12.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
|
||||
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
|
||||
if decay_steps == 0 || step >= decay_steps {
|
||||
return end;
|
||||
}
|
||||
start + (end - start) * (step as f32 / decay_steps as f32)
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
fn dummy(reward: f32) -> DqnSample {
|
||||
DqnSample {
|
||||
obs: vec![0.0],
|
||||
action: 0,
|
||||
reward,
|
||||
next_obs: vec![0.0],
|
||||
next_legal: vec![0],
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_and_len() {
|
||||
let mut buf = DqnReplayBuffer::new(10);
|
||||
assert!(buf.is_empty());
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
assert_eq!(buf.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evicts_oldest_at_capacity() {
|
||||
let mut buf = DqnReplayBuffer::new(3);
|
||||
buf.push(dummy(1.0));
|
||||
buf.push(dummy(2.0));
|
||||
buf.push(dummy(3.0));
|
||||
buf.push(dummy(4.0));
|
||||
assert_eq!(buf.len(), 3);
|
||||
assert_eq!(buf.data[0].reward, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_batch_size() {
|
||||
let mut buf = DqnReplayBuffer::new(20);
|
||||
for i in 0..10 { buf.push(dummy(i as f32)); }
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_start() {
|
||||
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_end() {
|
||||
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_epsilon_monotone() {
|
||||
let mut prev = f32::INFINITY;
|
||||
for step in 0..=100 {
|
||||
let e = linear_epsilon(1.0, 0.05, step, 100);
|
||||
assert!(e <= prev + 1e-6);
|
||||
prev = e;
|
||||
}
|
||||
}
|
||||
}
|
||||
278
spiel_bot/src/dqn/trainer.rs
Normal file
278
spiel_bot/src/dqn/trainer.rs
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
//! DQN gradient step and target-network management.
|
||||
//!
|
||||
//! # TD target
|
||||
//!
|
||||
//! ```text
|
||||
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
|
||||
//! y_i = r_i if done
|
||||
//! ```
|
||||
//!
|
||||
//! # Loss
|
||||
//!
|
||||
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
|
||||
//! and `y_i` (computed from the frozen target net).
|
||||
//!
|
||||
//! # Target network
|
||||
//!
|
||||
//! [`hard_update`] copies the online Q-net weights into the target net by
|
||||
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
|
||||
|
||||
use burn::{
|
||||
module::AutodiffModule,
|
||||
optim::{GradientsParams, Optimizer},
|
||||
prelude::ElementConversion,
|
||||
tensor::{
|
||||
Int, Tensor, TensorData,
|
||||
backend::{AutodiffBackend, Backend},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::network::QValueNet;
|
||||
use super::DqnSample;
|
||||
|
||||
// ── Target Q computation ─────────────────────────────────────────────────────
|
||||
|
||||
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
|
||||
///
|
||||
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
|
||||
/// (their bootstrap term is dropped by the TD target anyway).
|
||||
///
|
||||
/// The target network runs on the **inference backend** (`InferB`) with no
|
||||
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
|
||||
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
|
||||
target_net: &Q,
|
||||
batch: &[DqnSample],
|
||||
action_size: usize,
|
||||
device: &B::Device,
|
||||
) -> Vec<f32> {
|
||||
let batch_size = batch.len();
|
||||
|
||||
// Collect indices of non-done samples (done samples have no next state).
|
||||
let non_done: Vec<usize> = batch
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, s)| !s.done)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if non_done.is_empty() {
|
||||
return vec![0.0; batch_size];
|
||||
}
|
||||
|
||||
let obs_size = batch[0].next_obs.len();
|
||||
let nd = non_done.len();
|
||||
|
||||
// Stack next observations for non-done samples → [nd, obs_size].
|
||||
let obs_flat: Vec<f32> = non_done
|
||||
.iter()
|
||||
.flat_map(|&i| batch[i].next_obs.iter().copied())
|
||||
.collect();
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [nd, obs_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// Forward target net → [nd, action_size], then to Vec<f32>.
|
||||
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||
|
||||
// For each non-done sample, pick max Q over legal next actions.
|
||||
let mut result = vec![0.0f32; batch_size];
|
||||
for (k, &i) in non_done.iter().enumerate() {
|
||||
let legal = &batch[i].next_legal;
|
||||
let offset = k * action_size;
|
||||
let max_q = legal
|
||||
.iter()
|
||||
.map(|&a| q_flat[offset + a])
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
// If legal is empty (shouldn't happen for non-done, but be safe):
|
||||
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// ── Training step ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run one gradient step on `q_net` using `batch`.
|
||||
///
|
||||
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
|
||||
/// frozen target network and passed in here so that this function only
|
||||
/// needs the **autodiff backend**.
|
||||
///
|
||||
/// Returns the updated network and the scalar MSE loss.
|
||||
pub fn dqn_train_step<B, Q, O>(
|
||||
q_net: Q,
|
||||
optimizer: &mut O,
|
||||
batch: &[DqnSample],
|
||||
target_max_q: &[f32],
|
||||
device: &B::Device,
|
||||
lr: f64,
|
||||
gamma: f32,
|
||||
) -> (Q, f32)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
Q: QValueNet<B> + AutodiffModule<B>,
|
||||
O: Optimizer<Q, B>,
|
||||
{
|
||||
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
|
||||
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
|
||||
|
||||
let batch_size = batch.len();
|
||||
let obs_size = batch[0].obs.len();
|
||||
|
||||
// ── Build observation tensor [B, obs_size] ────────────────────────────
|
||||
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
|
||||
let q_all = q_net.forward(obs_tensor);
|
||||
|
||||
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
|
||||
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
|
||||
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
|
||||
TensorData::new(actions, [batch_size]),
|
||||
device,
|
||||
)
|
||||
.reshape([batch_size, 1]); // [B] → [B, 1]
|
||||
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
|
||||
|
||||
// ── TD targets: r + γ · max_next_q · (1 − done) ──────────────────────
|
||||
let targets: Vec<f32> = batch
|
||||
.iter()
|
||||
.zip(target_max_q.iter())
|
||||
.map(|(s, &max_q)| {
|
||||
if s.done { s.reward } else { s.reward + gamma * max_q }
|
||||
})
|
||||
.collect();
|
||||
let target_tensor = Tensor::<B, 1>::from_data(
|
||||
TensorData::new(targets, [batch_size]),
|
||||
device,
|
||||
);
|
||||
|
||||
// ── MSE loss ──────────────────────────────────────────────────────────
|
||||
let diff = q_pred - target_tensor.detach();
|
||||
let loss = (diff.clone() * diff).mean();
|
||||
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||
|
||||
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &q_net);
|
||||
let q_net = optimizer.step(lr, q_net, grads);
|
||||
|
||||
(q_net, loss_scalar)
|
||||
}
|
||||
|
||||
// ── Target network update ─────────────────────────────────────────────────────
|
||||
|
||||
/// Hard-copy the online Q-net weights to a new target network.
|
||||
///
|
||||
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
|
||||
/// inference-backend module with identical weights.
|
||||
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
|
||||
q_net.valid()
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::{
|
||||
backend::{Autodiff, NdArray},
|
||||
optim::AdamConfig,
|
||||
};
|
||||
use crate::network::{QNet, QNetConfig};
|
||||
|
||||
type InferB = NdArray<f32>;
|
||||
type TrainB = Autodiff<NdArray<f32>>;
|
||||
|
||||
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||
|
||||
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
|
||||
(0..n)
|
||||
.map(|i| DqnSample {
|
||||
obs: vec![0.5f32; obs_size],
|
||||
action: i % action_size,
|
||||
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||
next_obs: vec![0.5f32; obs_size],
|
||||
next_legal: vec![0, 1],
|
||||
done: i == n - 1,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_target_q_length() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
assert_eq!(tq.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_target_q_done_is_zero() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
// Single done sample.
|
||||
let batch = vec![DqnSample {
|
||||
obs: vec![0.0; 4],
|
||||
action: 0,
|
||||
reward: 5.0,
|
||||
next_obs: vec![0.0; 4],
|
||||
next_legal: vec![],
|
||||
done: true,
|
||||
}];
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
assert_eq!(tq.len(), 1);
|
||||
assert_eq!(tq[0], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_returns_finite_loss() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(8, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
|
||||
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_step_loss_decreases() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
let batch = dummy_batch(16, 4, 4);
|
||||
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||
|
||||
let mut prev_loss = f32::INFINITY;
|
||||
for _ in 0..10 {
|
||||
let (q, loss) = dqn_train_step(
|
||||
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
|
||||
);
|
||||
q_net = q;
|
||||
assert!(loss.is_finite());
|
||||
prev_loss = loss;
|
||||
}
|
||||
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hard_update_copies_weights() {
|
||||
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||
let target = hard_update::<TrainB, _>(&q_net);
|
||||
|
||||
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
|
||||
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
|
||||
// After hard_update the target produces finite outputs.
|
||||
assert!(q_out.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
}
|
||||
12
spiel_bot/src/env/trictrac.rs
vendored
12
spiel_bot/src/env/trictrac.rs
vendored
|
|
@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv {
|
|||
}
|
||||
}
|
||||
|
||||
// ── DQN helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
impl TrictracEnv {
|
||||
/// Score snapshot for DQN reward computation.
|
||||
///
|
||||
/// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`.
|
||||
/// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2).
|
||||
pub fn score_snapshot(s: &GameState) -> [i32; 2] {
|
||||
[s.total_score(1), s.total_score(2)]
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod alphazero;
|
||||
pub mod dqn;
|
||||
pub mod env;
|
||||
pub mod mcts;
|
||||
pub mod network;
|
||||
|
|
|
|||
|
|
@ -43,9 +43,11 @@
|
|||
//! before passing to softmax.
|
||||
|
||||
pub mod mlp;
|
||||
pub mod qnet;
|
||||
pub mod resnet;
|
||||
|
||||
pub use mlp::{MlpConfig, MlpNet};
|
||||
pub use qnet::{QNet, QNetConfig};
|
||||
pub use resnet::{ResNet, ResNetConfig};
|
||||
|
||||
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||
|
|
@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
|||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||||
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||||
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||||
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||||
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||||
}
|
||||
|
||||
/// A neural network that outputs one Q-value per action.
|
||||
///
|
||||
/// # Shapes
|
||||
/// - `obs`: `[batch, obs_size]`
|
||||
/// - output: `[batch, action_size]` — raw Q-values (no activation)
|
||||
///
|
||||
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
|
||||
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
|
||||
}
|
||||
|
|
|
|||
147
spiel_bot/src/network/qnet.rs
Normal file
147
spiel_bot/src/network/qnet.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
//! Single-headed Q-value network for DQN.
|
||||
//!
|
||||
//! ```text
|
||||
//! Input [B, obs_size]
|
||||
//! → Linear(obs → hidden) → ReLU
|
||||
//! → Linear(hidden → hidden) → ReLU
|
||||
//! → Linear(hidden → action_size) ← raw Q-values, no activation
|
||||
//! ```
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig},
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::{activation::relu, backend::Backend, Tensor},
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use super::QValueNet;
|
||||
|
||||
// ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for [`QNet`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QNetConfig {
|
||||
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||
pub obs_size: usize,
|
||||
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||
pub action_size: usize,
|
||||
/// Width of both hidden layers.
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Default for QNetConfig {
|
||||
fn default() -> Self {
|
||||
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Network ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Two-hidden-layer MLP that outputs one Q-value per action.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct QNet<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
q_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> QNet<B> {
|
||||
/// Construct a fresh network with random weights.
|
||||
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
CompactRecorder::new()
|
||||
.record(self.clone().into_record(), path.to_path_buf())
|
||||
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
|
||||
}
|
||||
|
||||
/// Load weights from `path` into a fresh model built from `config`.
|
||||
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||
let record = CompactRecorder::new()
|
||||
.load(path.to_path_buf(), device)
|
||||
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
|
||||
Ok(Self::new(config, device).load_record(record))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> QValueNet<B> for QNet<B> {
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = relu(self.fc1.forward(obs));
|
||||
let x = relu(self.fc2.forward(x));
|
||||
self.q_head.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
fn device() -> <B as Backend>::Device { Default::default() }
|
||||
|
||||
fn default_net() -> QNet<B> {
|
||||
QNet::new(&QNetConfig::default(), &device())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_output_shape() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::zeros([4, 217], &device());
|
||||
let q = net.forward(obs);
|
||||
assert_eq!(q.dims(), [4, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_single_sample() {
|
||||
let net = default_net();
|
||||
let q = net.forward(Tensor::zeros([1, 217], &device()));
|
||||
assert_eq!(q.dims(), [1, 514]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn q_values_not_all_equal() {
|
||||
let net = default_net();
|
||||
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
|
||||
.into_data().to_vec().unwrap();
|
||||
let first = q[0];
|
||||
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_config_shapes() {
|
||||
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
|
||||
let net = QNet::<B>::new(&cfg, &device());
|
||||
let q = net.forward(Tensor::zeros([3, 10], &device()));
|
||||
assert_eq!(q.dims(), [3, 20]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_load_preserves_weights() {
|
||||
let net = default_net();
|
||||
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
|
||||
|
||||
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
|
||||
net.save(&path).expect("save failed");
|
||||
|
||||
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
|
||||
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
|
||||
|
||||
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
|
||||
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
|
||||
}
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
|
@ -1011,6 +1011,16 @@ impl GameState {
|
|||
self.mark_points(player_id, points)
|
||||
}
|
||||
|
||||
/// Total accumulated score for a player: `holes × 12 + points`.
|
||||
///
|
||||
/// Returns `0` if `player_id` is not found (e.g. before `init_player`).
|
||||
pub fn total_score(&self, player_id: PlayerId) -> i32 {
|
||||
self.players
|
||||
.get(&player_id)
|
||||
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
// Update player points and holes
|
||||
let mut new_hole = false;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue