burn dqn trainer

This commit is contained in:
Henri Bourcereau 2025-06-22 21:25:45 +02:00
parent cf1175e497
commit a06b47628e
5 changed files with 276 additions and 13 deletions

View file

@ -9,6 +9,10 @@ edition = "2021"
name = "train_dqn"
path = "src/bin/train_dqn.rs"
[[bin]]
name = "train_burn_rl"
path = "src/bin/train_burn_rl.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -0,0 +1,227 @@
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::dqn_common::get_valid_actions;
use burn_rl::base::Environment;
use rand::Rng;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut save_every = 100;
let mut max_steps_per_episode = 500;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--max-steps" => {
if i + 1 < args.len() {
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
i += 2;
} else {
eprintln!("Erreur : --max-steps nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
println!("=== Entraînement DQN avec Burn-RL ===");
println!("Épisodes : {}", episodes);
println!("Sauvegarde tous les {} épisodes", save_every);
println!("Max steps par épisode : {}", max_steps_per_episode);
println!();
// Créer l'environnement
let mut env = TrictracEnvironment::new(true);
let mut rng = rand::thread_rng();
// Variables pour les statistiques
let mut total_rewards = Vec::new();
let mut episode_lengths = Vec::new();
let mut epsilon = 1.0; // Exploration rate
let epsilon_decay = 0.995;
let epsilon_min = 0.01;
println!("Début de l'entraînement...");
println!();
for episode in 1..=episodes {
// Reset de l'environnement
let mut snapshot = env.reset();
let mut episode_reward = 0.0;
let mut step = 0;
loop {
step += 1;
let current_state = snapshot.state();
// Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game);
if valid_actions.is_empty() {
if env.visualized && episode % 50 == 0 {
println!(" Pas d'actions valides disponibles à l'étape {}", step);
}
break;
}
// Sélection d'action epsilon-greedy simple
let action = if rng.gen::<f32>() < epsilon {
// Exploration : action aléatoire parmi les valides
let random_valid_index = rng.gen_range(0..valid_actions.len());
TrictracAction {
index: random_valid_index as u32,
}
} else {
// Exploitation : action simple (première action valide pour l'instant)
TrictracAction { index: 0 }
};
// Exécuter l'action
snapshot = env.step(action);
episode_reward += snapshot.reward();
if env.visualized && episode % 50 == 0 && step % 10 == 0 {
println!(
" Episode {}, Step {}, Reward: {:.3}, Action: {}",
episode,
step,
snapshot.reward(),
action.index
);
}
// Vérifier les conditions de fin
if snapshot.done() || step >= max_steps_per_episode {
break;
}
}
// Décroissance epsilon
if epsilon > epsilon_min {
epsilon *= epsilon_decay;
}
// Sauvegarder les statistiques
total_rewards.push(episode_reward);
episode_lengths.push(step);
// Affichage des statistiques
if episode % save_every == 0 {
let avg_reward =
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
let avg_length =
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
println!(
"Episode {} | Avg Reward: {:.3} | Avg Length: {} | Epsilon: {:.3}",
episode, avg_reward, avg_length, epsilon
);
// Ici on pourrait sauvegarder un modèle si on en avait un
println!(" → Checkpoint atteint (pas de modèle à sauvegarder pour l'instant)");
} else if episode % 10 == 0 {
println!(
"Episode {} | Reward: {:.3} | Length: {} | Epsilon: {:.3}",
episode, episode_reward, step, epsilon
);
}
}
// Statistiques finales
println!();
println!("=== Résultats de l'entraînement ===");
let final_avg_reward = total_rewards
.iter()
.rev()
.take(100.min(episodes))
.sum::<f32>()
/ 100.min(episodes) as f32;
let final_avg_length = episode_lengths
.iter()
.rev()
.take(100.min(episodes))
.sum::<usize>()
/ 100.min(episodes);
println!(
"Récompense moyenne (100 derniers épisodes) : {:.3}",
final_avg_reward
);
println!(
"Longueur moyenne (100 derniers épisodes) : {}",
final_avg_length
);
println!("Epsilon final : {:.3}", epsilon);
// Statistiques globales
let max_reward = total_rewards
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
println!("Récompense max : {:.3}", max_reward);
println!("Récompense min : {:.3}", min_reward);
println!();
println!("Entraînement terminé avec succès !");
println!("L'environnement Burn-RL fonctionne correctement.");
Ok(())
}
fn print_help() {
println!("Entraîneur DQN avec Burn-RL pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_burn_rl [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --save-every <NUM> Afficher stats tous les N épisodes (défaut: 100)");
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_burn_rl");
println!(" cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200");
println!(" cargo run --bin=train_burn_rl -- --max-steps 1000 --episodes 500");
println!();
println!("NOTES:");
println!(" - Utilise l'environnement Burn-RL avec l'espace d'actions compactes");
println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple");
println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard");
}

View file

@ -80,13 +80,13 @@ impl From<TrictracAction> for u32 {
/// Environnement Trictrac pour burn-rl
#[derive(Debug)]
pub struct TrictracEnvironment {
game: GameState,
pub game: GameState,
active_player_id: PlayerId,
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
step_count: usize,
visualized: bool,
pub visualized: bool,
}
impl Environment for TrictracEnvironment {
@ -127,6 +127,9 @@ impl Environment for TrictracEnvironment {
self.game.init_player("DQN Agent");
self.game.init_player("Opponent");
// Commencer la partie
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.step_count = 0;
@ -161,8 +164,10 @@ impl Environment for TrictracEnvironment {
}
}
// Jouer l'adversaire si c'est son tour
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
reward += self.play_opponent_if_needed();
}
// Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended
@ -366,13 +371,10 @@ impl TrictracEnvironment {
player_id: self.opponent_id,
}
}
TurnStage::Move => {
let (move1, move2) = default_strategy.choose_move();
GameEvent::Move {
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: (move1.mirror(), move2.mirror()),
}
}
moves: default_strategy.choose_move(),
},
};
if self.game.validate(&event) {
@ -382,4 +384,3 @@ impl TrictracEnvironment {
reward
}
}

View file

@ -200,3 +200,33 @@ L'approche utilise exactement votre suggestion : pour les mouvements, au lieu d'
3. Les destinations sont automatiquement calculées selon les règles
Cela réduit l'espace d'actions de 99.94% tout en gardant toute la richesse du jeu !
---
● Parfait ! J'ai créé le fichier train_burn_rl.rs qui utilise votre environnement Burn-RL corrigé.
Pour lancer l'entraînement avec Burn-RL :
cargo run --bin=train_burn_rl
Ou avec des options personnalisées :
cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200 --max-steps 1000
Caractéristiques de cet entraîneur :
✅ Utilise l'environnement Burn-RL que vous avez corrigé
✅ Actions contextuelles via get_valid_actions()
✅ Politique epsilon-greedy simple pour commencer
✅ Statistiques détaillées avec moyennes mobiles
✅ Configuration flexible via arguments CLI
✅ Logging progressif pour suivre l'entraînement
Options disponibles :
- --episodes : nombre d'épisodes (défaut: 1000)
- --save-every : fréquence d'affichage des stats (défaut: 100)
- --max-steps : nombre max de steps par épisode (défaut: 500)
- --help : aide complète
Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones !

View file

@ -19,4 +19,5 @@ pythonlib:
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot:
#python ./store/python/trainModel.py
cargo run --bin=train_dqn
# cargo run --bin=train_dqn
cargo run --bin=train_burn_rl