trictrac/spiel_bot/src/mcts/search.rs

190 lines
7.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Simulation, expansion, backup, and noise helpers.
//!
//! These are internal to the `mcts` module; the public entry points are
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
use rand::Rng;
use rand_distr::{Gamma, Distribution};
use crate::env::GameEnv;
use super::{Evaluator, MctsConfig};
use super::node::MctsNode;
// ── Masked softmax ─────────────────────────────────────────────────────────
/// Numerically stable softmax over `legal` actions only.
///
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
/// Returns a probability vector of length `action_space`.
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
let mut probs = vec![0.0f32; action_space];
if legal.is_empty() {
return probs;
}
let max_logit = legal
.iter()
.map(|&a| logits[a])
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for &a in legal {
let e = (logits[a] - max_logit).exp();
probs[a] = e;
sum += e;
}
if sum > 0.0 {
for &a in legal {
probs[a] /= sum;
}
} else {
let uniform = 1.0 / legal.len() as f32;
for &a in legal {
probs[a] = uniform;
}
}
probs
}
// ── Dirichlet noise ────────────────────────────────────────────────────────
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
///
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
pub(super) fn add_dirichlet_noise(
node: &mut MctsNode,
alpha: f32,
eps: f32,
rng: &mut impl Rng,
) {
let n = node.children.len();
if n == 0 {
return;
}
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
return;
};
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
let sum: f32 = samples.iter().sum();
if sum <= 0.0 {
return;
}
for (i, (_, child)) in node.children.iter_mut().enumerate() {
let noise = samples[i] / sum;
child.p = (1.0 - eps) * child.p + eps * noise;
}
}
// ── Expansion ──────────────────────────────────────────────────────────────
/// Evaluate the network at `state` and populate `node` with children.
///
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
/// Returns the network value estimate from `player_idx`'s perspective.
pub(super) fn expand<E: GameEnv>(
node: &mut MctsNode,
state: &E::State,
env: &E,
evaluator: &dyn Evaluator,
player_idx: usize,
) -> f32 {
let obs = env.observation(state, player_idx);
let legal = env.legal_actions(state);
let (logits, value) = evaluator.evaluate(&obs);
let priors = masked_softmax(&logits, &legal, env.action_space());
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
node.expanded = true;
node.n = 1;
node.w = value;
value
}
// ── Simulation ─────────────────────────────────────────────────────────────
/// One MCTS simulation from an **already-expanded** decision node.
///
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
/// and backs up the result.
///
/// * `player_idx` — the player (0 or 1) who acts at `state`.
/// * Returns the backed-up value **from `player_idx`'s perspective**.
pub(super) fn simulate<E: GameEnv>(
node: &mut MctsNode,
state: E::State,
env: &E,
evaluator: &dyn Evaluator,
config: &MctsConfig,
rng: &mut impl Rng,
player_idx: usize,
) -> f32 {
debug_assert!(node.expanded, "simulate called on unexpanded node");
// ── Selection: child with highest PUCT ────────────────────────────────
let parent_n = node.n;
let best = node
.children
.iter()
.enumerate()
.max_by(|(_, (_, a)), (_, (_, b))| {
a.puct(parent_n, config.c_puct)
.partial_cmp(&b.puct(parent_n, config.c_puct))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.expect("expanded node must have at least one child");
let (action, child) = &mut node.children[best];
let action = *action;
// ── Apply action + advance through any chance nodes ───────────────────
let mut next_state = state;
env.apply(&mut next_state, action);
// Track whether we crossed a chance node (dice roll) on the way down.
// If we did, the child's cached legal actions are for a *different* dice
// outcome and must not be reused — evaluate with the network directly.
let mut crossed_chance = false;
while env.current_player(&next_state).is_chance() {
env.apply_chance(&mut next_state, rng);
crossed_chance = true;
}
let next_cp = env.current_player(&next_state);
// ── Evaluate leaf or terminal ──────────────────────────────────────────
// All values are converted to `player_idx`'s perspective before backup.
let child_value = if next_cp.is_terminal() {
let returns = env
.returns(&next_state)
.expect("terminal node must have returns");
returns[player_idx]
} else {
let child_player = next_cp.index().unwrap();
let v = if crossed_chance {
// Outcome sampling: after dice, evaluate the resulting position
// directly with the network. Do NOT build the tree across chance
// boundaries — the dice change which actions are legal, so any
// previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs);
// Record the visit so that PUCT and mcts_policy use real counts.
// Without this, child.n stays 0 for every simulation in games where
// every player action is immediately followed by a chance node (e.g.
// Trictrac), causing mcts_policy to always return a uniform policy.
child.n += 1;
child.w += value;
value
} else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player)
} else {
expand::<E>(child, &next_state, env, evaluator, child_player)
};
// Negate when the child belongs to the opponent.
if child_player == player_idx { v } else { -v }
};
// ── Backup ────────────────────────────────────────────────────────────
node.n += 1;
node.w += child_value;
child_value
}