diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 135deae..4a0a95c 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -9,6 +9,14 @@ edition = "2021" name = "train_dqn_burn_valid" path = "src/dqn/burnrl_valid/main.rs" +[[bin]] +name = "train_dqn_burn_big" +path = "src/dqn/burnrl_big/main.rs" + +[[bin]] +name = "train_dqn_burn_before" +path = "src/dqn/burnrl_before/main.rs" + [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 9e54c7a..a3be831 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -4,20 +4,23 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 +# BINBOT=train_dqn_burn +BINBOT=train_dqn_burn_big +# BINBOT=train_dqn_burn_before OPPONENT="random" PLOT_EXT="png" train() { - cargo build --release --bin=train_dqn_burn - NAME="train_$(date +%Y-%m-%d_%H:%M:%S)" + cargo build --release --bin=$BINBOT + NAME=$BINBOT"_$(date +%Y-%m-%d_%H:%M:%S)" LOGS="$LOGS_DIR/$NAME.out" mkdir -p "$LOGS_DIR" - LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS" + LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep $BINBOT | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/src/dqn/burnrl_before/dqn_model.rs b/bot/src/dqn/burnrl_before/dqn_model.rs new file mode 100644 index 0000000..02646eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_before::environment::TrictracEnvironment; +use crate::dqn::burnrl_before::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs new file mode 100644 index 0000000..9925a9a --- /dev/null +++ b/bot/src/dqn/burnrl_before/environment.rs @@ -0,0 +1,449 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action:dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_before/main.rs b/bot/src/dqn/burnrl_before/main.rs new file mode 100644 index 0000000..602ff51 --- /dev/null +++ b/bot/src/dqn/burnrl_before/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl_before::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 3000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_before/mod.rs b/bot/src/dqn/burnrl_before/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_before/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs new file mode 100644 index 0000000..6c25c5d --- /dev/null +++ b/bot/src/dqn/burnrl_before/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_before::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_big::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/burnrl_big/dqn_model.rs b/bot/src/dqn/burnrl_big/dqn_model.rs new file mode 100644 index 0000000..f50bf31 --- /dev/null +++ b/bot/src/dqn/burnrl_big/dqn_model.rs @@ -0,0 +1,211 @@ +use crate::dqn::burnrl_big::environment::TrictracEnvironment; +use crate::dqn::burnrl_big::utils::soft_update_linear; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +pub struct DqnConfig { + pub min_steps: f32, + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + pub eps_start: f64, + pub eps_end: f64, + pub eps_decay: f64, + + pub gamma: f32, + pub tau: f32, + pub learning_rate: f32, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for DqnConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("min_steps={:?}\n", self.min_steps)); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("eps_start={:?}\n", self.eps_start)); + s.push_str(&format!("eps_end={:?}\n", self.eps_end)); + s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("tau={:?}\n", self.tau)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); + write!(f, "{s}") + } +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + min_steps: 250.0, + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + eps_start: 0.9, + eps_end: 0.05, + eps_decay: 1000.0, + + gamma: 0.999, + tau: 0.005, + learning_rate: 0.001, + batch_size: 32, + clip_grad: 100.0, + } + } +} + +type MyAgent = DQN>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &DqnConfig, + visualized: bool, +) -> DQN> { + // ) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().min_steps = conf.min_steps; + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + agent +} diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs new file mode 100644 index 0000000..9925a9a --- /dev/null +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -0,0 +1,449 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut terminated = false; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action:dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs new file mode 100644 index 0000000..6706163 --- /dev/null +++ b/bot/src/dqn/burnrl_big/environmentDiverge.rs @@ -0,0 +1,459 @@ +use crate::dqn::dqn_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub min_steps: f32, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + min_steps: 250.0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let is_rollpoint; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != Self::ERROR_REWARD { + self.goodmoves_count += 1; + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + let max_steps = self.min_steps + + (self.max_steps as f32 - self.min_steps) + * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += 50.0; // Victoire + } else { + reward -= 25.0; // Défaite + } + } + } + let terminated = done || self.step_count >= max_steps.round() as usize; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + const ERROR_REWARD: f32 = -1.12121; + const REWARD_RATIO: f32 = 1.0; + + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use dqn_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action: dqn_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { + use dqn_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += Self::REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = Self::ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).0, + } + } + TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/dqn/burnrl_big/main.rs b/bot/src/dqn/burnrl_big/main.rs new file mode 100644 index 0000000..3b72ef8 --- /dev/null +++ b/bot/src/dqn/burnrl_big/main.rs @@ -0,0 +1,53 @@ +use bot::dqn::burnrl_big::{ + dqn_model, environment, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::DQN; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = dqn_model::DqnConfig { + // defaults + num_episodes: 40, // 40 + min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) + max_steps: 3000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) + eps_end: 0.05, // 0.05 + // eps_decay higher = epsilon decrease slower + // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); + // epsilon is updated at the start of each episode + eps_decay: 2000.0, // 1000 ? + + gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) + }; + println!("{conf}----------"); + let agent = dqn_model::run::(&conf, false); //true); + + let valid_agent = agent.valid(); + + println!("> Sauvegarde du modèle de validation"); + + let path = "models/burn_dqn_40".to_string(); + save_model(valid_agent.model().as_ref().unwrap(), &path); + + println!("> Chargement du modèle pour test"); + let loaded_model = load_model(conf.dense_size, &path); + let loaded_agent = DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); +} diff --git a/bot/src/dqn/burnrl_big/mod.rs b/bot/src/dqn/burnrl_big/mod.rs new file mode 100644 index 0000000..f4380eb --- /dev/null +++ b/bot/src/dqn/burnrl_big/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_model; +pub mod environment; +pub mod utils; diff --git a/bot/src/dqn/burnrl_big/utils.rs b/bot/src/dqn/burnrl_big/utils.rs new file mode 100644 index 0000000..9159d57 --- /dev/null +++ b/bot/src/dqn/burnrl_big/utils.rs @@ -0,0 +1,114 @@ +use crate::dqn::burnrl_big::{ + dqn_model, + environment::{TrictracAction, TrictracEnvironment}, +}; +use crate::dqn::dqn_common_big::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &dqn_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}_model.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}_model.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: DQN) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &DQN, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + // Get q-values + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); + // Get valid actions + let valid_actions_indices = get_valid_action_indices(&env.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = TrictracAction::from(action_index); + Some(action) +} + +fn soft_update_tensor( + this: &Param>, + that: &Param>, + tau: ElemType, +) -> Param> { + let that_weight = that.val(); + let this_weight = this.val(); + let new_weight = this_weight * (1.0 - tau) + that_weight * tau; + + Param::initialized(ParamId::new(), new_weight) +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: ElemType, +) -> Linear { + let weight = soft_update_tensor(&this.weight, &that.weight, tau); + let bias = match (&this.bias, &that.bias) { + (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), + _ => None, + }; + + Linear:: { weight, bias } +} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index ab75746..ebc01a4 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,4 +1,6 @@ pub mod burnrl; +pub mod burnrl_before; +pub mod burnrl_big; pub mod dqn_common; pub mod dqn_common_big; pub mod simple;