2025-08-08 18:58:21 +02:00
|
|
|
use bot::dqn::dqn_common::TrictracAction;
|
|
|
|
|
use bot::dqn::simple::dqn_model::DqnConfig;
|
2025-08-01 20:45:57 +02:00
|
|
|
use bot::dqn::simple::dqn_trainer::DqnTrainer;
|
2025-05-26 20:44:35 +02:00
|
|
|
use std::env;
|
|
|
|
|
|
|
|
|
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
env_logger::init();
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
let args: Vec<String> = env::args().collect();
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
// Paramètres par défaut
|
|
|
|
|
let mut episodes = 1000;
|
|
|
|
|
let mut model_path = "models/dqn_model".to_string();
|
|
|
|
|
let mut save_every = 100;
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
// Parser les arguments de ligne de commande
|
|
|
|
|
let mut i = 1;
|
|
|
|
|
while i < args.len() {
|
|
|
|
|
match args[i].as_str() {
|
|
|
|
|
"--episodes" => {
|
|
|
|
|
if i + 1 < args.len() {
|
|
|
|
|
episodes = args[i + 1].parse().unwrap_or(1000);
|
|
|
|
|
i += 2;
|
|
|
|
|
} else {
|
|
|
|
|
eprintln!("Erreur : --episodes nécessite une valeur");
|
|
|
|
|
std::process::exit(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"--model-path" => {
|
|
|
|
|
if i + 1 < args.len() {
|
|
|
|
|
model_path = args[i + 1].clone();
|
|
|
|
|
i += 2;
|
|
|
|
|
} else {
|
|
|
|
|
eprintln!("Erreur : --model-path nécessite une valeur");
|
|
|
|
|
std::process::exit(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"--save-every" => {
|
|
|
|
|
if i + 1 < args.len() {
|
|
|
|
|
save_every = args[i + 1].parse().unwrap_or(100);
|
|
|
|
|
i += 2;
|
|
|
|
|
} else {
|
|
|
|
|
eprintln!("Erreur : --save-every nécessite une valeur");
|
|
|
|
|
std::process::exit(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"--help" | "-h" => {
|
|
|
|
|
print_help();
|
|
|
|
|
std::process::exit(0);
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
eprintln!("Argument inconnu : {}", args[i]);
|
|
|
|
|
print_help();
|
|
|
|
|
std::process::exit(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
// Créer le dossier models s'il n'existe pas
|
|
|
|
|
std::fs::create_dir_all("models")?;
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
println!("Configuration d'entraînement DQN :");
|
2025-08-17 15:59:53 +02:00
|
|
|
println!(" Épisodes : {episodes}");
|
|
|
|
|
println!(" Chemin du modèle : {model_path}");
|
|
|
|
|
println!(" Sauvegarde tous les {save_every} épisodes");
|
2025-05-26 20:44:35 +02:00
|
|
|
println!();
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
// Configuration DQN
|
|
|
|
|
let config = DqnConfig {
|
2025-05-30 20:32:00 +02:00
|
|
|
state_size: 36, // state.to_vec size
|
2025-05-26 20:44:35 +02:00
|
|
|
hidden_size: 256,
|
2025-06-01 20:00:15 +02:00
|
|
|
num_actions: TrictracAction::action_space_size(),
|
2025-05-26 20:44:35 +02:00
|
|
|
learning_rate: 0.001,
|
|
|
|
|
gamma: 0.99,
|
2025-05-30 20:32:00 +02:00
|
|
|
epsilon: 0.9, // Commencer avec plus d'exploration
|
2025-05-26 20:44:35 +02:00
|
|
|
epsilon_decay: 0.995,
|
|
|
|
|
epsilon_min: 0.01,
|
|
|
|
|
replay_buffer_size: 10000,
|
|
|
|
|
batch_size: 32,
|
|
|
|
|
};
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
// Créer et lancer l'entraîneur
|
|
|
|
|
let mut trainer = DqnTrainer::new(config);
|
|
|
|
|
trainer.train(episodes, save_every, &model_path)?;
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
println!("Entraînement terminé avec succès !");
|
|
|
|
|
println!("Pour utiliser le modèle entraîné :");
|
2025-08-17 15:59:53 +02:00
|
|
|
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
|
2025-05-30 20:32:00 +02:00
|
|
|
|
2025-05-26 20:44:35 +02:00
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn print_help() {
|
|
|
|
|
println!("Entraîneur DQN pour Trictrac");
|
|
|
|
|
println!();
|
|
|
|
|
println!("USAGE:");
|
|
|
|
|
println!(" cargo run --bin=train_dqn [OPTIONS]");
|
|
|
|
|
println!();
|
|
|
|
|
println!("OPTIONS:");
|
|
|
|
|
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
|
|
|
|
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
|
|
|
|
|
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
|
|
|
|
println!(" -h, --help Afficher cette aide");
|
|
|
|
|
println!();
|
|
|
|
|
println!("EXEMPLES:");
|
|
|
|
|
println!(" cargo run --bin=train_dqn");
|
|
|
|
|
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
|
|
|
|
|
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
|
2025-05-30 20:32:00 +02:00
|
|
|
}
|