train command

This commit is contained in:
Henri Bourcereau 2025-05-26 20:44:35 +02:00
parent 480b2ff427
commit ab959fa27b
9 changed files with 846 additions and 422 deletions

View file

@ -5,9 +5,14 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]]
name = "train_dqn"
path = "src/bin/train_dqn.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
store = { path = "../store" }
rand = "0.8"
env_logger = "0.10"

108
bot/src/bin/train_dqn.rs Normal file
View file

@ -0,0 +1,108 @@
use bot::strategy::dqn_trainer::{DqnTrainer};
use bot::strategy::dqn_common::DqnConfig;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/dqn_model".to_string();
let mut save_every = 100;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--model-path" => {
if i + 1 < args.len() {
model_path = args[i + 1].clone();
i += 2;
} else {
eprintln!("Erreur : --model-path nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes);
println!(" Chemin du modèle : {}", model_path);
println!(" Sauvegarde tous les {} épisodes", save_every);
println!();
// Configuration DQN
let config = DqnConfig {
input_size: 32,
hidden_size: 256,
num_actions: 3,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.9, // Commencer avec plus d'exploration
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
};
// Créer et lancer l'entraîneur
let mut trainer = DqnTrainer::new(config);
trainer.train(episodes, save_every, &model_path)?;
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
println!(" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy", model_path);
Ok(())
}
fn print_help() {
println!("Entraîneur DQN pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_dqn [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_dqn");
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
}

View file

@ -1,4 +1,4 @@
mod strategy;
pub mod strategy;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy;

View file

@ -1,5 +1,7 @@
pub mod client;
pub mod default;
pub mod dqn;
pub mod dqn_common;
pub mod dqn_trainer;
pub mod erroneous_moves;
pub mod stable_baselines3;

View file

@ -1,373 +1,25 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use store::MoveRules;
use rand::{thread_rng, Rng};
use std::collections::VecDeque;
use std::path::Path;
use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DqnConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_actions: usize,
pub learning_rate: f64,
pub gamma: f64,
pub epsilon: f64,
pub epsilon_decay: f64,
pub epsilon_min: f64,
pub replay_buffer_size: usize,
pub batch_size: usize,
}
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
impl Default for DqnConfig {
fn default() -> Self {
Self {
input_size: 32,
hidden_size: 256,
num_actions: 3,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.1,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
}
}
}
/// Réseau de neurones DQN simplifié (matrice de poids basique)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimpleNeuralNetwork {
weights1: Vec<Vec<f32>>,
biases1: Vec<f32>,
weights2: Vec<Vec<f32>>,
biases2: Vec<f32>,
weights3: Vec<Vec<f32>>,
biases3: Vec<f32>,
}
impl SimpleNeuralNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
let mut rng = thread_rng();
// Initialisation aléatoire des poids avec Xavier/Glorot
let scale1 = (2.0 / input_size as f32).sqrt();
let weights1 = (0..hidden_size)
.map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect())
.collect();
let biases1 = vec![0.0; hidden_size];
let scale2 = (2.0 / hidden_size as f32).sqrt();
let weights2 = (0..hidden_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect())
.collect();
let biases2 = vec![0.0; hidden_size];
let scale3 = (2.0 / hidden_size as f32).sqrt();
let weights3 = (0..output_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect())
.collect();
let biases3 = vec![0.0; output_size];
Self {
weights1,
biases1,
weights2,
biases2,
weights3,
biases3,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
// Première couche
let mut layer1: Vec<f32> = self.biases1.clone();
for (i, neuron_weights) in self.weights1.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < input.len() {
layer1[i] += input[j] * weight;
}
}
layer1[i] = layer1[i].max(0.0); // ReLU
}
// Deuxième couche
let mut layer2: Vec<f32> = self.biases2.clone();
for (i, neuron_weights) in self.weights2.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer1.len() {
layer2[i] += layer1[j] * weight;
}
}
layer2[i] = layer2[i].max(0.0); // ReLU
}
// Couche de sortie
let mut output: Vec<f32> = self.biases3.clone();
for (i, neuron_weights) in self.weights3.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer2.len() {
output[i] += layer2[j] * weight;
}
}
}
output
}
pub fn get_best_action(&self, input: &[f32]) -> usize {
let q_values = self.forward(input);
q_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(index, _)| index)
.unwrap_or(0)
}
}
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience {
pub state: Vec<f32>,
pub action: usize,
pub reward: f32,
pub next_state: Vec<f32>,
pub done: bool,
}
/// Buffer de replay pour stocker les expériences
#[derive(Debug)]
pub struct ReplayBuffer {
buffer: VecDeque<Experience>,
capacity: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, experience: Experience) {
if self.buffer.len() >= self.capacity {
self.buffer.pop_front();
}
self.buffer.push_back(experience);
}
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
let mut rng = thread_rng();
let len = self.buffer.len();
if len < batch_size {
return self.buffer.iter().cloned().collect();
}
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
let idx = rng.gen_range(0..len);
batch.push(self.buffer[idx].clone());
}
batch
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
/// Agent DQN pour l'apprentissage par renforcement
#[derive(Debug)]
pub struct DqnAgent {
config: DqnConfig,
model: SimpleNeuralNetwork,
target_model: SimpleNeuralNetwork,
replay_buffer: ReplayBuffer,
epsilon: f64,
step_count: usize,
}
impl DqnAgent {
pub fn new(config: DqnConfig) -> Self {
let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions);
let target_model = model.clone();
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
let epsilon = config.epsilon;
Self {
config,
model,
target_model,
replay_buffer,
epsilon,
step_count: 0,
}
}
pub fn select_action(&mut self, state: &[f32]) -> usize {
let mut rng = thread_rng();
if rng.gen::<f64>() < self.epsilon {
// Exploration : action aléatoire
rng.gen_range(0..self.config.num_actions)
} else {
// Exploitation : meilleure action selon le modèle
self.model.get_best_action(state)
}
}
pub fn store_experience(&mut self, experience: Experience) {
self.replay_buffer.push(experience);
}
pub fn train(&mut self) {
if self.replay_buffer.len() < self.config.batch_size {
return;
}
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
// Dans une implémentation complète, ici on ferait la backpropagation
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
self.step_count += 1;
// Mise à jour du target model tous les 100 steps
if self.step_count % 100 == 0 {
self.target_model = self.model.clone();
}
}
pub fn save_model<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
let data = serde_json::to_string_pretty(&self.model)?;
std::fs::write(path, data)?;
Ok(())
}
pub fn load_model<P: AsRef<Path>>(&mut self, path: P) -> Result<(), Box<dyn std::error::Error>> {
let data = std::fs::read_to_string(path)?;
self.model = serde_json::from_str(&data)?;
self.target_model = self.model.clone();
Ok(())
}
}
/// Environnement Trictrac pour l'entraînement
#[derive(Debug)]
pub struct TrictracEnv {
pub game_state: GameState,
pub agent_player_id: PlayerId,
pub opponent_player_id: PlayerId,
pub agent_color: Color,
pub max_steps: usize,
pub current_step: usize,
}
impl TrictracEnv {
pub fn new() -> Self {
let mut game_state = GameState::new(false);
game_state.init_player("agent");
game_state.init_player("opponent");
Self {
game_state,
agent_player_id: 1,
opponent_player_id: 2,
agent_color: Color::White,
max_steps: 1000,
current_step: 0,
}
}
pub fn reset(&mut self) -> Vec<f32> {
self.game_state = GameState::new(false);
self.game_state.init_player("agent");
self.game_state.init_player("opponent");
self.current_step = 0;
self.get_state_vector()
}
pub fn step(&mut self, _action: usize) -> (Vec<f32>, f32, bool) {
let reward = 0.0; // Simplifié pour l'instant
let done = self.game_state.stage == store::Stage::Ended ||
self.game_state.determine_winner().is_some() ||
self.current_step >= self.max_steps;
self.current_step += 1;
// Retourner l'état suivant
let next_state = self.get_state_vector();
(next_state, reward, done)
}
pub fn get_state_vector(&self) -> Vec<f32> {
let mut state = Vec::with_capacity(32);
// Plateau (24 cases)
let white_positions = self.game_state.board.get_color_fields(Color::White);
let black_positions = self.game_state.board.get_color_fields(Color::Black);
let mut board = vec![0.0; 24];
for (pos, count) in white_positions {
if pos < 24 {
board[pos] = count as f32;
}
}
for (pos, count) in black_positions {
if pos < 24 {
board[pos] = -(count as f32);
}
}
state.extend(board);
// Informations supplémentaires limitées pour respecter input_size = 32
state.push(self.game_state.active_player_id as f32);
state.push(self.game_state.dice.values.0 as f32);
state.push(self.game_state.dice.values.1 as f32);
// Points et trous des joueurs
if let Some(white_player) = self.game_state.get_white_player() {
state.push(white_player.points as f32);
state.push(white_player.holes as f32);
} else {
state.extend(vec![0.0, 0.0]);
}
// Assurer que la taille est exactement input_size
state.truncate(32);
while state.len() < 32 {
state.push(0.0);
}
state
}
}
/// Stratégie DQN pour le bot
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
pub struct DqnStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
pub agent: Option<DqnAgent>,
pub env: TrictracEnv,
pub model: Option<SimpleNeuralNetwork>,
}
impl Default for DqnStrategy {
fn default() -> Self {
let game = GameState::default();
let config = DqnConfig::default();
let agent = DqnAgent::new(config);
let env = TrictracEnv::new();
Self {
game,
game: GameState::default(),
player_id: 2,
color: Color::Black,
agent: Some(agent),
env,
model: None,
}
}
}
@ -377,54 +29,22 @@ impl DqnStrategy {
Self::default()
}
pub fn new_with_model(model_path: &str) -> Self {
pub fn new_with_model<P: AsRef<Path>>(model_path: P) -> Self {
let mut strategy = Self::new();
if let Some(ref mut agent) = strategy.agent {
let _ = agent.load_model(model_path);
if let Ok(model) = SimpleNeuralNetwork::load(model_path) {
strategy.model = Some(model);
}
strategy
}
pub fn train_episode(&mut self) -> f32 {
let mut total_reward = 0.0;
let mut state = self.env.reset();
loop {
let action = if let Some(ref mut agent) = self.agent {
agent.select_action(&state)
} else {
0
};
let (next_state, reward, done) = self.env.step(action);
total_reward += reward;
if let Some(ref mut agent) = self.agent {
let experience = Experience {
state: state.clone(),
action,
reward,
next_state: next_state.clone(),
done,
};
agent.store_experience(experience);
agent.train();
}
if done {
break;
}
state = next_state;
/// Utilise le modèle DQN pour choisir une action
fn get_dqn_action(&self) -> Option<usize> {
if let Some(ref model) = self.model {
let state = game_state_to_vector(&self.game);
Some(model.get_best_action(&state))
} else {
None
}
total_reward
}
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
if let Some(ref agent) = self.agent {
agent.save_model(path)?;
}
Ok(())
}
}
@ -447,6 +67,7 @@ impl BotStrategy for DqnStrategy {
fn calculate_points(&self) -> u8 {
// Pour l'instant, utilisation de la méthode standard
// Plus tard on pourrait utiliser le DQN pour optimiser le calcul de points
let dice_roll_count = self
.get_game()
.players
@ -462,34 +83,33 @@ impl BotStrategy for DqnStrategy {
}
fn choose_go(&self) -> bool {
// Utiliser le DQN pour décider (simplifié pour l'instant)
if let Some(ref agent) = self.agent {
let state = self.env.get_state_vector();
// Action 2 = "go", on vérifie si c'est la meilleure action
let q_values = agent.model.forward(&state);
if q_values.len() > 2 {
return q_values[2] > q_values[0] && q_values[2] > *q_values.get(1).unwrap_or(&0.0);
}
// Utiliser le DQN pour décider si on continue (action 2 = "go")
if let Some(action) = self.get_dqn_action() {
// Si le modèle prédit l'action "go" (2), on continue
action == 2
} else {
// Fallback : toujours continuer
true
}
true // Fallback
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Pour l'instant, utiliser la stratégie par défaut
// Plus tard, on pourrait utiliser le DQN pour choisir parmi les mouvements valides
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let chosen_move = if let Some(ref agent) = self.agent {
// Utiliser le DQN pour choisir le meilleur mouvement
let state = self.env.get_state_vector();
let action = agent.model.get_best_action(&state);
// Pour l'instant, on mappe simplement l'action à un mouvement
// Dans une implémentation complète, on aurait un espace d'action plus sophistiqué
let move_index = action.min(possible_moves.len().saturating_sub(1));
let chosen_move = if let Some(action) = self.get_dqn_action() {
// Utiliser l'action DQN pour choisir parmi les mouvements valides
// Action 0 = premier mouvement, action 1 = mouvement moyen, etc.
let move_index = if action == 0 {
0 // Premier mouvement
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2 // Mouvement du milieu
} else {
possible_moves.len().saturating_sub(1) // Dernier mouvement
};
*possible_moves.get(move_index).unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
} else {
// Fallback : premier mouvement valide
*possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))

View file

@ -0,0 +1,182 @@
use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DqnConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_actions: usize,
pub learning_rate: f64,
pub gamma: f64,
pub epsilon: f64,
pub epsilon_decay: f64,
pub epsilon_min: f64,
pub replay_buffer_size: usize,
pub batch_size: usize,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
input_size: 32,
hidden_size: 256,
num_actions: 3,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.1,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
}
}
}
/// Réseau de neurones DQN simplifié (matrice de poids basique)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimpleNeuralNetwork {
pub weights1: Vec<Vec<f32>>,
pub biases1: Vec<f32>,
pub weights2: Vec<Vec<f32>>,
pub biases2: Vec<f32>,
pub weights3: Vec<Vec<f32>>,
pub biases3: Vec<f32>,
}
impl SimpleNeuralNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
// Initialisation aléatoire des poids avec Xavier/Glorot
let scale1 = (2.0 / input_size as f32).sqrt();
let weights1 = (0..hidden_size)
.map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect())
.collect();
let biases1 = vec![0.0; hidden_size];
let scale2 = (2.0 / hidden_size as f32).sqrt();
let weights2 = (0..hidden_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect())
.collect();
let biases2 = vec![0.0; hidden_size];
let scale3 = (2.0 / hidden_size as f32).sqrt();
let weights3 = (0..output_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect())
.collect();
let biases3 = vec![0.0; output_size];
Self {
weights1,
biases1,
weights2,
biases2,
weights3,
biases3,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
// Première couche
let mut layer1: Vec<f32> = self.biases1.clone();
for (i, neuron_weights) in self.weights1.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < input.len() {
layer1[i] += input[j] * weight;
}
}
layer1[i] = layer1[i].max(0.0); // ReLU
}
// Deuxième couche
let mut layer2: Vec<f32> = self.biases2.clone();
for (i, neuron_weights) in self.weights2.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer1.len() {
layer2[i] += layer1[j] * weight;
}
}
layer2[i] = layer2[i].max(0.0); // ReLU
}
// Couche de sortie
let mut output: Vec<f32> = self.biases3.clone();
for (i, neuron_weights) in self.weights3.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer2.len() {
output[i] += layer2[j] * weight;
}
}
}
output
}
pub fn get_best_action(&self, input: &[f32]) -> usize {
let q_values = self.forward(input);
q_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(index, _)| index)
.unwrap_or(0)
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
let data = serde_json::to_string_pretty(self)?;
std::fs::write(path, data)?;
Ok(())
}
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
let data = std::fs::read_to_string(path)?;
let network = serde_json::from_str(&data)?;
Ok(network)
}
}
/// Convertit l'état du jeu en vecteur d'entrée pour le réseau de neurones
pub fn game_state_to_vector(game_state: &crate::GameState) -> Vec<f32> {
use crate::Color;
let mut state = Vec::with_capacity(32);
// Plateau (24 cases)
let white_positions = game_state.board.get_color_fields(Color::White);
let black_positions = game_state.board.get_color_fields(Color::Black);
let mut board = vec![0.0; 24];
for (pos, count) in white_positions {
if pos < 24 {
board[pos] = count as f32;
}
}
for (pos, count) in black_positions {
if pos < 24 {
board[pos] = -(count as f32);
}
}
state.extend(board);
// Informations supplémentaires limitées pour respecter input_size = 32
state.push(game_state.active_player_id as f32);
state.push(game_state.dice.values.0 as f32);
state.push(game_state.dice.values.1 as f32);
// Points et trous des joueurs
if let Some(white_player) = game_state.get_white_player() {
state.push(white_player.points as f32);
state.push(white_player.holes as f32);
} else {
state.extend(vec![0.0, 0.0]);
}
// Assurer que la taille est exactement input_size
state.truncate(32);
while state.len() < 32 {
state.push(0.0);
}
state
}

View file

@ -0,0 +1,438 @@
use crate::{Color, GameState, PlayerId};
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use rand::{thread_rng, Rng};
use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience {
pub state: Vec<f32>,
pub action: usize,
pub reward: f32,
pub next_state: Vec<f32>,
pub done: bool,
}
/// Buffer de replay pour stocker les expériences
#[derive(Debug)]
pub struct ReplayBuffer {
buffer: VecDeque<Experience>,
capacity: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, experience: Experience) {
if self.buffer.len() >= self.capacity {
self.buffer.pop_front();
}
self.buffer.push_back(experience);
}
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
let mut rng = thread_rng();
let len = self.buffer.len();
if len < batch_size {
return self.buffer.iter().cloned().collect();
}
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
let idx = rng.gen_range(0..len);
batch.push(self.buffer[idx].clone());
}
batch
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
/// Agent DQN pour l'apprentissage par renforcement
#[derive(Debug)]
pub struct DqnAgent {
config: DqnConfig,
model: SimpleNeuralNetwork,
target_model: SimpleNeuralNetwork,
replay_buffer: ReplayBuffer,
epsilon: f64,
step_count: usize,
}
impl DqnAgent {
pub fn new(config: DqnConfig) -> Self {
let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions);
let target_model = model.clone();
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
let epsilon = config.epsilon;
Self {
config,
model,
target_model,
replay_buffer,
epsilon,
step_count: 0,
}
}
pub fn select_action(&mut self, state: &[f32]) -> usize {
let mut rng = thread_rng();
if rng.gen::<f64>() < self.epsilon {
// Exploration : action aléatoire
rng.gen_range(0..self.config.num_actions)
} else {
// Exploitation : meilleure action selon le modèle
self.model.get_best_action(state)
}
}
pub fn store_experience(&mut self, experience: Experience) {
self.replay_buffer.push(experience);
}
pub fn train(&mut self) {
if self.replay_buffer.len() < self.config.batch_size {
return;
}
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
// Dans une implémentation complète, ici on ferait la backpropagation
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
self.step_count += 1;
// Mise à jour du target model tous les 100 steps
if self.step_count % 100 == 0 {
self.target_model = self.model.clone();
}
}
pub fn save_model<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
self.model.save(path)
}
pub fn get_epsilon(&self) -> f64 {
self.epsilon
}
pub fn get_step_count(&self) -> usize {
self.step_count
}
}
/// Environnement Trictrac pour l'entraînement
#[derive(Debug)]
pub struct TrictracEnv {
pub game_state: GameState,
pub agent_player_id: PlayerId,
pub opponent_player_id: PlayerId,
pub agent_color: Color,
pub max_steps: usize,
pub current_step: usize,
}
impl TrictracEnv {
pub fn new() -> Self {
let mut game_state = GameState::new(false);
game_state.init_player("agent");
game_state.init_player("opponent");
Self {
game_state,
agent_player_id: 1,
opponent_player_id: 2,
agent_color: Color::White,
max_steps: 1000,
current_step: 0,
}
}
pub fn reset(&mut self) -> Vec<f32> {
self.game_state = GameState::new(false);
self.game_state.init_player("agent");
self.game_state.init_player("opponent");
// Commencer la partie
self.game_state.consume(&GameEvent::BeginGame { goes_first: self.agent_player_id });
self.current_step = 0;
game_state_to_vector(&self.game_state)
}
pub fn step(&mut self, action: usize) -> (Vec<f32>, f32, bool) {
let mut reward = 0.0;
// Appliquer l'action de l'agent
if self.game_state.active_player_id == self.agent_player_id {
reward += self.apply_agent_action(action);
}
// Faire jouer l'adversaire (stratégie simple)
while self.game_state.active_player_id == self.opponent_player_id
&& self.game_state.stage != Stage::Ended {
self.play_opponent_turn();
}
// Vérifier si la partie est terminée
let done = self.game_state.stage == Stage::Ended ||
self.game_state.determine_winner().is_some() ||
self.current_step >= self.max_steps;
// Récompense finale si la partie est terminée
if done {
if let Some(winner) = self.game_state.determine_winner() {
if winner == self.agent_player_id {
reward += 10.0; // Bonus pour gagner
} else {
reward -= 5.0; // Pénalité pour perdre
}
}
}
self.current_step += 1;
let next_state = game_state_to_vector(&self.game_state);
(next_state, reward, done)
}
fn apply_agent_action(&mut self, action: usize) -> f32 {
let mut reward = 0.0;
match self.game_state.turn_stage {
TurnStage::RollDice => {
// Lancer les dés
let event = GameEvent::Roll { player_id: self.agent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
// Simuler le résultat des dés
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.agent_player_id,
dice: store::Dice { values: dice_values },
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
}
reward += 0.1;
}
}
TurnStage::Move => {
// Choisir un mouvement selon l'action
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let move_index = if action == 0 {
0
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2
} else {
possible_moves.len().saturating_sub(1)
};
let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]);
let event = GameEvent::Move {
player_id: self.agent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.2;
} else {
reward -= 1.0; // Pénalité pour mouvement invalide
}
}
}
TurnStage::MarkPoints => {
// Calculer et marquer les points
let dice_roll_count = self.game_state.players.get(&self.agent_player_id).unwrap().dice_roll_count;
let points_rules = PointsRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let points = points_rules.get_points(dice_roll_count).0;
let event = GameEvent::Mark {
player_id: self.agent_player_id,
points,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.1 * points as f32; // Récompense proportionnelle aux points
}
}
TurnStage::HoldOrGoChoice => {
// Décider de continuer ou pas selon l'action
if action == 2 { // Action "go"
let event = GameEvent::Go { player_id: self.agent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.1;
}
} else {
// Passer son tour en jouant un mouvement
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let moves = possible_moves[0];
let event = GameEvent::Move {
player_id: self.agent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
}
}
}
_ => {}
}
reward
}
fn play_opponent_turn(&mut self) {
match self.game_state.turn_stage {
TurnStage::RollDice => {
let event = GameEvent::Roll { player_id: self.opponent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.opponent_player_id,
dice: store::Dice { values: dice_values },
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
}
}
}
TurnStage::Move => {
let opponent_color = self.agent_color.opponent_color();
let rules = MoveRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let moves = possible_moves[0]; // Stratégie simple : premier mouvement
let event = GameEvent::Move {
player_id: self.opponent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
}
}
TurnStage::MarkPoints => {
let opponent_color = self.agent_color.opponent_color();
let dice_roll_count = self.game_state.players.get(&self.opponent_player_id).unwrap().dice_roll_count;
let points_rules = PointsRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
let points = points_rules.get_points(dice_roll_count).0;
let event = GameEvent::Mark {
player_id: self.opponent_player_id,
points,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
let event = GameEvent::Go { player_id: self.opponent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
}
_ => {}
}
}
}
/// Entraîneur pour le modèle DQN
pub struct DqnTrainer {
agent: DqnAgent,
env: TrictracEnv,
}
impl DqnTrainer {
pub fn new(config: DqnConfig) -> Self {
Self {
agent: DqnAgent::new(config),
env: TrictracEnv::new(),
}
}
pub fn train_episode(&mut self) -> f32 {
let mut total_reward = 0.0;
let mut state = self.env.reset();
loop {
let action = self.agent.select_action(&state);
let (next_state, reward, done) = self.env.step(action);
total_reward += reward;
let experience = Experience {
state: state.clone(),
action,
reward,
next_state: next_state.clone(),
done,
};
self.agent.store_experience(experience);
self.agent.train();
if done {
break;
}
state = next_state;
}
total_reward
}
pub fn train(&mut self, episodes: usize, save_every: usize, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
for episode in 1..=episodes {
let reward = self.train_episode();
if episode % 100 == 0 {
println!(
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
episode, episodes, reward, self.agent.get_epsilon(), self.agent.get_step_count()
);
}
if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode);
self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path);
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path);
self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path);
Ok(())
}
}