diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index 5716fa1..d5e0028 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -141,7 +141,7 @@ impl Environment for TrictracEnvironment { self.step_count += 1; // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = self.convert_action(action, &self.game); + let trictrac_action = Self::convert_action(action); let mut reward = 0.0; let mut terminated = false; @@ -203,11 +203,7 @@ impl Environment for TrictracEnvironment { impl TrictracEnvironment { /// Convertit une action burn-rl vers une action Trictrac - fn convert_action( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { + pub fn convert_action(action: TrictracAction) -> Option { dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) } diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index 8408e6a..4b3a789 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -31,7 +31,7 @@ fn main() { println!("> Chargement du modèle pour test"); let loaded_model = load_model(conf.dense_size, &path); - let loaded_agent = DQN::new(loaded_model); + let loaded_agent = DQN::new(loaded_model.unwrap()); println!("> Test avec le modèle chargé"); demo_model(loaded_agent); diff --git a/bot/src/dqn/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs index 66fa850..a1d5480 100644 --- a/bot/src/dqn/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -22,23 +22,21 @@ pub fn save_model(model: &dqn_model::Net>, path: &String) { .unwrap(); } -pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net> { +pub fn load_model(dense_size: usize, path: &String) -> Option>> { let model_path = format!("{path}_model.mpk"); - println!("Chargement du modèle depuis : {model_path}"); + // println!("Chargement du modèle depuis : {model_path}"); - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - - let record = recorder - .load(model_path.into(), &device) - .expect("Impossible de charger le modèle"); - - dqn_model::Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + dqn_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() } pub fn demo_model>(agent: DQN) { diff --git a/bot/src/lib.rs b/bot/src/lib.rs index ca338e1..f9a4617 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -5,6 +5,7 @@ use log::{debug, error}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; +pub use strategy::dqnburn::DqnBurnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::random::RandomStrategy; pub use strategy::stable_baselines3::StableBaselines3Strategy; diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs new file mode 100644 index 0000000..4fc0c06 --- /dev/null +++ b/bot/src/strategy/dqnburn.rs @@ -0,0 +1,176 @@ +use burn::backend::NdArray; +use burn::tensor::cast::ToElement; +use burn_rl::base::{ElemType, Model, State}; + +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; +use store::MoveRules; + +use crate::dqn::burnrl::{dqn_model, environment, utils}; +use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; + +type DqnBurnNetwork = dqn_model::Net>; + +/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné +#[derive(Debug)] +pub struct DqnBurnStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, + pub model: Option, +} + +impl Default for DqnBurnStrategy { + fn default() -> Self { + Self { + game: GameState::default(), + player_id: 1, + color: Color::White, + model: None, + } + } +} + +impl DqnBurnStrategy { + pub fn new() -> Self { + Self::default() + } + + pub fn new_with_model(model_path: &String) -> Self { + info!("Loading model {model_path:?}"); + let mut strategy = Self::new(); + strategy.model = utils::load_model(256, model_path); + strategy + } + + /// Utilise le modèle DQN pour choisir une action valide + fn get_dqn_action(&self) -> Option { + if let Some(ref model) = self.model { + let state = environment::TrictracState::from_game_state(&self.game); + let valid_actions_indices = get_valid_action_indices(&self.game); + if valid_actions_indices.is_empty() { + return None; // No valid actions, end of episode + } + + // Obtenir les Q-values pour toutes les actions + let q_values = model.infer(state.to_tensor().unsqueeze()); + + // Set non valid actions q-values to lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions_indices.contains(&index) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + // Get best action (highest q-value) + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + environment::TrictracEnvironment::convert_action(environment::TrictracAction::from( + action_index, + )) + } else { + // Fallback : action aléatoire valide + sample_valid_action(&self.game) + } + } +} + +impl BotStrategy for DqnBurnStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + // Utiliser le DQN pour décider si on continue + if let Some(action) = self.get_dqn_action() { + matches!(action, TrictracAction::Go) + } else { + // Fallback : toujours continuer + true + } + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + // Utiliser le DQN pour choisir le mouvement + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; + } + + // Fallback : utiliser la stratégie par défaut + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + let chosen_move = *possible_moves + .first() + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + chosen_move + } else { + (chosen_move.0.mirror(), chosen_move.1.mirror()) + } + } +} diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 731d1b1..b9fa3b2 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -1,6 +1,7 @@ pub mod client; pub mod default; pub mod dqn; +pub mod dqnburn; pub mod erroneous_moves; pub mod random; pub mod stable_baselines3; diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 8fb1c9e..519adf1 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,5 +1,5 @@ use bot::{ - BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, + BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, }; use itertools::Itertools; @@ -25,11 +25,11 @@ pub struct App { impl App { // Constructs a new instance of [`App`]. pub fn new(args: AppArgs) -> Self { - let bot_strategies: Vec> = - args.bot - .as_deref() - .map(|str_bots| { - str_bots + let bot_strategies: Vec> = args + .bot + .as_deref() + .map(|str_bots| { + str_bots .split(",") .filter_map(|s| match s.trim() { "dummy" => { @@ -44,6 +44,9 @@ impl App { "ai" => Some(Box::new(StableBaselines3Strategy::default()) as Box), "dqn" => Some(Box::new(DqnStrategy::default()) as Box), + "dqnburn" => { + Some(Box::new(DqnBurnStrategy::default()) as Box) + } s if s.starts_with("ai:") => { let path = s.trim_start_matches("ai:"); Some(Box::new(StableBaselines3Strategy::new(path)) @@ -54,11 +57,16 @@ impl App { Some(Box::new(DqnStrategy::new_with_model(path)) as Box) } + s if s.starts_with("dqnburn:") => { + let path = s.trim_start_matches("dqnburn:"); + Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}"))) + as Box) + } _ => None, }) .collect() - }) - .unwrap_or_default(); + }) + .unwrap_or_default(); let schools_enabled = false; let should_quit = bot_strategies.len() > 1; Self { diff --git a/justfile b/justfile index 32193af..dcb5117 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,8 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy + cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk + #cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy # RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn match: cargo build --release --bin=client_cli