fix: --n-sim training parameter

This commit is contained in:
Henri Bourcereau 2026-03-11 22:17:03 +01:00
parent e7d13c9a02
commit e80dade303
2 changed files with 10 additions and 4 deletions

View file

@ -403,10 +403,10 @@ mod tests {
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
// root.n = 1 (expansion) + n_simulations (one backup per simulation). // root.n = 1 (expansion) + n_simulations (one backup per simulation).
assert_eq!(root.n, 1 + config.n_simulations as u32); assert_eq!(root.n, 1 + config.n_simulations as u32);
// Children visit counts may sum to less than n_simulations when some // Every simulation crosses a chance node at depth 1 (dice roll after
// simulations cross a chance node at depth 1 (turn ends after one move) // the player's move). Since the fix now updates child.n in that case,
// and evaluate with the network directly without updating child.n. // children visit counts must sum to exactly n_simulations.
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
assert!(total <= config.n_simulations as u32); assert_eq!(total, config.n_simulations as u32);
} }
} }

View file

@ -166,6 +166,12 @@ pub(super) fn simulate<E: GameEnv>(
// previously cached children would be for a different outcome. // previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player); let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs); let (_, value) = evaluator.evaluate(&obs);
// Record the visit so that PUCT and mcts_policy use real counts.
// Without this, child.n stays 0 for every simulation in games where
// every player action is immediately followed by a chance node (e.g.
// Trictrac), causing mcts_policy to always return a uniform policy.
child.n += 1;
child.w += value;
value value
} else if child.expanded { } else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player) simulate(child, next_state, env, evaluator, config, rng, child_player)