réglages train bot dqn burnrl
This commit is contained in:
parent
28c2aa836f
commit
c0d42a0c45
7 changed files with 101 additions and 16 deletions
|
|
@ -91,7 +91,7 @@ impl Environment for TrictracEnvironment {
|
|||
type ActionType = TrictracAction;
|
||||
type RewardType = f32;
|
||||
|
||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
||||
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
||||
|
||||
fn new(visualized: bool) -> Self {
|
||||
let mut game = GameState::new(false);
|
||||
|
|
@ -179,9 +179,9 @@ impl Environment for TrictracEnvironment {
|
|||
// Récompense finale basée sur le résultat
|
||||
if let Some(winner_id) = self.game.determine_winner() {
|
||||
if winner_id == self.active_player_id {
|
||||
reward += 100.0; // Victoire
|
||||
reward += 50.0; // Victoire
|
||||
} else {
|
||||
reward -= 50.0; // Défaite
|
||||
reward -= 25.0; // Défaite
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -259,7 +259,7 @@ impl TrictracEnvironment {
|
|||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.4;
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
|
|
@ -288,7 +288,7 @@ impl TrictracEnvironment {
|
|||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.4;
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
|
|
@ -313,6 +313,8 @@ impl TrictracEnvironment {
|
|||
};
|
||||
if self.game.validate(&dice_event) {
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -356,7 +358,7 @@ impl TrictracEnvironment {
|
|||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
TurnStage::MarkPoints => {
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
|
|
@ -366,14 +368,31 @@ impl TrictracEnvironment {
|
|||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints => {
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
.get(&self.opponent_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let points = points_rules.get_points(dice_roll_count).1;
|
||||
// pas de reward : déjà comptabilisé lors du tour de blanc
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
GameEvent::Go {
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ type Env = environment::TrictracEnvironment;
|
|||
fn main() {
|
||||
println!("> Entraînement");
|
||||
let conf = dqn_model::DqnConfig {
|
||||
num_episodes: 50,
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
eps_decay: 3000.0,
|
||||
};
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
|
|
|
|||
|
|
@ -357,8 +357,8 @@ impl TrictracEnv {
|
|||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
|
|
|
|||
|
|
@ -46,7 +46,14 @@ impl BotStrategy for ClientStrategy {
|
|||
}
|
||||
|
||||
fn calculate_adv_points(&self) -> u8 {
|
||||
self.calculate_points()
|
||||
let dice_roll_count = self
|
||||
.get_game()
|
||||
.players
|
||||
.get(&self.player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
|
||||
points_rules.get_points(dice_roll_count).1
|
||||
}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue