debug
This commit is contained in:
parent
f7eea0ed02
commit
ebe98ca229
|
|
@ -2,7 +2,7 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use store::MoveRules;
|
use store::MoveRules;
|
||||||
|
|
||||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action};
|
use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action};
|
||||||
|
|
||||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::{CheckerMove};
|
|
||||||
|
|
||||||
/// Types d'actions possibles dans le jeu
|
/// Types d'actions possibles dans le jeu
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
|
@ -11,9 +10,9 @@ pub enum TrictracAction {
|
||||||
/// Continuer après avoir gagné un trou
|
/// Continuer après avoir gagné un trou
|
||||||
Go,
|
Go,
|
||||||
/// Effectuer un mouvement de pions
|
/// Effectuer un mouvement de pions
|
||||||
Move {
|
Move {
|
||||||
move1: (usize, usize), // (from, to) pour le premier pion
|
move1: (usize, usize), // (from, to) pour le premier pion
|
||||||
move2: (usize, usize), // (from, to) pour le deuxième pion
|
move2: (usize, usize), // (from, to) pour le deuxième pion
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -23,8 +22,8 @@ impl TrictracAction {
|
||||||
match self {
|
match self {
|
||||||
TrictracAction::Roll => 0,
|
TrictracAction::Roll => 0,
|
||||||
TrictracAction::Mark { points } => {
|
TrictracAction::Mark { points } => {
|
||||||
1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points
|
1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points
|
||||||
},
|
}
|
||||||
TrictracAction::Go => 14,
|
TrictracAction::Go => 14,
|
||||||
TrictracAction::Move { move1, move2 } => {
|
TrictracAction::Move { move1, move2 } => {
|
||||||
// Encoder les mouvements dans l'espace d'actions
|
// Encoder les mouvements dans l'espace d'actions
|
||||||
|
|
@ -33,22 +32,24 @@ impl TrictracAction {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Décode un index d'action en TrictracAction
|
/// Décode un index d'action en TrictracAction
|
||||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||||
match index {
|
match index {
|
||||||
0 => Some(TrictracAction::Roll),
|
0 => Some(TrictracAction::Roll),
|
||||||
1..=13 => Some(TrictracAction::Mark { points: (index - 1) as u8 }),
|
1..=13 => Some(TrictracAction::Mark {
|
||||||
|
points: (index - 1) as u8,
|
||||||
|
}),
|
||||||
14 => Some(TrictracAction::Go),
|
14 => Some(TrictracAction::Go),
|
||||||
i if i >= 15 => {
|
i if i >= 15 => {
|
||||||
let move_code = i - 15;
|
let move_code = i - 15;
|
||||||
let (move1, move2) = decode_move_pair(move_code);
|
let (move1, move2) = decode_move_pair(move_code);
|
||||||
Some(TrictracAction::Move { move1, move2 })
|
Some(TrictracAction::Move { move1, move2 })
|
||||||
},
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retourne la taille de l'espace d'actions total
|
/// Retourne la taille de l'espace d'actions total
|
||||||
pub fn action_space_size() -> usize {
|
pub fn action_space_size() -> usize {
|
||||||
// 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles
|
// 1 (Roll) + 13 (Mark 0-12) + 1 (Go) + mouvements possibles
|
||||||
|
|
@ -67,7 +68,7 @@ fn encode_move_pair(move1: (usize, usize), move2: (usize, usize)) -> usize {
|
||||||
let to1 = to1.min(24);
|
let to1 = to1.min(24);
|
||||||
let from2 = from2.min(24);
|
let from2 = from2.min(24);
|
||||||
let to2 = to2.min(24);
|
let to2 = to2.min(24);
|
||||||
|
|
||||||
from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2
|
from1 * (25 * 25 * 25) + to1 * (25 * 25) + from2 * 25 + to2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -79,7 +80,7 @@ fn decode_move_pair(code: usize) -> ((usize, usize), (usize, usize)) {
|
||||||
let remainder = remainder % (25 * 25);
|
let remainder = remainder % (25 * 25);
|
||||||
let from2 = remainder / 25;
|
let from2 = remainder / 25;
|
||||||
let to2 = remainder % 25;
|
let to2 = remainder % 25;
|
||||||
|
|
||||||
((from1, to1), (from2, to2))
|
((from1, to1), (from2, to2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -102,7 +103,7 @@ impl Default for DqnConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
state_size: 36,
|
state_size: 36,
|
||||||
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||||
num_actions: TrictracAction::action_space_size(),
|
num_actions: TrictracAction::action_space_size(),
|
||||||
learning_rate: 0.001,
|
learning_rate: 0.001,
|
||||||
gamma: 0.99,
|
gamma: 0.99,
|
||||||
|
|
@ -236,14 +237,14 @@ impl SimpleNeuralNetwork {
|
||||||
|
|
||||||
/// Obtient les actions valides pour l'état de jeu actuel
|
/// Obtient les actions valides pour l'état de jeu actuel
|
||||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
use crate::{Color, PointsRules};
|
use crate::PointsRules;
|
||||||
use store::{MoveRules, TurnStage};
|
use store::{MoveRules, TurnStage};
|
||||||
|
|
||||||
let mut valid_actions = Vec::new();
|
let mut valid_actions = Vec::new();
|
||||||
|
|
||||||
let active_player_id = game_state.active_player_id;
|
let active_player_id = game_state.active_player_id;
|
||||||
let player_color = game_state.player_color_by_id(&active_player_id);
|
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||||
|
|
||||||
if let Some(color) = player_color {
|
if let Some(color) = player_color {
|
||||||
match game_state.turn_stage {
|
match game_state.turn_stage {
|
||||||
TurnStage::RollDice | TurnStage::RollWaiting => {
|
TurnStage::RollDice | TurnStage::RollWaiting => {
|
||||||
|
|
@ -255,7 +256,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
let dice_roll_count = player.dice_roll_count;
|
let dice_roll_count = player.dice_roll_count;
|
||||||
let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice);
|
let points_rules = PointsRules::new(&color, &game_state.board, game_state.dice);
|
||||||
let (max_points, _) = points_rules.get_points(dice_roll_count);
|
let (max_points, _) = points_rules.get_points(dice_roll_count);
|
||||||
|
|
||||||
// Permettre de marquer entre 0 et max_points
|
// Permettre de marquer entre 0 et max_points
|
||||||
for points in 0..=max_points {
|
for points in 0..=max_points {
|
||||||
valid_actions.push(TrictracAction::Mark { points });
|
valid_actions.push(TrictracAction::Mark { points });
|
||||||
|
|
@ -264,11 +265,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
}
|
}
|
||||||
TurnStage::HoldOrGoChoice => {
|
TurnStage::HoldOrGoChoice => {
|
||||||
valid_actions.push(TrictracAction::Go);
|
valid_actions.push(TrictracAction::Go);
|
||||||
|
|
||||||
// Ajouter aussi les mouvements possibles
|
// Ajouter aussi les mouvements possibles
|
||||||
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
|
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(TrictracAction::Move {
|
valid_actions.push(TrictracAction::Move {
|
||||||
move1: (move1.get_from(), move1.get_to()),
|
move1: (move1.get_from(), move1.get_to()),
|
||||||
|
|
@ -279,7 +280,7 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
TurnStage::Move => {
|
TurnStage::Move => {
|
||||||
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
|
let rules = MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(TrictracAction::Move {
|
valid_actions.push(TrictracAction::Move {
|
||||||
move1: (move1.get_from(), move1.get_to()),
|
move1: (move1.get_from(), move1.get_to()),
|
||||||
|
|
@ -287,10 +288,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
valid_actions
|
valid_actions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -304,10 +304,9 @@ pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
||||||
|
|
||||||
/// Sélectionne une action valide aléatoire
|
/// Sélectionne une action valide aléatoire
|
||||||
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||||
use rand::{thread_rng, seq::SliceRandom};
|
use rand::{seq::SliceRandom, thread_rng};
|
||||||
|
|
||||||
let valid_actions = get_valid_actions(game_state);
|
let valid_actions = get_valid_actions(game_state);
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
valid_actions.choose(&mut rng).cloned()
|
valid_actions.choose(&mut rng).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, get_valid_action_indices, sample_valid_action};
|
use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
|
||||||
|
|
||||||
/// Expérience pour le buffer de replay
|
/// Expérience pour le buffer de replay
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -90,23 +90,26 @@ impl DqnAgent {
|
||||||
|
|
||||||
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||||
let valid_actions = get_valid_actions(game_state);
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
|
||||||
if valid_actions.is_empty() {
|
if valid_actions.is_empty() {
|
||||||
// Fallback si aucune action valide
|
// Fallback si aucune action valide
|
||||||
return TrictracAction::Roll;
|
return TrictracAction::Roll;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
if rng.gen::<f64>() < self.epsilon {
|
if rng.gen::<f64>() < self.epsilon {
|
||||||
// Exploration : action valide aléatoire
|
// Exploration : action valide aléatoire
|
||||||
valid_actions.choose(&mut rng).cloned().unwrap_or(TrictracAction::Roll)
|
valid_actions
|
||||||
|
.choose(&mut rng)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(TrictracAction::Roll)
|
||||||
} else {
|
} else {
|
||||||
// Exploitation : meilleure action valide selon le modèle
|
// Exploitation : meilleure action valide selon le modèle
|
||||||
let q_values = self.model.forward(state);
|
let q_values = self.model.forward(state);
|
||||||
|
|
||||||
let mut best_action = &valid_actions[0];
|
let mut best_action = &valid_actions[0];
|
||||||
let mut best_q_value = f32::NEG_INFINITY;
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
for action in &valid_actions {
|
for action in &valid_actions {
|
||||||
let action_index = action.to_action_index();
|
let action_index = action.to_action_index();
|
||||||
if action_index < q_values.len() {
|
if action_index < q_values.len() {
|
||||||
|
|
@ -117,7 +120,7 @@ impl DqnAgent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
best_action.clone()
|
best_action.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -267,7 +270,7 @@ impl TrictracEnv {
|
||||||
// Effectuer un mouvement
|
// Effectuer un mouvement
|
||||||
let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
||||||
let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
||||||
|
|
||||||
reward += 0.2;
|
reward += 0.2;
|
||||||
Some(GameEvent::Move {
|
Some(GameEvent::Move {
|
||||||
player_id: self.agent_player_id,
|
player_id: self.agent_player_id,
|
||||||
|
|
@ -280,14 +283,16 @@ impl TrictracEnv {
|
||||||
if let Some(event) = event {
|
if let Some(event) = event {
|
||||||
if self.game_state.validate(&event) {
|
if self.game_state.validate(&event) {
|
||||||
self.game_state.consume(&event);
|
self.game_state.consume(&event);
|
||||||
|
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
let dice_event = GameEvent::RollResult {
|
let dice_event = GameEvent::RollResult {
|
||||||
player_id: self.agent_player_id,
|
player_id: self.agent_player_id,
|
||||||
dice: store::Dice { values: dice_values },
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
if self.game_state.validate(&dice_event) {
|
if self.game_state.validate(&dice_event) {
|
||||||
self.game_state.consume(&dice_event);
|
self.game_state.consume(&dice_event);
|
||||||
|
|
@ -393,8 +398,10 @@ impl DqnTrainer {
|
||||||
pub fn train_episode(&mut self) -> f32 {
|
pub fn train_episode(&mut self) -> f32 {
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
let mut state = self.env.reset();
|
let mut state = self.env.reset();
|
||||||
|
// let mut step_count = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// step_count += 1;
|
||||||
let action = self.agent.select_action(&self.env.game_state, &state);
|
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||||
let (next_state, reward, done) = self.env.step(action.clone());
|
let (next_state, reward, done) = self.env.step(action.clone());
|
||||||
total_reward += reward;
|
total_reward += reward;
|
||||||
|
|
@ -412,6 +419,9 @@ impl DqnTrainer {
|
||||||
if done {
|
if done {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
// if step_count % 100 == 0 {
|
||||||
|
// println!("{:?}", next_state);
|
||||||
|
// }
|
||||||
state = next_state;
|
state = next_state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -429,6 +439,7 @@ impl DqnTrainer {
|
||||||
for episode in 1..=episodes {
|
for episode in 1..=episodes {
|
||||||
let reward = self.train_episode();
|
let reward = self.train_episode();
|
||||||
|
|
||||||
|
print!(".");
|
||||||
if episode % 100 == 0 {
|
if episode % 100 == 0 {
|
||||||
println!(
|
println!(
|
||||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||||
use store::MoveRules;
|
use serde::{Deserialize, Serialize};
|
||||||
use std::process::Command;
|
|
||||||
use std::io::Write;
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
use std::io::Write;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use serde::{Serialize, Deserialize};
|
use std::process::Command;
|
||||||
|
use store::MoveRules;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct StableBaselines3Strategy {
|
pub struct StableBaselines3Strategy {
|
||||||
|
|
@ -62,21 +62,21 @@ impl StableBaselines3Strategy {
|
||||||
fn get_state_as_json(&self) -> GameStateJson {
|
fn get_state_as_json(&self) -> GameStateJson {
|
||||||
// Convertir l'état du jeu en un format compatible avec notre modèle Python
|
// Convertir l'état du jeu en un format compatible avec notre modèle Python
|
||||||
let mut board = vec![0; 24];
|
let mut board = vec![0; 24];
|
||||||
|
|
||||||
// Remplir les positions des pièces blanches (valeurs positives)
|
// Remplir les positions des pièces blanches (valeurs positives)
|
||||||
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
||||||
if pos < 24 {
|
if pos < 24 {
|
||||||
board[pos] = count as i8;
|
board[pos] = count as i8;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remplir les positions des pièces noires (valeurs négatives)
|
// Remplir les positions des pièces noires (valeurs négatives)
|
||||||
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
||||||
if pos < 24 {
|
if pos < 24 {
|
||||||
board[pos] = -(count as i8);
|
board[pos] = -(count as i8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convertir l'étape du tour en entier
|
// Convertir l'étape du tour en entier
|
||||||
let turn_stage = match self.game.turn_stage {
|
let turn_stage = match self.game.turn_stage {
|
||||||
store::TurnStage::RollDice => 0,
|
store::TurnStage::RollDice => 0,
|
||||||
|
|
@ -85,15 +85,14 @@ impl StableBaselines3Strategy {
|
||||||
store::TurnStage::HoldOrGoChoice => 3,
|
store::TurnStage::HoldOrGoChoice => 3,
|
||||||
store::TurnStage::Move => 4,
|
store::TurnStage::Move => 4,
|
||||||
store::TurnStage::MarkAdvPoints => 5,
|
store::TurnStage::MarkAdvPoints => 5,
|
||||||
_ => 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Récupérer les points et trous des joueurs
|
// Récupérer les points et trous des joueurs
|
||||||
let white_points = self.game.players.get(&1).map_or(0, |p| p.points);
|
let white_points = self.game.players.get(&1).map_or(0, |p| p.points);
|
||||||
let white_holes = self.game.players.get(&1).map_or(0, |p| p.holes);
|
let white_holes = self.game.players.get(&1).map_or(0, |p| p.holes);
|
||||||
let black_points = self.game.players.get(&2).map_or(0, |p| p.points);
|
let black_points = self.game.players.get(&2).map_or(0, |p| p.points);
|
||||||
let black_holes = self.game.players.get(&2).map_or(0, |p| p.holes);
|
let black_holes = self.game.players.get(&2).map_or(0, |p| p.holes);
|
||||||
|
|
||||||
// Créer l'objet JSON
|
// Créer l'objet JSON
|
||||||
GameStateJson {
|
GameStateJson {
|
||||||
board,
|
board,
|
||||||
|
|
@ -111,12 +110,12 @@ impl StableBaselines3Strategy {
|
||||||
// Convertir l'état du jeu en JSON
|
// Convertir l'état du jeu en JSON
|
||||||
let state_json = self.get_state_as_json();
|
let state_json = self.get_state_as_json();
|
||||||
let state_str = serde_json::to_string(&state_json).unwrap();
|
let state_str = serde_json::to_string(&state_json).unwrap();
|
||||||
|
|
||||||
// Écrire l'état dans un fichier temporaire
|
// Écrire l'état dans un fichier temporaire
|
||||||
let temp_input_path = "temp_state.json";
|
let temp_input_path = "temp_state.json";
|
||||||
let mut file = File::create(temp_input_path).ok()?;
|
let mut file = File::create(temp_input_path).ok()?;
|
||||||
file.write_all(state_str.as_bytes()).ok()?;
|
file.write_all(state_str.as_bytes()).ok()?;
|
||||||
|
|
||||||
// Exécuter le script Python pour faire une prédiction
|
// Exécuter le script Python pour faire une prédiction
|
||||||
let output_path = "temp_action.json";
|
let output_path = "temp_action.json";
|
||||||
let python_script = format!(
|
let python_script = format!(
|
||||||
|
|
@ -164,32 +163,29 @@ with open("{}", "w") as f:
|
||||||
"#,
|
"#,
|
||||||
self.model_path, output_path
|
self.model_path, output_path
|
||||||
);
|
);
|
||||||
|
|
||||||
let temp_script_path = "temp_predict.py";
|
let temp_script_path = "temp_predict.py";
|
||||||
let mut script_file = File::create(temp_script_path).ok()?;
|
let mut script_file = File::create(temp_script_path).ok()?;
|
||||||
script_file.write_all(python_script.as_bytes()).ok()?;
|
script_file.write_all(python_script.as_bytes()).ok()?;
|
||||||
|
|
||||||
// Exécuter le script Python
|
// Exécuter le script Python
|
||||||
let status = Command::new("python")
|
let status = Command::new("python").arg(temp_script_path).status().ok()?;
|
||||||
.arg(temp_script_path)
|
|
||||||
.status()
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
if !status.success() {
|
if !status.success() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lire la prédiction
|
// Lire la prédiction
|
||||||
if Path::new(output_path).exists() {
|
if Path::new(output_path).exists() {
|
||||||
let mut file = File::open(output_path).ok()?;
|
let mut file = File::open(output_path).ok()?;
|
||||||
let mut contents = String::new();
|
let mut contents = String::new();
|
||||||
file.read_to_string(&mut contents).ok()?;
|
file.read_to_string(&mut contents).ok()?;
|
||||||
|
|
||||||
// Nettoyer les fichiers temporaires
|
// Nettoyer les fichiers temporaires
|
||||||
std::fs::remove_file(temp_input_path).ok();
|
std::fs::remove_file(temp_input_path).ok();
|
||||||
std::fs::remove_file(temp_script_path).ok();
|
std::fs::remove_file(temp_script_path).ok();
|
||||||
std::fs::remove_file(output_path).ok();
|
std::fs::remove_file(output_path).ok();
|
||||||
|
|
||||||
// Analyser la prédiction
|
// Analyser la prédiction
|
||||||
let action: ActionJson = serde_json::from_str(&contents).ok()?;
|
let action: ActionJson = serde_json::from_str(&contents).ok()?;
|
||||||
Some(action)
|
Some(action)
|
||||||
|
|
@ -203,7 +199,7 @@ impl BotStrategy for StableBaselines3Strategy {
|
||||||
fn get_game(&self) -> &GameState {
|
fn get_game(&self) -> &GameState {
|
||||||
&self.game
|
&self.game
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mut_game(&mut self) -> &mut GameState {
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||||||
&mut self.game
|
&mut self.game
|
||||||
}
|
}
|
||||||
|
|
@ -224,7 +220,7 @@ impl BotStrategy for StableBaselines3Strategy {
|
||||||
return self.game.dice.values.0 + self.game.dice.values.1;
|
return self.game.dice.values.0 + self.game.dice.values.1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback vers la méthode standard si la prédiction échoue
|
// Fallback vers la méthode standard si la prédiction échoue
|
||||||
let dice_roll_count = self
|
let dice_roll_count = self
|
||||||
.get_game()
|
.get_game()
|
||||||
|
|
@ -245,7 +241,7 @@ impl BotStrategy for StableBaselines3Strategy {
|
||||||
if let Some(action) = self.predict_action() {
|
if let Some(action) = self.predict_action() {
|
||||||
return action.action_type == 2;
|
return action.action_type == 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback vers la méthode standard si la prédiction échoue
|
// Fallback vers la méthode standard si la prédiction échoue
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
@ -259,18 +255,19 @@ impl BotStrategy for StableBaselines3Strategy {
|
||||||
return (move1, move2);
|
return (move1, move2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback vers la méthode standard si la prédiction échoue
|
// Fallback vers la méthode standard si la prédiction échoue
|
||||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
let choosen_move = *possible_moves
|
let choosen_move = *possible_moves
|
||||||
.first()
|
.first()
|
||||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
if self.color == Color::White {
|
if self.color == Color::White {
|
||||||
choosen_move
|
choosen_move
|
||||||
} else {
|
} else {
|
||||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ impl GameState {
|
||||||
state.push(self.dice.values.0 as i8);
|
state.push(self.dice.values.0 as i8);
|
||||||
state.push(self.dice.values.1 as i8);
|
state.push(self.dice.values.1 as i8);
|
||||||
|
|
||||||
// points length=4 x2 joueurs = 8
|
// points, trous, bredouille, grande bredouille length=4 x2 joueurs = 8
|
||||||
let white_player: Vec<i8> = self
|
let white_player: Vec<i8> = self
|
||||||
.get_white_player()
|
.get_white_player()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue