From e15dba167bb9d738f154772d82a8ed317939e8d7 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 13 Aug 2025 18:16:30 +0200 Subject: [PATCH 1/7] wip burnrl_big --- bot/Cargo.toml | 4 + bot/scripts/train.sh | 6 +- bot/src/dqn/burnrl_big/dqn_model.rs | 211 ++++++++++++ bot/src/dqn/burnrl_big/environment.rs | 459 ++++++++++++++++++++++++++ bot/src/dqn/burnrl_big/main.rs | 53 +++ bot/src/dqn/burnrl_big/mod.rs | 3 + bot/src/dqn/burnrl_big/utils.rs | 114 +++++++ bot/src/dqn/mod.rs | 1 + 8 files changed, 849 insertions(+), 2 deletions(-) create mode 100644 bot/src/dqn/burnrl_big/dqn_model.rs create mode 100644 bot/src/dqn/burnrl_big/environment.rs create mode 100644 bot/src/dqn/burnrl_big/main.rs create mode 100644 bot/src/dqn/burnrl_big/mod.rs create mode 100644 bot/src/dqn/burnrl_big/utils.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 135deae..c043393 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "train_dqn_burn_valid" path = "src/dqn/burnrl_valid/main.rs" +[[bin]] +name = "train_dqn_burn_big" +path = "src/dqn/burnrl_big/main.rs" + [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 9e54c7a..cfa29fb 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -4,16 +4,18 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 +# BINBOT=train_dqn_burn +BINBOT=train_dqn_burn_big OPPONENT="random" PLOT_EXT="png" train() { - cargo build --release --bin=train_dqn_burn + cargo build --release --bin=$BINBOT NAME="train_$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$NAME.out" mkdir -p "$LOGS_DIR" - LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS" + LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" } plot() { diff --git a/bot/src/dqn/burnrl_big/dqn_model.rs b/bot/src/dqn/burnrl_big/dqn_model.rs new file mode 100644 index 0000000..f50bf31 --- /dev/null +++ b/bot/src/dqn/burnrl_big/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_big::environment::TrictracEnvironment; +use crate::dqn::burnrl_big::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs new file mode 100644 index 0000000..d6b8721 --- /dev/null +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -0,0 +1,459 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).0, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_big/main.rs b/bot/src/dqn/burnrl_big/main.rs new file mode 100644 index 0000000..097a27b --- /dev/null +++ b/bot/src/dqn/burnrl_big/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 1000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_big/mod.rs b/bot/src/dqn/burnrl_big/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_big/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_big/utils.rs b/bot/src/dqn/burnrl_big/utils.rs new file mode 100644 index 0000000..9159d57 --- /dev/null +++ b/bot/src/dqn/burnrl_big/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_big::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_big::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index ab75746..7b12487 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,4 +1,5 @@ pub mod burnrl; +pub mod burnrl_big; pub mod dqn_common; pub mod dqn_common_big; pub mod simple; From f06d38fd3246a1077c9d58b91529dbf9c6cc7fe6 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 13 Aug 2025 21:30:24 +0200 Subject: [PATCH 2/7] burnrl_before --- bot/Cargo.toml | 4 + bot/scripts/train.sh | 3 +- bot/src/dqn/burnrl_before/dqn_model.rs | 211 +++++++++++ bot/src/dqn/burnrl_before/environment.rs | 448 +++++++++++++++++++++++ bot/src/dqn/burnrl_before/main.rs | 53 +++ bot/src/dqn/burnrl_before/mod.rs | 3 + bot/src/dqn/burnrl_before/utils.rs | 114 ++++++ bot/src/dqn/dqn_common_before.rs | 255 +++++++++++++ bot/src/dqn/mod.rs | 2 + 9 files changed, 1092 insertions(+), 1 deletion(-) create mode 100644 bot/src/dqn/burnrl_before/dqn_model.rs create mode 100644 bot/src/dqn/burnrl_before/environment.rs create mode 100644 bot/src/dqn/burnrl_before/main.rs create mode 100644 bot/src/dqn/burnrl_before/mod.rs create mode 100644 bot/src/dqn/burnrl_before/utils.rs create mode 100644 bot/src/dqn/dqn_common_before.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index c043393..4a0a95c 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,6 +13,10 @@ path = "src/dqn/burnrl_valid/main.rs" name = "train_dqn_burn_big" path = "src/dqn/burnrl_big/main.rs" +[[bin]] +name = "train_dqn_burn_before" +path = "src/dqn/burnrl_before/main.rs" + [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index cfa29fb..889dc9a 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -5,7 +5,8 @@ LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 # BINBOT=train_dqn_burn -BINBOT=train_dqn_burn_big +# BINBOT=train_dqn_burn_big +BINBOT=train_dqn_burn_before OPPONENT="random" PLOT_EXT="png" diff --git a/bot/src/dqn/burnrl_before/dqn_model.rs b/bot/src/dqn/burnrl_before/dqn_model.rs new file mode 100644 index 0000000..02646eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_before::environment::TrictracEnvironment; +use crate::dqn::burnrl_before::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs new file mode 100644 index 0000000..6ce01c9 --- /dev/null +++ b/bot/src/dqn/burnrl_before/environment.rs @@ -0,0 +1,448 @@ +use crate::dqn::dqn_common_before; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_before::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_before::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_before::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_before::TrictracAction) -> (f32, bool) { + use dqn_common_before::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_before/main.rs b/bot/src/dqn/burnrl_before/main.rs new file mode 100644 index 0000000..602ff51 --- /dev/null +++ b/bot/src/dqn/burnrl_before/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl_before::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 3000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_before/mod.rs b/bot/src/dqn/burnrl_before/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs new file mode 100644 index 0000000..e6b4330 --- /dev/null +++ b/bot/src/dqn/burnrl_before/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_before::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_before::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/dqn_common_before.rs b/bot/src/dqn/dqn_common_before.rs new file mode 100644 index 0000000..2da4aa5 --- /dev/null +++ b/bot/src/dqn/dqn_common_before.rs @@ -0,0 +1,255 @@ +use std::cmp::{max, min}; + +use serde::{Deserialize, Serialize}; +use store::{CheckerMove, Dice}; + +/// Types d'actions possibles dans le jeu +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum TrictracAction { + /// Lancer les dés + Roll, + /// Continuer après avoir gagné un trou + Go, + /// Effectuer un mouvement de pions + Move { + dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier + from1: usize, // position de départ du premier pion (0-24) + from2: usize, // position de départ du deuxième pion (0-24) + }, + // Marquer les points : à activer si support des écoles + // Mark, +} + +impl TrictracAction { + /// Encode une action en index pour le réseau de neurones + pub fn to_action_index(&self) -> usize { + match self { + TrictracAction::Roll => 0, + TrictracAction::Go => 1, + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Encoder les mouvements dans l'espace d'actions + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; + if !dice_order { + // 25 * 25 = 625 + start += 625; + } + start + from1 * 25 + from2 + } // TrictracAction::Mark => 1252, + } + } + + /// Décode un index d'action en TrictracAction + pub fn from_action_index(index: usize) -> Option { + match index { + 0 => Some(TrictracAction::Roll), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), + i if i >= 3 => { + let move_code = i - 3; + let (dice_order, from1, from2) = Self::decode_move(move_code); + Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) + } + _ => None, + } + } + + /// Décode un entier en paire de mouvements + fn decode_move(code: usize) -> (bool, usize, usize) { + let mut encoded = code; + let dice_order = code < 626; + if !dice_order { + encoded -= 625 + } + let from1 = encoded / 25; + let from2 = 1 + encoded % 25; + (dice_order, from1, from2) + } + + /// Retourne la taille de l'espace d'actions total + pub fn action_space_size() -> usize { + // 1 (Roll) + 1 (Go) + mouvements possibles + // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) + // Mais on peut optimiser en limitant aux positions valides (1-24) + 2 + (2 * 25 * 25) // = 1252 + } + + // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + // match action { + // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + // TrictracAction::Go => Some(GameEvent::Go { player_id }), + // TrictracAction::Move { + // dice_order, + // from1, + // from2, + // } => { + // // Effectuer un mouvement + // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + // + // Some(GameEvent::Move { + // player_id: self.agent_player_id, + // moves: (checker_move1, checker_move2), + // }) + // } + // }; + // } +} + +/// Obtient les actions valides pour l'état de jeu actuel +pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { + use store::TurnStage; + + let mut valid_actions = Vec::new(); + + let active_player_id = game_state.active_player_id; + let player_color = game_state.player_color_by_id(&active_player_id); + + if let Some(color) = player_color { + match game_state.turn_stage { + TurnStage::RollDice | TurnStage::RollWaiting => { + valid_actions.push(TrictracAction::Roll); + } + TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { + // valid_actions.push(TrictracAction::Mark); + } + TurnStage::HoldOrGoChoice => { + valid_actions.push(TrictracAction::Go); + + // Ajoute aussi les mouvements possibles + let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); + for (move1, move2) in possible_moves { + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); + } + } + TurnStage::Move => { + let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); + for (move1, move2) in possible_moves { + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); + } + } + } + } + + valid_actions +} + +// Valid only for White player +fn checker_moves_to_trictrac_action( + move1: &CheckerMove, + move2: &CheckerMove, + dice: &Dice, +) -> TrictracAction { + let to1 = move1.get_to(); + let to2 = move2.get_to(); + let from1 = move1.get_from(); + let from2 = move2.get_from(); + + let mut diff_move1 = if to1 > 0 { + // Mouvement sans sortie + to1 - from1 + } else { + // sortie, on utilise la valeur du dé + if to2 > 0 { + // sortie pour le mouvement 1 uniquement + let dice2 = to2 - from2; + if dice2 == dice.values.0 as usize { + dice.values.1 as usize + } else { + dice.values.0 as usize + } + } else { + // double sortie + if from1 < from2 { + max(dice.values.0, dice.values.1) as usize + } else { + min(dice.values.0, dice.values.1) as usize + } + } + }; + + // modification de diff_move1 si on est dans le cas d'un mouvement par puissance + let rest_field = 12; + if to1 == rest_field + && to2 == rest_field + && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field + { + // prise par puissance + diff_move1 += 1; + } + TrictracAction::Move { + dice_order: diff_move1 == dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), + } +} + +/// Retourne les indices des actions valides +pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { + get_valid_actions(game_state) + .into_iter() + .map(|action| action.to_action_index()) + .collect() +} + +/// Sélectionne une action valide aléatoire +pub fn sample_valid_action(game_state: &crate::GameState) -> Option { + use rand::{seq::SliceRandom, thread_rng}; + + let valid_actions = get_valid_actions(game_state); + let mut rng = thread_rng(); + valid_actions.choose(&mut rng).cloned() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn to_action_index() { + let action = TrictracAction::Move { + dice_order: true, + from1: 3, + from2: 4, + }; + let index = action.to_action_index(); + assert_eq!(Some(action), TrictracAction::from_action_index(index)); + assert_eq!(81, index); + } + + #[test] + fn from_action_index() { + let action = TrictracAction::Move { + dice_order: true, + from1: 3, + from2: 4, + }; + assert_eq!(Some(action), TrictracAction::from_action_index(81)); + } +} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index 7b12487..1edf4f7 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,6 +1,8 @@ pub mod burnrl; +pub mod burnrl_before; pub mod burnrl_big; pub mod dqn_common; +pub mod dqn_common_before; pub mod dqn_common_big; pub mod simple; From c0da21dfec1dd495aeb4f46012dbdcd28bdb73ba Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 13 Aug 2025 21:36:34 +0200 Subject: [PATCH 3/7] wip burnrl_big --- bot/src/dqn/burnrl_big/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/src/dqn/burnrl_big/main.rs b/bot/src/dqn/burnrl_big/main.rs index 097a27b..d8b200f 100644 --- a/bot/src/dqn/burnrl_big/main.rs +++ b/bot/src/dqn/burnrl_big/main.rs @@ -17,7 +17,7 @@ fn main() { // defaults num_episodes: 40, // 40 min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 1000, // 1000 max steps by episode + max_steps: 3000, // 1000 max steps by episode dense_size: 256, // 128 neural network complexity (default 128) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_end: 0.05, // 0.05 From 522f7470b292c46a2ef5374f693d9ffde6cab9b9 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 13 Aug 2025 21:55:37 +0200 Subject: [PATCH 4/7] not yet burnrl_big --- bot/scripts/train.sh | 4 ++-- bot/src/dqn/burnrl_big/environment.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 889dc9a..9da60a0 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -5,8 +5,8 @@ LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 # BINBOT=train_dqn_burn -# BINBOT=train_dqn_burn_big -BINBOT=train_dqn_burn_before +BINBOT=train_dqn_burn_big +# BINBOT=train_dqn_burn_before OPPONENT="random" PLOT_EXT="png" diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs index d6b8721..6706163 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -165,8 +165,7 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let mut is_rollpoint = false; - let mut terminated = false; + let is_rollpoint; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { @@ -381,7 +380,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - calculate_points = true; + // calculate_points = true; // comment to replicate burnrl_before GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -390,6 +389,7 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); let dice_roll_count = self .game .players From bcc4b977c481d0c74212ff043cde11b73008a914 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 14 Aug 2025 15:53:59 +0200 Subject: [PATCH 5/7] wip not yet --- bot/scripts/train.sh | 4 +- bot/src/dqn/burnrl_before/environment.rs | 19 +- bot/src/dqn/burnrl_before/utils.rs | 2 +- bot/src/dqn/burnrl_big/environment.rs | 30 +- bot/src/dqn/burnrl_big/environmentDiverge.rs | 459 +++++++++++++++++++ bot/src/dqn/dqn_common_before.rs | 255 ----------- bot/src/dqn/mod.rs | 1 - 7 files changed, 482 insertions(+), 288 deletions(-) create mode 100644 bot/src/dqn/burnrl_big/environmentDiverge.rs delete mode 100644 bot/src/dqn/dqn_common_before.rs diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 9da60a0..a3be831 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -13,14 +13,14 @@ PLOT_EXT="png" train() { cargo build --release --bin=$BINBOT - NAME="train_$(date +%Y-%m-%d_%H:%M:%S)" + NAME=$BINBOT"_$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$NAME.out" mkdir -p "$LOGS_DIR" LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep $BINBOT | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs index 6ce01c9..9925a9a 100644 --- a/bot/src/dqn/burnrl_before/environment.rs +++ b/bot/src/dqn/burnrl_before/environment.rs @@ -1,4 +1,4 @@ -use crate::dqn::dqn_common_before; +use crate::dqn::dqn_common_big; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; @@ -227,8 +227,8 @@ impl TrictracEnvironment { const REWARD_RATIO: f32 = 1.0; /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_before::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac @@ -236,8 +236,8 @@ impl TrictracEnvironment { &self, action: TrictracAction, game_state: &GameState, - ) -> Option { - use dqn_common_before::get_valid_actions; + ) -> Option { + use dqn_common_big::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(game_state); @@ -254,10 +254,10 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common_before::TrictracAction, + // action:dqn_common_big::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_before::TrictracAction) -> (f32, bool) { - use dqn_common_before::TrictracAction; + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; @@ -387,6 +387,7 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); let opponent_color = store::Color::Black; let dice_roll_count = self .game @@ -397,7 +398,7 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points GameEvent::Mark { player_id: self.opponent_id, diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs index e6b4330..6c25c5d 100644 --- a/bot/src/dqn/burnrl_before/utils.rs +++ b/bot/src/dqn/burnrl_before/utils.rs @@ -2,7 +2,7 @@ use crate::dqn::burnrl_before::{ dqn_model, environment::{TrictracAction, TrictracEnvironment}, }; -use crate::dqn::dqn_common_before::get_valid_action_indices; +use crate::dqn::dqn_common_big::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs index 6706163..9925a9a 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -165,7 +165,8 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let is_rollpoint; + let mut is_rollpoint = false; + let mut terminated = false; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { @@ -253,7 +254,7 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: dqn_common_big::TrictracAction, + // action:dqn_common_big::TrictracAction, // ) -> Result> { fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { use dqn_common_big::TrictracAction; @@ -371,8 +372,6 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -380,7 +379,6 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - // calculate_points = true; // comment to replicate burnrl_before GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -390,6 +388,7 @@ impl TrictracEnvironment { } TurnStage::MarkPoints => { panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -398,9 +397,12 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).0, + points, } } TurnStage::MarkAdvPoints => { @@ -413,10 +415,11 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, + points, } } TurnStage::HoldOrGoChoice => { @@ -433,19 +436,6 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - } } } reward diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs new file mode 100644 index 0000000..6706163 --- /dev/null +++ b/bot/src/dqn/burnrl_big/environmentDiverge.rs @@ -0,0 +1,459 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let is_rollpoint; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).0, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/dqn_common_before.rs b/bot/src/dqn/dqn_common_before.rs deleted file mode 100644 index 2da4aa5..0000000 --- a/bot/src/dqn/dqn_common_before.rs +++ /dev/null @@ -1,255 +0,0 @@ -use std::cmp::{max, min}; - -use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice}; - -/// Types d'actions possibles dans le jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum TrictracAction { - /// Lancer les dés - Roll, - /// Continuer après avoir gagné un trou - Go, - /// Effectuer un mouvement de pions - Move { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, - // Marquer les points : à activer si support des écoles - // Mark, -} - -impl TrictracAction { - /// Encode une action en index pour le réseau de neurones - pub fn to_action_index(&self) -> usize { - match self { - TrictracAction::Roll => 0, - TrictracAction::Go => 1, - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Encoder les mouvements dans l'espace d'actions - // Indices 2+ pour les mouvements - // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) - let mut start = 2; - if !dice_order { - // 25 * 25 = 625 - start += 625; - } - start + from1 * 25 + from2 - } // TrictracAction::Mark => 1252, - } - } - - /// Décode un index d'action en TrictracAction - pub fn from_action_index(index: usize) -> Option { - match index { - 0 => Some(TrictracAction::Roll), - // 1252 => Some(TrictracAction::Mark), - 1 => Some(TrictracAction::Go), - i if i >= 3 => { - let move_code = i - 3; - let (dice_order, from1, from2) = Self::decode_move(move_code); - Some(TrictracAction::Move { - dice_order, - from1, - from2, - }) - } - _ => None, - } - } - - /// Décode un entier en paire de mouvements - fn decode_move(code: usize) -> (bool, usize, usize) { - let mut encoded = code; - let dice_order = code < 626; - if !dice_order { - encoded -= 625 - } - let from1 = encoded / 25; - let from2 = 1 + encoded % 25; - (dice_order, from1, from2) - } - - /// Retourne la taille de l'espace d'actions total - pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Go) + mouvements possibles - // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) - // Mais on peut optimiser en limitant aux positions valides (1-24) - 2 + (2 * 25 * 25) // = 1252 - } - - // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { - // match action { - // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), - // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), - // TrictracAction::Go => Some(GameEvent::Go { player_id }), - // TrictracAction::Move { - // dice_order, - // from1, - // from2, - // } => { - // // Effectuer un mouvement - // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - // - // Some(GameEvent::Move { - // player_id: self.agent_player_id, - // moves: (checker_move1, checker_move2), - // }) - // } - // }; - // } -} - -/// Obtient les actions valides pour l'état de jeu actuel -pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice | TurnStage::RollWaiting => { - valid_actions.push(TrictracAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { - // valid_actions.push(TrictracAction::Mark); - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(TrictracAction::Go); - - // Ajoute aussi les mouvements possibles - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - TurnStage::Move => { - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - } - } - - valid_actions -} - -// Valid only for White player -fn checker_moves_to_trictrac_action( - move1: &CheckerMove, - move2: &CheckerMove, - dice: &Dice, -) -> TrictracAction { - let to1 = move1.get_to(); - let to2 = move2.get_to(); - let from1 = move1.get_from(); - let from2 = move2.get_from(); - - let mut diff_move1 = if to1 > 0 { - // Mouvement sans sortie - to1 - from1 - } else { - // sortie, on utilise la valeur du dé - if to2 > 0 { - // sortie pour le mouvement 1 uniquement - let dice2 = to2 - from2; - if dice2 == dice.values.0 as usize { - dice.values.1 as usize - } else { - dice.values.0 as usize - } - } else { - // double sortie - if from1 < from2 { - max(dice.values.0, dice.values.1) as usize - } else { - min(dice.values.0, dice.values.1) as usize - } - } - }; - - // modification de diff_move1 si on est dans le cas d'un mouvement par puissance - let rest_field = 12; - if to1 == rest_field - && to2 == rest_field - && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field - { - // prise par puissance - diff_move1 += 1; - } - TrictracAction::Move { - dice_order: diff_move1 == dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - } -} - -/// Retourne les indices des actions valides -pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { - get_valid_actions(game_state) - .into_iter() - .map(|action| action.to_action_index()) - .collect() -} - -/// Sélectionne une action valide aléatoire -pub fn sample_valid_action(game_state: &crate::GameState) -> Option { - use rand::{seq::SliceRandom, thread_rng}; - - let valid_actions = get_valid_actions(game_state); - let mut rng = thread_rng(); - valid_actions.choose(&mut rng).cloned() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - let index = action.to_action_index(); - assert_eq!(Some(action), TrictracAction::from_action_index(index)); - assert_eq!(81, index); - } - - #[test] - fn from_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - assert_eq!(Some(action), TrictracAction::from_action_index(81)); - } -} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index 1edf4f7..ebc01a4 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -2,7 +2,6 @@ pub mod burnrl; pub mod burnrl_before; pub mod burnrl_big; pub mod dqn_common; -pub mod dqn_common_before; pub mod dqn_common_big; pub mod simple; From 93624c425d09f305b1ed0c0067b168729765f74c Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 13 Aug 2025 18:16:30 +0200 Subject: [PATCH 6/7] wip burnrl_big --- bot/Cargo.toml | 8 + bot/scripts/train.sh | 11 +- bot/src/dqn/burnrl_before/dqn_model.rs | 211 +++++++++ bot/src/dqn/burnrl_before/environment.rs | 449 ++++++++++++++++++ bot/src/dqn/burnrl_before/main.rs | 53 +++ bot/src/dqn/burnrl_before/mod.rs | 3 + bot/src/dqn/burnrl_before/utils.rs | 114 +++++ bot/src/dqn/burnrl_big/dqn_model.rs | 211 +++++++++ bot/src/dqn/burnrl_big/environment.rs | 449 ++++++++++++++++++ bot/src/dqn/burnrl_big/environmentDiverge.rs | 459 +++++++++++++++++++ bot/src/dqn/burnrl_big/main.rs | 53 +++ bot/src/dqn/burnrl_big/mod.rs | 3 + bot/src/dqn/burnrl_big/utils.rs | 114 +++++ bot/src/dqn/mod.rs | 2 + 14 files changed, 2136 insertions(+), 4 deletions(-) create mode 100644 bot/src/dqn/burnrl_before/dqn_model.rs create mode 100644 bot/src/dqn/burnrl_before/environment.rs create mode 100644 bot/src/dqn/burnrl_before/main.rs create mode 100644 bot/src/dqn/burnrl_before/mod.rs create mode 100644 bot/src/dqn/burnrl_before/utils.rs create mode 100644 bot/src/dqn/burnrl_big/dqn_model.rs create mode 100644 bot/src/dqn/burnrl_big/environment.rs create mode 100644 bot/src/dqn/burnrl_big/environmentDiverge.rs create mode 100644 bot/src/dqn/burnrl_big/main.rs create mode 100644 bot/src/dqn/burnrl_big/mod.rs create mode 100644 bot/src/dqn/burnrl_big/utils.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 135deae..4a0a95c 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -9,6 +9,14 @@ edition = "2021" name = "train_dqn_burn_valid" path = "src/dqn/burnrl_valid/main.rs" +[[bin]] +name = "train_dqn_burn_big" +path = "src/dqn/burnrl_big/main.rs" + +[[bin]] +name = "train_dqn_burn_before" +path = "src/dqn/burnrl_before/main.rs" + [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 9e54c7a..a3be831 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -4,20 +4,23 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 +# BINBOT=train_dqn_burn +BINBOT=train_dqn_burn_big +# BINBOT=train_dqn_burn_before OPPONENT="random" PLOT_EXT="png" train() { - cargo build --release --bin=train_dqn_burn - NAME="train_$(date +%Y-%m-%d_%H:%M:%S)" + cargo build --release --bin=$BINBOT + NAME=$BINBOT"_$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$NAME.out" mkdir -p "$LOGS_DIR" - LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS" + LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep $BINBOT | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/src/dqn/burnrl_before/dqn_model.rs b/bot/src/dqn/burnrl_before/dqn_model.rs new file mode 100644 index 0000000..02646eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_before::environment::TrictracEnvironment; +use crate::dqn::burnrl_before::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs new file mode 100644 index 0000000..9925a9a --- /dev/null +++ b/bot/src/dqn/burnrl_before/environment.rs @@ -0,0 +1,449 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action:dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_before/main.rs b/bot/src/dqn/burnrl_before/main.rs new file mode 100644 index 0000000..602ff51 --- /dev/null +++ b/bot/src/dqn/burnrl_before/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl_before::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 3000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_before/mod.rs b/bot/src/dqn/burnrl_before/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs new file mode 100644 index 0000000..6c25c5d --- /dev/null +++ b/bot/src/dqn/burnrl_before/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_before::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_big::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/burnrl_big/dqn_model.rs b/bot/src/dqn/burnrl_big/dqn_model.rs new file mode 100644 index 0000000..f50bf31 --- /dev/null +++ b/bot/src/dqn/burnrl_big/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_big::environment::TrictracEnvironment; +use crate::dqn::burnrl_big::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs new file mode 100644 index 0000000..9925a9a --- /dev/null +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -0,0 +1,449 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action:dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs new file mode 100644 index 0000000..6706163 --- /dev/null +++ b/bot/src/dqn/burnrl_big/environmentDiverge.rs @@ -0,0 +1,459 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let is_rollpoint; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).0, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_big/main.rs b/bot/src/dqn/burnrl_big/main.rs new file mode 100644 index 0000000..3b72ef8 --- /dev/null +++ b/bot/src/dqn/burnrl_big/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl_big::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 3000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_big/mod.rs b/bot/src/dqn/burnrl_big/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_big/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_big/utils.rs b/bot/src/dqn/burnrl_big/utils.rs new file mode 100644 index 0000000..9159d57 --- /dev/null +++ b/bot/src/dqn/burnrl_big/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_big::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_big::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index ab75746..ebc01a4 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,4 +1,6 @@ pub mod burnrl; +pub mod burnrl_before; +pub mod burnrl_big; pub mod dqn_common; pub mod dqn_common_big; pub mod simple; From d313cb615163d9f80c32caa9cb87a5b9ef6d99e0 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 15 Aug 2025 21:08:23 +0200 Subject: [PATCH 7/7] burnrl_big like before --- bot/Cargo.toml | 4 - bot/src/dqn/burnrl_before/dqn_model.rs | 211 --------- bot/src/dqn/burnrl_before/environment.rs | 449 ------------------ bot/src/dqn/burnrl_before/main.rs | 53 --- bot/src/dqn/burnrl_before/mod.rs | 3 - bot/src/dqn/burnrl_before/utils.rs | 114 ----- bot/src/dqn/burnrl_big/environment.rs | 29 +- bot/src/dqn/burnrl_big/environmentDiverge.rs | 459 ------------------- bot/src/dqn/mod.rs | 1 - 9 files changed, 19 insertions(+), 1304 deletions(-) delete mode 100644 bot/src/dqn/burnrl_before/dqn_model.rs delete mode 100644 bot/src/dqn/burnrl_before/environment.rs delete mode 100644 bot/src/dqn/burnrl_before/main.rs delete mode 100644 bot/src/dqn/burnrl_before/mod.rs delete mode 100644 bot/src/dqn/burnrl_before/utils.rs delete mode 100644 bot/src/dqn/burnrl_big/environmentDiverge.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 4a0a95c..c043393 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,10 +13,6 @@ path = "src/dqn/burnrl_valid/main.rs" name = "train_dqn_burn_big" path = "src/dqn/burnrl_big/main.rs" -[[bin]] -name = "train_dqn_burn_before" -path = "src/dqn/burnrl_before/main.rs" - [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/src/dqn/burnrl_before/dqn_model.rs b/bot/src/dqn/burnrl_before/dqn_model.rs deleted file mode 100644 index 02646eb..0000000 --- a/bot/src/dqn/burnrl_before/dqn_model.rs +++ /dev/null @@ -1,211 +0,0 @@ -use crate::dqn::burnrl_before::environment::TrictracEnvironment; -use crate::dqn::burnrl_before::utils::soft_update_linear; -use burn::module::Module; -use burn::nn::{Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::tensor::activation::relu; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::DQN; -use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Net { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Net { - #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } - - fn consume(self) -> (Linear, Linear, Linear) { - (self.linear_0, self.linear_1, self.linear_2) - } -} - -impl Model, Tensor> for Net { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - relu(self.linear_2.forward(layer_1_output)) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); - - Self { - linear_0: soft_update_linear(linear_0, &that.linear_0, tau), - linear_1: soft_update_linear(linear_1, &that.linear_1, tau), - linear_2: soft_update_linear(linear_2, &that.linear_2, tau), - } - } -} - -#[allow(unused)] -const MEMORY_SIZE: usize = 8192; - -pub struct DqnConfig { - pub min_steps: f32, - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("min_steps={:?}\n", self.min_steps)); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - min_steps: 250.0, - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - -type MyAgent = DQN>; - -#[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, - visualized: bool, -) -> DQN> { - // ) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().min_steps = conf.min_steps; - env.as_mut().max_steps = conf.max_steps; - - let model = Net::::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - - let mut agent = MyAgent::new(model); - - // let config = DQNTrainingConfig::default(); - let config = DQNTrainingConfig { - gamma: conf.gamma, - tau: conf.tau, - learning_rate: conf.learning_rate, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut memory = Memory::::default(); - - let mut optimizer = AdamWConfig::new() - .with_grad_clipping(config.clip_grad.clone()) - .init(); - - let mut policy_net = agent.model().as_ref().unwrap().clone(); - - let mut step = 0_usize; - - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward: ElemType = 0.0; - let mut episode_duration = 0_usize; - let mut state = env.state(); - let mut now = SystemTime::now(); - - while !episode_done { - let eps_threshold = conf.eps_end - + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); - let action = - DQN::>::react_with_exploration(&policy_net, state, eps_threshold); - let snapshot = env.step(action); - - episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); - } - - step += 1; - episode_duration += 1; - - if snapshot.done() || episode_duration >= conf.max_steps { - let envmut = env.as_mut(); - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", - envmut.goodmoves_count, - envmut.pointrolls_count, - now.elapsed().unwrap().as_secs(), - ); - env.reset(); - episode_done = true; - now = SystemTime::now(); - } else { - state = *snapshot.state(); - } - } - } - agent -} diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs deleted file mode 100644 index 9925a9a..0000000 --- a/bot/src/dqn/burnrl_before/environment.rs +++ /dev/null @@ -1,449 +0,0 @@ -use crate::dqn::dqn_common_big; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - 1252 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub min_steps: f32, - pub max_steps: usize, - pub pointrolls_count: usize, - pub goodmoves_count: usize, - pub goodmoves_ratio: f32, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - min_steps: 250.0, - max_steps: 2000, - pointrolls_count: 0, - goodmoves_count: 0, - goodmoves_ratio: 0.0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.goodmoves_ratio = if self.step_count == 0 { - 0.0 - } else { - self.goodmoves_count as f32 / self.step_count as f32 - }; - println!( - "info: correct moves: {} ({}%)", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 - ); - self.step_count = 0; - self.pointrolls_count = 0; - self.goodmoves_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = Self::convert_action(action); - - let mut reward = 0.0; - let mut is_rollpoint = false; - let mut terminated = false; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - if reward != Self::ERROR_REWARD { - self.goodmoves_count += 1; - } - } else { - // Action non convertible, pénalité - reward = -0.5; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += 50.0; // Victoire - } else { - reward -= 25.0; // Défaite - } - } - } - let terminated = done || self.step_count >= max_steps.round() as usize; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - const ERROR_REWARD: f32 = -1.12121; - const REWARD_RATIO: f32 = 1.0; - - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - fn convert_valid_action_index( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { - use dqn_common_big::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action:dqn_common_big::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - reward += 0.1; - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - reward += 0.2; - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = Self::ERROR_REWARD; - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/dqn/burnrl_before/main.rs b/bot/src/dqn/burnrl_before/main.rs deleted file mode 100644 index 602ff51..0000000 --- a/bot/src/dqn/burnrl_before/main.rs +++ /dev/null @@ -1,53 +0,0 @@ -use bot::dqn::burnrl_before::{ - dqn_model, environment, - utils::{demo_model, load_model, save_model}, -}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 40, // 40 - min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 3000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "models/burn_dqn_40".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/dqn/burnrl_before/mod.rs b/bot/src/dqn/burnrl_before/mod.rs deleted file mode 100644 index f4380eb..0000000 --- a/bot/src/dqn/burnrl_before/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod dqn_model; -pub mod environment; -pub mod utils; diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs deleted file mode 100644 index 6c25c5d..0000000 --- a/bot/src/dqn/burnrl_before/utils.rs +++ /dev/null @@ -1,114 +0,0 @@ -use crate::dqn::burnrl_before::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common_big::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs index 9925a9a..ea5a9b4 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -165,8 +165,7 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let mut is_rollpoint = false; - let mut terminated = false; + let is_rollpoint; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { @@ -372,6 +371,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -379,6 +380,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -388,7 +390,6 @@ impl TrictracEnvironment { } TurnStage::MarkPoints => { panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -397,16 +398,12 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).0, } } TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -415,11 +412,10 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).1, } } TurnStage::HoldOrGoChoice => { @@ -436,6 +432,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } } } reward diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs deleted file mode 100644 index 6706163..0000000 --- a/bot/src/dqn/burnrl_big/environmentDiverge.rs +++ /dev/null @@ -1,459 +0,0 @@ -use crate::dqn::dqn_common_big; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - 1252 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub min_steps: f32, - pub max_steps: usize, - pub pointrolls_count: usize, - pub goodmoves_count: usize, - pub goodmoves_ratio: f32, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - min_steps: 250.0, - max_steps: 2000, - pointrolls_count: 0, - goodmoves_count: 0, - goodmoves_ratio: 0.0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.goodmoves_ratio = if self.step_count == 0 { - 0.0 - } else { - self.goodmoves_count as f32 / self.step_count as f32 - }; - println!( - "info: correct moves: {} ({}%)", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 - ); - self.step_count = 0; - self.pointrolls_count = 0; - self.goodmoves_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = Self::convert_action(action); - - let mut reward = 0.0; - let is_rollpoint; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - if reward != Self::ERROR_REWARD { - self.goodmoves_count += 1; - } - } else { - // Action non convertible, pénalité - reward = -0.5; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += 50.0; // Victoire - } else { - reward -= 25.0; // Défaite - } - } - } - let terminated = done || self.step_count >= max_steps.round() as usize; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - const ERROR_REWARD: f32 = -1.12121; - const REWARD_RATIO: f32 = 1.0; - - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - fn convert_valid_action_index( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { - use dqn_common_big::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action: dqn_common_big::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - reward += 0.1; - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - reward += 0.2; - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = Self::ERROR_REWARD; - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - // calculate_points = true; // comment to replicate burnrl_before - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - GameEvent::Mark { - player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).0, - } - } - TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - } - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index ebc01a4..7b12487 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,5 +1,4 @@ pub mod burnrl; -pub mod burnrl_before; pub mod burnrl_big; pub mod dqn_common; pub mod dqn_common_big;