use burn::{prelude::*, tensor::Tensor}; use crate::GameState; use store::{Color, PlayerId}; /// Trait pour les actions dans l'environnement pub trait Action: std::fmt::Debug + Clone + Copy { fn random() -> Self; fn enumerate() -> Vec; fn size() -> usize; } /// Trait pour les états dans l'environnement pub trait State: std::fmt::Debug + Clone + Copy { type Data; fn to_tensor(&self) -> Tensor; fn size() -> usize; } /// Snapshot d'un step dans l'environnement #[derive(Debug, Clone)] pub struct Snapshot { pub state: E::StateType, pub reward: E::RewardType, pub terminated: bool, } /// Trait pour l'environnement pub trait Environment: std::fmt::Debug { type StateType: State; type ActionType: Action; type RewardType: std::fmt::Debug + Clone; const MAX_STEPS: usize = usize::MAX; fn new(visualized: bool) -> Self; fn state(&self) -> Self::StateType; fn reset(&mut self) -> Snapshot; fn step(&mut self, action: Self::ActionType) -> Snapshot; } /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] pub struct TrictracState { pub data: [f32; 36], // Représentation vectorielle de l'état du jeu } impl State for TrictracState { type Data = [f32; 36]; fn to_tensor(&self) -> Tensor { Tensor::from_floats(self.data, &B::Device::default()) } fn size() -> usize { 36 } } impl TrictracState { /// Convertit un GameState en TrictracState pub fn from_game_state(game_state: &GameState) -> Self { let state_vec = game_state.to_vec(); let mut data = [0.0f32; 36]; // Copier les données en s'assurant qu'on ne dépasse pas la taille let copy_len = state_vec.len().min(36); for i in 0..copy_len { data[i] = state_vec[i] as f32; } TrictracState { data } } } /// Actions possibles dans Trictrac pour burn-rl #[derive(Debug, Clone, Copy, PartialEq)] pub struct TrictracAction { pub index: u32, } impl Action for TrictracAction { fn random() -> Self { use rand::{thread_rng, Rng}; let mut rng = thread_rng(); TrictracAction { index: rng.gen_range(0..Self::size() as u32), } } fn enumerate() -> Vec { (0..Self::size() as u32) .map(|index| TrictracAction { index }) .collect() } fn size() -> usize { // Utiliser l'espace d'actions compactes pour réduire la complexité // Maximum estimé basé sur les actions contextuelles 1000 // Estimation conservative, sera ajusté dynamiquement } } impl From for TrictracAction { fn from(index: u32) -> Self { TrictracAction { index } } } impl From for u32 { fn from(action: TrictracAction) -> u32 { action.index } } /// Environnement Trictrac pour burn-rl #[derive(Debug)] pub struct TrictracEnvironment { game_state: store::GameState, active_player_id: PlayerId, opponent_id: PlayerId, current_state: TrictracState, episode_reward: f32, step_count: usize, visualized: bool, } impl Environment for TrictracEnvironment { type StateType = TrictracState; type ActionType = TrictracAction; type RewardType = f32; const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies fn new(visualized: bool) -> Self { let mut game_state = store::GameState::new(false); // Pas d'écoles pour l'instant // Ajouter deux joueurs let player1_id = game_state.init_player("DQN Agent").unwrap(); let player2_id = game_state.init_player("Opponent").unwrap(); // Commencer le jeu game_state.stage = store::Stage::InGame; game_state.active_player_id = player1_id; let current_state = TrictracState::from_game_state(&game_state); TrictracEnvironment { game_state, active_player_id: player1_id, opponent_id: player2_id, current_state, episode_reward: 0.0, step_count: 0, visualized, } } fn state(&self) -> Self::StateType { self.current_state } fn reset(&mut self) -> Snapshot { // Réinitialiser le jeu self.game_state = store::GameState::new(false); self.active_player_id = self.game_state.init_player("DQN Agent").unwrap(); self.opponent_id = self.game_state.init_player("Opponent").unwrap(); self.game_state.stage = store::Stage::InGame; self.game_state.active_player_id = self.active_player_id; self.current_state = TrictracState::from_game_state(&self.game_state); self.episode_reward = 0.0; self.step_count = 0; Snapshot { state: self.current_state, reward: 0.0, terminated: false, } } fn step(&mut self, action: Self::ActionType) -> Snapshot { self.step_count += 1; // Convertir l'action burn-rl vers une action Trictrac let trictrac_action = self.convert_action(action, &self.game_state); let mut reward = 0.0; let mut terminated = false; // Simplification pour le moment - juste donner une récompense aléatoire reward = if trictrac_action.is_some() { 0.1 } else { -0.1 }; // Vérifier fin de partie (simplifiée) if self.step_count >= Self::MAX_STEPS { terminated = true; } // Mettre à jour l'état (simplifiée) self.current_state = TrictracState::from_game_state(&self.game_state); self.episode_reward += reward; if self.visualized && terminated { println!("Episode terminé. Récompense totale: {:.2}, Étapes: {}", self.episode_reward, self.step_count); } Snapshot { state: self.current_state, reward, terminated, } } } impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac fn convert_action(&self, action: TrictracAction, game_state: &GameState) -> Option { use super::dqn_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); if valid_actions.is_empty() { return None; } // Mapper l'index d'action sur une action valide let action_index = (action.index as usize) % valid_actions.len(); Some(valid_actions[action_index].clone()) } /// Exécute une action Trictrac dans le jeu fn execute_action(&mut self, action: super::dqn_common::TrictracAction) -> Result> { use super::dqn_common::TrictracAction; let mut reward = 0.0; match action { TrictracAction::Roll => { self.game.roll_dice_for_player(&self.active_player_id)?; reward = 0.1; // Petite récompense pour une action valide } TrictracAction::Go => { self.game.go_for_player(&self.active_player_id)?; reward = 0.2; // Récompense pour continuer } TrictracAction::Move { dice_order, from1, from2 } => { // Convertir les positions compactes en mouvements réels let game_state = self.game.get_state(); let dice = game_state.dice; let (die1, die2) = if dice_order { (dice.values.0, dice.values.1) } else { (dice.values.1, dice.values.0) }; // Calculer les destinations selon la couleur du joueur let player_color = game_state.player_color_by_id(&self.active_player_id).unwrap_or(Color::White); let to1 = if player_color == Color::White { from1 + die1 as usize } else { from1.saturating_sub(die1 as usize) }; let to2 = if player_color == Color::White { from2 + die2 as usize } else { from2.saturating_sub(die2 as usize) }; let checker_move1 = store::CheckerMove::new(from1, to1)?; let checker_move2 = store::CheckerMove::new(from2, to2)?; self.game.move_checker_for_player(&self.active_player_id, checker_move1, checker_move2)?; reward = 0.3; // Récompense pour un mouvement réussi } } Ok(reward) } /// Fait jouer l'adversaire avec une stratégie simple fn play_opponent_if_needed(&mut self) { let game_state = self.game.get_state(); // Si c'est le tour de l'adversaire, jouer automatiquement if game_state.active_player_id == self.opponent_id && !game_state.is_finished() { // Utiliser la stratégie default pour l'adversaire use super::default::DefaultStrategy; use crate::BotStrategy; let mut default_strategy = DefaultStrategy::default(); default_strategy.set_player_id(self.opponent_id); if let Some(color) = game_state.player_color_by_id(&self.opponent_id) { default_strategy.set_color(color); } *default_strategy.get_mut_game() = game_state.clone(); // Exécuter l'action selon le turn_stage match game_state.turn_stage { store::TurnStage::RollDice => { let _ = self.game.roll_dice_for_player(&self.opponent_id); } store::TurnStage::MarkPoints | store::TurnStage::MarkAdvPoints => { let points = if game_state.turn_stage == store::TurnStage::MarkPoints { default_strategy.calculate_points() } else { default_strategy.calculate_adv_points() }; let _ = self.game.mark_points_for_player(&self.opponent_id, points); } store::TurnStage::HoldOrGoChoice => { if default_strategy.choose_go() { let _ = self.game.go_for_player(&self.opponent_id); } else { let (move1, move2) = default_strategy.choose_move(); let _ = self.game.move_checker_for_player(&self.opponent_id, move1, move2); } } store::TurnStage::Move => { let (move1, move2) = default_strategy.choose_move(); let _ = self.game.move_checker_for_player(&self.opponent_id, move1, move2); } _ => {} } } } }