diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 3191b4b..7c2cc46 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -8,8 +8,6 @@ use store::{CheckerMove, Dice, GameEvent, PlayerId}; pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer les points - Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions @@ -18,6 +16,8 @@ pub enum TrictracAction { from1: usize, // position de départ du premier pion (0-24) from2: usize, // position de départ du deuxième pion (0-24) }, + // Marquer les points : à activer si support des écoles + // Mark, } impl TrictracAction { @@ -25,22 +25,22 @@ impl TrictracAction { pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark => 1, - TrictracAction::Go => 2, + TrictracAction::Go => 1, TrictracAction::Move { dice_order, from1, from2, } => { // Encoder les mouvements dans l'espace d'actions - // Indices 3+ pour les mouvements - let mut start = 3; + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; if !dice_order { // 25 * 25 = 625 start += 625; } start + from1 * 25 + from2 - } + } // TrictracAction::Mark => 1252, } } @@ -48,8 +48,8 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1 => Some(TrictracAction::Mark), - 2 => Some(TrictracAction::Go), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), i if i >= 3 => { let move_code = i - 3; let (dice_order, from1, from2) = Self::decode_move(move_code); @@ -77,10 +77,10 @@ impl TrictracAction { /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // 1 (Roll) + 1 (Go) + mouvements possibles // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 3 + (2 * 25 * 25) // = 1253 + 2 + (2 * 25 * 25) // = 1252 } // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { @@ -273,7 +273,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - valid_actions.push(TrictracAction::Mark); + // valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 2b935f5..f47e1a9 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -251,14 +251,15 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Mark { points } => { - // Marquer des points - reward += 0.1 * points as f32; - Some(GameEvent::Mark { - player_id: self.agent_player_id, - points, - }) - } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game_state. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.agent_player_id, + // points, + // }) + // } TrictracAction::Go => { // Continuer après avoir gagné un trou reward += 0.2;