Compare commits

...

5 commits

17 changed files with 1603 additions and 8 deletions

1
Cargo.lock generated
View file

@ -6009,6 +6009,7 @@ dependencies = [
"criterion",
"rand 0.9.2",
"rand_distr",
"rayon",
"trictrac-store",
]

121
doc/spiel_bot_parallel.md Normal file
View file

@ -0,0 +1,121 @@
Part B — Batched MCTS leaf evaluation
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
pub trait Evaluator: Send + Sync {
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
/// Evaluate a batch of observations at once. Default falls back to
/// sequential calls; backends override this for efficiency.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
}
}
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
let b = obs_batch.len();
let obs_size = obs_batch[0].len();
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
let action_size = policies.len() / b;
(0..b).map(|i| {
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
}).collect()
}
Step B3 — Add eval_batch_size to MctsConfig
pub struct MctsConfig {
// ... existing fields ...
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
pub eval_batch_size: usize,
}
Default: 1 (backwards-compatible).
Step B4 — Make the simulation iterative (mcts/search.rs)
The current simulate is recursive. For batching we need to split it into two phases:
descend (pure tree traversal — no network call):
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
ascend (backup — no network call):
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
Step B5 — Add run_mcts_batched to mcts/mod.rs
The new entry point, called by run_mcts when config.eval_batch_size > 1:
expand root (1 network call)
optionally add Dirichlet noise
for round in 0..(n*simulations / batch_size):
leaves = []
for * in 0..batch_size:
leaf = descend(root, state, env, rng)
leaves.push(leaf)
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
where leaf.kind == NeedsEval]
results = evaluator.evaluate_batch(obs_batch)
for (leaf, result) in zip(leaves, results):
expand the leaf node (insert children from result.policy)
ascend(root, leaf.path, result.value, leaf.player_idx)
// ascend also handles terminal and crossed-chance leaves
// handle remainder: n_simulations % batch_size
run_mcts becomes a thin dispatcher:
if config.eval_batch_size <= 1 {
// existing path (unchanged)
} else {
run_mcts_batched(...)
}
Step B6 — CLI flag in az_train.rs
--eval-batch N default: 8 Leaf batch size for MCTS network calls
---
Summary of file changes
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
│ File │ Changes │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ spiel_bot/Cargo.toml │ add rayon │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
Key design constraint
Parts A and B are independent and composable:
- A only touches the outer game loop.
- B only touches the inner MCTS per game.
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.

253
doc/tensor_research.md Normal file
View file

@ -0,0 +1,253 @@
# Tensor research
## Current tensor anatomy
[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!)
[24] active player color: 0 or 1
[25] turn_stage: 15
[2627] dice values (raw 16)
[2831] white: points, holes, can_bredouille, can_big_bredouille
[3235] black: same
─────────────────────────────────
Total 36 floats
The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's
AlphaZero uses a fully-connected network.
### Fundamental problems with the current encoding
1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network
must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with
all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from.
2. No normalization. Dice (16), counts (15 to +15), booleans (0/1), points (012) coexist without scaling. Gradient
flow during training is uneven.
3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it
triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers
produces a score. Including this explicitly is the single highest-value addition.
4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 1924) governs an entirely
different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0.
5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the
starting position). It's in the Player struct but not exported.
## Key Trictrac distinctions from backgammon that shape the encoding
| Concept | Backgammon | Trictrac |
| ------------------------- | ---------------------- | --------------------------------------------------------- |
| Hitting a blot | Removes checker to bar | Scores points, checker stays |
| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened |
| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) |
| 3-checker field | Safe with spare | Safe with spare |
| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) |
| Both colors on a field | Impossible | Perfectly legal |
| Rest corner (field 12/13) | Does not exist | Special two-checker rules |
The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary
indicators directly teaches the network the phase transitions the game hinges on.
## Options
### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D)
The minimum viable improvement.
For each of the 24 fields, encode own and opponent separately with 4 indicators each:
own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target)
own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill)
own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare)
own_x[i]: max(0, count 3) (overflow)
opp_1[i]: same for opponent
Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec().
Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204
Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured
overhead is negligible.
Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from
scratch will be substantially faster and the converged policy quality better.
### Option B — Option A + Trictrac-specific derived features (flat 1D)
Recommended starting point.
Add on top of Option A:
// Quarter fill status — the single most important derived feature
quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields)
quarter_filled_opp[q] (q=0..3): same for opponent
→ 8 values
// Exit readiness
can_exit_own: 1.0 if all own checkers are in fields 1924
can_exit_opp: same for opponent
→ 2 values
// Rest corner status (field 12/13)
own_corner_taken: 1.0 if field 12 has ≥2 own checkers
opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers
→ 2 values
// Jan de 3 coups counter (normalized)
dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0)
→ 1 value
Size: 204 + 8 + 2 + 2 + 1 = 217
Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve
the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes
expensive inference from the network.
### Option C — Option B + richer positional features (flat 1D)
More complete, higher sample efficiency, minor extra cost.
Add on top of Option B:
// Per-quarter fill fraction — how close to filling each quarter
own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0
opp_quarter_fill_fraction[q] (q=0..3): same for opponent
→ 8 values
// Blot counts — number of own/opponent single-checker fields globally
// (tells the network at a glance how much battage risk/opportunity exists)
own_blot_count: (number of own fields with exactly 1 checker) / 15.0
opp_blot_count: same for opponent
→ 2 values
// Bredouille would-double multiplier (already present, but explicitly scaled)
// No change needed, already binary
Size: 217 + 8 + 2 = 227
Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network
from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts).
### Option D — 2D spatial tensor {K, 24}
For CNN-based networks. Best eventual architecture but requires changing the training setup.
Shape {14, 24} — 14 feature channels over 24 field positions:
Channel 0: own_count_1 (blot)
Channel 1: own_count_2
Channel 2: own_count_3
Channel 3: own_count_overflow (float)
Channel 4: opp_count_1
Channel 5: opp_count_2
Channel 6: opp_count_3
Channel 7: opp_count_overflow
Channel 8: own_corner_mask (1.0 at field 12)
Channel 9: opp_corner_mask (1.0 at field 13)
Channel 10: final_quarter_mask (1.0 at fields 1924)
Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter)
Channel 12: quarter_filled_opp (same for opponent)
Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers)
Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform
value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global"
row by returning shape {K, 25} with position 0 holding global features.
Size: 14 × 24 + few global channels ≈ 336384
C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated
accordingly.
Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's
alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H,
W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension.
Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel
size 23 captures adjacent-field "tout d'une" interactions.
### On 3D tensors
Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is
the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and
field-within-quarter patterns jointly.
However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The
extra architecture work makes this premature unless you're already building a custom network. The information content
is the same as Option D.
### Recommendation
Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and
the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions
(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move.
Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when
the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time.
Costs summary:
| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain |
| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- |
| Current | 36 | — | — | — | baseline |
| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) |
| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) |
| C | 227 | to_vec() rewrite | constant update | none | large + moderate |
| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial |
One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2,
the new to_vec() must express everything from the active player's perspective (which the mirror already handles for
the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state,
which they will be if computed inside to_vec().
## Other algorithms
The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully.
### 1. Without MCTS, feature quality matters more
AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred
simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup —
the network must learn the full strategic structure directly from gradient updates.
This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO,
not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less
sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from
scratch.
### 2. Normalization becomes mandatory, not optional
AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps
Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw
values (dice 16, counts 015, booleans 0/1, points 012) in the same vector, gradient flow is uneven and training can
be unstable.
For DQN/PPO, every feature in the tensor should be in [0, 1]:
dice values: / 6.0
checker counts: overflow channel / 12.0
points: / 12.0
holes: / 12.0
dice_roll_count: / 3.0 (clamped)
Booleans and the TD-Gammon binary indicators are already in [0, 1].
### 3. The shape question depends on architecture, not algorithm
| Architecture | Shape | When to use |
| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- |
| MLP | {217} flat | Any algorithm, simplest baseline |
| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) |
| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network |
| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now |
The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want
convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet
template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all
— you just reshape the flat tensor in Python before passing it to the network.
### One thing that genuinely changes: value function perspective
AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This
works well.
DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit
current-player indicator), because a single Q-network estimates action values for both players simultaneously. With
the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must
learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity.
This is a minor point for a symmetric game like Trictrac, but worth keeping in mind.
Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm.

View file

@ -9,6 +9,7 @@ anyhow = "1"
rand = "0.9"
rand_distr = "0.5"
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
rayon = "1"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

View file

@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
pub fn into_model(self) -> N {
self.model
}
pub fn model_ref(&self) -> &N {
&self.model
}
}
// Safety: NdArray<f32> modules are Send; we never share across threads without

View file

@ -47,7 +47,8 @@ use burn::{
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use rayon::prelude::*;
use spiel_bot::{
alphazero::{
@ -195,10 +196,26 @@ where
if step < temp_drop { 1.0 } else { 0.0 }
};
// Prepare per-game seeds and evaluators sequentially so the main RNG
// and model cloning stay deterministic regardless of thread scheduling.
// Burn modules are Send but not Sync, so each task must own its model.
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
let game_evals: Vec<_> = (0..args.n_games)
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
.collect();
drop(evaluator);
let all_samples: Vec<Vec<TrainSample>> = game_seeds
.into_par_iter()
.zip(game_evals.into_par_iter())
.map(|(seed, game_eval)| {
let mut game_rng = SmallRng::seed_from_u64(seed);
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
})
.collect();
let mut new_samples = 0usize;
for _ in 0..args.n_games {
let samples =
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
for samples in all_samples {
new_samples += samples.len();
replay.extend(samples);
}

View file

@ -0,0 +1,251 @@
//! DQN self-play training loop.
//!
//! # Usage
//!
//! ```sh
//! # Start fresh with default settings
//! cargo run -p spiel_bot --bin dqn_train --release
//!
//! # Custom hyperparameters
//! cargo run -p spiel_bot --bin dqn_train --release -- \
//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000
//!
//! # Resume from a checkpoint
//! cargo run -p spiel_bot --bin dqn_train --release -- \
//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100
//! ```
//!
//! # Options
//!
//! | Flag | Default | Description |
//! |------|---------|-------------|
//! | `--hidden N` | 256 | Hidden layer width |
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
//! | `--n-iter N` | 100 | Training iterations |
//! | `--n-games N` | 10 | Self-play games per iteration |
//! | `--n-train N` | 20 | Gradient steps per iteration |
//! | `--batch N` | 64 | Mini-batch size |
//! | `--replay-cap N` | 50000 | Replay buffer capacity |
//! | `--lr F` | 1e-3 | Adam learning rate |
//! | `--epsilon-start F` | 1.0 | Initial exploration rate |
//! | `--epsilon-end F` | 0.05 | Final exploration rate |
//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor |
//! | `--gamma F` | 0.99 | Discount factor |
//! | `--target-update N` | 500 | Hard-update target net every N steps |
//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) |
//! | `--save-every N` | 10 | Save checkpoint every N iterations |
//! | `--seed N` | 42 | RNG seed |
//! | `--resume PATH` | (none) | Load weights before training |
use std::path::{Path, PathBuf};
use std::time::Instant;
use burn::{
backend::{Autodiff, NdArray},
module::AutodiffModule,
optim::AdamConfig,
tensor::backend::Backend,
};
use rand::{SeedableRng, rngs::SmallRng};
use spiel_bot::{
dqn::{
DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step,
generate_dqn_episode, hard_update, linear_epsilon,
},
env::TrictracEnv,
network::{QNet, QNetConfig},
};
type TrainB = Autodiff<NdArray<f32>>;
type InferB = NdArray<f32>;
// ── CLI ───────────────────────────────────────────────────────────────────────
struct Args {
hidden: usize,
out_dir: PathBuf,
save_every: usize,
seed: u64,
resume: Option<PathBuf>,
config: DqnConfig,
}
impl Default for Args {
fn default() -> Self {
Self {
hidden: 256,
out_dir: PathBuf::from("checkpoints"),
save_every: 10,
seed: 42,
resume: None,
config: DqnConfig::default(),
}
}
}
fn parse_args() -> Args {
let raw: Vec<String> = std::env::args().collect();
let mut a = Args::default();
let mut i = 1;
while i < raw.len() {
match raw[i].as_str() {
"--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); }
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
"--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); }
"--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); }
"--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); }
"--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); }
"--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); }
"--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); }
"--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); }
"--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); }
"--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); }
"--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); }
"--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); }
"--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); }
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
}
i += 1;
}
a
}
// ── Training loop ─────────────────────────────────────────────────────────────
fn train_loop(
mut q_net: QNet<TrainB>,
cfg: &QNetConfig,
save_fn: &dyn Fn(&QNet<TrainB>, &Path) -> anyhow::Result<()>,
args: &Args,
) {
let train_device: <TrainB as Backend>::Device = Default::default();
let infer_device: <InferB as Backend>::Device = Default::default();
let mut optimizer = AdamConfig::new().init();
let mut replay = DqnReplayBuffer::new(args.config.replay_capacity);
let mut rng = SmallRng::seed_from_u64(args.seed);
let env = TrictracEnv;
let mut target_net: QNet<InferB> = hard_update::<TrainB, _>(&q_net);
let mut global_step = 0usize;
let mut epsilon = args.config.epsilon_start;
println!(
"\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}",
"", args.config.n_iterations, args.config.n_games_per_iter,
args.config.n_train_steps_per_iter, ""
);
for iter in 0..args.config.n_iterations {
let t0 = Instant::now();
// ── Self-play ────────────────────────────────────────────────────
let infer_q: QNet<InferB> = q_net.valid();
let mut new_samples = 0usize;
for _ in 0..args.config.n_games_per_iter {
let samples = generate_dqn_episode(
&env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale,
);
new_samples += samples.len();
replay.extend(samples);
}
// ── Training ─────────────────────────────────────────────────────
let mut loss_sum = 0.0f32;
let mut n_steps = 0usize;
if replay.len() >= args.config.batch_size {
for _ in 0..args.config.n_train_steps_per_iter {
let batch: Vec<_> = replay
.sample_batch(args.config.batch_size, &mut rng)
.into_iter()
.cloned()
.collect();
// Target Q-values computed on the inference backend.
let target_q = compute_target_q(
&target_net, &batch, cfg.action_size, &infer_device,
);
let (q, loss) = dqn_train_step(
q_net, &mut optimizer, &batch, &target_q,
&train_device, args.config.learning_rate, args.config.gamma,
);
q_net = q;
loss_sum += loss;
n_steps += 1;
global_step += 1;
// Hard-update target net every target_update_freq steps.
if global_step % args.config.target_update_freq == 0 {
target_net = hard_update::<TrainB, _>(&q_net);
}
// Linear epsilon decay.
epsilon = linear_epsilon(
args.config.epsilon_start,
args.config.epsilon_end,
global_step,
args.config.epsilon_decay_steps,
);
}
}
// ── Logging ──────────────────────────────────────────────────────
let elapsed = t0.elapsed();
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
println!(
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s",
iter + 1,
args.config.n_iterations,
replay.len(),
new_samples,
avg_loss,
epsilon,
elapsed.as_secs_f32(),
);
// ── Checkpoint ───────────────────────────────────────────────────
let is_last = iter + 1 == args.config.n_iterations;
if (iter + 1) % args.save_every == 0 || is_last {
let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1));
match save_fn(&q_net, &path) {
Ok(()) => println!(" -> saved {}", path.display()),
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
}
}
}
println!("\nDQN training complete.");
}
// ── Main ──────────────────────────────────────────────────────────────────────
fn main() {
let args = parse_args();
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
std::process::exit(1);
}
let train_device: <TrainB as Backend>::Device = Default::default();
let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden };
let q_net = match &args.resume {
Some(path) => {
println!("Resuming from {}", path.display());
QNet::<TrainB>::load(&cfg, path, &train_device)
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
}
None => QNet::<TrainB>::new(&cfg, &train_device),
};
train_loop(q_net, &cfg, &|m: &QNet<TrainB>, path| m.valid().save(path), &args);
}

View file

@ -0,0 +1,247 @@
//! DQN self-play episode generation.
//!
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
//! mirroring so that each player always acts from "White's perspective").
//! Transitions for both players are stored in the returned sample list.
//!
//! # Reward
//!
//! After each full decision (action applied and the state has advanced through
//! any intervening chance nodes back to the same player's next turn), the
//! reward is:
//!
//! ```text
//! r = (my_total_score_now my_total_score_then)
//! (opp_total_score_now opp_total_score_then)
//! ```
//!
//! where `total_score = holes × 12 + points`.
//!
//! # Transition structure
//!
//! We use a "pending transition" per player. When a player acts again, we
//! *complete* the previous pending transition by filling in `next_obs`,
//! `next_legal`, and computing `reward`. Terminal transitions are completed
//! when the game ends.
use burn::tensor::{backend::Backend, Tensor, TensorData};
use rand::Rng;
use crate::env::{GameEnv, TrictracEnv};
use crate::network::QValueNet;
use super::DqnSample;
// ── Internals ─────────────────────────────────────────────────────────────────
struct PendingTransition {
obs: Vec<f32>,
action: usize,
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
score_before: [i32; 2],
}
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
q_net: &Q,
obs: &[f32],
legal: &[usize],
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
) -> usize {
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
if rng.random::<f32>() < epsilon {
legal[rng.random_range(0..legal.len())]
} else {
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs.to_vec(), [1, obs.len()]),
device,
);
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
legal
.iter()
.copied()
.max_by(|&a, &b| {
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap()
}
}
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
let opp_idx = 1 - player_idx;
((score_after[player_idx] - score_before[player_idx])
- (score_after[opp_idx] - score_before[opp_idx])) as f32
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Play one full game and return all transitions for both players.
///
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
env: &TrictracEnv,
q_net: &Q,
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
reward_scale: f32,
) -> Vec<DqnSample> {
let obs_size = env.obs_size();
let mut state = env.new_game();
let mut pending: [Option<PendingTransition>; 2] = [None, None];
let mut samples: Vec<DqnSample> = Vec::new();
loop {
// ── Advance past chance nodes ──────────────────────────────────────
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, rng);
}
let score_now = TrictracEnv::score_snapshot(&state);
if env.current_player(&state).is_terminal() {
// Complete all pending transitions as terminal.
for player_idx in 0..2 {
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: vec![0.0; obs_size],
next_legal: vec![],
done: true,
});
}
}
break;
}
let player_idx = env.current_player(&state).index().unwrap();
let legal = env.legal_actions(&state);
let obs = env.observation(&state, player_idx);
// ── Complete the previous transition for this player ───────────────
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: obs.clone(),
next_legal: legal.clone(),
done: false,
});
}
// ── Pick and apply action ──────────────────────────────────────────
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
env.apply(&mut state, action);
// ── Record new pending transition ──────────────────────────────────
pending[player_idx] = Some(PendingTransition {
obs,
action,
score_before: score_now,
});
}
samples
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng};
use crate::network::{QNet, QNetConfig};
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device { Default::default() }
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
fn tiny_q() -> QNet<B> {
QNet::new(&QNetConfig::default(), &device())
}
#[test]
fn episode_terminates_and_produces_samples() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty(), "episode must produce at least one sample");
}
#[test]
fn episode_obs_size_correct() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert_eq!(s.obs.len(), 217, "obs size mismatch");
if s.done {
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
assert!(s.next_legal.is_empty());
} else {
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
assert!(!s.next_legal.is_empty());
}
}
}
#[test]
fn episode_actions_within_action_space() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert!(s.action < 514, "action {} out of bounds", s.action);
}
}
#[test]
fn greedy_episode_also_terminates() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty());
}
#[test]
fn at_least_one_done_sample() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
let n_done = samples.iter().filter(|s| s.done).count();
// Two players, so 1 or 2 terminal transitions.
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
}
#[test]
fn compute_reward_correct() {
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
let before = [2 * 12 + 10, 0];
let after = [3 * 12 + 2, 0];
let r = compute_reward(0, &before, &after);
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
}
#[test]
fn compute_reward_with_opponent_scoring() {
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
let before = [0, 0];
let after = [2, 3];
let r = compute_reward(0, &before, &after);
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
}
}

232
spiel_bot/src/dqn/mod.rs Normal file
View file

@ -0,0 +1,232 @@
//! DQN: self-play data generation, replay buffer, and training step.
//!
//! # Algorithm
//!
//! Deep Q-Network with:
//! - **ε-greedy** exploration (linearly decayed).
//! - **Dense per-turn rewards**: `my_score_delta opponent_score_delta` where
//! `score = holes × 12 + points`.
//! - **Experience replay** with a fixed-capacity circular buffer.
//! - **Target network**: hard-copied from the online Q-net every
//! `target_update_freq` gradient steps for training stability.
//!
//! # Modules
//!
//! | Module | Contents |
//! |--------|----------|
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
pub mod episode;
pub mod trainer;
pub use episode::generate_dqn_episode;
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
use std::collections::VecDeque;
use rand::Rng;
// ── DqnSample ─────────────────────────────────────────────────────────────────
/// One transition `(s, a, r, s', done)` collected during self-play.
#[derive(Clone, Debug)]
pub struct DqnSample {
/// Observation from the acting player's perspective (`obs_size` floats).
pub obs: Vec<f32>,
/// Action index taken.
pub action: usize,
/// Per-turn reward: `my_score_delta opponent_score_delta`.
pub reward: f32,
/// Next observation from the same player's perspective.
/// All-zeros when `done = true` (ignored by the TD target).
pub next_obs: Vec<f32>,
/// Legal actions at `next_obs`. Empty when `done = true`.
pub next_legal: Vec<usize>,
/// `true` when `next_obs` is a terminal state.
pub done: bool,
}
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
///
/// When full, the oldest sample is evicted on push.
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
pub struct DqnReplayBuffer {
data: VecDeque<DqnSample>,
capacity: usize,
}
impl DqnReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
}
pub fn push(&mut self, sample: DqnSample) {
if self.data.len() == self.capacity {
self.data.pop_front();
}
self.data.push_back(sample);
}
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
for s in samples { self.push(s); }
}
pub fn len(&self) -> usize { self.data.len() }
pub fn is_empty(&self) -> bool { self.data.is_empty() }
/// Sample up to `n` distinct samples without replacement.
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
let len = self.data.len();
let n = n.min(len);
let mut indices: Vec<usize> = (0..len).collect();
for i in 0..n {
let j = rng.random_range(i..len);
indices.swap(i, j);
}
indices[..n].iter().map(|&i| &self.data[i]).collect()
}
}
// ── DqnConfig ─────────────────────────────────────────────────────────────────
/// Top-level DQN hyperparameters for the training loop.
#[derive(Debug, Clone)]
pub struct DqnConfig {
/// Initial exploration rate (1.0 = fully random).
pub epsilon_start: f32,
/// Final exploration rate after decay.
pub epsilon_end: f32,
/// Number of gradient steps over which ε decays linearly from start to end.
///
/// Should be calibrated to the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
/// means exploration never reaches `epsilon_end` during the run.
pub epsilon_decay_steps: usize,
/// Discount factor γ for the TD target. Typical: 0.99.
pub gamma: f32,
/// Hard-copy Q → target every this many gradient steps.
///
/// Should be much smaller than the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`).
pub target_update_freq: usize,
/// Adam learning rate.
pub learning_rate: f64,
/// Mini-batch size for each gradient step.
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Number of outer iterations (self-play + train).
pub n_iterations: usize,
/// Self-play games per iteration.
pub n_games_per_iter: usize,
/// Gradient steps per iteration.
pub n_train_steps_per_iter: usize,
/// Reward normalisation divisor.
///
/// Per-turn rewards (score delta) are divided by this constant before being
/// stored. Without normalisation, rewards can reach ±24 (jan with
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
/// causing MSE loss to grow unboundedly.
///
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
pub reward_scale: f32,
}
impl Default for DqnConfig {
fn default() -> Self {
// Total gradient steps with these defaults = 500 × 20 = 10_000,
// so epsilon decays fully and the target is updated 100 times.
Self {
epsilon_start: 1.0,
epsilon_end: 0.05,
epsilon_decay_steps: 10_000,
gamma: 0.99,
target_update_freq: 100,
learning_rate: 1e-3,
batch_size: 64,
replay_capacity: 50_000,
n_iterations: 500,
n_games_per_iter: 10,
n_train_steps_per_iter: 20,
reward_scale: 12.0,
}
}
}
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
if decay_steps == 0 || step >= decay_steps {
return end;
}
start + (end - start) * (step as f32 / decay_steps as f32)
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
fn dummy(reward: f32) -> DqnSample {
DqnSample {
obs: vec![0.0],
action: 0,
reward,
next_obs: vec![0.0],
next_legal: vec![0],
done: false,
}
}
#[test]
fn push_and_len() {
let mut buf = DqnReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(dummy(1.0));
buf.push(dummy(2.0));
assert_eq!(buf.len(), 2);
}
#[test]
fn evicts_oldest_at_capacity() {
let mut buf = DqnReplayBuffer::new(3);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
buf.push(dummy(3.0));
buf.push(dummy(4.0));
assert_eq!(buf.len(), 3);
assert_eq!(buf.data[0].reward, 2.0);
}
#[test]
fn sample_batch_size() {
let mut buf = DqnReplayBuffer::new(20);
for i in 0..10 { buf.push(dummy(i as f32)); }
let mut rng = SmallRng::seed_from_u64(0);
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
}
#[test]
fn linear_epsilon_start() {
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
}
#[test]
fn linear_epsilon_end() {
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
}
#[test]
fn linear_epsilon_monotone() {
let mut prev = f32::INFINITY;
for step in 0..=100 {
let e = linear_epsilon(1.0, 0.05, step, 100);
assert!(e <= prev + 1e-6);
prev = e;
}
}
}

View file

@ -0,0 +1,278 @@
//! DQN gradient step and target-network management.
//!
//! # TD target
//!
//! ```text
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
//! y_i = r_i if done
//! ```
//!
//! # Loss
//!
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
//! and `y_i` (computed from the frozen target net).
//!
//! # Target network
//!
//! [`hard_update`] copies the online Q-net weights into the target net by
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
use burn::{
module::AutodiffModule,
optim::{GradientsParams, Optimizer},
prelude::ElementConversion,
tensor::{
Int, Tensor, TensorData,
backend::{AutodiffBackend, Backend},
},
};
use crate::network::QValueNet;
use super::DqnSample;
// ── Target Q computation ─────────────────────────────────────────────────────
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
///
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
/// (their bootstrap term is dropped by the TD target anyway).
///
/// The target network runs on the **inference backend** (`InferB`) with no
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
target_net: &Q,
batch: &[DqnSample],
action_size: usize,
device: &B::Device,
) -> Vec<f32> {
let batch_size = batch.len();
// Collect indices of non-done samples (done samples have no next state).
let non_done: Vec<usize> = batch
.iter()
.enumerate()
.filter(|(_, s)| !s.done)
.map(|(i, _)| i)
.collect();
if non_done.is_empty() {
return vec![0.0; batch_size];
}
let obs_size = batch[0].next_obs.len();
let nd = non_done.len();
// Stack next observations for non-done samples → [nd, obs_size].
let obs_flat: Vec<f32> = non_done
.iter()
.flat_map(|&i| batch[i].next_obs.iter().copied())
.collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [nd, obs_size]),
device,
);
// Forward target net → [nd, action_size], then to Vec<f32>.
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
// For each non-done sample, pick max Q over legal next actions.
let mut result = vec![0.0f32; batch_size];
for (k, &i) in non_done.iter().enumerate() {
let legal = &batch[i].next_legal;
let offset = k * action_size;
let max_q = legal
.iter()
.map(|&a| q_flat[offset + a])
.fold(f32::NEG_INFINITY, f32::max);
// If legal is empty (shouldn't happen for non-done, but be safe):
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
}
result
}
// ── Training step ─────────────────────────────────────────────────────────────
/// Run one gradient step on `q_net` using `batch`.
///
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
/// frozen target network and passed in here so that this function only
/// needs the **autodiff backend**.
///
/// Returns the updated network and the scalar MSE loss.
pub fn dqn_train_step<B, Q, O>(
q_net: Q,
optimizer: &mut O,
batch: &[DqnSample],
target_max_q: &[f32],
device: &B::Device,
lr: f64,
gamma: f32,
) -> (Q, f32)
where
B: AutodiffBackend,
Q: QValueNet<B> + AutodiffModule<B>,
O: Optimizer<Q, B>,
{
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
let batch_size = batch.len();
let obs_size = batch[0].obs.len();
// ── Build observation tensor [B, obs_size] ────────────────────────────
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [batch_size, obs_size]),
device,
);
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
let q_all = q_net.forward(obs_tensor);
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
TensorData::new(actions, [batch_size]),
device,
)
.reshape([batch_size, 1]); // [B] → [B, 1]
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
// ── TD targets: r + γ · max_next_q · (1 done) ──────────────────────
let targets: Vec<f32> = batch
.iter()
.zip(target_max_q.iter())
.map(|(s, &max_q)| {
if s.done { s.reward } else { s.reward + gamma * max_q }
})
.collect();
let target_tensor = Tensor::<B, 1>::from_data(
TensorData::new(targets, [batch_size]),
device,
);
// ── MSE loss ──────────────────────────────────────────────────────────
let diff = q_pred - target_tensor.detach();
let loss = (diff.clone() * diff).mean();
let loss_scalar: f32 = loss.clone().into_scalar().elem();
// ── Backward + optimizer step ─────────────────────────────────────────
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &q_net);
let q_net = optimizer.step(lr, q_net, grads);
(q_net, loss_scalar)
}
// ── Target network update ─────────────────────────────────────────────────────
/// Hard-copy the online Q-net weights to a new target network.
///
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
/// inference-backend module with identical weights.
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
q_net.valid()
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::AdamConfig,
};
use crate::network::{QNet, QNetConfig};
type InferB = NdArray<f32>;
type TrainB = Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
(0..n)
.map(|i| DqnSample {
obs: vec![0.5f32; obs_size],
action: i % action_size,
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
next_obs: vec![0.5f32; obs_size],
next_legal: vec![0, 1],
done: i == n - 1,
})
.collect()
}
#[test]
fn compute_target_q_length() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 8);
}
#[test]
fn compute_target_q_done_is_zero() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
// Single done sample.
let batch = vec![DqnSample {
obs: vec![0.0; 4],
action: 0,
reward: 5.0,
next_obs: vec![0.0; 4],
next_legal: vec![],
done: true,
}];
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 1);
assert_eq!(tq[0], 0.0);
}
#[test]
fn train_step_returns_finite_loss() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
#[test]
fn train_step_loss_decreases() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(16, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let mut prev_loss = f32::INFINITY;
for _ in 0..10 {
let (q, loss) = dqn_train_step(
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
);
q_net = q;
assert!(loss.is_finite());
prev_loss = loss;
}
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
}
#[test]
fn hard_update_copies_weights() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = hard_update::<TrainB, _>(&q_net);
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
// After hard_update the target produces finite outputs.
assert!(q_out.iter().all(|v| v.is_finite()));
}
}

View file

@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv {
}
}
// ── DQN helpers ───────────────────────────────────────────────────────────────
impl TrictracEnv {
/// Score snapshot for DQN reward computation.
///
/// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`.
/// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2).
pub fn score_snapshot(s: &GameState) -> [i32; 2] {
[s.total_score(1), s.total_score(2)]
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]

View file

@ -1,4 +1,5 @@
pub mod alphazero;
pub mod dqn;
pub mod env;
pub mod mcts;
pub mod network;

View file

@ -403,10 +403,10 @@ mod tests {
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
assert_eq!(root.n, 1 + config.n_simulations as u32);
// Children visit counts may sum to less than n_simulations when some
// simulations cross a chance node at depth 1 (turn ends after one move)
// and evaluate with the network directly without updating child.n.
// Every simulation crosses a chance node at depth 1 (dice roll after
// the player's move). Since the fix now updates child.n in that case,
// children visit counts must sum to exactly n_simulations.
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert!(total <= config.n_simulations as u32);
assert_eq!(total, config.n_simulations as u32);
}
}

View file

@ -166,6 +166,12 @@ pub(super) fn simulate<E: GameEnv>(
// previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs);
// Record the visit so that PUCT and mcts_policy use real counts.
// Without this, child.n stays 0 for every simulation in games where
// every player action is immediately followed by a chance node (e.g.
// Trictrac), causing mcts_policy to always return a uniform policy.
child.n += 1;
child.w += value;
value
} else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player)

View file

@ -43,9 +43,11 @@
//! before passing to softmax.
pub mod mlp;
pub mod qnet;
pub mod resnet;
pub use mlp::{MlpConfig, MlpNet};
pub use qnet::{QNet, QNetConfig};
pub use resnet::{ResNet, ResNetConfig};
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
/// - `obs`: `[batch, obs_size]`
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
///
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
}
/// A neural network that outputs one Q-value per action.
///
/// # Shapes
/// - `obs`: `[batch, obs_size]`
/// - output: `[batch, action_size]` — raw Q-values (no activation)
///
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
}

View file

@ -0,0 +1,147 @@
//! Single-headed Q-value network for DQN.
//!
//! ```text
//! Input [B, obs_size]
//! → Linear(obs → hidden) → ReLU
//! → Linear(hidden → hidden) → ReLU
//! → Linear(hidden → action_size) ← raw Q-values, no activation
//! ```
use burn::{
module::Module,
nn::{Linear, LinearConfig},
record::{CompactRecorder, Recorder},
tensor::{activation::relu, backend::Backend, Tensor},
};
use std::path::Path;
use super::QValueNet;
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for [`QNet`].
#[derive(Debug, Clone)]
pub struct QNetConfig {
/// Number of input features. 217 for Trictrac's `to_tensor()`.
pub obs_size: usize,
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
pub action_size: usize,
/// Width of both hidden layers.
pub hidden_size: usize,
}
impl Default for QNetConfig {
fn default() -> Self {
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
}
}
// ── Network ───────────────────────────────────────────────────────────────────
/// Two-hidden-layer MLP that outputs one Q-value per action.
#[derive(Module, Debug)]
pub struct QNet<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
q_head: Linear<B>,
}
impl<B: Backend> QNet<B> {
/// Construct a fresh network with random weights.
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
}
}
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
CompactRecorder::new()
.record(self.clone().into_record(), path.to_path_buf())
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
}
/// Load weights from `path` into a fresh model built from `config`.
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
let record = CompactRecorder::new()
.load(path.to_path_buf(), device)
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
Ok(Self::new(config, device).load_record(record))
}
}
impl<B: Backend> QValueNet<B> for QNet<B> {
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
let x = relu(self.fc1.forward(obs));
let x = relu(self.fc2.forward(x));
self.q_head.forward(x)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device { Default::default() }
fn default_net() -> QNet<B> {
QNet::new(&QNetConfig::default(), &device())
}
#[test]
fn forward_output_shape() {
let net = default_net();
let obs = Tensor::zeros([4, 217], &device());
let q = net.forward(obs);
assert_eq!(q.dims(), [4, 514]);
}
#[test]
fn forward_single_sample() {
let net = default_net();
let q = net.forward(Tensor::zeros([1, 217], &device()));
assert_eq!(q.dims(), [1, 514]);
}
#[test]
fn q_values_not_all_equal() {
let net = default_net();
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
.into_data().to_vec().unwrap();
let first = q[0];
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
}
#[test]
fn custom_config_shapes() {
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
let net = QNet::<B>::new(&cfg, &device());
let q = net.forward(Tensor::zeros([3, 10], &device()));
assert_eq!(q.dims(), [3, 20]);
}
#[test]
fn save_load_preserves_weights() {
let net = default_net();
let obs = Tensor::<B, 2>::ones([2, 217], &device());
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
net.save(&path).expect("save failed");
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
}
let _ = std::fs::remove_file(path);
}
}

View file

@ -1011,6 +1011,16 @@ impl GameState {
self.mark_points(player_id, points)
}
/// Total accumulated score for a player: `holes × 12 + points`.
///
/// Returns `0` if `player_id` is not found (e.g. before `init_player`).
pub fn total_score(&self, player_id: PlayerId) -> i32 {
self.players
.get(&player_id)
.map(|p| p.holes as i32 * 12 + p.points as i32)
.unwrap_or(0)
}
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
// Update player points and holes
let mut new_hole = false;