trictrac/bot/src/bin/train_dqn_full.rs

298 lines
10 KiB
Rust

use bot::burnrl::environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
use bot::strategy::dqn_common::get_valid_actions;
use burn::optim::AdamConfig;
use burn_rl::base::Environment;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/burn_dqn_model".to_string();
let mut save_every = 100;
let mut max_steps_per_episode = 500;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--model-path" => {
if i + 1 < args.len() {
model_path = args[i + 1].clone();
i += 2;
} else {
eprintln!("Erreur : --model-path nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--max-steps" => {
if i + 1 < args.len() {
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
i += 2;
} else {
eprintln!("Erreur : --max-steps nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
println!("=== Entraînement DQN complet avec Burn ===");
println!("Épisodes : {}", episodes);
println!("Modèle : {}", model_path);
println!("Sauvegarde tous les {} épisodes", save_every);
println!("Max steps par épisode : {}", max_steps_per_episode);
println!();
// Configuration DQN
let config = DqnConfig {
state_size: 36,
action_size: 1252, // Espace d'actions réduit via contexte
hidden_size: 256,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 1.0,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
target_update_freq: 100,
};
// Créer l'agent et l'environnement
let mut agent = BurnDqnAgent::new(config);
let mut optimizer = AdamConfig::new().init();
let mut env = TrictracEnvironment::new(true);
// Variables pour les statistiques
let mut total_rewards = Vec::new();
let mut episode_lengths = Vec::new();
let mut losses = Vec::new();
println!("Début de l'entraînement avec agent DQN complet...");
println!();
for episode in 1..=episodes {
// Reset de l'environnement
let mut snapshot = env.reset();
let mut episode_reward = 0.0;
let mut step = 0;
let mut episode_loss = 0.0;
let mut loss_count = 0;
loop {
step += 1;
let current_state_data = snapshot.state().data;
// Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game);
if valid_actions.is_empty() {
break;
}
// Convertir les actions Trictrac en indices pour l'agent
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
// Sélectionner une action avec l'agent DQN
let action_index = agent.select_action(&current_state_data, &valid_indices);
let action = TrictracAction {
index: action_index as u32,
};
// Exécuter l'action
snapshot = env.step(action);
episode_reward += *snapshot.reward();
// Préparer l'expérience pour l'agent
let experience = Experience {
state: current_state_data.to_vec(),
action: action_index,
reward: *snapshot.reward(),
next_state: if snapshot.done() {
None
} else {
Some(snapshot.state().data.to_vec())
},
done: snapshot.done(),
};
// Ajouter l'expérience au replay buffer
agent.add_experience(experience);
// Entraîner l'agent
if let Some(loss) = agent.train_step(&mut optimizer) {
episode_loss += loss;
loss_count += 1;
}
// Vérifier les conditions de fin
if snapshot.done() || step >= max_steps_per_episode {
break;
}
}
// Calculer la loss moyenne de l'épisode
let avg_loss = if loss_count > 0 {
episode_loss / loss_count as f32
} else {
0.0
};
// Sauvegarder les statistiques
total_rewards.push(episode_reward);
episode_lengths.push(step);
losses.push(avg_loss);
// Affichage des statistiques
if episode % save_every == 0 {
let avg_reward =
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
let avg_length =
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
let avg_episode_loss =
losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
// Sauvegarder le modèle
let checkpoint_path = format!("{}_{}", model_path, episode);
if let Err(e) = agent.save_model(&checkpoint_path) {
eprintln!("Erreur lors de la sauvegarde : {}", e);
} else {
println!(" → Modèle sauvegardé : {}", checkpoint_path);
}
} else if episode % 10 == 0 {
println!(
"Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
episode,
episode_reward,
step,
avg_loss,
agent.get_epsilon()
);
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final", model_path);
agent.save_model(&final_path)?;
// Statistiques finales
println!();
println!("=== Résultats de l'entraînement ===");
let final_avg_reward = total_rewards
.iter()
.rev()
.take(100.min(episodes))
.sum::<f32>()
/ 100.min(episodes) as f32;
let final_avg_length = episode_lengths
.iter()
.rev()
.take(100.min(episodes))
.sum::<usize>()
/ 100.min(episodes);
let final_avg_loss =
losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
println!(
"Récompense moyenne (100 derniers épisodes) : {:.3}",
final_avg_reward
);
println!(
"Longueur moyenne (100 derniers épisodes) : {}",
final_avg_length
);
println!(
"Loss moyenne (100 derniers épisodes) : {:.6}",
final_avg_loss
);
println!("Epsilon final : {:.3}", agent.get_epsilon());
println!("Taille du buffer final : {}", agent.get_buffer_size());
// Statistiques globales
let max_reward = total_rewards
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
println!("Récompense max : {:.3}", max_reward);
println!("Récompense min : {:.3}", min_reward);
println!();
println!("Entraînement terminé avec succès !");
println!("Modèle final sauvegardé : {}", final_path);
println!();
println!("Pour utiliser le modèle entraîné :");
println!(
" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy",
model_path
);
Ok(())
}
fn print_help() {
println!("Entraîneur DQN complet avec Burn pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_dqn_full [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/burn_dqn_model)");
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_dqn_full");
println!(" cargo run --bin=train_dqn_full -- --episodes 2000 --save-every 200");
println!(" cargo run --bin=train_dqn_full -- --model-path models/my_model --episodes 500");
println!();
println!("FONCTIONNALITÉS:");
println!(" - Agent DQN complet avec réseau de neurones Burn");
println!(" - Experience replay buffer avec échantillonnage aléatoire");
println!(" - Epsilon-greedy avec décroissance automatique");
println!(" - Target network avec mise à jour périodique");
println!(" - Sauvegarde automatique des modèles");
println!(" - Statistiques d'entraînement détaillées");
}