177 lines
5.4 KiB
Rust
177 lines
5.4 KiB
Rust
|
|
use burn::backend::NdArray;
|
||
|
|
use burn::tensor::cast::ToElement;
|
||
|
|
use burn_rl::base::{ElemType, Model, State};
|
||
|
|
|
||
|
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||
|
|
use log::info;
|
||
|
|
use store::MoveRules;
|
||
|
|
|
||
|
|
use crate::dqn::burnrl::{dqn_model, environment, utils};
|
||
|
|
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||
|
|
|
||
|
|
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||
|
|
|
||
|
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct DqnBurnStrategy {
|
||
|
|
pub game: GameState,
|
||
|
|
pub player_id: PlayerId,
|
||
|
|
pub color: Color,
|
||
|
|
pub model: Option<DqnBurnNetwork>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for DqnBurnStrategy {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
game: GameState::default(),
|
||
|
|
player_id: 1,
|
||
|
|
color: Color::White,
|
||
|
|
model: None,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl DqnBurnStrategy {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self::default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn new_with_model(model_path: &String) -> Self {
|
||
|
|
info!("Loading model {model_path:?}");
|
||
|
|
let mut strategy = Self::new();
|
||
|
|
strategy.model = utils::load_model(256, model_path);
|
||
|
|
strategy
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Utilise le modèle DQN pour choisir une action valide
|
||
|
|
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||
|
|
if let Some(ref model) = self.model {
|
||
|
|
let state = environment::TrictracState::from_game_state(&self.game);
|
||
|
|
let valid_actions_indices = get_valid_action_indices(&self.game);
|
||
|
|
if valid_actions_indices.is_empty() {
|
||
|
|
return None; // No valid actions, end of episode
|
||
|
|
}
|
||
|
|
|
||
|
|
// Obtenir les Q-values pour toutes les actions
|
||
|
|
let q_values = model.infer(state.to_tensor().unsqueeze());
|
||
|
|
|
||
|
|
// Set non valid actions q-values to lowest
|
||
|
|
let mut masked_q_values = q_values.clone();
|
||
|
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||
|
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||
|
|
if !valid_actions_indices.contains(&index) {
|
||
|
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||
|
|
masked_q_values.clone().equal_elem(*q_value),
|
||
|
|
f32::NEG_INFINITY,
|
||
|
|
);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
// Get best action (highest q-value)
|
||
|
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||
|
|
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
|
||
|
|
action_index,
|
||
|
|
))
|
||
|
|
} else {
|
||
|
|
// Fallback : action aléatoire valide
|
||
|
|
sample_valid_action(&self.game)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl BotStrategy for DqnBurnStrategy {
|
||
|
|
fn get_game(&self) -> &GameState {
|
||
|
|
&self.game
|
||
|
|
}
|
||
|
|
|
||
|
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||
|
|
&mut self.game
|
||
|
|
}
|
||
|
|
|
||
|
|
fn set_color(&mut self, color: Color) {
|
||
|
|
self.color = color;
|
||
|
|
}
|
||
|
|
|
||
|
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||
|
|
self.player_id = player_id;
|
||
|
|
}
|
||
|
|
|
||
|
|
fn calculate_points(&self) -> u8 {
|
||
|
|
self.game.dice_points.0
|
||
|
|
}
|
||
|
|
|
||
|
|
fn calculate_adv_points(&self) -> u8 {
|
||
|
|
self.game.dice_points.1
|
||
|
|
}
|
||
|
|
|
||
|
|
fn choose_go(&self) -> bool {
|
||
|
|
// Utiliser le DQN pour décider si on continue
|
||
|
|
if let Some(action) = self.get_dqn_action() {
|
||
|
|
matches!(action, TrictracAction::Go)
|
||
|
|
} else {
|
||
|
|
// Fallback : toujours continuer
|
||
|
|
true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||
|
|
// Utiliser le DQN pour choisir le mouvement
|
||
|
|
if let Some(TrictracAction::Move {
|
||
|
|
dice_order,
|
||
|
|
from1,
|
||
|
|
from2,
|
||
|
|
}) = self.get_dqn_action()
|
||
|
|
{
|
||
|
|
let dicevals = self.game.dice.values;
|
||
|
|
let (mut dice1, mut dice2) = if dice_order {
|
||
|
|
(dicevals.0, dicevals.1)
|
||
|
|
} else {
|
||
|
|
(dicevals.1, dicevals.0)
|
||
|
|
};
|
||
|
|
|
||
|
|
if from1 == 0 {
|
||
|
|
// empty move
|
||
|
|
dice1 = 0;
|
||
|
|
}
|
||
|
|
let mut to1 = from1 + dice1 as usize;
|
||
|
|
if 24 < to1 {
|
||
|
|
// sortie
|
||
|
|
to1 = 0;
|
||
|
|
}
|
||
|
|
if from2 == 0 {
|
||
|
|
// empty move
|
||
|
|
dice2 = 0;
|
||
|
|
}
|
||
|
|
let mut to2 = from2 + dice2 as usize;
|
||
|
|
if 24 < to2 {
|
||
|
|
// sortie
|
||
|
|
to2 = 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||
|
|
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||
|
|
|
||
|
|
let chosen_move = if self.color == Color::White {
|
||
|
|
(checker_move1, checker_move2)
|
||
|
|
} else {
|
||
|
|
(checker_move1.mirror(), checker_move2.mirror())
|
||
|
|
};
|
||
|
|
|
||
|
|
return chosen_move;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Fallback : utiliser la stratégie par défaut
|
||
|
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||
|
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||
|
|
|
||
|
|
let chosen_move = *possible_moves
|
||
|
|
.first()
|
||
|
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||
|
|
|
||
|
|
if self.color == Color::White {
|
||
|
|
chosen_move
|
||
|
|
} else {
|
||
|
|
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|