From 2e0a874879876ab159cb7f78f2977b0663692f03 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 1 Aug 2025 20:45:57 +0200 Subject: [PATCH] refacto --- bot/Cargo.toml | 8 - bot/src/bin/train_burn_rl.rs | 226 ------------- bot/src/bin/train_dqn.rs | 4 +- bot/src/bin/train_dqn_full.rs | 297 ----------------- bot/src/bot.rs | 0 bot/src/{ => dqn}/burnrl/dqn_model.rs | 7 +- bot/src/{ => dqn}/burnrl/environment.rs | 2 +- bot/src/{ => dqn}/burnrl/main.rs | 0 bot/src/{ => dqn}/burnrl/mod.rs | 0 bot/src/{ => dqn}/burnrl/utils.rs | 4 +- bot/src/{strategy => dqn}/dqn_common.rs | 3 +- bot/src/dqn/mod.rs | 3 + .../{strategy => dqn/simple}/dqn_trainer.rs | 2 +- bot/src/dqn/simple/mod.rs | 1 + bot/src/lib.rs | 3 +- bot/src/strategy/burn_dqn_agent.rs | 305 ------------------ bot/src/strategy/burn_dqn_strategy.rs | 192 ----------- bot/src/strategy/default.rs | 2 +- bot/src/strategy/dqn.rs | 4 +- bot/src/{strategy.rs => strategy/mod.rs} | 4 - justfile | 7 +- 21 files changed, 23 insertions(+), 1051 deletions(-) delete mode 100644 bot/src/bin/train_burn_rl.rs delete mode 100644 bot/src/bin/train_dqn_full.rs delete mode 100644 bot/src/bot.rs rename bot/src/{ => dqn}/burnrl/dqn_model.rs (95%) rename bot/src/{ => dqn}/burnrl/environment.rs (99%) rename bot/src/{ => dqn}/burnrl/main.rs (100%) rename bot/src/{ => dqn}/burnrl/mod.rs (100%) rename bot/src/{ => dqn}/burnrl/utils.rs (95%) rename bot/src/{strategy => dqn}/dqn_common.rs (99%) create mode 100644 bot/src/dqn/mod.rs rename bot/src/{strategy => dqn/simple}/dqn_trainer.rs (99%) create mode 100644 bot/src/dqn/simple/mod.rs delete mode 100644 bot/src/strategy/burn_dqn_agent.rs delete mode 100644 bot/src/strategy/burn_dqn_strategy.rs rename bot/src/{strategy.rs => strategy/mod.rs} (51%) diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 5578fae..4da2866 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,14 +13,6 @@ path = "src/burnrl/main.rs" name = "train_dqn" path = "src/bin/train_dqn.rs" -# [[bin]] -# name = "train_burn_rl" -# path = "src/bin/train_burn_rl.rs" - -[[bin]] -name = "train_dqn_full" -path = "src/bin/train_dqn_full.rs" - [dependencies] pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } diff --git a/bot/src/bin/train_burn_rl.rs b/bot/src/bin/train_burn_rl.rs deleted file mode 100644 index 73337cd..0000000 --- a/bot/src/bin/train_burn_rl.rs +++ /dev/null @@ -1,226 +0,0 @@ -use bot::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use bot::strategy::dqn_common::get_valid_actions; -use burn_rl::base::Environment; -use rand::Rng; -use std::env; - -fn main() -> Result<(), Box> { - env_logger::init(); - - let args: Vec = env::args().collect(); - - // Paramètres par défaut - let mut episodes = 1000; - let mut save_every = 100; - let mut max_steps_per_episode = 500; - - // Parser les arguments de ligne de commande - let mut i = 1; - while i < args.len() { - match args[i].as_str() { - "--episodes" => { - if i + 1 < args.len() { - episodes = args[i + 1].parse().unwrap_or(1000); - i += 2; - } else { - eprintln!("Erreur : --episodes nécessite une valeur"); - std::process::exit(1); - } - } - "--save-every" => { - if i + 1 < args.len() { - save_every = args[i + 1].parse().unwrap_or(100); - i += 2; - } else { - eprintln!("Erreur : --save-every nécessite une valeur"); - std::process::exit(1); - } - } - "--max-steps" => { - if i + 1 < args.len() { - max_steps_per_episode = args[i + 1].parse().unwrap_or(500); - i += 2; - } else { - eprintln!("Erreur : --max-steps nécessite une valeur"); - std::process::exit(1); - } - } - "--help" | "-h" => { - print_help(); - std::process::exit(0); - } - _ => { - eprintln!("Argument inconnu : {}", args[i]); - print_help(); - std::process::exit(1); - } - } - } - - println!("=== Entraînement DQN avec Burn-RL ==="); - println!("Épisodes : {}", episodes); - println!("Sauvegarde tous les {} épisodes", save_every); - println!("Max steps par épisode : {}", max_steps_per_episode); - println!(); - - // Créer l'environnement - let mut env = TrictracEnvironment::new(true); - let mut rng = rand::thread_rng(); - - // Variables pour les statistiques - let mut total_rewards = Vec::new(); - let mut episode_lengths = Vec::new(); - let mut epsilon = 1.0; // Exploration rate - let epsilon_decay = 0.995; - let epsilon_min = 0.01; - - println!("Début de l'entraînement..."); - println!(); - - for episode in 1..=episodes { - // Reset de l'environnement - let mut snapshot = env.reset(); - let mut episode_reward = 0.0; - let mut step = 0; - - loop { - step += 1; - let current_state = snapshot.state(); - - // Obtenir les actions valides selon le contexte du jeu - let valid_actions = get_valid_actions(&env.game); - - if valid_actions.is_empty() { - if env.visualized && episode % 50 == 0 { - println!(" Pas d'actions valides disponibles à l'étape {}", step); - } - break; - } - - // Sélection d'action epsilon-greedy simple - let action = if rng.gen::() < epsilon { - // Exploration : action aléatoire parmi les valides - let random_valid_index = rng.gen_range(0..valid_actions.len()); - TrictracAction { - index: random_valid_index as u32, - } - } else { - // Exploitation : action simple (première action valide pour l'instant) - TrictracAction { index: 0 } - }; - - // Exécuter l'action - snapshot = env.step(action); - episode_reward += snapshot.reward(); - - if env.visualized && episode % 50 == 0 && step % 10 == 0 { - println!( - " Episode {}, Step {}, Reward: {:.3}, Action: {}", - episode, - step, - snapshot.reward(), - action.index - ); - } - - // Vérifier les conditions de fin - if snapshot.done() || step >= max_steps_per_episode { - break; - } - } - - // Décroissance epsilon - if epsilon > epsilon_min { - epsilon *= epsilon_decay; - } - - // Sauvegarder les statistiques - total_rewards.push(episode_reward); - episode_lengths.push(step); - - // Affichage des statistiques - if episode % save_every == 0 { - let avg_reward = - total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; - let avg_length = - episode_lengths.iter().rev().take(save_every).sum::() / save_every; - - println!( - "Episode {} | Avg Reward: {:.3} | Avg Length: {} | Epsilon: {:.3}", - episode, avg_reward, avg_length, epsilon - ); - - // Ici on pourrait sauvegarder un modèle si on en avait un - println!(" → Checkpoint atteint (pas de modèle à sauvegarder pour l'instant)"); - } else if episode % 10 == 0 { - println!( - "Episode {} | Reward: {:.3} | Length: {} | Epsilon: {:.3}", - episode, episode_reward, step, epsilon - ); - } - } - - // Statistiques finales - println!(); - println!("=== Résultats de l'entraînement ==="); - let final_avg_reward = total_rewards - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes) as f32; - let final_avg_length = episode_lengths - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes); - - println!( - "Récompense moyenne (100 derniers épisodes) : {:.3}", - final_avg_reward - ); - println!( - "Longueur moyenne (100 derniers épisodes) : {}", - final_avg_length - ); - println!("Epsilon final : {:.3}", epsilon); - - // Statistiques globales - let max_reward = total_rewards - .iter() - .cloned() - .fold(f32::NEG_INFINITY, f32::max); - let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); - println!("Récompense max : {:.3}", max_reward); - println!("Récompense min : {:.3}", min_reward); - - println!(); - println!("Entraînement terminé avec succès !"); - println!("L'environnement Burn-RL fonctionne correctement."); - - Ok(()) -} - -fn print_help() { - println!("Entraîneur DQN avec Burn-RL pour Trictrac"); - println!(); - println!("USAGE:"); - println!(" cargo run --bin=train_burn_rl [OPTIONS]"); - println!(); - println!("OPTIONS:"); - println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); - println!(" --save-every Afficher stats tous les N épisodes (défaut: 100)"); - println!(" --max-steps Nombre max de steps par épisode (défaut: 500)"); - println!(" -h, --help Afficher cette aide"); - println!(); - println!("EXEMPLES:"); - println!(" cargo run --bin=train_burn_rl"); - println!(" cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200"); - println!(" cargo run --bin=train_burn_rl -- --max-steps 1000 --episodes 500"); - println!(); - println!("NOTES:"); - println!(" - Utilise l'environnement Burn-RL avec l'espace d'actions compactes"); - println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple"); - println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard"); -} diff --git a/bot/src/bin/train_dqn.rs b/bot/src/bin/train_dqn.rs index 8556e34..e0929fb 100644 --- a/bot/src/bin/train_dqn.rs +++ b/bot/src/bin/train_dqn.rs @@ -1,5 +1,5 @@ -use bot::strategy::dqn_common::{DqnConfig, TrictracAction}; -use bot::strategy::dqn_trainer::DqnTrainer; +use bot::dqn::dqn_common::{DqnConfig, TrictracAction}; +use bot::dqn::simple::dqn_trainer::DqnTrainer; use std::env; fn main() -> Result<(), Box> { diff --git a/bot/src/bin/train_dqn_full.rs b/bot/src/bin/train_dqn_full.rs deleted file mode 100644 index 42e90ae..0000000 --- a/bot/src/bin/train_dqn_full.rs +++ /dev/null @@ -1,297 +0,0 @@ -use bot::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; -use bot::strategy::dqn_common::get_valid_actions; -use burn::optim::AdamConfig; -use burn_rl::base::Environment; -use std::env; - -fn main() -> Result<(), Box> { - env_logger::init(); - - let args: Vec = env::args().collect(); - - // Paramètres par défaut - let mut episodes = 1000; - let mut model_path = "models/burn_dqn_model".to_string(); - let mut save_every = 100; - let mut max_steps_per_episode = 500; - - // Parser les arguments de ligne de commande - let mut i = 1; - while i < args.len() { - match args[i].as_str() { - "--episodes" => { - if i + 1 < args.len() { - episodes = args[i + 1].parse().unwrap_or(1000); - i += 2; - } else { - eprintln!("Erreur : --episodes nécessite une valeur"); - std::process::exit(1); - } - } - "--model-path" => { - if i + 1 < args.len() { - model_path = args[i + 1].clone(); - i += 2; - } else { - eprintln!("Erreur : --model-path nécessite une valeur"); - std::process::exit(1); - } - } - "--save-every" => { - if i + 1 < args.len() { - save_every = args[i + 1].parse().unwrap_or(100); - i += 2; - } else { - eprintln!("Erreur : --save-every nécessite une valeur"); - std::process::exit(1); - } - } - "--max-steps" => { - if i + 1 < args.len() { - max_steps_per_episode = args[i + 1].parse().unwrap_or(500); - i += 2; - } else { - eprintln!("Erreur : --max-steps nécessite une valeur"); - std::process::exit(1); - } - } - "--help" | "-h" => { - print_help(); - std::process::exit(0); - } - _ => { - eprintln!("Argument inconnu : {}", args[i]); - print_help(); - std::process::exit(1); - } - } - } - - // Créer le dossier models s'il n'existe pas - std::fs::create_dir_all("models")?; - - println!("=== Entraînement DQN complet avec Burn ==="); - println!("Épisodes : {}", episodes); - println!("Modèle : {}", model_path); - println!("Sauvegarde tous les {} épisodes", save_every); - println!("Max steps par épisode : {}", max_steps_per_episode); - println!(); - - // Configuration DQN - let config = DqnConfig { - state_size: 36, - action_size: 1252, // Espace d'actions réduit via contexte - hidden_size: 256, - learning_rate: 0.001, - gamma: 0.99, - epsilon: 1.0, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - target_update_freq: 100, - }; - - // Créer l'agent et l'environnement - let mut agent = BurnDqnAgent::new(config); - let mut optimizer = AdamConfig::new().init(); - - let mut env = TrictracEnvironment::new(true); - - // Variables pour les statistiques - let mut total_rewards = Vec::new(); - let mut episode_lengths = Vec::new(); - let mut losses = Vec::new(); - - println!("Début de l'entraînement avec agent DQN complet..."); - println!(); - - for episode in 1..=episodes { - // Reset de l'environnement - let mut snapshot = env.reset(); - let mut episode_reward = 0.0; - let mut step = 0; - let mut episode_loss = 0.0; - let mut loss_count = 0; - - loop { - step += 1; - let current_state_data = snapshot.state().data; - - // Obtenir les actions valides selon le contexte du jeu - let valid_actions = get_valid_actions(&env.game); - - if valid_actions.is_empty() { - break; - } - - // Convertir les actions Trictrac en indices pour l'agent - let valid_indices: Vec = (0..valid_actions.len()).collect(); - - // Sélectionner une action avec l'agent DQN - let action_index = agent.select_action(¤t_state_data, &valid_indices); - let action = TrictracAction { - index: action_index as u32, - }; - - // Exécuter l'action - snapshot = env.step(action); - episode_reward += *snapshot.reward(); - - // Préparer l'expérience pour l'agent - let experience = Experience { - state: current_state_data.to_vec(), - action: action_index, - reward: *snapshot.reward(), - next_state: if snapshot.done() { - None - } else { - Some(snapshot.state().data.to_vec()) - }, - done: snapshot.done(), - }; - - // Ajouter l'expérience au replay buffer - agent.add_experience(experience); - - // Entraîner l'agent - if let Some(loss) = agent.train_step(&mut optimizer) { - episode_loss += loss; - loss_count += 1; - } - - // Vérifier les conditions de fin - if snapshot.done() || step >= max_steps_per_episode { - break; - } - } - - // Calculer la loss moyenne de l'épisode - let avg_loss = if loss_count > 0 { - episode_loss / loss_count as f32 - } else { - 0.0 - }; - - // Sauvegarder les statistiques - total_rewards.push(episode_reward); - episode_lengths.push(step); - losses.push(avg_loss); - - // Affichage des statistiques - if episode % save_every == 0 { - let avg_reward = - total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; - let avg_length = - episode_lengths.iter().rev().take(save_every).sum::() / save_every; - let avg_episode_loss = - losses.iter().rev().take(save_every).sum::() / save_every as f32; - - println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}", - episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size()); - - // Sauvegarder le modèle - let checkpoint_path = format!("{}_{}", model_path, episode); - if let Err(e) = agent.save_model(&checkpoint_path) { - eprintln!("Erreur lors de la sauvegarde : {}", e); - } else { - println!(" → Modèle sauvegardé : {}", checkpoint_path); - } - } else if episode % 10 == 0 { - println!( - "Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", - episode, - episode_reward, - step, - avg_loss, - agent.get_epsilon() - ); - } - } - - // Sauvegarder le modèle final - let final_path = format!("{}_final", model_path); - agent.save_model(&final_path)?; - - // Statistiques finales - println!(); - println!("=== Résultats de l'entraînement ==="); - let final_avg_reward = total_rewards - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes) as f32; - let final_avg_length = episode_lengths - .iter() - .rev() - .take(100.min(episodes)) - .sum::() - / 100.min(episodes); - let final_avg_loss = - losses.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; - - println!( - "Récompense moyenne (100 derniers épisodes) : {:.3}", - final_avg_reward - ); - println!( - "Longueur moyenne (100 derniers épisodes) : {}", - final_avg_length - ); - println!( - "Loss moyenne (100 derniers épisodes) : {:.6}", - final_avg_loss - ); - println!("Epsilon final : {:.3}", agent.get_epsilon()); - println!("Taille du buffer final : {}", agent.get_buffer_size()); - - // Statistiques globales - let max_reward = total_rewards - .iter() - .cloned() - .fold(f32::NEG_INFINITY, f32::max); - let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); - println!("Récompense max : {:.3}", max_reward); - println!("Récompense min : {:.3}", min_reward); - - println!(); - println!("Entraînement terminé avec succès !"); - println!("Modèle final sauvegardé : {}", final_path); - println!(); - println!("Pour utiliser le modèle entraîné :"); - println!( - " cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", - model_path - ); - - Ok(()) -} - -fn print_help() { - println!("Entraîneur DQN complet avec Burn pour Trictrac"); - println!(); - println!("USAGE:"); - println!(" cargo run --bin=train_dqn_full [OPTIONS]"); - println!(); - println!("OPTIONS:"); - println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); - println!(" --model-path Chemin de base pour sauvegarder les modèles (défaut: models/burn_dqn_model)"); - println!(" --save-every Sauvegarder le modèle tous les N épisodes (défaut: 100)"); - println!(" --max-steps Nombre max de steps par épisode (défaut: 500)"); - println!(" -h, --help Afficher cette aide"); - println!(); - println!("EXEMPLES:"); - println!(" cargo run --bin=train_dqn_full"); - println!(" cargo run --bin=train_dqn_full -- --episodes 2000 --save-every 200"); - println!(" cargo run --bin=train_dqn_full -- --model-path models/my_model --episodes 500"); - println!(); - println!("FONCTIONNALITÉS:"); - println!(" - Agent DQN complet avec réseau de neurones Burn"); - println!(" - Experience replay buffer avec échantillonnage aléatoire"); - println!(" - Epsilon-greedy avec décroissance automatique"); - println!(" - Target network avec mise à jour périodique"); - println!(" - Sauvegarde automatique des modèles"); - println!(" - Statistiques d'entraînement détaillées"); -} diff --git a/bot/src/bot.rs b/bot/src/bot.rs deleted file mode 100644 index e69de29..0000000 diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/dqn/burnrl/dqn_model.rs similarity index 95% rename from bot/src/burnrl/dqn_model.rs rename to bot/src/dqn/burnrl/dqn_model.rs index 5ceccaf..af0e2dd 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/dqn/burnrl/dqn_model.rs @@ -1,15 +1,14 @@ -use crate::burnrl::utils::soft_update_linear; +use crate::dqn::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::time::{Duration, SystemTime}; +use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; +use std::time::SystemTime; #[derive(Module, Debug)] pub struct Net { diff --git a/bot/src/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs similarity index 99% rename from bot/src/burnrl/environment.rs rename to bot/src/dqn/burnrl/environment.rs index 86ca586..40bcc29 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -1,4 +1,4 @@ -use crate::strategy::dqn_common; +use crate::dqn::dqn_common; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; diff --git a/bot/src/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs similarity index 100% rename from bot/src/burnrl/main.rs rename to bot/src/dqn/burnrl/main.rs diff --git a/bot/src/burnrl/mod.rs b/bot/src/dqn/burnrl/mod.rs similarity index 100% rename from bot/src/burnrl/mod.rs rename to bot/src/dqn/burnrl/mod.rs diff --git a/bot/src/burnrl/utils.rs b/bot/src/dqn/burnrl/utils.rs similarity index 95% rename from bot/src/burnrl/utils.rs rename to bot/src/dqn/burnrl/utils.rs index ece5761..ba04cb6 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/dqn/burnrl/utils.rs @@ -1,5 +1,5 @@ -use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; -use crate::strategy::dqn_common::get_valid_action_indices; +use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::dqn::dqn_common::get_valid_action_indices; use burn::module::{Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/dqn/dqn_common.rs similarity index 99% rename from bot/src/strategy/dqn_common.rs rename to bot/src/dqn/dqn_common.rs index 801e328..3ea0738 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/dqn/dqn_common.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice, GameEvent, PlayerId}; +use store::{CheckerMove, Dice}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -259,7 +259,6 @@ impl SimpleNeuralNetwork { /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; use store::TurnStage; let mut valid_actions = Vec::new(); diff --git a/bot/src/dqn/mod.rs b/bot/src/dqn/mod.rs new file mode 100644 index 0000000..6eafa27 --- /dev/null +++ b/bot/src/dqn/mod.rs @@ -0,0 +1,3 @@ +pub mod dqn_common; +pub mod simple; +pub mod burnrl; \ No newline at end of file diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/dqn/simple/dqn_trainer.rs similarity index 99% rename from bot/src/strategy/dqn_trainer.rs rename to bot/src/dqn/simple/dqn_trainer.rs index 8d9db57..c23b542 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/dqn/simple/dqn_trainer.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction}; +use crate::dqn::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/bot/src/dqn/simple/mod.rs b/bot/src/dqn/simple/mod.rs new file mode 100644 index 0000000..114bd10 --- /dev/null +++ b/bot/src/dqn/simple/mod.rs @@ -0,0 +1 @@ +pub mod dqn_trainer; diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 0dc60c0..65424fc 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,8 +1,7 @@ -pub mod burnrl; +pub mod dqn; pub mod strategy; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; -pub use strategy::burn_dqn_strategy::{create_burn_dqn_strategy, BurnDqnStrategy}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; diff --git a/bot/src/strategy/burn_dqn_agent.rs b/bot/src/strategy/burn_dqn_agent.rs deleted file mode 100644 index 3830fd1..0000000 --- a/bot/src/strategy/burn_dqn_agent.rs +++ /dev/null @@ -1,305 +0,0 @@ -use burn::{ - backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, - module::Module, - nn::{loss::MseLoss, Linear, LinearConfig}, - optim::Optimizer, - record::{CompactRecorder, Recorder}, - tensor::Tensor, -}; -use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; - -/// Backend utilisé pour l'entraînement (Autodiff + NdArray) -pub type MyBackend = Autodiff; -/// Backend utilisé pour l'inférence (NdArray) -pub type InferenceBackend = NdArray; -pub type MyDevice = NdArrayDevice; - -/// Réseau de neurones pour DQN -#[derive(Module, Debug)] -pub struct DqnNetwork { - fc1: Linear, - fc2: Linear, - fc3: Linear, -} - -impl DqnNetwork { - /// Crée un nouveau réseau DQN - pub fn new( - input_size: usize, - hidden_size: usize, - output_size: usize, - device: &B::Device, - ) -> Self { - let fc1 = LinearConfig::new(input_size, hidden_size).init(device); - let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device); - let fc3 = LinearConfig::new(hidden_size, output_size).init(device); - - Self { fc1, fc2, fc3 } - } - - /// Forward pass du réseau - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.fc1.forward(input); - let x = burn::tensor::activation::relu(x); - let x = self.fc2.forward(x); - let x = burn::tensor::activation::relu(x); - self.fc3.forward(x) - } -} - -/// Configuration pour l'entraînement DQN -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DqnConfig { - pub state_size: usize, - pub action_size: usize, - pub hidden_size: usize, - pub learning_rate: f64, - pub gamma: f32, - pub epsilon: f32, - pub epsilon_decay: f32, - pub epsilon_min: f32, - pub replay_buffer_size: usize, - pub batch_size: usize, - pub target_update_freq: usize, -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - state_size: 36, - action_size: 1000, - hidden_size: 256, - learning_rate: 0.001, - gamma: 0.99, - epsilon: 1.0, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - target_update_freq: 100, - } - } -} - -/// Experience pour le replay buffer -#[derive(Debug, Clone)] -pub struct Experience { - pub state: Vec, - pub action: usize, - pub reward: f32, - pub next_state: Option>, - pub done: bool, -} - -/// Agent DQN utilisant Burn -pub struct BurnDqnAgent { - config: DqnConfig, - device: MyDevice, - q_network: DqnNetwork, - target_network: DqnNetwork, - replay_buffer: VecDeque, - epsilon: f32, - step_count: usize, -} - -impl BurnDqnAgent { - /// Crée un nouvel agent DQN - pub fn new(config: DqnConfig) -> Self { - let device = MyDevice::default(); - - let q_network = DqnNetwork::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - let target_network = DqnNetwork::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - Self { - config: config.clone(), - device, - q_network, - target_network, - replay_buffer: VecDeque::new(), - epsilon: config.epsilon, - step_count: 0, - } - } - - /// Sélectionne une action avec epsilon-greedy - pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize { - if valid_actions.is_empty() { - // Retourne une action par défaut ou une action "nulle" si aucune n'est valide - // Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte - return 0; - } - - // Exploration epsilon-greedy - if rand::random::() < self.epsilon { - let random_index = rand::random::() % valid_actions.len(); - return valid_actions[random_index]; - } - - // Exploitation : choisir la meilleure action selon le Q-network - let state_tensor = Tensor::::from_floats(state, &self.device) - .reshape([1, self.config.state_size]); - let q_values = self.q_network.forward(state_tensor); - - // Convertir en vecteur pour traitement - let q_data = q_values.into_data().convert::().into_vec().unwrap(); - - // Trouver la meilleure action parmi les actions valides - let mut best_action = valid_actions[0]; - let mut best_q_value = f32::NEG_INFINITY; - - for &action in valid_actions { - if action < q_data.len() && q_data[action] > best_q_value { - best_q_value = q_data[action]; - best_action = action; - } - } - - best_action - } - - /// Ajoute une expérience au replay buffer - pub fn add_experience(&mut self, experience: Experience) { - if self.replay_buffer.len() >= self.config.replay_buffer_size { - self.replay_buffer.pop_front(); - } - self.replay_buffer.push_back(experience); - } - - /// Entraîne le réseau sur un batch d'expériences - pub fn train_step( - &mut self, - optimizer: &mut impl Optimizer, MyBackend>, - ) -> Option { - if self.replay_buffer.len() < self.config.batch_size { - return None; - } - - // Échantillonner un batch d'expériences - let batch = self.sample_batch(); - - // Préparer les tenseurs d'état - let states: Vec = batch.iter().flat_map(|exp| exp.state.clone()).collect(); - let state_tensor = Tensor::::from_floats(states.as_slice(), &self.device) - .reshape([self.config.batch_size, self.config.state_size]); - - // Calculer les Q-values actuelles - let current_q_values = self.q_network.forward(state_tensor); - - // Pour l'instant, version simplifiée sans calcul de target - let target_q_values = current_q_values.clone(); - - // Calculer la loss MSE - let loss = MseLoss::new().forward( - current_q_values, - target_q_values, - burn::nn::loss::Reduction::Mean, - ); - - // Backpropagation (version simplifiée) - let grads = loss.backward(); - // Gradients linked to each parameter of the model. - let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network); - self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads); - - // Mise à jour du réseau cible - self.step_count += 1; - if self.step_count % self.config.target_update_freq == 0 { - self.update_target_network(); - } - - // Décroissance d'epsilon - if self.epsilon > self.config.epsilon_min { - self.epsilon *= self.config.epsilon_decay; - } - - Some(loss.into_scalar()) - } - - /// Échantillonne un batch d'expériences du replay buffer - fn sample_batch(&self) -> Vec { - let mut batch = Vec::new(); - let buffer_size = self.replay_buffer.len(); - - for _ in 0..self.config.batch_size.min(buffer_size) { - let index = rand::random::() % buffer_size; - if let Some(exp) = self.replay_buffer.get(index) { - batch.push(exp.clone()); - } - } - - batch - } - - /// Met à jour le réseau cible avec les poids du réseau principal - fn update_target_network(&mut self) { - // Copie simple des poids - self.target_network = self.q_network.clone(); - } - - /// Sauvegarde le modèle - pub fn save_model(&self, path: &str) -> Result<(), Box> { - // Sauvegarder la configuration - let config_path = format!("{}_config.json", path); - let config_json = serde_json::to_string_pretty(&self.config)?; - std::fs::write(config_path, config_json)?; - - // Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend) - let inference_network = self.q_network.clone().into_record(); - let recorder = CompactRecorder::new(); - - let model_path = format!("{}_model.burn", path); - recorder.record(inference_network, model_path.into())?; - - println!("Modèle sauvegardé : {}", path); - Ok(()) - } - - /// Charge un modèle pour l'inférence - pub fn load_model_for_inference( - path: &str, - ) -> Result<(DqnNetwork, DqnConfig), Box> { - // Charger la configuration - let config_path = format!("{}_config.json", path); - let config_json = std::fs::read_to_string(config_path)?; - let config: DqnConfig = serde_json::from_str(&config_json)?; - - // Créer le réseau pour l'inférence - let device = NdArrayDevice::default(); - let network = DqnNetwork::::new( - config.state_size, - config.hidden_size, - config.action_size, - &device, - ); - - // Charger les poids - let model_path = format!("{}_model.burn", path); - let recorder = CompactRecorder::new(); - let record = recorder.load(model_path.into(), &device)?; - let network = network.load_record(record); - - Ok((network, config)) - } - - /// Retourne l'epsilon actuel - pub fn get_epsilon(&self) -> f32 { - self.epsilon - } - - /// Retourne la taille du replay buffer - pub fn get_buffer_size(&self) -> usize { - self.replay_buffer.len() - } -} diff --git a/bot/src/strategy/burn_dqn_strategy.rs b/bot/src/strategy/burn_dqn_strategy.rs deleted file mode 100644 index f111def..0000000 --- a/bot/src/strategy/burn_dqn_strategy.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; -use super::burn_dqn_agent::{DqnNetwork, DqnConfig, InferenceBackend}; -use super::dqn_common::get_valid_actions; -use burn::{backend::ndarray::NdArrayDevice, tensor::Tensor}; -use std::path::Path; - -/// Stratégie utilisant un modèle DQN Burn entraîné -#[derive(Debug)] -pub struct BurnDqnStrategy { - pub game: GameState, - pub player_id: PlayerId, - pub color: Color, - network: Option>, - config: Option, - device: NdArrayDevice, -} - -impl Default for BurnDqnStrategy { - fn default() -> Self { - Self { - game: GameState::default(), - player_id: 0, - color: Color::White, - network: None, - config: None, - device: NdArrayDevice::default(), - } - } -} - -impl BurnDqnStrategy { - /// Crée une nouvelle stratégie avec un modèle chargé - pub fn new(model_path: &str) -> Result> { - let mut strategy = Self::default(); - strategy.load_model(model_path)?; - Ok(strategy) - } - - /// Charge un modèle DQN depuis un fichier - pub fn load_model(&mut self, model_path: &str) -> Result<(), Box> { - if !Path::new(&format!("{}_config.json", model_path)).exists() { - return Err(format!("Modèle non trouvé : {}", model_path).into()); - } - - let (network, config) = super::burn_dqn_agent::BurnDqnAgent::load_model_for_inference(model_path)?; - - self.network = Some(network); - self.config = Some(config); - - println!("Modèle DQN Burn chargé depuis : {}", model_path); - Ok(()) - } - - /// Sélectionne la meilleure action selon le modèle DQN - fn select_best_action(&self, valid_actions: &[super::dqn_common::TrictracAction]) -> Option { - if valid_actions.is_empty() { - return None; - } - - // Si pas de réseau chargé, utiliser la première action valide - let Some(network) = &self.network else { - return Some(valid_actions[0].clone()); - }; - - // Convertir l'état du jeu en tensor - let state_vec = self.game.to_vec_float(); - let state_tensor = Tensor::::from_floats(state_vec.as_slice(), &self.device).reshape([1, self.config.as_ref().unwrap().state_size]); - - // Faire une prédiction - let q_values = network.forward(state_tensor); - let q_data = q_values.into_data().convert::().into_vec().unwrap(); - - // Trouver la meilleure action parmi les actions valides - let mut best_action = &valid_actions[0]; - let mut best_q_value = f32::NEG_INFINITY; - - for (i, action) in valid_actions.iter().enumerate() { - if i < q_data.len() && q_data[i] > best_q_value { - best_q_value = q_data[i]; - best_action = action; - } - } - - Some(best_action.clone()) - } - - /// Convertit une TrictracAction en CheckerMove pour les mouvements - fn trictrac_action_to_moves(&self, action: &super::dqn_common::TrictracAction) -> Option<(CheckerMove, CheckerMove)> { - match action { - super::dqn_common::TrictracAction::Move { dice_order, from1, from2 } => { - let dice = self.game.dice; - let (die1, die2) = if *dice_order { - (dice.values.0, dice.values.1) - } else { - (dice.values.1, dice.values.0) - }; - - // Calculer les destinations selon la couleur - let to1 = if self.color == Color::White { - from1 + die1 as usize - } else { - from1.saturating_sub(die1 as usize) - }; - let to2 = if self.color == Color::White { - from2 + die2 as usize - } else { - from2.saturating_sub(die2 as usize) - }; - - // Créer les mouvements - let move1 = CheckerMove::new(*from1, to1).ok()?; - let move2 = CheckerMove::new(*from2, to2).ok()?; - - Some((move1, move2)) - } - _ => None, - } - } -} - -impl BotStrategy for BurnDqnStrategy { - fn get_game(&self) -> &GameState { - &self.game - } - - fn get_mut_game(&mut self) -> &mut GameState { - &mut self.game - } - - fn calculate_points(&self) -> u8 { - // Utiliser le modèle DQN pour décider des points à marquer - // let valid_actions = get_valid_actions(&self.game); - - // Chercher une action Mark dans les actions valides - // for action in &valid_actions { - // if let super::dqn_common::TrictracAction::Mark { points } = action { - // return *points; - // } - // } - - // Par défaut, marquer 0 points - 0 - } - - fn calculate_adv_points(&self) -> u8 { - // Même logique que calculate_points pour les points d'avance - self.calculate_points() - } - - fn choose_move(&self) -> (CheckerMove, CheckerMove) { - let valid_actions = get_valid_actions(&self.game); - - if let Some(best_action) = self.select_best_action(&valid_actions) { - if let Some((move1, move2)) = self.trictrac_action_to_moves(&best_action) { - return (move1, move2); - } - } - - // Fallback: utiliser la stratégie par défaut - let default_strategy = super::default::DefaultStrategy::default(); - default_strategy.choose_move() - } - - fn choose_go(&self) -> bool { - let valid_actions = get_valid_actions(&self.game); - - if let Some(best_action) = self.select_best_action(&valid_actions) { - match best_action { - super::dqn_common::TrictracAction::Go => return true, - super::dqn_common::TrictracAction::Move { .. } => return false, - _ => {} - } - } - - // Par défaut, toujours choisir de continuer - true - } - - fn set_player_id(&mut self, player_id: PlayerId) { - self.player_id = player_id; - } - - fn set_color(&mut self, color: Color) { - self.color = color; - } -} - -/// Factory function pour créer une stratégie DQN Burn depuis un chemin de modèle -pub fn create_burn_dqn_strategy(model_path: &str) -> Result, Box> { - let strategy = BurnDqnStrategy::new(model_path)?; - Ok(Box::new(strategy)) -} \ No newline at end of file diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 81aa5f1..e01f406 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use store::MoveRules; #[derive(Debug)] diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 779ce3d..af08341 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,8 +1,8 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use std::path::Path; use store::MoveRules; -use super::dqn_common::{ +use crate::dqn::dqn_common::{ get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, }; diff --git a/bot/src/strategy.rs b/bot/src/strategy/mod.rs similarity index 51% rename from bot/src/strategy.rs rename to bot/src/strategy/mod.rs index a0ffc7a..3812188 100644 --- a/bot/src/strategy.rs +++ b/bot/src/strategy/mod.rs @@ -1,9 +1,5 @@ -pub mod burn_dqn_agent; -pub mod burn_dqn_strategy; pub mod client; pub mod default; pub mod dqn; -pub mod dqn_common; -pub mod dqn_trainer; pub mod erroneous_moves; pub mod stable_baselines3; diff --git a/justfile b/justfile index 465271e..e7d7222 100644 --- a/justfile +++ b/justfile @@ -9,7 +9,10 @@ shell: runcli: RUST_LOG=info cargo run --bin=client_cli runclibots: - RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,ai + RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy +match: + cargo build --release --bin=client_cli + LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn profile: echo '1' | sudo tee /proc/sys/kernel/perf_event_paranoid cargo build --profile profiling @@ -29,4 +32,4 @@ debugtrainbot: profiletrainbot: echo '1' | sudo tee /proc/sys/kernel/perf_event_paranoid cargo build --profile profiling --bin=train_dqn_burn - LD_LIBRARY_PATH=./target/debug samply record ./target/profiling/train_dqn_burn + LD_LIBRARY_PATH=./target/profiling samply record ./target/profiling/train_dqn_burn