diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 5d4f32d..99cba90 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs" name = "train_burn_dqn" path = "src/bin/train_burn_dqn.rs" +[[bin]] +name = "simple_burn_train" +path = "src/bin/simple_burn_train.rs" + [dependencies] pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } diff --git a/bot/src/bin/simple_burn_train.rs b/bot/src/bin/simple_burn_train.rs new file mode 100644 index 0000000..8946cc9 --- /dev/null +++ b/bot/src/bin/simple_burn_train.rs @@ -0,0 +1,83 @@ +use bot::strategy::burn_dqn::{BurnDqnAgent, BurnDqnConfig, Experience}; +use rand::Rng; + +fn main() -> Result<(), Box> { + env_logger::init(); + + println!("Entraînement DQN simplifié avec Burn"); + + // Configuration DQN simple + let config = BurnDqnConfig { + state_size: 10, + action_size: 4, + hidden_size: 64, + learning_rate: 0.001, + gamma: 0.99, + epsilon: 1.0, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 1000, + batch_size: 16, + target_update_freq: 50, + }; + + let mut agent = BurnDqnAgent::new(config); + let mut rng = rand::thread_rng(); + + println!("Début de l'entraînement simple..."); + + for episode in 1..=100 { + let mut total_reward = 0.0; + + for step in 1..=50 { + // État aléatoire simple + let state: Vec = (0..10).map(|_| rng.gen::()).collect(); + + // Actions valides (toutes les actions pour simplifier) + let valid_actions: Vec = vec![0, 1, 2, 3]; + + // Sélectionner une action + let action = agent.select_action(&state, &valid_actions); + + // Récompense simulée + let reward = rng.gen::() - 0.5; // Récompense entre -0.5 et 0.5 + + // État suivant aléatoire + let next_state: Vec = (0..10).map(|_| rng.gen::()).collect(); + + // Fin d'épisode aléatoire + let done = step >= 50 || rng.gen::() < 0.1; + + // Ajouter l'expérience + let experience = Experience { + state: state.clone(), + action, + reward, + next_state: if done { None } else { Some(next_state) }, + done, + }; + agent.add_experience(experience); + + // Entraîner + if let Some(loss) = agent.train_step() { + if step % 25 == 0 { + println!("Episode {}, Step {}, Loss: {:.4}, Epsilon: {:.3}", + episode, step, loss, agent.get_epsilon()); + } + } + + total_reward += reward; + + if done { + break; + } + } + + if episode % 10 == 0 { + println!("Episode {} terminé. Récompense totale: {:.2}", episode, total_reward); + } + } + + println!("Entraînement terminé !"); + Ok(()) +} \ No newline at end of file diff --git a/bot/src/strategy/burn_dqn.rs b/bot/src/strategy/burn_dqn.rs index 1b83410..883af70 100644 --- a/bot/src/strategy/burn_dqn.rs +++ b/bot/src/strategy/burn_dqn.rs @@ -117,9 +117,7 @@ impl BurnDqnAgent { &device, ); - let optimizer = AdamConfig::new() - .with_learning_rate(config.learning_rate) - .init(); + let optimizer = AdamConfig::new().init(); Self { config: config.clone(), @@ -147,10 +145,9 @@ impl BurnDqnAgent { } // Exploitation : choisir la meilleure action selon le Q-network - let state_tensor = Tensor::::from_data( - burn::tensor::Data::new(state.to_vec(), burn::tensor::Shape::new([1, state.len()])), - &self.device - ); + // Utiliser from_floats avec un vecteur 2D pour Burn 0.17 + let state_2d = vec![state.to_vec()]; + let state_tensor = Tensor::::from_floats(state_2d, &self.device); let q_values = self.q_network.forward(state_tensor); let q_data = q_values.into_data().to_vec::().unwrap(); @@ -195,17 +192,9 @@ impl BurnDqnAgent { .collect(); // Convertir en format compatible avec Burn - let state_data: Vec = states.into_iter().flatten().collect(); - let state_tensor = Tensor::::from_data( - burn::tensor::Data::new(state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])), - &self.device - ); + let state_tensor = Tensor::::from_floats(states, &self.device); let next_state_tensor = if !next_states.is_empty() { - let next_state_data: Vec = next_states.into_iter().flatten().collect(); - Some(Tensor::::from_data( - burn::tensor::Data::new(next_state_data, burn::tensor::Shape::new([batch.len(), self.config.state_size])), - &self.device - )) + Some(Tensor::::from_floats(next_states, &self.device)) } else { None };