171 lines
6.2 KiB
Rust
171 lines
6.2 KiB
Rust
|
|
//! Simulation, expansion, backup, and noise helpers.
|
|||
|
|
//!
|
|||
|
|
//! These are internal to the `mcts` module; the public entry points are
|
|||
|
|
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
|
|||
|
|
|
|||
|
|
use rand::Rng;
|
|||
|
|
use rand_distr::{Gamma, Distribution};
|
|||
|
|
|
|||
|
|
use crate::env::GameEnv;
|
|||
|
|
use super::{Evaluator, MctsConfig};
|
|||
|
|
use super::node::MctsNode;
|
|||
|
|
|
|||
|
|
// ── Masked softmax ─────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
/// Numerically stable softmax over `legal` actions only.
|
|||
|
|
///
|
|||
|
|
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
|
|||
|
|
/// Returns a probability vector of length `action_space`.
|
|||
|
|
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
|
|||
|
|
let mut probs = vec![0.0f32; action_space];
|
|||
|
|
if legal.is_empty() {
|
|||
|
|
return probs;
|
|||
|
|
}
|
|||
|
|
let max_logit = legal
|
|||
|
|
.iter()
|
|||
|
|
.map(|&a| logits[a])
|
|||
|
|
.fold(f32::NEG_INFINITY, f32::max);
|
|||
|
|
let mut sum = 0.0f32;
|
|||
|
|
for &a in legal {
|
|||
|
|
let e = (logits[a] - max_logit).exp();
|
|||
|
|
probs[a] = e;
|
|||
|
|
sum += e;
|
|||
|
|
}
|
|||
|
|
if sum > 0.0 {
|
|||
|
|
for &a in legal {
|
|||
|
|
probs[a] /= sum;
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
let uniform = 1.0 / legal.len() as f32;
|
|||
|
|
for &a in legal {
|
|||
|
|
probs[a] = uniform;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
probs
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ── Dirichlet noise ────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
|
|||
|
|
///
|
|||
|
|
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
|
|||
|
|
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
|
|||
|
|
pub(super) fn add_dirichlet_noise(
|
|||
|
|
node: &mut MctsNode,
|
|||
|
|
alpha: f32,
|
|||
|
|
eps: f32,
|
|||
|
|
rng: &mut impl Rng,
|
|||
|
|
) {
|
|||
|
|
let n = node.children.len();
|
|||
|
|
if n == 0 {
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
|
|||
|
|
return;
|
|||
|
|
};
|
|||
|
|
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
|
|||
|
|
let sum: f32 = samples.iter().sum();
|
|||
|
|
if sum <= 0.0 {
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
for (i, (_, child)) in node.children.iter_mut().enumerate() {
|
|||
|
|
let noise = samples[i] / sum;
|
|||
|
|
child.p = (1.0 - eps) * child.p + eps * noise;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ── Expansion ──────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
/// Evaluate the network at `state` and populate `node` with children.
|
|||
|
|
///
|
|||
|
|
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
|
|||
|
|
/// Returns the network value estimate from `player_idx`'s perspective.
|
|||
|
|
pub(super) fn expand<E: GameEnv>(
|
|||
|
|
node: &mut MctsNode,
|
|||
|
|
state: &E::State,
|
|||
|
|
env: &E,
|
|||
|
|
evaluator: &dyn Evaluator,
|
|||
|
|
player_idx: usize,
|
|||
|
|
) -> f32 {
|
|||
|
|
let obs = env.observation(state, player_idx);
|
|||
|
|
let legal = env.legal_actions(state);
|
|||
|
|
let (logits, value) = evaluator.evaluate(&obs);
|
|||
|
|
let priors = masked_softmax(&logits, &legal, env.action_space());
|
|||
|
|
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
|
|||
|
|
node.expanded = true;
|
|||
|
|
node.n = 1;
|
|||
|
|
node.w = value;
|
|||
|
|
value
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ── Simulation ─────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
/// One MCTS simulation from an **already-expanded** decision node.
|
|||
|
|
///
|
|||
|
|
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
|
|||
|
|
/// and backs up the result.
|
|||
|
|
///
|
|||
|
|
/// * `player_idx` — the player (0 or 1) who acts at `state`.
|
|||
|
|
/// * Returns the backed-up value **from `player_idx`'s perspective**.
|
|||
|
|
pub(super) fn simulate<E: GameEnv>(
|
|||
|
|
node: &mut MctsNode,
|
|||
|
|
state: E::State,
|
|||
|
|
env: &E,
|
|||
|
|
evaluator: &dyn Evaluator,
|
|||
|
|
config: &MctsConfig,
|
|||
|
|
rng: &mut impl Rng,
|
|||
|
|
player_idx: usize,
|
|||
|
|
) -> f32 {
|
|||
|
|
debug_assert!(node.expanded, "simulate called on unexpanded node");
|
|||
|
|
|
|||
|
|
// ── Selection: child with highest PUCT ────────────────────────────────
|
|||
|
|
let parent_n = node.n;
|
|||
|
|
let best = node
|
|||
|
|
.children
|
|||
|
|
.iter()
|
|||
|
|
.enumerate()
|
|||
|
|
.max_by(|(_, (_, a)), (_, (_, b))| {
|
|||
|
|
a.puct(parent_n, config.c_puct)
|
|||
|
|
.partial_cmp(&b.puct(parent_n, config.c_puct))
|
|||
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|||
|
|
})
|
|||
|
|
.map(|(i, _)| i)
|
|||
|
|
.expect("expanded node must have at least one child");
|
|||
|
|
|
|||
|
|
let (action, child) = &mut node.children[best];
|
|||
|
|
let action = *action;
|
|||
|
|
|
|||
|
|
// ── Apply action + advance through any chance nodes ───────────────────
|
|||
|
|
let mut next_state = state;
|
|||
|
|
env.apply(&mut next_state, action);
|
|||
|
|
while env.current_player(&next_state).is_chance() {
|
|||
|
|
env.apply_chance(&mut next_state, rng);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let next_cp = env.current_player(&next_state);
|
|||
|
|
|
|||
|
|
// ── Evaluate leaf or terminal ──────────────────────────────────────────
|
|||
|
|
// All values are converted to `player_idx`'s perspective before backup.
|
|||
|
|
let child_value = if next_cp.is_terminal() {
|
|||
|
|
let returns = env
|
|||
|
|
.returns(&next_state)
|
|||
|
|
.expect("terminal node must have returns");
|
|||
|
|
returns[player_idx]
|
|||
|
|
} else {
|
|||
|
|
let child_player = next_cp.index().unwrap();
|
|||
|
|
let v = if child.expanded {
|
|||
|
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
|||
|
|
} else {
|
|||
|
|
expand::<E>(child, &next_state, env, evaluator, child_player)
|
|||
|
|
};
|
|||
|
|
// Negate when the child belongs to the opponent.
|
|||
|
|
if child_player == player_idx { v } else { -v }
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// ── Backup ────────────────────────────────────────────────────────────
|
|||
|
|
node.n += 1;
|
|||
|
|
node.w += child_value;
|
|||
|
|
|
|||
|
|
child_value
|
|||
|
|
}
|