claude (dqn_rs trainer simplifié, compilation still fails)

This commit is contained in:
Henri Bourcereau 2025-06-22 16:23:38 +02:00
parent 16dd4fbf68
commit 80734990eb
3 changed files with 93 additions and 17 deletions

View file

@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs"
name = "train_burn_dqn"
path = "src/bin/train_burn_dqn.rs"
[[bin]]
name = "simple_burn_train"
path = "src/bin/simple_burn_train.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -0,0 +1,83 @@
use bot::strategy::burn_dqn::{BurnDqnAgent, BurnDqnConfig, Experience};
use rand::Rng;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
println!("Entraînement DQN simplifié avec Burn");
// Configuration DQN simple
let config = BurnDqnConfig {
state_size: 10,
action_size: 4,
hidden_size: 64,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 1.0,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 1000,
batch_size: 16,
target_update_freq: 50,
};
let mut agent = BurnDqnAgent::new(config);
let mut rng = rand::thread_rng();
println!("Début de l'entraînement simple...");
for episode in 1..=100 {
let mut total_reward = 0.0;
for step in 1..=50 {
// État aléatoire simple
let state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
// Actions valides (toutes les actions pour simplifier)
let valid_actions: Vec<usize> = vec![0, 1, 2, 3];
// Sélectionner une action
let action = agent.select_action(&state, &valid_actions);
// Récompense simulée
let reward = rng.gen::<f32>() - 0.5; // Récompense entre -0.5 et 0.5
// État suivant aléatoire
let next_state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
// Fin d'épisode aléatoire
let done = step >= 50 || rng.gen::<f32>() < 0.1;
// Ajouter l'expérience
let experience = Experience {
state: state.clone(),
action,
reward,
next_state: if done { None } else { Some(next_state) },
done,
};
agent.add_experience(experience);
// Entraîner
if let Some(loss) = agent.train_step() {
if step % 25 == 0 {
println!("Episode {}, Step {}, Loss: {:.4}, Epsilon: {:.3}",
episode, step, loss, agent.get_epsilon());
}
}
total_reward += reward;
if done {
break;
}
}
if episode % 10 == 0 {
println!("Episode {} terminé. Récompense totale: {:.2}", episode, total_reward);
}
}
println!("Entraînement terminé !");
Ok(())
}

View file

@ -117,9 +117,7 @@ impl BurnDqnAgent {
&device,
);
let optimizer = AdamConfig::new()
.with_learning_rate(config.learning_rate)
.init();
let optimizer = AdamConfig::new().init();
Self {
config: config.clone(),
@ -147,10 +145,9 @@ impl BurnDqnAgent {
}
// Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 2>::from_data(
burn::tensor::Data::new(state.to_vec(), burn::tensor::Shape::new([1, state.len()])),
&self.device
);
// Utiliser from_floats avec un vecteur 2D pour Burn 0.17
let state_2d = vec![state.to_vec()];
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state_2d, &self.device);
let q_values = self.q_network.forward(state_tensor);
let q_data = q_values.into_data().to_vec::<f32>().unwrap();
@ -195,17 +192,9 @@ impl BurnDqnAgent {
.collect();
// Convertir en format compatible avec Burn
let state_data: Vec<f32> = states.into_iter().flatten().collect();
let state_tensor = Tensor::<MyBackend, 2>::from_data(
burn::tensor::Data::new(state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
&self.device
);
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
let next_state_tensor = if !next_states.is_empty() {
let next_state_data: Vec<f32> = next_states.into_iter().flatten().collect();
Some(Tensor::<MyBackend, 2>::from_data(
burn::tensor::Data::new(next_state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
&self.device
))
Some(Tensor::<MyBackend, 2>::from_floats(next_states, &self.device))
} else {
None
};