feat: ai strategy (wip)
This commit is contained in:
parent
899a690869
commit
e9a4b7ac67
10 changed files with 394 additions and 52 deletions
|
|
@ -7,4 +7,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
pretty_assertions = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
store = { path = "../store" }
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ mod strategy;
|
|||
|
||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
pub use strategy::default::DefaultStrategy;
|
||||
pub use strategy::stable_baselines3::StableBaselines3Strategy;
|
||||
|
||||
pub trait BotStrategy: std::fmt::Debug {
|
||||
fn get_game(&self) -> &GameState;
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod client;
|
||||
pub mod default;
|
||||
pub mod stable_baselines3;
|
||||
|
|
|
|||
270
bot/src/strategy/stable_baselines3.rs
Normal file
270
bot/src/strategy/stable_baselines3.rs
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use store::MoveRules;
|
||||
use std::process::Command;
|
||||
use std::io::Write;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StableBaselines3Strategy {
|
||||
pub game: GameState,
|
||||
pub player_id: PlayerId,
|
||||
pub color: Color,
|
||||
pub model_path: String,
|
||||
}
|
||||
|
||||
impl Default for StableBaselines3Strategy {
|
||||
fn default() -> Self {
|
||||
let game = GameState::default();
|
||||
Self {
|
||||
game,
|
||||
player_id: 2,
|
||||
color: Color::Black,
|
||||
model_path: "models/trictrac_ppo.zip".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct GameStateJson {
|
||||
board: Vec<i8>,
|
||||
active_player: u8,
|
||||
dice: [u8; 2],
|
||||
white_points: u8,
|
||||
white_holes: u8,
|
||||
black_points: u8,
|
||||
black_holes: u8,
|
||||
turn_stage: u8,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ActionJson {
|
||||
action_type: u8,
|
||||
from1: usize,
|
||||
to1: usize,
|
||||
from2: usize,
|
||||
to2: usize,
|
||||
}
|
||||
|
||||
impl StableBaselines3Strategy {
|
||||
pub fn new(model_path: &str) -> Self {
|
||||
let game = GameState::default();
|
||||
Self {
|
||||
game,
|
||||
player_id: 2,
|
||||
color: Color::Black,
|
||||
model_path: model_path.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_state_as_json(&self) -> GameStateJson {
|
||||
// Convertir l'état du jeu en un format compatible avec notre modèle Python
|
||||
let mut board = vec![0; 24];
|
||||
|
||||
// Remplir les positions des pièces blanches (valeurs positives)
|
||||
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
||||
if pos < 24 {
|
||||
board[pos] = count as i8;
|
||||
}
|
||||
}
|
||||
|
||||
// Remplir les positions des pièces noires (valeurs négatives)
|
||||
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
||||
if pos < 24 {
|
||||
board[pos] = -(count as i8);
|
||||
}
|
||||
}
|
||||
|
||||
// Convertir l'étape du tour en entier
|
||||
let turn_stage = match self.game.turn_stage {
|
||||
store::TurnStage::RollDice => 0,
|
||||
store::TurnStage::RollWaiting => 1,
|
||||
store::TurnStage::MarkPoints => 2,
|
||||
store::TurnStage::HoldOrGoChoice => 3,
|
||||
store::TurnStage::Move => 4,
|
||||
store::TurnStage::MarkAdvPoints => 5,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
// Créer l'objet JSON
|
||||
GameStateJson {
|
||||
board,
|
||||
active_player: self.game.active_player_id,
|
||||
dice: [self.game.dice.values.0, self.game.dice.values.1],
|
||||
white_points: self.game.get_player_points(1),
|
||||
white_holes: self.game.get_player_score(1),
|
||||
black_points: self.game.get_player_points(2),
|
||||
black_holes: self.game.get_player_score(2),
|
||||
turn_stage,
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_action(&self) -> Option<ActionJson> {
|
||||
// Convertir l'état du jeu en JSON
|
||||
let state_json = self.get_state_as_json();
|
||||
let state_str = serde_json::to_string(&state_json).unwrap();
|
||||
|
||||
// Écrire l'état dans un fichier temporaire
|
||||
let temp_input_path = "temp_state.json";
|
||||
let mut file = File::create(temp_input_path).ok()?;
|
||||
file.write_all(state_str.as_bytes()).ok()?;
|
||||
|
||||
// Exécuter le script Python pour faire une prédiction
|
||||
let output_path = "temp_action.json";
|
||||
let python_script = format!(
|
||||
r#"
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
import torch
|
||||
|
||||
# Charger le modèle
|
||||
model = PPO.load("{}")
|
||||
|
||||
# Lire l'état du jeu
|
||||
with open("temp_state.json", "r") as f:
|
||||
state_dict = json.load(f)
|
||||
|
||||
# Convertir en format d'observation attendu par le modèle
|
||||
observation = {{
|
||||
'board': np.array(state_dict['board'], dtype=np.int8),
|
||||
'active_player': state_dict['active_player'],
|
||||
'dice': np.array(state_dict['dice'], dtype=np.int32),
|
||||
'white_points': state_dict['white_points'],
|
||||
'white_holes': state_dict['white_holes'],
|
||||
'black_points': state_dict['black_points'],
|
||||
'black_holes': state_dict['black_holes'],
|
||||
'turn_stage': state_dict['turn_stage'],
|
||||
}}
|
||||
|
||||
# Prédire l'action
|
||||
action, _ = model.predict(observation)
|
||||
|
||||
# Convertir l'action en format lisible
|
||||
action_dict = {{
|
||||
'action_type': int(action[0]),
|
||||
'from1': int(action[1]),
|
||||
'to1': int(action[2]),
|
||||
'from2': int(action[3]),
|
||||
'to2': int(action[4]),
|
||||
}}
|
||||
|
||||
# Écrire l'action dans un fichier
|
||||
with open("{}", "w") as f:
|
||||
json.dump(action_dict, f)
|
||||
"#,
|
||||
self.model_path, output_path
|
||||
);
|
||||
|
||||
let temp_script_path = "temp_predict.py";
|
||||
let mut script_file = File::create(temp_script_path).ok()?;
|
||||
script_file.write_all(python_script.as_bytes()).ok()?;
|
||||
|
||||
// Exécuter le script Python
|
||||
let status = Command::new("python")
|
||||
.arg(temp_script_path)
|
||||
.status()
|
||||
.ok()?;
|
||||
|
||||
if !status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Lire la prédiction
|
||||
if Path::new(output_path).exists() {
|
||||
let mut file = File::open(output_path).ok()?;
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents).ok()?;
|
||||
|
||||
// Nettoyer les fichiers temporaires
|
||||
std::fs::remove_file(temp_input_path).ok();
|
||||
std::fs::remove_file(temp_script_path).ok();
|
||||
std::fs::remove_file(output_path).ok();
|
||||
|
||||
// Analyser la prédiction
|
||||
let action: ActionJson = serde_json::from_str(&contents).ok()?;
|
||||
Some(action)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for StableBaselines3Strategy {
|
||||
fn get_game(&self) -> &GameState {
|
||||
&self.game
|
||||
}
|
||||
|
||||
fn get_mut_game(&mut self) -> &mut GameState {
|
||||
&mut self.game
|
||||
}
|
||||
|
||||
fn set_color(&mut self, color: Color) {
|
||||
self.color = color;
|
||||
}
|
||||
|
||||
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||
self.player_id = player_id;
|
||||
}
|
||||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
// Utiliser la prédiction du modèle uniquement si c'est une action de type "mark" (1)
|
||||
if let Some(action) = self.predict_action() {
|
||||
if action.action_type == 1 {
|
||||
// Marquer les points calculés par le modèle (ici on utilise la somme des dés comme proxy)
|
||||
return self.game.dice.values.0 + self.game.dice.values.1;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback vers la méthode standard si la prédiction échoue
|
||||
let dice_roll_count = self
|
||||
.get_game()
|
||||
.players
|
||||
.get(&self.player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
|
||||
points_rules.get_points(dice_roll_count).0
|
||||
}
|
||||
|
||||
fn calculate_adv_points(&self) -> u8 {
|
||||
self.calculate_points()
|
||||
}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
// Utiliser la prédiction du modèle uniquement si c'est une action de type "go" (2)
|
||||
if let Some(action) = self.predict_action() {
|
||||
return action.action_type == 2;
|
||||
}
|
||||
|
||||
// Fallback vers la méthode standard si la prédiction échoue
|
||||
true
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
// Utiliser la prédiction du modèle uniquement si c'est une action de type "move" (0)
|
||||
if let Some(action) = self.predict_action() {
|
||||
if action.action_type == 0 {
|
||||
let move1 = CheckerMove::new(action.from1, action.to1).unwrap_or_default();
|
||||
let move2 = CheckerMove::new(action.from2, action.to2).unwrap_or_default();
|
||||
return (move1, move2);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback vers la méthode standard si la prédiction échoue
|
||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
let choosen_move = *possible_moves
|
||||
.first()
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
if self.color == Color::White {
|
||||
choosen_move
|
||||
} else {
|
||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue