From 27e05369784afd21a32888e4983464aec00ae7d3 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sat, 7 Mar 2026 22:18:59 +0100 Subject: [PATCH] fix(spiel_bot): mcts fix --- spiel_bot/src/mcts/mod.rs | 8 ++++++-- spiel_bot/src/mcts/search.rs | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index e92bd09..a0a690d 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -401,8 +401,12 @@ mod tests { }; let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); - assert!(root.n > 0); + // root.n = 1 (expansion) + n_simulations (one backup per simulation). + assert_eq!(root.n, 1 + config.n_simulations as u32); + // Children visit counts may sum to less than n_simulations when some + // simulations cross a chance node at depth 1 (turn ends after one move) + // and evaluate with the network directly without updating child.n. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert_eq!(total, 5); + assert!(total <= config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index c4960c7..55db701 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -138,8 +138,14 @@ pub(super) fn simulate( // ── Apply action + advance through any chance nodes ─────────────────── let mut next_state = state; env.apply(&mut next_state, action); + + // Track whether we crossed a chance node (dice roll) on the way down. + // If we did, the child's cached legal actions are for a *different* dice + // outcome and must not be reused — evaluate with the network directly. + let mut crossed_chance = false; while env.current_player(&next_state).is_chance() { env.apply_chance(&mut next_state, rng); + crossed_chance = true; } let next_cp = env.current_player(&next_state); @@ -153,7 +159,15 @@ pub(super) fn simulate( returns[player_idx] } else { let child_player = next_cp.index().unwrap(); - let v = if child.expanded { + let v = if crossed_chance { + // Outcome sampling: after dice, evaluate the resulting position + // directly with the network. Do NOT build the tree across chance + // boundaries — the dice change which actions are legal, so any + // previously cached children would be for a different outcome. + let obs = env.observation(&next_state, child_player); + let (_, value) = evaluator.evaluate(&obs); + value + } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) } else { expand::(child, &next_state, env, evaluator, child_player)