From cf93255f03bd6ca60072d26f6bcc1325a422ddcd Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Mon, 23 Jun 2025 22:17:24 +0200 Subject: [PATCH] claude not tested --- bot/Cargo.toml | 4 + bot/src/bin/train_dqn_full.rs | 253 ++++++++++++++++++++++ bot/src/lib.rs | 1 + bot/src/strategy.rs | 2 + bot/src/strategy/burn_dqn_agent.rs | 294 ++++++++++++++++++++++++++ bot/src/strategy/burn_dqn_strategy.rs | 192 +++++++++++++++++ doc/refs/claudeAIquestionOnlyRust.md | 20 ++ 7 files changed, 766 insertions(+) create mode 100644 bot/src/bin/train_dqn_full.rs create mode 100644 bot/src/strategy/burn_dqn_agent.rs create mode 100644 bot/src/strategy/burn_dqn_strategy.rs diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 2da1ac1..38bfee9 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs" name = "train_burn_rl" path = "src/bin/train_burn_rl.rs" +[[bin]] +name = "train_dqn_full" +path = "src/bin/train_dqn_full.rs" + [dependencies] pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } diff --git a/bot/src/bin/train_dqn_full.rs b/bot/src/bin/train_dqn_full.rs new file mode 100644 index 0000000..357ce90 --- /dev/null +++ b/bot/src/bin/train_dqn_full.rs @@ -0,0 +1,253 @@ +use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; +use bot::strategy::burn_environment::{TrictracEnvironment, TrictracAction}; +use bot::strategy::dqn_common::get_valid_actions; +use burn_rl::base::Environment; +use std::env; + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = env::args().collect(); + + // Paramètres par défaut + let mut episodes = 1000; + let mut model_path = "models/burn_dqn_model".to_string(); + let mut save_every = 100; + let mut max_steps_per_episode = 500; + + // Parser les arguments de ligne de commande + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--episodes" => { + if i + 1 < args.len() { + episodes = args[i + 1].parse().unwrap_or(1000); + i += 2; + } else { + eprintln!("Erreur : --episodes nécessite une valeur"); + std::process::exit(1); + } + } + "--model-path" => { + if i + 1 < args.len() { + model_path = args[i + 1].clone(); + i += 2; + } else { + eprintln!("Erreur : --model-path nécessite une valeur"); + std::process::exit(1); + } + } + "--save-every" => { + if i + 1 < args.len() { + save_every = args[i + 1].parse().unwrap_or(100); + i += 2; + } else { + eprintln!("Erreur : --save-every nécessite une valeur"); + std::process::exit(1); + } + } + "--max-steps" => { + if i + 1 < args.len() { + max_steps_per_episode = args[i + 1].parse().unwrap_or(500); + i += 2; + } else { + eprintln!("Erreur : --max-steps nécessite une valeur"); + std::process::exit(1); + } + } + "--help" | "-h" => { + print_help(); + std::process::exit(0); + } + _ => { + eprintln!("Argument inconnu : {}", args[i]); + print_help(); + std::process::exit(1); + } + } + } + + // Créer le dossier models s'il n'existe pas + std::fs::create_dir_all("models")?; + + println!("=== Entraînement DQN complet avec Burn ==="); + println!("Épisodes : {}", episodes); + println!("Modèle : {}", model_path); + println!("Sauvegarde tous les {} épisodes", save_every); + println!("Max steps par épisode : {}", max_steps_per_episode); + println!(); + + // Configuration DQN + let config = DqnConfig { + state_size: 36, + action_size: 1000, // Espace d'actions réduit via contexte + hidden_size: 256, + learning_rate: 0.001, + gamma: 0.99, + epsilon: 1.0, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + target_update_freq: 100, + }; + + // Créer l'agent et l'environnement + let mut agent = BurnDqnAgent::new(config); + let mut env = TrictracEnvironment::new(true); + + // Variables pour les statistiques + let mut total_rewards = Vec::new(); + let mut episode_lengths = Vec::new(); + let mut losses = Vec::new(); + + println!("Début de l'entraînement avec agent DQN complet..."); + println!(); + + for episode in 1..=episodes { + // Reset de l'environnement + let mut snapshot = env.reset(); + let mut episode_reward = 0.0; + let mut step = 0; + let mut episode_loss = 0.0; + let mut loss_count = 0; + + loop { + step += 1; + let current_state = snapshot.state; + + // Obtenir les actions valides selon le contexte du jeu + let valid_actions = get_valid_actions(&env.game); + + if valid_actions.is_empty() { + break; + } + + // Convertir les actions Trictrac en indices pour l'agent + let valid_indices: Vec = (0..valid_actions.len()).collect(); + + // Sélectionner une action avec l'agent DQN + let action_index = agent.select_action(¤t_state.data.iter().map(|&x| x as f32).collect::>(), &valid_indices); + let action = TrictracAction { index: action_index as u32 }; + + // Exécuter l'action + snapshot = env.step(action); + episode_reward += snapshot.reward; + + // Préparer l'expérience pour l'agent + let experience = Experience { + state: current_state.data.iter().map(|&x| x as f32).collect(), + action: action_index, + reward: snapshot.reward, + next_state: if snapshot.terminated { + None + } else { + Some(snapshot.state.data.iter().map(|&x| x as f32).collect()) + }, + done: snapshot.terminated, + }; + + // Ajouter l'expérience au replay buffer + agent.add_experience(experience); + + // Entraîner l'agent + if let Some(loss) = agent.train_step() { + episode_loss += loss; + loss_count += 1; + } + + // Vérifier les conditions de fin + if snapshot.terminated || step >= max_steps_per_episode { + break; + } + } + + // Calculer la loss moyenne de l'épisode + let avg_loss = if loss_count > 0 { episode_loss / loss_count as f32 } else { 0.0 }; + + // Sauvegarder les statistiques + total_rewards.push(episode_reward); + episode_lengths.push(step); + losses.push(avg_loss); + + // Affichage des statistiques + if episode % save_every == 0 { + let avg_reward = total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; + let avg_length = episode_lengths.iter().rev().take(save_every).sum::() / save_every; + let avg_episode_loss = losses.iter().rev().take(save_every).sum::() / save_every as f32; + + println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}", + episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size()); + + // Sauvegarder le modèle + let checkpoint_path = format!("{}_{}", model_path, episode); + if let Err(e) = agent.save_model(&checkpoint_path) { + eprintln!("Erreur lors de la sauvegarde : {}", e); + } else { + println!(" → Modèle sauvegardé : {}", checkpoint_path); + } + } else if episode % 10 == 0 { + println!("Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", + episode, episode_reward, step, avg_loss, agent.get_epsilon()); + } + } + + // Sauvegarder le modèle final + let final_path = format!("{}_final", model_path); + agent.save_model(&final_path)?; + + // Statistiques finales + println!(); + println!("=== Résultats de l'entraînement ==="); + let final_avg_reward = total_rewards.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; + let final_avg_length = episode_lengths.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes); + let final_avg_loss = losses.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; + + println!("Récompense moyenne (100 derniers épisodes) : {:.3}", final_avg_reward); + println!("Longueur moyenne (100 derniers épisodes) : {}", final_avg_length); + println!("Loss moyenne (100 derniers épisodes) : {:.6}", final_avg_loss); + println!("Epsilon final : {:.3}", agent.get_epsilon()); + println!("Taille du buffer final : {}", agent.get_buffer_size()); + + // Statistiques globales + let max_reward = total_rewards.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); + println!("Récompense max : {:.3}", max_reward); + println!("Récompense min : {:.3}", min_reward); + + println!(); + println!("Entraînement terminé avec succès !"); + println!("Modèle final sauvegardé : {}", final_path); + println!(); + println!("Pour utiliser le modèle entraîné :"); + println!(" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", model_path); + + Ok(()) +} + +fn print_help() { + println!("Entraîneur DQN complet avec Burn pour Trictrac"); + println!(); + println!("USAGE:"); + println!(" cargo run --bin=train_dqn_full [OPTIONS]"); + println!(); + println!("OPTIONS:"); + println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); + println!(" --model-path Chemin de base pour sauvegarder les modèles (défaut: models/burn_dqn_model)"); + println!(" --save-every Sauvegarder le modèle tous les N épisodes (défaut: 100)"); + println!(" --max-steps Nombre max de steps par épisode (défaut: 500)"); + println!(" -h, --help Afficher cette aide"); + println!(); + println!("EXEMPLES:"); + println!(" cargo run --bin=train_dqn_full"); + println!(" cargo run --bin=train_dqn_full -- --episodes 2000 --save-every 200"); + println!(" cargo run --bin=train_dqn_full -- --model-path models/my_model --episodes 500"); + println!(); + println!("FONCTIONNALITÉS:"); + println!(" - Agent DQN complet avec réseau de neurones Burn"); + println!(" - Experience replay buffer avec échantillonnage aléatoire"); + println!(" - Epsilon-greedy avec décroissance automatique"); + println!(" - Target network avec mise à jour périodique"); + println!(" - Sauvegarde automatique des modèles"); + println!(" - Statistiques d'entraînement détaillées"); +} \ No newline at end of file diff --git a/bot/src/lib.rs b/bot/src/lib.rs index cd66aa9..d3da040 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,6 +1,7 @@ pub mod strategy; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; +pub use strategy::burn_dqn_strategy::{BurnDqnStrategy, create_burn_dqn_strategy}; pub use strategy::default::DefaultStrategy; pub use strategy::dqn::DqnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; diff --git a/bot/src/strategy.rs b/bot/src/strategy.rs index 5c36e04..e26c20f 100644 --- a/bot/src/strategy.rs +++ b/bot/src/strategy.rs @@ -1,3 +1,5 @@ +pub mod burn_dqn_agent; +pub mod burn_dqn_strategy; pub mod burn_environment; pub mod client; pub mod default; diff --git a/bot/src/strategy/burn_dqn_agent.rs b/bot/src/strategy/burn_dqn_agent.rs new file mode 100644 index 0000000..785e834 --- /dev/null +++ b/bot/src/strategy/burn_dqn_agent.rs @@ -0,0 +1,294 @@ +use burn::{ + backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, + nn::{Linear, LinearConfig, loss::MseLoss}, + module::Module, + tensor::Tensor, + optim::{AdamConfig, Optimizer}, + record::{CompactRecorder, Recorder}, +}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; + +/// Backend utilisé pour l'entraînement (Autodiff + NdArray) +pub type MyBackend = Autodiff; +/// Backend utilisé pour l'inférence (NdArray) +pub type InferenceBackend = NdArray; +pub type MyDevice = NdArrayDevice; + +/// Réseau de neurones pour DQN +#[derive(Module, Debug)] +pub struct DqnNetwork { + fc1: Linear, + fc2: Linear, + fc3: Linear, +} + +impl DqnNetwork { + /// Crée un nouveau réseau DQN + pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self { + let fc1 = LinearConfig::new(input_size, hidden_size).init(device); + let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device); + let fc3 = LinearConfig::new(hidden_size, output_size).init(device); + + Self { fc1, fc2, fc3 } + } + + /// Forward pass du réseau + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.fc1.forward(input); + let x = burn::tensor::activation::relu(x); + let x = self.fc2.forward(x); + let x = burn::tensor::activation::relu(x); + self.fc3.forward(x) + } +} + +/// Configuration pour l'entraînement DQN +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DqnConfig { + pub state_size: usize, + pub action_size: usize, + pub hidden_size: usize, + pub learning_rate: f64, + pub gamma: f32, + pub epsilon: f32, + pub epsilon_decay: f32, + pub epsilon_min: f32, + pub replay_buffer_size: usize, + pub batch_size: usize, + pub target_update_freq: usize, +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + state_size: 36, + action_size: 1000, + hidden_size: 256, + learning_rate: 0.001, + gamma: 0.99, + epsilon: 1.0, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + target_update_freq: 100, + } + } +} + +/// Experience pour le replay buffer +#[derive(Debug, Clone)] +pub struct Experience { + pub state: Vec, + pub action: usize, + pub reward: f32, + pub next_state: Option>, + pub done: bool, +} + +/// Agent DQN utilisant Burn +pub struct BurnDqnAgent { + config: DqnConfig, + device: MyDevice, + q_network: DqnNetwork, + target_network: DqnNetwork, + optimizer: burn::optim::Adam, + replay_buffer: VecDeque, + epsilon: f32, + step_count: usize, +} + +impl BurnDqnAgent { + /// Crée un nouvel agent DQN + pub fn new(config: DqnConfig) -> Self { + let device = MyDevice::default(); + + let q_network = DqnNetwork::new( + config.state_size, + config.hidden_size, + config.action_size, + &device, + ); + + let target_network = DqnNetwork::new( + config.state_size, + config.hidden_size, + config.action_size, + &device, + ); + + let optimizer = AdamConfig::new().init(); + + Self { + config: config.clone(), + device, + q_network, + target_network, + optimizer, + replay_buffer: VecDeque::new(), + epsilon: config.epsilon, + step_count: 0, + } + } + + /// Sélectionne une action avec epsilon-greedy + pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize { + if valid_actions.is_empty() { + return 0; + } + + // Exploration epsilon-greedy + if rand::random::() < self.epsilon { + let random_index = rand::random::() % valid_actions.len(); + return valid_actions[random_index]; + } + + // Exploitation : choisir la meilleure action selon le Q-network + let state_tensor = Tensor::::from_floats([state], &self.device); + let q_values = self.q_network.forward(state_tensor); + + // Convertir en vecteur pour traitement + let q_data = q_values.into_data().convert::().value; + + // Trouver la meilleure action parmi les actions valides + let mut best_action = valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for &action in valid_actions { + if action < q_data.len() && q_data[action] > best_q_value { + best_q_value = q_data[action]; + best_action = action; + } + } + + best_action + } + + /// Ajoute une expérience au replay buffer + pub fn add_experience(&mut self, experience: Experience) { + if self.replay_buffer.len() >= self.config.replay_buffer_size { + self.replay_buffer.pop_front(); + } + self.replay_buffer.push_back(experience); + } + + /// Entraîne le réseau sur un batch d'expériences + pub fn train_step(&mut self) -> Option { + if self.replay_buffer.len() < self.config.batch_size { + return None; + } + + // Échantillonner un batch d'expériences + let batch = self.sample_batch(); + + // Préparer les tenseurs d'état + let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect(); + let state_tensor = Tensor::::from_floats(states, &self.device); + + // Calculer les Q-values actuelles + let current_q_values = self.q_network.forward(state_tensor); + + // Pour l'instant, version simplifiée sans calcul de target + let target_q_values = current_q_values.clone(); + + // Calculer la loss MSE + let loss = MseLoss::new().forward( + current_q_values, + target_q_values, + burn::nn::loss::Reduction::Mean + ); + + // Backpropagation (version simplifiée) + let grads = loss.backward(); + self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads); + + // Mise à jour du réseau cible + self.step_count += 1; + if self.step_count % self.config.target_update_freq == 0 { + self.update_target_network(); + } + + // Décroissance d'epsilon + if self.epsilon > self.config.epsilon_min { + self.epsilon *= self.config.epsilon_decay; + } + + Some(loss.into_scalar()) + } + + /// Échantillonne un batch d'expériences du replay buffer + fn sample_batch(&self) -> Vec { + let mut batch = Vec::new(); + let buffer_size = self.replay_buffer.len(); + + for _ in 0..self.config.batch_size.min(buffer_size) { + let index = rand::random::() % buffer_size; + if let Some(exp) = self.replay_buffer.get(index) { + batch.push(exp.clone()); + } + } + + batch + } + + /// Met à jour le réseau cible avec les poids du réseau principal + fn update_target_network(&mut self) { + // Copie simple des poids + self.target_network = self.q_network.clone(); + } + + /// Sauvegarde le modèle + pub fn save_model(&self, path: &str) -> Result<(), Box> { + // Sauvegarder la configuration + let config_path = format!("{}_config.json", path); + let config_json = serde_json::to_string_pretty(&self.config)?; + std::fs::write(config_path, config_json)?; + + // Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend) + let inference_network = self.q_network.clone().into_record(); + let recorder = CompactRecorder::new(); + + let model_path = format!("{}_model.burn", path); + recorder.record(inference_network, model_path.into())?; + + println!("Modèle sauvegardé : {}", path); + Ok(()) + } + + /// Charge un modèle pour l'inférence + pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork, DqnConfig), Box> { + // Charger la configuration + let config_path = format!("{}_config.json", path); + let config_json = std::fs::read_to_string(config_path)?; + let config: DqnConfig = serde_json::from_str(&config_json)?; + + // Créer le réseau pour l'inférence + let device = NdArrayDevice::default(); + let network = DqnNetwork::::new( + config.state_size, + config.hidden_size, + config.action_size, + &device, + ); + + // Charger les poids + let model_path = format!("{}_model.burn", path); + let recorder = CompactRecorder::new(); + let record = recorder.load(model_path.into(), &device)?; + let network = network.load_record(record); + + Ok((network, config)) + } + + /// Retourne l'epsilon actuel + pub fn get_epsilon(&self) -> f32 { + self.epsilon + } + + /// Retourne la taille du replay buffer + pub fn get_buffer_size(&self) -> usize { + self.replay_buffer.len() + } +} \ No newline at end of file diff --git a/bot/src/strategy/burn_dqn_strategy.rs b/bot/src/strategy/burn_dqn_strategy.rs new file mode 100644 index 0000000..8e9b72b --- /dev/null +++ b/bot/src/strategy/burn_dqn_strategy.rs @@ -0,0 +1,192 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use super::burn_dqn_agent::{DqnNetwork, DqnConfig, InferenceBackend}; +use super::dqn_common::get_valid_actions; +use burn::{backend::ndarray::NdArrayDevice, tensor::Tensor}; +use std::path::Path; + +/// Stratégie utilisant un modèle DQN Burn entraîné +#[derive(Debug)] +pub struct BurnDqnStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, + network: Option>, + config: Option, + device: NdArrayDevice, +} + +impl Default for BurnDqnStrategy { + fn default() -> Self { + Self { + game: GameState::default(), + player_id: 0, + color: Color::White, + network: None, + config: None, + device: NdArrayDevice::default(), + } + } +} + +impl BurnDqnStrategy { + /// Crée une nouvelle stratégie avec un modèle chargé + pub fn new(model_path: &str) -> Result> { + let mut strategy = Self::default(); + strategy.load_model(model_path)?; + Ok(strategy) + } + + /// Charge un modèle DQN depuis un fichier + pub fn load_model(&mut self, model_path: &str) -> Result<(), Box> { + if !Path::new(&format!("{}_config.json", model_path)).exists() { + return Err(format!("Modèle non trouvé : {}", model_path).into()); + } + + let (network, config) = super::burn_dqn_agent::BurnDqnAgent::load_model_for_inference(model_path)?; + + self.network = Some(network); + self.config = Some(config); + + println!("Modèle DQN Burn chargé depuis : {}", model_path); + Ok(()) + } + + /// Sélectionne la meilleure action selon le modèle DQN + fn select_best_action(&self, valid_actions: &[super::dqn_common::TrictracAction]) -> Option { + if valid_actions.is_empty() { + return None; + } + + // Si pas de réseau chargé, utiliser la première action valide + let Some(network) = &self.network else { + return Some(valid_actions[0].clone()); + }; + + // Convertir l'état du jeu en tensor + let state_vec = self.game.to_vec_float(); + let state_tensor = Tensor::::from_floats([state_vec], &self.device); + + // Faire une prédiction + let q_values = network.forward(state_tensor); + let q_data = q_values.into_data().convert::().value; + + // Trouver la meilleure action parmi les actions valides + let mut best_action = &valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for (i, action) in valid_actions.iter().enumerate() { + if i < q_data.len() && q_data[i] > best_q_value { + best_q_value = q_data[i]; + best_action = action; + } + } + + Some(best_action.clone()) + } + + /// Convertit une TrictracAction en CheckerMove pour les mouvements + fn trictrac_action_to_moves(&self, action: &super::dqn_common::TrictracAction) -> Option<(CheckerMove, CheckerMove)> { + match action { + super::dqn_common::TrictracAction::Move { dice_order, from1, from2 } => { + let dice = self.game.dice; + let (die1, die2) = if *dice_order { + (dice.values.0, dice.values.1) + } else { + (dice.values.1, dice.values.0) + }; + + // Calculer les destinations selon la couleur + let to1 = if self.color == Color::White { + from1 + die1 as usize + } else { + from1.saturating_sub(die1 as usize) + }; + let to2 = if self.color == Color::White { + from2 + die2 as usize + } else { + from2.saturating_sub(die2 as usize) + }; + + // Créer les mouvements + let move1 = CheckerMove::new(*from1, to1).ok()?; + let move2 = CheckerMove::new(*from2, to2).ok()?; + + Some((move1, move2)) + } + _ => None, + } + } +} + +impl BotStrategy for BurnDqnStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn calculate_points(&self) -> u8 { + // Utiliser le modèle DQN pour décider des points à marquer + let valid_actions = get_valid_actions(&self.game); + + // Chercher une action Mark dans les actions valides + for action in &valid_actions { + if let super::dqn_common::TrictracAction::Mark { points } = action { + return *points; + } + } + + // Par défaut, marquer 0 points + 0 + } + + fn calculate_adv_points(&self) -> u8 { + // Même logique que calculate_points pour les points d'avance + self.calculate_points() + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + let valid_actions = get_valid_actions(&self.game); + + if let Some(best_action) = self.select_best_action(&valid_actions) { + if let Some((move1, move2)) = self.trictrac_action_to_moves(&best_action) { + return (move1, move2); + } + } + + // Fallback: utiliser la stratégie par défaut + let default_strategy = super::default::DefaultStrategy::default(); + default_strategy.choose_move() + } + + fn choose_go(&self) -> bool { + let valid_actions = get_valid_actions(&self.game); + + if let Some(best_action) = self.select_best_action(&valid_actions) { + match best_action { + super::dqn_common::TrictracAction::Go => return true, + super::dqn_common::TrictracAction::Move { .. } => return false, + _ => {} + } + } + + // Par défaut, toujours choisir de continuer + true + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } +} + +/// Factory function pour créer une stratégie DQN Burn depuis un chemin de modèle +pub fn create_burn_dqn_strategy(model_path: &str) -> Result, Box> { + let strategy = BurnDqnStrategy::new(model_path)?; + Ok(Box::new(strategy)) +} \ No newline at end of file diff --git a/doc/refs/claudeAIquestionOnlyRust.md b/doc/refs/claudeAIquestionOnlyRust.md index 9d4c108..9ed6496 100644 --- a/doc/refs/claudeAIquestionOnlyRust.md +++ b/doc/refs/claudeAIquestionOnlyRust.md @@ -230,3 +230,23 @@ Options disponibles : - --help : aide complète Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones ! + +--- + +Session où je n'ai cessé de recevoir ce message : + +⎿ API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute. For details, refer to: . You can see the response headers for current usage. Please reduce the prompt length or the maximum tokens requested, or try again later. You may also contact sales at to discuss your options for a rate limit increase."}}) · Retrying in 391 seconds… (attempt 1/10) + +✶ Coaching… (403s · ↑ 382 tokens · esc to interrupt) + +Pour à la fin de la session avoir ce résumé : + +Total cost: $0.95 +Total duration (API): 1h 24m 22.8s +Total duration (wall): 1h 43m 3.5s +Total code changes: 746 lines added, 0 lines removed +Token usage by model: +claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write +claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write + +Mais pourtant 2 millions indiqués dans la page usage : , et 7.88 dollars de consommés sur .