fix: remove panics on cxx hot path
This commit is contained in:
parent
72eb60f322
commit
953b5f451a
5 changed files with 51 additions and 48 deletions
|
|
@ -259,15 +259,13 @@ impl TrictracEnvironment {
|
||||||
use training_common::get_valid_actions;
|
use training_common::get_valid_actions;
|
||||||
|
|
||||||
// Obtenir les actions valides dans le contexte actuel
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
let valid_actions = get_valid_actions(game_state);
|
if let Ok(valid_actions) = get_valid_actions(game_state) {
|
||||||
|
// Mapper l'index d'action sur une action valide
|
||||||
if valid_actions.is_empty() {
|
let action_index = (action.index as usize) % valid_actions.len();
|
||||||
return None;
|
Some(valid_actions[action_index].clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mapper l'index d'action sur une action valide
|
|
||||||
let action_index = (action.index as usize) % valid_actions.len();
|
|
||||||
Some(valid_actions[action_index].clone())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
|
|
|
||||||
|
|
@ -229,15 +229,13 @@ impl TrictracEnvironment {
|
||||||
use training_common::get_valid_actions;
|
use training_common::get_valid_actions;
|
||||||
|
|
||||||
// Obtenir les actions valides dans le contexte actuel
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
let valid_actions = get_valid_actions(&self.game);
|
if let Ok(valid_actions) = get_valid_actions(&self.game) {
|
||||||
|
// Mapper l'index d'action sur une action valide
|
||||||
if valid_actions.is_empty() {
|
let action_index = (action.index as usize) % valid_actions.len();
|
||||||
return None;
|
Some(valid_actions[action_index])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mapper l'index d'action sur une action valide
|
|
||||||
let action_index = (action.index as usize) % valid_actions.len();
|
|
||||||
Some(valid_actions[action_index])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
|
|
|
||||||
|
|
@ -50,30 +50,32 @@ impl DqnBurnStrategy {
|
||||||
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||||
if let Some(ref model) = self.model {
|
if let Some(ref model) = self.model {
|
||||||
let state = environment::TrictracState::from_game_state(&self.game);
|
let state = environment::TrictracState::from_game_state(&self.game);
|
||||||
let valid_actions_indices = get_valid_action_indices(&self.game);
|
if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) {
|
||||||
if valid_actions_indices.is_empty() {
|
if valid_actions_indices.is_empty() {
|
||||||
return None; // No valid actions, end of episode
|
return None; // No valid actions, end of episode
|
||||||
}
|
|
||||||
|
|
||||||
// Obtenir les Q-values pour toutes les actions
|
|
||||||
let q_values = model.infer(state.to_tensor().unsqueeze());
|
|
||||||
|
|
||||||
// Set non valid actions q-values to lowest
|
|
||||||
let mut masked_q_values = q_values.clone();
|
|
||||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
|
||||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
|
||||||
if !valid_actions_indices.contains(&index) {
|
|
||||||
masked_q_values = masked_q_values.clone().mask_fill(
|
|
||||||
masked_q_values.clone().equal_elem(*q_value),
|
|
||||||
f32::NEG_INFINITY,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Obtenir les Q-values pour toutes les actions
|
||||||
|
let q_values = model.infer(state.to_tensor().unsqueeze());
|
||||||
|
|
||||||
|
// Set non valid actions q-values to lowest
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions_indices.contains(&index) {
|
||||||
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get best action (highest q-value)
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
return environment::TrictracEnvironment::convert_action(
|
||||||
|
environment::TrictracAction::from(action_index),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
// Get best action (highest q-value)
|
return None;
|
||||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
|
||||||
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
|
|
||||||
action_index,
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
// Fallback : action aléatoire valide
|
// Fallback : action aléatoire valide
|
||||||
sample_valid_action(&self.game)
|
sample_valid_action(&self.game)
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,9 @@ impl InternalIterator for TrictracAvailableMovesIterator<'_> {
|
||||||
where
|
where
|
||||||
F: FnMut(Self::Item) -> ControlFlow<R>,
|
F: FnMut(Self::Item) -> ControlFlow<R>,
|
||||||
{
|
{
|
||||||
get_valid_actions(&self.board.0).into_iter().try_for_each(f)
|
get_valid_actions(&self.board.0)
|
||||||
|
.unwrap()
|
||||||
|
.into_iter()
|
||||||
|
.try_for_each(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
||||||
}
|
}
|
||||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||||
// valid_actions.push(TrictracAction::Mark);
|
// valid_actions.push(TrictracAction::Mark);
|
||||||
panic!(
|
anyhow::bail!(
|
||||||
"get_valid_actions not implemented for turn stage {:?}",
|
"get_valid_actions not implemented for turn stage {:?}",
|
||||||
game_state.turn_stage
|
game_state.turn_stage
|
||||||
);
|
);
|
||||||
|
|
@ -225,7 +225,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(checker_moves_to_trictrac_action(
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
&move1, &move2, &color, game_state,
|
&move1, &move2, &color, game_state,
|
||||||
));
|
)?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::Move => {
|
TurnStage::Move => {
|
||||||
|
|
@ -239,7 +239,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracA
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(checker_moves_to_trictrac_action(
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
&move1, &move2, &color, game_state,
|
&move1, &move2, &color, game_state,
|
||||||
));
|
)?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -256,7 +256,7 @@ fn checker_moves_to_trictrac_action(
|
||||||
move2: &CheckerMove,
|
move2: &CheckerMove,
|
||||||
color: &crate::Color,
|
color: &crate::Color,
|
||||||
state: &GameState,
|
state: &GameState,
|
||||||
) -> TrictracAction {
|
) -> anyhow::Result<TrictracAction> {
|
||||||
let dice = &state.dice;
|
let dice = &state.dice;
|
||||||
let board = &state.board;
|
let board = &state.board;
|
||||||
|
|
||||||
|
|
@ -269,7 +269,7 @@ fn checker_moves_to_trictrac_action(
|
||||||
dice,
|
dice,
|
||||||
&board.clone().mirror(),
|
&board.clone().mirror(),
|
||||||
)
|
)
|
||||||
.mirror()
|
.map(|a| a.mirror())
|
||||||
} else {
|
} else {
|
||||||
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
|
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
|
||||||
}
|
}
|
||||||
|
|
@ -280,7 +280,7 @@ fn white_checker_moves_to_trictrac_action(
|
||||||
move2: &CheckerMove,
|
move2: &CheckerMove,
|
||||||
dice: &Dice,
|
dice: &Dice,
|
||||||
board: &Board,
|
board: &Board,
|
||||||
) -> TrictracAction {
|
) -> anyhow::Result<TrictracAction> {
|
||||||
let to1 = move1.get_to();
|
let to1 = move1.get_to();
|
||||||
let to2 = move2.get_to();
|
let to2 = move2.get_to();
|
||||||
let from1 = move1.get_from();
|
let from1 = move1.get_from();
|
||||||
|
|
@ -328,11 +328,11 @@ fn white_checker_moves_to_trictrac_action(
|
||||||
panic!("error while moving checker {move_res:?}");
|
panic!("error while moving checker {move_res:?}");
|
||||||
}
|
}
|
||||||
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
|
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
|
||||||
TrictracAction::Move {
|
Ok(TrictracAction::Move {
|
||||||
dice_order,
|
dice_order,
|
||||||
checker1,
|
checker1,
|
||||||
checker2,
|
checker2,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retourne les indices des actions valides
|
/// Retourne les indices des actions valides
|
||||||
|
|
@ -350,7 +350,9 @@ pub fn sample_valid_action(game_state: &GameState) -> Option<TrictracAction> {
|
||||||
|
|
||||||
let valid_actions = get_valid_actions(game_state);
|
let valid_actions = get_valid_actions(game_state);
|
||||||
let mut rng = rng();
|
let mut rng = rng();
|
||||||
valid_actions.unwrap().choose(&mut rng).cloned()
|
valid_actions
|
||||||
|
.map(|va| va.choose(&mut rng).cloned())
|
||||||
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue