réglages train bot dqn burnrl

This commit is contained in:
Henri Bourcereau 2025-08-03 16:11:45 +02:00
parent 28c2aa836f
commit c0d42a0c45
7 changed files with 101 additions and 16 deletions

View file

@ -91,7 +91,7 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction;
type RewardType = f32;
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
@ -179,9 +179,9 @@ impl Environment for TrictracEnvironment {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += 100.0; // Victoire
reward += 50.0; // Victoire
} else {
reward -= 50.0; // Défaite
reward -= 25.0; // Défaite
}
}
}
@ -259,7 +259,7 @@ impl TrictracEnvironment {
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.4;
reward += 0.2;
Some(GameEvent::Go {
player_id: self.active_player_id,
})
@ -288,7 +288,7 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.4;
reward += 0.2;
Some(GameEvent::Move {
player_id: self.active_player_id,
moves: (checker_move1, checker_move2),
@ -313,6 +313,8 @@ impl TrictracEnvironment {
};
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
}
}
} else {
@ -356,7 +358,7 @@ impl TrictracEnvironment {
},
}
}
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
TurnStage::MarkPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
@ -366,14 +368,31 @@ impl TrictracEnvironment {
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let points = points_rules.get_points(dice_roll_count).0;
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_id,
points,
}
}
TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let points = points_rules.get_points(dice_roll_count).1;
// pas de reward : déjà comptabilisé lors du tour de blanc
GameEvent::Mark {
player_id: self.opponent_id,
points,
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
GameEvent::Go {

View file

@ -11,13 +11,13 @@ type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let conf = dqn_model::DqnConfig {
num_episodes: 50,
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 1000.0,
eps_decay: 3000.0,
};
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);

View file

@ -357,8 +357,8 @@ impl TrictracEnv {
&self.game_state.board,
self.game_state.dice,
);
let points = points_rules.get_points(dice_roll_count).0;
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_player_id,

View file

@ -46,7 +46,14 @@ impl BotStrategy for ClientStrategy {
}
fn calculate_adv_points(&self) -> u8 {
self.calculate_points()
let dice_roll_count = self
.get_game()
.players
.get(&self.player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
points_rules.get_points(dice_roll_count).1
}
fn choose_go(&self) -> bool {