diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..eead171 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..4d36acc 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -156,7 +156,13 @@ pub(super) fn simulate( let returns = env .returns(&next_state) .expect("terminal node must have returns"); - returns[player_idx] + let v = returns[player_idx]; + // Update child stats so PUCT and mcts_policy count terminal visits. + // Store from player_idx's perspective so child.q() is directly usable + // by the parent's PUCT selection (high = good for the selecting player). + child.n += 1; + child.w += v; + v } else { let child_player = next_cp.index().unwrap(); let v = if crossed_chance { @@ -166,6 +172,13 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); + // Store from player_idx's (parent's) perspective so PUCT works correctly. + // `value` is from child_player's POV; negate when child is the opponent + // so that child.q() = expected return for the player CHOOSING this child. + // Without the negation, root would maximise the opponent's Q-value and + // systematically pick the worst action. + child.n += 1; + child.w += if child_player == player_idx { value } else { -value }; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player)