From bfd2a4ed475c19f7bd621333be9558460223112f Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 11 Aug 2025 17:24:59 +0200 Subject: [PATCH] burn-rl with valid moves --- bot/Cargo.toml | 4 + bot/scripts/trainValid.sh | 38 +++ bot/src/dqn/burnrl/dqn_model.rs | 5 +- bot/src/dqn/burnrl_valid/dqn_model.rs | 206 ++++++++++++ bot/src/dqn/burnrl_valid/environment.rs | 422 ++++++++++++++++++++++++ bot/src/dqn/burnrl_valid/main.rs | 52 +++ bot/src/dqn/burnrl_valid/mod.rs | 3 + bot/src/dqn/burnrl_valid/utils.rs | 114 +++++++ bot/src/dqn/mod.rs | 4 +- justfile | 4 +- 10 files changed, 845 insertions(+), 7 deletions(-) create mode 100755 bot/scripts/trainValid.sh create mode 100644 bot/src/dqn/burnrl_valid/dqn_model.rs create mode 100644 bot/src/dqn/burnrl_valid/environment.rs create mode 100644 bot/src/dqn/burnrl_valid/main.rs create mode 100644 bot/src/dqn/burnrl_valid/mod.rs create mode 100644 bot/src/dqn/burnrl_valid/utils.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 68ff52d..135deae 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -5,6 +5,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "train_dqn_burn_valid" +path = "src/dqn/burnrl_valid/main.rs" + [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh new file mode 100755 index 0000000..349517d --- /dev/null +++ b/bot/scripts/trainValid.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env sh + +ROOT="$(cd "$(dirname "$0")" && pwd)/../.." +LOGS_DIR="$ROOT/bot/models/logs" + +CFG_SIZE=11 +OPPONENT="random" + +PLOT_EXT="png" + +train() { + cargo build --release --bin=train_dqn_burn_valid + NAME="trainValid_$(date +%Y-%m-%d_%H:%M:%S)" + LOGS="$LOGS_DIR/$NAME.out" + mkdir -p "$LOGS_DIR" + LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn_valid" | tee "$LOGS" +} + +plot() { + NAME=$(ls "$LOGS_DIR" | tail -n 1) + LOGS="$LOGS_DIR/$NAME" + cfgs=$(head -n $CFG_SIZE "$LOGS") + for cfg in $cfgs; do + eval "$cfg" + done + + # tail -n +$((CFG_SIZE + 2)) "$LOGS" + tail -n +$((CFG_SIZE + 2)) "$LOGS" | + grep -v "info:" | + awk -F '[ ,]' '{print $5}' | + feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT" +} + +if [ "$1" = "plot" ]; then + plot +else + train +fi diff --git a/bot/src/dqn/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs index 7e1c797..3e90904 100644 --- a/bot/src/dqn/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -164,7 +164,6 @@ pub fn run, B: AutodiffBackend>( let mut episode_duration = 0_usize; let mut state = env.state(); let mut now = SystemTime::now(); - let mut goodmoves_ratio = 0.0; while !episode_done { let eps_threshold = conf.eps_end @@ -195,13 +194,11 @@ pub fn run, B: AutodiffBackend>( if snapshot.done() || episode_duration >= conf.max_steps { let envmut = env.as_mut(); println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"gm%\": {:.1}, \"rollpoints\":{}, \"duration\": {}}}", + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", envmut.goodmoves_count, - goodmoves_ratio * 100.0, envmut.pointrolls_count, now.elapsed().unwrap().as_secs(), ); - goodmoves_ratio = envmut.goodmoves_ratio; env.reset(); episode_done = true; now = SystemTime::now(); diff --git a/bot/src/dqn/burnrl_valid/dqn_model.rs b/bot/src/dqn/burnrl_valid/dqn_model.rs new file mode 100644 index 0000000..4dd5180 --- /dev/null +++ b/bot/src/dqn/burnrl_valid/dqn_model.rs @@ -0,0 +1,206 @@ +use crate::dqn::burnrl_valid::environment::TrictracEnvironment; +use crate::dqn::burnrl_valid::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/dqn/burnrl_valid/environment.rs new file mode 100644 index 0000000..93e6c14 --- /dev/null +++ b/bot/src/dqn/burnrl_valid/environment.rs @@ -0,0 +1,422 @@ +use crate::dqn::dqn_common; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + // état avec le plus de choix : mouvement + // choix premier dé : 16 (15 pions + aucun pion), choix deuxième dé 16, x2 ordre dé + 64 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub max_steps: usize, + pub pointrolls_count: usize, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + max_steps: 2000, + pointrolls_count: 0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.step_count = 0; + self.pointrolls_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + // let trictrac_action = Self::convert_action(action); + let trictrac_action = self.convert_valid_action_index(action); + let mut reward = 0.0; + let is_rollpoint: bool; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -1.0; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 100.0; // Victoire + } else { + reward -= 100.0; // Défaite + } + } + } + let terminated = done || self.step_count >= self.max_steps; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + ) -> Option { + use dqn_common::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(&self.game); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) { + use dqn_common::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_valid/main.rs b/bot/src/dqn/burnrl_valid/main.rs new file mode 100644 index 0000000..ee0dd1f --- /dev/null +++ b/bot/src/dqn/burnrl_valid/main.rs @@ -0,0 +1,52 @@ +use bot::dqn::burnrl_valid::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 100, // 40 + max_steps: 1000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "bot/models/burn_dqn_valid_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_valid/mod.rs b/bot/src/dqn/burnrl_valid/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_valid/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_valid/utils.rs b/bot/src/dqn/burnrl_valid/utils.rs new file mode 100644 index 0000000..61522e9 --- /dev/null +++ b/bot/src/dqn/burnrl_valid/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_valid::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index 6eafa27..7f1572e 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,3 +1,5 @@ +pub mod burnrl; pub mod dqn_common; pub mod simple; -pub mod burnrl; \ No newline at end of file + +pub mod burnrl_valid; diff --git a/justfile b/justfile index 63a66ab..c35d494 100644 --- a/justfile +++ b/justfile @@ -28,9 +28,9 @@ trainsimple: trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok - ./bot/scripts/train.sh + ./bot/scripts/trainValid.sh plottrainbot: - ./bot/scripts/train.sh plot + ./bot/scripts/trainValid.sh plot debugtrainbot: cargo build --bin=train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn