feat: ai strategy (wip)

This commit is contained in:
Henri Bourcereau 2025-03-02 15:20:24 +01:00
parent 899a690869
commit ab770f3a34
14 changed files with 421 additions and 57 deletions

16
Cargo.lock generated
View file

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "aead" name = "aead"
@ -120,6 +120,8 @@ name = "bot"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"pretty_assertions", "pretty_assertions",
"serde",
"serde_json",
"store", "store",
] ]
@ -912,6 +914,18 @@ dependencies = [
"syn 2.0.79", "syn 2.0.79",
] ]
[[package]]
name = "serde_json"
version = "1.0.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]] [[package]]
name = "signal-hook" name = "signal-hook"
version = "0.3.17" version = "0.3.17"

View file

@ -7,4 +7,6 @@ edition = "2021"
[dependencies] [dependencies]
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
store = { path = "../store" } store = { path = "../store" }

View file

@ -2,6 +2,7 @@ mod strategy;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy; pub use strategy::default::DefaultStrategy;
pub use strategy::stable_baselines3::StableBaselines3Strategy;
pub trait BotStrategy: std::fmt::Debug { pub trait BotStrategy: std::fmt::Debug {
fn get_game(&self) -> &GameState; fn get_game(&self) -> &GameState;

View file

@ -1,2 +1,3 @@
pub mod client; pub mod client;
pub mod default; pub mod default;
pub mod stable_baselines3;

View file

@ -0,0 +1,276 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use store::MoveRules;
use std::process::Command;
use std::io::Write;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use serde::{Serialize, Deserialize};
#[derive(Debug)]
pub struct StableBaselines3Strategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
pub model_path: String,
}
impl Default for StableBaselines3Strategy {
fn default() -> Self {
let game = GameState::default();
Self {
game,
player_id: 2,
color: Color::Black,
model_path: "models/trictrac_ppo.zip".to_string(),
}
}
}
#[derive(Serialize, Deserialize)]
struct GameStateJson {
board: Vec<i8>,
active_player: u8,
dice: [u8; 2],
white_points: u8,
white_holes: u8,
black_points: u8,
black_holes: u8,
turn_stage: u8,
}
#[derive(Deserialize)]
struct ActionJson {
action_type: u8,
from1: usize,
to1: usize,
from2: usize,
to2: usize,
}
impl StableBaselines3Strategy {
pub fn new(model_path: &str) -> Self {
let game = GameState::default();
Self {
game,
player_id: 2,
color: Color::Black,
model_path: model_path.to_string(),
}
}
fn get_state_as_json(&self) -> GameStateJson {
// Convertir l'état du jeu en un format compatible avec notre modèle Python
let mut board = vec![0; 24];
// Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 {
board[pos] = count as i8;
}
}
// Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 {
board[pos] = -(count as i8);
}
}
// Convertir l'étape du tour en entier
let turn_stage = match self.game.turn_stage {
store::TurnStage::RollDice => 0,
store::TurnStage::RollWaiting => 1,
store::TurnStage::MarkPoints => 2,
store::TurnStage::HoldOrGoChoice => 3,
store::TurnStage::Move => 4,
store::TurnStage::MarkAdvPoints => 5,
_ => 0,
};
// Récupérer les points et trous des joueurs
let white_points = self.game.players.get(&1).map_or(0, |p| p.points);
let white_holes = self.game.players.get(&1).map_or(0, |p| p.holes);
let black_points = self.game.players.get(&2).map_or(0, |p| p.points);
let black_holes = self.game.players.get(&2).map_or(0, |p| p.holes);
// Créer l'objet JSON
GameStateJson {
board,
active_player: self.game.active_player_id as u8,
dice: [self.game.dice.values.0, self.game.dice.values.1],
white_points,
white_holes,
black_points,
black_holes,
turn_stage,
}
}
fn predict_action(&self) -> Option<ActionJson> {
// Convertir l'état du jeu en JSON
let state_json = self.get_state_as_json();
let state_str = serde_json::to_string(&state_json).unwrap();
// Écrire l'état dans un fichier temporaire
let temp_input_path = "temp_state.json";
let mut file = File::create(temp_input_path).ok()?;
file.write_all(state_str.as_bytes()).ok()?;
// Exécuter le script Python pour faire une prédiction
let output_path = "temp_action.json";
let python_script = format!(
r#"
import sys
import json
import numpy as np
from stable_baselines3 import PPO
import torch
# Charger le modèle
model = PPO.load("{}")
# Lire l'état du jeu
with open("temp_state.json", "r") as f:
state_dict = json.load(f)
# Convertir en format d'observation attendu par le modèle
observation = {{
'board': np.array(state_dict['board'], dtype=np.int8),
'active_player': state_dict['active_player'],
'dice': np.array(state_dict['dice'], dtype=np.int32),
'white_points': state_dict['white_points'],
'white_holes': state_dict['white_holes'],
'black_points': state_dict['black_points'],
'black_holes': state_dict['black_holes'],
'turn_stage': state_dict['turn_stage'],
}}
# Prédire l'action
action, _ = model.predict(observation)
# Convertir l'action en format lisible
action_dict = {{
'action_type': int(action[0]),
'from1': int(action[1]),
'to1': int(action[2]),
'from2': int(action[3]),
'to2': int(action[4]),
}}
# Écrire l'action dans un fichier
with open("{}", "w") as f:
json.dump(action_dict, f)
"#,
self.model_path, output_path
);
let temp_script_path = "temp_predict.py";
let mut script_file = File::create(temp_script_path).ok()?;
script_file.write_all(python_script.as_bytes()).ok()?;
// Exécuter le script Python
let status = Command::new("python")
.arg(temp_script_path)
.status()
.ok()?;
if !status.success() {
return None;
}
// Lire la prédiction
if Path::new(output_path).exists() {
let mut file = File::open(output_path).ok()?;
let mut contents = String::new();
file.read_to_string(&mut contents).ok()?;
// Nettoyer les fichiers temporaires
std::fs::remove_file(temp_input_path).ok();
std::fs::remove_file(temp_script_path).ok();
std::fs::remove_file(output_path).ok();
// Analyser la prédiction
let action: ActionJson = serde_json::from_str(&contents).ok()?;
Some(action)
} else {
None
}
}
}
impl BotStrategy for StableBaselines3Strategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "mark" (1)
if let Some(action) = self.predict_action() {
if action.action_type == 1 {
// Marquer les points calculés par le modèle (ici on utilise la somme des dés comme proxy)
return self.game.dice.values.0 + self.game.dice.values.1;
}
}
// Fallback vers la méthode standard si la prédiction échoue
let dice_roll_count = self
.get_game()
.players
.get(&self.player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
points_rules.get_points(dice_roll_count).0
}
fn calculate_adv_points(&self) -> u8 {
self.calculate_points()
}
fn choose_go(&self) -> bool {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "go" (2)
if let Some(action) = self.predict_action() {
return action.action_type == 2;
}
// Fallback vers la méthode standard si la prédiction échoue
true
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser la prédiction du modèle uniquement si c'est une action de type "move" (0)
if let Some(action) = self.predict_action() {
if action.action_type == 0 {
let move1 = CheckerMove::new(action.from1, action.to1).unwrap_or_default();
let move2 = CheckerMove::new(action.from2, action.to2).unwrap_or_default();
return (move1, move2);
}
}
// Fallback vers la méthode standard si la prédiction échoue
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let choosen_move = *possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
}
}
}

View file

@ -1,4 +1,4 @@
use bot::{BotStrategy, DefaultStrategy}; use bot::{BotStrategy, DefaultStrategy, StableBaselines3Strategy};
use itertools::Itertools; use itertools::Itertools;
use crate::game_runner::GameRunner; use crate::game_runner::GameRunner;
@ -32,6 +32,13 @@ impl App {
"dummy" => { "dummy" => {
Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>)
} }
"ai" => {
Some(Box::new(StableBaselines3Strategy::default()) as Box<dyn BotStrategy>)
}
s if s.starts_with("ai:") => {
let path = s.trim_start_matches("ai:");
Some(Box::new(StableBaselines3Strategy::new(path)) as Box<dyn BotStrategy>)
}
_ => None, _ => None,
}) })
.collect() .collect()

View file

@ -19,6 +19,10 @@ FLAGS:
OPTIONS: OPTIONS:
--seed SEED Sets the random generator seed --seed SEED Sets the random generator seed
--bot STRATEGY_BOT Add a bot player with strategy STRATEGY, a second bot may be added to play against the first : --bot STRATEGY_BOT1,STRATEGY_BOT2 --bot STRATEGY_BOT Add a bot player with strategy STRATEGY, a second bot may be added to play against the first : --bot STRATEGY_BOT1,STRATEGY_BOT2
Available strategies:
- dummy: Default strategy selecting the first valid move
- ai: AI strategy using the default model at models/trictrac_ppo.zip
- ai:/path/to/model.zip: AI strategy using a custom model
ARGS: ARGS:
<INPUT> <INPUT>

View file

@ -75,10 +75,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1740870877, "lastModified": 1740915799,
"owner": "cachix", "owner": "cachix",
"repo": "pre-commit-hooks.nix", "repo": "pre-commit-hooks.nix",
"rev": "25d4946dfc2021584f5bde1fbd2aa97353384a95", "rev": "42b1ba089d2034d910566bf6b40830af6b8ec732",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -57,9 +57,10 @@
venv.enable = true; venv.enable = true;
venv.requirements = " venv.requirements = "
pip pip
gym gymnasium
numpy numpy
stable-baselines3 stable-baselines3
shimmy
"; ";
}; };

View file

@ -17,3 +17,5 @@ profile:
pythonlib: pythonlib:
maturin build -m store/Cargo.toml --release maturin build -m store/Cargo.toml --release
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot:
python ./store/python/trainModel.py

View file

@ -6,9 +6,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib] [lib]
name = "trictrac" name = "store"
# "cdylib" is necessary to produce a shared library for Python to import from. # "cdylib" is necessary to produce a shared library for Python to import from.
crate-type = ["cdylib"] # "rlib" is needed for other Rust crates to use this library
crate-type = ["cdylib", "rlib"]
[dependencies] [dependencies]
base64 = "0.21.7" base64 = "0.21.7"

View file

@ -0,0 +1,53 @@
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from trictracEnv import TricTracEnv
import os
import torch
import sys
# Vérifier si le GPU est disponible
try:
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"GPU disponible: {torch.cuda.get_device_name(0)}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Using device: {device}")
else:
device = torch.device("cpu")
print("GPU non disponible, utilisation du CPU")
print(f"Using device: {device}")
except Exception as e:
print(f"Erreur lors de la vérification de la disponibilité du GPU: {e}")
device = torch.device("cpu")
print(f"Using device: {device}")
# Créer l'environnement vectorisé
env = DummyVecEnv([lambda: TricTracEnv()])
try:
# Créer et entraîner le modèle avec support GPU si disponible
model = PPO("MultiInputPolicy", env, verbose=1, device=device)
print("Démarrage de l'entraînement...")
# Petit entraînement pour tester
# model.learn(total_timesteps=50)
# Entraînement complet
model.learn(total_timesteps=50000)
print("Entraînement terminé")
except Exception as e:
print(f"Erreur lors de l'entraînement: {e}")
sys.exit(1)
# Sauvegarder le modèle
os.makedirs("models", exist_ok=True)
model.save("models/trictrac_ppo")
# Test du modèle entraîné
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs)
# L'interface de DummyVecEnv ne retourne que 4 valeurs
obs, _, done, _ = env.step(action)
if done.any():
break

View file

@ -1,6 +1,6 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
from gym import spaces from gymnasium import spaces
import trictrac # module Rust exposé via PyO3 import trictrac # module Rust exposé via PyO3
from typing import Dict, List, Tuple, Optional, Any, Union from typing import Dict, List, Tuple, Optional, Any, Union
@ -43,14 +43,17 @@ class TricTracEnv(gym.Env):
}) })
# Définition de l'espace d'action # Définition de l'espace d'action
# Format: # Format: espace multidiscret avec 5 dimensions
# - Action type: 0=move, 1=mark, 2=go # - Action type: 0=move, 1=mark, 2=go (première dimension)
# - Move: (from1, to1, from2, to2) ou zeros # - Move: (from1, to1, from2, to2) (4 dernières dimensions)
self.action_space = spaces.Dict({ # Pour un total de 5 dimensions
'action_type': spaces.Discrete(3), self.action_space = spaces.MultiDiscrete([
'move': spaces.MultiDiscrete([self.MAX_FIELD + 1, self.MAX_FIELD + 1, 3, # Action type: 0=move, 1=mark, 2=go
self.MAX_FIELD + 1, self.MAX_FIELD + 1]) self.MAX_FIELD + 1, # from1 (0 signifie non utilisé)
}) self.MAX_FIELD + 1, # to1
self.MAX_FIELD + 1, # from2
self.MAX_FIELD + 1, # to2
])
# État courant # État courant
self.state = self._get_observation() self.state = self._get_observation()
@ -62,27 +65,30 @@ class TricTracEnv(gym.Env):
self.steps_taken = 0 self.steps_taken = 0
self.max_steps = 1000 # Limite pour éviter les parties infinies self.max_steps = 1000 # Limite pour éviter les parties infinies
def reset(self): def reset(self, seed=None, options=None):
"""Réinitialise l'environnement et renvoie l'état initial""" """Réinitialise l'environnement et renvoie l'état initial"""
super().reset(seed=seed)
self.game.reset() self.game.reset()
self.state = self._get_observation() self.state = self._get_observation()
self.state_history = [] self.state_history = []
self.steps_taken = 0 self.steps_taken = 0
return self.state
return self.state, {}
def step(self, action): def step(self, action):
""" """
Exécute une action et retourne (state, reward, done, info) Exécute une action et retourne (state, reward, terminated, truncated, info)
Action format: Action format: array de 5 entiers
{ [action_type, from1, to1, from2, to2]
'action_type': 0/1/2, # 0=move, 1=mark, 2=go - action_type: 0=move, 1=mark, 2=go
'move': [from1, to1, from2, to2] # Utilisé seulement si action_type=0 - from1, to1, from2, to2: utilisés seulement si action_type=0
}
""" """
action_type = action['action_type'] action_type = action[0]
reward = 0 reward = 0
done = False terminated = False
truncated = False
info = {} info = {}
# Vérifie que l'action est valide pour le joueur humain (id=1) # Vérifie que l'action est valide pour le joueur humain (id=1)
@ -92,7 +98,7 @@ class TricTracEnv(gym.Env):
if is_agent_turn: if is_agent_turn:
# Exécute l'action selon son type # Exécute l'action selon son type
if action_type == 0: # Move if action_type == 0: # Move
from1, to1, from2, to2 = action['move'] from1, to1, from2, to2 = action[1], action[2], action[3], action[4]
move_made = self.game.play_move(((from1, to1), (from2, to2))) move_made = self.game.play_move(((from1, to1), (from2, to2)))
if not move_made: if not move_made:
# Pénaliser les mouvements invalides # Pénaliser les mouvements invalides
@ -126,7 +132,7 @@ class TricTracEnv(gym.Env):
# Vérifier si la partie est terminée # Vérifier si la partie est terminée
if self.game.is_done(): if self.game.is_done():
done = True terminated = True
winner = self.game.get_winner() winner = self.game.get_winner()
if winner == 1: if winner == 1:
# Bonus si l'agent gagne # Bonus si l'agent gagne
@ -156,7 +162,7 @@ class TricTracEnv(gym.Env):
# Limiter la durée des parties # Limiter la durée des parties
self.steps_taken += 1 self.steps_taken += 1
if self.steps_taken >= self.max_steps: if self.steps_taken >= self.max_steps:
done = True truncated = True
info['timeout'] = True info['timeout'] = True
# Comparer les scores en cas de timeout # Comparer les scores en cas de timeout
@ -168,7 +174,7 @@ class TricTracEnv(gym.Env):
info['winner'] = 'opponent' info['winner'] = 'opponent'
self.state = new_state self.state = new_state
return self.state, reward, done, info return self.state, reward, terminated, truncated, info
def _play_opponent_turn(self): def _play_opponent_turn(self):
"""Simule le tour de l'adversaire avec la stratégie choisie""" """Simule le tour de l'adversaire avec la stratégie choisie"""
@ -291,57 +297,51 @@ class TricTracEnv(gym.Env):
turn_stage = state_dict.get('turn_stage') turn_stage = state_dict.get('turn_stage')
# Masque par défaut (toutes les actions sont invalides) # Masque par défaut (toutes les actions sont invalides)
mask = { # Pour le nouveau format d'action: [action_type, from1, to1, from2, to2]
'action_type': np.zeros(3, dtype=bool), action_type_mask = np.zeros(3, dtype=bool)
'move': np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1, move_mask = np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1,
self.MAX_FIELD + 1, self.MAX_FIELD + 1), dtype=bool) self.MAX_FIELD + 1, self.MAX_FIELD + 1), dtype=bool)
}
if self.game.get_active_player_id() != 1: if self.game.get_active_player_id() != 1:
return mask # Pas au tour de l'agent return action_type_mask, move_mask # Pas au tour de l'agent
# Activer les types d'actions valides selon l'étape du tour # Activer les types d'actions valides selon l'étape du tour
if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice': if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice':
mask['action_type'][0] = True # Activer l'action de mouvement action_type_mask[0] = True # Activer l'action de mouvement
# Activer les mouvements valides # Activer les mouvements valides
valid_moves = self.game.get_available_moves() valid_moves = self.game.get_available_moves()
for ((from1, to1), (from2, to2)) in valid_moves: for ((from1, to1), (from2, to2)) in valid_moves:
mask['move'][from1, to1, from2, to2] = True move_mask[from1, to1, from2, to2] = True
if turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints': if turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints':
mask['action_type'][1] = True # Activer l'action de marquer des points action_type_mask[1] = True # Activer l'action de marquer des points
if turn_stage == 'HoldOrGoChoice': if turn_stage == 'HoldOrGoChoice':
mask['action_type'][2] = True # Activer l'action de continuer (Go) action_type_mask[2] = True # Activer l'action de continuer (Go)
return mask return action_type_mask, move_mask
def sample_valid_action(self): def sample_valid_action(self):
"""Échantillonne une action valide selon le masque d'actions""" """Échantillonne une action valide selon le masque d'actions"""
mask = self.get_action_mask() action_type_mask, move_mask = self.get_action_mask()
# Trouver les types d'actions valides # Trouver les types d'actions valides
valid_action_types = np.where(mask['action_type'])[0] valid_action_types = np.where(action_type_mask)[0]
if len(valid_action_types) == 0: if len(valid_action_types) == 0:
# Aucune action valide (pas le tour de l'agent) # Aucune action valide (pas le tour de l'agent)
return { return np.array([0, 0, 0, 0, 0], dtype=np.int32)
'action_type': 0,
'move': np.zeros(4, dtype=np.int32)
}
# Choisir un type d'action # Choisir un type d'action
action_type = np.random.choice(valid_action_types) action_type = np.random.choice(valid_action_types)
action = { # Initialiser l'action
'action_type': action_type, action = np.array([action_type, 0, 0, 0, 0], dtype=np.int32)
'move': np.zeros(4, dtype=np.int32)
}
# Si c'est un mouvement, sélectionner un mouvement valide # Si c'est un mouvement, sélectionner un mouvement valide
if action_type == 0: if action_type == 0:
valid_moves = np.where(mask['move']) valid_moves = np.where(move_mask)
if len(valid_moves[0]) > 0: if len(valid_moves[0]) > 0:
# Sélectionner un mouvement valide aléatoirement # Sélectionner un mouvement valide aléatoirement
idx = np.random.randint(0, len(valid_moves[0])) idx = np.random.randint(0, len(valid_moves[0]))
@ -349,7 +349,7 @@ class TricTracEnv(gym.Env):
to1 = valid_moves[1][idx] to1 = valid_moves[1][idx]
from2 = valid_moves[2][idx] from2 = valid_moves[2][idx]
to2 = valid_moves[3][idx] to2 = valid_moves[3][idx]
action['move'] = np.array([from1, to1, from2, to2], dtype=np.int32) action[1:] = [from1, to1, from2, to2]
return action return action
@ -383,7 +383,7 @@ def example_usage():
if __name__ == "__main__": if __name__ == "__main__":
# Tester l'environnement # Tester l'environnement
env = TricTracEnv() env = TricTracEnv()
obs = env.reset() obs, _ = env.reset()
print("Environnement initialisé") print("Environnement initialisé")
env.render() env.render()
@ -391,14 +391,16 @@ if __name__ == "__main__":
# Jouer quelques coups aléatoires # Jouer quelques coups aléatoires
for _ in range(10): for _ in range(10):
action = env.sample_valid_action() action = env.sample_valid_action()
obs, reward, done, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
print(f"\nAction: {action}") print(f"\nAction: {action}")
print(f"Reward: {reward}") print(f"Reward: {reward}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")
print(f"Info: {info}") print(f"Info: {info}")
env.render() env.render()
if done: if terminated or truncated:
print("Game over!") print("Game over!")
break break

View file

@ -330,7 +330,7 @@ impl TricTrac {
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module. /// import the module.
#[pymodule] #[pymodule]
fn trictrac(m: &Bound<'_, PyModule>) -> PyResult<()> { fn store(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<TricTrac>()?; m.add_class::<TricTrac>()?;
Ok(()) Ok(())