burn dqn trainer

This commit is contained in:
Henri Bourcereau 2025-06-22 21:25:45 +02:00
parent cf1175e497
commit a06b47628e
5 changed files with 276 additions and 13 deletions

View file

@ -80,13 +80,13 @@ impl From<TrictracAction> for u32 {
/// Environnement Trictrac pour burn-rl
#[derive(Debug)]
pub struct TrictracEnvironment {
game: GameState,
pub game: GameState,
active_player_id: PlayerId,
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
step_count: usize,
visualized: bool,
pub visualized: bool,
}
impl Environment for TrictracEnvironment {
@ -127,6 +127,9 @@ impl Environment for TrictracEnvironment {
self.game.init_player("DQN Agent");
self.game.init_player("Opponent");
// Commencer la partie
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.step_count = 0;
@ -161,8 +164,10 @@ impl Environment for TrictracEnvironment {
}
}
// Jouer l'adversaire si c'est son tour
reward += self.play_opponent_if_needed();
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
reward += self.play_opponent_if_needed();
}
// Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended
@ -366,13 +371,10 @@ impl TrictracEnvironment {
player_id: self.opponent_id,
}
}
TurnStage::Move => {
let (move1, move2) = default_strategy.choose_move();
GameEvent::Move {
player_id: self.opponent_id,
moves: (move1.mirror(), move2.mirror()),
}
}
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: default_strategy.choose_move(),
},
};
if self.game.validate(&event) {
@ -382,4 +384,3 @@ impl TrictracEnvironment {
reward
}
}