diff --git a/bot/src/bin/train_dqn_full.rs b/bot/src/bin/train_dqn_full.rs index 56321b1..f01fc76 100644 --- a/bot/src/bin/train_dqn_full.rs +++ b/bot/src/bin/train_dqn_full.rs @@ -1,10 +1,19 @@ -use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; -use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment}; -use bot::strategy::dqn_common::get_valid_actions; -use burn::optim::AdamConfig; -use burn_rl::base::Environment; +use bot::strategy::burn_dqn_agent::{DqnNetwork, MyBackend}; +use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment, TrictracState}; +use burn::optim::{AdamWConfig, Optimizer}; +use burn_rl::{ + agent::{DQN, DQNTrainingConfig}, + base::{Action, Agent, ElemType, Environment, Memory, State}, +}; use std::env; +const DENSE_SIZE: usize = 128; +const EPS_DECAY: f64 = 1000.0; +const EPS_START: f64 = 0.9; +const EPS_END: f64 = 0.05; + +type MyAgent = DQN>; + fn main() -> Result<(), Box> { env_logger::init(); @@ -71,193 +80,73 @@ fn main() -> Result<(), Box> { // Créer le dossier models s'il n'existe pas std::fs::create_dir_all("models")?; - println!("=== Entraînement DQN complet avec Burn ==="); + println!("=== Entraînement DQN complet avec Burn-RL ==="); println!("Épisodes : {}", episodes); println!("Modèle : {}", model_path); println!("Sauvegarde tous les {} épisodes", save_every); println!("Max steps par épisode : {}", max_steps_per_episode); println!(); - // Configuration DQN - let config = DqnConfig { - state_size: 36, - action_size: 1252, // Espace d'actions réduit via contexte - hidden_size: 256, - learning_rate: 0.001, - gamma: 0.99, - epsilon: 1.0, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - target_update_freq: 100, - }; - - // Créer l'agent et l'environnement - let mut agent = BurnDqnAgent::new(config); - let mut optimizer = AdamConfig::new().init(); - let mut env = TrictracEnvironment::new(true); + let model = DqnNetwork::::new( + TrictracState::size(), + DENSE_SIZE, + TrictracAction::size(), + &Default::default(), + ); + let mut agent = MyAgent::new(model); + let config = DQNTrainingConfig::default(); + let mut memory = Memory::::default(); + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + let mut policy_net = agent.model().as_ref().unwrap().clone(); + let mut step = 0_usize; - // Variables pour les statistiques - let mut total_rewards = Vec::new(); - let mut episode_lengths = Vec::new(); - let mut losses = Vec::new(); + for episode in 0..episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); - println!("Début de l'entraînement avec agent DQN complet..."); - println!(); + while !episode_done { + let eps_threshold = + EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY); + let action = MyAgent::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); - for episode in 1..=episodes { - // Reset de l'environnement - let mut snapshot = env.reset(); - let mut episode_reward = 0.0; - let mut step = 0; - let mut episode_loss = 0.0; - let mut loss_count = 0; + episode_reward += >::into(snapshot.reward().clone()); + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = agent.train(&policy_net, &memory, &mut optimizer, &config); + } - loop { step += 1; - let current_state_data = snapshot.state().data; + episode_duration += 1; - // Obtenir les actions valides selon le contexte du jeu - let valid_actions = get_valid_actions(&env.game); - - if valid_actions.is_empty() { - break; - } - - // Convertir les actions Trictrac en indices pour l'agent - let valid_indices: Vec = (0..valid_actions.len()).collect(); - - // Sélectionner une action avec l'agent DQN - let action_index = agent.select_action( - ¤t_state_data, - &valid_indices, - ); - let action = TrictracAction { - index: action_index as u32, - }; - - // Exécuter l'action - snapshot = env.step(action); - episode_reward += *snapshot.reward(); - - // Préparer l'expérience pour l'agent - let experience = Experience { - state: current_state_data.to_vec(), - action: action_index, - reward: *snapshot.reward(), - next_state: if snapshot.done() { - None - } else { - Some(snapshot.state().data.to_vec()) - }, - done: snapshot.done(), - }; - - // Ajouter l'expérience au replay buffer - agent.add_experience(experience); - - // Entraîner l'agent - if let Some(loss) = agent.train_step(&mut optimizer) { - episode_loss += loss; - loss_count += 1; - } - - // Vérifier les conditions de fin - if snapshot.done() || step >= max_steps_per_episode { - break; - } - } - - // Calculer la loss moyenne de l'épisode - let avg_loss = if loss_count > 0 { - episode_loss / loss_count as f32 - } else { - 0.0 - }; - - // Sauvegarder les statistiques - total_rewards.push(episode_reward); - episode_lengths.push(step); - losses.push(avg_loss); - - // Affichage des statistiques - if episode % save_every == 0 { - let avg_reward = - total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; - let avg_length = - episode_lengths.iter().rev().take(save_every).sum::() / save_every; - let avg_episode_loss = - losses.iter().rev().take(save_every).sum::() / save_every as f32; - - println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}", - episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size()); - - // Sauvegarder le modèle - let checkpoint_path = format!("{}_{}", model_path, episode); - if let Err(e) = agent.save_model(&checkpoint_path) { - eprintln!("Erreur lors de la sauvegarde : {}", e); + if snapshot.done() || episode_duration >= TrictracEnvironment::MAX_STEPS { + env.reset(); + episode_done = true; + println!( + "{{\"episode\": {}, \"reward\": {:.4}, \"duration\": {}}}", + episode, episode_reward, episode_duration + ); } else { - println!(" → Modèle sauvegardé : {}", checkpoint_path); + state = *snapshot.state(); } - } else if episode % 10 == 0 { - println!( - "Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", - episode, - episode_reward, - step, - avg_loss, - agent.get_epsilon() - ); } } // Sauvegarder le modèle final let final_path = format!("{}_final", model_path); - agent.save_model(&final_path)?; - - // Statistiques finales - println!(); - println!("=== Résultats de l'entraînement ==="); - let final_avg_reward = total_rewards - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes) as f32; - let final_avg_length = episode_lengths - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes); - let final_avg_loss = - losses.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; - - println!( - "Récompense moyenne (100 derniers épisodes) : {:.3}", - final_avg_reward - ); - println!( - "Longueur moyenne (100 derniers épisodes) : {}", - final_avg_length - ); - println!( - "Loss moyenne (100 derniers épisodes) : {:.6}", - final_avg_loss - ); - println!("Epsilon final : {:.3}", agent.get_epsilon()); - println!("Taille du buffer final : {}", agent.get_buffer_size()); - - // Statistiques globales - let max_reward = total_rewards - .iter() - .cloned() - .fold(f32::NEG_INFINITY, f32::max); - let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); - println!("Récompense max : {:.3}", max_reward); - println!("Récompense min : {:.3}", min_reward); + // agent.save_model(&final_path)?; println!(); println!("Entraînement terminé avec succès !"); diff --git a/bot/src/strategy/burn_dqn_agent.rs b/bot/src/strategy/burn_dqn_agent.rs index 1f1c01a..99ded67 100644 --- a/bot/src/strategy/burn_dqn_agent.rs +++ b/bot/src/strategy/burn_dqn_agent.rs @@ -1,13 +1,11 @@ use burn::{ backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, module::Module, - nn::{loss::MseLoss, Linear, LinearConfig}, - optim::Optimizer, - record::{CompactRecorder, Recorder}, - tensor::Tensor, + nn::{Linear, LinearConfig}, + tensor::{activation::relu, backend::Backend, Tensor}, }; +use burn_rl::agent::DQNModel; use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; /// Backend utilisé pour l'entraînement (Autodiff + NdArray) pub type MyBackend = Autodiff; @@ -16,14 +14,14 @@ pub type InferenceBackend = NdArray; pub type MyDevice = NdArrayDevice; /// Réseau de neurones pour DQN -#[derive(Module, Debug)] -pub struct DqnNetwork { +#[derive(Module, Debug, Clone)] +pub struct DqnNetwork { fc1: Linear, fc2: Linear, fc3: Linear, } -impl DqnNetwork { +impl DqnNetwork { /// Crée un nouveau réseau DQN pub fn new( input_size: usize, @@ -38,14 +36,46 @@ impl DqnNetwork { Self { fc1, fc2, fc3 } } - /// Forward pass du réseau - pub fn forward(&self, input: Tensor) -> Tensor { + fn consume(self) -> (Linear, Linear, Linear) { + (self.fc1, self.fc2, self.fc3) + } +} + +impl burn_rl::base::Model, Tensor> for DqnNetwork { + fn forward(&self, input: Tensor) -> Tensor { let x = self.fc1.forward(input); - let x = burn::tensor::activation::relu(x); + let x = relu(x); let x = self.fc2.forward(x); - let x = burn::tensor::activation::relu(x); + let x = relu(x); self.fc3.forward(x) } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for DqnNetwork { + fn soft_update(this: Self, that: &Self, tau: f32) -> Self { + let (fc1, fc2, fc3) = this.consume(); + Self { + fc1: soft_update_linear(fc1, &that.fc1, tau), + fc2: soft_update_linear(fc2, &that.fc2, tau), + fc3: soft_update_linear(fc3, &that.fc3, tau), + } + } +} + +pub fn soft_update_linear( + this: Linear, + that: &Linear, + tau: f32, +) -> Linear { + let mut updated = this.clone(); + let that_record = that.clone().into_record(); + let updated_record = updated.clone().into_record(); + updated.load_record(updated_record.soft_update(tau, that_record)); + updated } /// Configuration pour l'entraînement DQN @@ -91,215 +121,3 @@ pub struct Experience { pub next_state: Option>, pub done: bool, } - -/// Agent DQN utilisant Burn -pub struct BurnDqnAgent { - config: DqnConfig, - device: MyDevice, - q_network: DqnNetwork, - target_network: DqnNetwork, - replay_buffer: VecDeque, - epsilon: f32, - step_count: usize, -} - -impl BurnDqnAgent { - /// Crée un nouvel agent DQN - pub fn new(config: DqnConfig) -> Self { - let device = MyDevice::default(); - - let q_network = DqnNetwork::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - let target_network = DqnNetwork::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - Self { - config: config.clone(), - device, - q_network, - target_network, - replay_buffer: VecDeque::new(), - epsilon: config.epsilon, - step_count: 0, - } - } - - /// Sélectionne une action avec epsilon-greedy - pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize { - if valid_actions.is_empty() { - // Retourne une action par défaut ou une action "nulle" si aucune n'est valide - // Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte - return 0; - } - - // Exploration epsilon-greedy - if rand::random::() < self.epsilon { - let random_index = rand::random::() % valid_actions.len(); - return valid_actions[random_index]; - } - - // Exploitation : choisir la meilleure action selon le Q-network - let state_tensor = Tensor::::from_floats(state, &self.device) - .reshape([1, self.config.state_size]); - let q_values = self.q_network.forward(state_tensor); - - // Convertir en vecteur pour traitement - let q_data = q_values.into_data().convert::().into_vec().unwrap(); - - // Trouver la meilleure action parmi les actions valides - let mut best_action = valid_actions[0]; - let mut best_q_value = f32::NEG_INFINITY; - - for &action in valid_actions { - if action < q_data.len() && q_data[action] > best_q_value { - best_q_value = q_data[action]; - best_action = action; - } - } - - best_action - } - - /// Ajoute une expérience au replay buffer - pub fn add_experience(&mut self, experience: Experience) { - if self.replay_buffer.len() >= self.config.replay_buffer_size { - self.replay_buffer.pop_front(); - } - self.replay_buffer.push_back(experience); - } - - /// Entraîne le réseau sur un batch d'expériences - pub fn train_step( - &mut self, - optimizer: &mut impl Optimizer, MyBackend>, - ) -> Option { - if self.replay_buffer.len() < self.config.batch_size { - return None; - } - - // Échantillonner un batch d'expériences - let batch = self.sample_batch(); - - // Préparer les tenseurs d'état - let states: Vec = batch.iter().flat_map(|exp| exp.state.clone()).collect(); - let state_tensor = Tensor::::from_floats(states.as_slice(), &self.device) - .reshape([self.config.batch_size, self.config.state_size]); - - // Calculer les Q-values actuelles - let current_q_values = self.q_network.forward(state_tensor); - - // Pour l'instant, version simplifiée sans calcul de target - let target_q_values = current_q_values.clone(); - - // Calculer la loss MSE - let loss = MseLoss::new().forward( - current_q_values, - target_q_values, - burn::nn::loss::Reduction::Mean, - ); - - // Backpropagation (version simplifiée) - let grads = loss.backward(); - // Gradients linked to each parameter of the model. - let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network); - self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads); - - // Mise à jour du réseau cible - self.step_count += 1; - if self.step_count % self.config.target_update_freq == 0 { - self.update_target_network(); - } - - // Décroissance d'epsilon - if self.epsilon > self.config.epsilon_min { - self.epsilon *= self.config.epsilon_decay; - } - - Some(loss.into_scalar()) - } - - /// Échantillonne un batch d'expériences du replay buffer - fn sample_batch(&self) -> Vec { - let mut batch = Vec::new(); - let buffer_size = self.replay_buffer.len(); - - for _ in 0..self.config.batch_size.min(buffer_size) { - let index = rand::random::() % buffer_size; - if let Some(exp) = self.replay_buffer.get(index) { - batch.push(exp.clone()); - } - } - - batch - } - - /// Met à jour le réseau cible avec les poids du réseau principal - fn update_target_network(&mut self) { - // Copie simple des poids - self.target_network = self.q_network.clone(); - } - - /// Sauvegarde le modèle - pub fn save_model(&self, path: &str) -> Result<(), Box> { - // Sauvegarder la configuration - let config_path = format!("{}_config.json", path); - let config_json = serde_json::to_string_pretty(&self.config)?; - std::fs::write(config_path, config_json)?; - - // Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend) - let inference_network = self.q_network.clone().into_record(); - let recorder = CompactRecorder::new(); - - let model_path = format!("{}_model.burn", path); - recorder.record(inference_network, model_path.into())?; - - println!("Modèle sauvegardé : {}", path); - Ok(()) - } - - /// Charge un modèle pour l'inférence - pub fn load_model_for_inference( - path: &str, - ) -> Result<(DqnNetwork, DqnConfig), Box> { - // Charger la configuration - let config_path = format!("{}_config.json", path); - let config_json = std::fs::read_to_string(config_path)?; - let config: DqnConfig = serde_json::from_str(&config_json)?; - - // Créer le réseau pour l'inférence - let device = NdArrayDevice::default(); - let network = DqnNetwork::::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - // Charger les poids - let model_path = format!("{}_model.burn", path); - let recorder = CompactRecorder::new(); - let record = recorder.load(model_path.into(), &device)?; - let network = network.load_record(record); - - Ok((network, config)) - } - - /// Retourne l'epsilon actuel - pub fn get_epsilon(&self) -> f32 { - self.epsilon - } - - /// Retourne la taille du replay buffer - pub fn get_buffer_size(&self) -> usize { - self.replay_buffer.len() - } -}