diff --git a/bot/src/strategy.rs b/bot/src/strategy.rs index d3d04ab..5c36e04 100644 --- a/bot/src/strategy.rs +++ b/bot/src/strategy.rs @@ -1,3 +1,4 @@ +pub mod burn_environment; pub mod client; pub mod default; pub mod dqn; diff --git a/bot/src/strategy/burn_environment.rs b/bot/src/strategy/burn_environment.rs index aa103df..a9f58ba 100644 --- a/bot/src/strategy/burn_environment.rs +++ b/bot/src/strategy/burn_environment.rs @@ -1,13 +1,12 @@ -use burn::{backend::Backend, tensor::Tensor}; +use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; -use crate::GameState; -use store::{Color, Game, PlayerId}; -use std::collections::HashMap; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] pub struct TrictracState { - pub data: [f32; 36], // Représentation vectorielle de l'état du jeu + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu } impl State for TrictracState { @@ -26,14 +25,14 @@ impl TrictracState { /// Convertit un GameState en TrictracState pub fn from_game_state(game_state: &GameState) -> Self { let state_vec = game_state.to_vec(); - let mut data = [0.0f32; 36]; - + let mut data = [0; 36]; + // Copier les données en s'assurant qu'on ne dépasse pas la taille let copy_len = state_vec.len().min(36); for i in 0..copy_len { data[i] = state_vec[i]; } - + TrictracState { data } } } @@ -81,8 +80,8 @@ impl From for u32 { /// Environnement Trictrac pour burn-rl #[derive(Debug)] pub struct TrictracEnvironment { - game: Game, - active_player_id: PlayerId, + game: GameState, + active_player_id: PlayerId, opponent_id: PlayerId, current_state: TrictracState, episode_reward: f32, @@ -98,17 +97,15 @@ impl Environment for TrictracEnvironment { const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { - let mut game = Game::new(); - + let mut game = GameState::new(false); + // Ajouter deux joueurs - let player1_id = game.add_player("DQN Agent".to_string(), Color::White); - let player2_id = game.add_player("Opponent".to_string(), Color::Black); - - game.start(); - - let game_state = game.get_state(); - let current_state = TrictracState::from_game_state(&game_state); - + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + let current_state = TrictracState::from_game_state(&game); TrictracEnvironment { game, active_player_id: player1_id, @@ -126,36 +123,28 @@ impl Environment for TrictracEnvironment { fn reset(&mut self) -> Snapshot { // Réinitialiser le jeu - self.game = Game::new(); - self.active_player_id = self.game.add_player("DQN Agent".to_string(), Color::White); - self.opponent_id = self.game.add_player("Opponent".to_string(), Color::Black); - self.game.start(); - - let game_state = self.game.get_state(); - self.current_state = TrictracState::from_game_state(&game_state); + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + self.current_state = TrictracState::from_game_state(&self.game); self.episode_reward = 0.0; self.step_count = 0; - Snapshot { - state: self.current_state, - reward: 0.0, - terminated: false, - } + Snapshot::new(self.current_state, 0.0, false) } fn step(&mut self, action: Self::ActionType) -> Snapshot { self.step_count += 1; - - let game_state = self.game.get_state(); - + // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = self.convert_action(action, &game_state); - + let trictrac_action = self.convert_action(action, &self.game); + let mut reward = 0.0; let mut terminated = false; - + // Exécuter l'action si c'est le tour de l'agent DQN - if game_state.active_player_id == self.active_player_id { + if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { match self.execute_action(action) { Ok(action_reward) => { @@ -171,102 +160,226 @@ impl Environment for TrictracEnvironment { reward = -0.5; } } - + // Jouer l'adversaire si c'est son tour - self.play_opponent_if_needed(); - - // Vérifier fin de partie - let updated_state = self.game.get_state(); - if updated_state.is_finished() || self.step_count >= Self::MAX_STEPS { + reward += self.play_opponent_if_needed(); + + // Vérifier si la partie est terminée + let done = self.game.stage == Stage::Ended + || self.game.determine_winner().is_some() + || self.step_count >= Self::MAX_STEPS; + + if done { terminated = true; - // Récompense finale basée sur le résultat - if let Some(winner_id) = updated_state.winner { + if let Some(winner_id) = self.game.determine_winner() { if winner_id == self.active_player_id { - reward += 10.0; // Victoire + reward += 100.0; // Victoire } else { - reward -= 10.0; // Défaite + reward -= 50.0; // Défaite } } } - + // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&updated_state); + self.current_state = TrictracState::from_game_state(&self.game); self.episode_reward += reward; - + if self.visualized && terminated { - println!("Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count); + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); } - Snapshot { - state: self.current_state, - reward, - terminated, - } + Snapshot::new(self.current_state, reward, terminated) } } impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac - fn convert_action(&self, action: TrictracAction, game_state: &GameState) -> Option { - use super::dqn_common::{get_valid_compact_actions, CompactAction}; - + fn convert_action( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use super::dqn_common::get_valid_actions; + // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_compact_actions(game_state); - + let valid_actions = get_valid_actions(game_state); + if valid_actions.is_empty() { return None; } - + // Mapper l'index d'action sur une action valide let action_index = (action.index as usize) % valid_actions.len(); - let compact_action = &valid_actions[action_index]; - - // Convertir l'action compacte vers une action Trictrac complète - compact_action.to_trictrac_action(game_state) + Some(valid_actions[action_index].clone()) } - + /// Exécute une action Trictrac dans le jeu - fn execute_action(&mut self, action: super::dqn_common::TrictracAction) -> Result> { + fn execute_action( + &mut self, + action: super::dqn_common::TrictracAction, + ) -> Result> { use super::dqn_common::TrictracAction; - + let mut reward = 0.0; - - match action { + + let event = match action { TrictracAction::Roll => { - self.game.roll_dice_for_player(&self.active_player_id)?; - reward = 0.1; // Petite récompense pour une action valide - } - TrictracAction::Mark { points } => { - self.game.mark_points_for_player(&self.active_player_id, points)?; - reward = points as f32 * 0.1; // Récompense proportionnelle aux points + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } TrictracAction::Go => { - self.game.go_for_player(&self.active_player_id)?; - reward = 0.2; // Récompense pour continuer + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) } - TrictracAction::Move { move1, move2 } => { - let checker_move1 = store::CheckerMove::new(move1.0, move1.1)?; - let checker_move2 = store::CheckerMove::new(move2.0, move2.1)?; - self.game.move_checker_for_player(&self.active_player_id, checker_move1, checker_move2)?; - reward = 0.3; // Récompense pour un mouvement réussi + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + } + } + } else { + // Pénalité pour action invalide + reward -= 2.0; } } - + Ok(reward) } - + /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) { - let game_state = self.game.get_state(); - + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + // Si c'est le tour de l'adversaire, jouer automatiquement - if game_state.active_player_id == self.opponent_id && !game_state.is_finished() { - // Utiliser une stratégie simple pour l'adversaire (dummy bot) - if let Ok(_) = crate::strategy::dummy::get_dummy_action(&mut self.game, &self.opponent_id) { - // L'action a été exécutée par get_dummy_action + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use super::default::DefaultStrategy; + use crate::BotStrategy; + + let mut default_strategy = DefaultStrategy::default(); + default_strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + default_strategy.set_color(color); + } + *default_strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).0; + reward -= 0.3 * points as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => { + let (move1, move2) = default_strategy.choose_move(); + GameEvent::Move { + player_id: self.opponent_id, + moves: (move1.mirror(), move2.mirror()), + } + } + }; + + if self.game.validate(&event) { + self.game.consume(&event); } } + reward } -} \ No newline at end of file +} + diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs deleted file mode 100644 index cc690dd..0000000 --- a/bot/src/strategy/mod.rs +++ /dev/null @@ -1,47 +0,0 @@ -pub mod burn_environment; -pub mod client; -pub mod default; -pub mod dqn; -pub mod dqn_common; -pub mod dqn_trainer; -pub mod erroneous_moves; -pub mod stable_baselines3; - -pub mod dummy { - use store::{Color, Game, PlayerId}; - - /// Action simple pour l'adversaire dummy - pub fn get_dummy_action(game: &mut Game, player_id: &PlayerId) -> Result<(), Box> { - let game_state = game.get_state(); - - match game_state.turn_stage { - store::TurnStage::RollDice => { - game.roll_dice_for_player(player_id)?; - } - store::TurnStage::MarkPoints | store::TurnStage::MarkAdvPoints => { - // Marquer 0 points (stratégie conservatrice) - game.mark_points_for_player(player_id, 0)?; - } - store::TurnStage::HoldOrGoChoice => { - // Toujours choisir "Go" (stratégie simple) - game.go_for_player(player_id)?; - } - store::TurnStage::Move => { - // Utiliser la logique de mouvement par défaut - use super::default::DefaultStrategy; - use crate::BotStrategy; - - let mut default_strategy = DefaultStrategy::default(); - default_strategy.set_player_id(*player_id); - default_strategy.set_color(game_state.player_color_by_id(player_id).unwrap_or(Color::White)); - *default_strategy.get_mut_game() = game_state.clone(); - - let (move1, move2) = default_strategy.choose_move(); - game.move_checker_for_player(player_id, move1, move2)?; - } - _ => {} - } - - Ok(()) - } -} \ No newline at end of file diff --git a/devenv.nix b/devenv.nix index c37b4ab..d41dbe8 100644 --- a/devenv.nix +++ b/devenv.nix @@ -4,7 +4,9 @@ packages = [ - # pour burn-rs (compilation sdl2-sys) + # pour burn-rs + pkgs.SDL2_gfx + # (compilation sdl2-sys) pkgs.cmake pkgs.libffi pkgs.wayland-scanner