refacto
This commit is contained in:
parent
ad58c0ec60
commit
2e0a874879
21 changed files with 23 additions and 1051 deletions
|
|
@ -1,305 +0,0 @@
|
|||
use burn::{
|
||||
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
||||
module::Module,
|
||||
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||
optim::Optimizer,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::Tensor,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Backend utilisé pour l'entraînement (Autodiff + NdArray)
|
||||
pub type MyBackend = Autodiff<NdArray>;
|
||||
/// Backend utilisé pour l'inférence (NdArray)
|
||||
pub type InferenceBackend = NdArray;
|
||||
pub type MyDevice = NdArrayDevice;
|
||||
|
||||
/// Réseau de neurones pour DQN
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DqnNetwork<B: burn::prelude::Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
fc3: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: burn::prelude::Backend> DqnNetwork<B> {
|
||||
/// Crée un nouveau réseau DQN
|
||||
pub fn new(
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
output_size: usize,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
let fc1 = LinearConfig::new(input_size, hidden_size).init(device);
|
||||
let fc2 = LinearConfig::new(hidden_size, hidden_size).init(device);
|
||||
let fc3 = LinearConfig::new(hidden_size, output_size).init(device);
|
||||
|
||||
Self { fc1, fc2, fc3 }
|
||||
}
|
||||
|
||||
/// Forward pass du réseau
|
||||
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = self.fc1.forward(input);
|
||||
let x = burn::tensor::activation::relu(x);
|
||||
let x = self.fc2.forward(x);
|
||||
let x = burn::tensor::activation::relu(x);
|
||||
self.fc3.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration pour l'entraînement DQN
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DqnConfig {
|
||||
pub state_size: usize,
|
||||
pub action_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub learning_rate: f64,
|
||||
pub gamma: f32,
|
||||
pub epsilon: f32,
|
||||
pub epsilon_decay: f32,
|
||||
pub epsilon_min: f32,
|
||||
pub replay_buffer_size: usize,
|
||||
pub batch_size: usize,
|
||||
pub target_update_freq: usize,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state_size: 36,
|
||||
action_size: 1000,
|
||||
hidden_size: 256,
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 1.0,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
target_update_freq: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Experience pour le replay buffer
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Experience {
|
||||
pub state: Vec<f32>,
|
||||
pub action: usize,
|
||||
pub reward: f32,
|
||||
pub next_state: Option<Vec<f32>>,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Agent DQN utilisant Burn
|
||||
pub struct BurnDqnAgent {
|
||||
config: DqnConfig,
|
||||
device: MyDevice,
|
||||
q_network: DqnNetwork<MyBackend>,
|
||||
target_network: DqnNetwork<MyBackend>,
|
||||
replay_buffer: VecDeque<Experience>,
|
||||
epsilon: f32,
|
||||
step_count: usize,
|
||||
}
|
||||
|
||||
impl BurnDqnAgent {
|
||||
/// Crée un nouvel agent DQN
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let device = MyDevice::default();
|
||||
|
||||
let q_network = DqnNetwork::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
let target_network = DqnNetwork::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
Self {
|
||||
config: config.clone(),
|
||||
device,
|
||||
q_network,
|
||||
target_network,
|
||||
replay_buffer: VecDeque::new(),
|
||||
epsilon: config.epsilon,
|
||||
step_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sélectionne une action avec epsilon-greedy
|
||||
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
|
||||
if valid_actions.is_empty() {
|
||||
// Retourne une action par défaut ou une action "nulle" si aucune n'est valide
|
||||
// Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Exploration epsilon-greedy
|
||||
if rand::random::<f32>() < self.epsilon {
|
||||
let random_index = rand::random::<usize>() % valid_actions.len();
|
||||
return valid_actions[random_index];
|
||||
}
|
||||
|
||||
// Exploitation : choisir la meilleure action selon le Q-network
|
||||
let state_tensor = Tensor::<MyBackend, 1>::from_floats(state, &self.device)
|
||||
.reshape([1, self.config.state_size]);
|
||||
let q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
// Convertir en vecteur pour traitement
|
||||
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
|
||||
|
||||
// Trouver la meilleure action parmi les actions valides
|
||||
let mut best_action = valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for &action in valid_actions {
|
||||
if action < q_data.len() && q_data[action] > best_q_value {
|
||||
best_q_value = q_data[action];
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
|
||||
best_action
|
||||
}
|
||||
|
||||
/// Ajoute une expérience au replay buffer
|
||||
pub fn add_experience(&mut self, experience: Experience) {
|
||||
if self.replay_buffer.len() >= self.config.replay_buffer_size {
|
||||
self.replay_buffer.pop_front();
|
||||
}
|
||||
self.replay_buffer.push_back(experience);
|
||||
}
|
||||
|
||||
/// Entraîne le réseau sur un batch d'expériences
|
||||
pub fn train_step(
|
||||
&mut self,
|
||||
optimizer: &mut impl Optimizer<DqnNetwork<MyBackend>, MyBackend>,
|
||||
) -> Option<f32> {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Échantillonner un batch d'expériences
|
||||
let batch = self.sample_batch();
|
||||
|
||||
// Préparer les tenseurs d'état
|
||||
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
|
||||
let state_tensor = Tensor::<MyBackend, 1>::from_floats(states.as_slice(), &self.device)
|
||||
.reshape([self.config.batch_size, self.config.state_size]);
|
||||
|
||||
// Calculer les Q-values actuelles
|
||||
let current_q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
// Pour l'instant, version simplifiée sans calcul de target
|
||||
let target_q_values = current_q_values.clone();
|
||||
|
||||
// Calculer la loss MSE
|
||||
let loss = MseLoss::new().forward(
|
||||
current_q_values,
|
||||
target_q_values,
|
||||
burn::nn::loss::Reduction::Mean,
|
||||
);
|
||||
|
||||
// Backpropagation (version simplifiée)
|
||||
let grads = loss.backward();
|
||||
// Gradients linked to each parameter of the model.
|
||||
let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network);
|
||||
self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads);
|
||||
|
||||
// Mise à jour du réseau cible
|
||||
self.step_count += 1;
|
||||
if self.step_count % self.config.target_update_freq == 0 {
|
||||
self.update_target_network();
|
||||
}
|
||||
|
||||
// Décroissance d'epsilon
|
||||
if self.epsilon > self.config.epsilon_min {
|
||||
self.epsilon *= self.config.epsilon_decay;
|
||||
}
|
||||
|
||||
Some(loss.into_scalar())
|
||||
}
|
||||
|
||||
/// Échantillonne un batch d'expériences du replay buffer
|
||||
fn sample_batch(&self) -> Vec<Experience> {
|
||||
let mut batch = Vec::new();
|
||||
let buffer_size = self.replay_buffer.len();
|
||||
|
||||
for _ in 0..self.config.batch_size.min(buffer_size) {
|
||||
let index = rand::random::<usize>() % buffer_size;
|
||||
if let Some(exp) = self.replay_buffer.get(index) {
|
||||
batch.push(exp.clone());
|
||||
}
|
||||
}
|
||||
|
||||
batch
|
||||
}
|
||||
|
||||
/// Met à jour le réseau cible avec les poids du réseau principal
|
||||
fn update_target_network(&mut self) {
|
||||
// Copie simple des poids
|
||||
self.target_network = self.q_network.clone();
|
||||
}
|
||||
|
||||
/// Sauvegarde le modèle
|
||||
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Sauvegarder la configuration
|
||||
let config_path = format!("{}_config.json", path);
|
||||
let config_json = serde_json::to_string_pretty(&self.config)?;
|
||||
std::fs::write(config_path, config_json)?;
|
||||
|
||||
// Sauvegarder le réseau pour l'inférence (conversion vers NdArray backend)
|
||||
let inference_network = self.q_network.clone().into_record();
|
||||
let recorder = CompactRecorder::new();
|
||||
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
recorder.record(inference_network, model_path.into())?;
|
||||
|
||||
println!("Modèle sauvegardé : {}", path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Charge un modèle pour l'inférence
|
||||
pub fn load_model_for_inference(
|
||||
path: &str,
|
||||
) -> Result<(DqnNetwork<InferenceBackend>, DqnConfig), Box<dyn std::error::Error>> {
|
||||
// Charger la configuration
|
||||
let config_path = format!("{}_config.json", path);
|
||||
let config_json = std::fs::read_to_string(config_path)?;
|
||||
let config: DqnConfig = serde_json::from_str(&config_json)?;
|
||||
|
||||
// Créer le réseau pour l'inférence
|
||||
let device = NdArrayDevice::default();
|
||||
let network = DqnNetwork::<InferenceBackend>::new(
|
||||
config.state_size,
|
||||
config.hidden_size,
|
||||
config.action_size,
|
||||
&device,
|
||||
);
|
||||
|
||||
// Charger les poids
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
let recorder = CompactRecorder::new();
|
||||
let record = recorder.load(model_path.into(), &device)?;
|
||||
let network = network.load_record(record);
|
||||
|
||||
Ok((network, config))
|
||||
}
|
||||
|
||||
/// Retourne l'epsilon actuel
|
||||
pub fn get_epsilon(&self) -> f32 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
/// Retourne la taille du replay buffer
|
||||
pub fn get_buffer_size(&self) -> usize {
|
||||
self.replay_buffer.len()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use super::burn_dqn_agent::{DqnNetwork, DqnConfig, InferenceBackend};
|
||||
use super::dqn_common::get_valid_actions;
|
||||
use burn::{backend::ndarray::NdArrayDevice, tensor::Tensor};
|
||||
use std::path::Path;
|
||||
|
||||
/// Stratégie utilisant un modèle DQN Burn entraîné
|
||||
#[derive(Debug)]
|
||||
pub struct BurnDqnStrategy {
|
||||
pub game: GameState,
|
||||
pub player_id: PlayerId,
|
||||
pub color: Color,
|
||||
network: Option<DqnNetwork<InferenceBackend>>,
|
||||
config: Option<DqnConfig>,
|
||||
device: NdArrayDevice,
|
||||
}
|
||||
|
||||
impl Default for BurnDqnStrategy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
game: GameState::default(),
|
||||
player_id: 0,
|
||||
color: Color::White,
|
||||
network: None,
|
||||
config: None,
|
||||
device: NdArrayDevice::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BurnDqnStrategy {
|
||||
/// Crée une nouvelle stratégie avec un modèle chargé
|
||||
pub fn new(model_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut strategy = Self::default();
|
||||
strategy.load_model(model_path)?;
|
||||
Ok(strategy)
|
||||
}
|
||||
|
||||
/// Charge un modèle DQN depuis un fichier
|
||||
pub fn load_model(&mut self, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if !Path::new(&format!("{}_config.json", model_path)).exists() {
|
||||
return Err(format!("Modèle non trouvé : {}", model_path).into());
|
||||
}
|
||||
|
||||
let (network, config) = super::burn_dqn_agent::BurnDqnAgent::load_model_for_inference(model_path)?;
|
||||
|
||||
self.network = Some(network);
|
||||
self.config = Some(config);
|
||||
|
||||
println!("Modèle DQN Burn chargé depuis : {}", model_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sélectionne la meilleure action selon le modèle DQN
|
||||
fn select_best_action(&self, valid_actions: &[super::dqn_common::TrictracAction]) -> Option<super::dqn_common::TrictracAction> {
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Si pas de réseau chargé, utiliser la première action valide
|
||||
let Some(network) = &self.network else {
|
||||
return Some(valid_actions[0].clone());
|
||||
};
|
||||
|
||||
// Convertir l'état du jeu en tensor
|
||||
let state_vec = self.game.to_vec_float();
|
||||
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats(state_vec.as_slice(), &self.device).reshape([1, self.config.as_ref().unwrap().state_size]);
|
||||
|
||||
// Faire une prédiction
|
||||
let q_values = network.forward(state_tensor);
|
||||
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
|
||||
|
||||
// Trouver la meilleure action parmi les actions valides
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for (i, action) in valid_actions.iter().enumerate() {
|
||||
if i < q_data.len() && q_data[i] > best_q_value {
|
||||
best_q_value = q_data[i];
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
|
||||
Some(best_action.clone())
|
||||
}
|
||||
|
||||
/// Convertit une TrictracAction en CheckerMove pour les mouvements
|
||||
fn trictrac_action_to_moves(&self, action: &super::dqn_common::TrictracAction) -> Option<(CheckerMove, CheckerMove)> {
|
||||
match action {
|
||||
super::dqn_common::TrictracAction::Move { dice_order, from1, from2 } => {
|
||||
let dice = self.game.dice;
|
||||
let (die1, die2) = if *dice_order {
|
||||
(dice.values.0, dice.values.1)
|
||||
} else {
|
||||
(dice.values.1, dice.values.0)
|
||||
};
|
||||
|
||||
// Calculer les destinations selon la couleur
|
||||
let to1 = if self.color == Color::White {
|
||||
from1 + die1 as usize
|
||||
} else {
|
||||
from1.saturating_sub(die1 as usize)
|
||||
};
|
||||
let to2 = if self.color == Color::White {
|
||||
from2 + die2 as usize
|
||||
} else {
|
||||
from2.saturating_sub(die2 as usize)
|
||||
};
|
||||
|
||||
// Créer les mouvements
|
||||
let move1 = CheckerMove::new(*from1, to1).ok()?;
|
||||
let move2 = CheckerMove::new(*from2, to2).ok()?;
|
||||
|
||||
Some((move1, move2))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for BurnDqnStrategy {
|
||||
fn get_game(&self) -> &GameState {
|
||||
&self.game
|
||||
}
|
||||
|
||||
fn get_mut_game(&mut self) -> &mut GameState {
|
||||
&mut self.game
|
||||
}
|
||||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
// Utiliser le modèle DQN pour décider des points à marquer
|
||||
// let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
// Chercher une action Mark dans les actions valides
|
||||
// for action in &valid_actions {
|
||||
// if let super::dqn_common::TrictracAction::Mark { points } = action {
|
||||
// return *points;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Par défaut, marquer 0 points
|
||||
0
|
||||
}
|
||||
|
||||
fn calculate_adv_points(&self) -> u8 {
|
||||
// Même logique que calculate_points pour les points d'avance
|
||||
self.calculate_points()
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if let Some(best_action) = self.select_best_action(&valid_actions) {
|
||||
if let Some((move1, move2)) = self.trictrac_action_to_moves(&best_action) {
|
||||
return (move1, move2);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: utiliser la stratégie par défaut
|
||||
let default_strategy = super::default::DefaultStrategy::default();
|
||||
default_strategy.choose_move()
|
||||
}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if let Some(best_action) = self.select_best_action(&valid_actions) {
|
||||
match best_action {
|
||||
super::dqn_common::TrictracAction::Go => return true,
|
||||
super::dqn_common::TrictracAction::Move { .. } => return false,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Par défaut, toujours choisir de continuer
|
||||
true
|
||||
}
|
||||
|
||||
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||
self.player_id = player_id;
|
||||
}
|
||||
|
||||
fn set_color(&mut self, color: Color) {
|
||||
self.color = color;
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory function pour créer une stratégie DQN Burn depuis un chemin de modèle
|
||||
pub fn create_burn_dqn_strategy(model_path: &str) -> Result<Box<dyn BotStrategy>, Box<dyn std::error::Error>> {
|
||||
let strategy = BurnDqnStrategy::new(model_path)?;
|
||||
Ok(Box::new(strategy))
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use super::dqn_common::{
|
||||
use crate::dqn::dqn_common::{
|
||||
get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,407 +0,0 @@
|
|||
use std::cmp::{max, min};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use store::{CheckerMove, Dice, GameEvent, PlayerId};
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum TrictracAction {
|
||||
/// Lancer les dés
|
||||
Roll,
|
||||
/// Continuer après avoir gagné un trou
|
||||
Go,
|
||||
/// Effectuer un mouvement de pions
|
||||
Move {
|
||||
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||
from1: usize, // position de départ du premier pion (0-24)
|
||||
from2: usize, // position de départ du deuxième pion (0-24)
|
||||
},
|
||||
// Marquer les points : à activer si support des écoles
|
||||
// Mark,
|
||||
}
|
||||
|
||||
impl TrictracAction {
|
||||
/// Encode une action en index pour le réseau de neurones
|
||||
pub fn to_action_index(&self) -> usize {
|
||||
match self {
|
||||
TrictracAction::Roll => 0,
|
||||
TrictracAction::Go => 1,
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Encoder les mouvements dans l'espace d'actions
|
||||
// Indices 2+ pour les mouvements
|
||||
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier)
|
||||
let mut start = 2;
|
||||
if !dice_order {
|
||||
// 25 * 25 = 625
|
||||
start += 625;
|
||||
}
|
||||
start + from1 * 25 + from2
|
||||
} // TrictracAction::Mark => 1252,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un index d'action en TrictracAction
|
||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||
match index {
|
||||
0 => Some(TrictracAction::Roll),
|
||||
// 1252 => Some(TrictracAction::Mark),
|
||||
1 => Some(TrictracAction::Go),
|
||||
i if i >= 3 => {
|
||||
let move_code = i - 3;
|
||||
let (dice_order, from1, from2) = Self::decode_move(move_code);
|
||||
Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un entier en paire de mouvements
|
||||
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||
let mut encoded = code;
|
||||
let dice_order = code < 626;
|
||||
if !dice_order {
|
||||
encoded -= 625
|
||||
}
|
||||
let from1 = encoded / 25;
|
||||
let from2 = 1 + encoded % 25;
|
||||
(dice_order, from1, from2)
|
||||
}
|
||||
|
||||
/// Retourne la taille de l'espace d'actions total
|
||||
pub fn action_space_size() -> usize {
|
||||
// 1 (Roll) + 1 (Go) + mouvements possibles
|
||||
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
|
||||
// Mais on peut optimiser en limitant aux positions valides (1-24)
|
||||
2 + (2 * 25 * 25) // = 1252
|
||||
}
|
||||
|
||||
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
||||
// match action {
|
||||
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
|
||||
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
|
||||
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
|
||||
// TrictracAction::Move {
|
||||
// dice_order,
|
||||
// from1,
|
||||
// from2,
|
||||
// } => {
|
||||
// // Effectuer un mouvement
|
||||
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
||||
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
||||
//
|
||||
// Some(GameEvent::Move {
|
||||
// player_id: self.agent_player_id,
|
||||
// moves: (checker_move1, checker_move2),
|
||||
// })
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
}
|
||||
|
||||
/// Configuration pour l'agent DQN
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DqnConfig {
|
||||
pub state_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_actions: usize,
|
||||
pub learning_rate: f64,
|
||||
pub gamma: f64,
|
||||
pub epsilon: f64,
|
||||
pub epsilon_decay: f64,
|
||||
pub epsilon_min: f64,
|
||||
pub replay_buffer_size: usize,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state_size: 36,
|
||||
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||
num_actions: TrictracAction::action_space_size(),
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 0.1,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimpleNeuralNetwork {
|
||||
pub weights1: Vec<Vec<f32>>,
|
||||
pub biases1: Vec<f32>,
|
||||
pub weights2: Vec<Vec<f32>>,
|
||||
pub biases2: Vec<f32>,
|
||||
pub weights3: Vec<Vec<f32>>,
|
||||
pub biases3: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SimpleNeuralNetwork {
|
||||
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||
let weights1 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..input_size)
|
||||
.map(|_| rng.gen_range(-scale1..scale1))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases1 = vec![0.0; hidden_size];
|
||||
|
||||
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights2 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale2..scale2))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases2 = vec![0.0; hidden_size];
|
||||
|
||||
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights3 = (0..output_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale3..scale3))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases3 = vec![0.0; output_size];
|
||||
|
||||
Self {
|
||||
weights1,
|
||||
biases1,
|
||||
weights2,
|
||||
biases2,
|
||||
weights3,
|
||||
biases3,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
// Première couche
|
||||
let mut layer1: Vec<f32> = self.biases1.clone();
|
||||
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < input.len() {
|
||||
layer1[i] += input[j] * weight;
|
||||
}
|
||||
}
|
||||
layer1[i] = layer1[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Deuxième couche
|
||||
let mut layer2: Vec<f32> = self.biases2.clone();
|
||||
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer1.len() {
|
||||
layer2[i] += layer1[j] * weight;
|
||||
}
|
||||
}
|
||||
layer2[i] = layer2[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Couche de sortie
|
||||
let mut output: Vec<f32> = self.biases3.clone();
|
||||
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer2.len() {
|
||||
output[i] += layer2[j] * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||||
let q_values = self.forward(input);
|
||||
q_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(index, _)| index)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let data = std::fs::read_to_string(path)?;
|
||||
let network = serde_json::from_str(&data)?;
|
||||
Ok(network)
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||
use crate::PointsRules;
|
||||
use store::TurnStage;
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
||||
let active_player_id = game_state.active_player_id;
|
||||
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||
|
||||
if let Some(color) = player_color {
|
||||
match game_state.turn_stage {
|
||||
TurnStage::RollDice | TurnStage::RollWaiting => {
|
||||
valid_actions.push(TrictracAction::Roll);
|
||||
}
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||
// valid_actions.push(TrictracAction::Mark);
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
valid_actions.push(TrictracAction::Go);
|
||||
|
||||
// Ajoute aussi les mouvements possibles
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
valid_actions
|
||||
}
|
||||
|
||||
// Valid only for White player
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
) -> TrictracAction {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
let from1 = move1.get_from();
|
||||
let from2 = move2.get_from();
|
||||
|
||||
let mut diff_move1 = if to1 > 0 {
|
||||
// Mouvement sans sortie
|
||||
to1 - from1
|
||||
} else {
|
||||
// sortie, on utilise la valeur du dé
|
||||
if to2 > 0 {
|
||||
// sortie pour le mouvement 1 uniquement
|
||||
let dice2 = to2 - from2;
|
||||
if dice2 == dice.values.0 as usize {
|
||||
dice.values.1 as usize
|
||||
} else {
|
||||
dice.values.0 as usize
|
||||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
|
||||
let rest_field = 12;
|
||||
if to1 == rest_field
|
||||
&& to2 == rest_field
|
||||
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
|
||||
{
|
||||
// prise par puissance
|
||||
diff_move1 += 1;
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order: diff_move1 == dice.values.0 as usize,
|
||||
from1: move1.get_from(),
|
||||
from2: move2.get_from(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retourne les indices des actions valides
|
||||
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
||||
get_valid_actions(game_state)
|
||||
.into_iter()
|
||||
.map(|action| action.to_action_index())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Sélectionne une action valide aléatoire
|
||||
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
let mut rng = thread_rng();
|
||||
valid_actions.choose(&mut rng).cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
let index = action.to_action_index();
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||
assert_eq!(81, index);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,489 +0,0 @@
|
|||
use crate::{CheckerMove, Color, GameState, PlayerId};
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Experience {
|
||||
pub state: Vec<f32>,
|
||||
pub action: TrictracAction,
|
||||
pub reward: f32,
|
||||
pub next_state: Vec<f32>,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Buffer de replay pour stocker les expériences
|
||||
#[derive(Debug)]
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Experience>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, experience: Experience) {
|
||||
if self.buffer.len() >= self.capacity {
|
||||
self.buffer.pop_front();
|
||||
}
|
||||
self.buffer.push_back(experience);
|
||||
}
|
||||
|
||||
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||||
let mut rng = thread_rng();
|
||||
let len = self.buffer.len();
|
||||
if len < batch_size {
|
||||
return self.buffer.iter().cloned().collect();
|
||||
}
|
||||
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
for _ in 0..batch_size {
|
||||
let idx = rng.gen_range(0..len);
|
||||
batch.push(self.buffer[idx].clone());
|
||||
}
|
||||
batch
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent DQN pour l'apprentissage par renforcement
|
||||
#[derive(Debug)]
|
||||
pub struct DqnAgent {
|
||||
config: DqnConfig,
|
||||
model: SimpleNeuralNetwork,
|
||||
target_model: SimpleNeuralNetwork,
|
||||
replay_buffer: ReplayBuffer,
|
||||
epsilon: f64,
|
||||
step_count: usize,
|
||||
}
|
||||
|
||||
impl DqnAgent {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let model =
|
||||
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||
let target_model = model.clone();
|
||||
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||
let epsilon = config.epsilon;
|
||||
|
||||
Self {
|
||||
config,
|
||||
model,
|
||||
target_model,
|
||||
replay_buffer,
|
||||
epsilon,
|
||||
step_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
// Fallback si aucune action valide
|
||||
return TrictracAction::Roll;
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
if rng.gen::<f64>() < self.epsilon {
|
||||
// Exploration : action valide aléatoire
|
||||
valid_actions
|
||||
.choose(&mut rng)
|
||||
.cloned()
|
||||
.unwrap_or(TrictracAction::Roll)
|
||||
} else {
|
||||
// Exploitation : meilleure action valide selon le modèle
|
||||
let q_values = self.model.forward(state);
|
||||
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for action in &valid_actions {
|
||||
let action_index = action.to_action_index();
|
||||
if action_index < q_values.len() {
|
||||
let q_value = q_values[action_index];
|
||||
if q_value > best_q_value {
|
||||
best_q_value = q_value;
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best_action.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_experience(&mut self, experience: Experience) {
|
||||
self.replay_buffer.push(experience);
|
||||
}
|
||||
|
||||
pub fn train(&mut self) {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return;
|
||||
}
|
||||
|
||||
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||||
// Dans une implémentation complète, ici on ferait la backpropagation
|
||||
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||||
self.step_count += 1;
|
||||
|
||||
// Mise à jour du target model tous les 100 steps
|
||||
if self.step_count % 100 == 0 {
|
||||
self.target_model = self.model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.model.save(path)
|
||||
}
|
||||
|
||||
pub fn get_epsilon(&self) -> f64 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
pub fn get_step_count(&self) -> usize {
|
||||
self.step_count
|
||||
}
|
||||
}
|
||||
|
||||
/// Environnement Trictrac pour l'entraînement
|
||||
#[derive(Debug)]
|
||||
pub struct TrictracEnv {
|
||||
pub game_state: GameState,
|
||||
pub agent_player_id: PlayerId,
|
||||
pub opponent_player_id: PlayerId,
|
||||
pub agent_color: Color,
|
||||
pub max_steps: usize,
|
||||
pub current_step: usize,
|
||||
}
|
||||
|
||||
impl Default for TrictracEnv {
|
||||
fn default() -> Self {
|
||||
let mut game_state = GameState::new(false);
|
||||
game_state.init_player("agent");
|
||||
game_state.init_player("opponent");
|
||||
|
||||
Self {
|
||||
game_state,
|
||||
agent_player_id: 1,
|
||||
opponent_player_id: 2,
|
||||
agent_color: Color::White,
|
||||
max_steps: 1000,
|
||||
current_step: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
pub fn reset(&mut self) -> Vec<f32> {
|
||||
self.game_state = GameState::new(false);
|
||||
self.game_state.init_player("agent");
|
||||
self.game_state.init_player("opponent");
|
||||
|
||||
// Commencer la partie
|
||||
self.game_state.consume(&GameEvent::BeginGame {
|
||||
goes_first: self.agent_player_id,
|
||||
});
|
||||
|
||||
self.current_step = 0;
|
||||
self.game_state.to_vec_float()
|
||||
}
|
||||
|
||||
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
|
||||
let mut reward = 0.0;
|
||||
|
||||
// Appliquer l'action de l'agent
|
||||
if self.game_state.active_player_id == self.agent_player_id {
|
||||
reward += self.apply_agent_action(action);
|
||||
}
|
||||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game_state.active_player_id == self.opponent_player_id
|
||||
&& self.game_state.stage != Stage::Ended
|
||||
{
|
||||
reward += self.play_opponent_turn();
|
||||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game_state.stage == Stage::Ended
|
||||
|| self.game_state.determine_winner().is_some()
|
||||
|| self.current_step >= self.max_steps;
|
||||
|
||||
// Récompense finale si la partie est terminée
|
||||
if done {
|
||||
if let Some(winner) = self.game_state.determine_winner() {
|
||||
if winner == self.agent_player_id {
|
||||
reward += 100.0; // Bonus pour gagner
|
||||
} else {
|
||||
reward -= 50.0; // Pénalité pour perdre
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.current_step += 1;
|
||||
let next_state = self.game_state.to_vec_float();
|
||||
(next_state, reward, done)
|
||||
}
|
||||
|
||||
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
reward += 0.1;
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game_state.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.agent_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game_state.dice.values.0, self.game_state.dice.values.1)
|
||||
} else {
|
||||
(self.game_state.dice.values.1, self.game_state.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = event {
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.agent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pénalité pour action invalide
|
||||
reward -= 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
|
||||
// TODO : use default bot strategy
|
||||
fn play_opponent_turn(&mut self) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
let event = match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_player_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let dice_roll_count = self
|
||||
.game_state
|
||||
.players
|
||||
.get(&self.opponent_player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let rules = MoveRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Stratégie simple : choix aléatoire
|
||||
let mut rng = thread_rng();
|
||||
let choosen_move = *possible_moves
|
||||
.choose(&mut rng)
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
GameEvent::Move {
|
||||
player_id: self.opponent_player_id,
|
||||
moves: if opponent_color == Color::White {
|
||||
choosen_move
|
||||
} else {
|
||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
GameEvent::Go {
|
||||
player_id: self.opponent_player_id,
|
||||
}
|
||||
}
|
||||
};
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
/// Entraîneur pour le modèle DQN
|
||||
pub struct DqnTrainer {
|
||||
agent: DqnAgent,
|
||||
env: TrictracEnv,
|
||||
}
|
||||
|
||||
impl DqnTrainer {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
Self {
|
||||
agent: DqnAgent::new(config),
|
||||
env: TrictracEnv::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_episode(&mut self) -> f32 {
|
||||
let mut total_reward = 0.0;
|
||||
let mut state = self.env.reset();
|
||||
// let mut step_count = 0;
|
||||
|
||||
loop {
|
||||
// step_count += 1;
|
||||
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||
let (next_state, reward, done) = self.env.step(action.clone());
|
||||
total_reward += reward;
|
||||
|
||||
let experience = Experience {
|
||||
state: state.clone(),
|
||||
action,
|
||||
reward,
|
||||
next_state: next_state.clone(),
|
||||
done,
|
||||
};
|
||||
self.agent.store_experience(experience);
|
||||
self.agent.train();
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
// if step_count % 100 == 0 {
|
||||
// println!("{:?}", next_state);
|
||||
// }
|
||||
state = next_state;
|
||||
}
|
||||
|
||||
total_reward
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
&mut self,
|
||||
episodes: usize,
|
||||
save_every: usize,
|
||||
model_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
|
||||
|
||||
for episode in 1..=episodes {
|
||||
let reward = self.train_episode();
|
||||
|
||||
if episode % 100 == 0 {
|
||||
println!(
|
||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||
episode,
|
||||
episodes,
|
||||
reward,
|
||||
self.agent.get_epsilon(),
|
||||
self.agent.get_step_count()
|
||||
);
|
||||
}
|
||||
|
||||
if episode % save_every == 0 {
|
||||
let save_path = format!("{}_episode_{}.json", model_path, episode);
|
||||
self.agent.save_model(&save_path)?;
|
||||
println!("Modèle sauvegardé : {}", save_path);
|
||||
}
|
||||
}
|
||||
|
||||
// Sauvegarder le modèle final
|
||||
let final_path = format!("{}_final.json", model_path);
|
||||
self.agent.save_model(&final_path)?;
|
||||
println!("Modèle final sauvegardé : {}", final_path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
5
bot/src/strategy/mod.rs
Normal file
5
bot/src/strategy/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
pub mod client;
|
||||
pub mod default;
|
||||
pub mod dqn;
|
||||
pub mod erroneous_moves;
|
||||
pub mod stable_baselines3;
|
||||
Loading…
Add table
Add a link
Reference in a new issue