From ab770f3a342abc05fa410444d23cf3793986744e Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Sun, 2 Mar 2025 15:20:24 +0100 Subject: [PATCH] feat: ai strategy (wip) --- Cargo.lock | 16 +- bot/Cargo.toml | 2 + bot/src/lib.rs | 1 + bot/src/strategy.rs | 1 + bot/src/strategy/stable_baselines3.rs | 276 ++++++++++++++++++++++++++ client_cli/src/app.rs | 9 +- client_cli/src/main.rs | 4 + devenv.lock | 4 +- devenv.nix | 3 +- justfile | 2 + store/Cargo.toml | 5 +- store/python/trainModel.py | 53 +++++ store/python/trictracEnv.py | 100 +++++----- store/src/engine.rs | 2 +- 14 files changed, 421 insertions(+), 57 deletions(-) create mode 100644 bot/src/strategy/stable_baselines3.rs create mode 100644 store/python/trainModel.py diff --git a/Cargo.lock b/Cargo.lock index 54d57f5..f637fe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aead" @@ -120,6 +120,8 @@ name = "bot" version = "0.1.0" dependencies = [ "pretty_assertions", + "serde", + "serde_json", "store", ] @@ -912,6 +914,18 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "serde_json" +version = "1.0.139" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "signal-hook" version = "0.3.17" diff --git a/bot/Cargo.toml b/bot/Cargo.toml index ca8f005..e99e807 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -7,4 +7,6 @@ edition = "2021" [dependencies] pretty_assertions = "1.4.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" store = { path = "../store" } diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 927fbc6..f3e1258 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -2,6 +2,7 @@ mod strategy; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; +pub use strategy::stable_baselines3::StableBaselines3Strategy; pub trait BotStrategy: std::fmt::Debug { fn get_game(&self) -> &GameState; diff --git a/bot/src/strategy.rs b/bot/src/strategy.rs index d1e88f8..6d144fb 100644 --- a/bot/src/strategy.rs +++ b/bot/src/strategy.rs @@ -1,2 +1,3 @@ pub mod client; pub mod default; +pub mod stable_baselines3; diff --git a/bot/src/strategy/stable_baselines3.rs b/bot/src/strategy/stable_baselines3.rs new file mode 100644 index 0000000..124e95d --- /dev/null +++ b/bot/src/strategy/stable_baselines3.rs @@ -0,0 +1,276 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use store::MoveRules; +use std::process::Command; +use std::io::Write; +use std::fs::File; +use std::io::Read; +use std::path::Path; +use serde::{Serialize, Deserialize}; + +#[derive(Debug)] +pub struct StableBaselines3Strategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, + pub model_path: String, +} + +impl Default for StableBaselines3Strategy { + fn default() -> Self { + let game = GameState::default(); + Self { + game, + player_id: 2, + color: Color::Black, + model_path: "models/trictrac_ppo.zip".to_string(), + } + } +} + +#[derive(Serialize, Deserialize)] +struct GameStateJson { + board: Vec, + active_player: u8, + dice: [u8; 2], + white_points: u8, + white_holes: u8, + black_points: u8, + black_holes: u8, + turn_stage: u8, +} + +#[derive(Deserialize)] +struct ActionJson { + action_type: u8, + from1: usize, + to1: usize, + from2: usize, + to2: usize, +} + +impl StableBaselines3Strategy { + pub fn new(model_path: &str) -> Self { + let game = GameState::default(); + Self { + game, + player_id: 2, + color: Color::Black, + model_path: model_path.to_string(), + } + } + + fn get_state_as_json(&self) -> GameStateJson { + // Convertir l'état du jeu en un format compatible avec notre modèle Python + let mut board = vec![0; 24]; + + // Remplir les positions des pièces blanches (valeurs positives) + for (pos, count) in self.game.board.get_color_fields(Color::White) { + if pos < 24 { + board[pos] = count as i8; + } + } + + // Remplir les positions des pièces noires (valeurs négatives) + for (pos, count) in self.game.board.get_color_fields(Color::Black) { + if pos < 24 { + board[pos] = -(count as i8); + } + } + + // Convertir l'étape du tour en entier + let turn_stage = match self.game.turn_stage { + store::TurnStage::RollDice => 0, + store::TurnStage::RollWaiting => 1, + store::TurnStage::MarkPoints => 2, + store::TurnStage::HoldOrGoChoice => 3, + store::TurnStage::Move => 4, + store::TurnStage::MarkAdvPoints => 5, + _ => 0, + }; + + // Récupérer les points et trous des joueurs + let white_points = self.game.players.get(&1).map_or(0, |p| p.points); + let white_holes = self.game.players.get(&1).map_or(0, |p| p.holes); + let black_points = self.game.players.get(&2).map_or(0, |p| p.points); + let black_holes = self.game.players.get(&2).map_or(0, |p| p.holes); + + // Créer l'objet JSON + GameStateJson { + board, + active_player: self.game.active_player_id as u8, + dice: [self.game.dice.values.0, self.game.dice.values.1], + white_points, + white_holes, + black_points, + black_holes, + turn_stage, + } + } + + fn predict_action(&self) -> Option { + // Convertir l'état du jeu en JSON + let state_json = self.get_state_as_json(); + let state_str = serde_json::to_string(&state_json).unwrap(); + + // Écrire l'état dans un fichier temporaire + let temp_input_path = "temp_state.json"; + let mut file = File::create(temp_input_path).ok()?; + file.write_all(state_str.as_bytes()).ok()?; + + // Exécuter le script Python pour faire une prédiction + let output_path = "temp_action.json"; + let python_script = format!( + r#" +import sys +import json +import numpy as np +from stable_baselines3 import PPO +import torch + +# Charger le modèle +model = PPO.load("{}") + +# Lire l'état du jeu +with open("temp_state.json", "r") as f: + state_dict = json.load(f) + +# Convertir en format d'observation attendu par le modèle +observation = {{ + 'board': np.array(state_dict['board'], dtype=np.int8), + 'active_player': state_dict['active_player'], + 'dice': np.array(state_dict['dice'], dtype=np.int32), + 'white_points': state_dict['white_points'], + 'white_holes': state_dict['white_holes'], + 'black_points': state_dict['black_points'], + 'black_holes': state_dict['black_holes'], + 'turn_stage': state_dict['turn_stage'], +}} + +# Prédire l'action +action, _ = model.predict(observation) + +# Convertir l'action en format lisible +action_dict = {{ + 'action_type': int(action[0]), + 'from1': int(action[1]), + 'to1': int(action[2]), + 'from2': int(action[3]), + 'to2': int(action[4]), +}} + +# Écrire l'action dans un fichier +with open("{}", "w") as f: + json.dump(action_dict, f) +"#, + self.model_path, output_path + ); + + let temp_script_path = "temp_predict.py"; + let mut script_file = File::create(temp_script_path).ok()?; + script_file.write_all(python_script.as_bytes()).ok()?; + + // Exécuter le script Python + let status = Command::new("python") + .arg(temp_script_path) + .status() + .ok()?; + + if !status.success() { + return None; + } + + // Lire la prédiction + if Path::new(output_path).exists() { + let mut file = File::open(output_path).ok()?; + let mut contents = String::new(); + file.read_to_string(&mut contents).ok()?; + + // Nettoyer les fichiers temporaires + std::fs::remove_file(temp_input_path).ok(); + std::fs::remove_file(temp_script_path).ok(); + std::fs::remove_file(output_path).ok(); + + // Analyser la prédiction + let action: ActionJson = serde_json::from_str(&contents).ok()?; + Some(action) + } else { + None + } + } +} + +impl BotStrategy for StableBaselines3Strategy { + fn get_game(&self) -> &GameState { + &self.game + } + + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + // Utiliser la prédiction du modèle uniquement si c'est une action de type "mark" (1) + if let Some(action) = self.predict_action() { + if action.action_type == 1 { + // Marquer les points calculés par le modèle (ici on utilise la somme des dés comme proxy) + return self.game.dice.values.0 + self.game.dice.values.1; + } + } + + // Fallback vers la méthode standard si la prédiction échoue + let dice_roll_count = self + .get_game() + .players + .get(&self.player_id) + .unwrap() + .dice_roll_count; + let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice); + points_rules.get_points(dice_roll_count).0 + } + + fn calculate_adv_points(&self) -> u8 { + self.calculate_points() + } + + fn choose_go(&self) -> bool { + // Utiliser la prédiction du modèle uniquement si c'est une action de type "go" (2) + if let Some(action) = self.predict_action() { + return action.action_type == 2; + } + + // Fallback vers la méthode standard si la prédiction échoue + true + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + // Utiliser la prédiction du modèle uniquement si c'est une action de type "move" (0) + if let Some(action) = self.predict_action() { + if action.action_type == 0 { + let move1 = CheckerMove::new(action.from1, action.to1).unwrap_or_default(); + let move2 = CheckerMove::new(action.from2, action.to2).unwrap_or_default(); + return (move1, move2); + } + } + + // Fallback vers la méthode standard si la prédiction échoue + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + let choosen_move = *possible_moves + .first() + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + } + } +} \ No newline at end of file diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index 4f617d3..a2f5244 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,4 +1,4 @@ -use bot::{BotStrategy, DefaultStrategy}; +use bot::{BotStrategy, DefaultStrategy, StableBaselines3Strategy}; use itertools::Itertools; use crate::game_runner::GameRunner; @@ -32,6 +32,13 @@ impl App { "dummy" => { Some(Box::new(DefaultStrategy::default()) as Box) } + "ai" => { + Some(Box::new(StableBaselines3Strategy::default()) as Box) + } + s if s.starts_with("ai:") => { + let path = s.trim_start_matches("ai:"); + Some(Box::new(StableBaselines3Strategy::new(path)) as Box) + } _ => None, }) .collect() diff --git a/client_cli/src/main.rs b/client_cli/src/main.rs index 0e1bcb9..064ae70 100644 --- a/client_cli/src/main.rs +++ b/client_cli/src/main.rs @@ -19,6 +19,10 @@ FLAGS: OPTIONS: --seed SEED Sets the random generator seed --bot STRATEGY_BOT Add a bot player with strategy STRATEGY, a second bot may be added to play against the first : --bot STRATEGY_BOT1,STRATEGY_BOT2 + Available strategies: + - dummy: Default strategy selecting the first valid move + - ai: AI strategy using the default model at models/trictrac_ppo.zip + - ai:/path/to/model.zip: AI strategy using a custom model ARGS: diff --git a/devenv.lock b/devenv.lock index 1bc5867..7ad7913 100644 --- a/devenv.lock +++ b/devenv.lock @@ -75,10 +75,10 @@ ] }, "locked": { - "lastModified": 1740870877, + "lastModified": 1740915799, "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "25d4946dfc2021584f5bde1fbd2aa97353384a95", + "rev": "42b1ba089d2034d910566bf6b40830af6b8ec732", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index b0a6ce1..b1d2d00 100644 --- a/devenv.nix +++ b/devenv.nix @@ -57,9 +57,10 @@ venv.enable = true; venv.requirements = " pip - gym + gymnasium numpy stable-baselines3 + shimmy "; }; diff --git a/justfile b/justfile index caf5ef5..7c2b61a 100644 --- a/justfile +++ b/justfile @@ -17,3 +17,5 @@ profile: pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl +trainbot: + python ./store/python/trainModel.py diff --git a/store/Cargo.toml b/store/Cargo.toml index 9951a03..6d88c56 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -6,9 +6,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "trictrac" +name = "store" # "cdylib" is necessary to produce a shared library for Python to import from. -crate-type = ["cdylib"] +# "rlib" is needed for other Rust crates to use this library +crate-type = ["cdylib", "rlib"] [dependencies] base64 = "0.21.7" diff --git a/store/python/trainModel.py b/store/python/trainModel.py new file mode 100644 index 0000000..c75f1e0 --- /dev/null +++ b/store/python/trainModel.py @@ -0,0 +1,53 @@ +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv +from trictracEnv import TricTracEnv +import os +import torch +import sys + +# Vérifier si le GPU est disponible +try: + if torch.cuda.is_available(): + device = torch.device("cuda") + print(f"GPU disponible: {torch.cuda.get_device_name(0)}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Using device: {device}") + else: + device = torch.device("cpu") + print("GPU non disponible, utilisation du CPU") + print(f"Using device: {device}") +except Exception as e: + print(f"Erreur lors de la vérification de la disponibilité du GPU: {e}") + device = torch.device("cpu") + print(f"Using device: {device}") + +# Créer l'environnement vectorisé +env = DummyVecEnv([lambda: TricTracEnv()]) + +try: + # Créer et entraîner le modèle avec support GPU si disponible + model = PPO("MultiInputPolicy", env, verbose=1, device=device) + + print("Démarrage de l'entraînement...") + # Petit entraînement pour tester + # model.learn(total_timesteps=50) + # Entraînement complet + model.learn(total_timesteps=50000) + print("Entraînement terminé") + +except Exception as e: + print(f"Erreur lors de l'entraînement: {e}") + sys.exit(1) + +# Sauvegarder le modèle +os.makedirs("models", exist_ok=True) +model.save("models/trictrac_ppo") + +# Test du modèle entraîné +obs = env.reset() +for _ in range(100): + action, _ = model.predict(obs) + # L'interface de DummyVecEnv ne retourne que 4 valeurs + obs, _, done, _ = env.step(action) + if done.any(): + break diff --git a/store/python/trictracEnv.py b/store/python/trictracEnv.py index 2f80147..4e40e33 100644 --- a/store/python/trictracEnv.py +++ b/store/python/trictracEnv.py @@ -1,6 +1,6 @@ -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces import trictrac # module Rust exposé via PyO3 from typing import Dict, List, Tuple, Optional, Any, Union @@ -43,14 +43,17 @@ class TricTracEnv(gym.Env): }) # Définition de l'espace d'action - # Format: - # - Action type: 0=move, 1=mark, 2=go - # - Move: (from1, to1, from2, to2) ou zeros - self.action_space = spaces.Dict({ - 'action_type': spaces.Discrete(3), - 'move': spaces.MultiDiscrete([self.MAX_FIELD + 1, self.MAX_FIELD + 1, - self.MAX_FIELD + 1, self.MAX_FIELD + 1]) - }) + # Format: espace multidiscret avec 5 dimensions + # - Action type: 0=move, 1=mark, 2=go (première dimension) + # - Move: (from1, to1, from2, to2) (4 dernières dimensions) + # Pour un total de 5 dimensions + self.action_space = spaces.MultiDiscrete([ + 3, # Action type: 0=move, 1=mark, 2=go + self.MAX_FIELD + 1, # from1 (0 signifie non utilisé) + self.MAX_FIELD + 1, # to1 + self.MAX_FIELD + 1, # from2 + self.MAX_FIELD + 1, # to2 + ]) # État courant self.state = self._get_observation() @@ -62,27 +65,30 @@ class TricTracEnv(gym.Env): self.steps_taken = 0 self.max_steps = 1000 # Limite pour éviter les parties infinies - def reset(self): + def reset(self, seed=None, options=None): """Réinitialise l'environnement et renvoie l'état initial""" + super().reset(seed=seed) + self.game.reset() self.state = self._get_observation() self.state_history = [] self.steps_taken = 0 - return self.state + + return self.state, {} def step(self, action): """ - Exécute une action et retourne (state, reward, done, info) + Exécute une action et retourne (state, reward, terminated, truncated, info) - Action format: - { - 'action_type': 0/1/2, # 0=move, 1=mark, 2=go - 'move': [from1, to1, from2, to2] # Utilisé seulement si action_type=0 - } + Action format: array de 5 entiers + [action_type, from1, to1, from2, to2] + - action_type: 0=move, 1=mark, 2=go + - from1, to1, from2, to2: utilisés seulement si action_type=0 """ - action_type = action['action_type'] + action_type = action[0] reward = 0 - done = False + terminated = False + truncated = False info = {} # Vérifie que l'action est valide pour le joueur humain (id=1) @@ -92,7 +98,7 @@ class TricTracEnv(gym.Env): if is_agent_turn: # Exécute l'action selon son type if action_type == 0: # Move - from1, to1, from2, to2 = action['move'] + from1, to1, from2, to2 = action[1], action[2], action[3], action[4] move_made = self.game.play_move(((from1, to1), (from2, to2))) if not move_made: # Pénaliser les mouvements invalides @@ -126,7 +132,7 @@ class TricTracEnv(gym.Env): # Vérifier si la partie est terminée if self.game.is_done(): - done = True + terminated = True winner = self.game.get_winner() if winner == 1: # Bonus si l'agent gagne @@ -156,7 +162,7 @@ class TricTracEnv(gym.Env): # Limiter la durée des parties self.steps_taken += 1 if self.steps_taken >= self.max_steps: - done = True + truncated = True info['timeout'] = True # Comparer les scores en cas de timeout @@ -168,7 +174,7 @@ class TricTracEnv(gym.Env): info['winner'] = 'opponent' self.state = new_state - return self.state, reward, done, info + return self.state, reward, terminated, truncated, info def _play_opponent_turn(self): """Simule le tour de l'adversaire avec la stratégie choisie""" @@ -291,57 +297,51 @@ class TricTracEnv(gym.Env): turn_stage = state_dict.get('turn_stage') # Masque par défaut (toutes les actions sont invalides) - mask = { - 'action_type': np.zeros(3, dtype=bool), - 'move': np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1, + # Pour le nouveau format d'action: [action_type, from1, to1, from2, to2] + action_type_mask = np.zeros(3, dtype=bool) + move_mask = np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1, self.MAX_FIELD + 1, self.MAX_FIELD + 1), dtype=bool) - } if self.game.get_active_player_id() != 1: - return mask # Pas au tour de l'agent + return action_type_mask, move_mask # Pas au tour de l'agent # Activer les types d'actions valides selon l'étape du tour if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice': - mask['action_type'][0] = True # Activer l'action de mouvement + action_type_mask[0] = True # Activer l'action de mouvement # Activer les mouvements valides valid_moves = self.game.get_available_moves() for ((from1, to1), (from2, to2)) in valid_moves: - mask['move'][from1, to1, from2, to2] = True + move_mask[from1, to1, from2, to2] = True if turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints': - mask['action_type'][1] = True # Activer l'action de marquer des points + action_type_mask[1] = True # Activer l'action de marquer des points if turn_stage == 'HoldOrGoChoice': - mask['action_type'][2] = True # Activer l'action de continuer (Go) + action_type_mask[2] = True # Activer l'action de continuer (Go) - return mask + return action_type_mask, move_mask def sample_valid_action(self): """Échantillonne une action valide selon le masque d'actions""" - mask = self.get_action_mask() + action_type_mask, move_mask = self.get_action_mask() # Trouver les types d'actions valides - valid_action_types = np.where(mask['action_type'])[0] + valid_action_types = np.where(action_type_mask)[0] if len(valid_action_types) == 0: # Aucune action valide (pas le tour de l'agent) - return { - 'action_type': 0, - 'move': np.zeros(4, dtype=np.int32) - } + return np.array([0, 0, 0, 0, 0], dtype=np.int32) # Choisir un type d'action action_type = np.random.choice(valid_action_types) - action = { - 'action_type': action_type, - 'move': np.zeros(4, dtype=np.int32) - } + # Initialiser l'action + action = np.array([action_type, 0, 0, 0, 0], dtype=np.int32) # Si c'est un mouvement, sélectionner un mouvement valide if action_type == 0: - valid_moves = np.where(mask['move']) + valid_moves = np.where(move_mask) if len(valid_moves[0]) > 0: # Sélectionner un mouvement valide aléatoirement idx = np.random.randint(0, len(valid_moves[0])) @@ -349,7 +349,7 @@ class TricTracEnv(gym.Env): to1 = valid_moves[1][idx] from2 = valid_moves[2][idx] to2 = valid_moves[3][idx] - action['move'] = np.array([from1, to1, from2, to2], dtype=np.int32) + action[1:] = [from1, to1, from2, to2] return action @@ -383,7 +383,7 @@ def example_usage(): if __name__ == "__main__": # Tester l'environnement env = TricTracEnv() - obs = env.reset() + obs, _ = env.reset() print("Environnement initialisé") env.render() @@ -391,14 +391,16 @@ if __name__ == "__main__": # Jouer quelques coups aléatoires for _ in range(10): action = env.sample_valid_action() - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) print(f"\nAction: {action}") print(f"Reward: {reward}") + print(f"Terminated: {terminated}") + print(f"Truncated: {truncated}") print(f"Info: {info}") env.render() - if done: + if terminated or truncated: print("Game over!") break diff --git a/store/src/engine.rs b/store/src/engine.rs index bf94559..845e22c 100644 --- a/store/src/engine.rs +++ b/store/src/engine.rs @@ -330,7 +330,7 @@ impl TricTrac { /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. #[pymodule] -fn trictrac(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn store(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(())