Compare commits
5 commits
bb6ef47a5f
...
1554286f25
| Author | SHA1 | Date | |
|---|---|---|---|
| 1554286f25 | |||
| 31bb568c2a | |||
| e80dade303 | |||
| e7d13c9a02 | |||
| 7c0f230e3d |
17 changed files with 1603 additions and 8 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -6009,6 +6009,7 @@ dependencies = [
|
||||||
"criterion",
|
"criterion",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
|
"rayon",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
121
doc/spiel_bot_parallel.md
Normal file
121
doc/spiel_bot_parallel.md
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
Part B — Batched MCTS leaf evaluation
|
||||||
|
|
||||||
|
Goal: during a single game's MCTS, accumulate eval_batch_size leaf observations and call the network once with a [B, obs_size] tensor instead of B separate [1, obs_size] calls.
|
||||||
|
|
||||||
|
Step B1 — Add evaluate_batch to the Evaluator trait (mcts/mod.rs)
|
||||||
|
|
||||||
|
pub trait Evaluator: Send + Sync {
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||||
|
|
||||||
|
/// Evaluate a batch of observations at once. Default falls back to
|
||||||
|
/// sequential calls; backends override this for efficiency.
|
||||||
|
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||||
|
obs_batch.iter().map(|obs| self.evaluate(obs)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B2 — Implement evaluate_batch in BurnEvaluator (selfplay.rs)
|
||||||
|
|
||||||
|
Stack all observations into one [B, obs_size] tensor, call model.forward once, split the output tensors back into B rows.
|
||||||
|
|
||||||
|
fn evaluate_batch(&self, obs_batch: &[&[f32]]) -> Vec<(Vec<f32>, f32)> {
|
||||||
|
let b = obs_batch.len();
|
||||||
|
let obs_size = obs_batch[0].len();
|
||||||
|
let flat: Vec<f32> = obs_batch.iter().flat_map(|o| o.iter().copied()).collect();
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(TensorData::new(flat, [b, obs_size]), &self.device);
|
||||||
|
let (policy_tensor, value_tensor) = self.model.forward(obs_tensor);
|
||||||
|
let policies: Vec<f32> = policy_tensor.into_data().to_vec().unwrap();
|
||||||
|
let values: Vec<f32> = value_tensor.into_data().to_vec().unwrap();
|
||||||
|
let action_size = policies.len() / b;
|
||||||
|
(0..b).map(|i| {
|
||||||
|
(policies[i * action_size..(i + 1) * action_size].to_vec(), values[i])
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B3 — Add eval_batch_size to MctsConfig
|
||||||
|
|
||||||
|
pub struct MctsConfig {
|
||||||
|
// ... existing fields ...
|
||||||
|
/// Number of leaves to batch per network call. 1 = no batching (current behaviour).
|
||||||
|
pub eval_batch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
Default: 1 (backwards-compatible).
|
||||||
|
|
||||||
|
Step B4 — Make the simulation iterative (mcts/search.rs)
|
||||||
|
|
||||||
|
The current simulate is recursive. For batching we need to split it into two phases:
|
||||||
|
|
||||||
|
descend (pure tree traversal — no network call):
|
||||||
|
|
||||||
|
- Traverse from root following PUCT, advancing through chance nodes with apply_chance.
|
||||||
|
- Stop when reaching: an unvisited leaf, a terminal node, or a node whose child was already selected by another in-flight descent (virtual loss in effect).
|
||||||
|
- Return a LeafWork { path: Vec<usize>, state: E::State, player_idx: usize, kind: LeafKind } where path is the sequence of child indices taken from the root and kind is NeedsEval | Terminal(value) | CrossedChance.
|
||||||
|
- Apply virtual loss along the path during descent: n += 1, w -= 1 at every node traversed. This steers the next concurrent descent away from the same path.
|
||||||
|
|
||||||
|
ascend (backup — no network call):
|
||||||
|
|
||||||
|
- Given the path and the evaluated value, walk back up the path re-negating at player-boundary transitions.
|
||||||
|
- Undo the virtual loss: n -= 1, w += 1, then add the real update: n += 1, w += value.
|
||||||
|
|
||||||
|
Step B5 — Add run_mcts_batched to mcts/mod.rs
|
||||||
|
|
||||||
|
The new entry point, called by run_mcts when config.eval_batch_size > 1:
|
||||||
|
|
||||||
|
expand root (1 network call)
|
||||||
|
optionally add Dirichlet noise
|
||||||
|
|
||||||
|
for round in 0..(n*simulations / batch_size):
|
||||||
|
leaves = []
|
||||||
|
for * in 0..batch_size:
|
||||||
|
leaf = descend(root, state, env, rng)
|
||||||
|
leaves.push(leaf)
|
||||||
|
|
||||||
|
obs_batch = [env.observation(leaf.state, leaf.player) for leaf in leaves
|
||||||
|
where leaf.kind == NeedsEval]
|
||||||
|
results = evaluator.evaluate_batch(obs_batch)
|
||||||
|
|
||||||
|
for (leaf, result) in zip(leaves, results):
|
||||||
|
expand the leaf node (insert children from result.policy)
|
||||||
|
ascend(root, leaf.path, result.value, leaf.player_idx)
|
||||||
|
// ascend also handles terminal and crossed-chance leaves
|
||||||
|
|
||||||
|
// handle remainder: n_simulations % batch_size
|
||||||
|
|
||||||
|
run_mcts becomes a thin dispatcher:
|
||||||
|
if config.eval_batch_size <= 1 {
|
||||||
|
// existing path (unchanged)
|
||||||
|
} else {
|
||||||
|
run_mcts_batched(...)
|
||||||
|
}
|
||||||
|
|
||||||
|
Step B6 — CLI flag in az_train.rs
|
||||||
|
|
||||||
|
--eval-batch N default: 8 Leaf batch size for MCTS network calls
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Summary of file changes
|
||||||
|
|
||||||
|
┌───────────────────────────┬──────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ File │ Changes │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ spiel_bot/Cargo.toml │ add rayon │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/mcts/mod.rs │ evaluate_batch on trait; eval_batch_size in MctsConfig; run_mcts_batched │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/mcts/search.rs │ descend (iterative, virtual loss); ascend (backup path); expand_at_path │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/alphazero/selfplay.rs │ BurnEvaluator::evaluate_batch │
|
||||||
|
├───────────────────────────┼──────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ src/bin/az_train.rs │ parallel game loop (rayon); --eval-batch flag │
|
||||||
|
└───────────────────────────┴──────────────────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Key design constraint
|
||||||
|
|
||||||
|
Parts A and B are independent and composable:
|
||||||
|
|
||||||
|
- A only touches the outer game loop.
|
||||||
|
- B only touches the inner MCTS per game.
|
||||||
|
- Together: each of the N parallel games runs its own batched MCTS tree entirely independently with no shared state.
|
||||||
253
doc/tensor_research.md
Normal file
253
doc/tensor_research.md
Normal file
|
|
@ -0,0 +1,253 @@
|
||||||
|
# Tensor research
|
||||||
|
|
||||||
|
## Current tensor anatomy
|
||||||
|
|
||||||
|
[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!)
|
||||||
|
[24] active player color: 0 or 1
|
||||||
|
[25] turn_stage: 1–5
|
||||||
|
[26–27] dice values (raw 1–6)
|
||||||
|
[28–31] white: points, holes, can_bredouille, can_big_bredouille
|
||||||
|
[32–35] black: same
|
||||||
|
─────────────────────────────────
|
||||||
|
Total 36 floats
|
||||||
|
|
||||||
|
The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's
|
||||||
|
AlphaZero uses a fully-connected network.
|
||||||
|
|
||||||
|
### Fundamental problems with the current encoding
|
||||||
|
|
||||||
|
1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network
|
||||||
|
must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with
|
||||||
|
all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from.
|
||||||
|
|
||||||
|
2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient
|
||||||
|
flow during training is uneven.
|
||||||
|
|
||||||
|
3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it
|
||||||
|
triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers
|
||||||
|
produces a score. Including this explicitly is the single highest-value addition.
|
||||||
|
|
||||||
|
4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely
|
||||||
|
different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0.
|
||||||
|
|
||||||
|
5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the
|
||||||
|
starting position). It's in the Player struct but not exported.
|
||||||
|
|
||||||
|
## Key Trictrac distinctions from backgammon that shape the encoding
|
||||||
|
|
||||||
|
| Concept | Backgammon | Trictrac |
|
||||||
|
| ------------------------- | ---------------------- | --------------------------------------------------------- |
|
||||||
|
| Hitting a blot | Removes checker to bar | Scores points, checker stays |
|
||||||
|
| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened |
|
||||||
|
| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) |
|
||||||
|
| 3-checker field | Safe with spare | Safe with spare |
|
||||||
|
| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) |
|
||||||
|
| Both colors on a field | Impossible | Perfectly legal |
|
||||||
|
| Rest corner (field 12/13) | Does not exist | Special two-checker rules |
|
||||||
|
|
||||||
|
The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary
|
||||||
|
indicators directly teaches the network the phase transitions the game hinges on.
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D)
|
||||||
|
|
||||||
|
The minimum viable improvement.
|
||||||
|
|
||||||
|
For each of the 24 fields, encode own and opponent separately with 4 indicators each:
|
||||||
|
|
||||||
|
own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target)
|
||||||
|
own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill)
|
||||||
|
own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare)
|
||||||
|
own_x[i]: max(0, count − 3) (overflow)
|
||||||
|
opp_1[i]: same for opponent
|
||||||
|
…
|
||||||
|
|
||||||
|
Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec().
|
||||||
|
|
||||||
|
Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204
|
||||||
|
Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured
|
||||||
|
overhead is negligible.
|
||||||
|
Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from
|
||||||
|
scratch will be substantially faster and the converged policy quality better.
|
||||||
|
|
||||||
|
### Option B — Option A + Trictrac-specific derived features (flat 1D)
|
||||||
|
|
||||||
|
Recommended starting point.
|
||||||
|
|
||||||
|
Add on top of Option A:
|
||||||
|
|
||||||
|
// Quarter fill status — the single most important derived feature
|
||||||
|
quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields)
|
||||||
|
quarter_filled_opp[q] (q=0..3): same for opponent
|
||||||
|
→ 8 values
|
||||||
|
|
||||||
|
// Exit readiness
|
||||||
|
can_exit_own: 1.0 if all own checkers are in fields 19–24
|
||||||
|
can_exit_opp: same for opponent
|
||||||
|
→ 2 values
|
||||||
|
|
||||||
|
// Rest corner status (field 12/13)
|
||||||
|
own_corner_taken: 1.0 if field 12 has ≥2 own checkers
|
||||||
|
opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers
|
||||||
|
→ 2 values
|
||||||
|
|
||||||
|
// Jan de 3 coups counter (normalized)
|
||||||
|
dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0)
|
||||||
|
→ 1 value
|
||||||
|
|
||||||
|
Size: 204 + 8 + 2 + 2 + 1 = 217
|
||||||
|
Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve
|
||||||
|
the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes
|
||||||
|
expensive inference from the network.
|
||||||
|
|
||||||
|
### Option C — Option B + richer positional features (flat 1D)
|
||||||
|
|
||||||
|
More complete, higher sample efficiency, minor extra cost.
|
||||||
|
|
||||||
|
Add on top of Option B:
|
||||||
|
|
||||||
|
// Per-quarter fill fraction — how close to filling each quarter
|
||||||
|
own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0
|
||||||
|
opp_quarter_fill_fraction[q] (q=0..3): same for opponent
|
||||||
|
→ 8 values
|
||||||
|
|
||||||
|
// Blot counts — number of own/opponent single-checker fields globally
|
||||||
|
// (tells the network at a glance how much battage risk/opportunity exists)
|
||||||
|
own_blot_count: (number of own fields with exactly 1 checker) / 15.0
|
||||||
|
opp_blot_count: same for opponent
|
||||||
|
→ 2 values
|
||||||
|
|
||||||
|
// Bredouille would-double multiplier (already present, but explicitly scaled)
|
||||||
|
// No change needed, already binary
|
||||||
|
|
||||||
|
Size: 217 + 8 + 2 = 227
|
||||||
|
Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network
|
||||||
|
from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts).
|
||||||
|
|
||||||
|
### Option D — 2D spatial tensor {K, 24}
|
||||||
|
|
||||||
|
For CNN-based networks. Best eventual architecture but requires changing the training setup.
|
||||||
|
|
||||||
|
Shape {14, 24} — 14 feature channels over 24 field positions:
|
||||||
|
|
||||||
|
Channel 0: own_count_1 (blot)
|
||||||
|
Channel 1: own_count_2
|
||||||
|
Channel 2: own_count_3
|
||||||
|
Channel 3: own_count_overflow (float)
|
||||||
|
Channel 4: opp_count_1
|
||||||
|
Channel 5: opp_count_2
|
||||||
|
Channel 6: opp_count_3
|
||||||
|
Channel 7: opp_count_overflow
|
||||||
|
Channel 8: own_corner_mask (1.0 at field 12)
|
||||||
|
Channel 9: opp_corner_mask (1.0 at field 13)
|
||||||
|
Channel 10: final_quarter_mask (1.0 at fields 19–24)
|
||||||
|
Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter)
|
||||||
|
Channel 12: quarter_filled_opp (same for opponent)
|
||||||
|
Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers)
|
||||||
|
|
||||||
|
Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform
|
||||||
|
value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global"
|
||||||
|
row by returning shape {K, 25} with position 0 holding global features.
|
||||||
|
|
||||||
|
Size: 14 × 24 + few global channels ≈ 336–384
|
||||||
|
C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated
|
||||||
|
accordingly.
|
||||||
|
Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's
|
||||||
|
alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H,
|
||||||
|
W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension.
|
||||||
|
Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel
|
||||||
|
size 2–3 captures adjacent-field "tout d'une" interactions.
|
||||||
|
|
||||||
|
### On 3D tensors
|
||||||
|
|
||||||
|
Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is
|
||||||
|
the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and
|
||||||
|
field-within-quarter patterns jointly.
|
||||||
|
|
||||||
|
However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The
|
||||||
|
extra architecture work makes this premature unless you're already building a custom network. The information content
|
||||||
|
is the same as Option D.
|
||||||
|
|
||||||
|
### Recommendation
|
||||||
|
|
||||||
|
Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and
|
||||||
|
the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions
|
||||||
|
(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move.
|
||||||
|
|
||||||
|
Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when
|
||||||
|
the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time.
|
||||||
|
|
||||||
|
Costs summary:
|
||||||
|
|
||||||
|
| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain |
|
||||||
|
| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- |
|
||||||
|
| Current | 36 | — | — | — | baseline |
|
||||||
|
| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) |
|
||||||
|
| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) |
|
||||||
|
| C | 227 | to_vec() rewrite | constant update | none | large + moderate |
|
||||||
|
| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial |
|
||||||
|
|
||||||
|
One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2,
|
||||||
|
the new to_vec() must express everything from the active player's perspective (which the mirror already handles for
|
||||||
|
the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state,
|
||||||
|
which they will be if computed inside to_vec().
|
||||||
|
|
||||||
|
## Other algorithms
|
||||||
|
|
||||||
|
The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully.
|
||||||
|
|
||||||
|
### 1. Without MCTS, feature quality matters more
|
||||||
|
|
||||||
|
AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred
|
||||||
|
simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup —
|
||||||
|
the network must learn the full strategic structure directly from gradient updates.
|
||||||
|
|
||||||
|
This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO,
|
||||||
|
not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less
|
||||||
|
sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from
|
||||||
|
scratch.
|
||||||
|
|
||||||
|
### 2. Normalization becomes mandatory, not optional
|
||||||
|
|
||||||
|
AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps
|
||||||
|
Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw
|
||||||
|
values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can
|
||||||
|
be unstable.
|
||||||
|
|
||||||
|
For DQN/PPO, every feature in the tensor should be in [0, 1]:
|
||||||
|
|
||||||
|
dice values: / 6.0
|
||||||
|
checker counts: overflow channel / 12.0
|
||||||
|
points: / 12.0
|
||||||
|
holes: / 12.0
|
||||||
|
dice_roll_count: / 3.0 (clamped)
|
||||||
|
|
||||||
|
Booleans and the TD-Gammon binary indicators are already in [0, 1].
|
||||||
|
|
||||||
|
### 3. The shape question depends on architecture, not algorithm
|
||||||
|
|
||||||
|
| Architecture | Shape | When to use |
|
||||||
|
| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- |
|
||||||
|
| MLP | {217} flat | Any algorithm, simplest baseline |
|
||||||
|
| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) |
|
||||||
|
| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network |
|
||||||
|
| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now |
|
||||||
|
|
||||||
|
The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want
|
||||||
|
convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet
|
||||||
|
template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all
|
||||||
|
— you just reshape the flat tensor in Python before passing it to the network.
|
||||||
|
|
||||||
|
### One thing that genuinely changes: value function perspective
|
||||||
|
|
||||||
|
AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This
|
||||||
|
works well.
|
||||||
|
|
||||||
|
DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit
|
||||||
|
current-player indicator), because a single Q-network estimates action values for both players simultaneously. With
|
||||||
|
the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must
|
||||||
|
learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity.
|
||||||
|
This is a minor point for a symmetric game like Trictrac, but worth keeping in mind.
|
||||||
|
|
||||||
|
Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm.
|
||||||
|
|
@ -9,6 +9,7 @@ anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
rand_distr = "0.5"
|
rand_distr = "0.5"
|
||||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
rayon = "1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,10 @@ impl<B: Backend, N: PolicyValueNet<B>> BurnEvaluator<B, N> {
|
||||||
pub fn into_model(self) -> N {
|
pub fn into_model(self) -> N {
|
||||||
self.model
|
self.model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn model_ref(&self) -> &N {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
// Safety: NdArray<f32> modules are Send; we never share across threads without
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,8 @@ use burn::{
|
||||||
optim::AdamConfig,
|
optim::AdamConfig,
|
||||||
tensor::backend::Backend,
|
tensor::backend::Backend,
|
||||||
};
|
};
|
||||||
use rand::{SeedableRng, rngs::SmallRng};
|
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use spiel_bot::{
|
use spiel_bot::{
|
||||||
alphazero::{
|
alphazero::{
|
||||||
|
|
@ -195,10 +196,26 @@ where
|
||||||
if step < temp_drop { 1.0 } else { 0.0 }
|
if step < temp_drop { 1.0 } else { 0.0 }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Prepare per-game seeds and evaluators sequentially so the main RNG
|
||||||
|
// and model cloning stay deterministic regardless of thread scheduling.
|
||||||
|
// Burn modules are Send but not Sync, so each task must own its model.
|
||||||
|
let game_seeds: Vec<u64> = (0..args.n_games).map(|_| rng.random()).collect();
|
||||||
|
let game_evals: Vec<_> = (0..args.n_games)
|
||||||
|
.map(|_| BurnEvaluator::new(evaluator.model_ref().clone(), infer_device.clone()))
|
||||||
|
.collect();
|
||||||
|
drop(evaluator);
|
||||||
|
|
||||||
|
let all_samples: Vec<Vec<TrainSample>> = game_seeds
|
||||||
|
.into_par_iter()
|
||||||
|
.zip(game_evals.into_par_iter())
|
||||||
|
.map(|(seed, game_eval)| {
|
||||||
|
let mut game_rng = SmallRng::seed_from_u64(seed);
|
||||||
|
generate_episode(&env, &game_eval, &mcts_cfg, &temperature_fn, &mut game_rng)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
let mut new_samples = 0usize;
|
let mut new_samples = 0usize;
|
||||||
for _ in 0..args.n_games {
|
for samples in all_samples {
|
||||||
let samples =
|
|
||||||
generate_episode(&env, &evaluator, &mcts_cfg, &temperature_fn, &mut rng);
|
|
||||||
new_samples += samples.len();
|
new_samples += samples.len();
|
||||||
replay.extend(samples);
|
replay.extend(samples);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
251
spiel_bot/src/bin/dqn_train.rs
Normal file
251
spiel_bot/src/bin/dqn_train.rs
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
//! DQN self-play training loop.
|
||||||
|
//!
|
||||||
|
//! # Usage
|
||||||
|
//!
|
||||||
|
//! ```sh
|
||||||
|
//! # Start fresh with default settings
|
||||||
|
//! cargo run -p spiel_bot --bin dqn_train --release
|
||||||
|
//!
|
||||||
|
//! # Custom hyperparameters
|
||||||
|
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||||
|
//! --hidden 512 --n-iter 200 --n-games 20 --epsilon-decay 5000
|
||||||
|
//!
|
||||||
|
//! # Resume from a checkpoint
|
||||||
|
//! cargo run -p spiel_bot --bin dqn_train --release -- \
|
||||||
|
//! --resume checkpoints/dqn_iter_0050.mpk --n-iter 100
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! # Options
|
||||||
|
//!
|
||||||
|
//! | Flag | Default | Description |
|
||||||
|
//! |------|---------|-------------|
|
||||||
|
//! | `--hidden N` | 256 | Hidden layer width |
|
||||||
|
//! | `--out DIR` | `checkpoints/` | Directory for checkpoint files |
|
||||||
|
//! | `--n-iter N` | 100 | Training iterations |
|
||||||
|
//! | `--n-games N` | 10 | Self-play games per iteration |
|
||||||
|
//! | `--n-train N` | 20 | Gradient steps per iteration |
|
||||||
|
//! | `--batch N` | 64 | Mini-batch size |
|
||||||
|
//! | `--replay-cap N` | 50000 | Replay buffer capacity |
|
||||||
|
//! | `--lr F` | 1e-3 | Adam learning rate |
|
||||||
|
//! | `--epsilon-start F` | 1.0 | Initial exploration rate |
|
||||||
|
//! | `--epsilon-end F` | 0.05 | Final exploration rate |
|
||||||
|
//! | `--epsilon-decay N` | 10000 | Gradient steps for ε to reach its floor |
|
||||||
|
//! | `--gamma F` | 0.99 | Discount factor |
|
||||||
|
//! | `--target-update N` | 500 | Hard-update target net every N steps |
|
||||||
|
//! | `--reward-scale F` | 12.0 | Divide raw rewards by this (12 = one hole → ±1) |
|
||||||
|
//! | `--save-every N` | 10 | Save checkpoint every N iterations |
|
||||||
|
//! | `--seed N` | 42 | RNG seed |
|
||||||
|
//! | `--resume PATH` | (none) | Load weights before training |
|
||||||
|
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
backend::{Autodiff, NdArray},
|
||||||
|
module::AutodiffModule,
|
||||||
|
optim::AdamConfig,
|
||||||
|
tensor::backend::Backend,
|
||||||
|
};
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
use spiel_bot::{
|
||||||
|
dqn::{
|
||||||
|
DqnConfig, DqnReplayBuffer, compute_target_q, dqn_train_step,
|
||||||
|
generate_dqn_episode, hard_update, linear_epsilon,
|
||||||
|
},
|
||||||
|
env::TrictracEnv,
|
||||||
|
network::{QNet, QNetConfig},
|
||||||
|
};
|
||||||
|
|
||||||
|
type TrainB = Autodiff<NdArray<f32>>;
|
||||||
|
type InferB = NdArray<f32>;
|
||||||
|
|
||||||
|
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
struct Args {
|
||||||
|
hidden: usize,
|
||||||
|
out_dir: PathBuf,
|
||||||
|
save_every: usize,
|
||||||
|
seed: u64,
|
||||||
|
resume: Option<PathBuf>,
|
||||||
|
config: DqnConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Args {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
hidden: 256,
|
||||||
|
out_dir: PathBuf::from("checkpoints"),
|
||||||
|
save_every: 10,
|
||||||
|
seed: 42,
|
||||||
|
resume: None,
|
||||||
|
config: DqnConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_args() -> Args {
|
||||||
|
let raw: Vec<String> = std::env::args().collect();
|
||||||
|
let mut a = Args::default();
|
||||||
|
let mut i = 1;
|
||||||
|
while i < raw.len() {
|
||||||
|
match raw[i].as_str() {
|
||||||
|
"--hidden" => { i += 1; a.hidden = raw[i].parse().expect("--hidden: integer"); }
|
||||||
|
"--out" => { i += 1; a.out_dir = PathBuf::from(&raw[i]); }
|
||||||
|
"--n-iter" => { i += 1; a.config.n_iterations = raw[i].parse().expect("--n-iter: integer"); }
|
||||||
|
"--n-games" => { i += 1; a.config.n_games_per_iter = raw[i].parse().expect("--n-games: integer"); }
|
||||||
|
"--n-train" => { i += 1; a.config.n_train_steps_per_iter = raw[i].parse().expect("--n-train: integer"); }
|
||||||
|
"--batch" => { i += 1; a.config.batch_size = raw[i].parse().expect("--batch: integer"); }
|
||||||
|
"--replay-cap" => { i += 1; a.config.replay_capacity = raw[i].parse().expect("--replay-cap: integer"); }
|
||||||
|
"--lr" => { i += 1; a.config.learning_rate = raw[i].parse().expect("--lr: float"); }
|
||||||
|
"--epsilon-start" => { i += 1; a.config.epsilon_start = raw[i].parse().expect("--epsilon-start: float"); }
|
||||||
|
"--epsilon-end" => { i += 1; a.config.epsilon_end = raw[i].parse().expect("--epsilon-end: float"); }
|
||||||
|
"--epsilon-decay" => { i += 1; a.config.epsilon_decay_steps = raw[i].parse().expect("--epsilon-decay: integer"); }
|
||||||
|
"--gamma" => { i += 1; a.config.gamma = raw[i].parse().expect("--gamma: float"); }
|
||||||
|
"--target-update" => { i += 1; a.config.target_update_freq = raw[i].parse().expect("--target-update: integer"); }
|
||||||
|
"--reward-scale" => { i += 1; a.config.reward_scale = raw[i].parse().expect("--reward-scale: float"); }
|
||||||
|
"--save-every" => { i += 1; a.save_every = raw[i].parse().expect("--save-every: integer"); }
|
||||||
|
"--seed" => { i += 1; a.seed = raw[i].parse().expect("--seed: integer"); }
|
||||||
|
"--resume" => { i += 1; a.resume = Some(PathBuf::from(&raw[i])); }
|
||||||
|
other => { eprintln!("Unknown argument: {other}"); std::process::exit(1); }
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
a
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Training loop ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn train_loop(
|
||||||
|
mut q_net: QNet<TrainB>,
|
||||||
|
cfg: &QNetConfig,
|
||||||
|
save_fn: &dyn Fn(&QNet<TrainB>, &Path) -> anyhow::Result<()>,
|
||||||
|
args: &Args,
|
||||||
|
) {
|
||||||
|
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||||
|
let infer_device: <InferB as Backend>::Device = Default::default();
|
||||||
|
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let mut replay = DqnReplayBuffer::new(args.config.replay_capacity);
|
||||||
|
let mut rng = SmallRng::seed_from_u64(args.seed);
|
||||||
|
let env = TrictracEnv;
|
||||||
|
|
||||||
|
let mut target_net: QNet<InferB> = hard_update::<TrainB, _>(&q_net);
|
||||||
|
let mut global_step = 0usize;
|
||||||
|
let mut epsilon = args.config.epsilon_start;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"\n{:-<60}\n dqn_train | {} iters | {} games/iter | {} train-steps/iter\n{:-<60}",
|
||||||
|
"", args.config.n_iterations, args.config.n_games_per_iter,
|
||||||
|
args.config.n_train_steps_per_iter, ""
|
||||||
|
);
|
||||||
|
|
||||||
|
for iter in 0..args.config.n_iterations {
|
||||||
|
let t0 = Instant::now();
|
||||||
|
|
||||||
|
// ── Self-play ────────────────────────────────────────────────────
|
||||||
|
let infer_q: QNet<InferB> = q_net.valid();
|
||||||
|
let mut new_samples = 0usize;
|
||||||
|
|
||||||
|
for _ in 0..args.config.n_games_per_iter {
|
||||||
|
let samples = generate_dqn_episode(
|
||||||
|
&env, &infer_q, epsilon, &mut rng, &infer_device, args.config.reward_scale,
|
||||||
|
);
|
||||||
|
new_samples += samples.len();
|
||||||
|
replay.extend(samples);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Training ─────────────────────────────────────────────────────
|
||||||
|
let mut loss_sum = 0.0f32;
|
||||||
|
let mut n_steps = 0usize;
|
||||||
|
|
||||||
|
if replay.len() >= args.config.batch_size {
|
||||||
|
for _ in 0..args.config.n_train_steps_per_iter {
|
||||||
|
let batch: Vec<_> = replay
|
||||||
|
.sample_batch(args.config.batch_size, &mut rng)
|
||||||
|
.into_iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Target Q-values computed on the inference backend.
|
||||||
|
let target_q = compute_target_q(
|
||||||
|
&target_net, &batch, cfg.action_size, &infer_device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (q, loss) = dqn_train_step(
|
||||||
|
q_net, &mut optimizer, &batch, &target_q,
|
||||||
|
&train_device, args.config.learning_rate, args.config.gamma,
|
||||||
|
);
|
||||||
|
q_net = q;
|
||||||
|
loss_sum += loss;
|
||||||
|
n_steps += 1;
|
||||||
|
global_step += 1;
|
||||||
|
|
||||||
|
// Hard-update target net every target_update_freq steps.
|
||||||
|
if global_step % args.config.target_update_freq == 0 {
|
||||||
|
target_net = hard_update::<TrainB, _>(&q_net);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Linear epsilon decay.
|
||||||
|
epsilon = linear_epsilon(
|
||||||
|
args.config.epsilon_start,
|
||||||
|
args.config.epsilon_end,
|
||||||
|
global_step,
|
||||||
|
args.config.epsilon_decay_steps,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Logging ──────────────────────────────────────────────────────
|
||||||
|
let elapsed = t0.elapsed();
|
||||||
|
let avg_loss = if n_steps > 0 { loss_sum / n_steps as f32 } else { f32::NAN };
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"iter {:4}/{} | buf {:6} | +{:<4} samples | loss {:7.4} | ε {:.3} | {:.1}s",
|
||||||
|
iter + 1,
|
||||||
|
args.config.n_iterations,
|
||||||
|
replay.len(),
|
||||||
|
new_samples,
|
||||||
|
avg_loss,
|
||||||
|
epsilon,
|
||||||
|
elapsed.as_secs_f32(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── Checkpoint ───────────────────────────────────────────────────
|
||||||
|
let is_last = iter + 1 == args.config.n_iterations;
|
||||||
|
if (iter + 1) % args.save_every == 0 || is_last {
|
||||||
|
let path = args.out_dir.join(format!("dqn_iter_{:04}.mpk", iter + 1));
|
||||||
|
match save_fn(&q_net, &path) {
|
||||||
|
Ok(()) => println!(" -> saved {}", path.display()),
|
||||||
|
Err(e) => eprintln!(" Warning: checkpoint save failed: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\nDQN training complete.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Main ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args = parse_args();
|
||||||
|
|
||||||
|
if let Err(e) = std::fs::create_dir_all(&args.out_dir) {
|
||||||
|
eprintln!("Cannot create output directory {}: {e}", args.out_dir.display());
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let train_device: <TrainB as Backend>::Device = Default::default();
|
||||||
|
let cfg = QNetConfig { obs_size: 217, action_size: 514, hidden_size: args.hidden };
|
||||||
|
|
||||||
|
let q_net = match &args.resume {
|
||||||
|
Some(path) => {
|
||||||
|
println!("Resuming from {}", path.display());
|
||||||
|
QNet::<TrainB>::load(&cfg, path, &train_device)
|
||||||
|
.unwrap_or_else(|e| { eprintln!("Load failed: {e}"); std::process::exit(1); })
|
||||||
|
}
|
||||||
|
None => QNet::<TrainB>::new(&cfg, &train_device),
|
||||||
|
};
|
||||||
|
|
||||||
|
train_loop(q_net, &cfg, &|m: &QNet<TrainB>, path| m.valid().save(path), &args);
|
||||||
|
}
|
||||||
247
spiel_bot/src/dqn/episode.rs
Normal file
247
spiel_bot/src/dqn/episode.rs
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
//! DQN self-play episode generation.
|
||||||
|
//!
|
||||||
|
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
|
||||||
|
//! mirroring so that each player always acts from "White's perspective").
|
||||||
|
//! Transitions for both players are stored in the returned sample list.
|
||||||
|
//!
|
||||||
|
//! # Reward
|
||||||
|
//!
|
||||||
|
//! After each full decision (action applied and the state has advanced through
|
||||||
|
//! any intervening chance nodes back to the same player's next turn), the
|
||||||
|
//! reward is:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! r = (my_total_score_now − my_total_score_then)
|
||||||
|
//! − (opp_total_score_now − opp_total_score_then)
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! where `total_score = holes × 12 + points`.
|
||||||
|
//!
|
||||||
|
//! # Transition structure
|
||||||
|
//!
|
||||||
|
//! We use a "pending transition" per player. When a player acts again, we
|
||||||
|
//! *complete* the previous pending transition by filling in `next_obs`,
|
||||||
|
//! `next_legal`, and computing `reward`. Terminal transitions are completed
|
||||||
|
//! when the game ends.
|
||||||
|
|
||||||
|
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::env::{GameEnv, TrictracEnv};
|
||||||
|
use crate::network::QValueNet;
|
||||||
|
use super::DqnSample;
|
||||||
|
|
||||||
|
// ── Internals ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
struct PendingTransition {
|
||||||
|
obs: Vec<f32>,
|
||||||
|
action: usize,
|
||||||
|
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
|
||||||
|
score_before: [i32; 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
|
||||||
|
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
|
||||||
|
q_net: &Q,
|
||||||
|
obs: &[f32],
|
||||||
|
legal: &[usize],
|
||||||
|
epsilon: f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> usize {
|
||||||
|
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
|
||||||
|
if rng.random::<f32>() < epsilon {
|
||||||
|
legal[rng.random_range(0..legal.len())]
|
||||||
|
} else {
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(obs.to_vec(), [1, obs.len()]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||||
|
legal
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.max_by(|&a, &b| {
|
||||||
|
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
|
||||||
|
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
|
||||||
|
let opp_idx = 1 - player_idx;
|
||||||
|
((score_after[player_idx] - score_before[player_idx])
|
||||||
|
- (score_after[opp_idx] - score_before[opp_idx])) as f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Play one full game and return all transitions for both players.
|
||||||
|
///
|
||||||
|
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
|
||||||
|
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
|
||||||
|
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
|
||||||
|
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
|
||||||
|
env: &TrictracEnv,
|
||||||
|
q_net: &Q,
|
||||||
|
epsilon: f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
device: &B::Device,
|
||||||
|
reward_scale: f32,
|
||||||
|
) -> Vec<DqnSample> {
|
||||||
|
let obs_size = env.obs_size();
|
||||||
|
let mut state = env.new_game();
|
||||||
|
let mut pending: [Option<PendingTransition>; 2] = [None, None];
|
||||||
|
let mut samples: Vec<DqnSample> = Vec::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// ── Advance past chance nodes ──────────────────────────────────────
|
||||||
|
while env.current_player(&state).is_chance() {
|
||||||
|
env.apply_chance(&mut state, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
let score_now = TrictracEnv::score_snapshot(&state);
|
||||||
|
|
||||||
|
if env.current_player(&state).is_terminal() {
|
||||||
|
// Complete all pending transitions as terminal.
|
||||||
|
for player_idx in 0..2 {
|
||||||
|
if let Some(prev) = pending[player_idx].take() {
|
||||||
|
let reward =
|
||||||
|
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||||
|
samples.push(DqnSample {
|
||||||
|
obs: prev.obs,
|
||||||
|
action: prev.action,
|
||||||
|
reward,
|
||||||
|
next_obs: vec![0.0; obs_size],
|
||||||
|
next_legal: vec![],
|
||||||
|
done: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let player_idx = env.current_player(&state).index().unwrap();
|
||||||
|
let legal = env.legal_actions(&state);
|
||||||
|
let obs = env.observation(&state, player_idx);
|
||||||
|
|
||||||
|
// ── Complete the previous transition for this player ───────────────
|
||||||
|
if let Some(prev) = pending[player_idx].take() {
|
||||||
|
let reward =
|
||||||
|
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
|
||||||
|
samples.push(DqnSample {
|
||||||
|
obs: prev.obs,
|
||||||
|
action: prev.action,
|
||||||
|
reward,
|
||||||
|
next_obs: obs.clone(),
|
||||||
|
next_legal: legal.clone(),
|
||||||
|
done: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Pick and apply action ──────────────────────────────────────────
|
||||||
|
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
|
||||||
|
env.apply(&mut state, action);
|
||||||
|
|
||||||
|
// ── Record new pending transition ──────────────────────────────────
|
||||||
|
pending[player_idx] = Some(PendingTransition {
|
||||||
|
obs,
|
||||||
|
action,
|
||||||
|
score_before: score_now,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
use crate::network::{QNet, QNetConfig};
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device { Default::default() }
|
||||||
|
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
|
||||||
|
|
||||||
|
fn tiny_q() -> QNet<B> {
|
||||||
|
QNet::new(&QNetConfig::default(), &device())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_terminates_and_produces_samples() {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let q = tiny_q();
|
||||||
|
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||||
|
assert!(!samples.is_empty(), "episode must produce at least one sample");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_obs_size_correct() {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let q = tiny_q();
|
||||||
|
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||||
|
for s in &samples {
|
||||||
|
assert_eq!(s.obs.len(), 217, "obs size mismatch");
|
||||||
|
if s.done {
|
||||||
|
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
|
||||||
|
assert!(s.next_legal.is_empty());
|
||||||
|
} else {
|
||||||
|
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
|
||||||
|
assert!(!s.next_legal.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn episode_actions_within_action_space() {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let q = tiny_q();
|
||||||
|
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||||
|
for s in &samples {
|
||||||
|
assert!(s.action < 514, "action {} out of bounds", s.action);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn greedy_episode_also_terminates() {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let q = tiny_q();
|
||||||
|
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
|
||||||
|
assert!(!samples.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn at_least_one_done_sample() {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let q = tiny_q();
|
||||||
|
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
|
||||||
|
let n_done = samples.iter().filter(|s| s.done).count();
|
||||||
|
// Two players, so 1 or 2 terminal transitions.
|
||||||
|
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_reward_correct() {
|
||||||
|
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
|
||||||
|
let before = [2 * 12 + 10, 0];
|
||||||
|
let after = [3 * 12 + 2, 0];
|
||||||
|
let r = compute_reward(0, &before, &after);
|
||||||
|
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_reward_with_opponent_scoring() {
|
||||||
|
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
|
||||||
|
let before = [0, 0];
|
||||||
|
let after = [2, 3];
|
||||||
|
let r = compute_reward(0, &before, &after);
|
||||||
|
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
|
||||||
|
}
|
||||||
|
}
|
||||||
232
spiel_bot/src/dqn/mod.rs
Normal file
232
spiel_bot/src/dqn/mod.rs
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
//! DQN: self-play data generation, replay buffer, and training step.
|
||||||
|
//!
|
||||||
|
//! # Algorithm
|
||||||
|
//!
|
||||||
|
//! Deep Q-Network with:
|
||||||
|
//! - **ε-greedy** exploration (linearly decayed).
|
||||||
|
//! - **Dense per-turn rewards**: `my_score_delta − opponent_score_delta` where
|
||||||
|
//! `score = holes × 12 + points`.
|
||||||
|
//! - **Experience replay** with a fixed-capacity circular buffer.
|
||||||
|
//! - **Target network**: hard-copied from the online Q-net every
|
||||||
|
//! `target_update_freq` gradient steps for training stability.
|
||||||
|
//!
|
||||||
|
//! # Modules
|
||||||
|
//!
|
||||||
|
//! | Module | Contents |
|
||||||
|
//! |--------|----------|
|
||||||
|
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
|
||||||
|
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
|
||||||
|
|
||||||
|
pub mod episode;
|
||||||
|
pub mod trainer;
|
||||||
|
|
||||||
|
pub use episode::generate_dqn_episode;
|
||||||
|
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
|
||||||
|
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// ── DqnSample ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One transition `(s, a, r, s', done)` collected during self-play.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct DqnSample {
|
||||||
|
/// Observation from the acting player's perspective (`obs_size` floats).
|
||||||
|
pub obs: Vec<f32>,
|
||||||
|
/// Action index taken.
|
||||||
|
pub action: usize,
|
||||||
|
/// Per-turn reward: `my_score_delta − opponent_score_delta`.
|
||||||
|
pub reward: f32,
|
||||||
|
/// Next observation from the same player's perspective.
|
||||||
|
/// All-zeros when `done = true` (ignored by the TD target).
|
||||||
|
pub next_obs: Vec<f32>,
|
||||||
|
/// Legal actions at `next_obs`. Empty when `done = true`.
|
||||||
|
pub next_legal: Vec<usize>,
|
||||||
|
/// `true` when `next_obs` is a terminal state.
|
||||||
|
pub done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
|
||||||
|
///
|
||||||
|
/// When full, the oldest sample is evicted on push.
|
||||||
|
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
|
||||||
|
pub struct DqnReplayBuffer {
|
||||||
|
data: VecDeque<DqnSample>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnReplayBuffer {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, sample: DqnSample) {
|
||||||
|
if self.data.len() == self.capacity {
|
||||||
|
self.data.pop_front();
|
||||||
|
}
|
||||||
|
self.data.push_back(sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
|
||||||
|
for s in samples { self.push(s); }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize { self.data.len() }
|
||||||
|
pub fn is_empty(&self) -> bool { self.data.is_empty() }
|
||||||
|
|
||||||
|
/// Sample up to `n` distinct samples without replacement.
|
||||||
|
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
|
||||||
|
let len = self.data.len();
|
||||||
|
let n = n.min(len);
|
||||||
|
let mut indices: Vec<usize> = (0..len).collect();
|
||||||
|
for i in 0..n {
|
||||||
|
let j = rng.random_range(i..len);
|
||||||
|
indices.swap(i, j);
|
||||||
|
}
|
||||||
|
indices[..n].iter().map(|&i| &self.data[i]).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── DqnConfig ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Top-level DQN hyperparameters for the training loop.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DqnConfig {
|
||||||
|
/// Initial exploration rate (1.0 = fully random).
|
||||||
|
pub epsilon_start: f32,
|
||||||
|
/// Final exploration rate after decay.
|
||||||
|
pub epsilon_end: f32,
|
||||||
|
/// Number of gradient steps over which ε decays linearly from start to end.
|
||||||
|
///
|
||||||
|
/// Should be calibrated to the total number of gradient steps
|
||||||
|
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
|
||||||
|
/// means exploration never reaches `epsilon_end` during the run.
|
||||||
|
pub epsilon_decay_steps: usize,
|
||||||
|
/// Discount factor γ for the TD target. Typical: 0.99.
|
||||||
|
pub gamma: f32,
|
||||||
|
/// Hard-copy Q → target every this many gradient steps.
|
||||||
|
///
|
||||||
|
/// Should be much smaller than the total number of gradient steps
|
||||||
|
/// (`n_iterations × n_train_steps_per_iter`).
|
||||||
|
pub target_update_freq: usize,
|
||||||
|
/// Adam learning rate.
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Mini-batch size for each gradient step.
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// Maximum number of samples in the replay buffer.
|
||||||
|
pub replay_capacity: usize,
|
||||||
|
/// Number of outer iterations (self-play + train).
|
||||||
|
pub n_iterations: usize,
|
||||||
|
/// Self-play games per iteration.
|
||||||
|
pub n_games_per_iter: usize,
|
||||||
|
/// Gradient steps per iteration.
|
||||||
|
pub n_train_steps_per_iter: usize,
|
||||||
|
/// Reward normalisation divisor.
|
||||||
|
///
|
||||||
|
/// Per-turn rewards (score delta) are divided by this constant before being
|
||||||
|
/// stored. Without normalisation, rewards can reach ±24 (jan with
|
||||||
|
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
|
||||||
|
/// causing MSE loss to grow unboundedly.
|
||||||
|
///
|
||||||
|
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
|
||||||
|
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
|
||||||
|
pub reward_scale: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
// Total gradient steps with these defaults = 500 × 20 = 10_000,
|
||||||
|
// so epsilon decays fully and the target is updated 100 times.
|
||||||
|
Self {
|
||||||
|
epsilon_start: 1.0,
|
||||||
|
epsilon_end: 0.05,
|
||||||
|
epsilon_decay_steps: 10_000,
|
||||||
|
gamma: 0.99,
|
||||||
|
target_update_freq: 100,
|
||||||
|
learning_rate: 1e-3,
|
||||||
|
batch_size: 64,
|
||||||
|
replay_capacity: 50_000,
|
||||||
|
n_iterations: 500,
|
||||||
|
n_games_per_iter: 10,
|
||||||
|
n_train_steps_per_iter: 20,
|
||||||
|
reward_scale: 12.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
|
||||||
|
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
|
||||||
|
if decay_steps == 0 || step >= decay_steps {
|
||||||
|
return end;
|
||||||
|
}
|
||||||
|
start + (end - start) * (step as f32 / decay_steps as f32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
|
||||||
|
fn dummy(reward: f32) -> DqnSample {
|
||||||
|
DqnSample {
|
||||||
|
obs: vec![0.0],
|
||||||
|
action: 0,
|
||||||
|
reward,
|
||||||
|
next_obs: vec![0.0],
|
||||||
|
next_legal: vec![0],
|
||||||
|
done: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_and_len() {
|
||||||
|
let mut buf = DqnReplayBuffer::new(10);
|
||||||
|
assert!(buf.is_empty());
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
assert_eq!(buf.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn evicts_oldest_at_capacity() {
|
||||||
|
let mut buf = DqnReplayBuffer::new(3);
|
||||||
|
buf.push(dummy(1.0));
|
||||||
|
buf.push(dummy(2.0));
|
||||||
|
buf.push(dummy(3.0));
|
||||||
|
buf.push(dummy(4.0));
|
||||||
|
assert_eq!(buf.len(), 3);
|
||||||
|
assert_eq!(buf.data[0].reward, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sample_batch_size() {
|
||||||
|
let mut buf = DqnReplayBuffer::new(20);
|
||||||
|
for i in 0..10 { buf.push(dummy(i as f32)); }
|
||||||
|
let mut rng = SmallRng::seed_from_u64(0);
|
||||||
|
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn linear_epsilon_start() {
|
||||||
|
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn linear_epsilon_end() {
|
||||||
|
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn linear_epsilon_monotone() {
|
||||||
|
let mut prev = f32::INFINITY;
|
||||||
|
for step in 0..=100 {
|
||||||
|
let e = linear_epsilon(1.0, 0.05, step, 100);
|
||||||
|
assert!(e <= prev + 1e-6);
|
||||||
|
prev = e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
278
spiel_bot/src/dqn/trainer.rs
Normal file
278
spiel_bot/src/dqn/trainer.rs
Normal file
|
|
@ -0,0 +1,278 @@
|
||||||
|
//! DQN gradient step and target-network management.
|
||||||
|
//!
|
||||||
|
//! # TD target
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
|
||||||
|
//! y_i = r_i if done
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! # Loss
|
||||||
|
//!
|
||||||
|
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
|
||||||
|
//! and `y_i` (computed from the frozen target net).
|
||||||
|
//!
|
||||||
|
//! # Target network
|
||||||
|
//!
|
||||||
|
//! [`hard_update`] copies the online Q-net weights into the target net by
|
||||||
|
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::AutodiffModule,
|
||||||
|
optim::{GradientsParams, Optimizer},
|
||||||
|
prelude::ElementConversion,
|
||||||
|
tensor::{
|
||||||
|
Int, Tensor, TensorData,
|
||||||
|
backend::{AutodiffBackend, Backend},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::network::QValueNet;
|
||||||
|
use super::DqnSample;
|
||||||
|
|
||||||
|
// ── Target Q computation ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
|
||||||
|
///
|
||||||
|
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
|
||||||
|
/// (their bootstrap term is dropped by the TD target anyway).
|
||||||
|
///
|
||||||
|
/// The target network runs on the **inference backend** (`InferB`) with no
|
||||||
|
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
|
||||||
|
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
|
||||||
|
target_net: &Q,
|
||||||
|
batch: &[DqnSample],
|
||||||
|
action_size: usize,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Vec<f32> {
|
||||||
|
let batch_size = batch.len();
|
||||||
|
|
||||||
|
// Collect indices of non-done samples (done samples have no next state).
|
||||||
|
let non_done: Vec<usize> = batch
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, s)| !s.done)
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if non_done.is_empty() {
|
||||||
|
return vec![0.0; batch_size];
|
||||||
|
}
|
||||||
|
|
||||||
|
let obs_size = batch[0].next_obs.len();
|
||||||
|
let nd = non_done.len();
|
||||||
|
|
||||||
|
// Stack next observations for non-done samples → [nd, obs_size].
|
||||||
|
let obs_flat: Vec<f32> = non_done
|
||||||
|
.iter()
|
||||||
|
.flat_map(|&i| batch[i].next_obs.iter().copied())
|
||||||
|
.collect();
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(obs_flat, [nd, obs_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Forward target net → [nd, action_size], then to Vec<f32>.
|
||||||
|
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
|
||||||
|
|
||||||
|
// For each non-done sample, pick max Q over legal next actions.
|
||||||
|
let mut result = vec![0.0f32; batch_size];
|
||||||
|
for (k, &i) in non_done.iter().enumerate() {
|
||||||
|
let legal = &batch[i].next_legal;
|
||||||
|
let offset = k * action_size;
|
||||||
|
let max_q = legal
|
||||||
|
.iter()
|
||||||
|
.map(|&a| q_flat[offset + a])
|
||||||
|
.fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
// If legal is empty (shouldn't happen for non-done, but be safe):
|
||||||
|
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Training step ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Run one gradient step on `q_net` using `batch`.
|
||||||
|
///
|
||||||
|
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
|
||||||
|
/// frozen target network and passed in here so that this function only
|
||||||
|
/// needs the **autodiff backend**.
|
||||||
|
///
|
||||||
|
/// Returns the updated network and the scalar MSE loss.
|
||||||
|
pub fn dqn_train_step<B, Q, O>(
|
||||||
|
q_net: Q,
|
||||||
|
optimizer: &mut O,
|
||||||
|
batch: &[DqnSample],
|
||||||
|
target_max_q: &[f32],
|
||||||
|
device: &B::Device,
|
||||||
|
lr: f64,
|
||||||
|
gamma: f32,
|
||||||
|
) -> (Q, f32)
|
||||||
|
where
|
||||||
|
B: AutodiffBackend,
|
||||||
|
Q: QValueNet<B> + AutodiffModule<B>,
|
||||||
|
O: Optimizer<Q, B>,
|
||||||
|
{
|
||||||
|
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
|
||||||
|
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
|
||||||
|
|
||||||
|
let batch_size = batch.len();
|
||||||
|
let obs_size = batch[0].obs.len();
|
||||||
|
|
||||||
|
// ── Build observation tensor [B, obs_size] ────────────────────────────
|
||||||
|
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
|
||||||
|
let obs_tensor = Tensor::<B, 2>::from_data(
|
||||||
|
TensorData::new(obs_flat, [batch_size, obs_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
|
||||||
|
let q_all = q_net.forward(obs_tensor);
|
||||||
|
|
||||||
|
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
|
||||||
|
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
|
||||||
|
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
|
||||||
|
TensorData::new(actions, [batch_size]),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.reshape([batch_size, 1]); // [B] → [B, 1]
|
||||||
|
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
|
||||||
|
|
||||||
|
// ── TD targets: r + γ · max_next_q · (1 − done) ──────────────────────
|
||||||
|
let targets: Vec<f32> = batch
|
||||||
|
.iter()
|
||||||
|
.zip(target_max_q.iter())
|
||||||
|
.map(|(s, &max_q)| {
|
||||||
|
if s.done { s.reward } else { s.reward + gamma * max_q }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let target_tensor = Tensor::<B, 1>::from_data(
|
||||||
|
TensorData::new(targets, [batch_size]),
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── MSE loss ──────────────────────────────────────────────────────────
|
||||||
|
let diff = q_pred - target_tensor.detach();
|
||||||
|
let loss = (diff.clone() * diff).mean();
|
||||||
|
let loss_scalar: f32 = loss.clone().into_scalar().elem();
|
||||||
|
|
||||||
|
// ── Backward + optimizer step ─────────────────────────────────────────
|
||||||
|
let grads = loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &q_net);
|
||||||
|
let q_net = optimizer.step(lr, q_net, grads);
|
||||||
|
|
||||||
|
(q_net, loss_scalar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Target network update ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Hard-copy the online Q-net weights to a new target network.
|
||||||
|
///
|
||||||
|
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
|
||||||
|
/// inference-backend module with identical weights.
|
||||||
|
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
|
||||||
|
q_net.valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::{
|
||||||
|
backend::{Autodiff, NdArray},
|
||||||
|
optim::AdamConfig,
|
||||||
|
};
|
||||||
|
use crate::network::{QNet, QNetConfig};
|
||||||
|
|
||||||
|
type InferB = NdArray<f32>;
|
||||||
|
type TrainB = Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
|
||||||
|
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
|
||||||
|
|
||||||
|
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
|
||||||
|
(0..n)
|
||||||
|
.map(|i| DqnSample {
|
||||||
|
obs: vec![0.5f32; obs_size],
|
||||||
|
action: i % action_size,
|
||||||
|
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
|
||||||
|
next_obs: vec![0.5f32; obs_size],
|
||||||
|
next_legal: vec![0, 1],
|
||||||
|
done: i == n - 1,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_target_q_length() {
|
||||||
|
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||||
|
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
let batch = dummy_batch(8, 4, 4);
|
||||||
|
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||||
|
assert_eq!(tq.len(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compute_target_q_done_is_zero() {
|
||||||
|
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||||
|
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
// Single done sample.
|
||||||
|
let batch = vec![DqnSample {
|
||||||
|
obs: vec![0.0; 4],
|
||||||
|
action: 0,
|
||||||
|
reward: 5.0,
|
||||||
|
next_obs: vec![0.0; 4],
|
||||||
|
next_legal: vec![],
|
||||||
|
done: true,
|
||||||
|
}];
|
||||||
|
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||||
|
assert_eq!(tq.len(), 1);
|
||||||
|
assert_eq!(tq[0], 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_returns_finite_loss() {
|
||||||
|
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
|
||||||
|
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||||
|
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(8, 4, 4);
|
||||||
|
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||||
|
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
|
||||||
|
assert!(loss.is_finite(), "loss must be finite, got {loss}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn train_step_loss_decreases() {
|
||||||
|
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
|
||||||
|
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||||
|
let target = QNet::<InferB>::new(&cfg, &infer_device());
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
let batch = dummy_batch(16, 4, 4);
|
||||||
|
let tq = compute_target_q(&target, &batch, 4, &infer_device());
|
||||||
|
|
||||||
|
let mut prev_loss = f32::INFINITY;
|
||||||
|
for _ in 0..10 {
|
||||||
|
let (q, loss) = dqn_train_step(
|
||||||
|
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
|
||||||
|
);
|
||||||
|
q_net = q;
|
||||||
|
assert!(loss.is_finite());
|
||||||
|
prev_loss = loss;
|
||||||
|
}
|
||||||
|
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hard_update_copies_weights() {
|
||||||
|
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
|
||||||
|
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
|
||||||
|
let target = hard_update::<TrainB, _>(&q_net);
|
||||||
|
|
||||||
|
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
|
||||||
|
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
|
||||||
|
// After hard_update the target produces finite outputs.
|
||||||
|
assert!(q_out.iter().all(|v| v.is_finite()));
|
||||||
|
}
|
||||||
|
}
|
||||||
12
spiel_bot/src/env/trictrac.rs
vendored
12
spiel_bot/src/env/trictrac.rs
vendored
|
|
@ -200,6 +200,18 @@ impl GameEnv for TrictracEnv {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── DQN helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
impl TrictracEnv {
|
||||||
|
/// Score snapshot for DQN reward computation.
|
||||||
|
///
|
||||||
|
/// Returns `[p1_total, p2_total]` where `total = holes × 12 + points`.
|
||||||
|
/// Index 0 = Player 1 (White, player_id 1), index 1 = Player 2 (Black, player_id 2).
|
||||||
|
pub fn score_snapshot(s: &GameState) -> [i32; 2] {
|
||||||
|
[s.total_score(1), s.total_score(2)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
pub mod alphazero;
|
pub mod alphazero;
|
||||||
|
pub mod dqn;
|
||||||
pub mod env;
|
pub mod env;
|
||||||
pub mod mcts;
|
pub mod mcts;
|
||||||
pub mod network;
|
pub mod network;
|
||||||
|
|
|
||||||
|
|
@ -403,10 +403,10 @@ mod tests {
|
||||||
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||||
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
||||||
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
||||||
// Children visit counts may sum to less than n_simulations when some
|
// Every simulation crosses a chance node at depth 1 (dice roll after
|
||||||
// simulations cross a chance node at depth 1 (turn ends after one move)
|
// the player's move). Since the fix now updates child.n in that case,
|
||||||
// and evaluate with the network directly without updating child.n.
|
// children visit counts must sum to exactly n_simulations.
|
||||||
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
assert!(total <= config.n_simulations as u32);
|
assert_eq!(total, config.n_simulations as u32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,12 @@ pub(super) fn simulate<E: GameEnv>(
|
||||||
// previously cached children would be for a different outcome.
|
// previously cached children would be for a different outcome.
|
||||||
let obs = env.observation(&next_state, child_player);
|
let obs = env.observation(&next_state, child_player);
|
||||||
let (_, value) = evaluator.evaluate(&obs);
|
let (_, value) = evaluator.evaluate(&obs);
|
||||||
|
// Record the visit so that PUCT and mcts_policy use real counts.
|
||||||
|
// Without this, child.n stays 0 for every simulation in games where
|
||||||
|
// every player action is immediately followed by a chance node (e.g.
|
||||||
|
// Trictrac), causing mcts_policy to always return a uniform policy.
|
||||||
|
child.n += 1;
|
||||||
|
child.w += value;
|
||||||
value
|
value
|
||||||
} else if child.expanded {
|
} else if child.expanded {
|
||||||
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,11 @@
|
||||||
//! before passing to softmax.
|
//! before passing to softmax.
|
||||||
|
|
||||||
pub mod mlp;
|
pub mod mlp;
|
||||||
|
pub mod qnet;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
|
|
||||||
pub use mlp::{MlpConfig, MlpNet};
|
pub use mlp::{MlpConfig, MlpNet};
|
||||||
|
pub use qnet::{QNet, QNetConfig};
|
||||||
pub use resnet::{ResNet, ResNetConfig};
|
pub use resnet::{ResNet, ResNetConfig};
|
||||||
|
|
||||||
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||||
|
|
@ -56,9 +58,21 @@ use burn::{module::Module, tensor::backend::Backend, tensor::Tensor};
|
||||||
/// - `obs`: `[batch, obs_size]`
|
/// - `obs`: `[batch, obs_size]`
|
||||||
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
/// - policy output: `[batch, action_size]` — raw logits (no softmax applied)
|
||||||
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
/// - value output: `[batch, 1]` — tanh-squashed ∈ (-1, 1)
|
||||||
|
///
|
||||||
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
/// Note: `Sync` is intentionally absent — Burn's `Module` internally uses
|
||||||
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
/// `OnceCell` for lazy parameter initialisation, which is not `Sync`.
|
||||||
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
/// Use an `Arc<Mutex<N>>` wrapper if cross-thread sharing is needed.
|
||||||
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
pub trait PolicyValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A neural network that outputs one Q-value per action.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
/// - `obs`: `[batch, obs_size]`
|
||||||
|
/// - output: `[batch, action_size]` — raw Q-values (no activation)
|
||||||
|
///
|
||||||
|
/// Note: `Sync` is intentionally absent for the same reason as [`PolicyValueNet`].
|
||||||
|
pub trait QValueNet<B: Backend>: Module<B> + Send + 'static {
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
|
||||||
|
}
|
||||||
|
|
|
||||||
147
spiel_bot/src/network/qnet.rs
Normal file
147
spiel_bot/src/network/qnet.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
//! Single-headed Q-value network for DQN.
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! Input [B, obs_size]
|
||||||
|
//! → Linear(obs → hidden) → ReLU
|
||||||
|
//! → Linear(hidden → hidden) → ReLU
|
||||||
|
//! → Linear(hidden → action_size) ← raw Q-values, no activation
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
module::Module,
|
||||||
|
nn::{Linear, LinearConfig},
|
||||||
|
record::{CompactRecorder, Recorder},
|
||||||
|
tensor::{activation::relu, backend::Backend, Tensor},
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use super::QValueNet;
|
||||||
|
|
||||||
|
// ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Configuration for [`QNet`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct QNetConfig {
|
||||||
|
/// Number of input features. 217 for Trictrac's `to_tensor()`.
|
||||||
|
pub obs_size: usize,
|
||||||
|
/// Number of output actions. 514 for Trictrac's `ACTION_SPACE_SIZE`.
|
||||||
|
pub action_size: usize,
|
||||||
|
/// Width of both hidden layers.
|
||||||
|
pub hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for QNetConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { obs_size: 217, action_size: 514, hidden_size: 256 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Network ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Two-hidden-layer MLP that outputs one Q-value per action.
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct QNet<B: Backend> {
|
||||||
|
fc1: Linear<B>,
|
||||||
|
fc2: Linear<B>,
|
||||||
|
q_head: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> QNet<B> {
|
||||||
|
/// Construct a fresh network with random weights.
|
||||||
|
pub fn new(config: &QNetConfig, device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
fc1: LinearConfig::new(config.obs_size, config.hidden_size).init(device),
|
||||||
|
fc2: LinearConfig::new(config.hidden_size, config.hidden_size).init(device),
|
||||||
|
q_head: LinearConfig::new(config.hidden_size, config.action_size).init(device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save weights to `path` (MessagePack format via [`CompactRecorder`]).
|
||||||
|
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||||
|
CompactRecorder::new()
|
||||||
|
.record(self.clone().into_record(), path.to_path_buf())
|
||||||
|
.map_err(|e| anyhow::anyhow!("QNet::save failed: {e:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load weights from `path` into a fresh model built from `config`.
|
||||||
|
pub fn load(config: &QNetConfig, path: &Path, device: &B::Device) -> anyhow::Result<Self> {
|
||||||
|
let record = CompactRecorder::new()
|
||||||
|
.load(path.to_path_buf(), device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("QNet::load failed: {e:?}"))?;
|
||||||
|
Ok(Self::new(config, device).load_record(record))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> QValueNet<B> for QNet<B> {
|
||||||
|
fn forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let x = relu(self.fc1.forward(obs));
|
||||||
|
let x = relu(self.fc2.forward(x));
|
||||||
|
self.q_head.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
fn device() -> <B as Backend>::Device { Default::default() }
|
||||||
|
|
||||||
|
fn default_net() -> QNet<B> {
|
||||||
|
QNet::new(&QNetConfig::default(), &device())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_output_shape() {
|
||||||
|
let net = default_net();
|
||||||
|
let obs = Tensor::zeros([4, 217], &device());
|
||||||
|
let q = net.forward(obs);
|
||||||
|
assert_eq!(q.dims(), [4, 514]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_single_sample() {
|
||||||
|
let net = default_net();
|
||||||
|
let q = net.forward(Tensor::zeros([1, 217], &device()));
|
||||||
|
assert_eq!(q.dims(), [1, 514]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_values_not_all_equal() {
|
||||||
|
let net = default_net();
|
||||||
|
let q: Vec<f32> = net.forward(Tensor::zeros([1, 217], &device()))
|
||||||
|
.into_data().to_vec().unwrap();
|
||||||
|
let first = q[0];
|
||||||
|
assert!(!q.iter().all(|&x| (x - first).abs() < 1e-6));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_config_shapes() {
|
||||||
|
let cfg = QNetConfig { obs_size: 10, action_size: 20, hidden_size: 32 };
|
||||||
|
let net = QNet::<B>::new(&cfg, &device());
|
||||||
|
let q = net.forward(Tensor::zeros([3, 10], &device()));
|
||||||
|
assert_eq!(q.dims(), [3, 20]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_load_preserves_weights() {
|
||||||
|
let net = default_net();
|
||||||
|
let obs = Tensor::<B, 2>::ones([2, 217], &device());
|
||||||
|
let q_before: Vec<f32> = net.forward(obs.clone()).into_data().to_vec().unwrap();
|
||||||
|
|
||||||
|
let path = std::env::temp_dir().join("spiel_bot_test_qnet.mpk");
|
||||||
|
net.save(&path).expect("save failed");
|
||||||
|
|
||||||
|
let loaded = QNet::<B>::load(&QNetConfig::default(), &path, &device()).expect("load failed");
|
||||||
|
let q_after: Vec<f32> = loaded.forward(obs).into_data().to_vec().unwrap();
|
||||||
|
|
||||||
|
for (i, (a, b)) in q_before.iter().zip(q_after.iter()).enumerate() {
|
||||||
|
assert!((a - b).abs() < 1e-3, "q[{i}]: {a} vs {b}");
|
||||||
|
}
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1011,6 +1011,16 @@ impl GameState {
|
||||||
self.mark_points(player_id, points)
|
self.mark_points(player_id, points)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Total accumulated score for a player: `holes × 12 + points`.
|
||||||
|
///
|
||||||
|
/// Returns `0` if `player_id` is not found (e.g. before `init_player`).
|
||||||
|
pub fn total_score(&self, player_id: PlayerId) -> i32 {
|
||||||
|
self.players
|
||||||
|
.get(&player_id)
|
||||||
|
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||||
// Update player points and holes
|
// Update player points and holes
|
||||||
let mut new_hole = false;
|
let mut new_hole = false;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue