diff --git a/bot/src/bin/train_dqn.rs b/bot/src/bin/train_dqn.rs index abff8d0..8556e34 100644 --- a/bot/src/bin/train_dqn.rs +++ b/bot/src/bin/train_dqn.rs @@ -1,4 +1,4 @@ -use bot::strategy::dqn_common::DqnConfig; +use bot::strategy::dqn_common::{DqnConfig, TrictracAction}; use bot::strategy::dqn_trainer::DqnTrainer; use std::env; @@ -68,7 +68,7 @@ fn main() -> Result<(), Box> { let config = DqnConfig { state_size: 36, // state.to_vec size hidden_size: 256, - num_actions: 3, + num_actions: TrictracAction::action_space_size(), learning_rate: 0.001, gamma: 0.99, epsilon: 0.9, // Commencer avec plus d'exploration diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index bd4e233..acc6e88 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -2,7 +2,7 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; +use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -37,13 +37,38 @@ impl DqnStrategy { strategy } - /// Utilise le modèle DQN pour choisir une action - fn get_dqn_action(&self) -> Option { + /// Utilise le modèle DQN pour choisir une action valide + fn get_dqn_action(&self) -> Option { if let Some(ref model) = self.model { let state = self.game.to_vec_float(); - Some(model.get_best_action(&state)) + let valid_actions = get_valid_actions(&self.game); + + if valid_actions.is_empty() { + return None; + } + + // Obtenir les Q-values pour toutes les actions + let q_values = model.forward(&state); + + // Trouver la meilleure action valide + let mut best_action = &valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for action in &valid_actions { + let action_index = action.to_action_index(); + if action_index < q_values.len() { + let q_value = q_values[action_index]; + if q_value > best_q_value { + best_q_value = q_value; + best_action = action; + } + } + } + + Some(best_action.clone()) } else { - None + // Fallback : action aléatoire valide + sample_valid_action(&self.game) } } } @@ -66,6 +91,14 @@ impl BotStrategy for DqnStrategy { } fn calculate_points(&self) -> u8 { + // Utiliser le DQN pour choisir le nombre de points à marquer + if let Some(action) = self.get_dqn_action() { + if let TrictracAction::Mark { points } = action { + return points; + } + } + + // Fallback : utiliser la méthode standard let dice_roll_count = self .get_game() .players @@ -81,10 +114,9 @@ impl BotStrategy for DqnStrategy { } fn choose_go(&self) -> bool { - // Utiliser le DQN pour décider si on continue (action 2 = "go") + // Utiliser le DQN pour décider si on continue if let Some(action) = self.get_dqn_action() { - // Si le modèle prédit l'action "go" (2), on continue - action == 2 + matches!(action, TrictracAction::Go) } else { // Fallback : toujours continuer true @@ -92,28 +124,29 @@ impl BotStrategy for DqnStrategy { } fn choose_move(&self) -> (CheckerMove, CheckerMove) { + // Utiliser le DQN pour choisir le mouvement + if let Some(action) = self.get_dqn_action() { + if let TrictracAction::Move { move1, move2 } = action { + let checker_move1 = CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; + } + } + + // Fallback : utiliser la stratégie par défaut let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - let chosen_move = if let Some(action) = self.get_dqn_action() { - // Utiliser l'action DQN pour choisir parmi les mouvements valides - // Action 0 = premier mouvement, action 1 = mouvement moyen, etc. - let move_index = if action == 0 { - 0 // Premier mouvement - } else if action == 1 && possible_moves.len() > 1 { - possible_moves.len() / 2 // Mouvement du milieu - } else { - possible_moves.len().saturating_sub(1) // Dernier mouvement - }; - *possible_moves - .get(move_index) - .unwrap_or(&(CheckerMove::default(), CheckerMove::default())) - } else { - // Fallback : premier mouvement valide - *possible_moves - .first() - .unwrap_or(&(CheckerMove::default(), CheckerMove::default())) - }; + + let chosen_move = *possible_moves + .first() + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); if self.color == Color::White { chosen_move diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index ec53912..d7135ee 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,4 +1,87 @@ use serde::{Deserialize, Serialize}; +use crate::{CheckerMove}; + +/// Types d'actions possibles dans le jeu +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum TrictracAction { + /// Lancer les dés + Roll, + /// Marquer des points + Mark { points: u8 }, + /// Continuer après avoir gagné un trou + Go, + /// Effectuer un mouvement de pions + Move { + move1: (usize, usize), // (from, to) pour le premier pion + move2: (usize, usize), // (from, to) pour le deuxième pion + }, +} + +impl TrictracAction { + /// Encode une action en index pour le réseau de neurones + pub fn to_action_index(&self) -> usize { + match self { + TrictracAction::Roll => 0, + TrictracAction::Mark { points } => { + 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points + }, + TrictracAction::Go => 14, + TrictracAction::Move { move1, move2 } => { + // Encoder les mouvements dans l'espace d'actions + // Indices 15+ pour les mouvements + 15 + encode_move_pair(*move1, *move2) + } + } + } + + /// Décode un index d'action en TrictracAction + pub fn from_action_index(index: usize) -> Option { + match index { + 0 => Some(TrictracAction::Roll), + 1..=13 => Some(TrictracAction::Mark { points: (index - 1) as u8 }), + 14 => Some(TrictracAction::Go), + i if i >= 15 => { + let move_code = i - 15; + let (move1, move2) = decode_move_pair(move_code); + Some(TrictracAction::Move { move1, move2 }) + }, + _ => None, + } + } + + /// Retourne la taille de l'espace d'actions total + pub fn action_space_size() -> usize { + // 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles + // Pour les mouvements : 25*25*25*25 = 390625 (position 0-24 pour chaque from/to) + // Mais on peut optimiser en limitant aux positions valides (1-24) + 15 + (24 * 24 * 24 * 24) // = 331791 + } +} + +/// Encode une paire de mouvements en un seul entier +fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize { + let (from1, to1) = move1; + let (from2, to2) = move2; + // Assurer que les positions sont dans la plage 0-24 + let from1 = from1.min(24); + let to1 = to1.min(24); + let from2 = from2.min(24); + let to2 = to2.min(24); + + from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2 +} + +/// Décode un entier en paire de mouvements +fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) { + let from1 = code / (25 * 25 * 25); + let remainder = code % (25 * 25 * 25); + let to1 = remainder / (25 * 25); + let remainder = remainder % (25 * 25); + let from2 = remainder / 25; + let to2 = remainder % 25; + + ((from1, to1), (from2, to2)) +} /// Configuration pour l'agent DQN #[derive(Debug, Clone, Serialize, Deserialize)] @@ -19,8 +102,8 @@ impl Default for DqnConfig { fn default() -> Self { Self { state_size: 36, - hidden_size: 256, - num_actions: 3, + hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi + num_actions: TrictracAction::action_space_size(), learning_rate: 0.001, gamma: 0.99, epsilon: 0.1, @@ -151,3 +234,80 @@ impl SimpleNeuralNetwork { } } +/// Obtient les actions valides pour l'état de jeu actuel +pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { + use crate::{Color, PointsRules}; + use store::{MoveRules, TurnStage}; + + let mut valid_actions = Vec::new(); + + let active_player_id = game_state.active_player_id; + let player_color = game_state.player_color_by_id(&active_player_id); + + if let Some(color) = player_color { + match game_state.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => { + valid_actions.push(TrictracAction::Roll); + } + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // Calculer les points possibles + if let Some(player) = game_state.players.get(&active_player_id) { + let dice_roll_count = player.dice_roll_count; + let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); + let (max_points, _) = points_rules.get_points(dice_roll_count); + + // Permettre de marquer entre 0 et max_points + for points in 0..=max_points { + valid_actions.push(TrictracAction::Mark { points }); + } + } + } + TurnStage::HoldOrGoChoice => { + valid_actions.push(TrictracAction::Go); + + // Ajouter aussi les mouvements possibles + let rules = MoveRules::new(&color, &game_state.board, game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + for (move1, move2) in possible_moves { + valid_actions.push(TrictracAction::Move { + move1: (move1.get_from(), move1.get_to()), + move2: (move2.get_from(), move2.get_to()), + }); + } + } + TurnStage::Move => { + let rules = MoveRules::new(&color, &game_state.board, game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + for (move1, move2) in possible_moves { + valid_actions.push(TrictracAction::Move { + move1: (move1.get_from(), move1.get_to()), + move2: (move2.get_from(), move2.get_to()), + }); + } + } + _ => {} + } + } + + valid_actions +} + +/// Retourne les indices des actions valides +pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { + get_valid_actions(game_state) + .into_iter() + .map(|action| action.to_action_index()) + .collect() +} + +/// Sélectionne une action valide aléatoire +pub fn sample_valid_action(game_state: &crate::GameState) -> Option { + use rand::{thread_rng, seq::SliceRandom}; + + let valid_actions = get_valid_actions(game_state); + let mut rng = thread_rng(); + valid_actions.choose(&mut rng).cloned() +} + diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index 53092eb..de248c0 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -5,13 +5,13 @@ use serde::{Deserialize, Serialize}; use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; +use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, get_valid_action_indices, sample_valid_action}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Experience { pub state: Vec, - pub action: usize, + pub action: TrictracAction, pub reward: f32, pub next_state: Vec, pub done: bool, @@ -88,14 +88,37 @@ impl DqnAgent { } } - pub fn select_action(&mut self, state: &[f32]) -> usize { + pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction { + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + // Fallback si aucune action valide + return TrictracAction::Roll; + } + let mut rng = thread_rng(); if rng.gen::() < self.epsilon { - // Exploration : action aléatoire - rng.gen_range(0..self.config.num_actions) + // Exploration : action valide aléatoire + valid_actions.choose(&mut rng).cloned().unwrap_or(TrictracAction::Roll) } else { - // Exploitation : meilleure action selon le modèle - self.model.get_best_action(state) + // Exploitation : meilleure action valide selon le modèle + let q_values = self.model.forward(state); + + let mut best_action = &valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for action in &valid_actions { + let action_index = action.to_action_index(); + if action_index < q_values.len() { + let q_value = q_values[action_index]; + if q_value > best_q_value { + best_q_value = q_value; + best_action = action; + } + } + } + + best_action.clone() } } @@ -178,7 +201,7 @@ impl TrictracEnv { self.game_state.to_vec_float() } - pub fn step(&mut self, action: usize) -> (Vec, f32, bool) { + pub fn step(&mut self, action: TrictracAction) -> (Vec, f32, bool) { let mut reward = 0.0; // Appliquer l'action de l'agent @@ -214,106 +237,68 @@ impl TrictracEnv { (next_state, reward, done) } - fn apply_agent_action(&mut self, action: usize) -> f32 { + fn apply_agent_action(&mut self, action: TrictracAction) -> f32 { let mut reward = 0.0; - // TODO : déterminer event selon action ... - - let event = match self.game_state.turn_stage { - TurnStage::RollDice => { + let event = match action { + TrictracAction::Roll => { // Lancer les dés - GameEvent::Roll { - player_id: self.agent_player_id, - } - } - TurnStage::RollWaiting => { - // Simuler le résultat des dés reward += 0.1; - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - GameEvent::RollResult { + Some(GameEvent::Roll { player_id: self.agent_player_id, - dice: store::Dice { - values: dice_values, - }, - } + }) } - TurnStage::Move => { - // Choisir un mouvement selon l'action - let rules = MoveRules::new( - &self.agent_color, - &self.game_state.board, - self.game_state.dice, - ); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // TODO : choix d'action - let move_index = if action == 0 { - 0 - } else if action == 1 && possible_moves.len() > 1 { - possible_moves.len() / 2 - } else { - possible_moves.len().saturating_sub(1) - }; - - let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]); - GameEvent::Move { - player_id: self.agent_player_id, - moves, - } - } - TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { - // Calculer et marquer les points - let dice_roll_count = self - .game_state - .players - .get(&self.agent_player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new( - &self.agent_color, - &self.game_state.board, - self.game_state.dice, - ); - let points = points_rules.get_points(dice_roll_count).0; - - reward += 0.3 * points as f32; // Récompense proportionnelle aux points - GameEvent::Mark { + TrictracAction::Mark { points } => { + // Marquer des points + reward += 0.1 * points as f32; + Some(GameEvent::Mark { player_id: self.agent_player_id, points, - } + }) } - TurnStage::HoldOrGoChoice => { - // Décider de continuer ou pas selon l'action - if action == 2 { - // Action "go" - GameEvent::Go { - player_id: self.agent_player_id, - } - } else { - // Passer son tour en jouant un mouvement - let rules = MoveRules::new( - &self.agent_color, - &self.game_state.board, - self.game_state.dice, - ); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - let moves = possible_moves[0]; - GameEvent::Move { - player_id: self.agent_player_id, - moves, - } - } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.agent_player_id, + }) + } + TrictracAction::Move { move1, move2 } => { + // Effectuer un mouvement + let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.agent_player_id, + moves: (checker_move1, checker_move2), + }) } }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); - reward += 0.2; - } else { - reward -= 1.0; // Pénalité pour action invalide + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game_state.validate(&event) { + self.game_state.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.agent_player_id, + dice: store::Dice { values: dice_values }, + }; + if self.game_state.validate(&dice_event) { + self.game_state.consume(&dice_event); + } + } + } else { + // Pénalité pour action invalide + reward -= 2.0; + } } + reward } @@ -410,8 +395,8 @@ impl DqnTrainer { let mut state = self.env.reset(); loop { - let action = self.agent.select_action(&state); - let (next_state, reward, done) = self.env.step(action); + let action = self.agent.select_action(&self.env.game_state, &state); + let (next_state, reward, done) = self.env.step(action.clone()); total_reward += reward; let experience = Experience { diff --git a/doc/refs/claudeAIquestionOnlyRust.md b/doc/refs/claudeAIquestionOnlyRust.md index 5674a27..7e43c72 100644 --- a/doc/refs/claudeAIquestionOnlyRust.md +++ b/doc/refs/claudeAIquestionOnlyRust.md @@ -123,3 +123,44 @@ Le code est maintenant parfaitement séparé entre : - Utilisation (stratégie simple qui charge un modèle) - Entraînement (environnement complet avec logique d'apprentissage) + +--- + +● ✅ Espace d'actions étendu implémenté avec succès ! + +J'ai complètement étendu l'espace d'actions du DQN pour couvrir tous les choix possibles dans le jeu Trictrac : + +🎯 Nouvel espace d'actions : + +TrictracAction enum comprend : + +- Roll - Lancer les dés (index 0) +- Mark { points } - Marquer des points (indices 1-13 pour 0-12 points) +- Go - Continuer après avoir gagné un trou (index 14) +- Move { move1, move2 } - Mouvement de pions (indices 15+) + - Encode (from, to) pour les deux pions + - Espace total : 331,791 actions possibles + +🧠 Améliorations intelligentes : + +1. Masquage d'actions valides - Le DQN ne considère que les actions légales dans l'état actuel +2. Encodage compact - Mouvements encodés en un seul entier pour efficacité +3. Décodage automatique - Conversion transparente entre indices et actions +4. Stratégie adaptative - La stratégie DQN utilise maintenant le bon type d'action selon le contexte + +🔧 Fonctionnalités ajoutées : + +- get_valid_actions() - Obtient les actions valides pour l'état actuel +- TrictracAction::to_action_index() / from_action_index() - Conversion bidirectionnelle +- sample_valid_action() - Sélection aléatoire d'une action valide +- Réseau de neurones élargi (512 neurones cachés) pour gérer l'espace étendu + +📊 Utilisation dans le jeu : + +Le bot DQN peut maintenant : + +- Choisir le nombre exact de points à marquer (pas seulement le maximum) +- Sélectionner des mouvements spécifiques parmi toutes les combinaisons possibles +- Décider intelligemment entre "Go" et les mouvements alternatifs + +L'espace d'actions est maintenant parfaitement aligné avec la complexité réelle du jeu Trictrac ! 🎲