extend actions space

This commit is contained in:
Henri Bourcereau 2025-06-01 20:00:15 +02:00
parent a2e54bc449
commit f7eea0ed02
5 changed files with 348 additions and 129 deletions

View file

@ -1,4 +1,4 @@
use bot::strategy::dqn_common::DqnConfig; use bot::strategy::dqn_common::{DqnConfig, TrictracAction};
use bot::strategy::dqn_trainer::DqnTrainer; use bot::strategy::dqn_trainer::DqnTrainer;
use std::env; use std::env;
@ -68,7 +68,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = DqnConfig { let config = DqnConfig {
state_size: 36, // state.to_vec size state_size: 36, // state.to_vec size
hidden_size: 256, hidden_size: 256,
num_actions: 3, num_actions: TrictracAction::action_space_size(),
learning_rate: 0.001, learning_rate: 0.001,
gamma: 0.99, gamma: 0.99,
epsilon: 0.9, // Commencer avec plus d'exploration epsilon: 0.9, // Commencer avec plus d'exploration

View file

@ -2,7 +2,7 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use std::path::Path; use std::path::Path;
use store::MoveRules; use store::MoveRules;
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)] #[derive(Debug)]
@ -37,13 +37,38 @@ impl DqnStrategy {
strategy strategy
} }
/// Utilise le modèle DQN pour choisir une action /// Utilise le modèle DQN pour choisir une action valide
fn get_dqn_action(&self) -> Option<usize> { fn get_dqn_action(&self) -> Option<TrictracAction> {
if let Some(ref model) = self.model { if let Some(ref model) = self.model {
let state = self.game.to_vec_float(); let state = self.game.to_vec_float();
Some(model.get_best_action(&state)) let valid_actions = get_valid_actions(&self.game);
if valid_actions.is_empty() {
return None;
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.forward(&state);
// Trouver la meilleure action valide
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for action in &valid_actions {
let action_index = action.to_action_index();
if action_index < q_values.len() {
let q_value = q_values[action_index];
if q_value > best_q_value {
best_q_value = q_value;
best_action = action;
}
}
}
Some(best_action.clone())
} else { } else {
None // Fallback : action aléatoire valide
sample_valid_action(&self.game)
} }
} }
} }
@ -66,6 +91,14 @@ impl BotStrategy for DqnStrategy {
} }
fn calculate_points(&self) -> u8 { fn calculate_points(&self) -> u8 {
// Utiliser le DQN pour choisir le nombre de points à marquer
if let Some(action) = self.get_dqn_action() {
if let TrictracAction::Mark { points } = action {
return points;
}
}
// Fallback : utiliser la méthode standard
let dice_roll_count = self let dice_roll_count = self
.get_game() .get_game()
.players .players
@ -81,10 +114,9 @@ impl BotStrategy for DqnStrategy {
} }
fn choose_go(&self) -> bool { fn choose_go(&self) -> bool {
// Utiliser le DQN pour décider si on continue (action 2 = "go") // Utiliser le DQN pour décider si on continue
if let Some(action) = self.get_dqn_action() { if let Some(action) = self.get_dqn_action() {
// Si le modèle prédit l'action "go" (2), on continue matches!(action, TrictracAction::Go)
action == 2
} else { } else {
// Fallback : toujours continuer // Fallback : toujours continuer
true true
@ -92,28 +124,29 @@ impl BotStrategy for DqnStrategy {
} }
fn choose_move(&self) -> (CheckerMove, CheckerMove) { fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() {
if let TrictracAction::Move { move1, move2 } = action {
let checker_move1 = CheckerMove::new(move1.0, move1.1).unwrap_or_default();
let checker_move2 = CheckerMove::new(move2.0, move2.1).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
}
}
// Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let chosen_move = if let Some(action) = self.get_dqn_action() { let chosen_move = *possible_moves
// Utiliser l'action DQN pour choisir parmi les mouvements valides
// Action 0 = premier mouvement, action 1 = mouvement moyen, etc.
let move_index = if action == 0 {
0 // Premier mouvement
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2 // Mouvement du milieu
} else {
possible_moves.len().saturating_sub(1) // Dernier mouvement
};
*possible_moves
.get(move_index)
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
} else {
// Fallback : premier mouvement valide
*possible_moves
.first() .first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default())) .unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
};
if self.color == Color::White { if self.color == Color::White {
chosen_move chosen_move

View file

@ -1,4 +1,87 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{CheckerMove};
/// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum TrictracAction {
/// Lancer les dés
Roll,
/// Marquer des points
Mark { points: u8 },
/// Continuer après avoir gagné un trou
Go,
/// Effectuer un mouvement de pions
Move {
move1: (usize, usize), // (from, to) pour le premier pion
move2: (usize, usize), // (from, to) pour le deuxième pion
},
}
impl TrictracAction {
/// Encode une action en index pour le réseau de neurones
pub fn to_action_index(&self) -> usize {
match self {
TrictracAction::Roll => 0,
TrictracAction::Mark { points } => {
1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points
},
TrictracAction::Go => 14,
TrictracAction::Move { move1, move2 } => {
// Encoder les mouvements dans l'espace d'actions
// Indices 15+ pour les mouvements
15 + encode_move_pair(*move1, *move2)
}
}
}
/// Décode un index d'action en TrictracAction
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
match index {
0 => Some(TrictracAction::Roll),
1..=13 => Some(TrictracAction::Mark { points: (index - 1) as u8 }),
14 => Some(TrictracAction::Go),
i if i >= 15 => {
let move_code = i - 15;
let (move1, move2) = decode_move_pair(move_code);
Some(TrictracAction::Move { move1, move2 })
},
_ => None,
}
}
/// Retourne la taille de l'espace d'actions total
pub fn action_space_size() -> usize {
// 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles
// Pour les mouvements : 25*25*25*25 = 390625 (position 0-24 pour chaque from/to)
// Mais on peut optimiser en limitant aux positions valides (1-24)
15 + (24 * 24 * 24 * 24) // = 331791
}
}
/// Encode une paire de mouvements en un seul entier
fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize {
let (from1, to1) = move1;
let (from2, to2) = move2;
// Assurer que les positions sont dans la plage 0-24
let from1 = from1.min(24);
let to1 = to1.min(24);
let from2 = from2.min(24);
let to2 = to2.min(24);
from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2
}
/// Décode un entier en paire de mouvements
fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) {
let from1 = code / (25 * 25 * 25);
let remainder = code % (25 * 25 * 25);
let to1 = remainder / (25 * 25);
let remainder = remainder % (25 * 25);
let from2 = remainder / 25;
let to2 = remainder % 25;
((from1, to1), (from2, to2))
}
/// Configuration pour l'agent DQN /// Configuration pour l'agent DQN
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -19,8 +102,8 @@ impl Default for DqnConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
state_size: 36, state_size: 36,
hidden_size: 256, hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
num_actions: 3, num_actions: TrictracAction::action_space_size(),
learning_rate: 0.001, learning_rate: 0.001,
gamma: 0.99, gamma: 0.99,
epsilon: 0.1, epsilon: 0.1,
@ -151,3 +234,80 @@ impl SimpleNeuralNetwork {
} }
} }
/// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use crate::{Color, PointsRules};
use store::{MoveRules, TurnStage};
let mut valid_actions = Vec::new();
let active_player_id = game_state.active_player_id;
let player_color = game_state.player_color_by_id(&active_player_id);
if let Some(color) = player_color {
match game_state.turn_stage {
TurnStage::RollDice | TurnStage::RollWaiting => {
valid_actions.push(TrictracAction::Roll);
}
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
// Calculer les points possibles
if let Some(player) = game_state.players.get(&active_player_id) {
let dice_roll_count = player.dice_roll_count;
let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice);
let (max_points, _) = points_rules.get_points(dice_roll_count);
// Permettre de marquer entre 0 et max_points
for points in 0..=max_points {
valid_actions.push(TrictracAction::Mark { points });
}
}
}
TurnStage::HoldOrGoChoice => {
valid_actions.push(TrictracAction::Go);
// Ajouter aussi les mouvements possibles
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
for (move1, move2) in possible_moves {
valid_actions.push(TrictracAction::Move {
move1: (move1.get_from(), move1.get_to()),
move2: (move2.get_from(), move2.get_to()),
});
}
}
TurnStage::Move => {
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
for (move1, move2) in possible_moves {
valid_actions.push(TrictracAction::Move {
move1: (move1.get_from(), move1.get_to()),
move2: (move2.get_from(), move2.get_to()),
});
}
}
_ => {}
}
}
valid_actions
}
/// Retourne les indices des actions valides
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
get_valid_actions(game_state)
.into_iter()
.map(|action| action.to_action_index())
.collect()
}
/// Sélectionne une action valide aléatoire
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
use rand::{thread_rng, seq::SliceRandom};
let valid_actions = get_valid_actions(game_state);
let mut rng = thread_rng();
valid_actions.choose(&mut rng).cloned()
}

View file

@ -5,13 +5,13 @@ use serde::{Deserialize, Serialize};
use std::collections::VecDeque; use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, get_valid_action_indices, sample_valid_action};
/// Expérience pour le buffer de replay /// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience { pub struct Experience {
pub state: Vec<f32>, pub state: Vec<f32>,
pub action: usize, pub action: TrictracAction,
pub reward: f32, pub reward: f32,
pub next_state: Vec<f32>, pub next_state: Vec<f32>,
pub done: bool, pub done: bool,
@ -88,14 +88,37 @@ impl DqnAgent {
} }
} }
pub fn select_action(&mut self, state: &[f32]) -> usize { pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
let valid_actions = get_valid_actions(game_state);
if valid_actions.is_empty() {
// Fallback si aucune action valide
return TrictracAction::Roll;
}
let mut rng = thread_rng(); let mut rng = thread_rng();
if rng.gen::<f64>() < self.epsilon { if rng.gen::<f64>() < self.epsilon {
// Exploration : action aléatoire // Exploration : action valide aléatoire
rng.gen_range(0..self.config.num_actions) valid_actions.choose(&mut rng).cloned().unwrap_or(TrictracAction::Roll)
} else { } else {
// Exploitation : meilleure action selon le modèle // Exploitation : meilleure action valide selon le modèle
self.model.get_best_action(state) let q_values = self.model.forward(state);
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for action in &valid_actions {
let action_index = action.to_action_index();
if action_index < q_values.len() {
let q_value = q_values[action_index];
if q_value > best_q_value {
best_q_value = q_value;
best_action = action;
}
}
}
best_action.clone()
} }
} }
@ -178,7 +201,7 @@ impl TrictracEnv {
self.game_state.to_vec_float() self.game_state.to_vec_float()
} }
pub fn step(&mut self, action: usize) -> (Vec<f32>, f32, bool) { pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
let mut reward = 0.0; let mut reward = 0.0;
// Appliquer l'action de l'agent // Appliquer l'action de l'agent
@ -214,106 +237,68 @@ impl TrictracEnv {
(next_state, reward, done) (next_state, reward, done)
} }
fn apply_agent_action(&mut self, action: usize) -> f32 { fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
let mut reward = 0.0; let mut reward = 0.0;
// TODO : déterminer event selon action ... let event = match action {
TrictracAction::Roll => {
let event = match self.game_state.turn_stage {
TurnStage::RollDice => {
// Lancer les dés // Lancer les dés
GameEvent::Roll {
player_id: self.agent_player_id,
}
}
TurnStage::RollWaiting => {
// Simuler le résultat des dés
reward += 0.1; reward += 0.1;
let mut rng = thread_rng(); Some(GameEvent::Roll {
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
GameEvent::RollResult {
player_id: self.agent_player_id, player_id: self.agent_player_id,
dice: store::Dice { })
values: dice_values,
},
} }
} TrictracAction::Mark { points } => {
TurnStage::Move => { // Marquer des points
// Choisir un mouvement selon l'action reward += 0.1 * points as f32;
let rules = MoveRules::new( Some(GameEvent::Mark {
&self.agent_color,
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// TODO : choix d'action
let move_index = if action == 0 {
0
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2
} else {
possible_moves.len().saturating_sub(1)
};
let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]);
GameEvent::Move {
player_id: self.agent_player_id,
moves,
}
}
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
// Calculer et marquer les points
let dice_roll_count = self
.game_state
.players
.get(&self.agent_player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(
&self.agent_color,
&self.game_state.board,
self.game_state.dice,
);
let points = points_rules.get_points(dice_roll_count).0;
reward += 0.3 * points as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.agent_player_id, player_id: self.agent_player_id,
points, points,
})
} }
} TrictracAction::Go => {
TurnStage::HoldOrGoChoice => { // Continuer après avoir gagné un trou
// Décider de continuer ou pas selon l'action reward += 0.2;
if action == 2 { Some(GameEvent::Go {
// Action "go"
GameEvent::Go {
player_id: self.agent_player_id, player_id: self.agent_player_id,
})
} }
} else { TrictracAction::Move { move1, move2 } => {
// Passer son tour en jouant un mouvement // Effectuer un mouvement
let rules = MoveRules::new( let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
&self.agent_color, let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let moves = possible_moves[0]; reward += 0.2;
GameEvent::Move { Some(GameEvent::Move {
player_id: self.agent_player_id, player_id: self.agent_player_id,
moves, moves: (checker_move1, checker_move2),
} })
}
} }
}; };
// Appliquer l'événement si valide
if let Some(event) = event {
if self.game_state.validate(&event) { if self.game_state.validate(&event) {
self.game_state.consume(&event); self.game_state.consume(&event);
reward += 0.2;
} else { // Simuler le résultat des dés après un Roll
reward -= 1.0; // Pénalité pour action invalide if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.agent_player_id,
dice: store::Dice { values: dice_values },
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
} }
}
} else {
// Pénalité pour action invalide
reward -= 2.0;
}
}
reward reward
} }
@ -410,8 +395,8 @@ impl DqnTrainer {
let mut state = self.env.reset(); let mut state = self.env.reset();
loop { loop {
let action = self.agent.select_action(&state); let action = self.agent.select_action(&self.env.game_state, &state);
let (next_state, reward, done) = self.env.step(action); let (next_state, reward, done) = self.env.step(action.clone());
total_reward += reward; total_reward += reward;
let experience = Experience { let experience = Experience {

View file

@ -123,3 +123,44 @@ Le code est maintenant parfaitement séparé entre :
- Utilisation (stratégie simple qui charge un modèle) - Utilisation (stratégie simple qui charge un modèle)
- Entraînement (environnement complet avec logique d'apprentissage) - Entraînement (environnement complet avec logique d'apprentissage)
---
● ✅ Espace d'actions étendu implémenté avec succès !
J'ai complètement étendu l'espace d'actions du DQN pour couvrir tous les choix possibles dans le jeu Trictrac :
🎯 Nouvel espace d'actions :
TrictracAction enum comprend :
- Roll - Lancer les dés (index 0)
- Mark { points } - Marquer des points (indices 1-13 pour 0-12 points)
- Go - Continuer après avoir gagné un trou (index 14)
- Move { move1, move2 } - Mouvement de pions (indices 15+)
- Encode (from, to) pour les deux pions
- Espace total : 331,791 actions possibles
🧠 Améliorations intelligentes :
1. Masquage d'actions valides - Le DQN ne considère que les actions légales dans l'état actuel
2. Encodage compact - Mouvements encodés en un seul entier pour efficacité
3. Décodage automatique - Conversion transparente entre indices et actions
4. Stratégie adaptative - La stratégie DQN utilise maintenant le bon type d'action selon le contexte
🔧 Fonctionnalités ajoutées :
- get_valid_actions() - Obtient les actions valides pour l'état actuel
- TrictracAction::to_action_index() / from_action_index() - Conversion bidirectionnelle
- sample_valid_action() - Sélection aléatoire d'une action valide
- Réseau de neurones élargi (512 neurones cachés) pour gérer l'espace étendu
📊 Utilisation dans le jeu :
Le bot DQN peut maintenant :
- Choisir le nombre exact de points à marquer (pas seulement le maximum)
- Sélectionner des mouvements spécifiques parmi toutes les combinaisons possibles
- Décider intelligemment entre "Go" et les mouvements alternatifs
L'espace d'actions est maintenant parfaitement aligné avec la complexité réelle du jeu Trictrac ! 🎲