claude (dqn_rs trainer simplifié, compilation still fails)
This commit is contained in:
parent
16dd4fbf68
commit
80734990eb
|
|
@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs"
|
|||
name = "train_burn_dqn"
|
||||
path = "src/bin/train_burn_dqn.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "simple_burn_train"
|
||||
path = "src/bin/simple_burn_train.rs"
|
||||
|
||||
[dependencies]
|
||||
pretty_assertions = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
|
|
|||
83
bot/src/bin/simple_burn_train.rs
Normal file
83
bot/src/bin/simple_burn_train.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
use bot::strategy::burn_dqn::{BurnDqnAgent, BurnDqnConfig, Experience};
|
||||
use rand::Rng;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::init();
|
||||
|
||||
println!("Entraînement DQN simplifié avec Burn");
|
||||
|
||||
// Configuration DQN simple
|
||||
let config = BurnDqnConfig {
|
||||
state_size: 10,
|
||||
action_size: 4,
|
||||
hidden_size: 64,
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 1.0,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 1000,
|
||||
batch_size: 16,
|
||||
target_update_freq: 50,
|
||||
};
|
||||
|
||||
let mut agent = BurnDqnAgent::new(config);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
println!("Début de l'entraînement simple...");
|
||||
|
||||
for episode in 1..=100 {
|
||||
let mut total_reward = 0.0;
|
||||
|
||||
for step in 1..=50 {
|
||||
// État aléatoire simple
|
||||
let state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
|
||||
|
||||
// Actions valides (toutes les actions pour simplifier)
|
||||
let valid_actions: Vec<usize> = vec![0, 1, 2, 3];
|
||||
|
||||
// Sélectionner une action
|
||||
let action = agent.select_action(&state, &valid_actions);
|
||||
|
||||
// Récompense simulée
|
||||
let reward = rng.gen::<f32>() - 0.5; // Récompense entre -0.5 et 0.5
|
||||
|
||||
// État suivant aléatoire
|
||||
let next_state: Vec<f32> = (0..10).map(|_| rng.gen::<f32>()).collect();
|
||||
|
||||
// Fin d'épisode aléatoire
|
||||
let done = step >= 50 || rng.gen::<f32>() < 0.1;
|
||||
|
||||
// Ajouter l'expérience
|
||||
let experience = Experience {
|
||||
state: state.clone(),
|
||||
action,
|
||||
reward,
|
||||
next_state: if done { None } else { Some(next_state) },
|
||||
done,
|
||||
};
|
||||
agent.add_experience(experience);
|
||||
|
||||
// Entraîner
|
||||
if let Some(loss) = agent.train_step() {
|
||||
if step % 25 == 0 {
|
||||
println!("Episode {}, Step {}, Loss: {:.4}, Epsilon: {:.3}",
|
||||
episode, step, loss, agent.get_epsilon());
|
||||
}
|
||||
}
|
||||
|
||||
total_reward += reward;
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if episode % 10 == 0 {
|
||||
println!("Episode {} terminé. Récompense totale: {:.2}", episode, total_reward);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Entraînement terminé !");
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -117,9 +117,7 @@ impl BurnDqnAgent {
|
|||
&device,
|
||||
);
|
||||
|
||||
let optimizer = AdamConfig::new()
|
||||
.with_learning_rate(config.learning_rate)
|
||||
.init();
|
||||
let optimizer = AdamConfig::new().init();
|
||||
|
||||
Self {
|
||||
config: config.clone(),
|
||||
|
|
@ -147,10 +145,9 @@ impl BurnDqnAgent {
|
|||
}
|
||||
|
||||
// Exploitation : choisir la meilleure action selon le Q-network
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_data(
|
||||
burn::tensor::Data::new(state.to_vec(), burn::tensor::Shape::new([1, state.len()])),
|
||||
&self.device
|
||||
);
|
||||
// Utiliser from_floats avec un vecteur 2D pour Burn 0.17
|
||||
let state_2d = vec![state.to_vec()];
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state_2d, &self.device);
|
||||
|
||||
let q_values = self.q_network.forward(state_tensor);
|
||||
let q_data = q_values.into_data().to_vec::<f32>().unwrap();
|
||||
|
|
@ -195,17 +192,9 @@ impl BurnDqnAgent {
|
|||
.collect();
|
||||
|
||||
// Convertir en format compatible avec Burn
|
||||
let state_data: Vec<f32> = states.into_iter().flatten().collect();
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_data(
|
||||
burn::tensor::Data::new(state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
|
||||
&self.device
|
||||
);
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
|
||||
let next_state_tensor = if !next_states.is_empty() {
|
||||
let next_state_data: Vec<f32> = next_states.into_iter().flatten().collect();
|
||||
Some(Tensor::<MyBackend, 2>::from_data(
|
||||
burn::tensor::Data::new(next_state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])),
|
||||
&self.device
|
||||
))
|
||||
Some(Tensor::<MyBackend, 2>::from_floats(next_states, &self.device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue