claude not tested
This commit is contained in:
parent
a06b47628e
commit
cf93255f03
|
|
@ -13,6 +13,10 @@ path = "src/bin/train_dqn.rs"
|
||||||
name = "train_burn_rl"
|
name = "train_burn_rl"
|
||||||
path = "src/bin/train_burn_rl.rs"
|
path = "src/bin/train_burn_rl.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "train_dqn_full"
|
||||||
|
path = "src/bin/train_dqn_full.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
|
||||||
253
bot/src/bin/train_dqn_full.rs
Normal file
253
bot/src/bin/train_dqn_full.rs
Normal file
|
|
@ -0,0 +1,253 @@
|
||||||
|
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
|
||||||
|
use bot::strategy::burn_environment::{TrictracEnvironment, TrictracAction};
|
||||||
|
use bot::strategy::dqn_common::get_valid_actions;
|
||||||
|
use burn_rl::base::Environment;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
// Paramètres par défaut
|
||||||
|
let mut episodes = 1000;
|
||||||
|
let mut model_path = "models/burn_dqn_model".to_string();
|
||||||
|
let mut save_every = 100;
|
||||||
|
let mut max_steps_per_episode = 500;
|
||||||
|
|
||||||
|
// Parser les arguments de ligne de commande
|
||||||
|
let mut i = 1;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--episodes" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--model-path" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
model_path = args[i + 1].clone();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --model-path nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--save-every" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
save_every = args[i + 1].parse().unwrap_or(100);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--max-steps" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
max_steps_per_episode = args[i + 1].parse().unwrap_or(500);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --max-steps nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
print_help();
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("Argument inconnu : {}", args[i]);
|
||||||
|
print_help();
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Créer le dossier models s'il n'existe pas
|
||||||
|
std::fs::create_dir_all("models")?;
|
||||||
|
|
||||||
|
println!("=== Entraînement DQN complet avec Burn ===");
|
||||||
|
println!("Épisodes : {}", episodes);
|
||||||
|
println!("Modèle : {}", model_path);
|
||||||
|
println!("Sauvegarde tous les {} épisodes", save_every);
|
||||||
|
println!("Max steps par épisode : {}", max_steps_per_episode);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Configuration DQN
|
||||||
|
let config = DqnConfig {
|
||||||
|
state_size: 36,
|
||||||
|
action_size: 1000, // Espace d'actions réduit via contexte
|
||||||
|
hidden_size: 256,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 1.0,
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
target_update_freq: 100,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Créer l'agent et l'environnement
|
||||||
|
let mut agent = BurnDqnAgent::new(config);
|
||||||
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
|
||||||
|
// Variables pour les statistiques
|
||||||
|
let mut total_rewards = Vec::new();
|
||||||
|
let mut episode_lengths = Vec::new();
|
||||||
|
let mut losses = Vec::new();
|
||||||
|
|
||||||
|
println!("Début de l'entraînement avec agent DQN complet...");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
for episode in 1..=episodes {
|
||||||
|
// Reset de l'environnement
|
||||||
|
let mut snapshot = env.reset();
|
||||||
|
let mut episode_reward = 0.0;
|
||||||
|
let mut step = 0;
|
||||||
|
let mut episode_loss = 0.0;
|
||||||
|
let mut loss_count = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
step += 1;
|
||||||
|
let current_state = snapshot.state;
|
||||||
|
|
||||||
|
// Obtenir les actions valides selon le contexte du jeu
|
||||||
|
let valid_actions = get_valid_actions(&env.game);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convertir les actions Trictrac en indices pour l'agent
|
||||||
|
let valid_indices: Vec<usize> = (0..valid_actions.len()).collect();
|
||||||
|
|
||||||
|
// Sélectionner une action avec l'agent DQN
|
||||||
|
let action_index = agent.select_action(¤t_state.data.iter().map(|&x| x as f32).collect::<Vec<_>>(), &valid_indices);
|
||||||
|
let action = TrictracAction { index: action_index as u32 };
|
||||||
|
|
||||||
|
// Exécuter l'action
|
||||||
|
snapshot = env.step(action);
|
||||||
|
episode_reward += snapshot.reward;
|
||||||
|
|
||||||
|
// Préparer l'expérience pour l'agent
|
||||||
|
let experience = Experience {
|
||||||
|
state: current_state.data.iter().map(|&x| x as f32).collect(),
|
||||||
|
action: action_index,
|
||||||
|
reward: snapshot.reward,
|
||||||
|
next_state: if snapshot.terminated {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(snapshot.state.data.iter().map(|&x| x as f32).collect())
|
||||||
|
},
|
||||||
|
done: snapshot.terminated,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ajouter l'expérience au replay buffer
|
||||||
|
agent.add_experience(experience);
|
||||||
|
|
||||||
|
// Entraîner l'agent
|
||||||
|
if let Some(loss) = agent.train_step() {
|
||||||
|
episode_loss += loss;
|
||||||
|
loss_count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vérifier les conditions de fin
|
||||||
|
if snapshot.terminated || step >= max_steps_per_episode {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculer la loss moyenne de l'épisode
|
||||||
|
let avg_loss = if loss_count > 0 { episode_loss / loss_count as f32 } else { 0.0 };
|
||||||
|
|
||||||
|
// Sauvegarder les statistiques
|
||||||
|
total_rewards.push(episode_reward);
|
||||||
|
episode_lengths.push(step);
|
||||||
|
losses.push(avg_loss);
|
||||||
|
|
||||||
|
// Affichage des statistiques
|
||||||
|
if episode % save_every == 0 {
|
||||||
|
let avg_reward = total_rewards.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||||
|
let avg_length = episode_lengths.iter().rev().take(save_every).sum::<usize>() / save_every;
|
||||||
|
let avg_episode_loss = losses.iter().rev().take(save_every).sum::<f32>() / save_every as f32;
|
||||||
|
|
||||||
|
println!("Episode {} | Avg Reward: {:.3} | Avg Length: {} | Avg Loss: {:.6} | Epsilon: {:.3} | Buffer: {}",
|
||||||
|
episode, avg_reward, avg_length, avg_episode_loss, agent.get_epsilon(), agent.get_buffer_size());
|
||||||
|
|
||||||
|
// Sauvegarder le modèle
|
||||||
|
let checkpoint_path = format!("{}_{}", model_path, episode);
|
||||||
|
if let Err(e) = agent.save_model(&checkpoint_path) {
|
||||||
|
eprintln!("Erreur lors de la sauvegarde : {}", e);
|
||||||
|
} else {
|
||||||
|
println!(" → Modèle sauvegardé : {}", checkpoint_path);
|
||||||
|
}
|
||||||
|
} else if episode % 10 == 0 {
|
||||||
|
println!("Episode {} | Reward: {:.3} | Length: {} | Loss: {:.6} | Epsilon: {:.3}",
|
||||||
|
episode, episode_reward, step, avg_loss, agent.get_epsilon());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sauvegarder le modèle final
|
||||||
|
let final_path = format!("{}_final", model_path);
|
||||||
|
agent.save_model(&final_path)?;
|
||||||
|
|
||||||
|
// Statistiques finales
|
||||||
|
println!();
|
||||||
|
println!("=== Résultats de l'entraînement ===");
|
||||||
|
let final_avg_reward = total_rewards.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
||||||
|
let final_avg_length = episode_lengths.iter().rev().take(100.min(episodes)).sum::<usize>() / 100.min(episodes);
|
||||||
|
let final_avg_loss = losses.iter().rev().take(100.min(episodes)).sum::<f32>() / 100.min(episodes) as f32;
|
||||||
|
|
||||||
|
println!("Récompense moyenne (100 derniers épisodes) : {:.3}", final_avg_reward);
|
||||||
|
println!("Longueur moyenne (100 derniers épisodes) : {}", final_avg_length);
|
||||||
|
println!("Loss moyenne (100 derniers épisodes) : {:.6}", final_avg_loss);
|
||||||
|
println!("Epsilon final : {:.3}", agent.get_epsilon());
|
||||||
|
println!("Taille du buffer final : {}", agent.get_buffer_size());
|
||||||
|
|
||||||
|
// Statistiques globales
|
||||||
|
let max_reward = total_rewards.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let min_reward = total_rewards.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||||
|
println!("Récompense max : {:.3}", max_reward);
|
||||||
|
println!("Récompense min : {:.3}", min_reward);
|
||||||
|
|
||||||
|
println!();
|
||||||
|
println!("Entraînement terminé avec succès !");
|
||||||
|
println!("Modèle final sauvegardé : {}", final_path);
|
||||||
|
println!();
|
||||||
|
println!("Pour utiliser le modèle entraîné :");
|
||||||
|
println!(" cargo run --bin=client_cli -- --bot burn_dqn:{}_final,dummy", model_path);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_help() {
|
||||||
|
println!("Entraîneur DQN complet avec Burn pour Trictrac");
|
||||||
|
println!();
|
||||||
|
println!("USAGE:");
|
||||||
|
println!(" cargo run --bin=train_dqn_full [OPTIONS]");
|
||||||
|
println!();
|
||||||
|
println!("OPTIONS:");
|
||||||
|
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||||
|
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/burn_dqn_model)");
|
||||||
|
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
||||||
|
println!(" --max-steps <NUM> Nombre max de steps par épisode (défaut: 500)");
|
||||||
|
println!(" -h, --help Afficher cette aide");
|
||||||
|
println!();
|
||||||
|
println!("EXEMPLES:");
|
||||||
|
println!(" cargo run --bin=train_dqn_full");
|
||||||
|
println!(" cargo run --bin=train_dqn_full -- --episodes 2000 --save-every 200");
|
||||||
|
println!(" cargo run --bin=train_dqn_full -- --model-path models/my_model --episodes 500");
|
||||||
|
println!();
|
||||||
|
println!("FONCTIONNALITÉS:");
|
||||||
|
println!(" - Agent DQN complet avec réseau de neurones Burn");
|
||||||
|
println!(" - Experience replay buffer avec échantillonnage aléatoire");
|
||||||
|
println!(" - Epsilon-greedy avec décroissance automatique");
|
||||||
|
println!(" - Target network avec mise à jour périodique");
|
||||||
|
println!(" - Sauvegarde automatique des modèles");
|
||||||
|
println!(" - Statistiques d'entraînement détaillées");
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
pub mod strategy;
|
pub mod strategy;
|
||||||
|
|
||||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
pub use strategy::burn_dqn_strategy::{BurnDqnStrategy, create_burn_dqn_strategy};
|
||||||
pub use strategy::default::DefaultStrategy;
|
pub use strategy::default::DefaultStrategy;
|
||||||
pub use strategy::dqn::DqnStrategy;
|
pub use strategy::dqn::DqnStrategy;
|
||||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
pub mod burn_dqn_agent;
|
||||||
|
pub mod burn_dqn_strategy;
|
||||||
pub mod burn_environment;
|
pub mod burn_environment;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod default;
|
pub mod default;
|
||||||
|
|
|
||||||
294
bot/src/strategy/burn_dqn_agent.rs
Normal file
294
bot/src/strategy/burn_dqn_agent.rs
Normal file
|
|
@ -0,0 +1,294 @@
|
||||||
|
use burn::{
|
||||||
|
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
||||||
|
nn::{Linear, LinearConfig, loss::MseLoss},
|
||||||
|
module::Module,
|
||||||
|
tensor::Tensor,
|
||||||
|
optim::{AdamConfig, Optimizer},
|
||||||
|
record::{CompactRecorder, Recorder},
|
||||||
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
|
||||||
|
/// Backend utilisé pour l'entraînement (Autodiff + NdArray)
|
||||||
|
pub type MyBackend = Autodiff<NdArray>;
|
||||||
|
/// Backend utilisé pour l'inférence (NdArray)
|
||||||
|
pub type InferenceBackend = NdArray;
|
||||||
|
pub type MyDevice = NdArrayDevice;
|
||||||
|
|
||||||
|
/// Réseau de neurones pour DQN
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct DqnNetwork<B: burn::prelude::Backend> {
|
||||||
|
fc1: Linear<B>,
|
||||||
|
fc2: Linear<B>,
|
||||||
|
fc3: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
||||||
|
/// Crée un nouveau réseau DQN
|
||||||
|
pub fn new(input_size: usize, hidden_size: usize, output_size: usize, device: &B::Device) -> Self {
|
||||||
|
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
|
||||||
|
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
|
||||||
|
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
|
||||||
|
|
||||||
|
Self { fc1, fc2, fc3 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward pass du réseau
|
||||||
|
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let x = self.fc1.forward(input);
|
||||||
|
let x = burn::tensor::activation::relu(x);
|
||||||
|
let x = self.fc2.forward(x);
|
||||||
|
let x = burn::tensor::activation::relu(x);
|
||||||
|
self.fc3.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration pour l'entraînement DQN
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DqnConfig {
|
||||||
|
pub state_size: usize,
|
||||||
|
pub action_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub gamma: f32,
|
||||||
|
pub epsilon: f32,
|
||||||
|
pub epsilon_decay: f32,
|
||||||
|
pub epsilon_min: f32,
|
||||||
|
pub replay_buffer_size: usize,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub target_update_freq: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state_size: 36,
|
||||||
|
action_size: 1000,
|
||||||
|
hidden_size: 256,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 1.0,
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
target_update_freq: 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Experience pour le replay buffer
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Experience {
|
||||||
|
pub state: Vec<f32>,
|
||||||
|
pub action: usize,
|
||||||
|
pub reward: f32,
|
||||||
|
pub next_state: Option<Vec<f32>>,
|
||||||
|
pub done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent DQN utilisant Burn
|
||||||
|
pub struct BurnDqnAgent {
|
||||||
|
config: DqnConfig,
|
||||||
|
device: MyDevice,
|
||||||
|
q_network: DqnNetwork<MyBackend>,
|
||||||
|
target_network: DqnNetwork<MyBackend>,
|
||||||
|
optimizer: burn::optim::Adam<MyBackend>,
|
||||||
|
replay_buffer: VecDeque<Experience>,
|
||||||
|
epsilon: f32,
|
||||||
|
step_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BurnDqnAgent {
|
||||||
|
/// Crée un nouvel agent DQN
|
||||||
|
pub fn new(config: DqnConfig) -> Self {
|
||||||
|
let device = MyDevice::default();
|
||||||
|
|
||||||
|
let q_network = DqnNetwork::new(
|
||||||
|
config.state_size,
|
||||||
|
config.hidden_size,
|
||||||
|
config.action_size,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let target_network = DqnNetwork::new(
|
||||||
|
config.state_size,
|
||||||
|
config.hidden_size,
|
||||||
|
config.action_size,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let optimizer = AdamConfig::new().init();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config: config.clone(),
|
||||||
|
device,
|
||||||
|
q_network,
|
||||||
|
target_network,
|
||||||
|
optimizer,
|
||||||
|
replay_buffer: VecDeque::new(),
|
||||||
|
epsilon: config.epsilon,
|
||||||
|
step_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sélectionne une action avec epsilon-greedy
|
||||||
|
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exploration epsilon-greedy
|
||||||
|
if rand::random::<f32>() < self.epsilon {
|
||||||
|
let random_index = rand::random::<usize>() % valid_actions.len();
|
||||||
|
return valid_actions[random_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exploitation : choisir la meilleure action selon le Q-network
|
||||||
|
let state_tensor = Tensor::<MyBackend, 2>::from_floats([state], &self.device);
|
||||||
|
let q_values = self.q_network.forward(state_tensor);
|
||||||
|
|
||||||
|
// Convertir en vecteur pour traitement
|
||||||
|
let q_data = q_values.into_data().convert::<f32>().value;
|
||||||
|
|
||||||
|
// Trouver la meilleure action parmi les actions valides
|
||||||
|
let mut best_action = valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for &action in valid_actions {
|
||||||
|
if action < q_data.len() && q_data[action] > best_q_value {
|
||||||
|
best_q_value = q_data[action];
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
best_action
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ajoute une expérience au replay buffer
|
||||||
|
pub fn add_experience(&mut self, experience: Experience) {
|
||||||
|
if self.replay_buffer.len() >= self.config.replay_buffer_size {
|
||||||
|
self.replay_buffer.pop_front();
|
||||||
|
}
|
||||||
|
self.replay_buffer.push_back(experience);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Entraîne le réseau sur un batch d'expériences
|
||||||
|
pub fn train_step(&mut self) -> Option<f32> {
|
||||||
|
if self.replay_buffer.len() < self.config.batch_size {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Échantillonner un batch d'expériences
|
||||||
|
let batch = self.sample_batch();
|
||||||
|
|
||||||
|
// Préparer les tenseurs d'état
|
||||||
|
let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect();
|
||||||
|
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
|
||||||
|
|
||||||
|
// Calculer les Q-values actuelles
|
||||||
|
let current_q_values = self.q_network.forward(state_tensor);
|
||||||
|
|
||||||
|
// Pour l'instant, version simplifiée sans calcul de target
|
||||||
|
let target_q_values = current_q_values.clone();
|
||||||
|
|
||||||
|
// Calculer la loss MSE
|
||||||
|
let loss = MseLoss::new().forward(
|
||||||
|
current_q_values,
|
||||||
|
target_q_values,
|
||||||
|
burn::nn::loss::Reduction::Mean
|
||||||
|
);
|
||||||
|
|
||||||
|
// Backpropagation (version simplifiée)
|
||||||
|
let grads = loss.backward();
|
||||||
|
self.q_network = self.optimizer.step(self.config.learning_rate, self.q_network, grads);
|
||||||
|
|
||||||
|
// Mise à jour du réseau cible
|
||||||
|
self.step_count += 1;
|
||||||
|
if self.step_count % self.config.target_update_freq == 0 {
|
||||||
|
self.update_target_network();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Décroissance d'epsilon
|
||||||
|
if self.epsilon > self.config.epsilon_min {
|
||||||
|
self.epsilon *= self.config.epsilon_decay;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(loss.into_scalar())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Échantillonne un batch d'expériences du replay buffer
|
||||||
|
fn sample_batch(&self) -> Vec<Experience> {
|
||||||
|
let mut batch = Vec::new();
|
||||||
|
let buffer_size = self.replay_buffer.len();
|
||||||
|
|
||||||
|
for _ in 0..self.config.batch_size.min(buffer_size) {
|
||||||
|
let index = rand::random::<usize>() % buffer_size;
|
||||||
|
if let Some(exp) = self.replay_buffer.get(index) {
|
||||||
|
batch.push(exp.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
batch
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Met à jour le réseau cible avec les poids du réseau principal
|
||||||
|
fn update_target_network(&mut self) {
|
||||||
|
// Copie simple des poids
|
||||||
|
self.target_network = self.q_network.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sauvegarde le modèle
|
||||||
|
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Sauvegarder la configuration
|
||||||
|
let config_path = format!("{}_config.json", path);
|
||||||
|
let config_json = serde_json::to_string_pretty(&self.config)?;
|
||||||
|
std::fs::write(config_path, config_json)?;
|
||||||
|
|
||||||
|
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
|
||||||
|
let inference_network = self.q_network.clone().into_record();
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
|
||||||
|
let model_path = format!("{}_model.burn", path);
|
||||||
|
recorder.record(inference_network, model_path.into())?;
|
||||||
|
|
||||||
|
println!("Modèle sauvegardé : {}", path);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Charge un modèle pour l'inférence
|
||||||
|
pub fn load_model_for_inference(path: &str) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
|
||||||
|
// Charger la configuration
|
||||||
|
let config_path = format!("{}_config.json", path);
|
||||||
|
let config_json = std::fs::read_to_string(config_path)?;
|
||||||
|
let config: DqnConfig = serde_json::from_str(&config_json)?;
|
||||||
|
|
||||||
|
// Créer le réseau pour l'inférence
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let network = DqnNetwork::<InferenceBackend>::new(
|
||||||
|
config.state_size,
|
||||||
|
config.hidden_size,
|
||||||
|
config.action_size,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Charger les poids
|
||||||
|
let model_path = format!("{}_model.burn", path);
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let record = recorder.load(model_path.into(), &device)?;
|
||||||
|
let network = network.load_record(record);
|
||||||
|
|
||||||
|
Ok((network, config))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retourne l'epsilon actuel
|
||||||
|
pub fn get_epsilon(&self) -> f32 {
|
||||||
|
self.epsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retourne la taille du replay buffer
|
||||||
|
pub fn get_buffer_size(&self) -> usize {
|
||||||
|
self.replay_buffer.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
192
bot/src/strategy/burn_dqn_strategy.rs
Normal file
192
bot/src/strategy/burn_dqn_strategy.rs
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use super::burn_dqn_agent::{DqnNetwork, DqnConfig, InferenceBackend};
|
||||||
|
use super::dqn_common::get_valid_actions;
|
||||||
|
use burn::{backend::ndarray::NdArrayDevice, tensor::Tensor};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Stratégie utilisant un modèle DQN Burn entraîné
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct BurnDqnStrategy {
|
||||||
|
pub game: GameState,
|
||||||
|
pub player_id: PlayerId,
|
||||||
|
pub color: Color,
|
||||||
|
network: Option<DqnNetwork<InferenceBackend>>,
|
||||||
|
config: Option<DqnConfig>,
|
||||||
|
device: NdArrayDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BurnDqnStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
game: GameState::default(),
|
||||||
|
player_id: 0,
|
||||||
|
color: Color::White,
|
||||||
|
network: None,
|
||||||
|
config: None,
|
||||||
|
device: NdArrayDevice::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BurnDqnStrategy {
|
||||||
|
/// Crée une nouvelle stratégie avec un modèle chargé
|
||||||
|
pub fn new(model_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let mut strategy = Self::default();
|
||||||
|
strategy.load_model(model_path)?;
|
||||||
|
Ok(strategy)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Charge un modèle DQN depuis un fichier
|
||||||
|
pub fn load_model(&mut self, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if !Path::new(&format!("{}_config.json", model_path)).exists() {
|
||||||
|
return Err(format!("Modèle non trouvé : {}", model_path).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (network, config) = super::burn_dqn_agent::BurnDqnAgent::load_model_for_inference(model_path)?;
|
||||||
|
|
||||||
|
self.network = Some(network);
|
||||||
|
self.config = Some(config);
|
||||||
|
|
||||||
|
println!("Modèle DQN Burn chargé depuis : {}", model_path);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sélectionne la meilleure action selon le modèle DQN
|
||||||
|
fn select_best_action(&self, valid_actions: &[super::dqn_common::TrictracAction]) -> Option<super::dqn_common::TrictracAction> {
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Si pas de réseau chargé, utiliser la première action valide
|
||||||
|
let Some(network) = &self.network else {
|
||||||
|
return Some(valid_actions[0].clone());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convertir l'état du jeu en tensor
|
||||||
|
let state_vec = self.game.to_vec_float();
|
||||||
|
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats([state_vec], &self.device);
|
||||||
|
|
||||||
|
// Faire une prédiction
|
||||||
|
let q_values = network.forward(state_tensor);
|
||||||
|
let q_data = q_values.into_data().convert::<f32>().value;
|
||||||
|
|
||||||
|
// Trouver la meilleure action parmi les actions valides
|
||||||
|
let mut best_action = &valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for (i, action) in valid_actions.iter().enumerate() {
|
||||||
|
if i < q_data.len() && q_data[i] > best_q_value {
|
||||||
|
best_q_value = q_data[i];
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(best_action.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convertit une TrictracAction en CheckerMove pour les mouvements
|
||||||
|
fn trictrac_action_to_moves(&self, action: &super::dqn_common::TrictracAction) -> Option<(CheckerMove, CheckerMove)> {
|
||||||
|
match action {
|
||||||
|
super::dqn_common::TrictracAction::Move { dice_order, from1, from2 } => {
|
||||||
|
let dice = self.game.dice;
|
||||||
|
let (die1, die2) = if *dice_order {
|
||||||
|
(dice.values.0, dice.values.1)
|
||||||
|
} else {
|
||||||
|
(dice.values.1, dice.values.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculer les destinations selon la couleur
|
||||||
|
let to1 = if self.color == Color::White {
|
||||||
|
from1 + die1 as usize
|
||||||
|
} else {
|
||||||
|
from1.saturating_sub(die1 as usize)
|
||||||
|
};
|
||||||
|
let to2 = if self.color == Color::White {
|
||||||
|
from2 + die2 as usize
|
||||||
|
} else {
|
||||||
|
from2.saturating_sub(die2 as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Créer les mouvements
|
||||||
|
let move1 = CheckerMove::new(*from1, to1).ok()?;
|
||||||
|
let move2 = CheckerMove::new(*from2, to2).ok()?;
|
||||||
|
|
||||||
|
Some((move1, move2))
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for BurnDqnStrategy {
|
||||||
|
fn get_game(&self) -> &GameState {
|
||||||
|
&self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||||||
|
&mut self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_points(&self) -> u8 {
|
||||||
|
// Utiliser le modèle DQN pour décider des points à marquer
|
||||||
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
||||||
|
// Chercher une action Mark dans les actions valides
|
||||||
|
for action in &valid_actions {
|
||||||
|
if let super::dqn_common::TrictracAction::Mark { points } = action {
|
||||||
|
return *points;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Par défaut, marquer 0 points
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_adv_points(&self) -> u8 {
|
||||||
|
// Même logique que calculate_points pour les points d'avance
|
||||||
|
self.calculate_points()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
||||||
|
if let Some(best_action) = self.select_best_action(&valid_actions) {
|
||||||
|
if let Some((move1, move2)) = self.trictrac_action_to_moves(&best_action) {
|
||||||
|
return (move1, move2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: utiliser la stratégie par défaut
|
||||||
|
let default_strategy = super::default::DefaultStrategy::default();
|
||||||
|
default_strategy.choose_move()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
||||||
|
if let Some(best_action) = self.select_best_action(&valid_actions) {
|
||||||
|
match best_action {
|
||||||
|
super::dqn_common::TrictracAction::Go => return true,
|
||||||
|
super::dqn_common::TrictracAction::Move { .. } => return false,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Par défaut, toujours choisir de continuer
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||||
|
self.player_id = player_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_color(&mut self, color: Color) {
|
||||||
|
self.color = color;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Factory function pour créer une stratégie DQN Burn depuis un chemin de modèle
|
||||||
|
pub fn create_burn_dqn_strategy(model_path: &str) -> Result<Box<dyn BotStrategy>, Box<dyn std::error::Error>> {
|
||||||
|
let strategy = BurnDqnStrategy::new(model_path)?;
|
||||||
|
Ok(Box::new(strategy))
|
||||||
|
}
|
||||||
|
|
@ -230,3 +230,23 @@ Options disponibles :
|
||||||
- --help : aide complète
|
- --help : aide complète
|
||||||
|
|
||||||
Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones !
|
Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones !
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Session où je n'ai cessé de recevoir ce message :
|
||||||
|
|
||||||
|
⎿ API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute. For details, refer to: <https://docs.anthropic.com/en/api/rate-limits>. You can see the response headers for current usage. Please reduce the prompt length or the maximum tokens requested, or try again later. You may also contact sales at <https://www.anthropic.com/contact-sales> to discuss your options for a rate limit increase."}}) · Retrying in 391 seconds… (attempt 1/10)
|
||||||
|
|
||||||
|
✶ Coaching… (403s · ↑ 382 tokens · esc to interrupt)
|
||||||
|
|
||||||
|
Pour à la fin de la session avoir ce résumé :
|
||||||
|
|
||||||
|
Total cost: $0.95
|
||||||
|
Total duration (API): 1h 24m 22.8s
|
||||||
|
Total duration (wall): 1h 43m 3.5s
|
||||||
|
Total code changes: 746 lines added, 0 lines removed
|
||||||
|
Token usage by model:
|
||||||
|
claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
|
||||||
|
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
||||||
|
|
||||||
|
Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue