This commit is contained in:
Henri Bourcereau 2025-06-28 21:34:44 +02:00
parent cf93255f03
commit f05094b2d4
3 changed files with 150 additions and 75 deletions

View file

@ -1,12 +1,13 @@
use burn::module::AutodiffModule;
use burn::tensor::backend::AutodiffBackend;
use burn::{
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
nn::{Linear, LinearConfig, loss::MseLoss},
module::Module,
tensor::Tensor,
optim::{AdamConfig, Optimizer},
nn::{loss::MseLoss, Linear, LinearConfig},
optim::{GradientsParams, Optimizer},
record::{CompactRecorder, Recorder},
tensor::Tensor,
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
@ -26,11 +27,16 @@ pub struct DqnNetwork<B: burn::prelude::Backend> {
impl<B: burn::prelude::Backend> DqnNetwork<B> {
/// Crée un nouveau réseau DQN
pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self {
pub fn new(
input_size: usize,
hidden_size: usize,
output_size: usize,
device: &B::Device,
) -> Self {
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
Self { fc1, fc2, fc3 }
}
@ -94,7 +100,6 @@ pub struct BurnDqnAgent {
device: MyDevice,
q_network: DqnNetwork<MyBackend>,
target_network: DqnNetwork<MyBackend>,
optimizer: burn::optim::Adam<MyBackend>,
replay_buffer: VecDeque<Experience>,
epsilon: f32,
step_count: usize,
@ -104,29 +109,26 @@ impl BurnDqnAgent {
/// Crée un nouvel agent DQN
pub fn new(config: DqnConfig) -> Self {
let device = MyDevice::default();
let q_network = DqnNetwork::new(
config.state_size,
config.hidden_size,
config.action_size,
&device,
);
let target_network = DqnNetwork::new(
config.state_size,
config.hidden_size,
config.action_size,
&device,
);
let optimizer = AdamConfig::new().init();
Self {
config: config.clone(),
device,
q_network,
target_network,
optimizer,
replay_buffer: VecDeque::new(),
epsilon: config.epsilon,
step_count: 0,
@ -146,23 +148,23 @@ impl BurnDqnAgent {
}
// Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 2>::from_floats([state], &self.device);
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device);
let q_values = self.q_network.forward(state_tensor);
// Convertir en vecteur pour traitement
let q_data = q_values.into_data().convert::<f32>().value;
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
// Trouver la meilleure action parmi les actions valides
let mut best_action = valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for &action in valid_actions {
if action < q_data.len() && q_data[action] > best_q_value {
best_q_value = q_data[action];
best_action = action;
}
}
best_action
}
@ -175,46 +177,51 @@ impl BurnDqnAgent {
}
/// Entraîne le réseau sur un batch d'expériences
pub fn train_step(&mut self) -> Option<f32> {
pub fn train_step<B: AutodiffBackend, M: AutodiffModule<B>>(
&mut self,
optimizer: &mut impl Optimizer<M, B>,
) -> Option<f32> {
if self.replay_buffer.len() < self.config.batch_size {
return None;
}
// Échantillonner un batch d'expériences
let batch = self.sample_batch();
// Préparer les tenseurs d'état
let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect();
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
// Calculer les Q-values actuelles
let current_q_values = self.q_network.forward(state_tensor);
// Pour l'instant, version simplifiée sans calcul de target
let target_q_values = current_q_values.clone();
// Calculer la loss MSE
let loss = MseLoss::new().forward(
current_q_values,
target_q_values,
burn::nn::loss::Reduction::Mean
current_q_values,
target_q_values,
burn::nn::loss::Reduction::Mean,
);
// Backpropagation (version simplifiée)
let grads = loss.backward();
self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads);
// Gradients linked to each parameter of the model.
// let grads = GradientsParams::from_grads(grads, &self.q_network);
self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads);
// Mise à jour du réseau cible
self.step_count += 1;
if self.step_count % self.config.target_update_freq == 0 {
self.update_target_network();
}
// Décroissance d'epsilon
if self.epsilon > self.config.epsilon_min {
self.epsilon *= self.config.epsilon_decay;
}
Some(loss.into_scalar())
}
@ -222,14 +229,14 @@ impl BurnDqnAgent {
fn sample_batch(&self) -> Vec<Experience> {
let mut batch = Vec::new();
let buffer_size = self.replay_buffer.len();
for _ in 0..self.config.batch_size.min(buffer_size) {
let index = rand::random::<usize>() % buffer_size;
if let Some(exp) = self.replay_buffer.get(index) {
batch.push(exp.clone());
}
}
batch
}
@ -245,25 +252,27 @@ impl BurnDqnAgent {
let config_path = format!("{}_config.json", path);
let config_json = serde_json::to_string_pretty(&self.config)?;
std::fs::write(config_path, config_json)?;
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
let inference_network = self.q_network.clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
recorder.record(inference_network, model_path.into())?;
println!("Modèle sauvegardé : {}", path);
Ok(())
}
/// Charge un modèle pour l'inférence
pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
pub fn load_model_for_inference(
path: &str,
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
// Charger la configuration
let config_path = format!("{}_config.json", path);
let config_json = std::fs::read_to_string(config_path)?;
let config: DqnConfig = serde_json::from_str(&config_json)?;
// Créer le réseau pour l'inférence
let device = NdArrayDevice::default();
let network = DqnNetwork::<InferenceBackend>::new(
@ -272,13 +281,13 @@ impl BurnDqnAgent {
config.action_size,
&device,
);
// Charger les poids
let model_path = format!("{}_model.burn", path);
let recorder = CompactRecorder::new();
let record = recorder.load(model_path.into(), &device)?;
let network = network.load_record(record);
Ok((network, config))
}
@ -291,4 +300,4 @@ impl BurnDqnAgent {
pub fn get_buffer_size(&self) -> usize {
self.replay_buffer.len()
}
}
}