diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 4a0a95c..c043393 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,10 +13,6 @@ path = "src/dqn/burnrl_valid/main.rs" name = "train_dqn_burn_big" path = "src/dqn/burnrl_big/main.rs" -[[bin]] -name = "train_dqn_burn_before" -path = "src/dqn/burnrl_before/main.rs" - [[bin]] name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" diff --git a/bot/src/dqn/burnrl_before/dqn_model.rs b/bot/src/dqn/burnrl_before/dqn_model.rs deleted file mode 100644 index 02646eb..0000000 --- a/bot/src/dqn/burnrl_before/dqn_model.rs +++ /dev/null @@ -1,211 +0,0 @@ -use crate::dqn::burnrl_before::environment::TrictracEnvironment; -use crate::dqn::burnrl_before::utils::soft_update_linear; -use burn::module::Module; -use burn::nn::{Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::tensor::activation::relu; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::DQN; -use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; -use std::fmt; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Net { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Net { - #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } - - fn consume(self) -> (Linear, Linear, Linear) { - (self.linear_0, self.linear_1, self.linear_2) - } -} - -impl Model, Tensor> for Net { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - relu(self.linear_2.forward(layer_1_output)) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); - - Self { - linear_0: soft_update_linear(linear_0, &that.linear_0, tau), - linear_1: soft_update_linear(linear_1, &that.linear_1, tau), - linear_2: soft_update_linear(linear_2, &that.linear_2, tau), - } - } -} - -#[allow(unused)] -const MEMORY_SIZE: usize = 8192; - -pub struct DqnConfig { - pub min_steps: f32, - pub max_steps: usize, - pub num_episodes: usize, - pub dense_size: usize, - pub eps_start: f64, - pub eps_end: f64, - pub eps_decay: f64, - - pub gamma: f32, - pub tau: f32, - pub learning_rate: f32, - pub batch_size: usize, - pub clip_grad: f32, -} - -impl fmt::Display for DqnConfig { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut s = String::new(); - s.push_str(&format!("min_steps={:?}\n", self.min_steps)); - s.push_str(&format!("max_steps={:?}\n", self.max_steps)); - s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); - s.push_str(&format!("dense_size={:?}\n", self.dense_size)); - s.push_str(&format!("eps_start={:?}\n", self.eps_start)); - s.push_str(&format!("eps_end={:?}\n", self.eps_end)); - s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); - s.push_str(&format!("gamma={:?}\n", self.gamma)); - s.push_str(&format!("tau={:?}\n", self.tau)); - s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); - s.push_str(&format!("batch_size={:?}\n", self.batch_size)); - s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); - write!(f, "{s}") - } -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - min_steps: 250.0, - max_steps: 2000, - num_episodes: 1000, - dense_size: 256, - eps_start: 0.9, - eps_end: 0.05, - eps_decay: 1000.0, - - gamma: 0.999, - tau: 0.005, - learning_rate: 0.001, - batch_size: 32, - clip_grad: 100.0, - } - } -} - -type MyAgent = DQN>; - -#[allow(unused)] -pub fn run, B: AutodiffBackend>( - conf: &DqnConfig, - visualized: bool, -) -> DQN> { - // ) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().min_steps = conf.min_steps; - env.as_mut().max_steps = conf.max_steps; - - let model = Net::::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - - let mut agent = MyAgent::new(model); - - // let config = DQNTrainingConfig::default(); - let config = DQNTrainingConfig { - gamma: conf.gamma, - tau: conf.tau, - learning_rate: conf.learning_rate, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut memory = Memory::::default(); - - let mut optimizer = AdamWConfig::new() - .with_grad_clipping(config.clip_grad.clone()) - .init(); - - let mut policy_net = agent.model().as_ref().unwrap().clone(); - - let mut step = 0_usize; - - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward: ElemType = 0.0; - let mut episode_duration = 0_usize; - let mut state = env.state(); - let mut now = SystemTime::now(); - - while !episode_done { - let eps_threshold = conf.eps_end - + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); - let action = - DQN::>::react_with_exploration(&policy_net, state, eps_threshold); - let snapshot = env.step(action); - - episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); - } - - step += 1; - episode_duration += 1; - - if snapshot.done() || episode_duration >= conf.max_steps { - let envmut = env.as_mut(); - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", - envmut.goodmoves_count, - envmut.pointrolls_count, - now.elapsed().unwrap().as_secs(), - ); - env.reset(); - episode_done = true; - now = SystemTime::now(); - } else { - state = *snapshot.state(); - } - } - } - agent -} diff --git a/bot/src/dqn/burnrl_before/environment.rs b/bot/src/dqn/burnrl_before/environment.rs deleted file mode 100644 index 9925a9a..0000000 --- a/bot/src/dqn/burnrl_before/environment.rs +++ /dev/null @@ -1,449 +0,0 @@ -use crate::dqn::dqn_common_big; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - 1252 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub min_steps: f32, - pub max_steps: usize, - pub pointrolls_count: usize, - pub goodmoves_count: usize, - pub goodmoves_ratio: f32, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - min_steps: 250.0, - max_steps: 2000, - pointrolls_count: 0, - goodmoves_count: 0, - goodmoves_ratio: 0.0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.goodmoves_ratio = if self.step_count == 0 { - 0.0 - } else { - self.goodmoves_count as f32 / self.step_count as f32 - }; - println!( - "info: correct moves: {} ({}%)", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 - ); - self.step_count = 0; - self.pointrolls_count = 0; - self.goodmoves_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = Self::convert_action(action); - - let mut reward = 0.0; - let mut is_rollpoint = false; - let mut terminated = false; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - if reward != Self::ERROR_REWARD { - self.goodmoves_count += 1; - } - } else { - // Action non convertible, pénalité - reward = -0.5; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += 50.0; // Victoire - } else { - reward -= 25.0; // Défaite - } - } - } - let terminated = done || self.step_count >= max_steps.round() as usize; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - const ERROR_REWARD: f32 = -1.12121; - const REWARD_RATIO: f32 = 1.0; - - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - fn convert_valid_action_index( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { - use dqn_common_big::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action:dqn_common_big::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - reward += 0.1; - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - reward += 0.2; - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = Self::ERROR_REWARD; - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/dqn/burnrl_before/main.rs b/bot/src/dqn/burnrl_before/main.rs deleted file mode 100644 index 602ff51..0000000 --- a/bot/src/dqn/burnrl_before/main.rs +++ /dev/null @@ -1,53 +0,0 @@ -use bot::dqn::burnrl_before::{ - dqn_model, environment, - utils::{demo_model, load_model, save_model}, -}; -use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::DQN; -use burn_rl::base::ElemType; - -type Backend = Autodiff>; -type Env = environment::TrictracEnvironment; - -fn main() { - // println!("> Entraînement"); - - // See also MEMORY_SIZE in dqn_model.rs : 8192 - let conf = dqn_model::DqnConfig { - // defaults - num_episodes: 40, // 40 - min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 3000, // 1000 max steps by episode - dense_size: 256, // 128 neural network complexity (default 128) - eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) - eps_end: 0.05, // 0.05 - // eps_decay higher = epsilon decrease slower - // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); - // epsilon is updated at the start of each episode - eps_decay: 2000.0, // 1000 ? - - gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme - tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation - // plus lente moins sensible aux coups de chance - learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais - // converger - batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. - clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100) - }; - println!("{conf}----------"); - let agent = dqn_model::run::(&conf, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - - let path = "models/burn_dqn_40".to_string(); - save_model(valid_agent.model().as_ref().unwrap(), &path); - - println!("> Chargement du modèle pour test"); - let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); -} diff --git a/bot/src/dqn/burnrl_before/mod.rs b/bot/src/dqn/burnrl_before/mod.rs deleted file mode 100644 index f4380eb..0000000 --- a/bot/src/dqn/burnrl_before/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod dqn_model; -pub mod environment; -pub mod utils; diff --git a/bot/src/dqn/burnrl_before/utils.rs b/bot/src/dqn/burnrl_before/utils.rs deleted file mode 100644 index 6c25c5d..0000000 --- a/bot/src/dqn/burnrl_before/utils.rs +++ /dev/null @@ -1,114 +0,0 @@ -use crate::dqn::burnrl_before::{ - dqn_model, - environment::{TrictracAction, TrictracEnvironment}, -}; -use crate::dqn::dqn_common_big::get_valid_action_indices; -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; -use burn::module::{Module, Param, ParamId}; -use burn::nn::Linear; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::backend::Backend; -use burn::tensor::cast::ToElement; -use burn::tensor::Tensor; -use burn_rl::agent::{DQNModel, DQN}; -use burn_rl::base::{Action, ElemType, Environment, State}; - -pub fn save_model(model: &dqn_model::Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}_model.mpk"); - println!("Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}_model.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - -pub fn demo_model>(agent: DQN) { - let mut env = TrictracEnvironment::new(true); - let mut done = false; - while !done { - // let action = match infer_action(&agent, &env, state) { - let action = match infer_action(&agent, &env) { - Some(value) => value, - None => break, - }; - // Execute action - let snapshot = env.step(action); - done = snapshot.done(); - } -} - -fn infer_action>( - agent: &DQN, - env: &TrictracEnvironment, -) -> Option { - let state = env.state(); - // Get q-values - let q_values = agent - .model() - .as_ref() - .unwrap() - .infer(state.to_tensor().unsqueeze()); - // Get valid actions - let valid_actions_indices = get_valid_action_indices(&env.game); - if valid_actions_indices.is_empty() { - return None; // No valid actions, end of episode - } - // Set non valid actions q-values to lowest - let mut masked_q_values = q_values.clone(); - let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - for (index, q_value) in q_values_vec.iter().enumerate() { - if !valid_actions_indices.contains(&index) { - masked_q_values = masked_q_values.clone().mask_fill( - masked_q_values.clone().equal_elem(*q_value), - f32::NEG_INFINITY, - ); - } - } - // Get best action (highest q-value) - let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - let action = TrictracAction::from(action_index); - Some(action) -} - -fn soft_update_tensor( - this: &Param>, - that: &Param>, - tau: ElemType, -) -> Param> { - let that_weight = that.val(); - let this_weight = this.val(); - let new_weight = this_weight * (1.0 - tau) + that_weight * tau; - - Param::initialized(ParamId::new(), new_weight) -} - -pub fn soft_update_linear( - this: Linear, - that: &Linear, - tau: ElemType, -) -> Linear { - let weight = soft_update_tensor(&this.weight, &that.weight, tau); - let bias = match (&this.bias, &that.bias) { - (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), - _ => None, - }; - - Linear:: { weight, bias } -} diff --git a/bot/src/dqn/burnrl_big/environment.rs b/bot/src/dqn/burnrl_big/environment.rs index 9925a9a..ea5a9b4 100644 --- a/bot/src/dqn/burnrl_big/environment.rs +++ b/bot/src/dqn/burnrl_big/environment.rs @@ -165,8 +165,7 @@ impl Environment for TrictracEnvironment { let trictrac_action = Self::convert_action(action); let mut reward = 0.0; - let mut is_rollpoint = false; - let mut terminated = false; + let is_rollpoint; // Exécuter l'action si c'est le tour de l'agent DQN if self.game.active_player_id == self.active_player_id { @@ -372,6 +371,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -379,6 +380,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + // calculate_points = true; // comment to replicate burnrl_before GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -388,7 +390,6 @@ impl TrictracEnvironment { } TurnStage::MarkPoints => { panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -397,16 +398,12 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).0, } } TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -415,11 +412,10 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).1, } } TurnStage::HoldOrGoChoice => { @@ -436,6 +432,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } } } reward diff --git a/bot/src/dqn/burnrl_big/environmentDiverge.rs b/bot/src/dqn/burnrl_big/environmentDiverge.rs deleted file mode 100644 index 6706163..0000000 --- a/bot/src/dqn/burnrl_big/environmentDiverge.rs +++ /dev/null @@ -1,459 +0,0 @@ -use crate::dqn::dqn_common_big; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - 1252 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub min_steps: f32, - pub max_steps: usize, - pub pointrolls_count: usize, - pub goodmoves_count: usize, - pub goodmoves_ratio: f32, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - min_steps: 250.0, - max_steps: 2000, - pointrolls_count: 0, - goodmoves_count: 0, - goodmoves_ratio: 0.0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.goodmoves_ratio = if self.step_count == 0 { - 0.0 - } else { - self.goodmoves_count as f32 / self.step_count as f32 - }; - println!( - "info: correct moves: {} ({}%)", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 - ); - self.step_count = 0; - self.pointrolls_count = 0; - self.goodmoves_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = Self::convert_action(action); - - let mut reward = 0.0; - let is_rollpoint; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - if reward != Self::ERROR_REWARD { - self.goodmoves_count += 1; - } - } else { - // Action non convertible, pénalité - reward = -0.5; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - let max_steps = self.min_steps - + (self.max_steps as f32 - self.min_steps) - * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += 50.0; // Victoire - } else { - reward -= 25.0; // Défaite - } - } - } - let terminated = done || self.step_count >= max_steps.round() as usize; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - const ERROR_REWARD: f32 = -1.12121; - const REWARD_RATIO: f32 = 1.0; - - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - fn convert_valid_action_index( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { - use dqn_common_big::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action: dqn_common_big::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { - use dqn_common_big::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - reward += 0.1; - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - reward += 0.2; - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = Self::ERROR_REWARD; - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - // calculate_points = true; // comment to replicate burnrl_before - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - GameEvent::Mark { - player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).0, - } - } - TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - } - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs index ebc01a4..7b12487 100644 --- a/bot/src/dqn/mod.rs +++ b/bot/src/dqn/mod.rs @@ -1,5 +1,4 @@ pub mod burnrl; -pub mod burnrl_before; pub mod burnrl_big; pub mod dqn_common; pub mod dqn_common_big;