From 5b133cfe0a58c0c310f1325854b5376ada3a9fd4 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 22 Jun 2025 15:42:55 +0200 Subject: [PATCH] claude (compile fails) --- bot/Cargo.toml | 2 + bot/src/strategy/burn_environment.rs | 272 +++++++++++++++++++++++++++ bot/src/strategy/mod.rs | 47 +++++ 3 files changed, 321 insertions(+) create mode 100644 bot/src/strategy/burn_environment.rs create mode 100644 bot/src/strategy/mod.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 64a6d76..878f90f 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -16,3 +16,5 @@ serde_json = "1.0" store = { path = "../store" } rand = "0.8" env_logger = "0.10" +burn = { version = "0.17", features = ["ndarray", "autodiff"] } +burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" } diff --git a/bot/src/strategy/burn_environment.rs b/bot/src/strategy/burn_environment.rs new file mode 100644 index 0000000..aa103df --- /dev/null +++ b/bot/src/strategy/burn_environment.rs @@ -0,0 +1,272 @@ +use burn::{backend::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use crate::GameState; +use store::{Color, Game, PlayerId}; +use std::collections::HashMap; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [f32; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [f32; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0.0f32; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + for i in 0..copy_len { + data[i] = state_vec[i]; + } + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + // Utiliser l'espace d'actions compactes pour réduire la complexité + // Maximum estimé basé sur les actions contextuelles + 1000 // Estimation conservative, sera ajusté dynamiquement + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + game: Game, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + step_count: usize, + visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies + + fn new(visualized: bool) -> Self { + let mut game = Game::new(); + + // Ajouter deux joueurs + let player1_id = game.add_player("DQN Agent".to_string(), Color::White); + let player2_id = game.add_player("Opponent".to_string(), Color::Black); + + game.start(); + + let game_state = game.get_state(); + let current_state = TrictracState::from_game_state(&game_state); + + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = Game::new(); + self.active_player_id = self.game.add_player("DQN Agent".to_string(), Color::White); + self.opponent_id = self.game.add_player("Opponent".to_string(), Color::Black); + self.game.start(); + + let game_state = self.game.get_state(); + self.current_state = TrictracState::from_game_state(&game_state); + self.episode_reward = 0.0; + self.step_count = 0; + + Snapshot { + state: self.current_state, + reward: 0.0, + terminated: false, + } + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + let game_state = self.game.get_state(); + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = self.convert_action(action, &game_state); + + let mut reward = 0.0; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if game_state.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + match self.execute_action(action) { + Ok(action_reward) => { + reward = action_reward; + } + Err(_) => { + // Action invalide, pénalité + reward = -1.0; + } + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Jouer l'adversaire si c'est son tour + self.play_opponent_if_needed(); + + // Vérifier fin de partie + let updated_state = self.game.get_state(); + if updated_state.is_finished() || self.step_count >= Self::MAX_STEPS { + terminated = true; + + // Récompense finale basée sur le résultat + if let Some(winner_id) = updated_state.winner { + if winner_id == self.active_player_id { + reward += 10.0; // Victoire + } else { + reward -= 10.0; // Défaite + } + } + } + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&updated_state); + self.episode_reward += reward; + + if self.visualized && terminated { + println!("Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count); + } + + Snapshot { + state: self.current_state, + reward, + terminated, + } + } +} + +impl TrictracEnvironment { + /// Convertit une action burn-rl vers une action Trictrac + fn convert_action(&self, action: TrictracAction, game_state: &GameState) -> Option { + use super::dqn_common::{get_valid_compact_actions, CompactAction}; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_compact_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + let compact_action = &valid_actions[action_index]; + + // Convertir l'action compacte vers une action Trictrac complète + compact_action.to_trictrac_action(game_state) + } + + /// Exécute une action Trictrac dans le jeu + fn execute_action(&mut self, action: super::dqn_common::TrictracAction) -> Result> { + use super::dqn_common::TrictracAction; + + let mut reward = 0.0; + + match action { + TrictracAction::Roll => { + self.game.roll_dice_for_player(&self.active_player_id)?; + reward = 0.1; // Petite récompense pour une action valide + } + TrictracAction::Mark { points } => { + self.game.mark_points_for_player(&self.active_player_id, points)?; + reward = points as f32 * 0.1; // Récompense proportionnelle aux points + } + TrictracAction::Go => { + self.game.go_for_player(&self.active_player_id)?; + reward = 0.2; // Récompense pour continuer + } + TrictracAction::Move { move1, move2 } => { + let checker_move1 = store::CheckerMove::new(move1.0, move1.1)?; + let checker_move2 = store::CheckerMove::new(move2.0, move2.1)?; + self.game.move_checker_for_player(&self.active_player_id, checker_move1, checker_move2)?; + reward = 0.3; // Récompense pour un mouvement réussi + } + } + + Ok(reward) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) { + let game_state = self.game.get_state(); + + // Si c'est le tour de l'adversaire, jouer automatiquement + if game_state.active_player_id == self.opponent_id && !game_state.is_finished() { + // Utiliser une stratégie simple pour l'adversaire (dummy bot) + if let Ok(_) = crate::strategy::dummy::get_dummy_action(&mut self.game, &self.opponent_id) { + // L'action a été exécutée par get_dummy_action + } + } + } +} \ No newline at end of file diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs new file mode 100644 index 0000000..cc690dd --- /dev/null +++ b/bot/src/strategy/mod.rs @@ -0,0 +1,47 @@ +pub mod burn_environment; +pub mod client; +pub mod default; +pub mod dqn; +pub mod dqn_common; +pub mod dqn_trainer; +pub mod erroneous_moves; +pub mod stable_baselines3; + +pub mod dummy { + use store::{Color, Game, PlayerId}; + + /// Action simple pour l'adversaire dummy + pub fn get_dummy_action(game: &mut Game, player_id: &PlayerId) -> Result<(), Box> { + let game_state = game.get_state(); + + match game_state.turn_stage { + store::TurnStage::RollDice => { + game.roll_dice_for_player(player_id)?; + } + store::TurnStage::MarkPoints | store::TurnStage::MarkAdvPoints => { + // Marquer 0 points (stratégie conservatrice) + game.mark_points_for_player(player_id, 0)?; + } + store::TurnStage::HoldOrGoChoice => { + // Toujours choisir "Go" (stratégie simple) + game.go_for_player(player_id)?; + } + store::TurnStage::Move => { + // Utiliser la logique de mouvement par défaut + use super::default::DefaultStrategy; + use crate::BotStrategy; + + let mut default_strategy = DefaultStrategy::default(); + default_strategy.set_player_id(*player_id); + default_strategy.set_color(game_state.player_color_by_id(player_id).unwrap_or(Color::White)); + *default_strategy.get_mut_game() = game_state.clone(); + + let (move1, move2) = default_strategy.choose_move(); + game.move_checker_for_player(player_id, move1, move2)?; + } + _ => {} + } + + Ok(()) + } +} \ No newline at end of file