From f6ec3ef5ae780337efd1c8fcd514ef3bd182a763 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Jun 2025 21:20:04 +0200 Subject: [PATCH 1/5] wip --- bot/src/strategy/dqn_common.rs | 322 ++++++++------------------------ bot/src/strategy/dqn_trainer.rs | 6 +- store/src/game.rs | 6 + 3 files changed, 88 insertions(+), 246 deletions(-) diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 022e4fc..5cf30d5 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,133 +1,45 @@ +use std::cmp::max; + use serde::{Deserialize, Serialize}; +use store::{CheckerMove, Dice, GameEvent, PlayerId}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer des points - Mark { points: u8 }, + /// Marquer les points + Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions Move { - move1: (usize, usize), // (from, to) pour le premier pion - move2: (usize, usize), // (from, to) pour le deuxième pion + dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier + from1: usize, // position de départ du premier pion (0-24) + from2: usize, // position de départ du deuxième pion (0-24) }, } -/// Actions compactes basées sur le contexte du jeu -/// Réduit drastiquement l'espace d'actions en utilisant l'état du jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum CompactAction { - /// Lancer les dés - Roll, - /// Marquer des points (0-12) - Mark { points: u8 }, - /// Continuer après avoir gagné un trou - Go, - /// Choix de mouvement simplifié - MoveChoice { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, -} - -impl CompactAction { - /// Convertit CompactAction vers TrictracAction en utilisant l'état du jeu - pub fn to_trictrac_action(&self, game_state: &crate::GameState) -> Option { - match self { - CompactAction::Roll => Some(TrictracAction::Roll), - CompactAction::Mark { points } => Some(TrictracAction::Mark { points: *points }), - CompactAction::Go => Some(TrictracAction::Go), - CompactAction::MoveChoice { dice_order, from1, from2 } => { - // Calculer les positions de destination basées sur les dés - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let dice = game_state.dice; - let (die1, die2) = if *dice_order { (dice.values.0, dice.values.1) } else { (dice.values.1, dice.values.0) }; - - // Calculer les destinations (simplifiée - à adapter selon les règles de mouvement) - let to1 = if player_color == store::Color::White { - from1 + die1 as usize - } else { - from1.saturating_sub(die1 as usize) - }; - - let to2 = if player_color == store::Color::White { - from2 + die2 as usize - } else { - from2.saturating_sub(die2 as usize) - }; - - Some(TrictracAction::Move { - move1: (*from1, to1), - move2: (*from2, to2), - }) - } else { - None - } - } - } - } - - /// Taille de l'espace d'actions compactes selon le contexte - pub fn context_action_space_size(game_state: &crate::GameState) -> usize { - use store::TurnStage; - - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => 1, // Seulement Roll - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => 13, // Mark 0-12 points - TurnStage::HoldOrGoChoice => { - // Go + mouvements possibles - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - 1 + Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 1 - } - } - TurnStage::Move => { - // Seulement les mouvements - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 0 - } - } - } - } - - /// Estime le nombre d'actions compactes pour les mouvements - fn estimate_compact_moves(game_state: &crate::GameState, _possible_moves: &[(store::CheckerMove, store::CheckerMove)]) -> usize { - // Au lieu d'encoder tous les mouvements possibles, - // on utilise : 2 (ordre des dés) * 25 (from1) * 25 (from2) = 1250 maximum - // En pratique, beaucoup moins car on ne peut partir que des positions avec des pions - - let max_dice_orders = if game_state.dice.values.0 != game_state.dice.values.1 { 2 } else { 1 }; - let _max_positions = 25; // positions 0-24 - - // Estimation conservatrice : environ 10 positions de départ possibles en moyenne - max_dice_orders * 10 * 10 // ≈ 200 au lieu de 331,791 - } -} - impl TrictracAction { /// Encode une action en index pour le réseau de neurones pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark { points } => { - 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points - } - TrictracAction::Go => 14, - TrictracAction::Move { move1, move2 } => { + TrictracAction::Mark => 1, + TrictracAction::Go => 2, + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Encoder les mouvements dans l'espace d'actions - // Indices 15+ pour les mouvements - 15 + encode_move_pair(*move1, *move2) + // Indices 3+ pour les mouvements + let mut start = 3; + if !dice_order { + // 25 * 25 = 625 + start += 625; + } + start + from1 * 25 + from2 } } } @@ -136,51 +48,63 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1..=13 => Some(TrictracAction::Mark { - points: (index - 1) as u8, - }), - 14 => Some(TrictracAction::Go), - i if i >= 15 => { - let move_code = i - 15; - let (move1, move2) = decode_move_pair(move_code); - Some(TrictracAction::Move { move1, move2 }) + 1 => Some(TrictracAction::Mark), + 2 => Some(TrictracAction::Go), + i if i >= 3 => { + let move_code = i - 3; + let (dice_order, from1, from2) = decode_move(move_code); + Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) } _ => None, } } + /// Décode un entier en paire de mouvements + fn decode_move(code: usize) -> (bool, usize, usize) { + let mut encoded = code; + let dice_order = code < 626; + if !dice_order { + encoded -= 625 + } + let from1 = encoded / 25; + let from2 = encoded % 25; + (dice_order, from1, from2) + } + /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles - // Pour les mouvements : 25*25*25*25 = 390625 (position 0-24 pour chaque from/to) + // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 15 + (24 * 24 * 24 * 24) // = 331791 + 3 + (2 * 25 * 25) // = 1253 } -} -/// Encode une paire de mouvements en un seul entier -fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize { - let (from1, to1) = move1; - let (from2, to2) = move2; - // Assurer que les positions sont dans la plage 0-24 - let from1 = from1.min(24); - let to1 = to1.min(24); - let from2 = from2.min(24); - let to2 = to2.min(24); + pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + match action { + TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + TrictracAction::Go => Some(GameEvent::Go { player_id }), + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2 -} - -/// Décode un entier en paire de mouvements -fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) { - let from1 = code / (25 * 25 * 25); - let remainder = code % (25 * 25 * 25); - let to1 = remainder / (25 * 25); - let remainder = remainder % (25 * 25); - let from2 = remainder / 25; - let to2 = remainder % 25; - - ((from1, to1), (from2, to2)) + reward += 0.2; + Some(GameEvent::Move { + player_id: self.agent_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + } } /// Configuration pour l'agent DQN @@ -350,17 +274,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(TrictracAction::Mark { points }); - } - } + valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); @@ -370,9 +284,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -381,9 +297,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -393,92 +311,6 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } -/// Génère les actions compactes valides selon l'état du jeu -pub fn get_valid_compact_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => { - valid_actions.push(CompactAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(CompactAction::Mark { points }); - } - } - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(CompactAction::Go); - - // Ajouter les choix de mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - TurnStage::Move => { - // Seulement les mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - } - } - - valid_actions -} - -/// Ajoute les actions de mouvement compactes basées sur le contexte -fn add_compact_move_actions(game_state: &crate::GameState, color: &store::Color, valid_actions: &mut Vec) { - let rules = store::MoveRules::new(color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Extraire les positions de départ uniques des mouvements possibles - let mut valid_from_positions = std::collections::HashSet::new(); - for (move1, move2) in &possible_moves { - valid_from_positions.insert(move1.get_from()); - valid_from_positions.insert(move2.get_from()); - } - - let dice = game_state.dice; - let dice_orders = if dice.values.0 != dice.values.1 { vec![true, false] } else { vec![true] }; - - // Générer les combinaisons compactes valides - for dice_order in dice_orders { - for &from1 in &valid_from_positions { - for &from2 in &valid_from_positions { - // Vérifier si cette combinaison produit un mouvement valide - let compact_action = CompactAction::MoveChoice { - dice_order, - from1, - from2 - }; - - if let Some(trictrac_action) = compact_action.to_trictrac_action(game_state) { - // Vérifier si ce mouvement est dans la liste des mouvements possibles - if let TrictracAction::Move { move1, move2 } = trictrac_action { - if let (Ok(checker_move1), Ok(checker_move2)) = - (store::CheckerMove::new(move1.0, move1.1), store::CheckerMove::new(move2.0, move2.1)) { - if possible_moves.contains(&(checker_move1, checker_move2)) { - valid_actions.push(compact_action); - } - } - } - } - } - } - } -} - /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 67c3e39..2b935f5 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -266,7 +266,11 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Move { move1, move2 } => { + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Effectuer un mouvement let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); diff --git a/store/src/game.rs b/store/src/game.rs index 477895f..90e905b 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -571,6 +571,12 @@ impl GameState { self.history.push(valid_event.clone()); } + fn dice_points(&self) -> u8 { + let player = self.players.get(&self.active_player_id).unwrap(); + let points_rules = PointsRules::new(&player.color, &self.board, self.dice); + let (points, adv_points) = points_rules.get_points(player.dice_roll_count); + } + /// Set a new pick up ('relevé') after a player won a hole and choose to 'go', /// or after a player has bore off (took of his men off the board) fn new_pick_up(&mut self) { From 5f33737c1bda8182c3190f58d5f1977926088a8d Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 9 Jun 2025 20:17:00 +0200 Subject: [PATCH 2/5] wip fix workflow --- bot/src/strategy/default.rs | 20 +++++---- bot/src/strategy/dqn.rs | 79 +++++++++++++++++++++------------- bot/src/strategy/dqn_common.rs | 45 ++++++++++--------- doc/workflow.md | 25 +++++++++++ store/src/game.rs | 45 +++++++++---------- 5 files changed, 126 insertions(+), 88 deletions(-) create mode 100644 doc/workflow.md diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 98e8322..81aa5f1 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -36,18 +36,20 @@ impl BotStrategy for DefaultStrategy { } fn calculate_points(&self) -> u8 { - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + // let dice_roll_count = self + // .get_game() + // .players + // .get(&self.player_id) + // .unwrap() + // .dice_roll_count; + // let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); + // points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + // self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index d2fc9ed..779ce3d 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -2,7 +2,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; +use super::dqn_common::{ + get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, +}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -42,18 +44,18 @@ impl DqnStrategy { if let Some(ref model) = self.model { let state = self.game.to_vec_float(); let valid_actions = get_valid_actions(&self.game); - + if valid_actions.is_empty() { return None; } - + // Obtenir les Q-values pour toutes les actions let q_values = model.forward(&state); - + // Trouver la meilleure action valide let mut best_action = &valid_actions[0]; let mut best_q_value = f32::NEG_INFINITY; - + for action in &valid_actions { let action_index = action.to_action_index(); if action_index < q_values.len() { @@ -64,7 +66,7 @@ impl DqnStrategy { } } } - + Some(best_action.clone()) } else { // Fallback : action aléatoire valide @@ -91,26 +93,11 @@ impl BotStrategy for DqnStrategy { } fn calculate_points(&self) -> u8 { - // Utiliser le DQN pour choisir le nombre de points à marquer - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Mark { points } = action { - return points; - } - } - - // Fallback : utiliser la méthode standard - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&self.color, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { @@ -126,24 +113,55 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { move1, move2 } = action { - let checker_move1 = CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - + if let TrictracAction::Move { + dice_order, + from1, + from2, + } = action + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + let chosen_move = if self.color == Color::White { (checker_move1, checker_move2) } else { (checker_move1.mirror(), checker_move2.mirror()) }; - + return chosen_move; } } - + // Fallback : utiliser la stratégie par défaut let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + let chosen_move = *possible_moves .first() .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); @@ -155,4 +173,3 @@ impl BotStrategy for DqnStrategy { } } } - diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 5cf30d5..3191b4b 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -52,7 +52,7 @@ impl TrictracAction { 2 => Some(TrictracAction::Go), i if i >= 3 => { let move_code = i - 3; - let (dice_order, from1, from2) = decode_move(move_code); + let (dice_order, from1, from2) = Self::decode_move(move_code); Some(TrictracAction::Move { dice_order, from1, @@ -83,28 +83,27 @@ impl TrictracAction { 3 + (2 * 25 * 25) // = 1253 } - pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { - match action { - TrictracAction::Roll => Some(GameEvent::Roll { player_id }), - TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), - TrictracAction::Go => Some(GameEvent::Go { player_id }), - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.agent_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - } + // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + // match action { + // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + // TrictracAction::Go => Some(GameEvent::Go { player_id }), + // TrictracAction::Move { + // dice_order, + // from1, + // from2, + // } => { + // // Effectuer un mouvement + // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + // + // Some(GameEvent::Move { + // player_id: self.agent_player_id, + // moves: (checker_move1, checker_move2), + // }) + // } + // }; + // } } /// Configuration pour l'agent DQN diff --git a/doc/workflow.md b/doc/workflow.md new file mode 100644 index 0000000..2139332 --- /dev/null +++ b/doc/workflow.md @@ -0,0 +1,25 @@ +# Workflow + +@startuml + +state c <> +state haswon <> +state MarkPoints #lightblue +state MarkAdvPoints #lightblue +note right of MarkPoints : automatic 'Mark' transition\nwhen no school +note right of MarkAdvPoints : automatic 'Mark' transition\nwhen no school + +[*] -> RollDice : BeginGame +RollDice --> RollWaiting : Roll (current player) +RollWaiting --> MarkPoints : RollResult (engine) +MarkPoints --> c : Mark (current player) +c --> HoldHorGoChoice : [new hole] +c --> [*] : [has won] +c --> Move : [not new hole] +HoldHorGoChoice --> RollDice : Go +HoldHorGoChoice --> MarkAdvPoints : Move +Move --> MarkAdvPoints : Move +MarkAdvPoints --> haswon : Mark (adversary) +haswon --> RollDice : [has not won] +haswon --> [*] : [has won] +@enduml diff --git a/store/src/game.rs b/store/src/game.rs index 90e905b..ed77519 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -71,7 +71,7 @@ pub struct GameState { /// last dice pair rolled pub dice: Dice, /// players points computed for the last dice pair rolled - dice_points: (u8, u8), + pub dice_points: (u8, u8), pub dice_moves: (CheckerMove, CheckerMove), pub dice_jans: PossibleJans, /// true if player needs to roll first @@ -505,13 +505,7 @@ impl GameState { self.players.remove(player_id); } Roll { player_id: _ } => { - // Opponent has moved, we can mark pending points earned during opponent's turn - let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); - if new_hole && self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; - } else { - self.turn_stage = TurnStage::RollWaiting; - } + self.turn_stage = TurnStage::RollWaiting; } RollResult { player_id: _, dice } => { self.dice = *dice; @@ -534,23 +528,25 @@ impl GameState { } } Mark { player_id, points } => { - let new_hole = self.mark_points(*player_id, *points); - if new_hole { - if self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; + if self.schools_enabled { + let new_hole = self.mark_points(*player_id, *points); + if new_hole { + if self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } else { + self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { + TurnStage::RollDice + } else { + TurnStage::HoldOrGoChoice + }; + } } else { self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { TurnStage::RollDice } else { - TurnStage::HoldOrGoChoice + TurnStage::Move }; } - } else { - self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { - TurnStage::RollDice - } else { - TurnStage::Move - }; } } Go { player_id: _ } => self.new_pick_up(), @@ -563,6 +559,11 @@ impl GameState { self.turn_stage = if self.schools_enabled { TurnStage::MarkAdvPoints } else { + // The player has moved, we can mark its opponent's points (which is now the current player) + let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); + if new_hole && self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } TurnStage::RollDice }; } @@ -571,12 +572,6 @@ impl GameState { self.history.push(valid_event.clone()); } - fn dice_points(&self) -> u8 { - let player = self.players.get(&self.active_player_id).unwrap(); - let points_rules = PointsRules::new(&player.color, &self.board, self.dice); - let (points, adv_points) = points_rules.get_points(player.dice_roll_count); - } - /// Set a new pick up ('relevé') after a player won a hole and choose to 'go', /// or after a player has bore off (took of his men off the board) fn new_pick_up(&mut self) { From 8e6a94425c9af98c8182f724adefd7089629d316 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 11 Jun 2025 17:31:35 +0200 Subject: [PATCH 3/5] dqn trainer --- bot/src/strategy/dqn_common.rs | 104 ++++++++++++++++++++++++-------- bot/src/strategy/dqn_trainer.rs | 37 ++++++++---- store/src/game.rs | 4 +- 3 files changed, 107 insertions(+), 38 deletions(-) diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 3191b4b..9a24ae6 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,4 +1,4 @@ -use std::cmp::max; +use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; use store::{CheckerMove, Dice, GameEvent, PlayerId}; @@ -8,8 +8,6 @@ use store::{CheckerMove, Dice, GameEvent, PlayerId}; pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer les points - Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions @@ -18,6 +16,8 @@ pub enum TrictracAction { from1: usize, // position de départ du premier pion (0-24) from2: usize, // position de départ du deuxième pion (0-24) }, + // Marquer les points : à activer si support des écoles + // Mark, } impl TrictracAction { @@ -25,22 +25,22 @@ impl TrictracAction { pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark => 1, - TrictracAction::Go => 2, + TrictracAction::Go => 1, TrictracAction::Move { dice_order, from1, from2, } => { // Encoder les mouvements dans l'espace d'actions - // Indices 3+ pour les mouvements - let mut start = 3; + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; if !dice_order { // 25 * 25 = 625 start += 625; } start + from1 * 25 + from2 - } + } // TrictracAction::Mark => 1252, } } @@ -48,8 +48,8 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1 => Some(TrictracAction::Mark), - 2 => Some(TrictracAction::Go), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), i if i >= 3 => { let move_code = i - 3; let (dice_order, from1, from2) = Self::decode_move(move_code); @@ -77,10 +77,10 @@ impl TrictracAction { /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // 1 (Roll) + 1 (Go) + mouvements possibles // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 3 + (2 * 25 * 25) // = 1253 + 2 + (2 * 25 * 25) // = 1252 } // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { @@ -273,35 +273,37 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - valid_actions.push(TrictracAction::Mark); + // valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); - // Ajouter aussi les mouvements possibles + // Ajoute aussi les mouvements possibles let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } TurnStage::Move => { let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } } @@ -310,6 +312,56 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } +// Valid only for White player +fn checker_moves_to_trictrac_action( + move1: &CheckerMove, + move2: &CheckerMove, + dice: &Dice, +) -> TrictracAction { + let to1 = move1.get_to(); + let to2 = move2.get_to(); + let from1 = move1.get_from(); + let from2 = move2.get_from(); + + let mut diff_move1 = if to1 > 0 { + // Mouvement sans sortie + to1 - from1 + } else { + // sortie, on utilise la valeur du dé + if to2 > 0 { + // sortie pour le mouvement 1 uniquement + let dice2 = to2 - from2; + if dice2 == dice.values.0 as usize { + dice.values.1 as usize + } else { + dice.values.0 as usize + } + } else { + // double sortie + if from1 < from2 { + max(dice.values.0, dice.values.1) as usize + } else { + min(dice.values.0, dice.values.1) as usize + } + } + }; + + // modification de diff_move1 si on est dans le cas d'un mouvement par puissance + let rest_field = 12; + if to1 == rest_field + && to2 == rest_field + && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field + { + // prise par puissance + diff_move1 += 1; + } + TrictracAction::Move { + dice_order: diff_move1 == dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), + } +} + /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 2b935f5..2cd0b3d 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -251,14 +251,15 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Mark { points } => { - // Marquer des points - reward += 0.1 * points as f32; - Some(GameEvent::Mark { - player_id: self.agent_player_id, - points, - }) - } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game_state. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.agent_player_id, + // points, + // }) + // } TrictracAction::Go => { // Continuer après avoir gagné un trou reward += 0.2; @@ -272,8 +273,23 @@ impl TrictracEnv { from2, } => { // Effectuer un mouvement - let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + let (dice1, dice2) = if dice_order { + (self.game_state.dice.values.0, self.game_state.dice.values.1) + } else { + (self.game_state.dice.values.1, self.game_state.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); reward += 0.2; Some(GameEvent::Move { @@ -443,7 +459,6 @@ impl DqnTrainer { for episode in 1..=episodes { let reward = self.train_episode(); - print!("."); if episode % 100 == 0 { println!( "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", diff --git a/store/src/game.rs b/store/src/game.rs index ed77519..fe2762f 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -757,6 +757,7 @@ mod tests { #[test] fn hold_or_go() { let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -782,6 +783,7 @@ mod tests { // Hold let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -802,6 +804,6 @@ mod tests { assert_ne!(game_state.active_player_id, pid); assert_eq!(game_state.players.get(&pid).unwrap().points, 1); assert_eq!(game_state.get_active_player().unwrap().points, 0); - assert_eq!(game_state.turn_stage, TurnStage::RollDice); + assert_eq!(game_state.turn_stage, TurnStage::MarkAdvPoints); } } From 7507ea5d78338d87c06e92e12e5fabd44e5e5e25 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 8 Jun 2025 21:20:04 +0200 Subject: [PATCH 4/5] fix workflow --- bot/src/strategy/default.rs | 20 +- bot/src/strategy/dqn.rs | 79 +++++--- bot/src/strategy/dqn_common.rs | 323 ++++++++------------------------ bot/src/strategy/dqn_trainer.rs | 6 +- doc/workflow.md | 25 +++ store/src/game.rs | 39 ++-- 6 files changed, 186 insertions(+), 306 deletions(-) create mode 100644 doc/workflow.md diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 98e8322..81aa5f1 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -36,18 +36,20 @@ impl BotStrategy for DefaultStrategy { } fn calculate_points(&self) -> u8 { - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + // let dice_roll_count = self + // .get_game() + // .players + // .get(&self.player_id) + // .unwrap() + // .dice_roll_count; + // let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); + // points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + // self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index d2fc9ed..779ce3d 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -2,7 +2,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; +use super::dqn_common::{ + get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, +}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -42,18 +44,18 @@ impl DqnStrategy { if let Some(ref model) = self.model { let state = self.game.to_vec_float(); let valid_actions = get_valid_actions(&self.game); - + if valid_actions.is_empty() { return None; } - + // Obtenir les Q-values pour toutes les actions let q_values = model.forward(&state); - + // Trouver la meilleure action valide let mut best_action = &valid_actions[0]; let mut best_q_value = f32::NEG_INFINITY; - + for action in &valid_actions { let action_index = action.to_action_index(); if action_index < q_values.len() { @@ -64,7 +66,7 @@ impl DqnStrategy { } } } - + Some(best_action.clone()) } else { // Fallback : action aléatoire valide @@ -91,26 +93,11 @@ impl BotStrategy for DqnStrategy { } fn calculate_points(&self) -> u8 { - // Utiliser le DQN pour choisir le nombre de points à marquer - if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Mark { points } = action { - return points; - } - } - - // Fallback : utiliser la méthode standard - let dice_roll_count = self - .get_game() - .players - .get(&self.player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new(&self.color, &self.game.board, self.game.dice); - points_rules.get_points(dice_roll_count).0 + self.game.dice_points.0 } fn calculate_adv_points(&self) -> u8 { - self.calculate_points() + self.game.dice_points.1 } fn choose_go(&self) -> bool { @@ -126,24 +113,55 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { // Utiliser le DQN pour choisir le mouvement if let Some(action) = self.get_dqn_action() { - if let TrictracAction::Move { move1, move2 } = action { - let checker_move1 = CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - + if let TrictracAction::Move { + dice_order, + from1, + from2, + } = action + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + let chosen_move = if self.color == Color::White { (checker_move1, checker_move2) } else { (checker_move1.mirror(), checker_move2.mirror()) }; - + return chosen_move; } } - + // Fallback : utiliser la stratégie par défaut let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + let chosen_move = *possible_moves .first() .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); @@ -155,4 +173,3 @@ impl BotStrategy for DqnStrategy { } } } - diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 022e4fc..3191b4b 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,133 +1,45 @@ +use std::cmp::max; + use serde::{Deserialize, Serialize}; +use store::{CheckerMove, Dice, GameEvent, PlayerId}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer des points - Mark { points: u8 }, + /// Marquer les points + Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions Move { - move1: (usize, usize), // (from, to) pour le premier pion - move2: (usize, usize), // (from, to) pour le deuxième pion + dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier + from1: usize, // position de départ du premier pion (0-24) + from2: usize, // position de départ du deuxième pion (0-24) }, } -/// Actions compactes basées sur le contexte du jeu -/// Réduit drastiquement l'espace d'actions en utilisant l'état du jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum CompactAction { - /// Lancer les dés - Roll, - /// Marquer des points (0-12) - Mark { points: u8 }, - /// Continuer après avoir gagné un trou - Go, - /// Choix de mouvement simplifié - MoveChoice { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, -} - -impl CompactAction { - /// Convertit CompactAction vers TrictracAction en utilisant l'état du jeu - pub fn to_trictrac_action(&self, game_state: &crate::GameState) -> Option { - match self { - CompactAction::Roll => Some(TrictracAction::Roll), - CompactAction::Mark { points } => Some(TrictracAction::Mark { points: *points }), - CompactAction::Go => Some(TrictracAction::Go), - CompactAction::MoveChoice { dice_order, from1, from2 } => { - // Calculer les positions de destination basées sur les dés - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let dice = game_state.dice; - let (die1, die2) = if *dice_order { (dice.values.0, dice.values.1) } else { (dice.values.1, dice.values.0) }; - - // Calculer les destinations (simplifiée - à adapter selon les règles de mouvement) - let to1 = if player_color == store::Color::White { - from1 + die1 as usize - } else { - from1.saturating_sub(die1 as usize) - }; - - let to2 = if player_color == store::Color::White { - from2 + die2 as usize - } else { - from2.saturating_sub(die2 as usize) - }; - - Some(TrictracAction::Move { - move1: (*from1, to1), - move2: (*from2, to2), - }) - } else { - None - } - } - } - } - - /// Taille de l'espace d'actions compactes selon le contexte - pub fn context_action_space_size(game_state: &crate::GameState) -> usize { - use store::TurnStage; - - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => 1, // Seulement Roll - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => 13, // Mark 0-12 points - TurnStage::HoldOrGoChoice => { - // Go + mouvements possibles - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - 1 + Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 1 - } - } - TurnStage::Move => { - // Seulement les mouvements - if let Some(player_color) = game_state.player_color_by_id(&game_state.active_player_id) { - let rules = store::MoveRules::new(&player_color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - Self::estimate_compact_moves(game_state, &possible_moves) - } else { - 0 - } - } - } - } - - /// Estime le nombre d'actions compactes pour les mouvements - fn estimate_compact_moves(game_state: &crate::GameState, _possible_moves: &[(store::CheckerMove, store::CheckerMove)]) -> usize { - // Au lieu d'encoder tous les mouvements possibles, - // on utilise : 2 (ordre des dés) * 25 (from1) * 25 (from2) = 1250 maximum - // En pratique, beaucoup moins car on ne peut partir que des positions avec des pions - - let max_dice_orders = if game_state.dice.values.0 != game_state.dice.values.1 { 2 } else { 1 }; - let _max_positions = 25; // positions 0-24 - - // Estimation conservatrice : environ 10 positions de départ possibles en moyenne - max_dice_orders * 10 * 10 // ≈ 200 au lieu de 331,791 - } -} - impl TrictracAction { /// Encode une action en index pour le réseau de neurones pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark { points } => { - 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points - } - TrictracAction::Go => 14, - TrictracAction::Move { move1, move2 } => { + TrictracAction::Mark => 1, + TrictracAction::Go => 2, + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Encoder les mouvements dans l'espace d'actions - // Indices 15+ pour les mouvements - 15 + encode_move_pair(*move1, *move2) + // Indices 3+ pour les mouvements + let mut start = 3; + if !dice_order { + // 25 * 25 = 625 + start += 625; + } + start + from1 * 25 + from2 } } } @@ -136,51 +48,62 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1..=13 => Some(TrictracAction::Mark { - points: (index - 1) as u8, - }), - 14 => Some(TrictracAction::Go), - i if i >= 15 => { - let move_code = i - 15; - let (move1, move2) = decode_move_pair(move_code); - Some(TrictracAction::Move { move1, move2 }) + 1 => Some(TrictracAction::Mark), + 2 => Some(TrictracAction::Go), + i if i >= 3 => { + let move_code = i - 3; + let (dice_order, from1, from2) = Self::decode_move(move_code); + Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) } _ => None, } } + /// Décode un entier en paire de mouvements + fn decode_move(code: usize) -> (bool, usize, usize) { + let mut encoded = code; + let dice_order = code < 626; + if !dice_order { + encoded -= 625 + } + let from1 = encoded / 25; + let from2 = encoded % 25; + (dice_order, from1, from2) + } + /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles - // Pour les mouvements : 25*25*25*25 = 390625 (position 0-24 pour chaque from/to) + // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 15 + (24 * 24 * 24 * 24) // = 331791 + 3 + (2 * 25 * 25) // = 1253 } -} -/// Encode une paire de mouvements en un seul entier -fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize { - let (from1, to1) = move1; - let (from2, to2) = move2; - // Assurer que les positions sont dans la plage 0-24 - let from1 = from1.min(24); - let to1 = to1.min(24); - let from2 = from2.min(24); - let to2 = to2.min(24); - - from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2 -} - -/// Décode un entier en paire de mouvements -fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) { - let from1 = code / (25 * 25 * 25); - let remainder = code % (25 * 25 * 25); - let to1 = remainder / (25 * 25); - let remainder = remainder % (25 * 25); - let from2 = remainder / 25; - let to2 = remainder % 25; - - ((from1, to1), (from2, to2)) + // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + // match action { + // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + // TrictracAction::Go => Some(GameEvent::Go { player_id }), + // TrictracAction::Move { + // dice_order, + // from1, + // from2, + // } => { + // // Effectuer un mouvement + // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + // + // Some(GameEvent::Move { + // player_id: self.agent_player_id, + // moves: (checker_move1, checker_move2), + // }) + // } + // }; + // } } /// Configuration pour l'agent DQN @@ -350,17 +273,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(TrictracAction::Mark { points }); - } - } + valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); @@ -370,9 +283,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -381,9 +296,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let possible_moves = rules.get_possible_moves_sequences(true, vec![]); for (move1, move2) in possible_moves { + let diff_move1 = move1.get_to() - move1.get_from(); valid_actions.push(TrictracAction::Move { - move1: (move1.get_from(), move1.get_to()), - move2: (move2.get_from(), move2.get_to()), + dice_order: diff_move1 == game_state.dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), }); } } @@ -393,92 +310,6 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } -/// Génère les actions compactes valides selon l'état du jeu -pub fn get_valid_compact_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => { - valid_actions.push(CompactAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // Calculer les points possibles - if let Some(player) = game_state.players.get(&active_player_id) { - let dice_roll_count = player.dice_roll_count; - let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); - let (max_points, _) = points_rules.get_points(dice_roll_count); - - // Permettre de marquer entre 0 et max_points - for points in 0..=max_points { - valid_actions.push(CompactAction::Mark { points }); - } - } - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(CompactAction::Go); - - // Ajouter les choix de mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - TurnStage::Move => { - // Seulement les mouvements compacts - add_compact_move_actions(game_state, &color, &mut valid_actions); - } - } - } - - valid_actions -} - -/// Ajoute les actions de mouvement compactes basées sur le contexte -fn add_compact_move_actions(game_state: &crate::GameState, color: &store::Color, valid_actions: &mut Vec) { - let rules = store::MoveRules::new(color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Extraire les positions de départ uniques des mouvements possibles - let mut valid_from_positions = std::collections::HashSet::new(); - for (move1, move2) in &possible_moves { - valid_from_positions.insert(move1.get_from()); - valid_from_positions.insert(move2.get_from()); - } - - let dice = game_state.dice; - let dice_orders = if dice.values.0 != dice.values.1 { vec![true, false] } else { vec![true] }; - - // Générer les combinaisons compactes valides - for dice_order in dice_orders { - for &from1 in &valid_from_positions { - for &from2 in &valid_from_positions { - // Vérifier si cette combinaison produit un mouvement valide - let compact_action = CompactAction::MoveChoice { - dice_order, - from1, - from2 - }; - - if let Some(trictrac_action) = compact_action.to_trictrac_action(game_state) { - // Vérifier si ce mouvement est dans la liste des mouvements possibles - if let TrictracAction::Move { move1, move2 } = trictrac_action { - if let (Ok(checker_move1), Ok(checker_move2)) = - (store::CheckerMove::new(move1.0, move1.1), store::CheckerMove::new(move2.0, move2.1)) { - if possible_moves.contains(&(checker_move1, checker_move2)) { - valid_actions.push(compact_action); - } - } - } - } - } - } - } -} - /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 67c3e39..2b935f5 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -266,7 +266,11 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Move { move1, move2 } => { + TrictracAction::Move { + dice_order, + from1, + from2, + } => { // Effectuer un mouvement let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); diff --git a/doc/workflow.md b/doc/workflow.md new file mode 100644 index 0000000..2139332 --- /dev/null +++ b/doc/workflow.md @@ -0,0 +1,25 @@ +# Workflow + +@startuml + +state c <> +state haswon <> +state MarkPoints #lightblue +state MarkAdvPoints #lightblue +note right of MarkPoints : automatic 'Mark' transition\nwhen no school +note right of MarkAdvPoints : automatic 'Mark' transition\nwhen no school + +[*] -> RollDice : BeginGame +RollDice --> RollWaiting : Roll (current player) +RollWaiting --> MarkPoints : RollResult (engine) +MarkPoints --> c : Mark (current player) +c --> HoldHorGoChoice : [new hole] +c --> [*] : [has won] +c --> Move : [not new hole] +HoldHorGoChoice --> RollDice : Go +HoldHorGoChoice --> MarkAdvPoints : Move +Move --> MarkAdvPoints : Move +MarkAdvPoints --> haswon : Mark (adversary) +haswon --> RollDice : [has not won] +haswon --> [*] : [has won] +@enduml diff --git a/store/src/game.rs b/store/src/game.rs index 477895f..ed77519 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -71,7 +71,7 @@ pub struct GameState { /// last dice pair rolled pub dice: Dice, /// players points computed for the last dice pair rolled - dice_points: (u8, u8), + pub dice_points: (u8, u8), pub dice_moves: (CheckerMove, CheckerMove), pub dice_jans: PossibleJans, /// true if player needs to roll first @@ -505,13 +505,7 @@ impl GameState { self.players.remove(player_id); } Roll { player_id: _ } => { - // Opponent has moved, we can mark pending points earned during opponent's turn - let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); - if new_hole && self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; - } else { - self.turn_stage = TurnStage::RollWaiting; - } + self.turn_stage = TurnStage::RollWaiting; } RollResult { player_id: _, dice } => { self.dice = *dice; @@ -534,23 +528,25 @@ impl GameState { } } Mark { player_id, points } => { - let new_hole = self.mark_points(*player_id, *points); - if new_hole { - if self.get_active_player().unwrap().holes > 12 { - self.stage = Stage::Ended; + if self.schools_enabled { + let new_hole = self.mark_points(*player_id, *points); + if new_hole { + if self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } else { + self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { + TurnStage::RollDice + } else { + TurnStage::HoldOrGoChoice + }; + } } else { self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { TurnStage::RollDice } else { - TurnStage::HoldOrGoChoice + TurnStage::Move }; } - } else { - self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints { - TurnStage::RollDice - } else { - TurnStage::Move - }; } } Go { player_id: _ } => self.new_pick_up(), @@ -563,6 +559,11 @@ impl GameState { self.turn_stage = if self.schools_enabled { TurnStage::MarkAdvPoints } else { + // The player has moved, we can mark its opponent's points (which is now the current player) + let new_hole = self.mark_points(self.active_player_id, self.dice_points.1); + if new_hole && self.get_active_player().unwrap().holes > 12 { + self.stage = Stage::Ended; + } TurnStage::RollDice }; } From dc197fbc6f62749d4c8a28ba7016016463244cb5 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 11 Jun 2025 17:31:35 +0200 Subject: [PATCH 5/5] dqn trainer --- bot/src/strategy/dqn_common.rs | 104 ++++++++++++++++++++++++-------- bot/src/strategy/dqn_trainer.rs | 43 +++++++++---- store/src/game.rs | 4 +- 3 files changed, 111 insertions(+), 40 deletions(-) diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 3191b4b..9a24ae6 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,4 +1,4 @@ -use std::cmp::max; +use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; use store::{CheckerMove, Dice, GameEvent, PlayerId}; @@ -8,8 +8,6 @@ use store::{CheckerMove, Dice, GameEvent, PlayerId}; pub enum TrictracAction { /// Lancer les dés Roll, - /// Marquer les points - Mark, /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions @@ -18,6 +16,8 @@ pub enum TrictracAction { from1: usize, // position de départ du premier pion (0-24) from2: usize, // position de départ du deuxième pion (0-24) }, + // Marquer les points : à activer si support des écoles + // Mark, } impl TrictracAction { @@ -25,22 +25,22 @@ impl TrictracAction { pub fn to_action_index(&self) -> usize { match self { TrictracAction::Roll => 0, - TrictracAction::Mark => 1, - TrictracAction::Go => 2, + TrictracAction::Go => 1, TrictracAction::Move { dice_order, from1, from2, } => { // Encoder les mouvements dans l'espace d'actions - // Indices 3+ pour les mouvements - let mut start = 3; + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; if !dice_order { // 25 * 25 = 625 start += 625; } start + from1 * 25 + from2 - } + } // TrictracAction::Mark => 1252, } } @@ -48,8 +48,8 @@ impl TrictracAction { pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1 => Some(TrictracAction::Mark), - 2 => Some(TrictracAction::Go), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), i if i >= 3 => { let move_code = i - 3; let (dice_order, from1, from2) = Self::decode_move(move_code); @@ -77,10 +77,10 @@ impl TrictracAction { /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Mark) + 1 (Go) + mouvements possibles + // 1 (Roll) + 1 (Go) + mouvements possibles // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) // Mais on peut optimiser en limitant aux positions valides (1-24) - 3 + (2 * 25 * 25) // = 1253 + 2 + (2 * 25 * 25) // = 1252 } // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { @@ -273,35 +273,37 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions.push(TrictracAction::Roll); } TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - valid_actions.push(TrictracAction::Mark); + // valid_actions.push(TrictracAction::Mark); } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); - // Ajouter aussi les mouvements possibles + // Ajoute aussi les mouvements possibles let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } TurnStage::Move => { let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); for (move1, move2) in possible_moves { - let diff_move1 = move1.get_to() - move1.get_from(); - valid_actions.push(TrictracAction::Move { - dice_order: diff_move1 == game_state.dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - }); + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); } } } @@ -310,6 +312,56 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { valid_actions } +// Valid only for White player +fn checker_moves_to_trictrac_action( + move1: &CheckerMove, + move2: &CheckerMove, + dice: &Dice, +) -> TrictracAction { + let to1 = move1.get_to(); + let to2 = move2.get_to(); + let from1 = move1.get_from(); + let from2 = move2.get_from(); + + let mut diff_move1 = if to1 > 0 { + // Mouvement sans sortie + to1 - from1 + } else { + // sortie, on utilise la valeur du dé + if to2 > 0 { + // sortie pour le mouvement 1 uniquement + let dice2 = to2 - from2; + if dice2 == dice.values.0 as usize { + dice.values.1 as usize + } else { + dice.values.0 as usize + } + } else { + // double sortie + if from1 < from2 { + max(dice.values.0, dice.values.1) as usize + } else { + min(dice.values.0, dice.values.1) as usize + } + } + }; + + // modification de diff_move1 si on est dans le cas d'un mouvement par puissance + let rest_field = 12; + if to1 == rest_field + && to2 == rest_field + && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field + { + // prise par puissance + diff_move1 += 1; + } + TrictracAction::Move { + dice_order: diff_move1 == dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), + } +} + /// Retourne les indices des actions valides pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { get_valid_actions(game_state) diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 2b935f5..8d9db57 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -1,4 +1,4 @@ -use crate::{Color, GameState, PlayerId}; +use crate::{CheckerMove, Color, GameState, PlayerId}; use rand::prelude::SliceRandom; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; @@ -251,14 +251,15 @@ impl TrictracEnv { player_id: self.agent_player_id, }) } - TrictracAction::Mark { points } => { - // Marquer des points - reward += 0.1 * points as f32; - Some(GameEvent::Mark { - player_id: self.agent_player_id, - points, - }) - } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game_state. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.agent_player_id, + // points, + // }) + // } TrictracAction::Go => { // Continuer après avoir gagné un trou reward += 0.2; @@ -272,8 +273,23 @@ impl TrictracEnv { from2, } => { // Effectuer un mouvement - let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + let (dice1, dice2) = if dice_order { + (self.game_state.dice.values.0, self.game_state.dice.values.1) + } else { + (self.game_state.dice.values.1, self.game_state.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); reward += 0.2; Some(GameEvent::Move { @@ -360,7 +376,9 @@ impl TrictracEnv { // Stratégie simple : choix aléatoire let mut rng = thread_rng(); - let choosen_move = *possible_moves.choose(&mut rng).unwrap(); + let choosen_move = *possible_moves + .choose(&mut rng) + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); GameEvent::Move { player_id: self.opponent_player_id, @@ -443,7 +461,6 @@ impl DqnTrainer { for episode in 1..=episodes { let reward = self.train_episode(); - print!("."); if episode % 100 == 0 { println!( "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", diff --git a/store/src/game.rs b/store/src/game.rs index ed77519..fe2762f 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -757,6 +757,7 @@ mod tests { #[test] fn hold_or_go() { let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -782,6 +783,7 @@ mod tests { // Hold let mut game_state = init_test_gamestate(TurnStage::MarkPoints); + game_state.schools_enabled = true; let pid = game_state.active_player_id; game_state.consume( &(GameEvent::Mark { @@ -802,6 +804,6 @@ mod tests { assert_ne!(game_state.active_player_id, pid); assert_eq!(game_state.players.get(&pid).unwrap().points, 1); assert_eq!(game_state.get_active_player().unwrap().points, 0); - assert_eq!(game_state.turn_stage, TurnStage::RollDice); + assert_eq!(game_state.turn_stage, TurnStage::MarkAdvPoints); } }