190 lines
7.3 KiB
Rust
190 lines
7.3 KiB
Rust
//! Simulation, expansion, backup, and noise helpers.
|
||
//!
|
||
//! These are internal to the `mcts` module; the public entry points are
|
||
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
|
||
|
||
use rand::Rng;
|
||
use rand_distr::{Gamma, Distribution};
|
||
|
||
use crate::env::GameEnv;
|
||
use super::{Evaluator, MctsConfig};
|
||
use super::node::MctsNode;
|
||
|
||
// ── Masked softmax ─────────────────────────────────────────────────────────
|
||
|
||
/// Numerically stable softmax over `legal` actions only.
|
||
///
|
||
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
|
||
/// Returns a probability vector of length `action_space`.
|
||
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
|
||
let mut probs = vec![0.0f32; action_space];
|
||
if legal.is_empty() {
|
||
return probs;
|
||
}
|
||
let max_logit = legal
|
||
.iter()
|
||
.map(|&a| logits[a])
|
||
.fold(f32::NEG_INFINITY, f32::max);
|
||
let mut sum = 0.0f32;
|
||
for &a in legal {
|
||
let e = (logits[a] - max_logit).exp();
|
||
probs[a] = e;
|
||
sum += e;
|
||
}
|
||
if sum > 0.0 {
|
||
for &a in legal {
|
||
probs[a] /= sum;
|
||
}
|
||
} else {
|
||
let uniform = 1.0 / legal.len() as f32;
|
||
for &a in legal {
|
||
probs[a] = uniform;
|
||
}
|
||
}
|
||
probs
|
||
}
|
||
|
||
// ── Dirichlet noise ────────────────────────────────────────────────────────
|
||
|
||
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
|
||
///
|
||
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
|
||
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
|
||
pub(super) fn add_dirichlet_noise(
|
||
node: &mut MctsNode,
|
||
alpha: f32,
|
||
eps: f32,
|
||
rng: &mut impl Rng,
|
||
) {
|
||
let n = node.children.len();
|
||
if n == 0 {
|
||
return;
|
||
}
|
||
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
|
||
return;
|
||
};
|
||
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
|
||
let sum: f32 = samples.iter().sum();
|
||
if sum <= 0.0 {
|
||
return;
|
||
}
|
||
for (i, (_, child)) in node.children.iter_mut().enumerate() {
|
||
let noise = samples[i] / sum;
|
||
child.p = (1.0 - eps) * child.p + eps * noise;
|
||
}
|
||
}
|
||
|
||
// ── Expansion ──────────────────────────────────────────────────────────────
|
||
|
||
/// Evaluate the network at `state` and populate `node` with children.
|
||
///
|
||
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
|
||
/// Returns the network value estimate from `player_idx`'s perspective.
|
||
pub(super) fn expand<E: GameEnv>(
|
||
node: &mut MctsNode,
|
||
state: &E::State,
|
||
env: &E,
|
||
evaluator: &dyn Evaluator,
|
||
player_idx: usize,
|
||
) -> f32 {
|
||
let obs = env.observation(state, player_idx);
|
||
let legal = env.legal_actions(state);
|
||
let (logits, value) = evaluator.evaluate(&obs);
|
||
let priors = masked_softmax(&logits, &legal, env.action_space());
|
||
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
|
||
node.expanded = true;
|
||
node.n = 1;
|
||
node.w = value;
|
||
value
|
||
}
|
||
|
||
// ── Simulation ─────────────────────────────────────────────────────────────
|
||
|
||
/// One MCTS simulation from an **already-expanded** decision node.
|
||
///
|
||
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
|
||
/// and backs up the result.
|
||
///
|
||
/// * `player_idx` — the player (0 or 1) who acts at `state`.
|
||
/// * Returns the backed-up value **from `player_idx`'s perspective**.
|
||
pub(super) fn simulate<E: GameEnv>(
|
||
node: &mut MctsNode,
|
||
state: E::State,
|
||
env: &E,
|
||
evaluator: &dyn Evaluator,
|
||
config: &MctsConfig,
|
||
rng: &mut impl Rng,
|
||
player_idx: usize,
|
||
) -> f32 {
|
||
debug_assert!(node.expanded, "simulate called on unexpanded node");
|
||
|
||
// ── Selection: child with highest PUCT ────────────────────────────────
|
||
let parent_n = node.n;
|
||
let best = node
|
||
.children
|
||
.iter()
|
||
.enumerate()
|
||
.max_by(|(_, (_, a)), (_, (_, b))| {
|
||
a.puct(parent_n, config.c_puct)
|
||
.partial_cmp(&b.puct(parent_n, config.c_puct))
|
||
.unwrap_or(std::cmp::Ordering::Equal)
|
||
})
|
||
.map(|(i, _)| i)
|
||
.expect("expanded node must have at least one child");
|
||
|
||
let (action, child) = &mut node.children[best];
|
||
let action = *action;
|
||
|
||
// ── Apply action + advance through any chance nodes ───────────────────
|
||
let mut next_state = state;
|
||
env.apply(&mut next_state, action);
|
||
|
||
// Track whether we crossed a chance node (dice roll) on the way down.
|
||
// If we did, the child's cached legal actions are for a *different* dice
|
||
// outcome and must not be reused — evaluate with the network directly.
|
||
let mut crossed_chance = false;
|
||
while env.current_player(&next_state).is_chance() {
|
||
env.apply_chance(&mut next_state, rng);
|
||
crossed_chance = true;
|
||
}
|
||
|
||
let next_cp = env.current_player(&next_state);
|
||
|
||
// ── Evaluate leaf or terminal ──────────────────────────────────────────
|
||
// All values are converted to `player_idx`'s perspective before backup.
|
||
let child_value = if next_cp.is_terminal() {
|
||
let returns = env
|
||
.returns(&next_state)
|
||
.expect("terminal node must have returns");
|
||
returns[player_idx]
|
||
} else {
|
||
let child_player = next_cp.index().unwrap();
|
||
let v = if crossed_chance {
|
||
// Outcome sampling: after dice, evaluate the resulting position
|
||
// directly with the network. Do NOT build the tree across chance
|
||
// boundaries — the dice change which actions are legal, so any
|
||
// previously cached children would be for a different outcome.
|
||
let obs = env.observation(&next_state, child_player);
|
||
let (_, value) = evaluator.evaluate(&obs);
|
||
// Record the visit so that PUCT and mcts_policy use real counts.
|
||
// Without this, child.n stays 0 for every simulation in games where
|
||
// every player action is immediately followed by a chance node (e.g.
|
||
// Trictrac), causing mcts_policy to always return a uniform policy.
|
||
child.n += 1;
|
||
child.w += value;
|
||
value
|
||
} else if child.expanded {
|
||
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||
} else {
|
||
expand::<E>(child, &next_state, env, evaluator, child_player)
|
||
};
|
||
// Negate when the child belongs to the opponent.
|
||
if child_player == player_idx { v } else { -v }
|
||
};
|
||
|
||
// ── Backup ────────────────────────────────────────────────────────────
|
||
node.n += 1;
|
||
node.w += child_value;
|
||
|
||
child_value
|
||
}
|