wip
This commit is contained in:
parent
cf93255f03
commit
f05094b2d4
|
|
@ -1,5 +1,5 @@
|
||||||
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
|
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
|
||||||
use bot::strategy::burn_environment::{TrictracEnvironment, TrictracAction};
|
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
|
||||||
use bot::strategy::dqn_common::get_valid_actions;
|
use bot::strategy::dqn_common::get_valid_actions;
|
||||||
use burn_rl::base::Environment;
|
use burn_rl::base::Environment;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
@ -80,7 +80,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Configuration DQN
|
// Configuration DQN
|
||||||
let config = DqnConfig {
|
let config = DqnConfig {
|
||||||
state_size: 36,
|
state_size: 36,
|
||||||
action_size: 1000, // Espace d'actions réduit via contexte
|
action_size: 1252, // Espace d'actions réduit via contexte
|
||||||
hidden_size: 256,
|
hidden_size: 256,
|
||||||
learning_rate: 0.001,
|
learning_rate: 0.001,
|
||||||
gamma: 0.99,
|
gamma: 0.99,
|
||||||
|
|
@ -94,6 +94,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
// Créer l'agent et l'environnement
|
// Créer l'agent et l'environnement
|
||||||
let mut agent = BurnDqnAgent::new(config);
|
let mut agent = BurnDqnAgent::new(config);
|
||||||
|
let mut optimizer = AdamConfig::new().init();
|
||||||
|
|
||||||
let mut env = TrictracEnvironment::new(true);
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
|
||||||
// Variables pour les statistiques
|
// Variables pour les statistiques
|
||||||
|
|
@ -114,7 +116,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
step += 1;
|
step += 1;
|
||||||
let current_state = snapshot.state;
|
let current_state = snapshot.state();
|
||||||
|
|
||||||
// Obtenir les actions valides selon le contexte du jeu
|
// Obtenir les actions valides selon le contexte du jeu
|
||||||
let valid_actions = get_valid_actions(&env.game);
|
let valid_actions = get_valid_actions(&env.game);
|
||||||
|
|
@ -127,22 +129,31 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
|
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
|
||||||
|
|
||||||
// Sélectionner une action avec l'agent DQN
|
// Sélectionner une action avec l'agent DQN
|
||||||
let action_index = agent.select_action(¤t_state.data.iter().map(|&x| x as f32).collect::<Vec<_>>(), &valid_indices);
|
let action_index = agent.select_action(
|
||||||
let action = TrictracAction { index: action_index as u32 };
|
¤t_state
|
||||||
|
.data
|
||||||
|
.iter()
|
||||||
|
.map(|&x| x as f32)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
&valid_indices,
|
||||||
|
);
|
||||||
|
let action = TrictracAction {
|
||||||
|
index: action_index as u32,
|
||||||
|
};
|
||||||
|
|
||||||
// Exécuter l'action
|
// Exécuter l'action
|
||||||
snapshot = env.step(action);
|
snapshot = env.step(action);
|
||||||
episode_reward += snapshot.reward;
|
episode_reward += snapshot.reward();
|
||||||
|
|
||||||
// Préparer l'expérience pour l'agent
|
// Préparer l'expérience pour l'agent
|
||||||
let experience = Experience {
|
let experience = Experience {
|
||||||
state: current_state.data.iter().map(|&x| x as f32).collect(),
|
state: current_state.data.iter().map(|&x| x as f32).collect(),
|
||||||
action: action_index,
|
action: action_index,
|
||||||
reward: snapshot.reward,
|
reward: snapshot.reward(),
|
||||||
next_state: if snapshot.terminated {
|
next_state: if snapshot.terminated {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(snapshot.state.data.iter().map(|&x| x as f32).collect())
|
Some(snapshot.state().data.iter().map(|&x| x as f32).collect())
|
||||||
},
|
},
|
||||||
done: snapshot.terminated,
|
done: snapshot.terminated,
|
||||||
};
|
};
|
||||||
|
|
@ -151,7 +162,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
agent.add_experience(experience);
|
agent.add_experience(experience);
|
||||||
|
|
||||||
// Entraîner l'agent
|
// Entraîner l'agent
|
||||||
if let Some(loss) = agent.train_step() {
|
if let Some(loss) = agent.train_step(optimizer) {
|
||||||
episode_loss += loss;
|
episode_loss += loss;
|
||||||
loss_count += 1;
|
loss_count += 1;
|
||||||
}
|
}
|
||||||
|
|
@ -163,7 +174,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculer la loss moyenne de l'épisode
|
// Calculer la loss moyenne de l'épisode
|
||||||
let avg_loss = if loss_count > 0 { episode_loss / loss_count as f32 } else { 0.0 };
|
let avg_loss = if loss_count > 0 {
|
||||||
|
episode_loss / loss_count as f32
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
// Sauvegarder les statistiques
|
// Sauvegarder les statistiques
|
||||||
total_rewards.push(episode_reward);
|
total_rewards.push(episode_reward);
|
||||||
|
|
@ -172,9 +187,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
// Affichage des statistiques
|
// Affichage des statistiques
|
||||||
if episode % save_every == 0 {
|
if episode % save_every == 0 {
|
||||||
let avg_reward = total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
let avg_reward =
|
||||||
let avg_length = episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
|
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||||
let avg_episode_loss = losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
let avg_length =
|
||||||
|
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
|
||||||
|
let avg_episode_loss =
|
||||||
|
losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||||
|
|
||||||
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
|
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
|
||||||
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
|
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
|
||||||
|
|
@ -187,8 +205,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!(" → Modèle sauvegardé : {}", checkpoint_path);
|
println!(" → Modèle sauvegardé : {}", checkpoint_path);
|
||||||
}
|
}
|
||||||
} else if episode % 10 == 0 {
|
} else if episode % 10 == 0 {
|
||||||
println!("Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
|
println!(
|
||||||
episode, episode_reward, step, avg_loss, agent.get_epsilon());
|
"Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
|
||||||
|
episode,
|
||||||
|
episode_reward,
|
||||||
|
step,
|
||||||
|
avg_loss,
|
||||||
|
agent.get_epsilon()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -199,18 +223,41 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Statistiques finales
|
// Statistiques finales
|
||||||
println!();
|
println!();
|
||||||
println!("=== Résultats de l'entraînement ===");
|
println!("=== Résultats de l'entraînement ===");
|
||||||
let final_avg_reward = total_rewards.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
let final_avg_reward = total_rewards
|
||||||
let final_avg_length = episode_lengths.iter().rev().take(100.min(episodes)).sum::<usize>() / 100.min(episodes);
|
.iter()
|
||||||
let final_avg_loss = losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
.rev()
|
||||||
|
.take(100.min(episodes))
|
||||||
|
.sum::<f32>()
|
||||||
|
/ 100.min(episodes) as f32;
|
||||||
|
let final_avg_length = episode_lengths
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(100.min(episodes))
|
||||||
|
.sum::<usize>()
|
||||||
|
/ 100.min(episodes);
|
||||||
|
let final_avg_loss =
|
||||||
|
losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
||||||
|
|
||||||
println!("Récompense moyenne (100 derniers épisodes) : {:.3}", final_avg_reward);
|
println!(
|
||||||
println!("Longueur moyenne (100 derniers épisodes) : {}", final_avg_length);
|
"Récompense moyenne (100 derniers épisodes) : {:.3}",
|
||||||
println!("Loss moyenne (100 derniers épisodes) : {:.6}", final_avg_loss);
|
final_avg_reward
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Longueur moyenne (100 derniers épisodes) : {}",
|
||||||
|
final_avg_length
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Loss moyenne (100 derniers épisodes) : {:.6}",
|
||||||
|
final_avg_loss
|
||||||
|
);
|
||||||
println!("Epsilon final : {:.3}", agent.get_epsilon());
|
println!("Epsilon final : {:.3}", agent.get_epsilon());
|
||||||
println!("Taille du buffer final : {}", agent.get_buffer_size());
|
println!("Taille du buffer final : {}", agent.get_buffer_size());
|
||||||
|
|
||||||
// Statistiques globales
|
// Statistiques globales
|
||||||
let max_reward = total_rewards.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
let max_reward = total_rewards
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.fold(f32::NEG_INFINITY, f32::max);
|
||||||
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
|
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||||
println!("Récompense max : {:.3}", max_reward);
|
println!("Récompense max : {:.3}", max_reward);
|
||||||
println!("Récompense min : {:.3}", min_reward);
|
println!("Récompense min : {:.3}", min_reward);
|
||||||
|
|
@ -220,7 +267,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!("Modèle final sauvegardé : {}", final_path);
|
println!("Modèle final sauvegardé : {}", final_path);
|
||||||
println!();
|
println!();
|
||||||
println!("Pour utiliser le modèle entraîné :");
|
println!("Pour utiliser le modèle entraîné :");
|
||||||
println!(" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", model_path);
|
println!(
|
||||||
|
" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy",
|
||||||
|
model_path
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
|
use burn::module::AutodiffModule;
|
||||||
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
use burn::{
|
use burn::{
|
||||||
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
||||||
nn::{Linear, LinearConfig, loss::MseLoss},
|
|
||||||
module::Module,
|
module::Module,
|
||||||
tensor::Tensor,
|
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||||
optim::{AdamConfig, Optimizer},
|
optim::{GradientsParams, Optimizer},
|
||||||
record::{CompactRecorder, Recorder},
|
record::{CompactRecorder, Recorder},
|
||||||
|
tensor::Tensor,
|
||||||
};
|
};
|
||||||
use rand::Rng;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
|
||||||
|
|
@ -26,7 +27,12 @@ pub struct DqnNetwork<B: burn::prelude::Backend> {
|
||||||
|
|
||||||
impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
||||||
/// Crée un nouveau réseau DQN
|
/// Crée un nouveau réseau DQN
|
||||||
pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self {
|
pub fn new(
|
||||||
|
input_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
output_size: usize,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Self {
|
||||||
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
|
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
|
||||||
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
|
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
|
||||||
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
|
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
|
||||||
|
|
@ -94,7 +100,6 @@ pub struct BurnDqnAgent {
|
||||||
device: MyDevice,
|
device: MyDevice,
|
||||||
q_network: DqnNetwork<MyBackend>,
|
q_network: DqnNetwork<MyBackend>,
|
||||||
target_network: DqnNetwork<MyBackend>,
|
target_network: DqnNetwork<MyBackend>,
|
||||||
optimizer: burn::optim::Adam<MyBackend>,
|
|
||||||
replay_buffer: VecDeque<Experience>,
|
replay_buffer: VecDeque<Experience>,
|
||||||
epsilon: f32,
|
epsilon: f32,
|
||||||
step_count: usize,
|
step_count: usize,
|
||||||
|
|
@ -119,14 +124,11 @@ impl BurnDqnAgent {
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
let optimizer = AdamConfig::new().init();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
device,
|
device,
|
||||||
q_network,
|
q_network,
|
||||||
target_network,
|
target_network,
|
||||||
optimizer,
|
|
||||||
replay_buffer: VecDeque::new(),
|
replay_buffer: VecDeque::new(),
|
||||||
epsilon: config.epsilon,
|
epsilon: config.epsilon,
|
||||||
step_count: 0,
|
step_count: 0,
|
||||||
|
|
@ -146,11 +148,11 @@ impl BurnDqnAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exploitation : choisir la meilleure action selon le Q-network
|
// Exploitation : choisir la meilleure action selon le Q-network
|
||||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats([state], &self.device);
|
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device);
|
||||||
let q_values = self.q_network.forward(state_tensor);
|
let q_values = self.q_network.forward(state_tensor);
|
||||||
|
|
||||||
// Convertir en vecteur pour traitement
|
// Convertir en vecteur pour traitement
|
||||||
let q_data = q_values.into_data().convert::<f32>().value;
|
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
|
||||||
|
|
||||||
// Trouver la meilleure action parmi les actions valides
|
// Trouver la meilleure action parmi les actions valides
|
||||||
let mut best_action = valid_actions[0];
|
let mut best_action = valid_actions[0];
|
||||||
|
|
@ -175,7 +177,10 @@ impl BurnDqnAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Entraîne le réseau sur un batch d'expériences
|
/// Entraîne le réseau sur un batch d'expériences
|
||||||
pub fn train_step(&mut self) -> Option<f32> {
|
pub fn train_step<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||||
|
&mut self,
|
||||||
|
optimizer: &mut impl Optimizer<M, B>,
|
||||||
|
) -> Option<f32> {
|
||||||
if self.replay_buffer.len() < self.config.batch_size {
|
if self.replay_buffer.len() < self.config.batch_size {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -197,12 +202,14 @@ impl BurnDqnAgent {
|
||||||
let loss = MseLoss::new().forward(
|
let loss = MseLoss::new().forward(
|
||||||
current_q_values,
|
current_q_values,
|
||||||
target_q_values,
|
target_q_values,
|
||||||
burn::nn::loss::Reduction::Mean
|
burn::nn::loss::Reduction::Mean,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Backpropagation (version simplifiée)
|
// Backpropagation (version simplifiée)
|
||||||
let grads = loss.backward();
|
let grads = loss.backward();
|
||||||
self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads);
|
// Gradients linked to each parameter of the model.
|
||||||
|
// let grads = GradientsParams::from_grads(grads, &self.q_network);
|
||||||
|
self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads);
|
||||||
|
|
||||||
// Mise à jour du réseau cible
|
// Mise à jour du réseau cible
|
||||||
self.step_count += 1;
|
self.step_count += 1;
|
||||||
|
|
@ -258,7 +265,9 @@ impl BurnDqnAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Charge un modèle pour l'inférence
|
/// Charge un modèle pour l'inférence
|
||||||
pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
|
pub fn load_model_for_inference(
|
||||||
|
path: &str,
|
||||||
|
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
|
||||||
// Charger la configuration
|
// Charger la configuration
|
||||||
let config_path = format!("{}_config.json", path);
|
let config_path = format!("{}_config.json", path);
|
||||||
let config_json = std::fs::read_to_string(config_path)?;
|
let config_json = std::fs::read_to_string(config_path)?;
|
||||||
|
|
|
||||||
|
|
@ -250,3 +250,19 @@ claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
|
||||||
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
||||||
|
|
||||||
Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>.
|
Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>.
|
||||||
|
|
||||||
|
I just had a claude code session in which I kept having this error, even if the agent didn't seem to read a lot of files : API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute.
|
||||||
|
|
||||||
|
at the end of the session the token usage and cost indicated was this :
|
||||||
|
|
||||||
|
Total cost: $0.95
|
||||||
|
Total duration (API): 1h 24m 22.8s
|
||||||
|
Total duration (wall): 1h 43m 3.5s
|
||||||
|
Total code changes: 746 lines added, 0 lines removed
|
||||||
|
Token usage by model:
|
||||||
|
claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
|
||||||
|
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
||||||
|
|
||||||
|
but the usage on the /usage page was 2,073,698 token in, and the cost on the /cost page was $7.90.
|
||||||
|
|
||||||
|
When looking at the costs csv file, it seems that it is the "input cache write 5m" that consumed nearly all the tokens ( $7,71 ). Is it a bug ?
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue