From bcc4b977c481d0c74212ff043cde11b73008a914 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 14 Aug 2025 15:53:59 +0200 Subject: [PATCH] wip not yet --- bot/scripts/train.sh | 4 +- bot/src/dqn/burnrl_before/environment.rs | 19 +- bot/src/dqn/burnrl_before/utils.rs | 2 +- bot/src/dqn/burnrl_big/environment.rs | 30 +- bot/src/dqn/burnrl_big/environmentDiverge.rs | 459 +++++++++++++++++++ bot/src/dqn/dqn_common_before.rs | 255 ----------- bot/src/dqn/mod.rs | 1 - 7 files changed, 482 insertions(+), 288 deletions(-) create mode 100644 bot/src/dqn/burnrl_big/environmentDiverge.rs delete mode 100644 bot/src/dqn/dqn_common_before.rs diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 9da60a0..a3be831 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -13,14 +13,14 @@ PLOT_EXT="png" train() { cargo build --release --bin=$BINBOT - NAME="train_$(date +%Y-%m-%d_%H:%M:%S)" + NAME=$BINBOT"_$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$NAME.out" mkdir -p "$LOGS_DIR" LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep $BINBOT | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs index 6ce01c9..9925a9a 100644 --- a/bot/src/dqn/burnrl_before/environment.rs +++ b/bot/src/dqn/burnrl_before/environment.rs @@ -1,4 +1,4 @@ -use crate::dqn::dqn_common_before; +use crate::dqn::dqn_common_big; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; @@ -227,8 +227,8 @@ impl TrictracEnvironment { const REWARD_RATIO: f32 = 1.0; /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_before::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac @@ -236,8 +236,8 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, - ) -> Option { - use dqn_common_before::get_valid_actions; + ) -> Option { + use dqn_common_big::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); @@ -254,10 +254,10 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common_before::TrictracAction, + // action:dqn_common_big::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_before::TrictracAction) -> (f32, bool) { - use dqn_common_before::TrictracAction; + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; @@ -387,6 +387,7 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); let opponent_color = store::Color::Black; let dice_roll_count = self .game @@ -397,7 +398,7 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points GameEvent::Mark { player_id: self.opponent_id, diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs index e6b4330..6c25c5d 100644 --- a/bot/src/dqn/burnrl_before/utils.rs +++ b/bot/src/dqn/burnrl_before/utils.rs @@ -2,7 +2,7 @@ use crate::dqn::burnrl_before::{ dqn_model, environment::{TrictracAction, TrictracEnvironment}, }; -use crate::dqn::dqn_common_before::get_valid_action_indices; +use crate::dqn::dqn_common_big::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs index 6706163..9925a9a 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -165,7 +165,8 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let is_rollpoint; + let mut is_rollpoint = false; + let mut terminated = false; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { @@ -253,7 +254,7 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common_big::TrictracAction, + // action:dqn_common_big::TrictracAction, // ) -> Result> { fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { use dqn_common_big::TrictracAction; @@ -371,8 +372,6 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -380,7 +379,6 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - // calculate_points = true; // comment to replicate burnrl_before GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -390,6 +388,7 @@ impl TrictracEnvironment { } TurnStage::MarkPoints => { panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -398,9 +397,12 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).0, + points, } } TurnStage::MarkAdvPoints => { @@ -413,10 +415,11 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, + points, } } TurnStage::HoldOrGoChoice => { @@ -433,19 +436,6 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - } } } reward diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs new file mode 100644 index 0000000..6706163 --- /dev/null +++ b/bot/src/dqn/burnrl_big/environmentDiverge.rs @@ -0,0 +1,459 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let is_rollpoint; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).0, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/dqn_common_before.rs b/bot/src/dqn/dqn_common_before.rs deleted file mode 100644 index 2da4aa5..0000000 --- a/bot/src/dqn/dqn_common_before.rs +++ /dev/null @@ -1,255 +0,0 @@ -use std::cmp::{max, min}; - -use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice}; - -/// Types d'actions possibles dans le jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum TrictracAction { - /// Lancer les dés - Roll, - /// Continuer après avoir gagné un trou - Go, - /// Effectuer un mouvement de pions - Move { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, - // Marquer les points : à activer si support des écoles - // Mark, -} - -impl TrictracAction { - /// Encode une action en index pour le réseau de neurones - pub fn to_action_index(&self) -> usize { - match self { - TrictracAction::Roll => 0, - TrictracAction::Go => 1, - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Encoder les mouvements dans l'espace d'actions - // Indices 2+ pour les mouvements - // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) - let mut start = 2; - if !dice_order { - // 25 * 25 = 625 - start += 625; - } - start + from1 * 25 + from2 - } // TrictracAction::Mark => 1252, - } - } - - /// Décode un index d'action en TrictracAction - pub fn from_action_index(index: usize) -> Option { - match index { - 0 => Some(TrictracAction::Roll), - // 1252 => Some(TrictracAction::Mark), - 1 => Some(TrictracAction::Go), - i if i >= 3 => { - let move_code = i - 3; - let (dice_order, from1, from2) = Self::decode_move(move_code); - Some(TrictracAction::Move { - dice_order, - from1, - from2, - }) - } - _ => None, - } - } - - /// Décode un entier en paire de mouvements - fn decode_move(code: usize) -> (bool, usize, usize) { - let mut encoded = code; - let dice_order = code < 626; - if !dice_order { - encoded -= 625 - } - let from1 = encoded / 25; - let from2 = 1 + encoded % 25; - (dice_order, from1, from2) - } - - /// Retourne la taille de l'espace d'actions total - pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Go) + mouvements possibles - // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) - // Mais on peut optimiser en limitant aux positions valides (1-24) - 2 + (2 * 25 * 25) // = 1252 - } - - // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { - // match action { - // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), - // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), - // TrictracAction::Go => Some(GameEvent::Go { player_id }), - // TrictracAction::Move { - // dice_order, - // from1, - // from2, - // } => { - // // Effectuer un mouvement - // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - // - // Some(GameEvent::Move { - // player_id: self.agent_player_id, - // moves: (checker_move1, checker_move2), - // }) - // } - // }; - // } -} - -/// Obtient les actions valides pour l'état de jeu actuel -pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => { - valid_actions.push(TrictracAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // valid_actions.push(TrictracAction::Mark); - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(TrictracAction::Go); - - // Ajoute aussi les mouvements possibles - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - TurnStage::Move => { - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - } - } - - valid_actions -} - -// Valid only for White player -fn checker_moves_to_trictrac_action( - move1: &CheckerMove, - move2: &CheckerMove, - dice: &Dice, -) -> TrictracAction { - let to1 = move1.get_to(); - let to2 = move2.get_to(); - let from1 = move1.get_from(); - let from2 = move2.get_from(); - - let mut diff_move1 = if to1 > 0 { - // Mouvement sans sortie - to1 - from1 - } else { - // sortie, on utilise la valeur du dé - if to2 > 0 { - // sortie pour le mouvement 1 uniquement - let dice2 = to2 - from2; - if dice2 == dice.values.0 as usize { - dice.values.1 as usize - } else { - dice.values.0 as usize - } - } else { - // double sortie - if from1 < from2 { - max(dice.values.0, dice.values.1) as usize - } else { - min(dice.values.0, dice.values.1) as usize - } - } - }; - - // modification de diff_move1 si on est dans le cas d'un mouvement par puissance - let rest_field = 12; - if to1 == rest_field - && to2 == rest_field - && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field - { - // prise par puissance - diff_move1 += 1; - } - TrictracAction::Move { - dice_order: diff_move1 == dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - } -} - -/// Retourne les indices des actions valides -pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { - get_valid_actions(game_state) - .into_iter() - .map(|action| action.to_action_index()) - .collect() -} - -/// Sélectionne une action valide aléatoire -pub fn sample_valid_action(game_state: &crate::GameState) -> Option { - use rand::{seq::SliceRandom, thread_rng}; - - let valid_actions = get_valid_actions(game_state); - let mut rng = thread_rng(); - valid_actions.choose(&mut rng).cloned() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - let index = action.to_action_index(); - assert_eq!(Some(action), TrictracAction::from_action_index(index)); - assert_eq!(81, index); - } - - #[test] - fn from_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - assert_eq!(Some(action), TrictracAction::from_action_index(81)); - } -} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index 1edf4f7..ebc01a4 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -2,7 +2,6 @@ pub mod burnrl; pub mod burnrl_before; pub mod burnrl_big; pub mod dqn_common; -pub mod dqn_common_before; pub mod dqn_common_big; pub mod simple;