fix: --n-sim training parameter
This commit is contained in:
parent
ad30d09311
commit
cf50784a23
2 changed files with 18 additions and 5 deletions
|
|
@ -403,10 +403,10 @@ mod tests {
|
||||||
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||||
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
||||||
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
||||||
// Children visit counts may sum to less than n_simulations when some
|
// Every simulation crosses a chance node at depth 1 (dice roll after
|
||||||
// simulations cross a chance node at depth 1 (turn ends after one move)
|
// the player's move). Since the fix now updates child.n in that case,
|
||||||
// and evaluate with the network directly without updating child.n.
|
// children visit counts must sum to exactly n_simulations.
|
||||||
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
assert!(total <= config.n_simulations as u32);
|
assert_eq!(total, config.n_simulations as u32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -156,7 +156,13 @@ pub(super) fn simulate<E: GameEnv>(
|
||||||
let returns = env
|
let returns = env
|
||||||
.returns(&next_state)
|
.returns(&next_state)
|
||||||
.expect("terminal node must have returns");
|
.expect("terminal node must have returns");
|
||||||
returns[player_idx]
|
let v = returns[player_idx];
|
||||||
|
// Update child stats so PUCT and mcts_policy count terminal visits.
|
||||||
|
// Store from player_idx's perspective so child.q() is directly usable
|
||||||
|
// by the parent's PUCT selection (high = good for the selecting player).
|
||||||
|
child.n += 1;
|
||||||
|
child.w += v;
|
||||||
|
v
|
||||||
} else {
|
} else {
|
||||||
let child_player = next_cp.index().unwrap();
|
let child_player = next_cp.index().unwrap();
|
||||||
let v = if crossed_chance {
|
let v = if crossed_chance {
|
||||||
|
|
@ -166,6 +172,13 @@ pub(super) fn simulate<E: GameEnv>(
|
||||||
// previously cached children would be for a different outcome.
|
// previously cached children would be for a different outcome.
|
||||||
let obs = env.observation(&next_state, child_player);
|
let obs = env.observation(&next_state, child_player);
|
||||||
let (_, value) = evaluator.evaluate(&obs);
|
let (_, value) = evaluator.evaluate(&obs);
|
||||||
|
// Store from player_idx's (parent's) perspective so PUCT works correctly.
|
||||||
|
// `value` is from child_player's POV; negate when child is the opponent
|
||||||
|
// so that child.q() = expected return for the player CHOOSING this child.
|
||||||
|
// Without the negation, root would maximise the opponent's Q-value and
|
||||||
|
// systematically pick the worst action.
|
||||||
|
child.n += 1;
|
||||||
|
child.w += if child_player == player_idx { value } else { -value };
|
||||||
value
|
value
|
||||||
} else if child.expanded {
|
} else if child.expanded {
|
||||||
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue