From ebe98ca229f9b17f29f11c473c8fa73454c5aa6a Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 1 Jun 2025 20:21:38 +0200 Subject: [PATCH] debug --- bot/src/strategy/dqn.rs | 2 +- bot/src/strategy/dqn_common.rs | 51 +++++++++++++------------- bot/src/strategy/dqn_trainer.rs | 31 +++++++++++----- bot/src/strategy/stable_baselines3.rs | 53 +++++++++++++-------------- store/src/game.rs | 2 +- 5 files changed, 73 insertions(+), 66 deletions(-) diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index acc6e88..d2fc9ed 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -2,7 +2,7 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; +use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index d7135ee..2390da4 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use crate::{CheckerMove}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -11,9 +10,9 @@ pub enum TrictracAction { /// Continuer après avoir gagné un trou Go, /// Effectuer un mouvement de pions - Move { - move1: (usize, usize), // (from, to) pour le premier pion - move2: (usize, usize), // (from, to) pour le deuxième pion + Move { + move1: (usize, usize), // (from, to) pour le premier pion + move2: (usize, usize), // (from, to) pour le deuxième pion }, } @@ -23,8 +22,8 @@ impl TrictracAction { match self { TrictracAction::Roll => 0, TrictracAction::Mark { points } => { - 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points - }, + 1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points + } TrictracAction::Go => 14, TrictracAction::Move { move1, move2 } => { // Encoder les mouvements dans l'espace d'actions @@ -33,22 +32,24 @@ impl TrictracAction { } } } - + /// Décode un index d'action en TrictracAction pub fn from_action_index(index: usize) -> Option { match index { 0 => Some(TrictracAction::Roll), - 1..=13 => Some(TrictracAction::Mark { points: (index - 1) as u8 }), + 1..=13 => Some(TrictracAction::Mark { + points: (index - 1) as u8, + }), 14 => Some(TrictracAction::Go), i if i >= 15 => { let move_code = i - 15; let (move1, move2) = decode_move_pair(move_code); Some(TrictracAction::Move { move1, move2 }) - }, + } _ => None, } } - + /// Retourne la taille de l'espace d'actions total pub fn action_space_size() -> usize { // 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles @@ -67,7 +68,7 @@ fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize { let to1 = to1.min(24); let from2 = from2.min(24); let to2 = to2.min(24); - + from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2 } @@ -79,7 +80,7 @@ fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) { let remainder = remainder % (25 * 25); let from2 = remainder / 25; let to2 = remainder % 25; - + ((from1, to1), (from2, to2)) } @@ -102,7 +103,7 @@ impl Default for DqnConfig { fn default() -> Self { Self { state_size: 36, - hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi + hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi num_actions: TrictracAction::action_space_size(), learning_rate: 0.001, gamma: 0.99, @@ -236,14 +237,14 @@ impl SimpleNeuralNetwork { /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use crate::{Color, PointsRules}; + use crate::PointsRules; use store::{MoveRules, TurnStage}; - + let mut valid_actions = Vec::new(); - + let active_player_id = game_state.active_player_id; let player_color = game_state.player_color_by_id(&active_player_id); - + if let Some(color) = player_color { match game_state.turn_stage { TurnStage::RollDice | TurnStage::RollWaiting => { @@ -255,7 +256,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { let dice_roll_count = player.dice_roll_count; let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice); let (max_points, _) = points_rules.get_points(dice_roll_count); - + // Permettre de marquer entre 0 et max_points for points in 0..=max_points { valid_actions.push(TrictracAction::Mark { points }); @@ -264,11 +265,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { } TurnStage::HoldOrGoChoice => { valid_actions.push(TrictracAction::Go); - + // Ajouter aussi les mouvements possibles let rules = MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + for (move1, move2) in possible_moves { valid_actions.push(TrictracAction::Move { move1: (move1.get_from(), move1.get_to()), @@ -279,7 +280,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { TurnStage::Move => { let rules = MoveRules::new(&color, &game_state.board, game_state.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + for (move1, move2) in possible_moves { valid_actions.push(TrictracAction::Move { move1: (move1.get_from(), move1.get_to()), @@ -287,10 +288,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { }); } } - _ => {} } } - + valid_actions } @@ -304,10 +304,9 @@ pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { /// Sélectionne une action valide aléatoire pub fn sample_valid_action(game_state: &crate::GameState) -> Option { - use rand::{thread_rng, seq::SliceRandom}; - + use rand::{seq::SliceRandom, thread_rng}; + let valid_actions = get_valid_actions(game_state); let mut rng = thread_rng(); valid_actions.choose(&mut rng).cloned() } - diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index de248c0..67c3e39 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, get_valid_action_indices, sample_valid_action}; +use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] @@ -90,23 +90,26 @@ impl DqnAgent { pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction { let valid_actions = get_valid_actions(game_state); - + if valid_actions.is_empty() { // Fallback si aucune action valide return TrictracAction::Roll; } - + let mut rng = thread_rng(); if rng.gen::() < self.epsilon { // Exploration : action valide aléatoire - valid_actions.choose(&mut rng).cloned().unwrap_or(TrictracAction::Roll) + valid_actions + .choose(&mut rng) + .cloned() + .unwrap_or(TrictracAction::Roll) } else { // Exploitation : meilleure action valide selon le modèle let q_values = self.model.forward(state); - + let mut best_action = &valid_actions[0]; let mut best_q_value = f32::NEG_INFINITY; - + for action in &valid_actions { let action_index = action.to_action_index(); if action_index < q_values.len() { @@ -117,7 +120,7 @@ impl DqnAgent { } } } - + best_action.clone() } } @@ -267,7 +270,7 @@ impl TrictracEnv { // Effectuer un mouvement let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - + reward += 0.2; Some(GameEvent::Move { player_id: self.agent_player_id, @@ -280,14 +283,16 @@ impl TrictracEnv { if let Some(event) = event { if self.game_state.validate(&event) { self.game_state.consume(&event); - + // Simuler le résultat des dés après un Roll if matches!(action, TrictracAction::Roll) { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); let dice_event = GameEvent::RollResult { player_id: self.agent_player_id, - dice: store::Dice { values: dice_values }, + dice: store::Dice { + values: dice_values, + }, }; if self.game_state.validate(&dice_event) { self.game_state.consume(&dice_event); @@ -393,8 +398,10 @@ impl DqnTrainer { pub fn train_episode(&mut self) -> f32 { let mut total_reward = 0.0; let mut state = self.env.reset(); + // let mut step_count = 0; loop { + // step_count += 1; let action = self.agent.select_action(&self.env.game_state, &state); let (next_state, reward, done) = self.env.step(action.clone()); total_reward += reward; @@ -412,6 +419,9 @@ impl DqnTrainer { if done { break; } + // if step_count % 100 == 0 { + // println!("{:?}", next_state); + // } state = next_state; } @@ -429,6 +439,7 @@ impl DqnTrainer { for episode in 1..=episodes { let reward = self.train_episode(); + print!("."); if episode % 100 == 0 { println!( "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", diff --git a/bot/src/strategy/stable_baselines3.rs b/bot/src/strategy/stable_baselines3.rs index 124e95d..4b94311 100644 --- a/bot/src/strategy/stable_baselines3.rs +++ b/bot/src/strategy/stable_baselines3.rs @@ -1,11 +1,11 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; -use store::MoveRules; -use std::process::Command; -use std::io::Write; +use serde::{Deserialize, Serialize}; use std::fs::File; use std::io::Read; +use std::io::Write; use std::path::Path; -use serde::{Serialize, Deserialize}; +use std::process::Command; +use store::MoveRules; #[derive(Debug)] pub struct StableBaselines3Strategy { @@ -62,21 +62,21 @@ impl StableBaselines3Strategy { fn get_state_as_json(&self) -> GameStateJson { // Convertir l'état du jeu en un format compatible avec notre modèle Python let mut board = vec![0; 24]; - + // Remplir les positions des pièces blanches (valeurs positives) for (pos, count) in self.game.board.get_color_fields(Color::White) { if pos < 24 { board[pos] = count as i8; } } - + // Remplir les positions des pièces noires (valeurs négatives) for (pos, count) in self.game.board.get_color_fields(Color::Black) { if pos < 24 { board[pos] = -(count as i8); } } - + // Convertir l'étape du tour en entier let turn_stage = match self.game.turn_stage { store::TurnStage::RollDice => 0, @@ -85,15 +85,14 @@ impl StableBaselines3Strategy { store::TurnStage::HoldOrGoChoice => 3, store::TurnStage::Move => 4, store::TurnStage::MarkAdvPoints => 5, - _ => 0, }; - + // Récupérer les points et trous des joueurs let white_points = self.game.players.get(&1).map_or(0, |p| p.points); let white_holes = self.game.players.get(&1).map_or(0, |p| p.holes); let black_points = self.game.players.get(&2).map_or(0, |p| p.points); let black_holes = self.game.players.get(&2).map_or(0, |p| p.holes); - + // Créer l'objet JSON GameStateJson { board, @@ -111,12 +110,12 @@ impl StableBaselines3Strategy { // Convertir l'état du jeu en JSON let state_json = self.get_state_as_json(); let state_str = serde_json::to_string(&state_json).unwrap(); - + // Écrire l'état dans un fichier temporaire let temp_input_path = "temp_state.json"; let mut file = File::create(temp_input_path).ok()?; file.write_all(state_str.as_bytes()).ok()?; - + // Exécuter le script Python pour faire une prédiction let output_path = "temp_action.json"; let python_script = format!( @@ -164,32 +163,29 @@ with open("{}", "w") as f: "#, self.model_path, output_path ); - + let temp_script_path = "temp_predict.py"; let mut script_file = File::create(temp_script_path).ok()?; script_file.write_all(python_script.as_bytes()).ok()?; - + // Exécuter le script Python - let status = Command::new("python") - .arg(temp_script_path) - .status() - .ok()?; - + let status = Command::new("python").arg(temp_script_path).status().ok()?; + if !status.success() { return None; } - + // Lire la prédiction if Path::new(output_path).exists() { let mut file = File::open(output_path).ok()?; let mut contents = String::new(); file.read_to_string(&mut contents).ok()?; - + // Nettoyer les fichiers temporaires std::fs::remove_file(temp_input_path).ok(); std::fs::remove_file(temp_script_path).ok(); std::fs::remove_file(output_path).ok(); - + // Analyser la prédiction let action: ActionJson = serde_json::from_str(&contents).ok()?; Some(action) @@ -203,7 +199,7 @@ impl BotStrategy for StableBaselines3Strategy { fn get_game(&self) -> &GameState { &self.game } - + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } @@ -224,7 +220,7 @@ impl BotStrategy for StableBaselines3Strategy { return self.game.dice.values.0 + self.game.dice.values.1; } } - + // Fallback vers la méthode standard si la prédiction échoue let dice_roll_count = self .get_game() @@ -245,7 +241,7 @@ impl BotStrategy for StableBaselines3Strategy { if let Some(action) = self.predict_action() { return action.action_type == 2; } - + // Fallback vers la méthode standard si la prédiction échoue true } @@ -259,18 +255,19 @@ impl BotStrategy for StableBaselines3Strategy { return (move1, move2); } } - + // Fallback vers la méthode standard si la prédiction échoue let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let choosen_move = *possible_moves .first() .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); - + if self.color == Color::White { choosen_move } else { (choosen_move.0.mirror(), choosen_move.1.mirror()) } } -} \ No newline at end of file +} + diff --git a/store/src/game.rs b/store/src/game.rs index 1ef8a39..477895f 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -174,7 +174,7 @@ impl GameState { state.push(self.dice.values.0 as i8); state.push(self.dice.values.1 as i8); - // points length=4 x2 joueurs = 8 + // points, trous, bredouille, grande bredouille length=4 x2 joueurs = 8 let white_player: Vec = self .get_white_player() .unwrap()