This commit is contained in:
Henri Bourcereau 2025-08-01 20:45:57 +02:00
parent ad58c0ec60
commit 2e0a874879
21 changed files with 23 additions and 1051 deletions

View file

@ -13,14 +13,6 @@ path = "src/burnrl/main.rs"
name = "train_dqn"
path = "src/bin/train_dqn.rs"
# [[bin]]
# name = "train_burn_rl"
# path = "src/bin/train_burn_rl.rs"
[[bin]]
name = "train_dqn_full"
path = "src/bin/train_dqn_full.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -1,226 +0,0 @@
use bot::burnrl::environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::dqn_common::get_valid_actions;
use burn_rl::base::Environment;
use rand::Rng;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut save_every = 100;
let mut max_steps_per_episode = 500;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--max-steps" => {
if i + 1 < args.len() {
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
i += 2;
} else {
eprintln!("Erreur : --max-steps nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
println!("=== Entraînement DQN avec Burn-RL ===");
println!("Épisodes : {}", episodes);
println!("Sauvegarde tous les {} épisodes", save_every);
println!("Max steps par épisode : {}", max_steps_per_episode);
println!();
// Créer l'environnement
let mut env = TrictracEnvironment::new(true);
let mut rng = rand::thread_rng();
// Variables pour les statistiques
let mut total_rewards = Vec::new();
let mut episode_lengths = Vec::new();
let mut epsilon = 1.0; // Exploration rate
let epsilon_decay = 0.995;
let epsilon_min = 0.01;
println!("Début de l'entraînement...");
println!();
for episode in 1..=episodes {
// Reset de l'environnement
let mut snapshot = env.reset();
let mut episode_reward = 0.0;
let mut step = 0;
loop {
step += 1;
let current_state = snapshot.state();
// Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game);
if valid_actions.is_empty() {
if env.visualized && episode % 50 == 0 {
println!(" Pas d'actions valides disponibles à l'étape {}", step);
}
break;
}
// Sélection d'action epsilon-greedy simple
let action = if rng.gen::<f32>() < epsilon {
// Exploration : action aléatoire parmi les valides
let random_valid_index = rng.gen_range(0..valid_actions.len());
TrictracAction {
index: random_valid_index as u32,
}
} else {
// Exploitation : action simple (première action valide pour l'instant)
TrictracAction { index: 0 }
};
// Exécuter l'action
snapshot = env.step(action);
episode_reward += snapshot.reward();
if env.visualized && episode % 50 == 0 && step % 10 == 0 {
println!(
" Episode {}, Step {}, Reward: {:.3}, Action: {}",
episode,
step,
snapshot.reward(),
action.index
);
}
// Vérifier les conditions de fin
if snapshot.done() || step >= max_steps_per_episode {
break;
}
}
// Décroissance epsilon
if epsilon > epsilon_min {
epsilon *= epsilon_decay;
}
// Sauvegarder les statistiques
total_rewards.push(episode_reward);
episode_lengths.push(step);
// Affichage des statistiques
if episode % save_every == 0 {
let avg_reward =
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
let avg_length =
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
println!(
"Episode {} | Avg Reward: {:.3} | Avg Length: {} | Epsilon: {:.3}",
episode, avg_reward, avg_length, epsilon
);
// Ici on pourrait sauvegarder un modèle si on en avait un
println!(" → Checkpoint atteint (pas de modèle à sauvegarder pour l'instant)");
} else if episode % 10 == 0 {
println!(
"Episode {} | Reward: {:.3} | Length: {} | Epsilon: {:.3}",
episode, episode_reward, step, epsilon
);
}
}
// Statistiques finales
println!();
println!("=== Résultats de l'entraînement ===");
let final_avg_reward = total_rewards
.iter()
.rev()
.take(100.min(episodes))
.sum::<f32>()
/ 100.min(episodes) as f32;
let final_avg_length = episode_lengths
.iter()
.rev()
.take(100.min(episodes))
.sum::<usize>()
/ 100.min(episodes);
println!(
"Récompense moyenne (100 derniers épisodes) : {:.3}",
final_avg_reward
);
println!(
"Longueur moyenne (100 derniers épisodes) : {}",
final_avg_length
);
println!("Epsilon final : {:.3}", epsilon);
// Statistiques globales
let max_reward = total_rewards
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
println!("Récompense max : {:.3}", max_reward);
println!("Récompense min : {:.3}", min_reward);
println!();
println!("Entraînement terminé avec succès !");
println!("L'environnement Burn-RL fonctionne correctement.");
Ok(())
}
fn print_help() {
println!("Entraîneur DQN avec Burn-RL pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_burn_rl [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --save-every <NUM> Afficher stats tous les N épisodes (défaut: 100)");
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_burn_rl");
println!(" cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200");
println!(" cargo run --bin=train_burn_rl -- --max-steps 1000 --episodes 500");
println!();
println!("NOTES:");
println!(" - Utilise l'environnement Burn-RL avec l'espace d'actions compactes");
println!(" - Pour l'instant, implémente seulement une politique epsilon-greedy simple");
println!(" - L'intégration avec un vrai agent DQN peut être ajoutée plus tard");
}

View file

@ -1,5 +1,5 @@
use bot::strategy::dqn_common::{DqnConfig, TrictracAction};
use bot::strategy::dqn_trainer::DqnTrainer;
use bot::dqn::dqn_common::{DqnConfig, TrictracAction};
use bot::dqn::simple::dqn_trainer::DqnTrainer;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {

View file

@ -1,297 +0,0 @@
use bot::burnrl::environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
use bot::strategy::dqn_common::get_valid_actions;
use burn::optim::AdamConfig;
use burn_rl::base::Environment;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/burn_dqn_model".to_string();
let mut save_every = 100;
let mut max_steps_per_episode = 500;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--model-path" => {
if i + 1 < args.len() {
model_path = args[i + 1].clone();
i += 2;
} else {
eprintln!("Erreur : --model-path nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--max-steps" => {
if i + 1 < args.len() {
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
i += 2;
} else {
eprintln!("Erreur : --max-steps nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
println!("=== Entraînement DQN complet avec Burn ===");
println!("Épisodes : {}", episodes);
println!("Modèle : {}", model_path);
println!("Sauvegarde tous les {} épisodes", save_every);
println!("Max steps par épisode : {}", max_steps_per_episode);
println!();
// Configuration DQN
let config = DqnConfig {
state_size: 36,
action_size: 1252, // Espace d'actions réduit via contexte
hidden_size: 256,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 1.0,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
target_update_freq: 100,
};
// Créer l'agent et l'environnement
let mut agent = BurnDqnAgent::new(config);
let mut optimizer = AdamConfig::new().init();
let mut env = TrictracEnvironment::new(true);
// Variables pour les statistiques
let mut total_rewards = Vec::new();
let mut episode_lengths = Vec::new();
let mut losses = Vec::new();
println!("Début de l'entraînement avec agent DQN complet...");
println!();
for episode in 1..=episodes {
// Reset de l'environnement
let mut snapshot = env.reset();
let mut episode_reward = 0.0;
let mut step = 0;
let mut episode_loss = 0.0;
let mut loss_count = 0;
loop {
step += 1;
let current_state_data = snapshot.state().data;
// Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game);
if valid_actions.is_empty() {
break;
}
// Convertir les actions Trictrac en indices pour l'agent
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
// Sélectionner une action avec l'agent DQN
let action_index = agent.select_action(&current_state_data, &valid_indices);
let action = TrictracAction {
index: action_index as u32,
};
// Exécuter l'action
snapshot = env.step(action);
episode_reward += *snapshot.reward();
// Préparer l'expérience pour l'agent
let experience = Experience {
state: current_state_data.to_vec(),
action: action_index,
reward: *snapshot.reward(),
next_state: if snapshot.done() {
None
} else {
Some(snapshot.state().data.to_vec())
},
done: snapshot.done(),
};
// Ajouter l'expérience au replay buffer
agent.add_experience(experience);
// Entraîner l'agent
if let Some(loss) = agent.train_step(&mut optimizer) {
episode_loss += loss;
loss_count += 1;
}
// Vérifier les conditions de fin
if snapshot.done() || step >= max_steps_per_episode {
break;
}
}
// Calculer la loss moyenne de l'épisode
let avg_loss = if loss_count > 0 {
episode_loss / loss_count as f32
} else {
0.0
};
// Sauvegarder les statistiques
total_rewards.push(episode_reward);
episode_lengths.push(step);
losses.push(avg_loss);
// Affichage des statistiques
if episode % save_every == 0 {
let avg_reward =
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
let avg_length =
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
let avg_episode_loss =
losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
// Sauvegarder le modèle
let checkpoint_path = format!("{}_{}", model_path, episode);
if let Err(e) = agent.save_model(&checkpoint_path) {
eprintln!("Erreur lors de la sauvegarde : {}", e);
} else {
println!(" → Modèle sauvegardé : {}", checkpoint_path);
}
} else if episode % 10 == 0 {
println!(
"Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
episode,
episode_reward,
step,
avg_loss,
agent.get_epsilon()
);
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final", model_path);
agent.save_model(&final_path)?;
// Statistiques finales
println!();
println!("=== Résultats de l'entraînement ===");
let final_avg_reward = total_rewards
.iter()
.rev()
.take(100.min(episodes))
.sum::<f32>()
/ 100.min(episodes) as f32;
let final_avg_length = episode_lengths
.iter()
.rev()
.take(100.min(episodes))
.sum::<usize>()
/ 100.min(episodes);
let final_avg_loss =
losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
println!(
"Récompense moyenne (100 derniers épisodes) : {:.3}",
final_avg_reward
);
println!(
"Longueur moyenne (100 derniers épisodes) : {}",
final_avg_length
);
println!(
"Loss moyenne (100 derniers épisodes) : {:.6}",
final_avg_loss
);
println!("Epsilon final : {:.3}", agent.get_epsilon());
println!("Taille du buffer final : {}", agent.get_buffer_size());
// Statistiques globales
let max_reward = total_rewards
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
println!("Récompense max : {:.3}", max_reward);
println!("Récompense min : {:.3}", min_reward);
println!();
println!("Entraînement terminé avec succès !");
println!("Modèle final sauvegardé : {}", final_path);
println!();
println!("Pour utiliser le modèle entraîné :");
println!(
" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy",
model_path
);
Ok(())
}
fn print_help() {
println!("Entraîneur DQN complet avec Burn pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_dqn_full [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/burn_dqn_model)");
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_dqn_full");
println!(" cargo run --bin=train_dqn_full -- --episodes 2000 --save-every 200");
println!(" cargo run --bin=train_dqn_full -- --model-path models/my_model --episodes 500");
println!();
println!("FONCTIONNALITÉS:");
println!(" - Agent DQN complet avec réseau de neurones Burn");
println!(" - Experience replay buffer avec échantillonnage aléatoire");
println!(" - Epsilon-greedy avec décroissance automatique");
println!(" - Target network avec mise à jour périodique");
println!(" - Sauvegarde automatique des modèles");
println!(" - Statistiques d'entraînement détaillées");
}

View file

View file

@ -1,15 +1,14 @@
use crate::burnrl::utils::soft_update_linear;
use crate::dqn::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::{Duration, SystemTime};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {

View file

@ -1,4 +1,4 @@
use crate::strategy::dqn_common;
use crate::dqn::dqn_common;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};

View file

@ -1,5 +1,5 @@
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::strategy::dqn_common::get_valid_action_indices;
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;

View file

@ -1,7 +1,7 @@
use std::cmp::{max, min};
use serde::{Deserialize, Serialize};
use store::{CheckerMove, Dice, GameEvent, PlayerId};
use store::{CheckerMove, Dice};
/// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@ -259,7 +259,6 @@ impl SimpleNeuralNetwork {
/// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use crate::PointsRules;
use store::TurnStage;
let mut valid_actions = Vec::new();

3
bot/src/dqn/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod dqn_common;
pub mod simple;
pub mod burnrl;

View file

@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
use crate::dqn::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -0,0 +1 @@
pub mod dqn_trainer;

View file

@ -1,8 +1,7 @@
pub mod burnrl;
pub mod dqn;
pub mod strategy;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::burn_dqn_strategy::{create_burn_dqn_strategy, BurnDqnStrategy};
pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy;
pub use strategy::erroneous_moves::ErroneousStrategy;

View file

@ -1,305 +0,0 @@
use burn::{
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
module::Module,
nn::{loss::MseLoss, Linear, LinearConfig},
optim::Optimizer,
record::{CompactRecorder, Recorder},
tensor::Tensor,
};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
/// Backend utilisé pour l'entraînement (Autodiff + NdArray)
pub type MyBackend = Autodiff<NdArray>;
/// Backend utilisé pour l'inférence (NdArray)
pub type InferenceBackend = NdArray;
pub type MyDevice = NdArrayDevice;
/// Réseau de neurones pour DQN
#[derive(Module, Debug)]
pub struct DqnNetwork<B: burn::prelude::Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
fc3: Linear<B>,
}
impl<B: burn::prelude::Backend> DqnNetwork<B> {
/// Crée un nouveau réseau DQN
pub fn new(
input_size: usize,
hidden_size: usize,
output_size: usize,
device: &B::Device,
) -> Self {
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
Self { fc1, fc2, fc3 }
}
/// Forward pass du réseau
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.fc1.forward(input);
let x = burn::tensor::activation::relu(x);
let x = self.fc2.forward(x);
let x = burn::tensor::activation::relu(x);
self.fc3.forward(x)
}
}
/// Configuration pour l'entraînement DQN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DqnConfig {
pub state_size: usize,
pub action_size: usize,
pub hidden_size: usize,
pub learning_rate: f64,
pub gamma: f32,
pub epsilon: f32,
pub epsilon_decay: f32,
pub epsilon_min: f32,
pub replay_buffer_size: usize,
pub batch_size: usize,
pub target_update_freq: usize,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
state_size: 36,
action_size: 1000,
hidden_size: 256,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 1.0,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
target_update_freq: 100,
}
}
}
/// Experience pour le replay buffer
#[derive(Debug, Clone)]
pub struct Experience {
pub state: Vec<f32>,
pub action: usize,
pub reward: f32,
pub next_state: Option<Vec<f32>>,
pub done: bool,
}
/// Agent DQN utilisant Burn
pub struct BurnDqnAgent {
config: DqnConfig,
device: MyDevice,
q_network: DqnNetwork<MyBackend>,
target_network: DqnNetwork<MyBackend>,
replay_buffer: VecDeque<Experience>,
epsilon: f32,
step_count: usize,
}
impl BurnDqnAgent {
/// Crée un nouvel agent DQN
pub fn new(config: DqnConfig) -> Self {
let device = MyDevice::default();
let q_network = DqnNetwork::new(
config.state_size,
config.hidden_size,
config.action_size,
&device,
);
let target_network = DqnNetwork::new(
config.state_size,
config.hidden_size,
config.action_size,
&device,
);
Self {
config: config.clone(),
device,
q_network,
target_network,
replay_buffer: VecDeque::new(),
epsilon: config.epsilon,
step_count: 0,
}
}
/// Sélectionne une action avec epsilon-greedy
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
if valid_actions.is_empty() {
// Retourne une action par défaut ou une action "nulle" si aucune n'est valide
// Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte
return 0;
}
// Exploration epsilon-greedy
if rand::random::<f32>() < self.epsilon {
let random_index = rand::random::<usize>() % valid_actions.len();
return valid_actions[random_index];
}
// Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 1>::from_floats(state, &self.device)
.reshape([1, self.config.state_size]);
let q_values = self.q_network.forward(state_tensor);
// Convertir en vecteur pour traitement
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
// Trouver la meilleure action parmi les actions valides
let mut best_action = valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for &action in valid_actions {
if action < q_data.len() && q_data[action] > best_q_value {
best_q_value = q_data[action];
best_action = action;
}
}
best_action
}
/// Ajoute une expérience au replay buffer
pub fn add_experience(&mut self, experience: Experience) {
if self.replay_buffer.len() >= self.config.replay_buffer_size {
self.replay_buffer.pop_front();
}
self.replay_buffer.push_back(experience);
}
/// Entraîne le réseau sur un batch d'expériences
pub fn train_step(
&mut self,
optimizer: &mut impl Optimizer<DqnNetwork<MyBackend>, MyBackend>,
) -> Option<f32> {
if self.replay_buffer.len() < self.config.batch_size {
return None;
}
// Échantillonner un batch d'expériences
let batch = self.sample_batch();
// Préparer les tenseurs d'état
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
let state_tensor = Tensor::<MyBackend, 1>::from_floats(states.as_slice(), &self.device)
.reshape([self.config.batch_size, self.config.state_size]);
// Calculer les Q-values actuelles
let current_q_values = self.q_network.forward(state_tensor);
// Pour l'instant, version simplifiée sans calcul de target
let target_q_values = current_q_values.clone();
// Calculer la loss MSE
let loss = MseLoss::new().forward(
current_q_values,
target_q_values,
burn::nn::loss::Reduction::Mean,
);
// Backpropagation (version simplifiée)
let grads = loss.backward();
// Gradients linked to each parameter of the model.
let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network);
self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads);
// Mise à jour du réseau cible
self.step_count += 1;
if self.step_count % self.config.target_update_freq == 0 {
self.update_target_network();
}
// Décroissance d'epsilon
if self.epsilon > self.config.epsilon_min {
self.epsilon *= self.config.epsilon_decay;
}
Some(loss.into_scalar())
}
/// Échantillonne un batch d'expériences du replay buffer
fn sample_batch(&self) -> Vec<Experience> {
let mut batch = Vec::new();
let buffer_size = self.replay_buffer.len();
for _ in 0..self.config.batch_size.min(buffer_size) {
let index = rand::random::<usize>() % buffer_size;
if let Some(exp) = self.replay_buffer.get(index) {
batch.push(exp.clone());
}
}
batch
}
/// Met à jour le réseau cible avec les poids du réseau principal
fn update_target_network(&mut self) {
// Copie simple des poids
self.target_network = self.q_network.clone();
}
/// Sauvegarde le modèle
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
// Sauvegarder la configuration
let config_path = format!("{}_config.json", path);
let config_json = serde_json::to_string_pretty(&self.config)?;
std::fs::write(config_path, config_json)?;
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
let inference_network = self.q_network.clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
recorder.record(inference_network, model_path.into())?;
println!("Modèle sauvegardé : {}", path);
Ok(())
}
/// Charge un modèle pour l'inférence
pub fn load_model_for_inference(
path: &str,
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
// Charger la configuration
let config_path = format!("{}_config.json", path);
let config_json = std::fs::read_to_string(config_path)?;
let config: DqnConfig = serde_json::from_str(&config_json)?;
// Créer le réseau pour l'inférence
let device = NdArrayDevice::default();
let network = DqnNetwork::<InferenceBackend>::new(
config.state_size,
config.hidden_size,
config.action_size,
&device,
);
// Charger les poids
let model_path = format!("{}_model.burn", path);
let recorder = CompactRecorder::new();
let record = recorder.load(model_path.into(), &device)?;
let network = network.load_record(record);
Ok((network, config))
}
/// Retourne l'epsilon actuel
pub fn get_epsilon(&self) -> f32 {
self.epsilon
}
/// Retourne la taille du replay buffer
pub fn get_buffer_size(&self) -> usize {
self.replay_buffer.len()
}
}

View file

@ -1,192 +0,0 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use super::burn_dqn_agent::{DqnNetwork, DqnConfig, InferenceBackend};
use super::dqn_common::get_valid_actions;
use burn::{backend::ndarray::NdArrayDevice, tensor::Tensor};
use std::path::Path;
/// Stratégie utilisant un modèle DQN Burn entraîné
#[derive(Debug)]
pub struct BurnDqnStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
network: Option<DqnNetwork<InferenceBackend>>,
config: Option<DqnConfig>,
device: NdArrayDevice,
}
impl Default for BurnDqnStrategy {
fn default() -> Self {
Self {
game: GameState::default(),
player_id: 0,
color: Color::White,
network: None,
config: None,
device: NdArrayDevice::default(),
}
}
}
impl BurnDqnStrategy {
/// Crée une nouvelle stratégie avec un modèle chargé
pub fn new(model_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let mut strategy = Self::default();
strategy.load_model(model_path)?;
Ok(strategy)
}
/// Charge un modèle DQN depuis un fichier
pub fn load_model(&mut self, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
if !Path::new(&format!("{}_config.json", model_path)).exists() {
return Err(format!("Modèle non trouvé : {}", model_path).into());
}
let (network, config) = super::burn_dqn_agent::BurnDqnAgent::load_model_for_inference(model_path)?;
self.network = Some(network);
self.config = Some(config);
println!("Modèle DQN Burn chargé depuis : {}", model_path);
Ok(())
}
/// Sélectionne la meilleure action selon le modèle DQN
fn select_best_action(&self, valid_actions: &[super::dqn_common::TrictracAction]) -> Option<super::dqn_common::TrictracAction> {
if valid_actions.is_empty() {
return None;
}
// Si pas de réseau chargé, utiliser la première action valide
let Some(network) = &self.network else {
return Some(valid_actions[0].clone());
};
// Convertir l'état du jeu en tensor
let state_vec = self.game.to_vec_float();
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats(state_vec.as_slice(), &self.device).reshape([1, self.config.as_ref().unwrap().state_size]);
// Faire une prédiction
let q_values = network.forward(state_tensor);
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
// Trouver la meilleure action parmi les actions valides
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for (i, action) in valid_actions.iter().enumerate() {
if i < q_data.len() && q_data[i] > best_q_value {
best_q_value = q_data[i];
best_action = action;
}
}
Some(best_action.clone())
}
/// Convertit une TrictracAction en CheckerMove pour les mouvements
fn trictrac_action_to_moves(&self, action: &super::dqn_common::TrictracAction) -> Option<(CheckerMove, CheckerMove)> {
match action {
super::dqn_common::TrictracAction::Move { dice_order, from1, from2 } => {
let dice = self.game.dice;
let (die1, die2) = if *dice_order {
(dice.values.0, dice.values.1)
} else {
(dice.values.1, dice.values.0)
};
// Calculer les destinations selon la couleur
let to1 = if self.color == Color::White {
from1 + die1 as usize
} else {
from1.saturating_sub(die1 as usize)
};
let to2 = if self.color == Color::White {
from2 + die2 as usize
} else {
from2.saturating_sub(die2 as usize)
};
// Créer les mouvements
let move1 = CheckerMove::new(*from1, to1).ok()?;
let move2 = CheckerMove::new(*from2, to2).ok()?;
Some((move1, move2))
}
_ => None,
}
}
}
impl BotStrategy for BurnDqnStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn calculate_points(&self) -> u8 {
// Utiliser le modèle DQN pour décider des points à marquer
// let valid_actions = get_valid_actions(&self.game);
// Chercher une action Mark dans les actions valides
// for action in &valid_actions {
// if let super::dqn_common::TrictracAction::Mark { points } = action {
// return *points;
// }
// }
// Par défaut, marquer 0 points
0
}
fn calculate_adv_points(&self) -> u8 {
// Même logique que calculate_points pour les points d'avance
self.calculate_points()
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
let valid_actions = get_valid_actions(&self.game);
if let Some(best_action) = self.select_best_action(&valid_actions) {
if let Some((move1, move2)) = self.trictrac_action_to_moves(&best_action) {
return (move1, move2);
}
}
// Fallback: utiliser la stratégie par défaut
let default_strategy = super::default::DefaultStrategy::default();
default_strategy.choose_move()
}
fn choose_go(&self) -> bool {
let valid_actions = get_valid_actions(&self.game);
if let Some(best_action) = self.select_best_action(&valid_actions) {
match best_action {
super::dqn_common::TrictracAction::Go => return true,
super::dqn_common::TrictracAction::Move { .. } => return false,
_ => {}
}
}
// Par défaut, toujours choisir de continuer
true
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
}
/// Factory function pour créer une stratégie DQN Burn depuis un chemin de modèle
pub fn create_burn_dqn_strategy(model_path: &str) -> Result<Box<dyn BotStrategy>, Box<dyn std::error::Error>> {
let strategy = BurnDqnStrategy::new(model_path)?;
Ok(Box::new(strategy))
}

View file

@ -1,4 +1,4 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use store::MoveRules;
#[derive(Debug)]

View file

@ -1,8 +1,8 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use std::path::Path;
use store::MoveRules;
use super::dqn_common::{
use crate::dqn::dqn_common::{
get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction,
};

View file

@ -1,9 +1,5 @@
pub mod burn_dqn_agent;
pub mod burn_dqn_strategy;
pub mod client;
pub mod default;
pub mod dqn;
pub mod dqn_common;
pub mod dqn_trainer;
pub mod erroneous_moves;
pub mod stable_baselines3;

View file

@ -9,7 +9,10 @@ shell:
runcli:
RUST_LOG=info cargo run --bin=client_cli
runclibots:
RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,ai
RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy
match:
cargo build --release --bin=client_cli
LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn
profile:
echo '1' | sudo tee /proc/sys/kernel/perf_event_paranoid
cargo build --profile profiling
@ -29,4 +32,4 @@ debugtrainbot:
profiletrainbot:
echo '1' | sudo tee /proc/sys/kernel/perf_event_paranoid
cargo build --profile profiling --bin=train_dqn_burn
LD_LIBRARY_PATH=./target/debug samply record ./target/profiling/train_dqn_burn
LD_LIBRARY_PATH=./target/profiling samply record ./target/profiling/train_dqn_burn