trictrac/bot/src/strategy/dqn.rs

176 lines
5 KiB
Rust
Raw Normal View History

2025-08-01 20:45:57 +02:00
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use std::path::Path;
2025-05-30 20:32:00 +02:00
use store::MoveRules;
2025-08-01 20:45:57 +02:00
use crate::dqn::dqn_common::{
2025-06-08 21:20:04 +02:00
get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction,
};
2025-05-26 20:44:35 +02:00
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
pub struct DqnStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
2025-05-26 20:44:35 +02:00
pub model: Option<SimpleNeuralNetwork>,
}
impl Default for DqnStrategy {
fn default() -> Self {
Self {
2025-05-26 20:44:35 +02:00
game: GameState::default(),
player_id: 2,
color: Color::Black,
2025-05-26 20:44:35 +02:00
model: None,
}
}
}
impl DqnStrategy {
pub fn new() -> Self {
Self::default()
}
2025-05-26 20:44:35 +02:00
pub fn new_with_model<P: AsRef<Path>>(model_path: P) -> Self {
let mut strategy = Self::new();
2025-05-26 20:44:35 +02:00
if let Ok(model) = SimpleNeuralNetwork::load(model_path) {
strategy.model = Some(model);
}
strategy
}
2025-06-01 20:00:15 +02:00
/// Utilise le modèle DQN pour choisir une action valide
fn get_dqn_action(&self) -> Option<TrictracAction> {
2025-05-26 20:44:35 +02:00
if let Some(ref model) = self.model {
2025-05-30 20:32:00 +02:00
let state = self.game.to_vec_float();
2025-06-01 20:00:15 +02:00
let valid_actions = get_valid_actions(&self.game);
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
if valid_actions.is_empty() {
return None;
}
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
// Obtenir les Q-values pour toutes les actions
let q_values = model.forward(&state);
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
// Trouver la meilleure action valide
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
for action in &valid_actions {
let action_index = action.to_action_index();
if action_index < q_values.len() {
let q_value = q_values[action_index];
if q_value > best_q_value {
best_q_value = q_value;
best_action = action;
}
}
}
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
Some(best_action.clone())
2025-05-26 20:44:35 +02:00
} else {
2025-06-01 20:00:15 +02:00
// Fallback : action aléatoire valide
sample_valid_action(&self.game)
}
}
}
impl BotStrategy for DqnStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
2025-05-30 20:32:00 +02:00
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
2025-06-08 21:20:04 +02:00
self.game.dice_points.0
}
fn calculate_adv_points(&self) -> u8 {
2025-06-08 21:20:04 +02:00
self.game.dice_points.1
}
fn choose_go(&self) -> bool {
2025-06-01 20:00:15 +02:00
// Utiliser le DQN pour décider si on continue
2025-05-26 20:44:35 +02:00
if let Some(action) = self.get_dqn_action() {
2025-06-01 20:00:15 +02:00
matches!(action, TrictracAction::Go)
2025-05-26 20:44:35 +02:00
} else {
// Fallback : toujours continuer
true
}
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
2025-06-01 20:00:15 +02:00
// Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() {
2025-06-08 21:20:04 +02:00
if let TrictracAction::Move {
dice_order,
from1,
from2,
} = action
{
let dicevals = self.game.dice.values;
let (mut dice1, mut dice2) = if dice_order {
(dicevals.0, dicevals.1)
} else {
(dicevals.1, dicevals.0)
};
if from1 == 0 {
// empty move
dice1 = 0;
}
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
2025-06-01 20:00:15 +02:00
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
return chosen_move;
}
}
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
// Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
2025-06-08 21:20:04 +02:00
2025-06-01 20:00:15 +02:00
let chosen_move = *possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
2025-05-30 20:32:00 +02:00
if self.color == Color::White {
chosen_move
} else {
(chosen_move.0.mirror(), chosen_move.1.mirror())
}
}
2025-05-30 20:32:00 +02:00
}