fix(spiel_bot): mcts fix
This commit is contained in:
parent
eadc101741
commit
53eeda349e
2 changed files with 21 additions and 3 deletions
|
|
@ -401,8 +401,12 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r);
|
||||||
assert!(root.n > 0);
|
// root.n = 1 (expansion) + n_simulations (one backup per simulation).
|
||||||
|
assert_eq!(root.n, 1 + config.n_simulations as u32);
|
||||||
|
// Children visit counts may sum to less than n_simulations when some
|
||||||
|
// simulations cross a chance node at depth 1 (turn ends after one move)
|
||||||
|
// and evaluate with the network directly without updating child.n.
|
||||||
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
let total: u32 = root.children.iter().map(|(_, c)| c.n).sum();
|
||||||
assert_eq!(total, 5);
|
assert!(total <= config.n_simulations as u32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,8 +138,14 @@ pub(super) fn simulate<E: GameEnv>(
|
||||||
// ── Apply action + advance through any chance nodes ───────────────────
|
// ── Apply action + advance through any chance nodes ───────────────────
|
||||||
let mut next_state = state;
|
let mut next_state = state;
|
||||||
env.apply(&mut next_state, action);
|
env.apply(&mut next_state, action);
|
||||||
|
|
||||||
|
// Track whether we crossed a chance node (dice roll) on the way down.
|
||||||
|
// If we did, the child's cached legal actions are for a *different* dice
|
||||||
|
// outcome and must not be reused — evaluate with the network directly.
|
||||||
|
let mut crossed_chance = false;
|
||||||
while env.current_player(&next_state).is_chance() {
|
while env.current_player(&next_state).is_chance() {
|
||||||
env.apply_chance(&mut next_state, rng);
|
env.apply_chance(&mut next_state, rng);
|
||||||
|
crossed_chance = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
let next_cp = env.current_player(&next_state);
|
let next_cp = env.current_player(&next_state);
|
||||||
|
|
@ -153,7 +159,15 @@ pub(super) fn simulate<E: GameEnv>(
|
||||||
returns[player_idx]
|
returns[player_idx]
|
||||||
} else {
|
} else {
|
||||||
let child_player = next_cp.index().unwrap();
|
let child_player = next_cp.index().unwrap();
|
||||||
let v = if child.expanded {
|
let v = if crossed_chance {
|
||||||
|
// Outcome sampling: after dice, evaluate the resulting position
|
||||||
|
// directly with the network. Do NOT build the tree across chance
|
||||||
|
// boundaries — the dice change which actions are legal, so any
|
||||||
|
// previously cached children would be for a different outcome.
|
||||||
|
let obs = env.observation(&next_state, child_player);
|
||||||
|
let (_, value) = evaluator.evaluate(&obs);
|
||||||
|
value
|
||||||
|
} else if child.expanded {
|
||||||
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
simulate(child, next_state, env, evaluator, config, rng, child_player)
|
||||||
} else {
|
} else {
|
||||||
expand::<E>(child, &next_state, env, evaluator, child_player)
|
expand::<E>(child, &next_state, env, evaluator, child_player)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue