From a06b47628e979d073f08af016c79ddbcbe865691 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 22 Jun 2025 21:25:45 +0200 Subject: [PATCH] burn dqn trainer --- bot/Cargo.toml | 4 + bot/src/bin/train_burn_rl.rs | 227 +++++++++++++++++++++++++++ bot/src/strategy/burn_environment.rs | 25 +-- doc/refs/claudeAIquestionOnlyRust.md | 30 ++++ justfile | 3 +- 5 files changed, 276 insertions(+), 13 deletions(-) create mode 100644 bot/src/bin/train_burn_rl.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 878f90f..2da1ac1 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -9,6 +9,10 @@ edition = "2021" name = "train_dqn" path = "src/bin/train_dqn.rs" +[[bin]] +name = "train_burn_rl" +path = "src/bin/train_burn_rl.rs" + [dependencies] pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } diff --git a/bot/src/bin/train_burn_rl.rs b/bot/src/bin/train_burn_rl.rs new file mode 100644 index 0000000..6962f84 --- /dev/null +++ b/bot/src/bin/train_burn_rl.rs @@ -0,0 +1,227 @@ +use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment}; +use bot::strategy::dqn_common::get_valid_actions; +use burn_rl::base::Environment; +use rand::Rng; +use std::env; + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = env::args().collect(); + + // Paramètres par défaut + let mut episodes = 1000; + let mut save_every = 100; + let mut max_steps_per_episode = 500; + + // Parser les arguments de ligne de commande + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--episodes" => { + if i + 1 < args.len() { + episodes = args[i + 1].parse().unwrap_or(1000); + i += 2; + } else { + eprintln!("Erreur : --episodes nécessite une valeur"); + std::process::exit(1); + } + } + "--save-every" => { + if i + 1 < args.len() { + save_every = args[i + 1].parse().unwrap_or(100); + i += 2; + } else { + eprintln!("Erreur : --save-every nécessite une valeur"); + std::process::exit(1); + } + } + "--max-steps" => { + if i + 1 < args.len() { + max_steps_per_episode = args[i + 1].parse().unwrap_or(500); + i += 2; + } else { + eprintln!("Erreur : --max-steps nécessite une valeur"); + std::process::exit(1); + } + } + "--help" | "-h" => { + print_help(); + std::process::exit(0); + } + _ => { + eprintln!("Argument inconnu : {}", args[i]); + print_help(); + std::process::exit(1); + } + } + } + + println!("=== Entraînement DQN avec Burn-RL ==="); + println!("Épisodes : {}", episodes); + println!("Sauvegarde tous les {} épisodes", save_every); + println!("Max steps par épisode : {}", max_steps_per_episode); + println!(); + + // Créer l'environnement + let mut env = TrictracEnvironment::new(true); + let mut rng = rand::thread_rng(); + + // Variables pour les statistiques + let mut total_rewards = Vec::new(); + let mut episode_lengths = Vec::new(); + let mut epsilon = 1.0; // Exploration rate + let epsilon_decay = 0.995; + let epsilon_min = 0.01; + + println!("Début de l'entraînement..."); + println!(); + + for episode in 1..=episodes { + // Reset de l'environnement + let mut snapshot = env.reset(); + let mut episode_reward = 0.0; + let mut step = 0; + + loop { + step += 1; + let current_state = snapshot.state(); + + // Obtenir les actions valides selon le contexte du jeu + let valid_actions = get_valid_actions(&env.game); + + if valid_actions.is_empty() { + if env.visualized && episode % 50 == 0 { + println!(" Pas d'actions valides disponibles à l'étape {}", step); + } + break; + } + + // Sélection d'action epsilon-greedy simple + let action = if rng.gen::() < epsilon { + // Exploration : action aléatoire parmi les valides + let random_valid_index = rng.gen_range(0..valid_actions.len()); + TrictracAction { + index: random_valid_index as u32, + } + } else { + // Exploitation : action simple (première action valide pour l'instant) + TrictracAction { index: 0 } + }; + + // Exécuter l'action + snapshot = env.step(action); + episode_reward += snapshot.reward(); + + if env.visualized && episode % 50 == 0 && step % 10 == 0 { + println!( + " Episode {}, Step {}, Reward: {:.3}, Action: {}", + episode, + step, + snapshot.reward(), + action.index + ); + } + + // Vérifier les conditions de fin + if snapshot.done() || step >= max_steps_per_episode { + break; + } + } + + // Décroissance epsilon + if epsilon > epsilon_min { + epsilon *= epsilon_decay; + } + + // Sauvegarder les statistiques + total_rewards.push(episode_reward); + episode_lengths.push(step); + + // Affichage des statistiques + if episode % save_every == 0 { + let avg_reward = + total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; + let avg_length = + episode_lengths.iter().rev().take(save_every).sum::() / save_every; + + println!( + "Episode {} | Avg Reward: {:.3} | Avg Length: {} | Epsilon: {:.3}", + episode, avg_reward, avg_length, epsilon + ); + + // Ici on pourrait sauvegarder un modèle si on en avait un + println!(" → Checkpoint atteint (pas de modèle à sauvegarder pour l'instant)"); + } else if episode % 10 == 0 { + println!( + "Episode {} | Reward: {:.3} | Length: {} | Epsilon: {:.3}", + episode, episode_reward, step, epsilon + ); + } + } + + // Statistiques finales + println!(); + println!("=== Résultats de l'entraînement ==="); + let final_avg_reward = total_rewards + .iter() + .rev() + .take(100.min(episodes)) + .sum::() + / 100.min(episodes) as f32; + let final_avg_length = episode_lengths + .iter() + .rev() + .take(100.min(episodes)) + .sum::() + / 100.min(episodes); + + println!( + "Récompense moyenne (100 derniers épisodes) : {:.3}", + final_avg_reward + ); + println!( + "Longueur moyenne (100 derniers épisodes) : {}", + final_avg_length + ); + println!("Epsilon final : {:.3}", epsilon); + + // Statistiques globales + let max_reward = total_rewards + .iter() + .cloned() + .fold(f32::NEG_INFINITY, f32::max); + let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); + println!("Récompense max : {:.3}", max_reward); + println!("Récompense min : {:.3}", min_reward); + + println!(); + println!("Entraînement terminé avec succès !"); + println!("L'environnement Burn-RL fonctionne correctement."); + + Ok(()) +} + +fn print_help() { + println!("Entraîneur DQN avec Burn-RL pour Trictrac"); + println!(); + println!("USAGE:"); + println!(" cargo run --bin=train_burn_rl [OPTIONS]"); + println!(); + println!("OPTIONS:"); + println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); + println!(" --save-every Afficher stats tous les N épisodes (défaut: 100)"); + println!(" --max-steps Nombre max de steps par épisode (défaut: 500)"); + println!(" -h, --help Afficher cette aide"); + println!(); + println!("EXEMPLES:"); + println!(" cargo run --bin=train_burn_rl"); + println!(" cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200"); + println!(" cargo run --bin=train_burn_rl -- --max-steps 1000 --episodes 500"); + println!(); + println!("NOTES:"); + println!(" - Utilise l'environnement Burn-RL avec l'espace d'actions compactes"); + println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple"); + println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard"); +} + diff --git a/bot/src/strategy/burn_environment.rs b/bot/src/strategy/burn_environment.rs index a9f58ba..df44398 100644 --- a/bot/src/strategy/burn_environment.rs +++ b/bot/src/strategy/burn_environment.rs @@ -80,13 +80,13 @@ impl From for u32 { /// Environnement Trictrac pour burn-rl #[derive(Debug)] pub struct TrictracEnvironment { - game: GameState, + pub game: GameState, active_player_id: PlayerId, opponent_id: PlayerId, current_state: TrictracState, episode_reward: f32, step_count: usize, - visualized: bool, + pub visualized: bool, } impl Environment for TrictracEnvironment { @@ -127,6 +127,9 @@ impl Environment for TrictracEnvironment { self.game.init_player("DQN Agent"); self.game.init_player("Opponent"); + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + self.current_state = TrictracState::from_game_state(&self.game); self.episode_reward = 0.0; self.step_count = 0; @@ -161,8 +164,10 @@ impl Environment for TrictracEnvironment { } } - // Jouer l'adversaire si c'est son tour - reward += self.play_opponent_if_needed(); + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + reward += self.play_opponent_if_needed(); + } // Vérifier si la partie est terminée let done = self.game.stage == Stage::Ended @@ -366,13 +371,10 @@ impl TrictracEnvironment { player_id: self.opponent_id, } } - TurnStage::Move => { - let (move1, move2) = default_strategy.choose_move(); - GameEvent::Move { - player_id: self.opponent_id, - moves: (move1.mirror(), move2.mirror()), - } - } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: default_strategy.choose_move(), + }, }; if self.game.validate(&event) { @@ -382,4 +384,3 @@ impl TrictracEnvironment { reward } } - diff --git a/doc/refs/claudeAIquestionOnlyRust.md b/doc/refs/claudeAIquestionOnlyRust.md index 6739dac..9d4c108 100644 --- a/doc/refs/claudeAIquestionOnlyRust.md +++ b/doc/refs/claudeAIquestionOnlyRust.md @@ -200,3 +200,33 @@ L'approche utilise exactement votre suggestion : pour les mouvements, au lieu d' 3. Les destinations sont automatiquement calculées selon les règles Cela réduit l'espace d'actions de 99.94% tout en gardant toute la richesse du jeu ! + +--- + +● Parfait ! J'ai créé le fichier train_burn_rl.rs qui utilise votre environnement Burn-RL corrigé. + +Pour lancer l'entraînement avec Burn-RL : + +cargo run --bin=train_burn_rl + +Ou avec des options personnalisées : + +cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200 --max-steps 1000 + +Caractéristiques de cet entraîneur : + +✅ Utilise l'environnement Burn-RL que vous avez corrigé +✅ Actions contextuelles via get_valid_actions() +✅ Politique epsilon-greedy simple pour commencer +✅ Statistiques détaillées avec moyennes mobiles +✅ Configuration flexible via arguments CLI +✅ Logging progressif pour suivre l'entraînement + +Options disponibles : + +- --episodes : nombre d'épisodes (défaut: 1000) +- --save-every : fréquence d'affichage des stats (défaut: 100) +- --max-steps : nombre max de steps par épisode (défaut: 500) +- --help : aide complète + +Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones ! diff --git a/justfile b/justfile index 4d75790..b4e2c4b 100644 --- a/justfile +++ b/justfile @@ -19,4 +19,5 @@ pythonlib: pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainbot: #python ./store/python/trainModel.py - cargo run --bin=train_dqn + # cargo run --bin=train_dqn + cargo run --bin=train_burn_rl