chore:refacto clippy

This commit is contained in:
Henri Bourcereau 2025-08-17 15:59:53 +02:00
parent db9560dfac
commit 1dc29d0ff0
19 changed files with 98 additions and 95 deletions

View file

@ -4,6 +4,11 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl /// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct TrictracState { pub struct TrictracState {
@ -165,8 +170,7 @@ impl Environment for TrictracEnvironment {
let trictrac_action = Self::convert_action(action); let trictrac_action = Self::convert_action(action);
let mut reward = 0.0; let mut reward = 0.0;
let mut is_rollpoint = false; let is_rollpoint;
let mut terminated = false;
// Exécuter l'action si c'est le tour de l'agent DQN // Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id { if self.game.active_player_id == self.active_player_id {
@ -175,7 +179,7 @@ impl Environment for TrictracEnvironment {
if is_rollpoint { if is_rollpoint {
self.pointrolls_count += 1; self.pointrolls_count += 1;
} }
if reward != Self::ERROR_REWARD { if reward != ERROR_REWARD {
self.goodmoves_count += 1; self.goodmoves_count += 1;
} }
} else { } else {
@ -199,9 +203,9 @@ impl Environment for TrictracEnvironment {
// Récompense finale basée sur le résultat // Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() { if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id { if winner_id == self.active_player_id {
reward += 50.0; // Victoire reward += WIN_POINTS; // Victoire
} else { } else {
reward -= 25.0; // Défaite reward -= WIN_POINTS; // Défaite
} }
} }
} }
@ -223,15 +227,13 @@ impl Environment for TrictracEnvironment {
} }
impl TrictracEnvironment { impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index( fn convert_valid_action_index(
&self, &self,
action: TrictracAction, action: TrictracAction,
@ -265,7 +267,6 @@ impl TrictracEnvironment {
let event = match action { let event = match action {
TrictracAction::Roll => { TrictracAction::Roll => {
// Lancer les dés // Lancer les dés
reward += 0.1;
Some(GameEvent::Roll { Some(GameEvent::Roll {
player_id: self.active_player_id, player_id: self.active_player_id,
}) })
@ -273,7 +274,6 @@ impl TrictracEnvironment {
// TrictracAction::Mark => { // TrictracAction::Mark => {
// // Marquer des points // // Marquer des points
// let points = self.game. // let points = self.game.
// reward += 0.1 * points as f32;
// Some(GameEvent::Mark { // Some(GameEvent::Mark {
// player_id: self.active_player_id, // player_id: self.active_player_id,
// points, // points,
@ -281,7 +281,6 @@ impl TrictracEnvironment {
// } // }
TrictracAction::Go => { TrictracAction::Go => {
// Continuer après avoir gagné un trou // Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go { Some(GameEvent::Go {
player_id: self.active_player_id, player_id: self.active_player_id,
}) })
@ -308,7 +307,10 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone(); let mut tmp_board = self.game.board.clone();
tmp_board.move_checker(color, checker_move1); let move_result = tmp_board.move_checker(color, checker_move1);
if move_result.is_err() {
panic!("Error while moving checker {move_result:?}")
}
let from2 = tmp_board let from2 = tmp_board
.get_checker_field(color, checker2 as u8) .get_checker_field(color, checker2 as u8)
.unwrap_or(0); .unwrap_or(0);
@ -324,7 +326,6 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move { Some(GameEvent::Move {
player_id: self.active_player_id, player_id: self.active_player_id,
moves: (checker_move1, checker_move2), moves: (checker_move1, checker_move2),
@ -336,7 +337,7 @@ impl TrictracEnvironment {
if let Some(event) = event { if let Some(event) = event {
if self.game.validate(&event) { if self.game.validate(&event) {
self.game.consume(&event); self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll // Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) { if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -350,7 +351,7 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) { if self.game.validate(&dice_event) {
self.game.consume(&dice_event); self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points; let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32; reward += REWARD_RATIO * (points - adv_points) as f32;
if points > 0 { if points > 0 {
is_rollpoint = true; is_rollpoint = true;
// println!("info: rolled for {reward}"); // println!("info: rolled for {reward}");
@ -362,7 +363,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide // Pénalité pour action invalide
// on annule les précédents reward // on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques // et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD; reward = ERROR_REWARD;
} }
} }
@ -458,7 +459,7 @@ impl TrictracEnvironment {
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count); let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points // Récompense proportionnelle aux points
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; reward -= REWARD_RATIO * (points - adv_points) as f32;
} }
} }
} }

View file

@ -16,8 +16,8 @@ fn main() {
let conf = dqn_model::DqnConfig { let conf = dqn_model::DqnConfig {
// defaults // defaults
num_episodes: 40, // 40 num_episodes: 40, // 40
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 1000, // 1000 max steps by episode max_steps: 2000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128) dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05 eps_end: 0.05, // 0.05

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl::{
environment::{TrictracAction, TrictracEnvironment}, environment::{TrictracAction, TrictracEnvironment},
}; };
use crate::dqn::dqn_common::get_valid_action_indices; use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder}; use burn::record::{CompactRecorder, Recorder};

View file

@ -4,6 +4,11 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl /// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct TrictracState { pub struct TrictracState {
@ -168,16 +173,13 @@ impl Environment for TrictracEnvironment {
let is_rollpoint; let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN // Exécuter l'action si c'est le tour de l'agent DQN
let mut has_played = false;
if self.game.active_player_id == self.active_player_id { if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action { if let Some(action) = trictrac_action {
let str_action = format!("{action:?}");
(reward, is_rollpoint) = self.execute_action(action); (reward, is_rollpoint) = self.execute_action(action);
if is_rollpoint { if is_rollpoint {
self.pointrolls_count += 1; self.pointrolls_count += 1;
} }
if reward != Self::ERROR_REWARD { if reward != ERROR_REWARD {
has_played = true;
self.goodmoves_count += 1; self.goodmoves_count += 1;
// println!("{str_action}"); // println!("{str_action}");
} }
@ -203,9 +205,9 @@ impl Environment for TrictracEnvironment {
// Récompense finale basée sur le résultat // Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() { if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id { if winner_id == self.active_player_id {
reward += 50.0; // Victoire reward += WIN_POINTS; // Victoire
} else { } else {
reward -= 25.0; // Défaite reward -= WIN_POINTS; // Défaite
} }
} }
} }
@ -226,15 +228,13 @@ impl Environment for TrictracEnvironment {
} }
impl TrictracEnvironment { impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index( fn convert_valid_action_index(
&self, &self,
action: TrictracAction, action: TrictracAction,
@ -269,7 +269,6 @@ impl TrictracEnvironment {
let event = match action { let event = match action {
TrictracAction::Roll => { TrictracAction::Roll => {
// Lancer les dés // Lancer les dés
reward += 0.1;
need_roll = true; need_roll = true;
Some(GameEvent::Roll { Some(GameEvent::Roll {
player_id: self.active_player_id, player_id: self.active_player_id,
@ -286,7 +285,6 @@ impl TrictracEnvironment {
// } // }
TrictracAction::Go => { TrictracAction::Go => {
// Continuer après avoir gagné un trou // Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go { Some(GameEvent::Go {
player_id: self.active_player_id, player_id: self.active_player_id,
}) })
@ -315,7 +313,6 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move { Some(GameEvent::Move {
player_id: self.active_player_id, player_id: self.active_player_id,
moves: (checker_move1, checker_move2), moves: (checker_move1, checker_move2),
@ -327,7 +324,7 @@ impl TrictracEnvironment {
if let Some(event) = event { if let Some(event) = event {
if self.game.validate(&event) { if self.game.validate(&event) {
self.game.consume(&event); self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll // Simuler le résultat des dés après un Roll
// if matches!(action, TrictracAction::Roll) { // if matches!(action, TrictracAction::Roll) {
if need_roll { if need_roll {
@ -343,7 +340,7 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) { if self.game.validate(&dice_event) {
self.game.consume(&dice_event); self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points; let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32; reward += REWARD_RATIO * (points - adv_points) as f32;
if points > 0 { if points > 0 {
is_rollpoint = true; is_rollpoint = true;
// println!("info: rolled for {reward}"); // println!("info: rolled for {reward}");
@ -355,7 +352,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide // Pénalité pour action invalide
// on annule les précédents reward // on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques // et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD; reward = ERROR_REWARD;
} }
} }
@ -399,18 +396,18 @@ impl TrictracEnvironment {
} }
TurnStage::MarkPoints => { TurnStage::MarkPoints => {
panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); panic!("in play_opponent_if_needed > TurnStage::MarkPoints");
let dice_roll_count = self // let dice_roll_count = self
.game // .game
.players // .players
.get(&self.opponent_id) // .get(&self.opponent_id)
.unwrap() // .unwrap()
.dice_roll_count; // .dice_roll_count;
let points_rules = // let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); // PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
GameEvent::Mark { // GameEvent::Mark {
player_id: self.opponent_id, // player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).0, // points: points_rules.get_points(dice_roll_count).0,
} // }
} }
TurnStage::MarkAdvPoints => { TurnStage::MarkAdvPoints => {
let dice_roll_count = self let dice_roll_count = self
@ -454,7 +451,7 @@ impl TrictracEnvironment {
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count); let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points // Récompense proportionnelle aux points
let adv_reward = Self::REWARD_RATIO * (points - adv_points) as f32; let adv_reward = REWARD_RATIO * (points - adv_points) as f32;
reward -= adv_reward; reward -= adv_reward;
// if adv_reward != 0.0 { // if adv_reward != 0.0 {
// println!("info: opponent : {adv_reward} -> {reward}"); // println!("info: opponent : {adv_reward} -> {reward}");

View file

@ -16,15 +16,15 @@ fn main() {
let conf = dqn_model::DqnConfig { let conf = dqn_model::DqnConfig {
// defaults // defaults
num_episodes: 40, // 40 num_episodes: 40, // 40
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) min_steps: 2000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 3000, // 1000 max steps by episode max_steps: 4000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128) dense_size: 128, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05 eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower // eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay); // used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode // epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ? eps_decay: 1000.0, // 1000 ?
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl_big::{
environment::{TrictracAction, TrictracEnvironment}, environment::{TrictracAction, TrictracEnvironment},
}; };
use crate::dqn::dqn_common_big::get_valid_action_indices; use crate::dqn::dqn_common_big::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder}; use burn::record::{CompactRecorder, Recorder};

View file

@ -382,10 +382,9 @@ impl TrictracEnvironment {
.dice_roll_count; .dice_roll_count;
let points_rules = let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
GameEvent::Mark { GameEvent::Mark {
player_id: self.opponent_id, player_id: self.opponent_id,
points, points: points_rules.get_points(dice_roll_count).0,
} }
} }
TurnStage::MarkAdvPoints => { TurnStage::MarkAdvPoints => {

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl_valid::{
environment::{TrictracAction, TrictracEnvironment}, environment::{TrictracAction, TrictracEnvironment},
}; };
use crate::dqn::dqn_common::get_valid_action_indices; use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder}; use burn::record::{CompactRecorder, Recorder};

View file

@ -1,7 +1,7 @@
use std::cmp::{max, min}; use std::cmp::{max, min};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use store::{CheckerMove, Dice}; use store::CheckerMove;
/// Types d'actions possibles dans le jeu /// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@ -210,7 +210,10 @@ fn checker_moves_to_trictrac_action(
let checker1 = state.board.get_field_checker(color, from1) as usize; let checker1 = state.board.get_field_checker(color, from1) as usize;
let mut tmp_board = state.board.clone(); let mut tmp_board = state.board.clone();
// should not raise an error for a valid action // should not raise an error for a valid action
tmp_board.move_checker(color, *move1); let move_res = tmp_board.move_checker(color, *move1);
if move_res.is_err() {
panic!("error while moving checker {move_res:?}");
}
let checker2 = tmp_board.get_field_checker(color, from2) as usize; let checker2 = tmp_board.get_field_checker(color, from2) as usize;
TrictracAction::Move { TrictracAction::Move {
dice_order, dice_order,

View file

@ -55,6 +55,10 @@ impl ReplayBuffer {
batch batch
} }
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.buffer.len() self.buffer.len()
} }
@ -457,7 +461,7 @@ impl DqnTrainer {
save_every: usize, save_every: usize,
model_path: &str, model_path: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes); println!("Démarrage de l'entraînement DQN pour {episodes} épisodes");
for episode in 1..=episodes { for episode in 1..=episodes {
let reward = self.train_episode(); let reward = self.train_episode();
@ -474,16 +478,16 @@ impl DqnTrainer {
} }
if episode % save_every == 0 { if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode); let save_path = format!("{model_path}_episode_{episode}.json");
self.agent.save_model(&save_path)?; self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path); println!("Modèle sauvegardé : {save_path}");
} }
} }
// Sauvegarder le modèle final // Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path); let final_path = format!("{model_path}_final.json");
self.agent.save_model(&final_path)?; self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path); println!("Modèle final sauvegardé : {final_path}");
Ok(()) Ok(())
} }

View file

@ -60,9 +60,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
std::fs::create_dir_all("models")?; std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :"); println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes); println!(" Épisodes : {episodes}");
println!(" Chemin du modèle : {}", model_path); println!(" Chemin du modèle : {model_path}");
println!(" Sauvegarde tous les {} épisodes", save_every); println!(" Sauvegarde tous les {save_every} épisodes");
println!(); println!();
// Configuration DQN // Configuration DQN
@ -85,10 +85,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Entraînement terminé avec succès !"); println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :"); println!("Pour utiliser le modèle entraîné :");
println!( println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
model_path
);
Ok(()) Ok(())
} }

View file

@ -1,7 +1,7 @@
pub mod dqn; pub mod dqn;
pub mod strategy; pub mod strategy;
use log::{debug, error}; use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy; pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy; pub use strategy::dqn::DqnStrategy;

View file

@ -154,7 +154,10 @@ impl BotStrategy for DqnBurnStrategy {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone(); let mut tmp_board = self.game.board.clone();
tmp_board.move_checker(&self.color, checker_move1); let move_res = tmp_board.move_checker(&self.color, checker_move1);
if move_res.is_err() {
panic!("could not move {move_res:?}");
}
let from2 = tmp_board let from2 = tmp_board
.get_checker_field(&self.color, checker2 as u8) .get_checker_field(&self.color, checker2 as u8)
.unwrap_or(0); .unwrap_or(0);

View file

@ -66,14 +66,14 @@ impl StableBaselines3Strategy {
// Remplir les positions des pièces blanches (valeurs positives) // Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) { for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 { if pos < 24 {
board[pos] = count as i8; board[pos] = count;
} }
} }
// Remplir les positions des pièces noires (valeurs négatives) // Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) { for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 { if pos < 24 {
board[pos] = -(count as i8); board[pos] = -count;
} }
} }
@ -270,4 +270,3 @@ impl BotStrategy for StableBaselines3Strategy {
} }
} }
} }

View file

@ -59,7 +59,7 @@ impl App {
} }
s if s.starts_with("dqnburn:") => { s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqnburn:"); let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}"))) Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
as Box<dyn BotStrategy>) as Box<dyn BotStrategy>)
} }
_ => None, _ => None,
@ -114,7 +114,7 @@ impl App {
pub fn show_history(&self) { pub fn show_history(&self) {
for hist in self.game.state.history.iter() { for hist in self.game.state.history.iter() {
println!("{:?}\n", hist); println!("{hist:?}\n");
} }
} }
@ -192,7 +192,7 @@ impl App {
return; return;
} }
} }
println!("invalid move : {}", input); println!("invalid move : {input}");
} }
pub fn display(&mut self) -> String { pub fn display(&mut self) -> String {

View file

@ -77,7 +77,7 @@ impl GameRunner {
} else { } else {
debug!("{}", self.state); debug!("{}", self.state);
error!("event not valid : {event:?}"); error!("event not valid : {event:?}");
panic!("crash and burn {} \nevt not valid {event:?}", self.state); // panic!("crash and burn {} \nevt not valid {event:?}", self.state);
&GameEvent::PlayError &GameEvent::PlayError
}; };

View file

@ -35,7 +35,7 @@ fn main() -> Result<()> {
let args = match parse_args() { let args = match parse_args() {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
eprintln!("Error: {}.", e); eprintln!("Error: {e}.");
std::process::exit(1); std::process::exit(1);
} }
}; };
@ -63,7 +63,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// Help has a higher priority and should be handled separately. // Help has a higher priority and should be handled separately.
if pargs.contains(["-h", "--help"]) { if pargs.contains(["-h", "--help"]) {
print!("{}", HELP); print!("{HELP}");
std::process::exit(0); std::process::exit(0);
} }
@ -78,7 +78,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// It's up to the caller what to do with the remaining arguments. // It's up to the caller what to do with the remaining arguments.
let remaining = pargs.finish(); let remaining = pargs.finish();
if !remaining.is_empty() { if !remaining.is_empty() {
eprintln!("Warning: unused arguments left: {:?}.", remaining); eprintln!("Warning: unused arguments left: {remaining:?}.");
} }
Ok(args) Ok(args)

View file

@ -43,7 +43,7 @@ fn main() {
.unwrap(); .unwrap();
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap(); let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
trace!("❂ TricTrac server listening on {}", SERVER_ADDR); trace!("❂ TricTrac server listening on {SERVER_ADDR}");
let mut game_state = store::GameState::default(); let mut game_state = store::GameState::default();
let mut last_updated = Instant::now(); let mut last_updated = Instant::now();
@ -80,7 +80,7 @@ fn main() {
// Tell all players that a new player has joined // Tell all players that a new player has joined
server.broadcast_message(0, bincode::serialize(&event).unwrap()); server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("🎉 Client {} connected.", client_id); info!("🎉 Client {client_id} connected.");
// In TicTacTussle the game can begin once two players has joined // In TicTacTussle the game can begin once two players has joined
if game_state.players.len() == 2 { if game_state.players.len() == 2 {
let event = store::GameEvent::BeginGame { let event = store::GameEvent::BeginGame {
@ -101,7 +101,7 @@ fn main() {
}; };
game_state.consume(&event); game_state.consume(&event);
server.broadcast_message(0, bincode::serialize(&event).unwrap()); server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("Client {} disconnected", client_id); info!("Client {client_id} disconnected");
// Then end the game, since tic tac toe can't go on with a single player // Then end the game, since tic tac toe can't go on with a single player
let event = store::GameEvent::EndGame { let event = store::GameEvent::EndGame {
@ -124,7 +124,7 @@ fn main() {
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) { if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
if game_state.validate(&event) { if game_state.validate(&event) {
game_state.consume(&event); game_state.consume(&event);
trace!("Player {} sent:\n\t{:#?}", client_id, event); trace!("Player {client_id} sent:\n\t{event:#?}");
server.broadcast_message(0, bincode::serialize(&event).unwrap()); server.broadcast_message(0, bincode::serialize(&event).unwrap());
// Determine if a player has won the game // Determine if a player has won the game
@ -135,7 +135,7 @@ fn main() {
server.broadcast_message(0, bincode::serialize(&event).unwrap()); server.broadcast_message(0, bincode::serialize(&event).unwrap());
} }
} else { } else {
warn!("Player {} sent invalid event:\n\t{:#?}", client_id, event); warn!("Player {client_id} sent invalid event:\n\t{event:#?}");
} }
} }
} }

View file

@ -4,7 +4,7 @@ use crate::dice::Dice;
use crate::game_rules_moves::MoveRules; use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::game_rules_points::{PointsRules, PossibleJans};
use crate::player::{Color, Player, PlayerId}; use crate::player::{Color, Player, PlayerId};
use log::{debug, error, info}; use log::{debug, error};
// use itertools::Itertools; // use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};