diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 6b5b2be..99f1f1f 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -259,15 +259,13 @@ impl TrictracEnvironment { use training_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; + if let Ok(valid_actions) = get_valid_actions(game_state) { + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } else { + None } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) } /// Exécute une action Trictrac dans le jeu diff --git a/bot/src/burnrl/environment_valid.rs b/bot/src/burnrl/environment_valid.rs index df51781..2648831 100644 --- a/bot/src/burnrl/environment_valid.rs +++ b/bot/src/burnrl/environment_valid.rs @@ -229,15 +229,13 @@ impl TrictracEnvironment { use training_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(&self.game); - - if valid_actions.is_empty() { - return None; + if let Ok(valid_actions) = get_valid_actions(&self.game) { + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index]) + } else { + None } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index]) } /// Exécute une action Trictrac dans le jeu diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index ea58bc5..37f9b7f 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -50,30 +50,32 @@ impl DqnBurnStrategy { fn get_dqn_action(&self) -> Option { if let Some(ref model) = self.model { let state = environment::TrictracState::from_game_state(&self.game); - let valid_actions_indices = get_valid_action_indices(&self.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - - // Obtenir les Q-values pour toutes les actions - let q_values = model.infer(state.to_tensor().unsqueeze()); - - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); + if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) { + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode } + + // Obtenir les Q-values pour toutes les actions + let q_values = model.infer(state.to_tensor().unsqueeze()); + + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + return environment::TrictracEnvironment::convert_action( + environment::TrictracAction::from(action_index), + ); } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - environment::TrictracEnvironment::convert_action(environment::TrictracAction::from( - action_index, - )) + return None; } else { // Fallback : action aléatoire valide sample_valid_action(&self.game) diff --git a/bot/src/trictrac_board.rs b/bot/src/trictrac_board.rs index ecc5dcc..24dce0d 100644 --- a/bot/src/trictrac_board.rs +++ b/bot/src/trictrac_board.rs @@ -159,6 +159,9 @@ impl InternalIterator for TrictracAvailableMovesIterator<'_> { where F: FnMut(Self::Item) -> ControlFlow, { - get_valid_actions(&self.board.0).into_iter().try_for_each(f) + get_valid_actions(&self.board.0) + .unwrap() + .into_iter() + .try_for_each(f) } } diff --git a/store/src/training_common.rs b/store/src/training_common.rs index 0cc635c..32fefd7 100644 --- a/store/src/training_common.rs +++ b/store/src/training_common.rs @@ -210,7 +210,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result { // valid_actions.push(TrictracAction::Mark); - panic!( + anyhow::bail!( "get_valid_actions not implemented for turn stage {:?}", game_state.turn_stage ); @@ -225,7 +225,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result { @@ -239,7 +239,7 @@ pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result TrictracAction { +) -> anyhow::Result { let dice = &state.dice; let board = &state.board; @@ -269,7 +269,7 @@ fn checker_moves_to_trictrac_action( dice, &board.clone().mirror(), ) - .mirror() + .map(|a| a.mirror()) } else { white_checker_moves_to_trictrac_action(move1, move2, dice, board) } @@ -280,7 +280,7 @@ fn white_checker_moves_to_trictrac_action( move2: &CheckerMove, dice: &Dice, board: &Board, -) -> TrictracAction { +) -> anyhow::Result { let to1 = move1.get_to(); let to2 = move2.get_to(); let from1 = move1.get_from(); @@ -328,11 +328,11 @@ fn white_checker_moves_to_trictrac_action( panic!("error while moving checker {move_res:?}"); } let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize; - TrictracAction::Move { + Ok(TrictracAction::Move { dice_order, checker1, checker2, - } + }) } /// Retourne les indices des actions valides @@ -350,7 +350,9 @@ pub fn sample_valid_action(game_state: &GameState) -> Option { let valid_actions = get_valid_actions(game_state); let mut rng = rng(); - valid_actions.unwrap().choose(&mut rng).cloned() + valid_actions + .map(|va| va.choose(&mut rng).cloned()) + .unwrap_or_default() } #[cfg(test)]