feat(spiel_bot): Monte-Carlo tree search
This commit is contained in:
parent
7ba4b9bbf3
commit
baa47e996d
6 changed files with 672 additions and 0 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -5898,6 +5898,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"burn",
|
"burn",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
|
"rand_distr",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,4 +7,5 @@ edition = "2021"
|
||||||
trictrac-store = { path = "../store" }
|
trictrac-store = { path = "../store" }
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
|
rand_distr = "0.5"
|
||||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod env;
|
pub mod env;
|
||||||
|
pub mod mcts;
|
||||||
pub mod network;
|
pub mod network;
|
||||||
|
|
|
||||||
408
spiel_bot/src/mcts/mod.rs
Normal file
408
spiel_bot/src/mcts/mod.rs
Normal file
|
|
@ -0,0 +1,408 @@
|
||||||
|
//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance.
|
||||||
|
//!
|
||||||
|
//! # Algorithm
|
||||||
|
//!
|
||||||
|
//! The implementation follows AlphaZero's MCTS:
|
||||||
|
//!
|
||||||
|
//! 1. **Expand root** — run the network once to get priors and a value
|
||||||
|
//! estimate; optionally add Dirichlet noise for training-time exploration.
|
||||||
|
//! 2. **Simulate** `n_simulations` times:
|
||||||
|
//! - *Selection* — traverse the tree with PUCT until an unvisited leaf.
|
||||||
|
//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes;
|
||||||
|
//! chance nodes are **not** stored in the tree (outcome sampling).
|
||||||
|
//! - *Expansion* — evaluate the network at the leaf; populate children.
|
||||||
|
//! - *Backup* — propagate the value upward; negate at each player boundary.
|
||||||
|
//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]).
|
||||||
|
//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]).
|
||||||
|
//!
|
||||||
|
//! # Perspective convention
|
||||||
|
//!
|
||||||
|
//! Every [`MctsNode::w`] is stored **from the perspective of the player who
|
||||||
|
//! acts at that node**. The backup negates the child value whenever the
|
||||||
|
//! acting player differs between parent and child.
|
||||||
|
//!
|
||||||
|
//! # Stochastic games
|
||||||
|
//!
|
||||||
|
//! When [`GameEnv::current_player`] returns [`Player::Chance`], the
|
||||||
|
//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and
|
||||||
|
//! continues. Chance nodes are skipped transparently; Q-values converge to
|
||||||
|
//! their expectation over many simulations (outcome sampling).
|
||||||
|
|
||||||
|
pub mod node;
|
||||||
|
mod search;
|
||||||
|
|
||||||
|
pub use node::MctsNode;
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
|
||||||
|
// ── Evaluator trait ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Evaluates a game position for use in MCTS.
|
||||||
|
///
|
||||||
|
/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet)
|
||||||
|
/// but the `mcts` module itself does **not** depend on Burn.
|
||||||
|
pub trait Evaluator: Send + Sync {
|
||||||
|
/// Evaluate `obs` (flat observation vector of length `obs_size`).
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `policy_logits`: one raw logit per action (`action_space` entries).
|
||||||
|
/// Illegal action entries are masked inside the search — no need to
|
||||||
|
/// zero them here.
|
||||||
|
/// - `value`: scalar in `(-1, 1)` from **the current player's** perspective.
|
||||||
|
fn evaluate(&self, obs: &[f32]) -> (Vec<f32>, f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Configuration ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Hyperparameters for [`run_mcts`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MctsConfig {
|
||||||
|
/// Number of MCTS simulations per move. Typical: 50–800.
|
||||||
|
pub n_simulations: usize,
|
||||||
|
/// PUCT exploration constant `c_puct`. Typical: 1.0–2.0.
|
||||||
|
pub c_puct: f32,
|
||||||
|
/// Dirichlet noise concentration α. Set to `0.0` to disable.
|
||||||
|
/// Typical: `0.3` for Chess, `0.1` for large action spaces.
|
||||||
|
pub dirichlet_alpha: f32,
|
||||||
|
/// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`.
|
||||||
|
pub dirichlet_eps: f32,
|
||||||
|
/// Action sampling temperature. `> 0` = proportional sample, `0` = argmax.
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MctsConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
n_simulations: 200,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.3,
|
||||||
|
dirichlet_eps: 0.25,
|
||||||
|
temperature: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public interface ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Run MCTS from `state` and return the populated root [`MctsNode`].
|
||||||
|
///
|
||||||
|
/// `state` must be a player-decision node (`P1` or `P2`).
|
||||||
|
/// Use [`mcts_policy`] and [`select_action`] on the returned root.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `env.current_player(state)` is not `P1` or `P2`.
|
||||||
|
pub fn run_mcts<E: GameEnv>(
|
||||||
|
env: &E,
|
||||||
|
state: &E::State,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
config: &MctsConfig,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> MctsNode {
|
||||||
|
let player_idx = env
|
||||||
|
.current_player(state)
|
||||||
|
.index()
|
||||||
|
.expect("run_mcts called at a non-decision node");
|
||||||
|
|
||||||
|
// ── Expand root (network called once here, not inside the loop) ────────
|
||||||
|
let mut root = MctsNode::new(1.0);
|
||||||
|
search::expand::<E>(&mut root, state, env, evaluator, player_idx);
|
||||||
|
|
||||||
|
// ── Optional Dirichlet noise for training exploration ──────────────────
|
||||||
|
if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 {
|
||||||
|
search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Simulations ────────────────────────────────────────────────────────
|
||||||
|
for _ in 0..config.n_simulations {
|
||||||
|
search::simulate::<E>(
|
||||||
|
&mut root,
|
||||||
|
state.clone(),
|
||||||
|
env,
|
||||||
|
evaluator,
|
||||||
|
config,
|
||||||
|
rng,
|
||||||
|
player_idx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
root
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the MCTS policy: normalized visit counts at the root.
|
||||||
|
///
|
||||||
|
/// Returns a vector of length `action_space` where `policy[a]` is the
|
||||||
|
/// fraction of simulations that visited action `a`.
|
||||||
|
pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec<f32> {
|
||||||
|
let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum();
|
||||||
|
let mut policy = vec![0.0f32; action_space];
|
||||||
|
if total > 0.0 {
|
||||||
|
for (a, child) in &root.children {
|
||||||
|
policy[*a] = child.n as f32 / total;
|
||||||
|
}
|
||||||
|
} else if !root.children.is_empty() {
|
||||||
|
// n_simulations = 0: uniform over legal actions.
|
||||||
|
let uniform = 1.0 / root.children.len() as f32;
|
||||||
|
for (a, _) in &root.children {
|
||||||
|
policy[*a] = uniform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
policy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select an action index from the root after MCTS.
|
||||||
|
///
|
||||||
|
/// * `temperature = 0` — greedy argmax of visit counts.
|
||||||
|
/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the root has no children.
|
||||||
|
pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize {
|
||||||
|
assert!(!root.children.is_empty(), "select_action called on a root with no children");
|
||||||
|
if temperature <= 0.0 {
|
||||||
|
root.children
|
||||||
|
.iter()
|
||||||
|
.max_by_key(|(_, c)| c.n)
|
||||||
|
.map(|(a, _)| *a)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
let weights: Vec<f32> = root
|
||||||
|
.children
|
||||||
|
.iter()
|
||||||
|
.map(|(_, c)| (c.n as f32).powf(1.0 / temperature))
|
||||||
|
.collect();
|
||||||
|
let total: f32 = weights.iter().sum();
|
||||||
|
let mut r: f32 = rng.random::<f32>() * total;
|
||||||
|
for (i, (a, _)) in root.children.iter().enumerate() {
|
||||||
|
r -= weights[i];
|
||||||
|
if r <= 0.0 {
|
||||||
|
return *a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
root.children.last().map(|(a, _)| *a).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
use crate::env::Player;
|
||||||
|
|
||||||
|
// ── Minimal deterministic test game ───────────────────────────────────
|
||||||
|
//
|
||||||
|
// "Countdown" — two players alternate subtracting 1 or 2 from a counter.
|
||||||
|
// The player who brings the counter to 0 wins.
|
||||||
|
// No chance nodes, two legal actions (0 = -1, 1 = -2).
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct CState {
|
||||||
|
remaining: u8,
|
||||||
|
to_move: usize, // at terminal: last mover (winner)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountdownEnv;
|
||||||
|
|
||||||
|
impl crate::env::GameEnv for CountdownEnv {
|
||||||
|
type State = CState;
|
||||||
|
|
||||||
|
fn new_game(&self) -> CState {
|
||||||
|
CState { remaining: 6, to_move: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_player(&self, s: &CState) -> Player {
|
||||||
|
if s.remaining == 0 {
|
||||||
|
Player::Terminal
|
||||||
|
} else if s.to_move == 0 {
|
||||||
|
Player::P1
|
||||||
|
} else {
|
||||||
|
Player::P2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn legal_actions(&self, s: &CState) -> Vec<usize> {
|
||||||
|
if s.remaining >= 2 { vec![0, 1] } else { vec![0] }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(&self, s: &mut CState, action: usize) {
|
||||||
|
let sub = (action as u8) + 1;
|
||||||
|
if s.remaining <= sub {
|
||||||
|
s.remaining = 0;
|
||||||
|
// to_move stays as winner
|
||||||
|
} else {
|
||||||
|
s.remaining -= sub;
|
||||||
|
s.to_move = 1 - s.to_move;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_chance<R: rand::Rng>(&self, _s: &mut CState, _rng: &mut R) {}
|
||||||
|
|
||||||
|
fn observation(&self, s: &CState, _pov: usize) -> Vec<f32> {
|
||||||
|
vec![s.remaining as f32 / 6.0, s.to_move as f32]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn obs_size(&self) -> usize { 2 }
|
||||||
|
fn action_space(&self) -> usize { 2 }
|
||||||
|
|
||||||
|
fn returns(&self, s: &CState) -> Option<[f32; 2]> {
|
||||||
|
if s.remaining != 0 { return None; }
|
||||||
|
let mut r = [-1.0f32; 2];
|
||||||
|
r[s.to_move] = 1.0;
|
||||||
|
Some(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uniform evaluator: all logits = 0, value = 0.
|
||||||
|
// `action_space` must match the environment's `action_space()`.
|
||||||
|
struct ZeroEval(usize);
|
||||||
|
impl Evaluator for ZeroEval {
|
||||||
|
fn evaluate(&self, _obs: &[f32]) -> (Vec<f32>, f32) {
|
||||||
|
(vec![0.0f32; self.0], 0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rng() -> SmallRng {
|
||||||
|
SmallRng::seed_from_u64(42)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_n(n: usize) -> MctsConfig {
|
||||||
|
MctsConfig {
|
||||||
|
n_simulations: n,
|
||||||
|
c_puct: 1.5,
|
||||||
|
dirichlet_alpha: 0.0, // off for reproducibility
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
temperature: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Visit count tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn visit_counts_sum_to_n_simulations() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng());
|
||||||
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
|
assert_eq!(total, 50, "visit counts must sum to n_simulations");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn all_root_children_are_legal() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let legal = env.legal_actions(&state);
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng());
|
||||||
|
for (a, _) in &root.children {
|
||||||
|
assert!(legal.contains(a), "child action {a} is not legal");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Policy tests ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_sums_to_one() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
let sum: f32 = policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn policy_zero_for_illegal_actions() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
// remaining = 1 → only action 0 is legal
|
||||||
|
let state = CState { remaining: 1, to_move: 0 };
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Action selection tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn greedy_selects_most_visited() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng());
|
||||||
|
let greedy = select_action(&root, 0.0, &mut rng());
|
||||||
|
let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap();
|
||||||
|
assert_eq!(greedy, most_visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temperature_sampling_stays_legal() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let legal = env.legal_actions(&state);
|
||||||
|
let mut r = rng();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r);
|
||||||
|
for _ in 0..20 {
|
||||||
|
let a = select_action(&root, 1.0, &mut r);
|
||||||
|
assert!(legal.contains(&a), "sampled action {a} is not legal");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Zero-simulation edge case ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_simulations_uniform_policy() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng());
|
||||||
|
let policy = mcts_policy(&root, env.action_space());
|
||||||
|
// With 0 simulations, fallback is uniform over the 2 legal actions.
|
||||||
|
let sum: f32 = policy.iter().sum();
|
||||||
|
assert!((sum - 1.0).abs() < 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Root value ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn root_q_in_valid_range() {
|
||||||
|
let env = CountdownEnv;
|
||||||
|
let state = env.new_game();
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng());
|
||||||
|
let q = root.q();
|
||||||
|
assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: run on a real Trictrac game ──────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_panic_on_trictrac_state() {
|
||||||
|
use crate::env::TrictracEnv;
|
||||||
|
|
||||||
|
let env = TrictracEnv;
|
||||||
|
let mut state = env.new_game();
|
||||||
|
let mut r = rng();
|
||||||
|
|
||||||
|
// Advance past the initial chance node to reach a decision node.
|
||||||
|
while env.current_player(&state).is_chance() {
|
||||||
|
env.apply_chance(&mut state, &mut r);
|
||||||
|
}
|
||||||
|
|
||||||
|
if env.current_player(&state).is_terminal() {
|
||||||
|
return; // unlikely but safe
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = MctsConfig {
|
||||||
|
n_simulations: 5, // tiny for speed
|
||||||
|
dirichlet_alpha: 0.0,
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
..MctsConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||||
|
assert!(root.n > 0);
|
||||||
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
|
assert_eq!(total, 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
91
spiel_bot/src/mcts/node.rs
Normal file
91
spiel_bot/src/mcts/node.rs
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
//! MCTS tree node.
|
||||||
|
//!
|
||||||
|
//! [`MctsNode`] holds the visit statistics for one player-decision position in
|
||||||
|
//! the search tree. A node is *expanded* the first time the policy-value
|
||||||
|
//! network is evaluated there; before that it is a leaf.
|
||||||
|
|
||||||
|
/// One node in the MCTS tree, representing a player-decision position.
|
||||||
|
///
|
||||||
|
/// `w` stores the sum of values backed up into this node, always from the
|
||||||
|
/// perspective of **the player who acts here**. `q()` therefore also returns
|
||||||
|
/// a value in `(-1, 1)` from that same perspective.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MctsNode {
|
||||||
|
/// Visit count `N(s, a)`.
|
||||||
|
pub n: u32,
|
||||||
|
/// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective.
|
||||||
|
pub w: f32,
|
||||||
|
/// Prior probability `P(s, a)` assigned by the policy head (after masked softmax).
|
||||||
|
pub p: f32,
|
||||||
|
/// Children: `(action_index, child_node)`, populated on first expansion.
|
||||||
|
pub children: Vec<(usize, MctsNode)>,
|
||||||
|
/// `true` after the network has been evaluated and children have been set up.
|
||||||
|
pub expanded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsNode {
|
||||||
|
/// Create a fresh, unexpanded leaf with the given prior probability.
|
||||||
|
pub fn new(prior: f32) -> Self {
|
||||||
|
Self {
|
||||||
|
n: 0,
|
||||||
|
w: 0.0,
|
||||||
|
p: prior,
|
||||||
|
children: Vec::new(),
|
||||||
|
expanded: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Q(s, a) = W / N`, or `0.0` if this node has never been visited.
|
||||||
|
#[inline]
|
||||||
|
pub fn q(&self) -> f32 {
|
||||||
|
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUCT selection score:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a))
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||||
|
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_zero_when_unvisited() {
|
||||||
|
let node = MctsNode::new(0.5);
|
||||||
|
assert_eq!(node.q(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn q_reflects_w_over_n() {
|
||||||
|
let mut node = MctsNode::new(0.5);
|
||||||
|
node.n = 4;
|
||||||
|
node.w = 2.0;
|
||||||
|
assert!((node.q() - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn puct_exploration_dominates_unvisited() {
|
||||||
|
// Unvisited child should outscore a visited child with negative Q.
|
||||||
|
let mut visited = MctsNode::new(0.5);
|
||||||
|
visited.n = 10;
|
||||||
|
visited.w = -5.0; // Q = -0.5
|
||||||
|
|
||||||
|
let unvisited = MctsNode::new(0.5);
|
||||||
|
|
||||||
|
let parent_n = 10;
|
||||||
|
let c = 1.5;
|
||||||
|
assert!(
|
||||||
|
unvisited.puct(parent_n, c) > visited.puct(parent_n, c),
|
||||||
|
"unvisited child should have higher PUCT than a negatively-valued visited child"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
170
spiel_bot/src/mcts/search.rs
Normal file
170
spiel_bot/src/mcts/search.rs
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
//! Simulation, expansion, backup, and noise helpers.
|
||||||
|
//!
|
||||||
|
//! These are internal to the `mcts` module; the public entry points are
|
||||||
|
//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`].
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
use rand_distr::{Gamma, Distribution};
|
||||||
|
|
||||||
|
use crate::env::GameEnv;
|
||||||
|
use super::{Evaluator, MctsConfig};
|
||||||
|
use super::node::MctsNode;
|
||||||
|
|
||||||
|
// ── Masked softmax ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Numerically stable softmax over `legal` actions only.
|
||||||
|
///
|
||||||
|
/// Illegal logits are treated as `-∞` and receive probability `0.0`.
|
||||||
|
/// Returns a probability vector of length `action_space`.
|
||||||
|
pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec<f32> {
|
||||||
|
let mut probs = vec![0.0f32; action_space];
|
||||||
|
if legal.is_empty() {
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
let max_logit = legal
|
||||||
|
.iter()
|
||||||
|
.map(|&a| logits[a])
|
||||||
|
.fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
for &a in legal {
|
||||||
|
let e = (logits[a] - max_logit).exp();
|
||||||
|
probs[a] = e;
|
||||||
|
sum += e;
|
||||||
|
}
|
||||||
|
if sum > 0.0 {
|
||||||
|
for &a in legal {
|
||||||
|
probs[a] /= sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let uniform = 1.0 / legal.len() as f32;
|
||||||
|
for &a in legal {
|
||||||
|
probs[a] = uniform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
probs
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Dirichlet noise ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration.
|
||||||
|
///
|
||||||
|
/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`.
|
||||||
|
/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum.
|
||||||
|
pub(super) fn add_dirichlet_noise(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
alpha: f32,
|
||||||
|
eps: f32,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) {
|
||||||
|
let n = node.children.len();
|
||||||
|
if n == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let samples: Vec<f32> = (0..n).map(|_| gamma.sample(rng) as f32).collect();
|
||||||
|
let sum: f32 = samples.iter().sum();
|
||||||
|
if sum <= 0.0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (i, (_, child)) in node.children.iter_mut().enumerate() {
|
||||||
|
let noise = samples[i] / sum;
|
||||||
|
child.p = (1.0 - eps) * child.p + eps * noise;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Expansion ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Evaluate the network at `state` and populate `node` with children.
|
||||||
|
///
|
||||||
|
/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`.
|
||||||
|
/// Returns the network value estimate from `player_idx`'s perspective.
|
||||||
|
pub(super) fn expand<E: GameEnv>(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
state: &E::State,
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
player_idx: usize,
|
||||||
|
) -> f32 {
|
||||||
|
let obs = env.observation(state, player_idx);
|
||||||
|
let legal = env.legal_actions(state);
|
||||||
|
let (logits, value) = evaluator.evaluate(&obs);
|
||||||
|
let priors = masked_softmax(&logits, &legal, env.action_space());
|
||||||
|
node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect();
|
||||||
|
node.expanded = true;
|
||||||
|
node.n = 1;
|
||||||
|
node.w = value;
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Simulation ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// One MCTS simulation from an **already-expanded** decision node.
|
||||||
|
///
|
||||||
|
/// Traverses the tree with PUCT selection, expands the first unvisited leaf,
|
||||||
|
/// and backs up the result.
|
||||||
|
///
|
||||||
|
/// * `player_idx` — the player (0 or 1) who acts at `state`.
|
||||||
|
/// * Returns the backed-up value **from `player_idx`'s perspective**.
|
||||||
|
pub(super) fn simulate<E: GameEnv>(
|
||||||
|
node: &mut MctsNode,
|
||||||
|
state: E::State,
|
||||||
|
env: &E,
|
||||||
|
evaluator: &dyn Evaluator,
|
||||||
|
config: &MctsConfig,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
player_idx: usize,
|
||||||
|
) -> f32 {
|
||||||
|
debug_assert!(node.expanded, "simulate called on unexpanded node");
|
||||||
|
|
||||||
|
// ── Selection: child with highest PUCT ────────────────────────────────
|
||||||
|
let parent_n = node.n;
|
||||||
|
let best = node
|
||||||
|
.children
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, (_, a)), (_, (_, b))| {
|
||||||
|
a.puct(parent_n, config.c_puct)
|
||||||
|
.partial_cmp(&b.puct(parent_n, config.c_puct))
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.expect("expanded node must have at least one child");
|
||||||
|
|
||||||
|
let (action, child) = &mut node.children[best];
|
||||||
|
let action = *action;
|
||||||
|
|
||||||
|
// ── Apply action + advance through any chance nodes ───────────────────
|
||||||
|
let mut next_state = state;
|
||||||
|
env.apply(&mut next_state, action);
|
||||||
|
while env.current_player(&next_state).is_chance() {
|
||||||
|
env.apply_chance(&mut next_state, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
let next_cp = env.current_player(&next_state);
|
||||||
|
|
||||||
|
// ── Evaluate leaf or terminal ──────────────────────────────────────────
|
||||||
|
// All values are converted to `player_idx`'s perspective before backup.
|
||||||
|
let child_value = if next_cp.is_terminal() {
|
||||||
|
let returns = env
|
||||||
|
.returns(&next_state)
|
||||||
|
.expect("terminal node must have returns");
|
||||||
|
returns[player_idx]
|
||||||
|
} else {
|
||||||
|
let child_player = next_cp.index().unwrap();
|
||||||
|
let v = if child.expanded {
|
||||||
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||||
|
} else {
|
||||||
|
expand::<E>(child, &next_state, env, evaluator, child_player)
|
||||||
|
};
|
||||||
|
// Negate when the child belongs to the opponent.
|
||||||
|
if child_player == player_idx { v } else { -v }
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Backup ────────────────────────────────────────────────────────────
|
||||||
|
node.n += 1;
|
||||||
|
node.w += child_value;
|
||||||
|
|
||||||
|
child_value
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue