fix: remove panics on cxx hot path
This commit is contained in:
parent
72eb60f322
commit
953b5f451a
5 changed files with 51 additions and 48 deletions
|
|
@ -259,15 +259,13 @@ impl TrictracEnvironment {
|
|||
use training_common::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
if let Ok(valid_actions) = get_valid_actions(game_state) {
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index].clone())
|
||||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
|
|
|
|||
|
|
@ -229,15 +229,13 @@ impl TrictracEnvironment {
|
|||
use training_common::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
if let Ok(valid_actions) = get_valid_actions(&self.game) {
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index])
|
||||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
|
|
|
|||
|
|
@ -50,30 +50,32 @@ impl DqnBurnStrategy {
|
|||
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||
if let Some(ref model) = self.model {
|
||||
let state = environment::TrictracState::from_game_state(&self.game);
|
||||
let valid_actions_indices = get_valid_action_indices(&self.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
|
||||
// Obtenir les Q-values pour toutes les actions
|
||||
let q_values = model.infer(state.to_tensor().unsqueeze());
|
||||
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) {
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
|
||||
// Obtenir les Q-values pour toutes les actions
|
||||
let q_values = model.infer(state.to_tensor().unsqueeze());
|
||||
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
return environment::TrictracEnvironment::convert_action(
|
||||
environment::TrictracAction::from(action_index),
|
||||
);
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
|
||||
action_index,
|
||||
))
|
||||
return None;
|
||||
} else {
|
||||
// Fallback : action aléatoire valide
|
||||
sample_valid_action(&self.game)
|
||||
|
|
|
|||
|
|
@ -159,6 +159,9 @@ impl InternalIterator for TrictracAvailableMovesIterator<'_> {
|
|||
where
|
||||
F: FnMut(Self::Item) -> ControlFlow<R>,
|
||||
{
|
||||
get_valid_actions(&self.board.0).into_iter().try_for_each(f)
|
||||
get_valid_actions(&self.board.0)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.try_for_each(f)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue