diff --git a/doc/tensor_research.md b/doc/tensor_research.md deleted file mode 100644 index b0d0ede..0000000 --- a/doc/tensor_research.md +++ /dev/null @@ -1,253 +0,0 @@ -# Tensor research - -## Current tensor anatomy - -[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!) -[24] active player color: 0 or 1 -[25] turn_stage: 1–5 -[26–27] dice values (raw 1–6) -[28–31] white: points, holes, can_bredouille, can_big_bredouille -[32–35] black: same -───────────────────────────────── -Total 36 floats - -The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's -AlphaZero uses a fully-connected network. - -### Fundamental problems with the current encoding - -1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network - must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with - all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from. - -2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient - flow during training is uneven. - -3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it - triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers - produces a score. Including this explicitly is the single highest-value addition. - -4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely - different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0. - -5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the - starting position). It's in the Player struct but not exported. - -## Key Trictrac distinctions from backgammon that shape the encoding - -| Concept | Backgammon | Trictrac | -| ------------------------- | ---------------------- | --------------------------------------------------------- | -| Hitting a blot | Removes checker to bar | Scores points, checker stays | -| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened | -| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) | -| 3-checker field | Safe with spare | Safe with spare | -| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) | -| Both colors on a field | Impossible | Perfectly legal | -| Rest corner (field 12/13) | Does not exist | Special two-checker rules | - -The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary -indicators directly teaches the network the phase transitions the game hinges on. - -## Options - -### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D) - -The minimum viable improvement. - -For each of the 24 fields, encode own and opponent separately with 4 indicators each: - -own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target) -own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill) -own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare) -own_x[i]: max(0, count − 3) (overflow) -opp_1[i]: same for opponent -… - -Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec(). - -Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204 -Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured -overhead is negligible. -Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from -scratch will be substantially faster and the converged policy quality better. - -### Option B — Option A + Trictrac-specific derived features (flat 1D) - -Recommended starting point. - -Add on top of Option A: - -// Quarter fill status — the single most important derived feature -quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields) -quarter_filled_opp[q] (q=0..3): same for opponent -→ 8 values - -// Exit readiness -can_exit_own: 1.0 if all own checkers are in fields 19–24 -can_exit_opp: same for opponent -→ 2 values - -// Rest corner status (field 12/13) -own_corner_taken: 1.0 if field 12 has ≥2 own checkers -opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers -→ 2 values - -// Jan de 3 coups counter (normalized) -dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0) -→ 1 value - -Size: 204 + 8 + 2 + 2 + 1 = 217 -Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve -the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes -expensive inference from the network. - -### Option C — Option B + richer positional features (flat 1D) - -More complete, higher sample efficiency, minor extra cost. - -Add on top of Option B: - -// Per-quarter fill fraction — how close to filling each quarter -own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0 -opp_quarter_fill_fraction[q] (q=0..3): same for opponent -→ 8 values - -// Blot counts — number of own/opponent single-checker fields globally -// (tells the network at a glance how much battage risk/opportunity exists) -own_blot_count: (number of own fields with exactly 1 checker) / 15.0 -opp_blot_count: same for opponent -→ 2 values - -// Bredouille would-double multiplier (already present, but explicitly scaled) -// No change needed, already binary - -Size: 217 + 8 + 2 = 227 -Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network -from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts). - -### Option D — 2D spatial tensor {K, 24} - -For CNN-based networks. Best eventual architecture but requires changing the training setup. - -Shape {14, 24} — 14 feature channels over 24 field positions: - -Channel 0: own_count_1 (blot) -Channel 1: own_count_2 -Channel 2: own_count_3 -Channel 3: own_count_overflow (float) -Channel 4: opp_count_1 -Channel 5: opp_count_2 -Channel 6: opp_count_3 -Channel 7: opp_count_overflow -Channel 8: own_corner_mask (1.0 at field 12) -Channel 9: opp_corner_mask (1.0 at field 13) -Channel 10: final_quarter_mask (1.0 at fields 19–24) -Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter) -Channel 12: quarter_filled_opp (same for opponent) -Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers) - -Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform -value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global" -row by returning shape {K, 25} with position 0 holding global features. - -Size: 14 × 24 + few global channels ≈ 336–384 -C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated -accordingly. -Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's -alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H, -W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension. -Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel -size 2–3 captures adjacent-field "tout d'une" interactions. - -### On 3D tensors - -Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is -the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and -field-within-quarter patterns jointly. - -However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The -extra architecture work makes this premature unless you're already building a custom network. The information content -is the same as Option D. - -### Recommendation - -Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and -the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions -(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move. - -Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when -the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time. - -Costs summary: - -| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain | -| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- | -| Current | 36 | — | — | — | baseline | -| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) | -| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) | -| C | 227 | to_vec() rewrite | constant update | none | large + moderate | -| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial | - -One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2, -the new to_vec() must express everything from the active player's perspective (which the mirror already handles for -the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state, -which they will be if computed inside to_vec(). - -## Other algorithms - -The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully. - -### 1. Without MCTS, feature quality matters more - -AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred -simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup — -the network must learn the full strategic structure directly from gradient updates. - -This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO, -not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less -sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from -scratch. - -### 2. Normalization becomes mandatory, not optional - -AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps -Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw -values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can -be unstable. - -For DQN/PPO, every feature in the tensor should be in [0, 1]: - -dice values: / 6.0 -checker counts: overflow channel / 12.0 -points: / 12.0 -holes: / 12.0 -dice_roll_count: / 3.0 (clamped) - -Booleans and the TD-Gammon binary indicators are already in [0, 1]. - -### 3. The shape question depends on architecture, not algorithm - -| Architecture | Shape | When to use | -| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- | -| MLP | {217} flat | Any algorithm, simplest baseline | -| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) | -| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network | -| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now | - -The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want -convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet -template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all -— you just reshape the flat tensor in Python before passing it to the network. - -### One thing that genuinely changes: value function perspective - -AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This -works well. - -DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit -current-player indicator), because a single Q-network estimates action values for both players simultaneously. With -the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must -learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity. -This is a minor point for a symmetric game like Trictrac, but worth keeping in mind. - -Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm. diff --git a/spiel_bot/src/bin/dqn_train.rs b/spiel_bot/src/bin/dqn_train.rs deleted file mode 100644 index 0ebe978..0000000 --- a/spiel_bot/src/bin/dqn_train.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! DQN self-play training loop. -//! -//! # Usage -//! -//! ```sh -//! # Start fresh with default settings -//! cargo run -p spiel_bot --bin dqn_train --release -//! -//! # Custom hyperparameters -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000 -//! -//! # Resume from a checkpoint -//! cargo run -p spiel_bot --bin dqn_train --release -- \ -//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100 -//! ``` -//! -//! # Options -//! -//! | Flag | Default | Description | -//! |------|---------|-------------| -//! | `--hidden N` | 256 | Hidden layer width | -//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files | -//! | `--n-iter N` | 100 | Training iterations | -//! | `--n-games N` | 10 | Self-play games per iteration | -//! | `--n-train N` | 20 | Gradient steps per iteration | -//! | `--batch N` | 64 | Mini-batch size | -//! | `--replay-cap N` | 50000 | Replay buffer capacity | -//! | `--lr F` | 1e-3 | Adam learning rate | -//! | `--epsilon-start F` | 1.0 | Initial exploration rate | -//! | `--epsilon-end F` | 0.05 | Final exploration rate | -//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor | -//! | `--gamma F` | 0.99 | Discount factor | -//! | `--target-update N` | 500 | Hard-update target net every N steps | -//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) | -//! | `--save-every N` | 10 | Save checkpoint every N iterations | -//! | `--seed N` | 42 | RNG seed | -//! | `--resume PATH` | (none) | Load weights before training | - -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use burn::{ - backend::{Autodiff, NdArray}, - module::AutodiffModule, - optim::AdamConfig, - tensor::backend::Backend, -}; -use rand::{SeedableRng, rngs::SmallRng}; - -use spiel_bot::{ - dqn::{ - DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step, - generate_dqn_episode, hard_update, linear_epsilon, - }, - env::TrictracEnv, - network::{QNet, QNetConfig}, -}; - -type TrainB = Autodiff>; -type InferB = NdArray; - -// ── CLI ─────────────────────────────────────────────────────────────────────── - -struct Args { - hidden: usize, - out_dir: PathBuf, - save_every: usize, - seed: u64, - resume: Option, - config: DqnConfig, -} - -impl Default for Args { - fn default() -> Self { - Self { - hidden: 256, - out_dir: PathBuf::from("checkpoints"), - save_every: 10, - seed: 42, - resume: None, - config: DqnConfig::default(), - } - } -} - -fn parse_args() -> Args { - let raw: Vec = std::env::args().collect(); - let mut a = Args::default(); - let mut i = 1; - while i < raw.len() { - match raw[i].as_str() { - "--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); } - "--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); } - "--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); } - "--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); } - "--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); } - "--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); } - "--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); } - "--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); } - "--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); } - "--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); } - "--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); } - "--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); } - "--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); } - "--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); } - "--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); } - "--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); } - "--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); } - other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); } - } - i += 1; - } - a -} - -// ── Training loop ───────────────────────────────────────────────────────────── - -fn train_loop( - mut q_net: QNet, - cfg: &QNetConfig, - save_fn: &dyn Fn(&QNet, &Path) -> anyhow::Result<()>, - args: &Args, -) { - let train_device: ::Device = Default::default(); - let infer_device: ::Device = Default::default(); - - let mut optimizer = AdamConfig::new().init(); - let mut replay = DqnReplayBuffer::new(args.config.replay_capacity); - let mut rng = SmallRng::seed_from_u64(args.seed); - let env = TrictracEnv; - - let mut target_net: QNet = hard_update::(&q_net); - let mut global_step = 0usize; - let mut epsilon = args.config.epsilon_start; - - println!( - "\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}", - "", args.config.n_iterations, args.config.n_games_per_iter, - args.config.n_train_steps_per_iter, "" - ); - - for iter in 0..args.config.n_iterations { - let t0 = Instant::now(); - - // ── Self-play ──────────────────────────────────────────────────── - let infer_q: QNet = q_net.valid(); - let mut new_samples = 0usize; - - for _ in 0..args.config.n_games_per_iter { - let samples = generate_dqn_episode( - &env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale, - ); - new_samples += samples.len(); - replay.extend(samples); - } - - // ── Training ───────────────────────────────────────────────────── - let mut loss_sum = 0.0f32; - let mut n_steps = 0usize; - - if replay.len() >= args.config.batch_size { - for _ in 0..args.config.n_train_steps_per_iter { - let batch: Vec<_> = replay - .sample_batch(args.config.batch_size, &mut rng) - .into_iter() - .cloned() - .collect(); - - // Target Q-values computed on the inference backend. - let target_q = compute_target_q( - &target_net, &batch, cfg.action_size, &infer_device, - ); - - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &target_q, - &train_device, args.config.learning_rate, args.config.gamma, - ); - q_net = q; - loss_sum += loss; - n_steps += 1; - global_step += 1; - - // Hard-update target net every target_update_freq steps. - if global_step % args.config.target_update_freq == 0 { - target_net = hard_update::(&q_net); - } - - // Linear epsilon decay. - epsilon = linear_epsilon( - args.config.epsilon_start, - args.config.epsilon_end, - global_step, - args.config.epsilon_decay_steps, - ); - } - } - - // ── Logging ────────────────────────────────────────────────────── - let elapsed = t0.elapsed(); - let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN }; - - println!( - "iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s", - iter + 1, - args.config.n_iterations, - replay.len(), - new_samples, - avg_loss, - epsilon, - elapsed.as_secs_f32(), - ); - - // ── Checkpoint ─────────────────────────────────────────────────── - let is_last = iter + 1 == args.config.n_iterations; - if (iter + 1) % args.save_every == 0 || is_last { - let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1)); - match save_fn(&q_net, &path) { - Ok(()) => println!(" -> saved {}", path.display()), - Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"), - } - } - } - - println!("\nDQN training complete."); -} - -// ── Main ────────────────────────────────────────────────────────────────────── - -fn main() { - let args = parse_args(); - - if let Err(e) = std::fs::create_dir_all(&args.out_dir) { - eprintln!("Cannot create output directory {}: {e}", args.out_dir.display()); - std::process::exit(1); - } - - let train_device: ::Device = Default::default(); - let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden }; - - let q_net = match &args.resume { - Some(path) => { - println!("Resuming from {}", path.display()); - QNet::::load(&cfg, path, &train_device) - .unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); }) - } - None => QNet::::new(&cfg, &train_device), - }; - - train_loop(q_net, &cfg, &|m: &QNet, path| m.valid().save(path), &args); -} diff --git a/spiel_bot/src/dqn/episode.rs b/spiel_bot/src/dqn/episode.rs deleted file mode 100644 index aca1343..0000000 --- a/spiel_bot/src/dqn/episode.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! DQN self-play episode generation. -//! -//! Both players share the same Q-network (the [`TrictracEnv`] handles board -//! mirroring so that each player always acts from "White's perspective"). -//! Transitions for both players are stored in the returned sample list. -//! -//! # Reward -//! -//! After each full decision (action applied and the state has advanced through -//! any intervening chance nodes back to the same player's next turn), the -//! reward is: -//! -//! ```text -//! r = (my_total_score_now − my_total_score_then) -//! − (opp_total_score_now − opp_total_score_then) -//! ``` -//! -//! where `total_score = holes × 12 + points`. -//! -//! # Transition structure -//! -//! We use a "pending transition" per player. When a player acts again, we -//! *complete* the previous pending transition by filling in `next_obs`, -//! `next_legal`, and computing `reward`. Terminal transitions are completed -//! when the game ends. - -use burn::tensor::{backend::Backend, Tensor, TensorData}; -use rand::Rng; - -use crate::env::{GameEnv, TrictracEnv}; -use crate::network::QValueNet; -use super::DqnSample; - -// ── Internals ───────────────────────────────────────────────────────────────── - -struct PendingTransition { - obs: Vec, - action: usize, - /// Score snapshot `[p1_total, p2_total]` at the moment of the action. - score_before: [i32; 2], -} - -/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise. -fn epsilon_greedy>( - q_net: &Q, - obs: &[f32], - legal: &[usize], - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, -) -> usize { - debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions"); - if rng.random::() < epsilon { - legal[rng.random_range(0..legal.len())] - } else { - let obs_tensor = Tensor::::from_data( - TensorData::new(obs.to_vec(), [1, obs.len()]), - device, - ); - let q_values: Vec = q_net.forward(obs_tensor).into_data().to_vec().unwrap(); - legal - .iter() - .copied() - .max_by(|&a, &b| { - q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap() - } -} - -/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after. -fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 { - let opp_idx = 1 - player_idx; - ((score_after[player_idx] - score_before[player_idx]) - - (score_after[opp_idx] - score_before[opp_idx])) as f32 -} - -// ── Public API ──────────────────────────────────────────────────────────────── - -/// Play one full game and return all transitions for both players. -/// -/// - `q_net` uses the **inference backend** (no autodiff wrapper). -/// - `epsilon` in `[0, 1]`: probability of taking a random action. -/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`). -pub fn generate_dqn_episode>( - env: &TrictracEnv, - q_net: &Q, - epsilon: f32, - rng: &mut impl Rng, - device: &B::Device, - reward_scale: f32, -) -> Vec { - let obs_size = env.obs_size(); - let mut state = env.new_game(); - let mut pending: [Option; 2] = [None, None]; - let mut samples: Vec = Vec::new(); - - loop { - // ── Advance past chance nodes ────────────────────────────────────── - while env.current_player(&state).is_chance() { - env.apply_chance(&mut state, rng); - } - - let score_now = TrictracEnv::score_snapshot(&state); - - if env.current_player(&state).is_terminal() { - // Complete all pending transitions as terminal. - for player_idx in 0..2 { - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: vec![0.0; obs_size], - next_legal: vec![], - done: true, - }); - } - } - break; - } - - let player_idx = env.current_player(&state).index().unwrap(); - let legal = env.legal_actions(&state); - let obs = env.observation(&state, player_idx); - - // ── Complete the previous transition for this player ─────────────── - if let Some(prev) = pending[player_idx].take() { - let reward = - compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale; - samples.push(DqnSample { - obs: prev.obs, - action: prev.action, - reward, - next_obs: obs.clone(), - next_legal: legal.clone(), - done: false, - }); - } - - // ── Pick and apply action ────────────────────────────────────────── - let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device); - env.apply(&mut state, action); - - // ── Record new pending transition ────────────────────────────────── - pending[player_idx] = Some(PendingTransition { - obs, - action, - score_before: score_now, - }); - } - - samples -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - use rand::{SeedableRng, rngs::SmallRng}; - - use crate::network::{QNet, QNetConfig}; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - fn rng() -> SmallRng { SmallRng::seed_from_u64(7) } - - fn tiny_q() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn episode_terminates_and_produces_samples() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty(), "episode must produce at least one sample"); - } - - #[test] - fn episode_obs_size_correct() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert_eq!(s.obs.len(), 217, "obs size mismatch"); - if s.done { - assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size"); - assert!(s.next_legal.is_empty()); - } else { - assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch"); - assert!(!s.next_legal.is_empty()); - } - } - } - - #[test] - fn episode_actions_within_action_space() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - for s in &samples { - assert!(s.action < 514, "action {} out of bounds", s.action); - } - } - - #[test] - fn greedy_episode_also_terminates() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0); - assert!(!samples.is_empty()); - } - - #[test] - fn at_least_one_done_sample() { - let env = TrictracEnv; - let q = tiny_q(); - let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0); - let n_done = samples.iter().filter(|s| s.done).count(); - // Two players, so 1 or 2 terminal transitions. - assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}"); - } - - #[test] - fn compute_reward_correct() { - // P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged. - let before = [2 * 12 + 10, 0]; - let after = [3 * 12 + 2, 0]; - let r = compute_reward(0, &before, &after); - assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}"); - } - - #[test] - fn compute_reward_with_opponent_scoring() { - // P1 gains 2, opp gains 3 → net = -1 from P1's perspective. - let before = [0, 0]; - let after = [2, 3]; - let r = compute_reward(0, &before, &after); - assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}"); - } -} diff --git a/spiel_bot/src/dqn/mod.rs b/spiel_bot/src/dqn/mod.rs deleted file mode 100644 index 8c34fc1..0000000 --- a/spiel_bot/src/dqn/mod.rs +++ /dev/null @@ -1,232 +0,0 @@ -//! DQN: self-play data generation, replay buffer, and training step. -//! -//! # Algorithm -//! -//! Deep Q-Network with: -//! - **ε-greedy** exploration (linearly decayed). -//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where -//! `score = holes × 12 + points`. -//! - **Experience replay** with a fixed-capacity circular buffer. -//! - **Target network**: hard-copied from the online Q-net every -//! `target_update_freq` gradient steps for training stability. -//! -//! # Modules -//! -//! | Module | Contents | -//! |--------|----------| -//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] | -//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] | - -pub mod episode; -pub mod trainer; - -pub use episode::generate_dqn_episode; -pub use trainer::{compute_target_q, dqn_train_step, hard_update}; - -use std::collections::VecDeque; -use rand::Rng; - -// ── DqnSample ───────────────────────────────────────────────────────────────── - -/// One transition `(s, a, r, s', done)` collected during self-play. -#[derive(Clone, Debug)] -pub struct DqnSample { - /// Observation from the acting player's perspective (`obs_size` floats). - pub obs: Vec, - /// Action index taken. - pub action: usize, - /// Per-turn reward: `my_score_delta − opponent_score_delta`. - pub reward: f32, - /// Next observation from the same player's perspective. - /// All-zeros when `done = true` (ignored by the TD target). - pub next_obs: Vec, - /// Legal actions at `next_obs`. Empty when `done = true`. - pub next_legal: Vec, - /// `true` when `next_obs` is a terminal state. - pub done: bool, -} - -// ── DqnReplayBuffer ─────────────────────────────────────────────────────────── - -/// Fixed-capacity circular replay buffer for [`DqnSample`]s. -/// -/// When full, the oldest sample is evicted on push. -/// Batches are drawn without replacement via a partial Fisher-Yates shuffle. -pub struct DqnReplayBuffer { - data: VecDeque, - capacity: usize, -} - -impl DqnReplayBuffer { - pub fn new(capacity: usize) -> Self { - Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity } - } - - pub fn push(&mut self, sample: DqnSample) { - if self.data.len() == self.capacity { - self.data.pop_front(); - } - self.data.push_back(sample); - } - - pub fn extend(&mut self, samples: impl IntoIterator) { - for s in samples { self.push(s); } - } - - pub fn len(&self) -> usize { self.data.len() } - pub fn is_empty(&self) -> bool { self.data.is_empty() } - - /// Sample up to `n` distinct samples without replacement. - pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> { - let len = self.data.len(); - let n = n.min(len); - let mut indices: Vec = (0..len).collect(); - for i in 0..n { - let j = rng.random_range(i..len); - indices.swap(i, j); - } - indices[..n].iter().map(|&i| &self.data[i]).collect() - } -} - -// ── DqnConfig ───────────────────────────────────────────────────────────────── - -/// Top-level DQN hyperparameters for the training loop. -#[derive(Debug, Clone)] -pub struct DqnConfig { - /// Initial exploration rate (1.0 = fully random). - pub epsilon_start: f32, - /// Final exploration rate after decay. - pub epsilon_end: f32, - /// Number of gradient steps over which ε decays linearly from start to end. - /// - /// Should be calibrated to the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). A value larger than that - /// means exploration never reaches `epsilon_end` during the run. - pub epsilon_decay_steps: usize, - /// Discount factor γ for the TD target. Typical: 0.99. - pub gamma: f32, - /// Hard-copy Q → target every this many gradient steps. - /// - /// Should be much smaller than the total number of gradient steps - /// (`n_iterations × n_train_steps_per_iter`). - pub target_update_freq: usize, - /// Adam learning rate. - pub learning_rate: f64, - /// Mini-batch size for each gradient step. - pub batch_size: usize, - /// Maximum number of samples in the replay buffer. - pub replay_capacity: usize, - /// Number of outer iterations (self-play + train). - pub n_iterations: usize, - /// Self-play games per iteration. - pub n_games_per_iter: usize, - /// Gradient steps per iteration. - pub n_train_steps_per_iter: usize, - /// Reward normalisation divisor. - /// - /// Per-turn rewards (score delta) are divided by this constant before being - /// stored. Without normalisation, rewards can reach ±24 (jan with - /// bredouille = 12 pts × 2), driving Q-values into the hundreds and - /// causing MSE loss to grow unboundedly. - /// - /// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping - /// Q-value magnitudes in a stable range. Set to `1.0` to disable. - pub reward_scale: f32, -} - -impl Default for DqnConfig { - fn default() -> Self { - // Total gradient steps with these defaults = 500 × 20 = 10_000, - // so epsilon decays fully and the target is updated 100 times. - Self { - epsilon_start: 1.0, - epsilon_end: 0.05, - epsilon_decay_steps: 10_000, - gamma: 0.99, - target_update_freq: 100, - learning_rate: 1e-3, - batch_size: 64, - replay_capacity: 50_000, - n_iterations: 500, - n_games_per_iter: 10, - n_train_steps_per_iter: 20, - reward_scale: 12.0, - } - } -} - -/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps. -pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 { - if decay_steps == 0 || step >= decay_steps { - return end; - } - start + (end - start) * (step as f32 / decay_steps as f32) -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use rand::{SeedableRng, rngs::SmallRng}; - - fn dummy(reward: f32) -> DqnSample { - DqnSample { - obs: vec![0.0], - action: 0, - reward, - next_obs: vec![0.0], - next_legal: vec![0], - done: false, - } - } - - #[test] - fn push_and_len() { - let mut buf = DqnReplayBuffer::new(10); - assert!(buf.is_empty()); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - assert_eq!(buf.len(), 2); - } - - #[test] - fn evicts_oldest_at_capacity() { - let mut buf = DqnReplayBuffer::new(3); - buf.push(dummy(1.0)); - buf.push(dummy(2.0)); - buf.push(dummy(3.0)); - buf.push(dummy(4.0)); - assert_eq!(buf.len(), 3); - assert_eq!(buf.data[0].reward, 2.0); - } - - #[test] - fn sample_batch_size() { - let mut buf = DqnReplayBuffer::new(20); - for i in 0..10 { buf.push(dummy(i as f32)); } - let mut rng = SmallRng::seed_from_u64(0); - assert_eq!(buf.sample_batch(5, &mut rng).len(), 5); - } - - #[test] - fn linear_epsilon_start() { - assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_end() { - assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6); - } - - #[test] - fn linear_epsilon_monotone() { - let mut prev = f32::INFINITY; - for step in 0..=100 { - let e = linear_epsilon(1.0, 0.05, step, 100); - assert!(e <= prev + 1e-6); - prev = e; - } - } -} diff --git a/spiel_bot/src/dqn/trainer.rs b/spiel_bot/src/dqn/trainer.rs deleted file mode 100644 index b8b0a02..0000000 --- a/spiel_bot/src/dqn/trainer.rs +++ /dev/null @@ -1,278 +0,0 @@ -//! DQN gradient step and target-network management. -//! -//! # TD target -//! -//! ```text -//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done -//! y_i = r_i if done -//! ``` -//! -//! # Loss -//! -//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net) -//! and `y_i` (computed from the frozen target net). -//! -//! # Target network -//! -//! [`hard_update`] copies the online Q-net weights into the target net by -//! stripping the autodiff wrapper via [`AutodiffModule::valid`]. - -use burn::{ - module::AutodiffModule, - optim::{GradientsParams, Optimizer}, - prelude::ElementConversion, - tensor::{ - Int, Tensor, TensorData, - backend::{AutodiffBackend, Backend}, - }, -}; - -use crate::network::QValueNet; -use super::DqnSample; - -// ── Target Q computation ───────────────────────────────────────────────────── - -/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample. -/// -/// Returns a `Vec` of length `batch.len()`. Done samples get `0.0` -/// (their bootstrap term is dropped by the TD target anyway). -/// -/// The target network runs on the **inference backend** (`InferB`) with no -/// gradient tape, so this function is backend-agnostic (`B: Backend`). -pub fn compute_target_q>( - target_net: &Q, - batch: &[DqnSample], - action_size: usize, - device: &B::Device, -) -> Vec { - let batch_size = batch.len(); - - // Collect indices of non-done samples (done samples have no next state). - let non_done: Vec = batch - .iter() - .enumerate() - .filter(|(_, s)| !s.done) - .map(|(i, _)| i) - .collect(); - - if non_done.is_empty() { - return vec![0.0; batch_size]; - } - - let obs_size = batch[0].next_obs.len(); - let nd = non_done.len(); - - // Stack next observations for non-done samples → [nd, obs_size]. - let obs_flat: Vec = non_done - .iter() - .flat_map(|&i| batch[i].next_obs.iter().copied()) - .collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [nd, obs_size]), - device, - ); - - // Forward target net → [nd, action_size], then to Vec. - let q_flat: Vec = target_net.forward(obs_tensor).into_data().to_vec().unwrap(); - - // For each non-done sample, pick max Q over legal next actions. - let mut result = vec![0.0f32; batch_size]; - for (k, &i) in non_done.iter().enumerate() { - let legal = &batch[i].next_legal; - let offset = k * action_size; - let max_q = legal - .iter() - .map(|&a| q_flat[offset + a]) - .fold(f32::NEG_INFINITY, f32::max); - // If legal is empty (shouldn't happen for non-done, but be safe): - result[i] = if max_q.is_finite() { max_q } else { 0.0 }; - } - result -} - -// ── Training step ───────────────────────────────────────────────────────────── - -/// Run one gradient step on `q_net` using `batch`. -/// -/// `target_max_q` must be pre-computed via [`compute_target_q`] using the -/// frozen target network and passed in here so that this function only -/// needs the **autodiff backend**. -/// -/// Returns the updated network and the scalar MSE loss. -pub fn dqn_train_step( - q_net: Q, - optimizer: &mut O, - batch: &[DqnSample], - target_max_q: &[f32], - device: &B::Device, - lr: f64, - gamma: f32, -) -> (Q, f32) -where - B: AutodiffBackend, - Q: QValueNet + AutodiffModule, - O: Optimizer, -{ - assert!(!batch.is_empty(), "dqn_train_step: empty batch"); - assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch"); - - let batch_size = batch.len(); - let obs_size = batch[0].obs.len(); - - // ── Build observation tensor [B, obs_size] ──────────────────────────── - let obs_flat: Vec = batch.iter().flat_map(|s| s.obs.iter().copied()).collect(); - let obs_tensor = Tensor::::from_data( - TensorData::new(obs_flat, [batch_size, obs_size]), - device, - ); - - // ── Forward Q-net → [B, action_size] ───────────────────────────────── - let q_all = q_net.forward(obs_tensor); - - // ── Gather Q(s, a) for the taken action → [B] ──────────────────────── - let actions: Vec = batch.iter().map(|s| s.action as i32).collect(); - let action_tensor: Tensor = Tensor::::from_data( - TensorData::new(actions, [batch_size]), - device, - ) - .reshape([batch_size, 1]); // [B] → [B, 1] - let q_pred: Tensor = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B] - - // ── TD targets: r + γ · max_next_q · (1 − done) ────────────────────── - let targets: Vec = batch - .iter() - .zip(target_max_q.iter()) - .map(|(s, &max_q)| { - if s.done { s.reward } else { s.reward + gamma * max_q } - }) - .collect(); - let target_tensor = Tensor::::from_data( - TensorData::new(targets, [batch_size]), - device, - ); - - // ── MSE loss ────────────────────────────────────────────────────────── - let diff = q_pred - target_tensor.detach(); - let loss = (diff.clone() * diff).mean(); - let loss_scalar: f32 = loss.clone().into_scalar().elem(); - - // ── Backward + optimizer step ───────────────────────────────────────── - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &q_net); - let q_net = optimizer.step(lr, q_net, grads); - - (q_net, loss_scalar) -} - -// ── Target network update ───────────────────────────────────────────────────── - -/// Hard-copy the online Q-net weights to a new target network. -/// -/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an -/// inference-backend module with identical weights. -pub fn hard_update>(q_net: &Q) -> Q::InnerModule { - q_net.valid() -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::{ - backend::{Autodiff, NdArray}, - optim::AdamConfig, - }; - use crate::network::{QNet, QNetConfig}; - - type InferB = NdArray; - type TrainB = Autodiff>; - - fn infer_device() -> ::Device { Default::default() } - fn train_device() -> ::Device { Default::default() } - - fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec { - (0..n) - .map(|i| DqnSample { - obs: vec![0.5f32; obs_size], - action: i % action_size, - reward: if i % 2 == 0 { 1.0 } else { -1.0 }, - next_obs: vec![0.5f32; obs_size], - next_legal: vec![0, 1], - done: i == n - 1, - }) - .collect() - } - - #[test] - fn compute_target_q_length() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 8); - } - - #[test] - fn compute_target_q_done_is_zero() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let target = QNet::::new(&cfg, &infer_device()); - // Single done sample. - let batch = vec![DqnSample { - obs: vec![0.0; 4], - action: 0, - reward: 5.0, - next_obs: vec![0.0; 4], - next_legal: vec![], - done: true, - }]; - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - assert_eq!(tq.len(), 1); - assert_eq!(tq[0], 0.0); - } - - #[test] - fn train_step_returns_finite_loss() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(8, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99); - assert!(loss.is_finite(), "loss must be finite, got {loss}"); - } - - #[test] - fn train_step_loss_decreases() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 }; - let mut q_net = QNet::::new(&cfg, &train_device()); - let target = QNet::::new(&cfg, &infer_device()); - let mut optimizer = AdamConfig::new().init(); - let batch = dummy_batch(16, 4, 4); - let tq = compute_target_q(&target, &batch, 4, &infer_device()); - - let mut prev_loss = f32::INFINITY; - for _ in 0..10 { - let (q, loss) = dqn_train_step( - q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99, - ); - q_net = q; - assert!(loss.is_finite()); - prev_loss = loss; - } - assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}"); - } - - #[test] - fn hard_update_copies_weights() { - let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 }; - let q_net = QNet::::new(&cfg, &train_device()); - let target = hard_update::(&q_net); - - let obs = burn::tensor::Tensor::::zeros([1, 4], &infer_device()); - let q_out: Vec = target.forward(obs).into_data().to_vec().unwrap(); - // After hard_update the target produces finite outputs. - assert!(q_out.iter().all(|v| v.is_finite())); - } -} diff --git a/spiel_bot/src/env/trictrac.rs b/spiel_bot/src/env/trictrac.rs index 8dc3676..99ba058 100644 --- a/spiel_bot/src/env/trictrac.rs +++ b/spiel_bot/src/env/trictrac.rs @@ -200,18 +200,6 @@ impl GameEnv for TrictracEnv { } } -// ── DQN helpers ─────────────────────────────────────────────────────────────── - -impl TrictracEnv { - /// Score snapshot for DQN reward computation. - /// - /// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`. - /// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2). - pub fn score_snapshot(s: &GameState) -> [i32; 2] { - [s.total_score(1), s.total_score(2)] - } -} - // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 9dfb4de..23895b9 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,5 +1,4 @@ pub mod alphazero; -pub mod dqn; pub mod env; pub mod mcts; pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index eead171..a0a690d 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Every simulation crosses a chance node at depth 1 (dice roll after - // the player's move). Since the fix now updates child.n in that case, - // children visit counts must sum to exactly n_simulations. + // Children visit counts may sum to less than n_simulations when some + // simulations cross a chance node at depth 1 (turn ends after one move) + // and evaluate with the network directly without updating child.n. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, config.n_simulations as u32); + assert!(total <= config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 1d9750d..55db701 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -166,12 +166,6 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); - // Record the visit so that PUCT and mcts_policy use real counts. - // Without this, child.n stays 0 for every simulation in games where - // every player action is immediately followed by a chance node (e.g. - // Trictrac), causing mcts_policy to always return a uniform policy. - child.n += 1; - child.w += value; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) diff --git a/spiel_bot/src/network/mod.rs b/spiel_bot/src/network/mod.rs index 64f93ec..df710e9 100644 --- a/spiel_bot/src/network/mod.rs +++ b/spiel_bot/src/network/mod.rs @@ -43,11 +43,9 @@ //! before passing to softmax. pub mod mlp; -pub mod qnet; pub mod resnet; pub use mlp::{MlpConfig, MlpNet}; -pub use qnet::{QNet, QNetConfig}; pub use resnet::{ResNet, ResNetConfig}; use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; @@ -58,21 +56,9 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor}; /// - `obs`: `[batch, obs_size]` /// - policy output: `[batch, action_size]` — raw logits (no softmax applied) /// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1) -/// /// Note: `Sync` is intentionally absent — Burn's `Module` internally uses /// `OnceCell` for lazy parameter initialisation, which is not `Sync`. /// Use an `Arc>` wrapper if cross-thread sharing is needed. pub trait PolicyValueNet: Module + Send + 'static { fn forward(&self, obs: Tensor) -> (Tensor, Tensor); } - -/// A neural network that outputs one Q-value per action. -/// -/// # Shapes -/// - `obs`: `[batch, obs_size]` -/// - output: `[batch, action_size]` — raw Q-values (no activation) -/// -/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`]. -pub trait QValueNet: Module + Send + 'static { - fn forward(&self, obs: Tensor) -> Tensor; -} diff --git a/spiel_bot/src/network/qnet.rs b/spiel_bot/src/network/qnet.rs deleted file mode 100644 index 1737f72..0000000 --- a/spiel_bot/src/network/qnet.rs +++ /dev/null @@ -1,147 +0,0 @@ -//! Single-headed Q-value network for DQN. -//! -//! ```text -//! Input [B, obs_size] -//! → Linear(obs → hidden) → ReLU -//! → Linear(hidden → hidden) → ReLU -//! → Linear(hidden → action_size) ← raw Q-values, no activation -//! ``` - -use burn::{ - module::Module, - nn::{Linear, LinearConfig}, - record::{CompactRecorder, Recorder}, - tensor::{activation::relu, backend::Backend, Tensor}, -}; -use std::path::Path; - -use super::QValueNet; - -// ── Config ──────────────────────────────────────────────────────────────────── - -/// Configuration for [`QNet`]. -#[derive(Debug, Clone)] -pub struct QNetConfig { - /// Number of input features. 217 for Trictrac's `to_tensor()`. - pub obs_size: usize, - /// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`. - pub action_size: usize, - /// Width of both hidden layers. - pub hidden_size: usize, -} - -impl Default for QNetConfig { - fn default() -> Self { - Self { obs_size: 217, action_size: 514, hidden_size: 256 } - } -} - -// ── Network ─────────────────────────────────────────────────────────────────── - -/// Two-hidden-layer MLP that outputs one Q-value per action. -#[derive(Module, Debug)] -pub struct QNet { - fc1: Linear, - fc2: Linear, - q_head: Linear, -} - -impl QNet { - /// Construct a fresh network with random weights. - pub fn new(config: &QNetConfig, device: &B::Device) -> Self { - Self { - fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device), - fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device), - q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device), - } - } - - /// Save weights to `path` (MessagePack format via [`CompactRecorder`]). - pub fn save(&self, path: &Path) -> anyhow::Result<()> { - CompactRecorder::new() - .record(self.clone().into_record(), path.to_path_buf()) - .map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}")) - } - - /// Load weights from `path` into a fresh model built from `config`. - pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result { - let record = CompactRecorder::new() - .load(path.to_path_buf(), device) - .map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?; - Ok(Self::new(config, device).load_record(record)) - } -} - -impl QValueNet for QNet { - fn forward(&self, obs: Tensor) -> Tensor { - let x = relu(self.fc1.forward(obs)); - let x = relu(self.fc2.forward(x)); - self.q_head.forward(x) - } -} - -// ── Tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use burn::backend::NdArray; - - type B = NdArray; - - fn device() -> ::Device { Default::default() } - - fn default_net() -> QNet { - QNet::new(&QNetConfig::default(), &device()) - } - - #[test] - fn forward_output_shape() { - let net = default_net(); - let obs = Tensor::zeros([4, 217], &device()); - let q = net.forward(obs); - assert_eq!(q.dims(), [4, 514]); - } - - #[test] - fn forward_single_sample() { - let net = default_net(); - let q = net.forward(Tensor::zeros([1, 217], &device())); - assert_eq!(q.dims(), [1, 514]); - } - - #[test] - fn q_values_not_all_equal() { - let net = default_net(); - let q: Vec = net.forward(Tensor::zeros([1, 217], &device())) - .into_data().to_vec().unwrap(); - let first = q[0]; - assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6)); - } - - #[test] - fn custom_config_shapes() { - let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 }; - let net = QNet::::new(&cfg, &device()); - let q = net.forward(Tensor::zeros([3, 10], &device())); - assert_eq!(q.dims(), [3, 20]); - } - - #[test] - fn save_load_preserves_weights() { - let net = default_net(); - let obs = Tensor::::ones([2, 217], &device()); - let q_before: Vec = net.forward(obs.clone()).into_data().to_vec().unwrap(); - - let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk"); - net.save(&path).expect("save failed"); - - let loaded = QNet::::load(&QNetConfig::default(), &path, &device()).expect("load failed"); - let q_after: Vec = loaded.forward(obs).into_data().to_vec().unwrap(); - - for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() { - assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}"); - } - let _ = std::fs::remove_file(path); - } -} diff --git a/store/src/game.rs b/store/src/game.rs index e4e938c..2fde45c 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -1011,16 +1011,6 @@ impl GameState { self.mark_points(player_id, points) } - /// Total accumulated score for a player: `holes × 12 + points`. - /// - /// Returns `0` if `player_id` is not found (e.g. before `init_player`). - pub fn total_score(&self, player_id: PlayerId) -> i32 { - self.players - .get(&player_id) - .map(|p| p.holes as i32 * 12 + p.points as i32) - .unwrap_or(0) - } - fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { // Update player points and holes let mut new_hole = false;