trictrac/bot/src/strategy/stable_baselines3.rs

270 lines
8.2 KiB
Rust

use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use store::MoveRules;
use std::process::Command;
use std::io::Write;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use serde::{Serialize, Deserialize};
#[derive(Debug)]
pub struct StableBaselines3Strategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
pub model_path: String,
}
impl Default for StableBaselines3Strategy {
fn default() -> Self {
let game = GameState::default();
Self {
game,
player_id: 2,
color: Color::Black,
model_path: "models/trictrac_ppo.zip".to_string(),
}
}
}
#[derive(Serialize, Deserialize)]
struct GameStateJson {
board: Vec<i8>,
active_player: u8,
dice: [u8; 2],
white_points: u8,
white_holes: u8,
black_points: u8,
black_holes: u8,
turn_stage: u8,
}
#[derive(Deserialize)]
struct ActionJson {
action_type: u8,
from1: usize,
to1: usize,
from2: usize,
to2: usize,
}
impl StableBaselines3Strategy {
pub fn new(model_path: &str) -> Self {
let game = GameState::default();
Self {
game,
player_id: 2,
color: Color::Black,
model_path: model_path.to_string(),
}
}
fn get_state_as_json(&self) -> GameStateJson {
// Convertir l'état du jeu en un format compatible avec notre modèle Python
let mut board = vec![0; 24];
// Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 {
board[pos] = count as i8;
}
}
// Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 {
board[pos] = -(count as i8);
}
}
// Convertir l'étape du tour en entier
let turn_stage = match self.game.turn_stage {
store::TurnStage::RollDice => 0,
store::TurnStage::RollWaiting => 1,
store::TurnStage::MarkPoints => 2,
store::TurnStage::HoldOrGoChoice => 3,
store::TurnStage::Move => 4,
store::TurnStage::MarkAdvPoints => 5,
_ => 0,
};
// Créer l'objet JSON
GameStateJson {
board,
active_player: self.game.active_player_id,
dice: [self.game.dice.values.0, self.game.dice.values.1],
white_points: self.game.get_player_points(1),
white_holes: self.game.get_player_score(1),
black_points: self.game.get_player_points(2),
black_holes: self.game.get_player_score(2),
turn_stage,
}
}
fn predict_action(&self) -> Option<ActionJson> {
// Convertir l'état du jeu en JSON
let state_json = self.get_state_as_json();
let state_str = serde_json::to_string(&state_json).unwrap();
// Écrire l'état dans un fichier temporaire
let temp_input_path = "temp_state.json";
let mut file = File::create(temp_input_path).ok()?;
file.write_all(state_str.as_bytes()).ok()?;
// Exécuter le script Python pour faire une prédiction
let output_path = "temp_action.json";
let python_script = format!(
r#"
import sys
import json
import numpy as np
from stable_baselines3 import PPO
import torch
# Charger le modèle
model = PPO.load("{}")
# Lire l'état du jeu
with open("temp_state.json", "r") as f:
state_dict = json.load(f)
# Convertir en format d'observation attendu par le modèle
observation = {{
'board': np.array(state_dict['board'], dtype=np.int8),
'active_player': state_dict['active_player'],
'dice': np.array(state_dict['dice'], dtype=np.int32),
'white_points': state_dict['white_points'],
'white_holes': state_dict['white_holes'],
'black_points': state_dict['black_points'],
'black_holes': state_dict['black_holes'],
'turn_stage': state_dict['turn_stage'],
}}
# Prédire l'action
action, _ = model.predict(observation)
# Convertir l'action en format lisible
action_dict = {{
'action_type': int(action[0]),
'from1': int(action[1]),
'to1': int(action[2]),
'from2': int(action[3]),
'to2': int(action[4]),
}}
# Écrire l'action dans un fichier
with open("{}", "w") as f:
json.dump(action_dict, f)
"#,
self.model_path, output_path
);
let temp_script_path = "temp_predict.py";
let mut script_file = File::create(temp_script_path).ok()?;
script_file.write_all(python_script.as_bytes()).ok()?;
// Exécuter le script Python
let status = Command::new("python")
.arg(temp_script_path)
.status()
.ok()?;
if !status.success() {
return None;
}
// Lire la prédiction
if Path::new(output_path).exists() {
let mut file = File::open(output_path).ok()?;
let mut contents = String::new();
file.read_to_string(&mut contents).ok()?;
// Nettoyer les fichiers temporaires
std::fs::remove_file(temp_input_path).ok();
std::fs::remove_file(temp_script_path).ok();
std::fs::remove_file(output_path).ok();
// Analyser la prédiction
let action: ActionJson = serde_json::from_str(&contents).ok()?;
Some(action)
} else {
None
}
}
}
impl BotStrategy for StableBaselines3Strategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "mark" (1)
if let Some(action) = self.predict_action() {
if action.action_type == 1 {
// Marquer les points calculés par le modèle (ici on utilise la somme des dés comme proxy)
return self.game.dice.values.0 + self.game.dice.values.1;
}
}
// Fallback vers la méthode standard si la prédiction échoue
let dice_roll_count = self
.get_game()
.players
.get(&self.player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
points_rules.get_points(dice_roll_count).0
}
fn calculate_adv_points(&self) -> u8 {
self.calculate_points()
}
fn choose_go(&self) -> bool {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "go" (2)
if let Some(action) = self.predict_action() {
return action.action_type == 2;
}
// Fallback vers la méthode standard si la prédiction échoue
true
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "move" (0)
if let Some(action) = self.predict_action() {
if action.action_type == 0 {
let move1 = CheckerMove::new(action.from1, action.to1).unwrap_or_default();
let move2 = CheckerMove::new(action.from2, action.to2).unwrap_or_default();
return (move1, move2);
}
}
// Fallback vers la méthode standard si la prédiction échoue
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let choosen_move = *possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
}
}
}