fix: remove panics on cxx hot path

This commit is contained in:
Henri Bourcereau 2026-02-27 22:25:06 +01:00
parent 72eb60f322
commit 953b5f451a
5 changed files with 51 additions and 48 deletions

View file

@ -50,30 +50,32 @@ impl DqnBurnStrategy {
fn get_dqn_action(&self) -> Option<TrictracAction> {
if let Some(ref model) = self.model {
let state = environment::TrictracState::from_game_state(&self.game);
let valid_actions_indices = get_valid_action_indices(&self.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.infer(state.to_tensor().unsqueeze());
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) {
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.infer(state.to_tensor().unsqueeze());
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
return environment::TrictracEnvironment::convert_action(
environment::TrictracAction::from(action_index),
);
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
action_index,
))
return None;
} else {
// Fallback : action aléatoire valide
sample_valid_action(&self.game)