diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 3191b4b..9a24ae6 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,4 +1,4 @@ -use std::cmp::max; +use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; use store::{CheckerMove, Dice, GameEvent, PlayerId}; @@ -8,8 +8,6 @@ use store::{CheckerMove, Dice, GameEvent, PlayerId}; pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer les points - Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions @@ -18,6 +16,8 @@ pub enum TrictracAction { from1: usize, // position de départ du premier pion (0-24) from2: usize, // position de départ du deuxième pion (0-24) }, + // Marquer les points : à activer si support des écoles + // Mark, } impl TrictracAction { @@ -25,22 +25,22 @@ impl TrictracAction { pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark => 1, - TrictracAction::Go => 2, + TrictracAction::Go => 1, TrictracAction::Move { dice_order, from1, from2, } => { // Encoder les mouvements dans l'espace d'actions - // Indices 3+ pour les mouvements - let mut start = 3; + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; if !dice_order { // 25 * 25 = 625 start += 625; } start + from1 * 25 + from2 - } + } // TrictracAction::Mark => 1252, } } @@ -48,8 +48,8 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1 => Some(TrictracAction::Mark), - 2 => Some(TrictracAction::Go), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), i if i >= 3 => { let move_code = i - 3; let (dice_order, from1, from2) = Self::decode_move(move_code); @@ -77,10 +77,10 @@ impl TrictracAction { /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // 1 (Roll) + 1 (Go) + mouvements possibles // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 3 + (2 * 25 * 25) // = 1253 + 2 + (2 * 25 * 25) // = 1252 } // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { @@ -273,35 +273,37 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - valid_actions.push(TrictracAction::Mark); + // valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); - // Ajouter aussi les mouvements possibles + // Ajoute aussi les mouvements possibles let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } TurnStage::Move => { let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } } @@ -310,6 +312,56 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } +// Valid only for White player +fn checker_moves_to_trictrac_action( + move1: &CheckerMove, + move2: &CheckerMove, + dice: &Dice, +) -> TrictracAction { + let to1 = move1.get_to(); + let to2 = move2.get_to(); + let from1 = move1.get_from(); + let from2 = move2.get_from(); + + let mut diff_move1 = if to1 > 0 { + // Mouvement sans sortie + to1 - from1 + } else { + // sortie, on utilise la valeur du dé + if to2 > 0 { + // sortie pour le mouvement 1 uniquement + let dice2 = to2 - from2; + if dice2 == dice.values.0 as usize { + dice.values.1 as usize + } else { + dice.values.0 as usize + } + } else { + // double sortie + if from1 < from2 { + max(dice.values.0, dice.values.1) as usize + } else { + min(dice.values.0, dice.values.1) as usize + } + } + }; + + // modification de diff_move1 si on est dans le cas d'un mouvement par puissance + let rest_field = 12; + if to1 == rest_field + && to2 == rest_field + && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field + { + // prise par puissance + diff_move1 += 1; + } + TrictracAction::Move { + dice_order: diff_move1 == dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), + } +} + /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 2b935f5..8d9db57 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -1,4 +1,4 @@ -use crate::{Color, GameState, PlayerId}; +use crate::{CheckerMove, Color, GameState, PlayerId}; use rand::prelude::SliceRandom; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; @@ -251,14 +251,15 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Mark { points } => { - // Marquer des points - reward += 0.1 * points as f32; - Some(GameEvent::Mark { - player_id: self.agent_player_id, - points, - }) - } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game_state. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.agent_player_id, + // points, + // }) + // } TrictracAction::Go => { // Continuer après avoir gagné un trou reward += 0.2; @@ -272,8 +273,23 @@ impl TrictracEnv { from2, } => { // Effectuer un mouvement - let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + let (dice1, dice2) = if dice_order { + (self.game_state.dice.values.0, self.game_state.dice.values.1) + } else { + (self.game_state.dice.values.1, self.game_state.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); reward += 0.2; Some(GameEvent::Move { @@ -360,7 +376,9 @@ impl TrictracEnv { // Stratégie simple : choix aléatoire let mut rng = thread_rng(); - let choosen_move = *possible_moves.choose(&mut rng).unwrap(); + let choosen_move = *possible_moves + .choose(&mut rng) + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); GameEvent::Move { player_id: self.opponent_player_id, @@ -443,7 +461,6 @@ impl DqnTrainer { for episode in 1..=episodes { let reward = self.train_episode(); - print!("."); if episode % 100 == 0 { println!( "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", diff --git a/store/src/game.rs b/store/src/game.rs index ed77519..fe2762f 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -757,6 +757,7 @@ mod tests { #[test] fn hold_or_go() { let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -782,6 +783,7 @@ mod tests { // Hold let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -802,6 +804,6 @@ mod tests { assert_ne!(game_state.active_player_id, pid); assert_eq!(game_state.players.get(&pid).unwrap().points, 1); assert_eq!(game_state.get_active_player().unwrap().points, 0); - assert_eq!(game_state.turn_stage, TurnStage::RollDice); + assert_eq!(game_state.turn_stage, TurnStage::MarkAdvPoints); } }