claude (dqn_rs trainer simplifié, compilation still fails)
This commit is contained in:
parent
16dd4fbf68
commit
80734990eb
|
|
@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs"
|
||||||
name = "train_burn_dqn"
|
name = "train_burn_dqn"
|
||||||
path = "src/bin/train_burn_dqn.rs"
|
path = "src/bin/train_burn_dqn.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "simple_burn_train"
|
||||||
|
path = "src/bin/simple_burn_train.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
|
||||||
83
bot/src/bin/simple_burn_train.rs
Normal file
83
bot/src/bin/simple_burn_train.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
use bot::strategy::burn_dqn::{BurnDqnAgent, BurnDqnConfig, Experience};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
println!("Entraînement DQN simplifié avec Burn");
|
||||||
|
|
||||||
|
// Configuration DQN simple
|
||||||
|
let config = BurnDqnConfig {
|
||||||
|
state_size: 10,
|
||||||
|
action_size: 4,
|
||||||
|
hidden_size: 64,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 1.0,
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 1000,
|
||||||
|
batch_size: 16,
|
||||||
|
target_update_freq: 50,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut agent = BurnDqnAgent::new(config);
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
println!("Début de l'entraînement simple...");
|
||||||
|
|
||||||
|
for episode in 1..=100 {
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
|
||||||
|
for step in 1..=50 {
|
||||||
|
// État aléatoire simple
|
||||||
|
let state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
|
||||||
|
|
||||||
|
// Actions valides (toutes les actions pour simplifier)
|
||||||
|
let valid_actions: Vec<usize> = vec![0, 1, 2, 3];
|
||||||
|
|
||||||
|
// Sélectionner une action
|
||||||
|
let action = agent.select_action(&state, &valid_actions);
|
||||||
|
|
||||||
|
// Récompense simulée
|
||||||
|
let reward = rng.gen::<f32>() - 0.5; // Récompense entre -0.5 et 0.5
|
||||||
|
|
||||||
|
// État suivant aléatoire
|
||||||
|
let next_state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
|
||||||
|
|
||||||
|
// Fin d'épisode aléatoire
|
||||||
|
let done = step >= 50 || rng.gen::<f32>() < 0.1;
|
||||||
|
|
||||||
|
// Ajouter l'expérience
|
||||||
|
let experience = Experience {
|
||||||
|
state: state.clone(),
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
next_state: if done { None } else { Some(next_state) },
|
||||||
|
done,
|
||||||
|
};
|
||||||
|
agent.add_experience(experience);
|
||||||
|
|
||||||
|
// Entraîner
|
||||||
|
if let Some(loss) = agent.train_step() {
|
||||||
|
if step % 25 == 0 {
|
||||||
|
println!("Episode {}, Step {}, Loss: {:.4}, Epsilon: {:.3}",
|
||||||
|
episode, step, loss, agent.get_epsilon());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
total_reward += reward;
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if episode % 10 == 0 {
|
||||||
|
println!("Episode {} terminé. Récompense totale: {:.2}", episode, total_reward);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Entraînement terminé !");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
@ -117,9 +117,7 @@ impl BurnDqnAgent {
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
let optimizer = AdamConfig::new()
|
let optimizer = AdamConfig::new().init();
|
||||||
.with_learning_rate(config.learning_rate)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
|
|
@ -147,10 +145,9 @@ impl BurnDqnAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exploitation : choisir la meilleure action selon le Q-network
|
// Exploitation : choisir la meilleure action selon le Q-network
|
||||||
let state_tensor = Tensor::<MyBackend, 2>::from_data(
|
// Utiliser from_floats avec un vecteur 2D pour Burn 0.17
|
||||||
burn::tensor::Data::new(state.to_vec(), burn::tensor::Shape::new([1, state.len()])),
|
let state_2d = vec![state.to_vec()];
|
||||||
&self.device
|
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state_2d, &self.device);
|
||||||
);
|
|
||||||
|
|
||||||
let q_values = self.q_network.forward(state_tensor);
|
let q_values = self.q_network.forward(state_tensor);
|
||||||
let q_data = q_values.into_data().to_vec::<f32>().unwrap();
|
let q_data = q_values.into_data().to_vec::<f32>().unwrap();
|
||||||
|
|
@ -195,17 +192,9 @@ impl BurnDqnAgent {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Convertir en format compatible avec Burn
|
// Convertir en format compatible avec Burn
|
||||||
let state_data: Vec<f32> = states.into_iter().flatten().collect();
|
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
|
||||||
let state_tensor = Tensor::<MyBackend, 2>::from_data(
|
|
||||||
burn::tensor::Data::new(state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
|
|
||||||
&self.device
|
|
||||||
);
|
|
||||||
let next_state_tensor = if !next_states.is_empty() {
|
let next_state_tensor = if !next_states.is_empty() {
|
||||||
let next_state_data: Vec<f32> = next_states.into_iter().flatten().collect();
|
Some(Tensor::<MyBackend, 2>::from_floats(next_states, &self.device))
|
||||||
Some(Tensor::<MyBackend, 2>::from_data(
|
|
||||||
burn::tensor::Data::new(next_state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
|
|
||||||
&self.device
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue