diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..eead171 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..1d9750d 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -166,6 +166,12 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); + // Record the visit so that PUCT and mcts_policy use real counts. + // Without this, child.n stays 0 for every simulation in games where + // every player action is immediately followed by a chance node (e.g. + // Trictrac), causing mcts_policy to always return a uniform policy. + child.n += 1; + child.w += value; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player)