trictrac/store/python/trainModel.py

54 lines
1.6 KiB
Python

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from trictracEnv import TricTracEnv
import os
import torch
import sys
# Vérifier si le GPU est disponible
try:
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"GPU disponible: {torch.cuda.get_device_name(0)}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Using device: {device}")
else:
device = torch.device("cpu")
print("GPU non disponible, utilisation du CPU")
print(f"Using device: {device}")
except Exception as e:
print(f"Erreur lors de la vérification de la disponibilité du GPU: {e}")
device = torch.device("cpu")
print(f"Using device: {device}")
# Créer l'environnement vectorisé
env = DummyVecEnv([lambda: TricTracEnv()])
try:
# Créer et entraîner le modèle avec support GPU si disponible
model = PPO("MultiInputPolicy", env, verbose=1, device=device)
print("Démarrage de l'entraînement...")
# Petit entraînement pour tester
# model.learn(total_timesteps=50)
# Entraînement complet
model.learn(total_timesteps=50000)
print("Entraînement terminé")
except Exception as e:
print(f"Erreur lors de l'entraînement: {e}")
sys.exit(1)
# Sauvegarder le modèle
os.makedirs("models", exist_ok=True)
model.save("models/trictrac_ppo")
# Test du modèle entraîné
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs)
# L'interface de DummyVecEnv ne retourne que 4 valeurs
obs, _, done, _ = env.step(action)
if done.any():
break