trictrac/bot/src/dqn/simple/main.rs

110 lines
3.7 KiB
Rust
Raw Normal View History

2025-08-08 18:58:21 +02:00
use bot::dqn::dqn_common::TrictracAction;
use bot::dqn::simple::dqn_model::DqnConfig;
2025-08-01 20:45:57 +02:00
use bot::dqn::simple::dqn_trainer::DqnTrainer;
2025-05-26 20:44:35 +02:00
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
let args: Vec<String> = env::args().collect();
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/dqn_model".to_string();
let mut save_every = 100;
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--model-path" => {
if i + 1 < args.len() {
model_path = args[i + 1].clone();
i += 2;
} else {
eprintln!("Erreur : --model-path nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
println!("Configuration d'entraînement DQN :");
2025-08-17 15:59:53 +02:00
println!(" Épisodes : {episodes}");
println!(" Chemin du modèle : {model_path}");
println!(" Sauvegarde tous les {save_every} épisodes");
2025-05-26 20:44:35 +02:00
println!();
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
// Configuration DQN
let config = DqnConfig {
2025-05-30 20:32:00 +02:00
state_size: 36, // state.to_vec size
2025-05-26 20:44:35 +02:00
hidden_size: 256,
2025-06-01 20:00:15 +02:00
num_actions: TrictracAction::action_space_size(),
2025-05-26 20:44:35 +02:00
learning_rate: 0.001,
gamma: 0.99,
2025-05-30 20:32:00 +02:00
epsilon: 0.9, // Commencer avec plus d'exploration
2025-05-26 20:44:35 +02:00
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
};
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
// Créer et lancer l'entraîneur
let mut trainer = DqnTrainer::new(config);
trainer.train(episodes, save_every, &model_path)?;
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
2025-08-17 15:59:53 +02:00
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
2025-05-30 20:32:00 +02:00
2025-05-26 20:44:35 +02:00
Ok(())
}
fn print_help() {
println!("Entraîneur DQN pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_dqn [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_dqn");
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
2025-05-30 20:32:00 +02:00
}