This commit is contained in:
Henri Bourcereau 2025-06-28 21:34:44 +02:00
parent cf93255f03
commit f05094b2d4
3 changed files with 150 additions and 75 deletions

View file

@ -1,5 +1,5 @@
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience}; use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
use bot::strategy::burn_environment::{TrictracEnvironment, TrictracAction}; use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::dqn_common::get_valid_actions; use bot::strategy::dqn_common::get_valid_actions;
use burn_rl::base::Environment; use burn_rl::base::Environment;
use std::env; use std::env;
@ -80,7 +80,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Configuration DQN // Configuration DQN
let config = DqnConfig { let config = DqnConfig {
state_size: 36, state_size: 36,
action_size: 1000, // Espace d'actions réduit via contexte action_size: 1252, // Espace d'actions réduit via contexte
hidden_size: 256, hidden_size: 256,
learning_rate: 0.001, learning_rate: 0.001,
gamma: 0.99, gamma: 0.99,
@ -94,6 +94,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Créer l'agent et l'environnement // Créer l'agent et l'environnement
let mut agent = BurnDqnAgent::new(config); let mut agent = BurnDqnAgent::new(config);
let mut optimizer = AdamConfig::new().init();
let mut env = TrictracEnvironment::new(true); let mut env = TrictracEnvironment::new(true);
// Variables pour les statistiques // Variables pour les statistiques
@ -114,35 +116,44 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
loop { loop {
step += 1; step += 1;
let current_state = snapshot.state; let current_state = snapshot.state();
// Obtenir les actions valides selon le contexte du jeu // Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game); let valid_actions = get_valid_actions(&env.game);
if valid_actions.is_empty() { if valid_actions.is_empty() {
break; break;
} }
// Convertir les actions Trictrac en indices pour l'agent // Convertir les actions Trictrac en indices pour l'agent
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect(); let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
// Sélectionner une action avec l'agent DQN // Sélectionner une action avec l'agent DQN
let action_index = agent.select_action(&current_state.data.iter().map(|&x| x as f32).collect::<Vec<_>>(), &valid_indices); let action_index = agent.select_action(
let action = TrictracAction { index: action_index as u32 }; &current_state
.data
.iter()
.map(|&x| x as f32)
.collect::<Vec<_>>(),
&valid_indices,
);
let action = TrictracAction {
index: action_index as u32,
};
// Exécuter l'action // Exécuter l'action
snapshot = env.step(action); snapshot = env.step(action);
episode_reward += snapshot.reward; episode_reward += snapshot.reward();
// Préparer l'expérience pour l'agent // Préparer l'expérience pour l'agent
let experience = Experience { let experience = Experience {
state: current_state.data.iter().map(|&x| x as f32).collect(), state: current_state.data.iter().map(|&x| x as f32).collect(),
action: action_index, action: action_index,
reward: snapshot.reward, reward: snapshot.reward(),
next_state: if snapshot.terminated { next_state: if snapshot.terminated {
None None
} else { } else {
Some(snapshot.state.data.iter().map(|&x| x as f32).collect()) Some(snapshot.state().data.iter().map(|&x| x as f32).collect())
}, },
done: snapshot.terminated, done: snapshot.terminated,
}; };
@ -151,7 +162,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
agent.add_experience(experience); agent.add_experience(experience);
// Entraîner l'agent // Entraîner l'agent
if let Some(loss) = agent.train_step() { if let Some(loss) = agent.train_step(optimizer) {
episode_loss += loss; episode_loss += loss;
loss_count += 1; loss_count += 1;
} }
@ -163,7 +174,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
// Calculer la loss moyenne de l'épisode // Calculer la loss moyenne de l'épisode
let avg_loss = if loss_count > 0 { episode_loss / loss_count as f32 } else { 0.0 }; let avg_loss = if loss_count > 0 {
episode_loss / loss_count as f32
} else {
0.0
};
// Sauvegarder les statistiques // Sauvegarder les statistiques
total_rewards.push(episode_reward); total_rewards.push(episode_reward);
@ -172,13 +187,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Affichage des statistiques // Affichage des statistiques
if episode % save_every == 0 { if episode % save_every == 0 {
let avg_reward = total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32; let avg_reward =
let avg_length = episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every; total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
let avg_episode_loss = losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32; let avg_length =
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
let avg_episode_loss =
losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}", println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size()); episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
// Sauvegarder le modèle // Sauvegarder le modèle
let checkpoint_path = format!("{}_{}", model_path, episode); let checkpoint_path = format!("{}_{}", model_path, episode);
if let Err(e) = agent.save_model(&checkpoint_path) { if let Err(e) = agent.save_model(&checkpoint_path) {
@ -187,8 +205,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" → Modèle sauvegardé : {}", checkpoint_path); println!(" → Modèle sauvegardé : {}", checkpoint_path);
} }
} else if episode % 10 == 0 { } else if episode % 10 == 0 {
println!("Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}", println!(
episode, episode_reward, step, avg_loss, agent.get_epsilon()); "Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
episode,
episode_reward,
step,
avg_loss,
agent.get_epsilon()
);
} }
} }
@ -199,28 +223,54 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Statistiques finales // Statistiques finales
println!(); println!();
println!("=== Résultats de l'entraînement ==="); println!("=== Résultats de l'entraînement ===");
let final_avg_reward = total_rewards.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32; let final_avg_reward = total_rewards
let final_avg_length = episode_lengths.iter().rev().take(100.min(episodes)).sum::<usize>() / 100.min(episodes); .iter()
let final_avg_loss = losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32; .rev()
.take(100.min(episodes))
println!("Récompense moyenne (100 derniers épisodes) : {:.3}", final_avg_reward); .sum::<f32>()
println!("Longueur moyenne (100 derniers épisodes) : {}", final_avg_length); / 100.min(episodes) as f32;
println!("Loss moyenne (100 derniers épisodes) : {:.6}", final_avg_loss); let final_avg_length = episode_lengths
.iter()
.rev()
.take(100.min(episodes))
.sum::<usize>()
/ 100.min(episodes);
let final_avg_loss =
losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
println!(
"Récompense moyenne (100 derniers épisodes) : {:.3}",
final_avg_reward
);
println!(
"Longueur moyenne (100 derniers épisodes) : {}",
final_avg_length
);
println!(
"Loss moyenne (100 derniers épisodes) : {:.6}",
final_avg_loss
);
println!("Epsilon final : {:.3}", agent.get_epsilon()); println!("Epsilon final : {:.3}", agent.get_epsilon());
println!("Taille du buffer final : {}", agent.get_buffer_size()); println!("Taille du buffer final : {}", agent.get_buffer_size());
// Statistiques globales // Statistiques globales
let max_reward = total_rewards.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let max_reward = total_rewards
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min); let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
println!("Récompense max : {:.3}", max_reward); println!("Récompense max : {:.3}", max_reward);
println!("Récompense min : {:.3}", min_reward); println!("Récompense min : {:.3}", min_reward);
println!(); println!();
println!("Entraînement terminé avec succès !"); println!("Entraînement terminé avec succès !");
println!("Modèle final sauvegardé : {}", final_path); println!("Modèle final sauvegardé : {}", final_path);
println!(); println!();
println!("Pour utiliser le modèle entraîné :"); println!("Pour utiliser le modèle entraîné :");
println!(" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", model_path); println!(
" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy",
model_path
);
Ok(()) Ok(())
} }
@ -250,4 +300,4 @@ fn print_help() {
println!(" - Target network avec mise à jour périodique"); println!(" - Target network avec mise à jour périodique");
println!(" - Sauvegarde automatique des modèles"); println!(" - Sauvegarde automatique des modèles");
println!(" - Statistiques d'entraînement détaillées"); println!(" - Statistiques d'entraînement détaillées");
} }

View file

@ -1,12 +1,13 @@
use burn::module::AutodiffModule;
use burn::tensor::backend::AutodiffBackend;
use burn::{ use burn::{
backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
nn::{Linear, LinearConfig, loss::MseLoss},
module::Module, module::Module,
tensor::Tensor, nn::{loss::MseLoss, Linear, LinearConfig},
optim::{AdamConfig, Optimizer}, optim::{GradientsParams, Optimizer},
record::{CompactRecorder, Recorder}, record::{CompactRecorder, Recorder},
tensor::Tensor,
}; };
use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::VecDeque; use std::collections::VecDeque;
@ -26,11 +27,16 @@ pub struct DqnNetwork<B: burn::prelude::Backend> {
impl<B: burn::prelude::Backend> DqnNetwork<B> { impl<B: burn::prelude::Backend> DqnNetwork<B> {
/// Crée un nouveau réseau DQN /// Crée un nouveau réseau DQN
pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self { pub fn new(
input_size: usize,
hidden_size: usize,
output_size: usize,
device: &B::Device,
) -> Self {
let fc1 = LinearConfig::new(input_size, hidden_size).init(device); let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device); let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
let fc3 = LinearConfig::new(hidden_size, output_size).init(device); let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
Self { fc1, fc2, fc3 } Self { fc1, fc2, fc3 }
} }
@ -94,7 +100,6 @@ pub struct BurnDqnAgent {
device: MyDevice, device: MyDevice,
q_network: DqnNetwork<MyBackend>, q_network: DqnNetwork<MyBackend>,
target_network: DqnNetwork<MyBackend>, target_network: DqnNetwork<MyBackend>,
optimizer: burn::optim::Adam<MyBackend>,
replay_buffer: VecDeque<Experience>, replay_buffer: VecDeque<Experience>,
epsilon: f32, epsilon: f32,
step_count: usize, step_count: usize,
@ -104,29 +109,26 @@ impl BurnDqnAgent {
/// Crée un nouvel agent DQN /// Crée un nouvel agent DQN
pub fn new(config: DqnConfig) -> Self { pub fn new(config: DqnConfig) -> Self {
let device = MyDevice::default(); let device = MyDevice::default();
let q_network = DqnNetwork::new( let q_network = DqnNetwork::new(
config.state_size, config.state_size,
config.hidden_size, config.hidden_size,
config.action_size, config.action_size,
&device, &device,
); );
let target_network = DqnNetwork::new( let target_network = DqnNetwork::new(
config.state_size, config.state_size,
config.hidden_size, config.hidden_size,
config.action_size, config.action_size,
&device, &device,
); );
let optimizer = AdamConfig::new().init();
Self { Self {
config: config.clone(), config: config.clone(),
device, device,
q_network, q_network,
target_network, target_network,
optimizer,
replay_buffer: VecDeque::new(), replay_buffer: VecDeque::new(),
epsilon: config.epsilon, epsilon: config.epsilon,
step_count: 0, step_count: 0,
@ -146,23 +148,23 @@ impl BurnDqnAgent {
} }
// Exploitation : choisir la meilleure action selon le Q-network // Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 2>::from_floats([state], &self.device); let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device);
let q_values = self.q_network.forward(state_tensor); let q_values = self.q_network.forward(state_tensor);
// Convertir en vecteur pour traitement // Convertir en vecteur pour traitement
let q_data = q_values.into_data().convert::<f32>().value; let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
// Trouver la meilleure action parmi les actions valides // Trouver la meilleure action parmi les actions valides
let mut best_action = valid_actions[0]; let mut best_action = valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY; let mut best_q_value = f32::NEG_INFINITY;
for &action in valid_actions { for &action in valid_actions {
if action < q_data.len() && q_data[action] > best_q_value { if action < q_data.len() && q_data[action] > best_q_value {
best_q_value = q_data[action]; best_q_value = q_data[action];
best_action = action; best_action = action;
} }
} }
best_action best_action
} }
@ -175,46 +177,51 @@ impl BurnDqnAgent {
} }
/// Entraîne le réseau sur un batch d'expériences /// Entraîne le réseau sur un batch d'expériences
pub fn train_step(&mut self) -> Option<f32> { pub fn train_step<B: AutodiffBackend, M: AutodiffModule<B>>(
&mut self,
optimizer: &mut impl Optimizer<M, B>,
) -> Option<f32> {
if self.replay_buffer.len() < self.config.batch_size { if self.replay_buffer.len() < self.config.batch_size {
return None; return None;
} }
// Échantillonner un batch d'expériences // Échantillonner un batch d'expériences
let batch = self.sample_batch(); let batch = self.sample_batch();
// Préparer les tenseurs d'état // Préparer les tenseurs d'état
let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect(); let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect();
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device); let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
// Calculer les Q-values actuelles // Calculer les Q-values actuelles
let current_q_values = self.q_network.forward(state_tensor); let current_q_values = self.q_network.forward(state_tensor);
// Pour l'instant, version simplifiée sans calcul de target // Pour l'instant, version simplifiée sans calcul de target
let target_q_values = current_q_values.clone(); let target_q_values = current_q_values.clone();
// Calculer la loss MSE // Calculer la loss MSE
let loss = MseLoss::new().forward( let loss = MseLoss::new().forward(
current_q_values, current_q_values,
target_q_values, target_q_values,
burn::nn::loss::Reduction::Mean burn::nn::loss::Reduction::Mean,
); );
// Backpropagation (version simplifiée) // Backpropagation (version simplifiée)
let grads = loss.backward(); let grads = loss.backward();
self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads); // Gradients linked to each parameter of the model.
// let grads = GradientsParams::from_grads(grads, &self.q_network);
self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads);
// Mise à jour du réseau cible // Mise à jour du réseau cible
self.step_count += 1; self.step_count += 1;
if self.step_count % self.config.target_update_freq == 0 { if self.step_count % self.config.target_update_freq == 0 {
self.update_target_network(); self.update_target_network();
} }
// Décroissance d'epsilon // Décroissance d'epsilon
if self.epsilon > self.config.epsilon_min { if self.epsilon > self.config.epsilon_min {
self.epsilon *= self.config.epsilon_decay; self.epsilon *= self.config.epsilon_decay;
} }
Some(loss.into_scalar()) Some(loss.into_scalar())
} }
@ -222,14 +229,14 @@ impl BurnDqnAgent {
fn sample_batch(&self) -> Vec<Experience> { fn sample_batch(&self) -> Vec<Experience> {
let mut batch = Vec::new(); let mut batch = Vec::new();
let buffer_size = self.replay_buffer.len(); let buffer_size = self.replay_buffer.len();
for _ in 0..self.config.batch_size.min(buffer_size) { for _ in 0..self.config.batch_size.min(buffer_size) {
let index = rand::random::<usize>() % buffer_size; let index = rand::random::<usize>() % buffer_size;
if let Some(exp) = self.replay_buffer.get(index) { if let Some(exp) = self.replay_buffer.get(index) {
batch.push(exp.clone()); batch.push(exp.clone());
} }
} }
batch batch
} }
@ -245,25 +252,27 @@ impl BurnDqnAgent {
let config_path = format!("{}_config.json", path); let config_path = format!("{}_config.json", path);
let config_json = serde_json::to_string_pretty(&self.config)?; let config_json = serde_json::to_string_pretty(&self.config)?;
std::fs::write(config_path, config_json)?; std::fs::write(config_path, config_json)?;
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend) // Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
let inference_network = self.q_network.clone().into_record(); let inference_network = self.q_network.clone().into_record();
let recorder = CompactRecorder::new(); let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path); let model_path = format!("{}_model.burn", path);
recorder.record(inference_network, model_path.into())?; recorder.record(inference_network, model_path.into())?;
println!("Modèle sauvegardé : {}", path); println!("Modèle sauvegardé : {}", path);
Ok(()) Ok(())
} }
/// Charge un modèle pour l'inférence /// Charge un modèle pour l'inférence
pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> { pub fn load_model_for_inference(
path: &str,
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
// Charger la configuration // Charger la configuration
let config_path = format!("{}_config.json", path); let config_path = format!("{}_config.json", path);
let config_json = std::fs::read_to_string(config_path)?; let config_json = std::fs::read_to_string(config_path)?;
let config: DqnConfig = serde_json::from_str(&config_json)?; let config: DqnConfig = serde_json::from_str(&config_json)?;
// Créer le réseau pour l'inférence // Créer le réseau pour l'inférence
let device = NdArrayDevice::default(); let device = NdArrayDevice::default();
let network = DqnNetwork::<InferenceBackend>::new( let network = DqnNetwork::<InferenceBackend>::new(
@ -272,13 +281,13 @@ impl BurnDqnAgent {
config.action_size, config.action_size,
&device, &device,
); );
// Charger les poids // Charger les poids
let model_path = format!("{}_model.burn", path); let model_path = format!("{}_model.burn", path);
let recorder = CompactRecorder::new(); let recorder = CompactRecorder::new();
let record = recorder.load(model_path.into(), &device)?; let record = recorder.load(model_path.into(), &device)?;
let network = network.load_record(record); let network = network.load_record(record);
Ok((network, config)) Ok((network, config))
} }
@ -291,4 +300,4 @@ impl BurnDqnAgent {
pub fn get_buffer_size(&self) -> usize { pub fn get_buffer_size(&self) -> usize {
self.replay_buffer.len() self.replay_buffer.len()
} }
} }

View file

@ -250,3 +250,19 @@ claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>. Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>.
I just had a claude code session in which I kept having this error, even if the agent didn't seem to read a lot of files : API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute.
at the end of the session the token usage and cost indicated was this :
Total cost: $0.95
Total duration (API): 1h 24m 22.8s
Total duration (wall): 1h 43m 3.5s
Total code changes: 746 lines added, 0 lines removed
Token usage by model:
claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
but the usage on the /usage page was 2,073,698 token in, and the cost on the /cost page was $7.90.
When looking at the costs csv file, it seems that it is the "input cache write 5m" that consumed nearly all the tokens ( $7,71 ). Is it a bug ?