wip broken
This commit is contained in:
parent
6a7b1cbebc
commit
894a24033c
|
|
@ -1,10 +1,19 @@
|
|||
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
|
||||
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
|
||||
use bot::strategy::dqn_common::get_valid_actions;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn_rl::base::Environment;
|
||||
use bot::strategy::burn_dqn_agent::{DqnNetwork, MyBackend};
|
||||
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment, TrictracState};
|
||||
use burn::optim::{AdamWConfig, Optimizer};
|
||||
use burn_rl::{
|
||||
agent::{DQN, DQNTrainingConfig},
|
||||
base::{Action, Agent, ElemType, Environment, Memory, State},
|
||||
};
|
||||
use std::env;
|
||||
|
||||
const DENSE_SIZE: usize = 128;
|
||||
const EPS_DECAY: f64 = 1000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
|
||||
type MyAgent = DQN<TrictracEnvironment, MyBackend, DqnNetwork<MyBackend>>;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::init();
|
||||
|
||||
|
|
@ -71,193 +80,73 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// Créer le dossier models s'il n'existe pas
|
||||
std::fs::create_dir_all("models")?;
|
||||
|
||||
println!("=== Entraînement DQN complet avec Burn ===");
|
||||
println!("=== Entraînement DQN complet avec Burn-RL ===");
|
||||
println!("Épisodes : {}", episodes);
|
||||
println!("Modèle : {}", model_path);
|
||||
println!("Sauvegarde tous les {} épisodes", save_every);
|
||||
println!("Max steps par épisode : {}", max_steps_per_episode);
|
||||
println!();
|
||||
|
||||
// Configuration DQN
|
||||
let config = DqnConfig {
|
||||
state_size: 36,
|
||||
action_size: 1252, // Espace d'actions réduit via contexte
|
||||
hidden_size: 256,
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 1.0,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
target_update_freq: 100,
|
||||
};
|
||||
|
||||
// Créer l'agent et l'environnement
|
||||
let mut agent = BurnDqnAgent::new(config);
|
||||
let mut optimizer = AdamConfig::new().init();
|
||||
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
let model = DqnNetwork::<MyBackend>::new(
|
||||
TrictracState::size(),
|
||||
DENSE_SIZE,
|
||||
TrictracAction::size(),
|
||||
&Default::default(),
|
||||
);
|
||||
let mut agent = MyAgent::new(model);
|
||||
let config = DQNTrainingConfig::default();
|
||||
let mut memory = Memory::<TrictracEnvironment, MyBackend>::default();
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(config.clip_grad.clone())
|
||||
.init();
|
||||
let mut policy_net = agent.model().as_ref().unwrap().clone();
|
||||
let mut step = 0_usize;
|
||||
|
||||
// Variables pour les statistiques
|
||||
let mut total_rewards = Vec::new();
|
||||
let mut episode_lengths = Vec::new();
|
||||
let mut losses = Vec::new();
|
||||
for episode in 0..episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward: ElemType = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut state = env.state();
|
||||
|
||||
println!("Début de l'entraînement avec agent DQN complet...");
|
||||
println!();
|
||||
while !episode_done {
|
||||
let eps_threshold =
|
||||
EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY);
|
||||
let action = MyAgent::react_with_exploration(&policy_net, state, eps_threshold);
|
||||
let snapshot = env.step(action);
|
||||
|
||||
for episode in 1..=episodes {
|
||||
// Reset de l'environnement
|
||||
let mut snapshot = env.reset();
|
||||
let mut episode_reward = 0.0;
|
||||
let mut step = 0;
|
||||
let mut episode_loss = 0.0;
|
||||
let mut loss_count = 0;
|
||||
episode_reward += <f32 as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
policy_net = agent.train(&policy_net, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
loop {
|
||||
step += 1;
|
||||
let current_state_data = snapshot.state().data;
|
||||
episode_duration += 1;
|
||||
|
||||
// Obtenir les actions valides selon le contexte du jeu
|
||||
let valid_actions = get_valid_actions(&env.game);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Convertir les actions Trictrac en indices pour l'agent
|
||||
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
|
||||
|
||||
// Sélectionner une action avec l'agent DQN
|
||||
let action_index = agent.select_action(
|
||||
¤t_state_data,
|
||||
&valid_indices,
|
||||
);
|
||||
let action = TrictracAction {
|
||||
index: action_index as u32,
|
||||
};
|
||||
|
||||
// Exécuter l'action
|
||||
snapshot = env.step(action);
|
||||
episode_reward += *snapshot.reward();
|
||||
|
||||
// Préparer l'expérience pour l'agent
|
||||
let experience = Experience {
|
||||
state: current_state_data.to_vec(),
|
||||
action: action_index,
|
||||
reward: *snapshot.reward(),
|
||||
next_state: if snapshot.done() {
|
||||
None
|
||||
} else {
|
||||
Some(snapshot.state().data.to_vec())
|
||||
},
|
||||
done: snapshot.done(),
|
||||
};
|
||||
|
||||
// Ajouter l'expérience au replay buffer
|
||||
agent.add_experience(experience);
|
||||
|
||||
// Entraîner l'agent
|
||||
if let Some(loss) = agent.train_step(&mut optimizer) {
|
||||
episode_loss += loss;
|
||||
loss_count += 1;
|
||||
}
|
||||
|
||||
// Vérifier les conditions de fin
|
||||
if snapshot.done() || step >= max_steps_per_episode {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculer la loss moyenne de l'épisode
|
||||
let avg_loss = if loss_count > 0 {
|
||||
episode_loss / loss_count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Sauvegarder les statistiques
|
||||
total_rewards.push(episode_reward);
|
||||
episode_lengths.push(step);
|
||||
losses.push(avg_loss);
|
||||
|
||||
// Affichage des statistiques
|
||||
if episode % save_every == 0 {
|
||||
let avg_reward =
|
||||
total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||
let avg_length =
|
||||
episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
|
||||
let avg_episode_loss =
|
||||
losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||
|
||||
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
|
||||
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
|
||||
|
||||
// Sauvegarder le modèle
|
||||
let checkpoint_path = format!("{}_{}", model_path, episode);
|
||||
if let Err(e) = agent.save_model(&checkpoint_path) {
|
||||
eprintln!("Erreur lors de la sauvegarde : {}", e);
|
||||
if snapshot.done() || episode_duration >= TrictracEnvironment::MAX_STEPS {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
println!(
|
||||
"{{\"episode\": {}, \"reward\": {:.4}, \"duration\": {}}}",
|
||||
episode, episode_reward, episode_duration
|
||||
);
|
||||
} else {
|
||||
println!(" → Modèle sauvegardé : {}", checkpoint_path);
|
||||
state = *snapshot.state();
|
||||
}
|
||||
} else if episode % 10 == 0 {
|
||||
println!(
|
||||
"Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
|
||||
episode,
|
||||
episode_reward,
|
||||
step,
|
||||
avg_loss,
|
||||
agent.get_epsilon()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Sauvegarder le modèle final
|
||||
let final_path = format!("{}_final", model_path);
|
||||
agent.save_model(&final_path)?;
|
||||
|
||||
// Statistiques finales
|
||||
println!();
|
||||
println!("=== Résultats de l'entraînement ===");
|
||||
let final_avg_reward = total_rewards
|
||||
.iter()
|
||||
.rev()
|
||||
.take(100.min(episodes))
|
||||
.sum::<f32>()
|
||||
/ 100.min(episodes) as f32;
|
||||
let final_avg_length = episode_lengths
|
||||
.iter()
|
||||
.rev()
|
||||
.take(100.min(episodes))
|
||||
.sum::<usize>()
|
||||
/ 100.min(episodes);
|
||||
let final_avg_loss =
|
||||
losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
||||
|
||||
println!(
|
||||
"Récompense moyenne (100 derniers épisodes) : {:.3}",
|
||||
final_avg_reward
|
||||
);
|
||||
println!(
|
||||
"Longueur moyenne (100 derniers épisodes) : {}",
|
||||
final_avg_length
|
||||
);
|
||||
println!(
|
||||
"Loss moyenne (100 derniers épisodes) : {:.6}",
|
||||
final_avg_loss
|
||||
);
|
||||
println!("Epsilon final : {:.3}", agent.get_epsilon());
|
||||
println!("Taille du buffer final : {}", agent.get_buffer_size());
|
||||
|
||||
// Statistiques globales
|
||||
let max_reward = total_rewards
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
println!("Récompense max : {:.3}", max_reward);
|
||||
println!("Récompense min : {:.3}", min_reward);
|
||||
// agent.save_model(&final_path)?;
|
||||
|
||||
println!();
|
||||
println!("Entraînement terminé avec succès !");
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
use burn::{
|
||||
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
||||
module::Module,
|
||||
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||
optim::Optimizer,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::Tensor,
|
||||
nn::{Linear, LinearConfig},
|
||||
tensor::{activation::relu, backend::Backend, Tensor},
|
||||
};
|
||||
use burn_rl::agent::DQNModel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Backend utilisé pour l'entraînement (Autodiff + NdArray)
|
||||
pub type MyBackend = Autodiff<NdArray>;
|
||||
|
|
@ -16,14 +14,14 @@ pub type InferenceBackend = NdArray;
|
|||
pub type MyDevice = NdArrayDevice;
|
||||
|
||||
/// Réseau de neurones pour DQN
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DqnNetwork<B: burn::prelude::Backend> {
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub struct DqnNetwork<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
fc3: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
||||
impl<B: Backend> DqnNetwork<B> {
|
||||
/// Crée un nouveau réseau DQN
|
||||
pub fn new(
|
||||
input_size: usize,
|
||||
|
|
@ -38,14 +36,46 @@ impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
|||
Self { fc1, fc2, fc3 }
|
||||
}
|
||||
|
||||
/// Forward pass du réseau
|
||||
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||
(self.fc1, self.fc2, self.fc3)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> burn_rl::base::Model<Tensor<B, 2>, Tensor<B, 2>> for DqnNetwork<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = self.fc1.forward(input);
|
||||
let x = burn::tensor::activation::relu(x);
|
||||
let x = relu(x);
|
||||
let x = self.fc2.forward(x);
|
||||
let x = burn::tensor::activation::relu(x);
|
||||
let x = relu(x);
|
||||
self.fc3.forward(x)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DQNModel<B> for DqnNetwork<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: f32) -> Self {
|
||||
let (fc1, fc2, fc3) = this.consume();
|
||||
Self {
|
||||
fc1: soft_update_linear(fc1, &that.fc1, tau),
|
||||
fc2: soft_update_linear(fc2, &that.fc2, tau),
|
||||
fc3: soft_update_linear(fc3, &that.fc3, tau),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(
|
||||
this: Linear<B>,
|
||||
that: &Linear<B>,
|
||||
tau: f32,
|
||||
) -> Linear<B> {
|
||||
let mut updated = this.clone();
|
||||
let that_record = that.clone().into_record();
|
||||
let updated_record = updated.clone().into_record();
|
||||
updated.load_record(updated_record.soft_update(tau, that_record));
|
||||
updated
|
||||
}
|
||||
|
||||
/// Configuration pour l'entraînement DQN
|
||||
|
|
@ -91,215 +121,3 @@ pub struct Experience {
|
|||
pub next_state: Option<Vec<f32>>,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Agent DQN utilisant Burn
|
||||
pub struct BurnDqnAgent {
|
||||
config: DqnConfig,
|
||||
device: MyDevice,
|
||||
q_network: DqnNetwork<MyBackend>,
|
||||
target_network: DqnNetwork<MyBackend>,
|
||||
replay_buffer: VecDeque<Experience>,
|
||||
epsilon: f32,
|
||||
step_count: usize,
|
||||
}
|
||||
|
||||
impl BurnDqnAgent {
|
||||
/// Crée un nouvel agent DQN
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let device = MyDevice::default();
|
||||
|
||||
let q_network = DqnNetwork::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
let target_network = DqnNetwork::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
Self {
|
||||
config: config.clone(),
|
||||
device,
|
||||
q_network,
|
||||
target_network,
|
||||
replay_buffer: VecDeque::new(),
|
||||
epsilon: config.epsilon,
|
||||
step_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sélectionne une action avec epsilon-greedy
|
||||
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
|
||||
if valid_actions.is_empty() {
|
||||
// Retourne une action par défaut ou une action "nulle" si aucune n'est valide
|
||||
// Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Exploration epsilon-greedy
|
||||
if rand::random::<f32>() < self.epsilon {
|
||||
let random_index = rand::random::<usize>() % valid_actions.len();
|
||||
return valid_actions[random_index];
|
||||
}
|
||||
|
||||
// Exploitation : choisir la meilleure action selon le Q-network
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device)
|
||||
.reshape([1, self.config.state_size]);
|
||||
let q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
// Convertir en vecteur pour traitement
|
||||
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
|
||||
|
||||
// Trouver la meilleure action parmi les actions valides
|
||||
let mut best_action = valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for &action in valid_actions {
|
||||
if action < q_data.len() && q_data[action] > best_q_value {
|
||||
best_q_value = q_data[action];
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
|
||||
best_action
|
||||
}
|
||||
|
||||
/// Ajoute une expérience au replay buffer
|
||||
pub fn add_experience(&mut self, experience: Experience) {
|
||||
if self.replay_buffer.len() >= self.config.replay_buffer_size {
|
||||
self.replay_buffer.pop_front();
|
||||
}
|
||||
self.replay_buffer.push_back(experience);
|
||||
}
|
||||
|
||||
/// Entraîne le réseau sur un batch d'expériences
|
||||
pub fn train_step(
|
||||
&mut self,
|
||||
optimizer: &mut impl Optimizer<DqnNetwork<MyBackend>, MyBackend>,
|
||||
) -> Option<f32> {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Échantillonner un batch d'expériences
|
||||
let batch = self.sample_batch();
|
||||
|
||||
// Préparer les tenseurs d'état
|
||||
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states.as_slice(), &self.device)
|
||||
.reshape([self.config.batch_size, self.config.state_size]);
|
||||
|
||||
// Calculer les Q-values actuelles
|
||||
let current_q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
// Pour l'instant, version simplifiée sans calcul de target
|
||||
let target_q_values = current_q_values.clone();
|
||||
|
||||
// Calculer la loss MSE
|
||||
let loss = MseLoss::new().forward(
|
||||
current_q_values,
|
||||
target_q_values,
|
||||
burn::nn::loss::Reduction::Mean,
|
||||
);
|
||||
|
||||
// Backpropagation (version simplifiée)
|
||||
let grads = loss.backward();
|
||||
// Gradients linked to each parameter of the model.
|
||||
let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network);
|
||||
self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads);
|
||||
|
||||
// Mise à jour du réseau cible
|
||||
self.step_count += 1;
|
||||
if self.step_count % self.config.target_update_freq == 0 {
|
||||
self.update_target_network();
|
||||
}
|
||||
|
||||
// Décroissance d'epsilon
|
||||
if self.epsilon > self.config.epsilon_min {
|
||||
self.epsilon *= self.config.epsilon_decay;
|
||||
}
|
||||
|
||||
Some(loss.into_scalar())
|
||||
}
|
||||
|
||||
/// Échantillonne un batch d'expériences du replay buffer
|
||||
fn sample_batch(&self) -> Vec<Experience> {
|
||||
let mut batch = Vec::new();
|
||||
let buffer_size = self.replay_buffer.len();
|
||||
|
||||
for _ in 0..self.config.batch_size.min(buffer_size) {
|
||||
let index = rand::random::<usize>() % buffer_size;
|
||||
if let Some(exp) = self.replay_buffer.get(index) {
|
||||
batch.push(exp.clone());
|
||||
}
|
||||
}
|
||||
|
||||
batch
|
||||
}
|
||||
|
||||
/// Met à jour le réseau cible avec les poids du réseau principal
|
||||
fn update_target_network(&mut self) {
|
||||
// Copie simple des poids
|
||||
self.target_network = self.q_network.clone();
|
||||
}
|
||||
|
||||
/// Sauvegarde le modèle
|
||||
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Sauvegarder la configuration
|
||||
let config_path = format!("{}_config.json", path);
|
||||
let config_json = serde_json::to_string_pretty(&self.config)?;
|
||||
std::fs::write(config_path, config_json)?;
|
||||
|
||||
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
|
||||
let inference_network = self.q_network.clone().into_record();
|
||||
let recorder = CompactRecorder::new();
|
||||
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
recorder.record(inference_network, model_path.into())?;
|
||||
|
||||
println!("Modèle sauvegardé : {}", path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Charge un modèle pour l'inférence
|
||||
pub fn load_model_for_inference(
|
||||
path: &str,
|
||||
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
|
||||
// Charger la configuration
|
||||
let config_path = format!("{}_config.json", path);
|
||||
let config_json = std::fs::read_to_string(config_path)?;
|
||||
let config: DqnConfig = serde_json::from_str(&config_json)?;
|
||||
|
||||
// Créer le réseau pour l'inférence
|
||||
let device = NdArrayDevice::default();
|
||||
let network = DqnNetwork::<InferenceBackend>::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
// Charger les poids
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
let recorder = CompactRecorder::new();
|
||||
let record = recorder.load(model_path.into(), &device)?;
|
||||
let network = network.load_record(record);
|
||||
|
||||
Ok((network, config))
|
||||
}
|
||||
|
||||
/// Retourne l'epsilon actuel
|
||||
pub fn get_epsilon(&self) -> f32 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
/// Retourne la taille du replay buffer
|
||||
pub fn get_buffer_size(&self) -> usize {
|
||||
self.replay_buffer.len()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue