chore:refacto clippy

This commit is contained in:
Henri Bourcereau 2025-08-17 15:59:53 +02:00
parent db9560dfac
commit 1dc29d0ff0
19 changed files with 98 additions and 95 deletions

View file

@ -4,6 +4,11 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
@ -165,8 +170,7 @@ impl Environment for TrictracEnvironment {
let trictrac_action = Self::convert_action(action);
let mut reward = 0.0;
let mut is_rollpoint = false;
let mut terminated = false;
let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
@ -175,7 +179,7 @@ impl Environment for TrictracEnvironment {
if is_rollpoint {
self.pointrolls_count += 1;
}
if reward != Self::ERROR_REWARD {
if reward != ERROR_REWARD {
self.goodmoves_count += 1;
}
} else {
@ -199,9 +203,9 @@ impl Environment for TrictracEnvironment {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += 50.0; // Victoire
reward += WIN_POINTS; // Victoire
} else {
reward -= 25.0; // Défaite
reward -= WIN_POINTS; // Défaite
}
}
}
@ -223,15 +227,13 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index(
&self,
action: TrictracAction,
@ -265,7 +267,6 @@ impl TrictracEnvironment {
let event = match action {
TrictracAction::Roll => {
// Lancer les dés
reward += 0.1;
Some(GameEvent::Roll {
player_id: self.active_player_id,
})
@ -273,7 +274,6 @@ impl TrictracEnvironment {
// TrictracAction::Mark => {
// // Marquer des points
// let points = self.game.
// reward += 0.1 * points as f32;
// Some(GameEvent::Mark {
// player_id: self.active_player_id,
// points,
@ -281,7 +281,6 @@ impl TrictracEnvironment {
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go {
player_id: self.active_player_id,
})
@ -308,7 +307,10 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone();
tmp_board.move_checker(color, checker_move1);
let move_result = tmp_board.move_checker(color, checker_move1);
if move_result.is_err() {
panic!("Error while moving checker {move_result:?}")
}
let from2 = tmp_board
.get_checker_field(color, checker2 as u8)
.unwrap_or(0);
@ -324,7 +326,6 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move {
player_id: self.active_player_id,
moves: (checker_move1, checker_move2),
@ -336,7 +337,7 @@ impl TrictracEnvironment {
if let Some(event) = event {
if self.game.validate(&event) {
self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng();
@ -350,7 +351,7 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
reward += REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
@ -362,7 +363,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
reward = ERROR_REWARD;
}
}
@ -458,7 +459,7 @@ impl TrictracEnvironment {
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points
reward -= Self::REWARD_RATIO * (points - adv_points) as f32;
reward -= REWARD_RATIO * (points - adv_points) as f32;
}
}
}

View file

@ -15,12 +15,12 @@ fn main() {
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
// defaults
num_episodes: 40, // 40
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 1000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
num_episodes: 40, // 40
min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 2000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl::{
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};

View file

@ -4,6 +4,11 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
@ -168,16 +173,13 @@ impl Environment for TrictracEnvironment {
let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN
let mut has_played = false;
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
let str_action = format!("{action:?}");
(reward, is_rollpoint) = self.execute_action(action);
if is_rollpoint {
self.pointrolls_count += 1;
}
if reward != Self::ERROR_REWARD {
has_played = true;
if reward != ERROR_REWARD {
self.goodmoves_count += 1;
// println!("{str_action}");
}
@ -203,9 +205,9 @@ impl Environment for TrictracEnvironment {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += 50.0; // Victoire
reward += WIN_POINTS; // Victoire
} else {
reward -= 25.0; // Défaite
reward -= WIN_POINTS; // Défaite
}
}
}
@ -226,15 +228,13 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index(
&self,
action: TrictracAction,
@ -269,7 +269,6 @@ impl TrictracEnvironment {
let event = match action {
TrictracAction::Roll => {
// Lancer les dés
reward += 0.1;
need_roll = true;
Some(GameEvent::Roll {
player_id: self.active_player_id,
@ -286,7 +285,6 @@ impl TrictracEnvironment {
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go {
player_id: self.active_player_id,
})
@ -315,7 +313,6 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move {
player_id: self.active_player_id,
moves: (checker_move1, checker_move2),
@ -327,7 +324,7 @@ impl TrictracEnvironment {
if let Some(event) = event {
if self.game.validate(&event) {
self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll
// if matches!(action, TrictracAction::Roll) {
if need_roll {
@ -343,7 +340,7 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
reward += REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
@ -355,7 +352,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
reward = ERROR_REWARD;
}
}
@ -399,18 +396,18 @@ impl TrictracEnvironment {
}
TurnStage::MarkPoints => {
panic!("in play_opponent_if_needed > TurnStage::MarkPoints");
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
GameEvent::Mark {
player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).0,
}
// let dice_roll_count = self
// .game
// .players
// .get(&self.opponent_id)
// .unwrap()
// .dice_roll_count;
// let points_rules =
// PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
// GameEvent::Mark {
// player_id: self.opponent_id,
// points: points_rules.get_points(dice_roll_count).0,
// }
}
TurnStage::MarkAdvPoints => {
let dice_roll_count = self
@ -454,7 +451,7 @@ impl TrictracEnvironment {
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points
let adv_reward = Self::REWARD_RATIO * (points - adv_points) as f32;
let adv_reward = REWARD_RATIO * (points - adv_points) as f32;
reward -= adv_reward;
// if adv_reward != 0.0 {
// println!("info: opponent : {adv_reward} -> {reward}");

View file

@ -15,16 +15,16 @@ fn main() {
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
// defaults
num_episodes: 40, // 40
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 3000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
num_episodes: 40, // 40
min_steps: 2000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 4000, // 1000 max steps by episode
dense_size: 128, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ?
eps_decay: 1000.0, // 1000 ?
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl_big::{
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common_big::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};

View file

@ -382,10 +382,9 @@ impl TrictracEnvironment {
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
GameEvent::Mark {
player_id: self.opponent_id,
points,
points: points_rules.get_points(dice_roll_count).0,
}
}
TurnStage::MarkAdvPoints => {

View file

@ -3,7 +3,7 @@ use crate::dqn::burnrl_valid::{
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};

View file

@ -1,7 +1,7 @@
use std::cmp::{max, min};
use serde::{Deserialize, Serialize};
use store::{CheckerMove, Dice};
use store::CheckerMove;
/// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@ -210,7 +210,10 @@ fn checker_moves_to_trictrac_action(
let checker1 = state.board.get_field_checker(color, from1) as usize;
let mut tmp_board = state.board.clone();
// should not raise an error for a valid action
tmp_board.move_checker(color, *move1);
let move_res = tmp_board.move_checker(color, *move1);
if move_res.is_err() {
panic!("error while moving checker {move_res:?}");
}
let checker2 = tmp_board.get_field_checker(color, from2) as usize;
TrictracAction::Move {
dice_order,

View file

@ -55,6 +55,10 @@ impl ReplayBuffer {
batch
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
@ -457,7 +461,7 @@ impl DqnTrainer {
save_every: usize,
model_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
println!("Démarrage de l'entraînement DQN pour {episodes} épisodes");
for episode in 1..=episodes {
let reward = self.train_episode();
@ -474,16 +478,16 @@ impl DqnTrainer {
}
if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode);
let save_path = format!("{model_path}_episode_{episode}.json");
self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path);
println!("Modèle sauvegardé : {save_path}");
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path);
let final_path = format!("{model_path}_final.json");
self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path);
println!("Modèle final sauvegardé : {final_path}");
Ok(())
}

View file

@ -60,9 +60,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes);
println!(" Chemin du modèle : {}", model_path);
println!(" Sauvegarde tous les {} épisodes", save_every);
println!(" Épisodes : {episodes}");
println!(" Chemin du modèle : {model_path}");
println!(" Sauvegarde tous les {save_every} épisodes");
println!();
// Configuration DQN
@ -85,10 +85,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
println!(
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
model_path
);
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
Ok(())
}

View file

@ -1,7 +1,7 @@
pub mod dqn;
pub mod strategy;
use log::{debug, error};
use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy;

View file

@ -154,7 +154,10 @@ impl BotStrategy for DqnBurnStrategy {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone();
tmp_board.move_checker(&self.color, checker_move1);
let move_res = tmp_board.move_checker(&self.color, checker_move1);
if move_res.is_err() {
panic!("could not move {move_res:?}");
}
let from2 = tmp_board
.get_checker_field(&self.color, checker2 as u8)
.unwrap_or(0);

View file

@ -66,14 +66,14 @@ impl StableBaselines3Strategy {
// Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 {
board[pos] = count as i8;
board[pos] = count;
}
}
// Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 {
board[pos] = -(count as i8);
board[pos] = -count;
}
}
@ -270,4 +270,3 @@ impl BotStrategy for StableBaselines3Strategy {
}
}
}