wip fix train
This commit is contained in:
parent
ab959fa27b
commit
a2e54bc449
|
|
@ -1,5 +1,5 @@
|
|||
use bot::strategy::dqn_trainer::{DqnTrainer};
|
||||
use bot::strategy::dqn_common::DqnConfig;
|
||||
use bot::strategy::dqn_trainer::DqnTrainer;
|
||||
use std::env;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -66,7 +66,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
// Configuration DQN
|
||||
let config = DqnConfig {
|
||||
input_size: 32,
|
||||
state_size: 36, // state.to_vec size
|
||||
hidden_size: 256,
|
||||
num_actions: 3,
|
||||
learning_rate: 0.001,
|
||||
|
|
@ -84,7 +84,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
println!("Entraînement terminé avec succès !");
|
||||
println!("Pour utiliser le modèle entraîné :");
|
||||
println!(" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy", model_path);
|
||||
println!(
|
||||
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
|
||||
model_path
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use store::MoveRules;
|
||||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork};
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
|
|
@ -40,7 +40,7 @@ impl DqnStrategy {
|
|||
/// Utilise le modèle DQN pour choisir une action
|
||||
fn get_dqn_action(&self) -> Option<usize> {
|
||||
if let Some(ref model) = self.model {
|
||||
let state = game_state_to_vector(&self.game);
|
||||
let state = self.game.to_vec_float();
|
||||
Some(model.get_best_action(&state))
|
||||
} else {
|
||||
None
|
||||
|
|
@ -66,8 +66,6 @@ impl BotStrategy for DqnStrategy {
|
|||
}
|
||||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
// Pour l'instant, utilisation de la méthode standard
|
||||
// Plus tard on pourrait utiliser le DQN pour optimiser le calcul de points
|
||||
let dice_roll_count = self
|
||||
.get_game()
|
||||
.players
|
||||
|
|
@ -107,7 +105,9 @@ impl BotStrategy for DqnStrategy {
|
|||
} else {
|
||||
possible_moves.len().saturating_sub(1) // Dernier mouvement
|
||||
};
|
||||
*possible_moves.get(move_index).unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||||
*possible_moves
|
||||
.get(move_index)
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||||
} else {
|
||||
// Fallback : premier mouvement valide
|
||||
*possible_moves
|
||||
|
|
@ -122,3 +122,4 @@ impl BotStrategy for DqnStrategy {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
|||
/// Configuration pour l'agent DQN
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DqnConfig {
|
||||
pub input_size: usize,
|
||||
pub state_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_actions: usize,
|
||||
pub learning_rate: f64,
|
||||
|
|
@ -18,7 +18,7 @@ pub struct DqnConfig {
|
|||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_size: 32,
|
||||
state_size: 36,
|
||||
hidden_size: 256,
|
||||
num_actions: 3,
|
||||
learning_rate: 0.001,
|
||||
|
|
@ -51,19 +51,31 @@ impl SimpleNeuralNetwork {
|
|||
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||
let weights1 = (0..hidden_size)
|
||||
.map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect())
|
||||
.map(|_| {
|
||||
(0..input_size)
|
||||
.map(|_| rng.gen_range(-scale1..scale1))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases1 = vec![0.0; hidden_size];
|
||||
|
||||
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights2 = (0..hidden_size)
|
||||
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect())
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale2..scale2))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases2 = vec![0.0; hidden_size];
|
||||
|
||||
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights3 = (0..output_size)
|
||||
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect())
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale3..scale3))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases3 = vec![0.0; output_size];
|
||||
|
||||
|
|
@ -123,7 +135,10 @@ impl SimpleNeuralNetwork {
|
|||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
|
|
@ -136,47 +151,3 @@ impl SimpleNeuralNetwork {
|
|||
}
|
||||
}
|
||||
|
||||
/// Convertit l'état du jeu en vecteur d'entrée pour le réseau de neurones
|
||||
pub fn game_state_to_vector(game_state: &crate::GameState) -> Vec<f32> {
|
||||
use crate::Color;
|
||||
|
||||
let mut state = Vec::with_capacity(32);
|
||||
|
||||
// Plateau (24 cases)
|
||||
let white_positions = game_state.board.get_color_fields(Color::White);
|
||||
let black_positions = game_state.board.get_color_fields(Color::Black);
|
||||
|
||||
let mut board = vec![0.0; 24];
|
||||
for (pos, count) in white_positions {
|
||||
if pos < 24 {
|
||||
board[pos] = count as f32;
|
||||
}
|
||||
}
|
||||
for (pos, count) in black_positions {
|
||||
if pos < 24 {
|
||||
board[pos] = -(count as f32);
|
||||
}
|
||||
}
|
||||
state.extend(board);
|
||||
|
||||
// Informations supplémentaires limitées pour respecter input_size = 32
|
||||
state.push(game_state.active_player_id as f32);
|
||||
state.push(game_state.dice.values.0 as f32);
|
||||
state.push(game_state.dice.values.1 as f32);
|
||||
|
||||
// Points et trous des joueurs
|
||||
if let Some(white_player) = game_state.get_white_player() {
|
||||
state.push(white_player.points as f32);
|
||||
state.push(white_player.holes as f32);
|
||||
} else {
|
||||
state.extend(vec![0.0, 0.0]);
|
||||
}
|
||||
|
||||
// Assurer que la taille est exactement input_size
|
||||
state.truncate(32);
|
||||
while state.len() < 32 {
|
||||
state.push(0.0);
|
||||
}
|
||||
|
||||
state
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
use crate::{Color, GameState, PlayerId};
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::collections::VecDeque;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
|
||||
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -71,7 +72,8 @@ pub struct DqnAgent {
|
|||
|
||||
impl DqnAgent {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions);
|
||||
let model =
|
||||
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||
let target_model = model.clone();
|
||||
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||
let epsilon = config.epsilon;
|
||||
|
|
@ -117,7 +119,10 @@ impl DqnAgent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn save_model<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.model.save(path)
|
||||
}
|
||||
|
||||
|
|
@ -141,8 +146,8 @@ pub struct TrictracEnv {
|
|||
pub current_step: usize,
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
pub fn new() -> Self {
|
||||
impl Default for TrictracEnv {
|
||||
fn default() -> Self {
|
||||
let mut game_state = GameState::new(false);
|
||||
game_state.init_player("agent");
|
||||
game_state.init_player("opponent");
|
||||
|
|
@ -156,17 +161,21 @@ impl TrictracEnv {
|
|||
current_step: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
pub fn reset(&mut self) -> Vec<f32> {
|
||||
self.game_state = GameState::new(false);
|
||||
self.game_state.init_player("agent");
|
||||
self.game_state.init_player("opponent");
|
||||
|
||||
// Commencer la partie
|
||||
self.game_state.consume(&GameEvent::BeginGame { goes_first: self.agent_player_id });
|
||||
self.game_state.consume(&GameEvent::BeginGame {
|
||||
goes_first: self.agent_player_id,
|
||||
});
|
||||
|
||||
self.current_step = 0;
|
||||
game_state_to_vector(&self.game_state)
|
||||
self.game_state.to_vec_float()
|
||||
}
|
||||
|
||||
pub fn step(&mut self, action: usize) -> (Vec<f32>, f32, bool) {
|
||||
|
|
@ -179,61 +188,66 @@ impl TrictracEnv {
|
|||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game_state.active_player_id == self.opponent_player_id
|
||||
&& self.game_state.stage != Stage::Ended {
|
||||
self.play_opponent_turn();
|
||||
&& self.game_state.stage != Stage::Ended
|
||||
{
|
||||
reward += self.play_opponent_turn();
|
||||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game_state.stage == Stage::Ended ||
|
||||
self.game_state.determine_winner().is_some() ||
|
||||
self.current_step >= self.max_steps;
|
||||
let done = self.game_state.stage == Stage::Ended
|
||||
|| self.game_state.determine_winner().is_some()
|
||||
|| self.current_step >= self.max_steps;
|
||||
|
||||
// Récompense finale si la partie est terminée
|
||||
if done {
|
||||
if let Some(winner) = self.game_state.determine_winner() {
|
||||
if winner == self.agent_player_id {
|
||||
reward += 10.0; // Bonus pour gagner
|
||||
reward += 100.0; // Bonus pour gagner
|
||||
} else {
|
||||
reward -= 5.0; // Pénalité pour perdre
|
||||
reward -= 50.0; // Pénalité pour perdre
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.current_step += 1;
|
||||
let next_state = game_state_to_vector(&self.game_state);
|
||||
|
||||
let next_state = self.game_state.to_vec_float();
|
||||
(next_state, reward, done)
|
||||
}
|
||||
|
||||
fn apply_agent_action(&mut self, action: usize) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
|
||||
match self.game_state.turn_stage {
|
||||
// TODO : déterminer event selon action ...
|
||||
|
||||
let event = match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
// Lancer les dés
|
||||
let event = GameEvent::Roll { player_id: self.agent_player_id };
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
|
||||
GameEvent::Roll {
|
||||
player_id: self.agent_player_id,
|
||||
}
|
||||
}
|
||||
TurnStage::RollWaiting => {
|
||||
// Simuler le résultat des dés
|
||||
reward += 0.1;
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
GameEvent::RollResult {
|
||||
player_id: self.agent_player_id,
|
||||
dice: store::Dice { values: dice_values },
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
}
|
||||
reward += 0.1;
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
// Choisir un mouvement selon l'action
|
||||
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
|
||||
let rules = MoveRules::new(
|
||||
&self.agent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
if !possible_moves.is_empty() {
|
||||
// TODO : choix d'action
|
||||
let move_index = if action == 0 {
|
||||
0
|
||||
} else if action == 1 && possible_moves.len() > 1 {
|
||||
|
|
@ -243,126 +257,137 @@ impl TrictracEnv {
|
|||
};
|
||||
|
||||
let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]);
|
||||
let event = GameEvent::Move {
|
||||
GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves,
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
// Calculer et marquer les points
|
||||
let dice_roll_count = self
|
||||
.game_state
|
||||
.players
|
||||
.get(&self.agent_player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(
|
||||
&self.agent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
|
||||
reward += 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||
GameEvent::Mark {
|
||||
player_id: self.agent_player_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Décider de continuer ou pas selon l'action
|
||||
if action == 2 {
|
||||
// Action "go"
|
||||
GameEvent::Go {
|
||||
player_id: self.agent_player_id,
|
||||
}
|
||||
} else {
|
||||
// Passer son tour en jouant un mouvement
|
||||
let rules = MoveRules::new(
|
||||
&self.agent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
let moves = possible_moves[0];
|
||||
GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
reward += 0.2;
|
||||
} else {
|
||||
reward -= 1.0; // Pénalité pour mouvement invalide
|
||||
reward -= 1.0; // Pénalité pour action invalide
|
||||
}
|
||||
}
|
||||
}
|
||||
TurnStage::MarkPoints => {
|
||||
// Calculer et marquer les points
|
||||
let dice_roll_count = self.game_state.players.get(&self.agent_player_id).unwrap().dice_roll_count;
|
||||
let points_rules = PointsRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
|
||||
let event = GameEvent::Mark {
|
||||
player_id: self.agent_player_id,
|
||||
points,
|
||||
};
|
||||
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
reward += 0.1 * points as f32; // Récompense proportionnelle aux points
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Décider de continuer ou pas selon l'action
|
||||
if action == 2 { // Action "go"
|
||||
let event = GameEvent::Go { player_id: self.agent_player_id };
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
reward += 0.1;
|
||||
}
|
||||
} else {
|
||||
// Passer son tour en jouant un mouvement
|
||||
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
if !possible_moves.is_empty() {
|
||||
let moves = possible_moves[0];
|
||||
let event = GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves,
|
||||
};
|
||||
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
|
||||
fn play_opponent_turn(&mut self) {
|
||||
match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
let event = GameEvent::Roll { player_id: self.opponent_player_id };
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// TODO : use default bot strategy
|
||||
fn play_opponent_turn(&mut self) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
let event = match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_player_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_player_id,
|
||||
dice: store::Dice { values: dice_values },
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let dice_roll_count = self
|
||||
.game_state
|
||||
.players
|
||||
.get(&self.opponent_player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let rules = MoveRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
|
||||
let rules = MoveRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
if !possible_moves.is_empty() {
|
||||
let moves = possible_moves[0]; // Stratégie simple : premier mouvement
|
||||
let event = GameEvent::Move {
|
||||
// Stratégie simple : choix aléatoire
|
||||
let mut rng = thread_rng();
|
||||
let choosen_move = *possible_moves.choose(&mut rng).unwrap();
|
||||
|
||||
GameEvent::Move {
|
||||
player_id: self.opponent_player_id,
|
||||
moves,
|
||||
};
|
||||
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
}
|
||||
}
|
||||
TurnStage::MarkPoints => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let dice_roll_count = self.game_state.players.get(&self.opponent_player_id).unwrap().dice_roll_count;
|
||||
let points_rules = PointsRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
|
||||
let points = points_rules.get_points(dice_roll_count).0;
|
||||
|
||||
let event = GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
points,
|
||||
};
|
||||
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
moves: if opponent_color == Color::White {
|
||||
choosen_move
|
||||
} else {
|
||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
let event = GameEvent::Go { player_id: self.opponent_player_id };
|
||||
GameEvent::Go {
|
||||
player_id: self.opponent_player_id,
|
||||
}
|
||||
}
|
||||
};
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -376,7 +401,7 @@ impl DqnTrainer {
|
|||
pub fn new(config: DqnConfig) -> Self {
|
||||
Self {
|
||||
agent: DqnAgent::new(config),
|
||||
env: TrictracEnv::new(),
|
||||
env: TrictracEnv::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -408,7 +433,12 @@ impl DqnTrainer {
|
|||
total_reward
|
||||
}
|
||||
|
||||
pub fn train(&mut self, episodes: usize, save_every: usize, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn train(
|
||||
&mut self,
|
||||
episodes: usize,
|
||||
save_every: usize,
|
||||
model_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
|
||||
|
||||
for episode in 1..=episodes {
|
||||
|
|
@ -417,7 +447,11 @@ impl DqnTrainer {
|
|||
if episode % 100 == 0 {
|
||||
println!(
|
||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||
episode, episodes, reward, self.agent.get_epsilon(), self.agent.get_step_count()
|
||||
episode,
|
||||
episodes,
|
||||
reward,
|
||||
self.agent.get_epsilon(),
|
||||
self.agent.get_step_count()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErroneousStrategy {
|
||||
|
|
|
|||
3
justfile
3
justfile
|
|
@ -18,4 +18,5 @@ pythonlib:
|
|||
maturin build -m store/Cargo.toml --release
|
||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||
trainbot:
|
||||
python ./store/python/trainModel.py
|
||||
#python ./store/python/trainModel.py
|
||||
cargo run --bin=train_dqn
|
||||
|
|
|
|||
|
|
@ -153,6 +153,10 @@ impl Board {
|
|||
.unsigned_abs()
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<i8> {
|
||||
self.positions.to_vec()
|
||||
}
|
||||
|
||||
// maybe todo : operate on bits (cf. https://github.com/bungogood/bkgm/blob/a2fb3f395243bcb0bc9f146df73413f73f5ea1e0/src/position.rs#L217)
|
||||
pub fn to_gnupg_pos_id(&self) -> String {
|
||||
// Pieces placement -> 77bits (24 + 23 + 30 max)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,33 @@ pub enum TurnStage {
|
|||
MarkAdvPoints,
|
||||
}
|
||||
|
||||
impl From<u8> for TurnStage {
|
||||
fn from(item: u8) -> Self {
|
||||
match item {
|
||||
0 => TurnStage::RollWaiting,
|
||||
1 => TurnStage::RollDice,
|
||||
2 => TurnStage::MarkPoints,
|
||||
3 => TurnStage::HoldOrGoChoice,
|
||||
4 => TurnStage::Move,
|
||||
5 => TurnStage::MarkAdvPoints,
|
||||
_ => TurnStage::RollWaiting,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TurnStage> for u8 {
|
||||
fn from(stage: TurnStage) -> u8 {
|
||||
match stage {
|
||||
TurnStage::RollWaiting => 0,
|
||||
TurnStage::RollDice => 1,
|
||||
TurnStage::MarkPoints => 2,
|
||||
TurnStage::HoldOrGoChoice => 3,
|
||||
TurnStage::Move => 4,
|
||||
TurnStage::MarkAdvPoints => 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a TricTrac game
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GameState {
|
||||
|
|
@ -117,6 +144,63 @@ impl GameState {
|
|||
// accessors
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
pub fn to_vec_float(&self) -> Vec<f32> {
|
||||
self.to_vec().iter().map(|&x| x as f32).collect()
|
||||
}
|
||||
|
||||
/// Get state as a vector (to be used for bot training input) :
|
||||
/// length = 36
|
||||
pub fn to_vec(&self) -> Vec<i8> {
|
||||
let state_len = 36;
|
||||
let mut state = Vec::with_capacity(state_len);
|
||||
|
||||
// length = 24
|
||||
state.extend(self.board.to_vec());
|
||||
|
||||
// active player -> length = 1
|
||||
// white : 0 (false)
|
||||
// black : 1 (true)
|
||||
state.push(
|
||||
self.who_plays()
|
||||
.map(|player| if player.color == Color::Black { 1 } else { 0 })
|
||||
.unwrap_or(0), // White by default
|
||||
);
|
||||
|
||||
// step -> length = 1
|
||||
let turn_stage: u8 = self.turn_stage.into();
|
||||
state.push(turn_stage as i8);
|
||||
|
||||
// dice roll -> length = 2
|
||||
state.push(self.dice.values.0 as i8);
|
||||
state.push(self.dice.values.1 as i8);
|
||||
|
||||
// points length=4 x2 joueurs = 8
|
||||
let white_player: Vec<i8> = self
|
||||
.get_white_player()
|
||||
.unwrap()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|&x| x as i8)
|
||||
.collect();
|
||||
state.extend(white_player);
|
||||
let black_player: Vec<i8> = self
|
||||
.get_black_player()
|
||||
.unwrap()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|&x| x as i8)
|
||||
.collect();
|
||||
// .iter().map(|&x| x as i8) .collect()
|
||||
state.extend(black_player);
|
||||
|
||||
// ensure state has length state_len
|
||||
state.truncate(state_len);
|
||||
while state.len() < state_len {
|
||||
state.push(0);
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
/// Calculate game state id :
|
||||
pub fn to_string_id(&self) -> String {
|
||||
// Pieces placement -> 77 bits (24 + 23 + 30 max)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,15 @@ impl Player {
|
|||
self.points, self.holes, self.can_bredouille as u8, self.can_big_bredouille as u8
|
||||
)
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<u8> {
|
||||
vec![
|
||||
self.points,
|
||||
self.holes,
|
||||
self.can_bredouille as u8,
|
||||
self.can_big_bredouille as u8,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a player in the game.
|
||||
|
|
|
|||
Loading…
Reference in a new issue