diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 135deae..68ff52d 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -5,10 +5,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[[bin]] -name = "train_dqn_burn_valid" -path = "src/dqn/burnrl_valid/main.rs" - [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh deleted file mode 100755 index 349517d..0000000 --- a/bot/scripts/trainValid.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env sh - -ROOT="$(cd "$(dirname "$0")" && pwd)/../.." -LOGS_DIR="$ROOT/bot/models/logs" - -CFG_SIZE=11 -OPPONENT="random" - -PLOT_EXT="png" - -train() { - cargo build --release --bin=train_dqn_burn_valid - NAME="trainValid_$(date +%Y-%m-%d_%H:%M:%S)" - LOGS="$LOGS_DIR/$NAME.out" - mkdir -p "$LOGS_DIR" - LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn_valid" | tee "$LOGS" -} - -plot() { - NAME=$(ls "$LOGS_DIR" | tail -n 1) - LOGS="$LOGS_DIR/$NAME" - cfgs=$(head -n $CFG_SIZE "$LOGS") - for cfg in $cfgs; do - eval "$cfg" - done - - # tail -n +$((CFG_SIZE + 2)) "$LOGS" - tail -n +$((CFG_SIZE + 2)) "$LOGS" | - grep -v "info:" | - awk -F '[ ,]' '{print $5}' | - feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT" -} - -if [ "$1" = "plot" ]; then - plot -else - train -fi diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs index 3e90904..9cf72a1 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -192,15 +192,13 @@ pub fn run, B: AutodiffBackend>( episode_duration += 1; if snapshot.done() || episode_duration >= conf.max_steps { - let envmut = env.as_mut(); - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", - envmut.goodmoves_count, - envmut.pointrolls_count, - now.elapsed().unwrap().as_secs(), - ); env.reset(); episode_done = true; + + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), + ); now = SystemTime::now(); } else { state = *snapshot.state(); diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index a774b12..5cc37c4 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -86,7 +86,6 @@ pub struct TrictracEnvironment { pub step_count: usize, pub min_steps: f32, pub max_steps: usize, - pub pointrolls_count: usize, pub goodmoves_count: usize, pub goodmoves_ratio: f32, pub visualized: bool, @@ -119,7 +118,6 @@ impl Environment for TrictracEnvironment { step_count: 0, min_steps: 250.0, max_steps: 2000, - pointrolls_count: 0, goodmoves_count: 0, goodmoves_ratio: 0.0, visualized, @@ -152,7 +150,6 @@ impl Environment for TrictracEnvironment { (100.0 * self.goodmoves_ratio).round() as u32 ); self.step_count = 0; - self.pointrolls_count = 0; self.goodmoves_count = 0; Snapshot::new(self.current_state, 0.0, false) @@ -165,16 +162,12 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let mut is_rollpoint = false; let mut terminated = false; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } + reward = self.execute_action(action); if reward != Self::ERROR_REWARD { self.goodmoves_count += 1; } @@ -256,11 +249,10 @@ impl TrictracEnvironment { // &mut self, // action: dqn_common::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) { + fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 { use dqn_common::TrictracAction; let mut reward = 0.0; - let mut is_rollpoint = false; let event = match action { TrictracAction::Roll => { @@ -338,8 +330,7 @@ impl TrictracEnvironment { let (points, adv_points) = self.game.dice_points; reward += Self::REWARD_RATIO * (points - adv_points) as f32; if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); + println!("info: rolled for {reward}"); } // Récompense proportionnelle aux points } @@ -352,7 +343,7 @@ impl TrictracEnvironment { } } - (reward, is_rollpoint) + reward } /// Fait jouer l'adversaire avec une stratégie simple diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index d8b200f..dbd6e53 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -14,25 +14,24 @@ fn main() { // See also MEMORY_SIZE in dqn_model.rs : 8192 let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 40, // 40 - min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 3000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 + num_episodes: 40, // default : 40 + min_steps: 250.0, // min of max steps by episode (mise à jour par la fonction)(default 1000 ?) + max_steps: 3000, // max steps by episode (default 1000 ?) + dense_size: 256, // neural network complexity (default 128) + eps_start: 0.9, // epsilon initial value (0.9 => more exploration) (default 0.9) + eps_end: 0.05, // (default 0.05) // eps_decay higher = epsilon decrease slower // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? + eps_decay: 5000.0, // default 1000 ? - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 50.0, // limite max de correction à apporter au gradient (default 100) }; println!("{conf}----------"); let agent = dqn_model::run::(&conf, false); //true); diff --git a/bot/src/dqn/burnrl_valid/dqn_model.rs b/bot/src/dqn/burnrl_valid/dqn_model.rs deleted file mode 100644 index 4dd5180..0000000 --- a/bot/src/dqn/burnrl_valid/dqn_model.rs +++ /dev/null @@ -1,206 +0,0 @@ -use crate::dqn::burnrl_valid::environment::TrictracEnvironment; -use crate::dqn::burnrl_valid::utils::soft_update_linear; -use burn::module::Module; -use burn::nn::{Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::tensor::activation::relu; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::DQN; -use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Net { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Net { - #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } - - fn consume(self) -> (Linear, Linear, Linear) { - (self.linear_0, self.linear_1, self.linear_2) - } -} - -impl Model, Tensor> for Net { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - relu(self.linear_2.forward(layer_1_output)) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); - - Self { - linear_0: soft_update_linear(linear_0, &that.linear_0, tau), - linear_1: soft_update_linear(linear_1, &that.linear_1, tau), - linear_2: soft_update_linear(linear_2, &that.linear_2, tau), - } - } -} - -#[allow(unused)] -const MEMORY_SIZE: usize = 8192; - -pub struct DqnConfig { - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - -type MyAgent = DQN>; - -#[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, - visualized: bool, -) -> DQN> { - // ) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().max_steps = conf.max_steps; - - let model = Net::::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - - let mut agent = MyAgent::new(model); - - // let config = DQNTrainingConfig::default(); - let config = DQNTrainingConfig { - gamma: conf.gamma, - tau: conf.tau, - learning_rate: conf.learning_rate, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut memory = Memory::::default(); - - let mut optimizer = AdamWConfig::new() - .with_grad_clipping(config.clip_grad.clone()) - .init(); - - let mut policy_net = agent.model().as_ref().unwrap().clone(); - - let mut step = 0_usize; - - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward: ElemType = 0.0; - let mut episode_duration = 0_usize; - let mut state = env.state(); - let mut now = SystemTime::now(); - - while !episode_done { - let eps_threshold = conf.eps_end - + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); - let action = - DQN::>::react_with_exploration(&policy_net, state, eps_threshold); - let snapshot = env.step(action); - - episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); - } - - step += 1; - episode_duration += 1; - - if snapshot.done() || episode_duration >= conf.max_steps { - let envmut = env.as_mut(); - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}", - envmut.pointrolls_count, - now.elapsed().unwrap().as_secs(), - ); - env.reset(); - episode_done = true; - now = SystemTime::now(); - } else { - state = *snapshot.state(); - } - } - } - agent -} diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/dqn/burnrl_valid/environment.rs deleted file mode 100644 index 93e6c14..0000000 --- a/bot/src/dqn/burnrl_valid/environment.rs +++ /dev/null @@ -1,422 +0,0 @@ -use crate::dqn::dqn_common; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - // état avec le plus de choix : mouvement - // choix premier dé : 16 (15 pions + aucun pion), choix deuxième dé 16, x2 ordre dé - 64 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub max_steps: usize, - pub pointrolls_count: usize, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - max_steps: 2000, - pointrolls_count: 0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.step_count = 0; - self.pointrolls_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - // let trictrac_action = Self::convert_action(action); - let trictrac_action = self.convert_valid_action_index(action); - let mut reward = 0.0; - let is_rollpoint: bool; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - } else { - // Action non convertible, pénalité - reward = -1.0; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += 100.0; // Victoire - } else { - reward -= 100.0; // Défaite - } - } - } - let terminated = done || self.step_count >= self.max_steps; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - const ERROR_REWARD: f32 = -1.12121; - const REWARD_RATIO: f32 = 1.0; - - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - fn convert_valid_action_index( - &self, - action: TrictracAction, - ) -> Option { - use dqn_common::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(&self.game); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action: dqn_common::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) { - use dqn_common::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = Self::ERROR_REWARD; - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/dqn/burnrl_valid/main.rs b/bot/src/dqn/burnrl_valid/main.rs deleted file mode 100644 index ee0dd1f..0000000 --- a/bot/src/dqn/burnrl_valid/main.rs +++ /dev/null @@ -1,52 +0,0 @@ -use bot::dqn::burnrl_valid::{ - dqn_model, environment, - utils::{demo_model, load_model, save_model}, -}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 100, // 40 - max_steps: 1000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "bot/models/burn_dqn_valid_40".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/dqn/burnrl_valid/mod.rs b/bot/src/dqn/burnrl_valid/mod.rs deleted file mode 100644 index f4380eb..0000000 --- a/bot/src/dqn/burnrl_valid/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod dqn_model; -pub mod environment; -pub mod utils; diff --git a/bot/src/dqn/burnrl_valid/utils.rs b/bot/src/dqn/burnrl_valid/utils.rs deleted file mode 100644 index 61522e9..0000000 --- a/bot/src/dqn/burnrl_valid/utils.rs +++ /dev/null @@ -1,114 +0,0 @@ -use crate::dqn::burnrl_valid::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index 7f1572e..6eafa27 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,5 +1,3 @@ -pub mod burnrl; pub mod dqn_common; pub mod simple; - -pub mod burnrl_valid; +pub mod burnrl; \ No newline at end of file diff --git a/justfile b/justfile index c35d494..63a66ab 100644 --- a/justfile +++ b/justfile @@ -28,9 +28,9 @@ trainsimple: trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok - ./bot/scripts/trainValid.sh + ./bot/scripts/train.sh plottrainbot: - ./bot/scripts/trainValid.sh plot + ./bot/scripts/train.sh plot debugtrainbot: cargo build --bin=train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn