debug
This commit is contained in:
parent
f7eea0ed02
commit
ebe98ca229
|
|
@ -2,7 +2,7 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
|||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action};
|
||||
use super::dqn_common::{SimpleNeuralNetwork, TrictracAction, get_valid_actions, sample_valid_action};
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use crate::{CheckerMove};
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
|
|
@ -24,7 +23,7 @@ impl TrictracAction {
|
|||
TrictracAction::Roll => 0,
|
||||
TrictracAction::Mark { points } => {
|
||||
1 + (*points as usize).min(12) // Indices 1-13 pour 0-12 points
|
||||
},
|
||||
}
|
||||
TrictracAction::Go => 14,
|
||||
TrictracAction::Move { move1, move2 } => {
|
||||
// Encoder les mouvements dans l'espace d'actions
|
||||
|
|
@ -38,13 +37,15 @@ impl TrictracAction {
|
|||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||
match index {
|
||||
0 => Some(TrictracAction::Roll),
|
||||
1..=13 => Some(TrictracAction::Mark { points: (index - 1) as u8 }),
|
||||
1..=13 => Some(TrictracAction::Mark {
|
||||
points: (index - 1) as u8,
|
||||
}),
|
||||
14 => Some(TrictracAction::Go),
|
||||
i if i >= 15 => {
|
||||
let move_code = i - 15;
|
||||
let (move1, move2) = decode_move_pair(move_code);
|
||||
Some(TrictracAction::Move { move1, move2 })
|
||||
},
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -236,7 +237,7 @@ impl SimpleNeuralNetwork {
|
|||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||
use crate::{Color, PointsRules};
|
||||
use crate::PointsRules;
|
||||
use store::{MoveRules, TurnStage};
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
|
@ -287,7 +288,6 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
|||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -304,10 +304,9 @@ pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
|||
|
||||
/// Sélectionne une action valide aléatoire
|
||||
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||
use rand::{thread_rng, seq::SliceRandom};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
let mut rng = thread_rng();
|
||||
valid_actions.choose(&mut rng).cloned()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::VecDeque;
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, TrictracAction, get_valid_actions, get_valid_action_indices, sample_valid_action};
|
||||
use super::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -99,7 +99,10 @@ impl DqnAgent {
|
|||
let mut rng = thread_rng();
|
||||
if rng.gen::<f64>() < self.epsilon {
|
||||
// Exploration : action valide aléatoire
|
||||
valid_actions.choose(&mut rng).cloned().unwrap_or(TrictracAction::Roll)
|
||||
valid_actions
|
||||
.choose(&mut rng)
|
||||
.cloned()
|
||||
.unwrap_or(TrictracAction::Roll)
|
||||
} else {
|
||||
// Exploitation : meilleure action valide selon le modèle
|
||||
let q_values = self.model.forward(state);
|
||||
|
|
@ -287,7 +290,9 @@ impl TrictracEnv {
|
|||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.agent_player_id,
|
||||
dice: store::Dice { values: dice_values },
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
|
|
@ -393,8 +398,10 @@ impl DqnTrainer {
|
|||
pub fn train_episode(&mut self) -> f32 {
|
||||
let mut total_reward = 0.0;
|
||||
let mut state = self.env.reset();
|
||||
// let mut step_count = 0;
|
||||
|
||||
loop {
|
||||
// step_count += 1;
|
||||
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||
let (next_state, reward, done) = self.env.step(action.clone());
|
||||
total_reward += reward;
|
||||
|
|
@ -412,6 +419,9 @@ impl DqnTrainer {
|
|||
if done {
|
||||
break;
|
||||
}
|
||||
// if step_count % 100 == 0 {
|
||||
// println!("{:?}", next_state);
|
||||
// }
|
||||
state = next_state;
|
||||
}
|
||||
|
||||
|
|
@ -429,6 +439,7 @@ impl DqnTrainer {
|
|||
for episode in 1..=episodes {
|
||||
let reward = self.train_episode();
|
||||
|
||||
print!(".");
|
||||
if episode % 100 == 0 {
|
||||
println!(
|
||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use store::MoveRules;
|
||||
use std::process::Command;
|
||||
use std::io::Write;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::process::Command;
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StableBaselines3Strategy {
|
||||
|
|
@ -85,7 +85,6 @@ impl StableBaselines3Strategy {
|
|||
store::TurnStage::HoldOrGoChoice => 3,
|
||||
store::TurnStage::Move => 4,
|
||||
store::TurnStage::MarkAdvPoints => 5,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
// Récupérer les points et trous des joueurs
|
||||
|
|
@ -170,10 +169,7 @@ with open("{}", "w") as f:
|
|||
script_file.write_all(python_script.as_bytes()).ok()?;
|
||||
|
||||
// Exécuter le script Python
|
||||
let status = Command::new("python")
|
||||
.arg(temp_script_path)
|
||||
.status()
|
||||
.ok()?;
|
||||
let status = Command::new("python").arg(temp_script_path).status().ok()?;
|
||||
|
||||
if !status.success() {
|
||||
return None;
|
||||
|
|
@ -274,3 +270,4 @@ impl BotStrategy for StableBaselines3Strategy {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ impl GameState {
|
|||
state.push(self.dice.values.0 as i8);
|
||||
state.push(self.dice.values.1 as i8);
|
||||
|
||||
// points length=4 x2 joueurs = 8
|
||||
// points, trous, bredouille, grande bredouille length=4 x2 joueurs = 8
|
||||
let white_player: Vec<i8> = self
|
||||
.get_white_player()
|
||||
.unwrap()
|
||||
|
|
|
|||
Loading…
Reference in a new issue