fix: remove panics on cxx hot path

This commit is contained in:
Henri Bourcereau 2026-02-27 22:25:06 +01:00
parent 72eb60f322
commit 953b5f451a
5 changed files with 51 additions and 48 deletions

View file

@ -259,15 +259,13 @@ impl TrictracEnvironment {
use training_common::get_valid_actions; use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state); if let Ok(valid_actions) = get_valid_actions(game_state) {
// Mapper l'index d'action sur une action valide
if valid_actions.is_empty() { let action_index = (action.index as usize) % valid_actions.len();
return None; Some(valid_actions[action_index].clone())
} else {
None
} }
// Mapper l'index d'action sur une action valide
let action_index = (action.index as usize) % valid_actions.len();
Some(valid_actions[action_index].clone())
} }
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu

View file

@ -229,15 +229,13 @@ impl TrictracEnvironment {
use training_common::get_valid_actions; use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(&self.game); if let Ok(valid_actions) = get_valid_actions(&self.game) {
// Mapper l'index d'action sur une action valide
if valid_actions.is_empty() { let action_index = (action.index as usize) % valid_actions.len();
return None; Some(valid_actions[action_index])
} else {
None
} }
// Mapper l'index d'action sur une action valide
let action_index = (action.index as usize) % valid_actions.len();
Some(valid_actions[action_index])
} }
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu

View file

@ -50,30 +50,32 @@ impl DqnBurnStrategy {
fn get_dqn_action(&self) -> Option<TrictracAction> { fn get_dqn_action(&self) -> Option<TrictracAction> {
if let Some(ref model) = self.model { if let Some(ref model) = self.model {
let state = environment::TrictracState::from_game_state(&self.game); let state = environment::TrictracState::from_game_state(&self.game);
let valid_actions_indices = get_valid_action_indices(&self.game); if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) {
if valid_actions_indices.is_empty() { if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode return None; // No valid actions, end of episode
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.infer(state.to_tensor().unsqueeze());
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
} }
// Obtenir les Q-values pour toutes les actions
let q_values = model.infer(state.to_tensor().unsqueeze());
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
return environment::TrictracEnvironment::convert_action(
environment::TrictracAction::from(action_index),
);
} }
// Get best action (highest q-value) return None;
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
action_index,
))
} else { } else {
// Fallback : action aléatoire valide // Fallback : action aléatoire valide
sample_valid_action(&self.game) sample_valid_action(&self.game)

View file

@ -159,6 +159,9 @@ impl InternalIterator for TrictracAvailableMovesIterator<'_> {
where where
F: FnMut(Self::Item) -> ControlFlow<R>, F: FnMut(Self::Item) -> ControlFlow<R>,
{ {
get_valid_actions(&self.board.0).into_iter().try_for_each(f) get_valid_actions(&self.board.0)
.unwrap()
.into_iter()
.try_for_each(f)
} }
} }

View file

@ -210,7 +210,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
} }
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => { TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
// valid_actions.push(TrictracAction::Mark); // valid_actions.push(TrictracAction::Mark);
panic!( anyhow::bail!(
"get_valid_actions not implemented for turn stage {:?}", "get_valid_actions not implemented for turn stage {:?}",
game_state.turn_stage game_state.turn_stage
); );
@ -225,7 +225,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1, &move2, &color, game_state,
)); )?);
} }
} }
TurnStage::Move => { TurnStage::Move => {
@ -239,7 +239,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state, &move1, &move2, &color, game_state,
)); )?);
} }
} }
} }
@ -256,7 +256,7 @@ fn checker_moves_to_trictrac_action(
move2: &CheckerMove, move2: &CheckerMove,
color: &crate::Color, color: &crate::Color,
state: &GameState, state: &GameState,
) -> TrictracAction { ) -> anyhow::Result<TrictracAction> {
let dice = &state.dice; let dice = &state.dice;
let board = &state.board; let board = &state.board;
@ -269,7 +269,7 @@ fn checker_moves_to_trictrac_action(
dice, dice,
&board.clone().mirror(), &board.clone().mirror(),
) )
.mirror() .map(|a| a.mirror())
} else { } else {
white_checker_moves_to_trictrac_action(move1, move2, dice, board) white_checker_moves_to_trictrac_action(move1, move2, dice, board)
} }
@ -280,7 +280,7 @@ fn white_checker_moves_to_trictrac_action(
move2: &CheckerMove, move2: &CheckerMove,
dice: &Dice, dice: &Dice,
board: &Board, board: &Board,
) -> TrictracAction { ) -> anyhow::Result<TrictracAction> {
let to1 = move1.get_to(); let to1 = move1.get_to();
let to2 = move2.get_to(); let to2 = move2.get_to();
let from1 = move1.get_from(); let from1 = move1.get_from();
@ -328,11 +328,11 @@ fn white_checker_moves_to_trictrac_action(
panic!("error while moving checker {move_res:?}"); panic!("error while moving checker {move_res:?}");
} }
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
TrictracAction::Move { Ok(TrictracAction::Move {
dice_order, dice_order,
checker1, checker1,
checker2, checker2,
} })
} }
/// Retourne les indices des actions valides /// Retourne les indices des actions valides
@ -350,7 +350,9 @@ pub fn sample_valid_action(game_state: &GameState) -> Option<TrictracAction> {
let valid_actions = get_valid_actions(game_state); let valid_actions = get_valid_actions(game_state);
let mut rng = rng(); let mut rng = rng();
valid_actions.unwrap().choose(&mut rng).cloned() valid_actions
.map(|va| va.choose(&mut rng).cloned())
.unwrap_or_default()
} }
#[cfg(test)] #[cfg(test)]