diff --git a/bot/src/bin/train_dqn_full.rs b/bot/src/bin/train_dqn_full.rs index 357ce90..82eb502 100644 --- a/bot/src/bin/train_dqn_full.rs +++ b/bot/src/bin/train_dqn_full.rs @@ -1,5 +1,5 @@ use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; -use bot::strategy::burn_environment::{TrictracEnvironment, TrictracAction}; +use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment}; use bot::strategy::dqn_common::get_valid_actions; use burn_rl::base::Environment; use std::env; @@ -80,7 +80,7 @@ fn main() -> Result<(), Box> { // Configuration DQN let config = DqnConfig { state_size: 36, - action_size: 1000, // Espace d'actions réduit via contexte + action_size: 1252, // Espace d'actions réduit via contexte hidden_size: 256, learning_rate: 0.001, gamma: 0.99, @@ -94,6 +94,8 @@ fn main() -> Result<(), Box> { // Créer l'agent et l'environnement let mut agent = BurnDqnAgent::new(config); + let mut optimizer = AdamConfig::new().init(); + let mut env = TrictracEnvironment::new(true); // Variables pour les statistiques @@ -114,35 +116,44 @@ fn main() -> Result<(), Box> { loop { step += 1; - let current_state = snapshot.state; + let current_state = snapshot.state(); // Obtenir les actions valides selon le contexte du jeu let valid_actions = get_valid_actions(&env.game); - + if valid_actions.is_empty() { break; } // Convertir les actions Trictrac en indices pour l'agent let valid_indices: Vec = (0..valid_actions.len()).collect(); - + // Sélectionner une action avec l'agent DQN - let action_index = agent.select_action(¤t_state.data.iter().map(|&x| x as f32).collect::>(), &valid_indices); - let action = TrictracAction { index: action_index as u32 }; + let action_index = agent.select_action( + ¤t_state + .data + .iter() + .map(|&x| x as f32) + .collect::>(), + &valid_indices, + ); + let action = TrictracAction { + index: action_index as u32, + }; // Exécuter l'action snapshot = env.step(action); - episode_reward += snapshot.reward; + episode_reward += snapshot.reward(); // Préparer l'expérience pour l'agent let experience = Experience { state: current_state.data.iter().map(|&x| x as f32).collect(), action: action_index, - reward: snapshot.reward, - next_state: if snapshot.terminated { - None - } else { - Some(snapshot.state.data.iter().map(|&x| x as f32).collect()) + reward: snapshot.reward(), + next_state: if snapshot.terminated { + None + } else { + Some(snapshot.state().data.iter().map(|&x| x as f32).collect()) }, done: snapshot.terminated, }; @@ -151,7 +162,7 @@ fn main() -> Result<(), Box> { agent.add_experience(experience); // Entraîner l'agent - if let Some(loss) = agent.train_step() { + if let Some(loss) = agent.train_step(optimizer) { episode_loss += loss; loss_count += 1; } @@ -163,7 +174,11 @@ fn main() -> Result<(), Box> { } // Calculer la loss moyenne de l'épisode - let avg_loss = if loss_count > 0 { episode_loss / loss_count as f32 } else { 0.0 }; + let avg_loss = if loss_count > 0 { + episode_loss / loss_count as f32 + } else { + 0.0 + }; // Sauvegarder les statistiques total_rewards.push(episode_reward); @@ -172,13 +187,16 @@ fn main() -> Result<(), Box> { // Affichage des statistiques if episode % save_every == 0 { - let avg_reward = total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; - let avg_length = episode_lengths.iter().rev().take(save_every).sum::() / save_every; - let avg_episode_loss = losses.iter().rev().take(save_every).sum::() / save_every as f32; - + let avg_reward = + total_rewards.iter().rev().take(save_every).sum::() / save_every as f32; + let avg_length = + episode_lengths.iter().rev().take(save_every).sum::() / save_every; + let avg_episode_loss = + losses.iter().rev().take(save_every).sum::() / save_every as f32; + println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}", episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size()); - + // Sauvegarder le modèle let checkpoint_path = format!("{}_{}", model_path, episode); if let Err(e) = agent.save_model(&checkpoint_path) { @@ -187,8 +205,14 @@ fn main() -> Result<(), Box> { println!(" → Modèle sauvegardé : {}", checkpoint_path); } } else if episode % 10 == 0 { - println!("Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", - episode, episode_reward, step, avg_loss, agent.get_epsilon()); + println!( + "Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", + episode, + episode_reward, + step, + avg_loss, + agent.get_epsilon() + ); } } @@ -199,28 +223,54 @@ fn main() -> Result<(), Box> { // Statistiques finales println!(); println!("=== Résultats de l'entraînement ==="); - let final_avg_reward = total_rewards.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; - let final_avg_length = episode_lengths.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes); - let final_avg_loss = losses.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; - - println!("Récompense moyenne (100 derniers épisodes) : {:.3}", final_avg_reward); - println!("Longueur moyenne (100 derniers épisodes) : {}", final_avg_length); - println!("Loss moyenne (100 derniers épisodes) : {:.6}", final_avg_loss); + let final_avg_reward = total_rewards + .iter() + .rev() + .take(100.min(episodes)) + .sum::() + / 100.min(episodes) as f32; + let final_avg_length = episode_lengths + .iter() + .rev() + .take(100.min(episodes)) + .sum::() + / 100.min(episodes); + let final_avg_loss = + losses.iter().rev().take(100.min(episodes)).sum::() / 100.min(episodes) as f32; + + println!( + "Récompense moyenne (100 derniers épisodes) : {:.3}", + final_avg_reward + ); + println!( + "Longueur moyenne (100 derniers épisodes) : {}", + final_avg_length + ); + println!( + "Loss moyenne (100 derniers épisodes) : {:.6}", + final_avg_loss + ); println!("Epsilon final : {:.3}", agent.get_epsilon()); println!("Taille du buffer final : {}", agent.get_buffer_size()); - + // Statistiques globales - let max_reward = total_rewards.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let max_reward = total_rewards + .iter() + .cloned() + .fold(f32::NEG_INFINITY, f32::max); let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); println!("Récompense max : {:.3}", max_reward); println!("Récompense min : {:.3}", min_reward); - + println!(); println!("Entraînement terminé avec succès !"); println!("Modèle final sauvegardé : {}", final_path); println!(); println!("Pour utiliser le modèle entraîné :"); - println!(" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", model_path); + println!( + " cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", + model_path + ); Ok(()) } @@ -250,4 +300,4 @@ fn print_help() { println!(" - Target network avec mise à jour périodique"); println!(" - Sauvegarde automatique des modèles"); println!(" - Statistiques d'entraînement détaillées"); -} \ No newline at end of file +} diff --git a/bot/src/strategy/burn_dqn_agent.rs b/bot/src/strategy/burn_dqn_agent.rs index 785e834..36ad5d6 100644 --- a/bot/src/strategy/burn_dqn_agent.rs +++ b/bot/src/strategy/burn_dqn_agent.rs @@ -1,12 +1,13 @@ +use burn::module::AutodiffModule; +use burn::tensor::backend::AutodiffBackend; use burn::{ backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, - nn::{Linear, LinearConfig, loss::MseLoss}, module::Module, - tensor::Tensor, - optim::{AdamConfig, Optimizer}, + nn::{loss::MseLoss, Linear, LinearConfig}, + optim::{GradientsParams, Optimizer}, record::{CompactRecorder, Recorder}, + tensor::Tensor, }; -use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::VecDeque; @@ -26,11 +27,16 @@ pub struct DqnNetwork { impl DqnNetwork { /// Crée un nouveau réseau DQN - pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self { + pub fn new( + input_size: usize, + hidden_size: usize, + output_size: usize, + device: &B::Device, + ) -> Self { let fc1 = LinearConfig::new(input_size, hidden_size).init(device); let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device); let fc3 = LinearConfig::new(hidden_size, output_size).init(device); - + Self { fc1, fc2, fc3 } } @@ -94,7 +100,6 @@ pub struct BurnDqnAgent { device: MyDevice, q_network: DqnNetwork, target_network: DqnNetwork, - optimizer: burn::optim::Adam, replay_buffer: VecDeque, epsilon: f32, step_count: usize, @@ -104,29 +109,26 @@ impl BurnDqnAgent { /// Crée un nouvel agent DQN pub fn new(config: DqnConfig) -> Self { let device = MyDevice::default(); - + let q_network = DqnNetwork::new( config.state_size, config.hidden_size, config.action_size, &device, ); - + let target_network = DqnNetwork::new( config.state_size, config.hidden_size, config.action_size, &device, ); - - let optimizer = AdamConfig::new().init(); Self { config: config.clone(), device, q_network, target_network, - optimizer, replay_buffer: VecDeque::new(), epsilon: config.epsilon, step_count: 0, @@ -146,23 +148,23 @@ impl BurnDqnAgent { } // Exploitation : choisir la meilleure action selon le Q-network - let state_tensor = Tensor::::from_floats([state], &self.device); + let state_tensor = Tensor::::from_floats(state, &self.device); let q_values = self.q_network.forward(state_tensor); - + // Convertir en vecteur pour traitement - let q_data = q_values.into_data().convert::().value; - + let q_data = q_values.into_data().convert::().into_vec().unwrap(); + // Trouver la meilleure action parmi les actions valides let mut best_action = valid_actions[0]; let mut best_q_value = f32::NEG_INFINITY; - + for &action in valid_actions { if action < q_data.len() && q_data[action] > best_q_value { best_q_value = q_data[action]; best_action = action; } } - + best_action } @@ -175,46 +177,51 @@ impl BurnDqnAgent { } /// Entraîne le réseau sur un batch d'expériences - pub fn train_step(&mut self) -> Option { + pub fn train_step>( + &mut self, + optimizer: &mut impl Optimizer, + ) -> Option { if self.replay_buffer.len() < self.config.batch_size { return None; } // Échantillonner un batch d'expériences let batch = self.sample_batch(); - + // Préparer les tenseurs d'état let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect(); let state_tensor = Tensor::::from_floats(states, &self.device); - + // Calculer les Q-values actuelles let current_q_values = self.q_network.forward(state_tensor); - + // Pour l'instant, version simplifiée sans calcul de target let target_q_values = current_q_values.clone(); - + // Calculer la loss MSE let loss = MseLoss::new().forward( - current_q_values, - target_q_values, - burn::nn::loss::Reduction::Mean + current_q_values, + target_q_values, + burn::nn::loss::Reduction::Mean, ); - + // Backpropagation (version simplifiée) let grads = loss.backward(); - self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads); - + // Gradients linked to each parameter of the model. + // let grads = GradientsParams::from_grads(grads, &self.q_network); + self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads); + // Mise à jour du réseau cible self.step_count += 1; if self.step_count % self.config.target_update_freq == 0 { self.update_target_network(); } - + // Décroissance d'epsilon if self.epsilon > self.config.epsilon_min { self.epsilon *= self.config.epsilon_decay; } - + Some(loss.into_scalar()) } @@ -222,14 +229,14 @@ impl BurnDqnAgent { fn sample_batch(&self) -> Vec { let mut batch = Vec::new(); let buffer_size = self.replay_buffer.len(); - + for _ in 0..self.config.batch_size.min(buffer_size) { let index = rand::random::() % buffer_size; if let Some(exp) = self.replay_buffer.get(index) { batch.push(exp.clone()); } } - + batch } @@ -245,25 +252,27 @@ impl BurnDqnAgent { let config_path = format!("{}_config.json", path); let config_json = serde_json::to_string_pretty(&self.config)?; std::fs::write(config_path, config_json)?; - + // Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend) let inference_network = self.q_network.clone().into_record(); let recorder = CompactRecorder::new(); - + let model_path = format!("{}_model.burn", path); recorder.record(inference_network, model_path.into())?; - + println!("Modèle sauvegardé : {}", path); Ok(()) } /// Charge un modèle pour l'inférence - pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork, DqnConfig), Box> { + pub fn load_model_for_inference( + path: &str, + ) -> Result<(DqnNetwork, DqnConfig), Box> { // Charger la configuration let config_path = format!("{}_config.json", path); let config_json = std::fs::read_to_string(config_path)?; let config: DqnConfig = serde_json::from_str(&config_json)?; - + // Créer le réseau pour l'inférence let device = NdArrayDevice::default(); let network = DqnNetwork::::new( @@ -272,13 +281,13 @@ impl BurnDqnAgent { config.action_size, &device, ); - + // Charger les poids let model_path = format!("{}_model.burn", path); let recorder = CompactRecorder::new(); let record = recorder.load(model_path.into(), &device)?; let network = network.load_record(record); - + Ok((network, config)) } @@ -291,4 +300,4 @@ impl BurnDqnAgent { pub fn get_buffer_size(&self) -> usize { self.replay_buffer.len() } -} \ No newline at end of file +} diff --git a/doc/refs/claudeAIquestionOnlyRust.md b/doc/refs/claudeAIquestionOnlyRust.md index 9ed6496..ac81f7a 100644 --- a/doc/refs/claudeAIquestionOnlyRust.md +++ b/doc/refs/claudeAIquestionOnlyRust.md @@ -250,3 +250,19 @@ claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write Mais pourtant 2 millions indiqués dans la page usage : , et 7.88 dollars de consommés sur . + +I just had a claude code session in which I kept having this error, even if the agent didn't seem to read a lot of files : API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute. + +at the end of the session the token usage and cost indicated was this : + +Total cost: $0.95 +Total duration (API): 1h 24m 22.8s +Total duration (wall): 1h 43m 3.5s +Total code changes: 746 lines added, 0 lines removed +Token usage by model: +claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write +claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write + +but the usage on the /usage page was 2,073,698 token in, and the cost on the /cost page was $7.90. + +When looking at the costs csv file, it seems that it is the "input cache write 5m" that consumed nearly all the tokens ( $7,71 ). Is it a bug ?