diff --git a/Cargo.lock b/Cargo.lock index 2e81285..0baa02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5898,6 +5898,7 @@ dependencies = [ "anyhow", "burn", "rand 0.9.2", + "rand_distr", "trictrac-store", ] diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index fba2aab..323c953 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" trictrac-store = { path = "../store" } anyhow = "1" rand = "0.9" +rand_distr = "0.5" burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 6e71016..5beb37c 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -1,2 +1,3 @@ pub mod env; +pub mod mcts; pub mod network; diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs new file mode 100644 index 0000000..e92bd09 --- /dev/null +++ b/spiel_bot/src/mcts/mod.rs @@ -0,0 +1,408 @@ +//! Monte Carlo Tree Search with PUCT selection and policy-value network guidance. +//! +//! # Algorithm +//! +//! The implementation follows AlphaZero's MCTS: +//! +//! 1. **Expand root** — run the network once to get priors and a value +//! estimate; optionally add Dirichlet noise for training-time exploration. +//! 2. **Simulate** `n_simulations` times: +//! - *Selection* — traverse the tree with PUCT until an unvisited leaf. +//! - *Chance bypass* — call [`GameEnv::apply_chance`] at chance nodes; +//! chance nodes are **not** stored in the tree (outcome sampling). +//! - *Expansion* — evaluate the network at the leaf; populate children. +//! - *Backup* — propagate the value upward; negate at each player boundary. +//! 3. **Policy** — normalized visit counts at the root ([`mcts_policy`]). +//! 4. **Action** — greedy (temperature = 0) or sampled ([`select_action`]). +//! +//! # Perspective convention +//! +//! Every [`MctsNode::w`] is stored **from the perspective of the player who +//! acts at that node**. The backup negates the child value whenever the +//! acting player differs between parent and child. +//! +//! # Stochastic games +//! +//! When [`GameEnv::current_player`] returns [`Player::Chance`], the +//! simulation calls [`GameEnv::apply_chance`] to sample a random outcome and +//! continues. Chance nodes are skipped transparently; Q-values converge to +//! their expectation over many simulations (outcome sampling). + +pub mod node; +mod search; + +pub use node::MctsNode; + +use rand::Rng; + +use crate::env::GameEnv; + +// ── Evaluator trait ──────────────────────────────────────────────────────── + +/// Evaluates a game position for use in MCTS. +/// +/// Implementations typically wrap a [`PolicyValueNet`](crate::network::PolicyValueNet) +/// but the `mcts` module itself does **not** depend on Burn. +pub trait Evaluator: Send + Sync { + /// Evaluate `obs` (flat observation vector of length `obs_size`). + /// + /// Returns: + /// - `policy_logits`: one raw logit per action (`action_space` entries). + /// Illegal action entries are masked inside the search — no need to + /// zero them here. + /// - `value`: scalar in `(-1, 1)` from **the current player's** perspective. + fn evaluate(&self, obs: &[f32]) -> (Vec, f32); +} + +// ── Configuration ───────────────────────────────────────────────────────── + +/// Hyperparameters for [`run_mcts`]. +#[derive(Debug, Clone)] +pub struct MctsConfig { + /// Number of MCTS simulations per move. Typical: 50–800. + pub n_simulations: usize, + /// PUCT exploration constant `c_puct`. Typical: 1.0–2.0. + pub c_puct: f32, + /// Dirichlet noise concentration α. Set to `0.0` to disable. + /// Typical: `0.3` for Chess, `0.1` for large action spaces. + pub dirichlet_alpha: f32, + /// Weight of Dirichlet noise mixed into root priors. Typical: `0.25`. + pub dirichlet_eps: f32, + /// Action sampling temperature. `> 0` = proportional sample, `0` = argmax. + pub temperature: f32, +} + +impl Default for MctsConfig { + fn default() -> Self { + Self { + n_simulations: 200, + c_puct: 1.5, + dirichlet_alpha: 0.3, + dirichlet_eps: 0.25, + temperature: 1.0, + } + } +} + +// ── Public interface ─────────────────────────────────────────────────────── + +/// Run MCTS from `state` and return the populated root [`MctsNode`]. +/// +/// `state` must be a player-decision node (`P1` or `P2`). +/// Use [`mcts_policy`] and [`select_action`] on the returned root. +/// +/// # Panics +/// +/// Panics if `env.current_player(state)` is not `P1` or `P2`. +pub fn run_mcts( + env: &E, + state: &E::State, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, +) -> MctsNode { + let player_idx = env + .current_player(state) + .index() + .expect("run_mcts called at a non-decision node"); + + // ── Expand root (network called once here, not inside the loop) ──────── + let mut root = MctsNode::new(1.0); + search::expand::(&mut root, state, env, evaluator, player_idx); + + // ── Optional Dirichlet noise for training exploration ────────────────── + if config.dirichlet_alpha > 0.0 && config.dirichlet_eps > 0.0 { + search::add_dirichlet_noise(&mut root, config.dirichlet_alpha, config.dirichlet_eps, rng); + } + + // ── Simulations ──────────────────────────────────────────────────────── + for _ in 0..config.n_simulations { + search::simulate::( + &mut root, + state.clone(), + env, + evaluator, + config, + rng, + player_idx, + ); + } + + root +} + +/// Compute the MCTS policy: normalized visit counts at the root. +/// +/// Returns a vector of length `action_space` where `policy[a]` is the +/// fraction of simulations that visited action `a`. +pub fn mcts_policy(root: &MctsNode, action_space: usize) -> Vec { + let total: f32 = root.children.iter().map(|(_, c)| c.n as f32).sum(); + let mut policy = vec![0.0f32; action_space]; + if total > 0.0 { + for (a, child) in &root.children { + policy[*a] = child.n as f32 / total; + } + } else if !root.children.is_empty() { + // n_simulations = 0: uniform over legal actions. + let uniform = 1.0 / root.children.len() as f32; + for (a, _) in &root.children { + policy[*a] = uniform; + } + } + policy +} + +/// Select an action index from the root after MCTS. +/// +/// * `temperature = 0` — greedy argmax of visit counts. +/// * `temperature > 0` — sample proportionally to `N^(1 / temperature)`. +/// +/// # Panics +/// +/// Panics if the root has no children. +pub fn select_action(root: &MctsNode, temperature: f32, rng: &mut impl Rng) -> usize { + assert!(!root.children.is_empty(), "select_action called on a root with no children"); + if temperature <= 0.0 { + root.children + .iter() + .max_by_key(|(_, c)| c.n) + .map(|(a, _)| *a) + .unwrap() + } else { + let weights: Vec = root + .children + .iter() + .map(|(_, c)| (c.n as f32).powf(1.0 / temperature)) + .collect(); + let total: f32 = weights.iter().sum(); + let mut r: f32 = rng.random::() * total; + for (i, (a, _)) in root.children.iter().enumerate() { + r -= weights[i]; + if r <= 0.0 { + return *a; + } + } + root.children.last().map(|(a, _)| *a).unwrap() + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use rand::{SeedableRng, rngs::SmallRng}; + use crate::env::Player; + + // ── Minimal deterministic test game ─────────────────────────────────── + // + // "Countdown" — two players alternate subtracting 1 or 2 from a counter. + // The player who brings the counter to 0 wins. + // No chance nodes, two legal actions (0 = -1, 1 = -2). + + #[derive(Clone, Debug)] + struct CState { + remaining: u8, + to_move: usize, // at terminal: last mover (winner) + } + + #[derive(Clone)] + struct CountdownEnv; + + impl crate::env::GameEnv for CountdownEnv { + type State = CState; + + fn new_game(&self) -> CState { + CState { remaining: 6, to_move: 0 } + } + + fn current_player(&self, s: &CState) -> Player { + if s.remaining == 0 { + Player::Terminal + } else if s.to_move == 0 { + Player::P1 + } else { + Player::P2 + } + } + + fn legal_actions(&self, s: &CState) -> Vec { + if s.remaining >= 2 { vec![0, 1] } else { vec![0] } + } + + fn apply(&self, s: &mut CState, action: usize) { + let sub = (action as u8) + 1; + if s.remaining <= sub { + s.remaining = 0; + // to_move stays as winner + } else { + s.remaining -= sub; + s.to_move = 1 - s.to_move; + } + } + + fn apply_chance(&self, _s: &mut CState, _rng: &mut R) {} + + fn observation(&self, s: &CState, _pov: usize) -> Vec { + vec![s.remaining as f32 / 6.0, s.to_move as f32] + } + + fn obs_size(&self) -> usize { 2 } + fn action_space(&self) -> usize { 2 } + + fn returns(&self, s: &CState) -> Option<[f32; 2]> { + if s.remaining != 0 { return None; } + let mut r = [-1.0f32; 2]; + r[s.to_move] = 1.0; + Some(r) + } + } + + // Uniform evaluator: all logits = 0, value = 0. + // `action_space` must match the environment's `action_space()`. + struct ZeroEval(usize); + impl Evaluator for ZeroEval { + fn evaluate(&self, _obs: &[f32]) -> (Vec, f32) { + (vec![0.0f32; self.0], 0.0) + } + } + + fn rng() -> SmallRng { + SmallRng::seed_from_u64(42) + } + + fn config_n(n: usize) -> MctsConfig { + MctsConfig { + n_simulations: n, + c_puct: 1.5, + dirichlet_alpha: 0.0, // off for reproducibility + dirichlet_eps: 0.0, + temperature: 1.0, + } + } + + // ── Visit count tests ───────────────────────────────────────────────── + + #[test] + fn visit_counts_sum_to_n_simulations() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(50), &mut rng()); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 50, "visit counts must sum to n_simulations"); + } + + #[test] + fn all_root_children_are_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut rng()); + for (a, _) in &root.children { + assert!(legal.contains(a), "child action {a} is not legal"); + } + } + + // ── Policy tests ───────────────────────────────────────────────────── + + #[test] + fn policy_sums_to_one() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(20), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "policy sums to {sum}, expected 1.0"); + } + + #[test] + fn policy_zero_for_illegal_actions() { + let env = CountdownEnv; + // remaining = 1 → only action 0 is legal + let state = CState { remaining: 1, to_move: 0 }; + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(10), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + assert_eq!(policy[1], 0.0, "illegal action must have zero policy mass"); + } + + // ── Action selection tests ──────────────────────────────────────────── + + #[test] + fn greedy_selects_most_visited() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(60), &mut rng()); + let greedy = select_action(&root, 0.0, &mut rng()); + let most_visited = root.children.iter().max_by_key(|(_, c)| c.n).map(|(a, _)| *a).unwrap(); + assert_eq!(greedy, most_visited); + } + + #[test] + fn temperature_sampling_stays_legal() { + let env = CountdownEnv; + let state = env.new_game(); + let legal = env.legal_actions(&state); + let mut r = rng(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(30), &mut r); + for _ in 0..20 { + let a = select_action(&root, 1.0, &mut r); + assert!(legal.contains(&a), "sampled action {a} is not legal"); + } + } + + // ── Zero-simulation edge case ───────────────────────────────────────── + + #[test] + fn zero_simulations_uniform_policy() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(0), &mut rng()); + let policy = mcts_policy(&root, env.action_space()); + // With 0 simulations, fallback is uniform over the 2 legal actions. + let sum: f32 = policy.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + // ── Root value ──────────────────────────────────────────────────────── + + #[test] + fn root_q_in_valid_range() { + let env = CountdownEnv; + let state = env.new_game(); + let root = run_mcts(&env, &state, &ZeroEval(2), &config_n(40), &mut rng()); + let q = root.q(); + assert!(q >= -1.0 && q <= 1.0, "root Q={q} outside [-1, 1]"); + } + + // ── Integration: run on a real Trictrac game ────────────────────────── + + #[test] + fn no_panic_on_trictrac_state() { + use crate::env::TrictracEnv; + + let env = TrictracEnv; + let mut state = env.new_game(); + let mut r = rng(); + + // Advance past the initial chance node to reach a decision node. + while env.current_player(&state).is_chance() { + env.apply_chance(&mut state, &mut r); + } + + if env.current_player(&state).is_terminal() { + return; // unlikely but safe + } + + let config = MctsConfig { + n_simulations: 5, // tiny for speed + dirichlet_alpha: 0.0, + dirichlet_eps: 0.0, + ..MctsConfig::default() + }; + + let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); + assert!(root.n > 0); + let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); + assert_eq!(total, 5); + } +} diff --git a/spiel_bot/src/mcts/node.rs b/spiel_bot/src/mcts/node.rs new file mode 100644 index 0000000..aff7735 --- /dev/null +++ b/spiel_bot/src/mcts/node.rs @@ -0,0 +1,91 @@ +//! MCTS tree node. +//! +//! [`MctsNode`] holds the visit statistics for one player-decision position in +//! the search tree. A node is *expanded* the first time the policy-value +//! network is evaluated there; before that it is a leaf. + +/// One node in the MCTS tree, representing a player-decision position. +/// +/// `w` stores the sum of values backed up into this node, always from the +/// perspective of **the player who acts here**. `q()` therefore also returns +/// a value in `(-1, 1)` from that same perspective. +#[derive(Debug)] +pub struct MctsNode { + /// Visit count `N(s, a)`. + pub n: u32, + /// Sum of backed-up values `W(s, a)` — from **this node's player's** perspective. + pub w: f32, + /// Prior probability `P(s, a)` assigned by the policy head (after masked softmax). + pub p: f32, + /// Children: `(action_index, child_node)`, populated on first expansion. + pub children: Vec<(usize, MctsNode)>, + /// `true` after the network has been evaluated and children have been set up. + pub expanded: bool, +} + +impl MctsNode { + /// Create a fresh, unexpanded leaf with the given prior probability. + pub fn new(prior: f32) -> Self { + Self { + n: 0, + w: 0.0, + p: prior, + children: Vec::new(), + expanded: false, + } + } + + /// `Q(s, a) = W / N`, or `0.0` if this node has never been visited. + #[inline] + pub fn q(&self) -> f32 { + if self.n == 0 { 0.0 } else { self.w / self.n as f32 } + } + + /// PUCT selection score: + /// + /// ```text + /// Q(s,a) + c_puct · P(s,a) · √N_parent / (1 + N(s,a)) + /// ``` + #[inline] + pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 { + self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_zero_when_unvisited() { + let node = MctsNode::new(0.5); + assert_eq!(node.q(), 0.0); + } + + #[test] + fn q_reflects_w_over_n() { + let mut node = MctsNode::new(0.5); + node.n = 4; + node.w = 2.0; + assert!((node.q() - 0.5).abs() < 1e-6); + } + + #[test] + fn puct_exploration_dominates_unvisited() { + // Unvisited child should outscore a visited child with negative Q. + let mut visited = MctsNode::new(0.5); + visited.n = 10; + visited.w = -5.0; // Q = -0.5 + + let unvisited = MctsNode::new(0.5); + + let parent_n = 10; + let c = 1.5; + assert!( + unvisited.puct(parent_n, c) > visited.puct(parent_n, c), + "unvisited child should have higher PUCT than a negatively-valued visited child" + ); + } +} diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs new file mode 100644 index 0000000..c4960c7 --- /dev/null +++ b/spiel_bot/src/mcts/search.rs @@ -0,0 +1,170 @@ +//! Simulation, expansion, backup, and noise helpers. +//! +//! These are internal to the `mcts` module; the public entry points are +//! [`super::run_mcts`], [`super::mcts_policy`], and [`super::select_action`]. + +use rand::Rng; +use rand_distr::{Gamma, Distribution}; + +use crate::env::GameEnv; +use super::{Evaluator, MctsConfig}; +use super::node::MctsNode; + +// ── Masked softmax ───────────────────────────────────────────────────────── + +/// Numerically stable softmax over `legal` actions only. +/// +/// Illegal logits are treated as `-∞` and receive probability `0.0`. +/// Returns a probability vector of length `action_space`. +pub(super) fn masked_softmax(logits: &[f32], legal: &[usize], action_space: usize) -> Vec { + let mut probs = vec![0.0f32; action_space]; + if legal.is_empty() { + return probs; + } + let max_logit = legal + .iter() + .map(|&a| logits[a]) + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for &a in legal { + let e = (logits[a] - max_logit).exp(); + probs[a] = e; + sum += e; + } + if sum > 0.0 { + for &a in legal { + probs[a] /= sum; + } + } else { + let uniform = 1.0 / legal.len() as f32; + for &a in legal { + probs[a] = uniform; + } + } + probs +} + +// ── Dirichlet noise ──────────────────────────────────────────────────────── + +/// Mix Dirichlet(α, …, α) noise into the root's children priors for exploration. +/// +/// Standard AlphaZero parameters: `alpha = 0.3`, `eps = 0.25`. +/// Uses the Gamma-distribution trick: Dir(α,…,α) = Gamma(α,1)^n / sum. +pub(super) fn add_dirichlet_noise( + node: &mut MctsNode, + alpha: f32, + eps: f32, + rng: &mut impl Rng, +) { + let n = node.children.len(); + if n == 0 { + return; + } + let Ok(gamma) = Gamma::new(alpha as f64, 1.0_f64) else { + return; + }; + let samples: Vec = (0..n).map(|_| gamma.sample(rng) as f32).collect(); + let sum: f32 = samples.iter().sum(); + if sum <= 0.0 { + return; + } + for (i, (_, child)) in node.children.iter_mut().enumerate() { + let noise = samples[i] / sum; + child.p = (1.0 - eps) * child.p + eps * noise; + } +} + +// ── Expansion ────────────────────────────────────────────────────────────── + +/// Evaluate the network at `state` and populate `node` with children. +/// +/// Sets `node.n = 1`, `node.w = value`, `node.expanded = true`. +/// Returns the network value estimate from `player_idx`'s perspective. +pub(super) fn expand( + node: &mut MctsNode, + state: &E::State, + env: &E, + evaluator: &dyn Evaluator, + player_idx: usize, +) -> f32 { + let obs = env.observation(state, player_idx); + let legal = env.legal_actions(state); + let (logits, value) = evaluator.evaluate(&obs); + let priors = masked_softmax(&logits, &legal, env.action_space()); + node.children = legal.iter().map(|&a| (a, MctsNode::new(priors[a]))).collect(); + node.expanded = true; + node.n = 1; + node.w = value; + value +} + +// ── Simulation ───────────────────────────────────────────────────────────── + +/// One MCTS simulation from an **already-expanded** decision node. +/// +/// Traverses the tree with PUCT selection, expands the first unvisited leaf, +/// and backs up the result. +/// +/// * `player_idx` — the player (0 or 1) who acts at `state`. +/// * Returns the backed-up value **from `player_idx`'s perspective**. +pub(super) fn simulate( + node: &mut MctsNode, + state: E::State, + env: &E, + evaluator: &dyn Evaluator, + config: &MctsConfig, + rng: &mut impl Rng, + player_idx: usize, +) -> f32 { + debug_assert!(node.expanded, "simulate called on unexpanded node"); + + // ── Selection: child with highest PUCT ──────────────────────────────── + let parent_n = node.n; + let best = node + .children + .iter() + .enumerate() + .max_by(|(_, (_, a)), (_, (_, b))| { + a.puct(parent_n, config.c_puct) + .partial_cmp(&b.puct(parent_n, config.c_puct)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i) + .expect("expanded node must have at least one child"); + + let (action, child) = &mut node.children[best]; + let action = *action; + + // ── Apply action + advance through any chance nodes ─────────────────── + let mut next_state = state; + env.apply(&mut next_state, action); + while env.current_player(&next_state).is_chance() { + env.apply_chance(&mut next_state, rng); + } + + let next_cp = env.current_player(&next_state); + + // ── Evaluate leaf or terminal ────────────────────────────────────────── + // All values are converted to `player_idx`'s perspective before backup. + let child_value = if next_cp.is_terminal() { + let returns = env + .returns(&next_state) + .expect("terminal node must have returns"); + returns[player_idx] + } else { + let child_player = next_cp.index().unwrap(); + let v = if child.expanded { + simulate(child, next_state, env, evaluator, config, rng, child_player) + } else { + expand::(child, &next_state, env, evaluator, child_player) + }; + // Negate when the child belongs to the opponent. + if child_player == player_idx { v } else { -v } + }; + + // ── Backup ──────────────────────────────────────────────────────────── + node.n += 1; + node.w += child_value; + + child_value +}