feat: ai strategy (wip)

This commit is contained in:
Henri Bourcereau 2025-03-02 15:20:24 +01:00
parent 899a690869
commit ab770f3a34
14 changed files with 421 additions and 57 deletions

View file

@ -6,9 +6,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "trictrac"
name = "store"
# "cdylib" is necessary to produce a shared library for Python to import from.
crate-type = ["cdylib"]
# "rlib" is needed for other Rust crates to use this library
crate-type = ["cdylib", "rlib"]
[dependencies]
base64 = "0.21.7"

View file

@ -0,0 +1,53 @@
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from trictracEnv import TricTracEnv
import os
import torch
import sys
# Vérifier si le GPU est disponible
try:
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"GPU disponible: {torch.cuda.get_device_name(0)}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Using device: {device}")
else:
device = torch.device("cpu")
print("GPU non disponible, utilisation du CPU")
print(f"Using device: {device}")
except Exception as e:
print(f"Erreur lors de la vérification de la disponibilité du GPU: {e}")
device = torch.device("cpu")
print(f"Using device: {device}")
# Créer l'environnement vectorisé
env = DummyVecEnv([lambda: TricTracEnv()])
try:
# Créer et entraîner le modèle avec support GPU si disponible
model = PPO("MultiInputPolicy", env, verbose=1, device=device)
print("Démarrage de l'entraînement...")
# Petit entraînement pour tester
# model.learn(total_timesteps=50)
# Entraînement complet
model.learn(total_timesteps=50000)
print("Entraînement terminé")
except Exception as e:
print(f"Erreur lors de l'entraînement: {e}")
sys.exit(1)
# Sauvegarder le modèle
os.makedirs("models", exist_ok=True)
model.save("models/trictrac_ppo")
# Test du modèle entraîné
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs)
# L'interface de DummyVecEnv ne retourne que 4 valeurs
obs, _, done, _ = env.step(action)
if done.any():
break

View file

@ -1,6 +1,6 @@
import gym
import gymnasium as gym
import numpy as np
from gym import spaces
from gymnasium import spaces
import trictrac # module Rust exposé via PyO3
from typing import Dict, List, Tuple, Optional, Any, Union
@ -43,14 +43,17 @@ class TricTracEnv(gym.Env):
})
# Définition de l'espace d'action
# Format:
# - Action type: 0=move, 1=mark, 2=go
# - Move: (from1, to1, from2, to2) ou zeros
self.action_space = spaces.Dict({
'action_type': spaces.Discrete(3),
'move': spaces.MultiDiscrete([self.MAX_FIELD + 1, self.MAX_FIELD + 1,
self.MAX_FIELD + 1, self.MAX_FIELD + 1])
})
# Format: espace multidiscret avec 5 dimensions
# - Action type: 0=move, 1=mark, 2=go (première dimension)
# - Move: (from1, to1, from2, to2) (4 dernières dimensions)
# Pour un total de 5 dimensions
self.action_space = spaces.MultiDiscrete([
3, # Action type: 0=move, 1=mark, 2=go
self.MAX_FIELD + 1, # from1 (0 signifie non utilisé)
self.MAX_FIELD + 1, # to1
self.MAX_FIELD + 1, # from2
self.MAX_FIELD + 1, # to2
])
# État courant
self.state = self._get_observation()
@ -62,27 +65,30 @@ class TricTracEnv(gym.Env):
self.steps_taken = 0
self.max_steps = 1000 # Limite pour éviter les parties infinies
def reset(self):
def reset(self, seed=None, options=None):
"""Réinitialise l'environnement et renvoie l'état initial"""
super().reset(seed=seed)
self.game.reset()
self.state = self._get_observation()
self.state_history = []
self.steps_taken = 0
return self.state
return self.state, {}
def step(self, action):
"""
Exécute une action et retourne (state, reward, done, info)
Exécute une action et retourne (state, reward, terminated, truncated, info)
Action format:
{
'action_type': 0/1/2, # 0=move, 1=mark, 2=go
'move': [from1, to1, from2, to2] # Utilisé seulement si action_type=0
}
Action format: array de 5 entiers
[action_type, from1, to1, from2, to2]
- action_type: 0=move, 1=mark, 2=go
- from1, to1, from2, to2: utilisés seulement si action_type=0
"""
action_type = action['action_type']
action_type = action[0]
reward = 0
done = False
terminated = False
truncated = False
info = {}
# Vérifie que l'action est valide pour le joueur humain (id=1)
@ -92,7 +98,7 @@ class TricTracEnv(gym.Env):
if is_agent_turn:
# Exécute l'action selon son type
if action_type == 0: # Move
from1, to1, from2, to2 = action['move']
from1, to1, from2, to2 = action[1], action[2], action[3], action[4]
move_made = self.game.play_move(((from1, to1), (from2, to2)))
if not move_made:
# Pénaliser les mouvements invalides
@ -126,7 +132,7 @@ class TricTracEnv(gym.Env):
# Vérifier si la partie est terminée
if self.game.is_done():
done = True
terminated = True
winner = self.game.get_winner()
if winner == 1:
# Bonus si l'agent gagne
@ -156,7 +162,7 @@ class TricTracEnv(gym.Env):
# Limiter la durée des parties
self.steps_taken += 1
if self.steps_taken >= self.max_steps:
done = True
truncated = True
info['timeout'] = True
# Comparer les scores en cas de timeout
@ -168,7 +174,7 @@ class TricTracEnv(gym.Env):
info['winner'] = 'opponent'
self.state = new_state
return self.state, reward, done, info
return self.state, reward, terminated, truncated, info
def _play_opponent_turn(self):
"""Simule le tour de l'adversaire avec la stratégie choisie"""
@ -291,57 +297,51 @@ class TricTracEnv(gym.Env):
turn_stage = state_dict.get('turn_stage')
# Masque par défaut (toutes les actions sont invalides)
mask = {
'action_type': np.zeros(3, dtype=bool),
'move': np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1,
# Pour le nouveau format d'action: [action_type, from1, to1, from2, to2]
action_type_mask = np.zeros(3, dtype=bool)
move_mask = np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1,
self.MAX_FIELD + 1, self.MAX_FIELD + 1), dtype=bool)
}
if self.game.get_active_player_id() != 1:
return mask # Pas au tour de l'agent
return action_type_mask, move_mask # Pas au tour de l'agent
# Activer les types d'actions valides selon l'étape du tour
if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice':
mask['action_type'][0] = True # Activer l'action de mouvement
action_type_mask[0] = True # Activer l'action de mouvement
# Activer les mouvements valides
valid_moves = self.game.get_available_moves()
for ((from1, to1), (from2, to2)) in valid_moves:
mask['move'][from1, to1, from2, to2] = True
move_mask[from1, to1, from2, to2] = True
if turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints':
mask['action_type'][1] = True # Activer l'action de marquer des points
action_type_mask[1] = True # Activer l'action de marquer des points
if turn_stage == 'HoldOrGoChoice':
mask['action_type'][2] = True # Activer l'action de continuer (Go)
action_type_mask[2] = True # Activer l'action de continuer (Go)
return mask
return action_type_mask, move_mask
def sample_valid_action(self):
"""Échantillonne une action valide selon le masque d'actions"""
mask = self.get_action_mask()
action_type_mask, move_mask = self.get_action_mask()
# Trouver les types d'actions valides
valid_action_types = np.where(mask['action_type'])[0]
valid_action_types = np.where(action_type_mask)[0]
if len(valid_action_types) == 0:
# Aucune action valide (pas le tour de l'agent)
return {
'action_type': 0,
'move': np.zeros(4, dtype=np.int32)
}
return np.array([0, 0, 0, 0, 0], dtype=np.int32)
# Choisir un type d'action
action_type = np.random.choice(valid_action_types)
action = {
'action_type': action_type,
'move': np.zeros(4, dtype=np.int32)
}
# Initialiser l'action
action = np.array([action_type, 0, 0, 0, 0], dtype=np.int32)
# Si c'est un mouvement, sélectionner un mouvement valide
if action_type == 0:
valid_moves = np.where(mask['move'])
valid_moves = np.where(move_mask)
if len(valid_moves[0]) > 0:
# Sélectionner un mouvement valide aléatoirement
idx = np.random.randint(0, len(valid_moves[0]))
@ -349,7 +349,7 @@ class TricTracEnv(gym.Env):
to1 = valid_moves[1][idx]
from2 = valid_moves[2][idx]
to2 = valid_moves[3][idx]
action['move'] = np.array([from1, to1, from2, to2], dtype=np.int32)
action[1:] = [from1, to1, from2, to2]
return action
@ -383,7 +383,7 @@ def example_usage():
if __name__ == "__main__":
# Tester l'environnement
env = TricTracEnv()
obs = env.reset()
obs, _ = env.reset()
print("Environnement initialisé")
env.render()
@ -391,14 +391,16 @@ if __name__ == "__main__":
# Jouer quelques coups aléatoires
for _ in range(10):
action = env.sample_valid_action()
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
print(f"\nAction: {action}")
print(f"Reward: {reward}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")
print(f"Info: {info}")
env.render()
if done:
if terminated or truncated:
print("Game over!")
break

View file

@ -330,7 +330,7 @@ impl TricTrac {
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn trictrac(m: &Bound<'_, PyModule>) -> PyResult<()> {
fn store(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<TricTrac>()?;
Ok(())