From 7507ea5d78338d87c06e92e12e5fabd44e5e5e25 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Jun 2025 21:20:04 +0200 Subject: [PATCH] fix workflow --- bot/src/strategy/default.rs | 20 +- bot/src/strategy/dqn.rs | 79 +++++--- bot/src/strategy/dqn_common.rs | 323 ++++++++------------------------ bot/src/strategy/dqn_trainer.rs | 6 +- doc/workflow.md | 25 +++ store/src/game.rs | 39 ++-- 6 files changed, 186 insertions(+), 306 deletions(-) create mode 100644 doc/workflow.md diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 98e8322..81aa5f1 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -36,18 +36,20 @@ impl BotStrategy for DefaultStrategy { } fn calculate_points(&self) -> u8 { - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + // let dice_roll_count = self + // .get_game() + // .players + // .get(&self.player_id) + // .unwrap() + // .dice_roll_count; + // let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); + // points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + // self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index d2fc9ed..779ce3d 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -2,7 +2,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; +use super::dqn_common::{ + get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, +}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -42,18 +44,18 @@ impl DqnStrategy { if let Some(ref model) = self.model { let state = self.game.to_vec_float(); let valid_actions = get_valid_actions(&self.game); - + if valid_actions.is_empty() { return None; } - + // Obtenir les Q-values pour toutes les actions let q_values = model.forward(&state); - + // Trouver la meilleure action valide let mut best_action = &valid_actions[0]; let mut best_q_value = f32::NEG_INFINITY; - + for action in &valid_actions { let action_index = action.to_action_index(); if action_index < q_values.len() { @@ -64,7 +66,7 @@ impl DqnStrategy { } } } - + Some(best_action.clone()) } else { // Fallback : action aléatoire valide @@ -91,26 +93,11 @@ impl BotStrategy for DqnStrategy { } fn calculate_points(&self) -> u8 { - // Utiliser le DQN pour choisir le nombre de points à marquer - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Mark { points } = action { - return points; - } - } - - // Fallback : utiliser la méthode standard - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&self.color, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { @@ -126,24 +113,55 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { move1, move2 } = action { - let checker_move1 = CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - + if let TrictracAction::Move { + dice_order, + from1, + from2, + } = action + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + let chosen_move = if self.color == Color::White { (checker_move1, checker_move2) } else { (checker_move1.mirror(), checker_move2.mirror()) }; - + return chosen_move; } } - + // Fallback : utiliser la stratégie par défaut let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + let chosen_move = *possible_moves .first() .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); @@ -155,4 +173,3 @@ impl BotStrategy for DqnStrategy { } } } - diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 022e4fc..3191b4b 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,133 +1,45 @@ +use std::cmp::max; + use serde::{Deserialize, Serialize}; +use store::{CheckerMove, Dice, GameEvent, PlayerId}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer des points - Mark { points: u8 }, + /// Marquer les points + Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions Move { - move1: (usize, usize), // (from, to) pour le premier pion - move2: (usize, usize), // (from, to) pour le deuxième pion + dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier + from1: usize, // position de départ du premier pion (0-24) + from2: usize, // position de départ du deuxième pion (0-24) }, } -/// Actions compactes basées sur le contexte du jeu -/// Réduit drastiquement l'espace d'actions en utilisant l'état du jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum CompactAction { - /// Lancer les dés - Roll, - /// Marquer des points (0-12) - Mark { points: u8 }, - /// Continuer après avoir gagné un trou - Go, - /// Choix de mouvement simplifié - MoveChoice { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, -} - -impl CompactAction { - /// Convertit CompactAction vers TrictracAction en utilisant l'état du jeu - pub fn to_trictrac_action(&self, game_state: &crate::GameState) -> Option { - match self { - CompactAction::Roll => Some(TrictracAction::Roll), - CompactAction::Mark { points } => Some(TrictracAction::Mark { points: *points }), - CompactAction::Go => Some(TrictracAction::Go), - CompactAction::MoveChoice { dice_order, from1, from2 } => { - // Calculer les positions de destination basées sur les dés - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let dice = game_state.dice; - let (die1, die2) = if *dice_order { (dice.values.0, dice.values.1) } else { (dice.values.1, dice.values.0) }; - - // Calculer les destinations (simplifiée - à adapter selon les règles de mouvement) - let to1 = if player_color == store::Color::White { - from1 + die1 as usize - } else { - from1.saturating_sub(die1 as usize) - }; - - let to2 = if player_color == store::Color::White { - from2 + die2 as usize - } else { - from2.saturating_sub(die2 as usize) - }; - - Some(TrictracAction::Move { - move1: (*from1, to1), - move2: (*from2, to2), - }) - } else { - None - } - } - } - } - - /// Taille de l'espace d'actions compactes selon le contexte - pub fn context_action_space_size(game_state: &crate::GameState) -> usize { - use store::TurnStage; - - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => 1, // Seulement Roll - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => 13, // Mark 0-12 points - TurnStage::HoldOrGoChoice => { - // Go + mouvements possibles - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - 1 + Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 1 - } - } - TurnStage::Move => { - // Seulement les mouvements - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 0 - } - } - } - } - - /// Estime le nombre d'actions compactes pour les mouvements - fn estimate_compact_moves(game_state: &crate::GameState, _possible_moves: &[(store::CheckerMove, store::CheckerMove)]) -> usize { - // Au lieu d'encoder tous les mouvements possibles, - // on utilise : 2 (ordre des dés) * 25 (from1) * 25 (from2) = 1250 maximum - // En pratique, beaucoup moins car on ne peut partir que des positions avec des pions - - let max_dice_orders = if game_state.dice.values.0 != game_state.dice.values.1 { 2 } else { 1 }; - let _max_positions = 25; // positions 0-24 - - // Estimation conservatrice : environ 10 positions de départ possibles en moyenne - max_dice_orders * 10 * 10 // ≈ 200 au lieu de 331,791 - } -} - impl TrictracAction { /// Encode une action en index pour le réseau de neurones pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark { points } => { - 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points - } - TrictracAction::Go => 14, - TrictracAction::Move { move1, move2 } => { + TrictracAction::Mark => 1, + TrictracAction::Go => 2, + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Encoder les mouvements dans l'espace d'actions - // Indices 15+ pour les mouvements - 15 + encode_move_pair(*move1, *move2) + // Indices 3+ pour les mouvements + let mut start = 3; + if !dice_order { + // 25 * 25 = 625 + start += 625; + } + start + from1 * 25 + from2 } } } @@ -136,51 +48,62 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1..=13 => Some(TrictracAction::Mark { - points: (index - 1) as u8, - }), - 14 => Some(TrictracAction::Go), - i if i >= 15 => { - let move_code = i - 15; - let (move1, move2) = decode_move_pair(move_code); - Some(TrictracAction::Move { move1, move2 }) + 1 => Some(TrictracAction::Mark), + 2 => Some(TrictracAction::Go), + i if i >= 3 => { + let move_code = i - 3; + let (dice_order, from1, from2) = Self::decode_move(move_code); + Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) } _ => None, } } + /// Décode un entier en paire de mouvements + fn decode_move(code: usize) -> (bool, usize, usize) { + let mut encoded = code; + let dice_order = code < 626; + if !dice_order { + encoded -= 625 + } + let from1 = encoded / 25; + let from2 = encoded % 25; + (dice_order, from1, from2) + } + /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles - // Pour les mouvements : 25*25*25*25 = 390625 (position 0-24 pour chaque from/to) + // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 15 + (24 * 24 * 24 * 24) // = 331791 + 3 + (2 * 25 * 25) // = 1253 } -} -/// Encode une paire de mouvements en un seul entier -fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize { - let (from1, to1) = move1; - let (from2, to2) = move2; - // Assurer que les positions sont dans la plage 0-24 - let from1 = from1.min(24); - let to1 = to1.min(24); - let from2 = from2.min(24); - let to2 = to2.min(24); - - from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2 -} - -/// Décode un entier en paire de mouvements -fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) { - let from1 = code / (25 * 25 * 25); - let remainder = code % (25 * 25 * 25); - let to1 = remainder / (25 * 25); - let remainder = remainder % (25 * 25); - let from2 = remainder / 25; - let to2 = remainder % 25; - - ((from1, to1), (from2, to2)) + // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + // match action { + // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + // TrictracAction::Go => Some(GameEvent::Go { player_id }), + // TrictracAction::Move { + // dice_order, + // from1, + // from2, + // } => { + // // Effectuer un mouvement + // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + // + // Some(GameEvent::Move { + // player_id: self.agent_player_id, + // moves: (checker_move1, checker_move2), + // }) + // } + // }; + // } } /// Configuration pour l'agent DQN @@ -350,17 +273,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(TrictracAction::Mark { points }); - } - } + valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); @@ -370,9 +283,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -381,9 +296,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -393,92 +310,6 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } -/// Génère les actions compactes valides selon l'état du jeu -pub fn get_valid_compact_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => { - valid_actions.push(CompactAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(CompactAction::Mark { points }); - } - } - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(CompactAction::Go); - - // Ajouter les choix de mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - TurnStage::Move => { - // Seulement les mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - } - } - - valid_actions -} - -/// Ajoute les actions de mouvement compactes basées sur le contexte -fn add_compact_move_actions(game_state: &crate::GameState, color: &store::Color, valid_actions: &mut Vec) { - let rules = store::MoveRules::new(color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Extraire les positions de départ uniques des mouvements possibles - let mut valid_from_positions = std::collections::HashSet::new(); - for (move1, move2) in &possible_moves { - valid_from_positions.insert(move1.get_from()); - valid_from_positions.insert(move2.get_from()); - } - - let dice = game_state.dice; - let dice_orders = if dice.values.0 != dice.values.1 { vec![true, false] } else { vec![true] }; - - // Générer les combinaisons compactes valides - for dice_order in dice_orders { - for &from1 in &valid_from_positions { - for &from2 in &valid_from_positions { - // Vérifier si cette combinaison produit un mouvement valide - let compact_action = CompactAction::MoveChoice { - dice_order, - from1, - from2 - }; - - if let Some(trictrac_action) = compact_action.to_trictrac_action(game_state) { - // Vérifier si ce mouvement est dans la liste des mouvements possibles - if let TrictracAction::Move { move1, move2 } = trictrac_action { - if let (Ok(checker_move1), Ok(checker_move2)) = - (store::CheckerMove::new(move1.0, move1.1), store::CheckerMove::new(move2.0, move2.1)) { - if possible_moves.contains(&(checker_move1, checker_move2)) { - valid_actions.push(compact_action); - } - } - } - } - } - } - } -} - /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 67c3e39..2b935f5 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -266,7 +266,11 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Move { move1, move2 } => { + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Effectuer un mouvement let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); diff --git a/doc/workflow.md b/doc/workflow.md new file mode 100644 index 0000000..2139332 --- /dev/null +++ b/doc/workflow.md @@ -0,0 +1,25 @@ +# Workflow + +@startuml + +state c <> +state haswon <> +state MarkPoints #lightblue +state MarkAdvPoints #lightblue +note right of MarkPoints : automatic 'Mark' transition\nwhen no school +note right of MarkAdvPoints : automatic 'Mark' transition\nwhen no school + +[*] -> RollDice : BeginGame +RollDice --> RollWaiting : Roll (current player) +RollWaiting --> MarkPoints : RollResult (engine) +MarkPoints --> c : Mark (current player) +c --> HoldHorGoChoice : [new hole] +c --> [*] : [has won] +c --> Move : [not new hole] +HoldHorGoChoice --> RollDice : Go +HoldHorGoChoice --> MarkAdvPoints : Move +Move --> MarkAdvPoints : Move +MarkAdvPoints --> haswon : Mark (adversary) +haswon --> RollDice : [has not won] +haswon --> [*] : [has won] +@enduml diff --git a/store/src/game.rs b/store/src/game.rs index 477895f..ed77519 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -71,7 +71,7 @@ pub struct GameState { /// last dice pair rolled pub dice: Dice, /// players points computed for the last dice pair rolled - dice_points: (u8, u8), + pub dice_points: (u8, u8), pub dice_moves: (CheckerMove, CheckerMove), pub dice_jans: PossibleJans, /// true if player needs to roll first @@ -505,13 +505,7 @@ impl GameState { self.players.remove(player_id); } Roll { player_id: _ } => { - // Opponent has moved, we can mark pending points earned during opponent's turn - let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); - if new_hole && self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; - } else { - self.turn_stage = TurnStage::RollWaiting; - } + self.turn_stage = TurnStage::RollWaiting; } RollResult { player_id: _, dice } => { self.dice = *dice; @@ -534,23 +528,25 @@ impl GameState { } } Mark { player_id, points } => { - let new_hole = self.mark_points(*player_id, *points); - if new_hole { - if self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; + if self.schools_enabled { + let new_hole = self.mark_points(*player_id, *points); + if new_hole { + if self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } else { + self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { + TurnStage::RollDice + } else { + TurnStage::HoldOrGoChoice + }; + } } else { self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { TurnStage::RollDice } else { - TurnStage::HoldOrGoChoice + TurnStage::Move }; } - } else { - self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { - TurnStage::RollDice - } else { - TurnStage::Move - }; } } Go { player_id: _ } => self.new_pick_up(), @@ -563,6 +559,11 @@ impl GameState { self.turn_stage = if self.schools_enabled { TurnStage::MarkAdvPoints } else { + // The player has moved, we can mark its opponent's points (which is now the current player) + let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); + if new_hole && self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } TurnStage::RollDice }; }