burn dqn trainer
This commit is contained in:
parent
cf1175e497
commit
a06b47628e
|
|
@ -9,6 +9,10 @@ edition = "2021"
|
||||||
name = "train_dqn"
|
name = "train_dqn"
|
||||||
path = "src/bin/train_dqn.rs"
|
path = "src/bin/train_dqn.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "train_burn_rl"
|
||||||
|
path = "src/bin/train_burn_rl.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
|
||||||
227
bot/src/bin/train_burn_rl.rs
Normal file
227
bot/src/bin/train_burn_rl.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
|
||||||
|
use bot::strategy::dqn_common::get_valid_actions;
|
||||||
|
use burn_rl::base::Environment;
|
||||||
|
use rand::Rng;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
// Paramètres par défaut
|
||||||
|
let mut episodes = 1000;
|
||||||
|
let mut save_every = 100;
|
||||||
|
let mut max_steps_per_episode = 500;
|
||||||
|
|
||||||
|
// Parser les arguments de ligne de commande
|
||||||
|
let mut i = 1;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--episodes" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--save-every" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
save_every = args[i + 1].parse().unwrap_or(100);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--max-steps" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --max-steps nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
print_help();
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("Argument inconnu : {}", args[i]);
|
||||||
|
print_help();
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("=== Entraînement DQN avec Burn-RL ===");
|
||||||
|
println!("Épisodes : {}", episodes);
|
||||||
|
println!("Sauvegarde tous les {} épisodes", save_every);
|
||||||
|
println!("Max steps par épisode : {}", max_steps_per_episode);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Créer l'environnement
|
||||||
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
// Variables pour les statistiques
|
||||||
|
let mut total_rewards = Vec::new();
|
||||||
|
let mut episode_lengths = Vec::new();
|
||||||
|
let mut epsilon = 1.0; // Exploration rate
|
||||||
|
let epsilon_decay = 0.995;
|
||||||
|
let epsilon_min = 0.01;
|
||||||
|
|
||||||
|
println!("Début de l'entraînement...");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
for episode in 1..=episodes {
|
||||||
|
// Reset de l'environnement
|
||||||
|
let mut snapshot = env.reset();
|
||||||
|
let mut episode_reward = 0.0;
|
||||||
|
let mut step = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
step += 1;
|
||||||
|
let current_state = snapshot.state();
|
||||||
|
|
||||||
|
// Obtenir les actions valides selon le contexte du jeu
|
||||||
|
let valid_actions = get_valid_actions(&env.game);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
if env.visualized && episode % 50 == 0 {
|
||||||
|
println!(" Pas d'actions valides disponibles à l'étape {}", step);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sélection d'action epsilon-greedy simple
|
||||||
|
let action = if rng.gen::<f32>() < epsilon {
|
||||||
|
// Exploration : action aléatoire parmi les valides
|
||||||
|
let random_valid_index = rng.gen_range(0..valid_actions.len());
|
||||||
|
TrictracAction {
|
||||||
|
index: random_valid_index as u32,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Exploitation : action simple (première action valide pour l'instant)
|
||||||
|
TrictracAction { index: 0 }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Exécuter l'action
|
||||||
|
snapshot = env.step(action);
|
||||||
|
episode_reward += snapshot.reward();
|
||||||
|
|
||||||
|
if env.visualized && episode % 50 == 0 && step % 10 == 0 {
|
||||||
|
println!(
|
||||||
|
" Episode {}, Step {}, Reward: {:.3}, Action: {}",
|
||||||
|
episode,
|
||||||
|
step,
|
||||||
|
snapshot.reward(),
|
||||||
|
action.index
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vérifier les conditions de fin
|
||||||
|
if snapshot.done() || step >= max_steps_per_episode {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Décroissance epsilon
|
||||||
|
if epsilon > epsilon_min {
|
||||||
|
epsilon *= epsilon_decay;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sauvegarder les statistiques
|
||||||
|
total_rewards.push(episode_reward);
|
||||||
|
episode_lengths.push(step);
|
||||||
|
|
||||||
|
// Affichage des statistiques
|
||||||
|
if episode % save_every == 0 {
|
||||||
|
let avg_reward =
|
||||||
|
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||||
|
let avg_length =
|
||||||
|
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Episode {} | Avg Reward: {:.3} | Avg Length: {} | Epsilon: {:.3}",
|
||||||
|
episode, avg_reward, avg_length, epsilon
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ici on pourrait sauvegarder un modèle si on en avait un
|
||||||
|
println!(" → Checkpoint atteint (pas de modèle à sauvegarder pour l'instant)");
|
||||||
|
} else if episode % 10 == 0 {
|
||||||
|
println!(
|
||||||
|
"Episode {} | Reward: {:.3} | Length: {} | Epsilon: {:.3}",
|
||||||
|
episode, episode_reward, step, epsilon
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Statistiques finales
|
||||||
|
println!();
|
||||||
|
println!("=== Résultats de l'entraînement ===");
|
||||||
|
let final_avg_reward = total_rewards
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(100.min(episodes))
|
||||||
|
.sum::<f32>()
|
||||||
|
/ 100.min(episodes) as f32;
|
||||||
|
let final_avg_length = episode_lengths
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(100.min(episodes))
|
||||||
|
.sum::<usize>()
|
||||||
|
/ 100.min(episodes);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Récompense moyenne (100 derniers épisodes) : {:.3}",
|
||||||
|
final_avg_reward
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Longueur moyenne (100 derniers épisodes) : {}",
|
||||||
|
final_avg_length
|
||||||
|
);
|
||||||
|
println!("Epsilon final : {:.3}", epsilon);
|
||||||
|
|
||||||
|
// Statistiques globales
|
||||||
|
let max_reward = total_rewards
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||||
|
println!("Récompense max : {:.3}", max_reward);
|
||||||
|
println!("Récompense min : {:.3}", min_reward);
|
||||||
|
|
||||||
|
println!();
|
||||||
|
println!("Entraînement terminé avec succès !");
|
||||||
|
println!("L'environnement Burn-RL fonctionne correctement.");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_help() {
|
||||||
|
println!("Entraîneur DQN avec Burn-RL pour Trictrac");
|
||||||
|
println!();
|
||||||
|
println!("USAGE:");
|
||||||
|
println!(" cargo run --bin=train_burn_rl [OPTIONS]");
|
||||||
|
println!();
|
||||||
|
println!("OPTIONS:");
|
||||||
|
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||||
|
println!(" --save-every <NUM> Afficher stats tous les N épisodes (défaut: 100)");
|
||||||
|
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
|
||||||
|
println!(" -h, --help Afficher cette aide");
|
||||||
|
println!();
|
||||||
|
println!("EXEMPLES:");
|
||||||
|
println!(" cargo run --bin=train_burn_rl");
|
||||||
|
println!(" cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200");
|
||||||
|
println!(" cargo run --bin=train_burn_rl -- --max-steps 1000 --episodes 500");
|
||||||
|
println!();
|
||||||
|
println!("NOTES:");
|
||||||
|
println!(" - Utilise l'environnement Burn-RL avec l'espace d'actions compactes");
|
||||||
|
println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple");
|
||||||
|
println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -80,13 +80,13 @@ impl From<TrictracAction> for u32 {
|
||||||
/// Environnement Trictrac pour burn-rl
|
/// Environnement Trictrac pour burn-rl
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TrictracEnvironment {
|
pub struct TrictracEnvironment {
|
||||||
game: GameState,
|
pub game: GameState,
|
||||||
active_player_id: PlayerId,
|
active_player_id: PlayerId,
|
||||||
opponent_id: PlayerId,
|
opponent_id: PlayerId,
|
||||||
current_state: TrictracState,
|
current_state: TrictracState,
|
||||||
episode_reward: f32,
|
episode_reward: f32,
|
||||||
step_count: usize,
|
step_count: usize,
|
||||||
visualized: bool,
|
pub visualized: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Environment for TrictracEnvironment {
|
impl Environment for TrictracEnvironment {
|
||||||
|
|
@ -127,6 +127,9 @@ impl Environment for TrictracEnvironment {
|
||||||
self.game.init_player("DQN Agent");
|
self.game.init_player("DQN Agent");
|
||||||
self.game.init_player("Opponent");
|
self.game.init_player("Opponent");
|
||||||
|
|
||||||
|
// Commencer la partie
|
||||||
|
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||||
|
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
self.episode_reward = 0.0;
|
self.episode_reward = 0.0;
|
||||||
self.step_count = 0;
|
self.step_count = 0;
|
||||||
|
|
@ -161,8 +164,10 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Jouer l'adversaire si c'est son tour
|
// Faire jouer l'adversaire (stratégie simple)
|
||||||
reward += self.play_opponent_if_needed();
|
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
|
reward += self.play_opponent_if_needed();
|
||||||
|
}
|
||||||
|
|
||||||
// Vérifier si la partie est terminée
|
// Vérifier si la partie est terminée
|
||||||
let done = self.game.stage == Stage::Ended
|
let done = self.game.stage == Stage::Ended
|
||||||
|
|
@ -366,13 +371,10 @@ impl TrictracEnvironment {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::Move => {
|
TurnStage::Move => GameEvent::Move {
|
||||||
let (move1, move2) = default_strategy.choose_move();
|
player_id: self.opponent_id,
|
||||||
GameEvent::Move {
|
moves: default_strategy.choose_move(),
|
||||||
player_id: self.opponent_id,
|
},
|
||||||
moves: (move1.mirror(), move2.mirror()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
|
|
@ -382,4 +384,3 @@ impl TrictracEnvironment {
|
||||||
reward
|
reward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -200,3 +200,33 @@ L'approche utilise exactement votre suggestion : pour les mouvements, au lieu d'
|
||||||
3. Les destinations sont automatiquement calculées selon les règles
|
3. Les destinations sont automatiquement calculées selon les règles
|
||||||
|
|
||||||
Cela réduit l'espace d'actions de 99.94% tout en gardant toute la richesse du jeu !
|
Cela réduit l'espace d'actions de 99.94% tout en gardant toute la richesse du jeu !
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
● Parfait ! J'ai créé le fichier train_burn_rl.rs qui utilise votre environnement Burn-RL corrigé.
|
||||||
|
|
||||||
|
Pour lancer l'entraînement avec Burn-RL :
|
||||||
|
|
||||||
|
cargo run --bin=train_burn_rl
|
||||||
|
|
||||||
|
Ou avec des options personnalisées :
|
||||||
|
|
||||||
|
cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200 --max-steps 1000
|
||||||
|
|
||||||
|
Caractéristiques de cet entraîneur :
|
||||||
|
|
||||||
|
✅ Utilise l'environnement Burn-RL que vous avez corrigé
|
||||||
|
✅ Actions contextuelles via get_valid_actions()
|
||||||
|
✅ Politique epsilon-greedy simple pour commencer
|
||||||
|
✅ Statistiques détaillées avec moyennes mobiles
|
||||||
|
✅ Configuration flexible via arguments CLI
|
||||||
|
✅ Logging progressif pour suivre l'entraînement
|
||||||
|
|
||||||
|
Options disponibles :
|
||||||
|
|
||||||
|
- --episodes : nombre d'épisodes (défaut: 1000)
|
||||||
|
- --save-every : fréquence d'affichage des stats (défaut: 100)
|
||||||
|
- --max-steps : nombre max de steps par épisode (défaut: 500)
|
||||||
|
- --help : aide complète
|
||||||
|
|
||||||
|
Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones !
|
||||||
|
|
|
||||||
3
justfile
3
justfile
|
|
@ -19,4 +19,5 @@ pythonlib:
|
||||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||||
trainbot:
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
cargo run --bin=train_dqn
|
# cargo run --bin=train_dqn
|
||||||
|
cargo run --bin=train_burn_rl
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue