diff --git a/bot/src/bin/train_dqn.rs b/bot/src/bin/train_dqn.rs index 05267f3..abff8d0 100644 --- a/bot/src/bin/train_dqn.rs +++ b/bot/src/bin/train_dqn.rs @@ -1,17 +1,17 @@ -use bot::strategy::dqn_trainer::{DqnTrainer}; use bot::strategy::dqn_common::DqnConfig; +use bot::strategy::dqn_trainer::DqnTrainer; use std::env; fn main() -> Result<(), Box> { env_logger::init(); - + let args: Vec = env::args().collect(); - + // Paramètres par défaut let mut episodes = 1000; let mut model_path = "models/dqn_model".to_string(); let mut save_every = 100; - + // Parser les arguments de ligne de commande let mut i = 1; while i < args.len() { @@ -54,38 +54,41 @@ fn main() -> Result<(), Box> { } } } - + // Créer le dossier models s'il n'existe pas std::fs::create_dir_all("models")?; - + println!("Configuration d'entraînement DQN :"); println!(" Épisodes : {}", episodes); println!(" Chemin du modèle : {}", model_path); println!(" Sauvegarde tous les {} épisodes", save_every); println!(); - + // Configuration DQN let config = DqnConfig { - input_size: 32, + state_size: 36, // state.to_vec size hidden_size: 256, num_actions: 3, learning_rate: 0.001, gamma: 0.99, - epsilon: 0.9, // Commencer avec plus d'exploration + epsilon: 0.9, // Commencer avec plus d'exploration epsilon_decay: 0.995, epsilon_min: 0.01, replay_buffer_size: 10000, batch_size: 32, }; - + // Créer et lancer l'entraîneur let mut trainer = DqnTrainer::new(config); trainer.train(episodes, save_every, &model_path)?; - + println!("Entraînement terminé avec succès !"); println!("Pour utiliser le modèle entraîné :"); - println!(" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy", model_path); - + println!( + " cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy", + model_path + ); + Ok(()) } @@ -105,4 +108,4 @@ fn print_help() { println!(" cargo run --bin=train_dqn"); println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500"); println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000"); -} \ No newline at end of file +} diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 71f9863..bd4e233 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,8 +1,8 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; -use store::MoveRules; use std::path::Path; +use store::MoveRules; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector}; +use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] @@ -40,7 +40,7 @@ impl DqnStrategy { /// Utilise le modèle DQN pour choisir une action fn get_dqn_action(&self) -> Option { if let Some(ref model) = self.model { - let state = game_state_to_vector(&self.game); + let state = self.game.to_vec_float(); Some(model.get_best_action(&state)) } else { None @@ -52,7 +52,7 @@ impl BotStrategy for DqnStrategy { fn get_game(&self) -> &GameState { &self.game } - + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } @@ -66,8 +66,6 @@ impl BotStrategy for DqnStrategy { } fn calculate_points(&self) -> u8 { - // Pour l'instant, utilisation de la méthode standard - // Plus tard on pourrait utiliser le DQN pour optimiser le calcul de points let dice_roll_count = self .get_game() .players @@ -96,7 +94,7 @@ impl BotStrategy for DqnStrategy { fn choose_move(&self) -> (CheckerMove, CheckerMove) { let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - + let chosen_move = if let Some(action) = self.get_dqn_action() { // Utiliser l'action DQN pour choisir parmi les mouvements valides // Action 0 = premier mouvement, action 1 = mouvement moyen, etc. @@ -107,18 +105,21 @@ impl BotStrategy for DqnStrategy { } else { possible_moves.len().saturating_sub(1) // Dernier mouvement }; - *possible_moves.get(move_index).unwrap_or(&(CheckerMove::default(), CheckerMove::default())) + *possible_moves + .get(move_index) + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())) } else { // Fallback : premier mouvement valide *possible_moves .first() .unwrap_or(&(CheckerMove::default(), CheckerMove::default())) }; - + if self.color == Color::White { chosen_move } else { (chosen_move.0.mirror(), chosen_move.1.mirror()) } } -} \ No newline at end of file +} + diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 6dfe991..ec53912 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; /// Configuration pour l'agent DQN #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DqnConfig { - pub input_size: usize, + pub state_size: usize, pub hidden_size: usize, pub num_actions: usize, pub learning_rate: f64, @@ -18,7 +18,7 @@ pub struct DqnConfig { impl Default for DqnConfig { fn default() -> Self { Self { - input_size: 32, + state_size: 36, hidden_size: 256, num_actions: 3, learning_rate: 0.001, @@ -47,23 +47,35 @@ impl SimpleNeuralNetwork { pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { use rand::{thread_rng, Rng}; let mut rng = thread_rng(); - + // Initialisation aléatoire des poids avec Xavier/Glorot let scale1 = (2.0 / input_size as f32).sqrt(); let weights1 = (0..hidden_size) - .map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect()) + .map(|_| { + (0..input_size) + .map(|_| rng.gen_range(-scale1..scale1)) + .collect() + }) .collect(); let biases1 = vec![0.0; hidden_size]; - + let scale2 = (2.0 / hidden_size as f32).sqrt(); let weights2 = (0..hidden_size) - .map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect()) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale2..scale2)) + .collect() + }) .collect(); let biases2 = vec![0.0; hidden_size]; - + let scale3 = (2.0 / hidden_size as f32).sqrt(); let weights3 = (0..output_size) - .map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect()) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale3..scale3)) + .collect() + }) .collect(); let biases3 = vec![0.0; output_size]; @@ -123,7 +135,10 @@ impl SimpleNeuralNetwork { .unwrap_or(0) } - pub fn save>(&self, path: P) -> Result<(), Box> { + pub fn save>( + &self, + path: P, + ) -> Result<(), Box> { let data = serde_json::to_string_pretty(self)?; std::fs::write(path, data)?; Ok(()) @@ -136,47 +151,3 @@ impl SimpleNeuralNetwork { } } -/// Convertit l'état du jeu en vecteur d'entrée pour le réseau de neurones -pub fn game_state_to_vector(game_state: &crate::GameState) -> Vec { - use crate::Color; - - let mut state = Vec::with_capacity(32); - - // Plateau (24 cases) - let white_positions = game_state.board.get_color_fields(Color::White); - let black_positions = game_state.board.get_color_fields(Color::Black); - - let mut board = vec![0.0; 24]; - for (pos, count) in white_positions { - if pos < 24 { - board[pos] = count as f32; - } - } - for (pos, count) in black_positions { - if pos < 24 { - board[pos] = -(count as f32); - } - } - state.extend(board); - - // Informations supplémentaires limitées pour respecter input_size = 32 - state.push(game_state.active_player_id as f32); - state.push(game_state.dice.values.0 as f32); - state.push(game_state.dice.values.1 as f32); - - // Points et trous des joueurs - if let Some(white_player) = game_state.get_white_player() { - state.push(white_player.points as f32); - state.push(white_player.holes as f32); - } else { - state.extend(vec![0.0, 0.0]); - } - - // Assurer que la taille est exactement input_size - state.truncate(32); - while state.len() < 32 { - state.push(0.0); - } - - state -} \ No newline at end of file diff --git a/bot/src/strategy/dqn_trainer.rs b/bot/src/strategy/dqn_trainer.rs index abdbbe7..53092eb 100644 --- a/bot/src/strategy/dqn_trainer.rs +++ b/bot/src/strategy/dqn_trainer.rs @@ -1,10 +1,11 @@ use crate::{Color, GameState, PlayerId}; -use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; +use rand::prelude::SliceRandom; use rand::{thread_rng, Rng}; -use std::collections::VecDeque; use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector}; +use super::dqn_common::{DqnConfig, SimpleNeuralNetwork}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] @@ -71,7 +72,8 @@ pub struct DqnAgent { impl DqnAgent { pub fn new(config: DqnConfig) -> Self { - let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions); + let model = + SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions); let target_model = model.clone(); let replay_buffer = ReplayBuffer::new(config.replay_buffer_size); let epsilon = config.epsilon; @@ -117,7 +119,10 @@ impl DqnAgent { } } - pub fn save_model>(&self, path: P) -> Result<(), Box> { + pub fn save_model>( + &self, + path: P, + ) -> Result<(), Box> { self.model.save(path) } @@ -141,12 +146,12 @@ pub struct TrictracEnv { pub current_step: usize, } -impl TrictracEnv { - pub fn new() -> Self { +impl Default for TrictracEnv { + fn default() -> Self { let mut game_state = GameState::new(false); game_state.init_player("agent"); game_state.init_player("opponent"); - + Self { game_state, agent_player_id: 1, @@ -156,213 +161,233 @@ impl TrictracEnv { current_step: 0, } } +} +impl TrictracEnv { pub fn reset(&mut self) -> Vec { self.game_state = GameState::new(false); self.game_state.init_player("agent"); self.game_state.init_player("opponent"); - + // Commencer la partie - self.game_state.consume(&GameEvent::BeginGame { goes_first: self.agent_player_id }); - + self.game_state.consume(&GameEvent::BeginGame { + goes_first: self.agent_player_id, + }); + self.current_step = 0; - game_state_to_vector(&self.game_state) + self.game_state.to_vec_float() } pub fn step(&mut self, action: usize) -> (Vec, f32, bool) { let mut reward = 0.0; - + // Appliquer l'action de l'agent if self.game_state.active_player_id == self.agent_player_id { reward += self.apply_agent_action(action); } - + // Faire jouer l'adversaire (stratégie simple) - while self.game_state.active_player_id == self.opponent_player_id - && self.game_state.stage != Stage::Ended { - self.play_opponent_turn(); + while self.game_state.active_player_id == self.opponent_player_id + && self.game_state.stage != Stage::Ended + { + reward += self.play_opponent_turn(); } - + // Vérifier si la partie est terminée - let done = self.game_state.stage == Stage::Ended || - self.game_state.determine_winner().is_some() || - self.current_step >= self.max_steps; + let done = self.game_state.stage == Stage::Ended + || self.game_state.determine_winner().is_some() + || self.current_step >= self.max_steps; // Récompense finale si la partie est terminée if done { if let Some(winner) = self.game_state.determine_winner() { if winner == self.agent_player_id { - reward += 10.0; // Bonus pour gagner + reward += 100.0; // Bonus pour gagner } else { - reward -= 5.0; // Pénalité pour perdre + reward -= 50.0; // Pénalité pour perdre } } } self.current_step += 1; - let next_state = game_state_to_vector(&self.game_state); - + let next_state = self.game_state.to_vec_float(); (next_state, reward, done) } fn apply_agent_action(&mut self, action: usize) -> f32 { let mut reward = 0.0; - - match self.game_state.turn_stage { + + // TODO : déterminer event selon action ... + + let event = match self.game_state.turn_stage { TurnStage::RollDice => { // Lancer les dés - let event = GameEvent::Roll { player_id: self.agent_player_id }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); - - // Simuler le résultat des dés - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.agent_player_id, - dice: store::Dice { values: dice_values }, - }; - if self.game_state.validate(&dice_event) { - self.game_state.consume(&dice_event); - } - reward += 0.1; + GameEvent::Roll { + player_id: self.agent_player_id, + } + } + TurnStage::RollWaiting => { + // Simuler le résultat des dés + reward += 0.1; + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.agent_player_id, + dice: store::Dice { + values: dice_values, + }, } } TurnStage::Move => { // Choisir un mouvement selon l'action - let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice); + let rules = MoveRules::new( + &self.agent_color, + &self.game_state.board, + self.game_state.dice, + ); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - if !possible_moves.is_empty() { - let move_index = if action == 0 { - 0 - } else if action == 1 && possible_moves.len() > 1 { - possible_moves.len() / 2 - } else { - possible_moves.len().saturating_sub(1) - }; - - let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]); - let event = GameEvent::Move { - player_id: self.agent_player_id, - moves, - }; - - if self.game_state.validate(&event) { - self.game_state.consume(&event); - reward += 0.2; - } else { - reward -= 1.0; // Pénalité pour mouvement invalide - } + + // TODO : choix d'action + let move_index = if action == 0 { + 0 + } else if action == 1 && possible_moves.len() > 1 { + possible_moves.len() / 2 + } else { + possible_moves.len().saturating_sub(1) + }; + + let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]); + GameEvent::Move { + player_id: self.agent_player_id, + moves, } } - TurnStage::MarkPoints => { + TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { // Calculer et marquer les points - let dice_roll_count = self.game_state.players.get(&self.agent_player_id).unwrap().dice_roll_count; - let points_rules = PointsRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice); + let dice_roll_count = self + .game_state + .players + .get(&self.agent_player_id) + .unwrap() + .dice_roll_count; + let points_rules = PointsRules::new( + &self.agent_color, + &self.game_state.board, + self.game_state.dice, + ); let points = points_rules.get_points(dice_roll_count).0; - - let event = GameEvent::Mark { + + reward += 0.3 * points as f32; // Récompense proportionnelle aux points + GameEvent::Mark { player_id: self.agent_player_id, points, - }; - - if self.game_state.validate(&event) { - self.game_state.consume(&event); - reward += 0.1 * points as f32; // Récompense proportionnelle aux points } } TurnStage::HoldOrGoChoice => { // Décider de continuer ou pas selon l'action - if action == 2 { // Action "go" - let event = GameEvent::Go { player_id: self.agent_player_id }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); - reward += 0.1; + if action == 2 { + // Action "go" + GameEvent::Go { + player_id: self.agent_player_id, } } else { // Passer son tour en jouant un mouvement - let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice); + let rules = MoveRules::new( + &self.agent_color, + &self.game_state.board, + self.game_state.dice, + ); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - if !possible_moves.is_empty() { - let moves = possible_moves[0]; - let event = GameEvent::Move { - player_id: self.agent_player_id, - moves, - }; - - if self.game_state.validate(&event) { - self.game_state.consume(&event); - } + + let moves = possible_moves[0]; + GameEvent::Move { + player_id: self.agent_player_id, + moves, } } } - _ => {} + }; + + if self.game_state.validate(&event) { + self.game_state.consume(&event); + reward += 0.2; + } else { + reward -= 1.0; // Pénalité pour action invalide } - reward } - fn play_opponent_turn(&mut self) { - match self.game_state.turn_stage { - TurnStage::RollDice => { - let event = GameEvent::Roll { player_id: self.opponent_player_id }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); - - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.opponent_player_id, - dice: store::Dice { values: dice_values }, - }; - if self.game_state.validate(&dice_event) { - self.game_state.consume(&dice_event); - } + // TODO : use default bot strategy + fn play_opponent_turn(&mut self) -> f32 { + let mut reward = 0.0; + let event = match self.game_state.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_player_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_player_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { + let opponent_color = self.agent_color.opponent_color(); + let dice_roll_count = self + .game_state + .players + .get(&self.opponent_player_id) + .unwrap() + .dice_roll_count; + let points_rules = PointsRules::new( + &opponent_color, + &self.game_state.board, + self.game_state.dice, + ); + let points = points_rules.get_points(dice_roll_count).0; + reward -= 0.3 * points as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_player_id, + points, } } TurnStage::Move => { let opponent_color = self.agent_color.opponent_color(); - let rules = MoveRules::new(&opponent_color, &self.game_state.board, self.game_state.dice); + let rules = MoveRules::new( + &opponent_color, + &self.game_state.board, + self.game_state.dice, + ); let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - if !possible_moves.is_empty() { - let moves = possible_moves[0]; // Stratégie simple : premier mouvement - let event = GameEvent::Move { - player_id: self.opponent_player_id, - moves, - }; - - if self.game_state.validate(&event) { - self.game_state.consume(&event); - } - } - } - TurnStage::MarkPoints => { - let opponent_color = self.agent_color.opponent_color(); - let dice_roll_count = self.game_state.players.get(&self.opponent_player_id).unwrap().dice_roll_count; - let points_rules = PointsRules::new(&opponent_color, &self.game_state.board, self.game_state.dice); - let points = points_rules.get_points(dice_roll_count).0; - - let event = GameEvent::Mark { + + // Stratégie simple : choix aléatoire + let mut rng = thread_rng(); + let choosen_move = *possible_moves.choose(&mut rng).unwrap(); + + GameEvent::Move { player_id: self.opponent_player_id, - points, - }; - - if self.game_state.validate(&event) { - self.game_state.consume(&event); + moves: if opponent_color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + }, } } TurnStage::HoldOrGoChoice => { // Stratégie simple : toujours continuer - let event = GameEvent::Go { player_id: self.opponent_player_id }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); + GameEvent::Go { + player_id: self.opponent_player_id, } } - _ => {} + }; + if self.game_state.validate(&event) { + self.game_state.consume(&event); } + reward } } @@ -376,14 +401,14 @@ impl DqnTrainer { pub fn new(config: DqnConfig) -> Self { Self { agent: DqnAgent::new(config), - env: TrictracEnv::new(), + env: TrictracEnv::default(), } } pub fn train_episode(&mut self) -> f32 { let mut total_reward = 0.0; let mut state = self.env.reset(); - + loop { let action = self.agent.select_action(&state); let (next_state, reward, done) = self.env.step(action); @@ -408,31 +433,40 @@ impl DqnTrainer { total_reward } - pub fn train(&mut self, episodes: usize, save_every: usize, model_path: &str) -> Result<(), Box> { + pub fn train( + &mut self, + episodes: usize, + save_every: usize, + model_path: &str, + ) -> Result<(), Box> { println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes); - + for episode in 1..=episodes { let reward = self.train_episode(); - + if episode % 100 == 0 { println!( "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", - episode, episodes, reward, self.agent.get_epsilon(), self.agent.get_step_count() + episode, + episodes, + reward, + self.agent.get_epsilon(), + self.agent.get_step_count() ); } - + if episode % save_every == 0 { let save_path = format!("{}_episode_{}.json", model_path, episode); self.agent.save_model(&save_path)?; println!("Modèle sauvegardé : {}", save_path); } } - + // Sauvegarder le modèle final let final_path = format!("{}_final.json", model_path); self.agent.save_model(&final_path)?; println!("Modèle final sauvegardé : {}", final_path); - + Ok(()) } -} \ No newline at end of file +} diff --git a/bot/src/strategy/erroneous_moves.rs b/bot/src/strategy/erroneous_moves.rs index 3f26f28..f57ec6c 100644 --- a/bot/src/strategy/erroneous_moves.rs +++ b/bot/src/strategy/erroneous_moves.rs @@ -1,5 +1,4 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; -use store::MoveRules; #[derive(Debug)] pub struct ErroneousStrategy { diff --git a/justfile b/justfile index d7b450c..4d75790 100644 --- a/justfile +++ b/justfile @@ -18,4 +18,5 @@ pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainbot: - python ./store/python/trainModel.py + #python ./store/python/trainModel.py + cargo run --bin=train_dqn diff --git a/store/src/board.rs b/store/src/board.rs index ced30e4..ada22c9 100644 --- a/store/src/board.rs +++ b/store/src/board.rs @@ -153,6 +153,10 @@ impl Board { .unsigned_abs() } + pub fn to_vec(&self) -> Vec { + self.positions.to_vec() + } + // maybe todo : operate on bits (cf. https://github.com/bungogood/bkgm/blob/a2fb3f395243bcb0bc9f146df73413f73f5ea1e0/src/position.rs#L217) pub fn to_gnupg_pos_id(&self) -> String { // Pieces placement -> 77bits (24 + 23 + 30 max) diff --git a/store/src/game.rs b/store/src/game.rs index 65a23e3..1ef8a39 100644 --- a/store/src/game.rs +++ b/store/src/game.rs @@ -32,6 +32,33 @@ pub enum TurnStage { MarkAdvPoints, } +impl From for TurnStage { + fn from(item: u8) -> Self { + match item { + 0 => TurnStage::RollWaiting, + 1 => TurnStage::RollDice, + 2 => TurnStage::MarkPoints, + 3 => TurnStage::HoldOrGoChoice, + 4 => TurnStage::Move, + 5 => TurnStage::MarkAdvPoints, + _ => TurnStage::RollWaiting, + } + } +} + +impl From for u8 { + fn from(stage: TurnStage) -> u8 { + match stage { + TurnStage::RollWaiting => 0, + TurnStage::RollDice => 1, + TurnStage::MarkPoints => 2, + TurnStage::HoldOrGoChoice => 3, + TurnStage::Move => 4, + TurnStage::MarkAdvPoints => 5, + } + } +} + /// Represents a TricTrac game #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct GameState { @@ -117,6 +144,63 @@ impl GameState { // accessors // ------------------------------------------------------------------------- + pub fn to_vec_float(&self) -> Vec { + self.to_vec().iter().map(|&x| x as f32).collect() + } + + /// Get state as a vector (to be used for bot training input) : + /// length = 36 + pub fn to_vec(&self) -> Vec { + let state_len = 36; + let mut state = Vec::with_capacity(state_len); + + // length = 24 + state.extend(self.board.to_vec()); + + // active player -> length = 1 + // white : 0 (false) + // black : 1 (true) + state.push( + self.who_plays() + .map(|player| if player.color == Color::Black { 1 } else { 0 }) + .unwrap_or(0), // White by default + ); + + // step -> length = 1 + let turn_stage: u8 = self.turn_stage.into(); + state.push(turn_stage as i8); + + // dice roll -> length = 2 + state.push(self.dice.values.0 as i8); + state.push(self.dice.values.1 as i8); + + // points length=4 x2 joueurs = 8 + let white_player: Vec = self + .get_white_player() + .unwrap() + .to_vec() + .iter() + .map(|&x| x as i8) + .collect(); + state.extend(white_player); + let black_player: Vec = self + .get_black_player() + .unwrap() + .to_vec() + .iter() + .map(|&x| x as i8) + .collect(); + // .iter().map(|&x| x as i8) .collect() + state.extend(black_player); + + // ensure state has length state_len + state.truncate(state_len); + while state.len() < state_len { + state.push(0); + } + state + } + /// Calculate game state id : pub fn to_string_id(&self) -> String { // Pieces placement -> 77 bits (24 + 23 + 30 max) diff --git a/store/src/player.rs b/store/src/player.rs index 54f8cf6..cf31953 100644 --- a/store/src/player.rs +++ b/store/src/player.rs @@ -52,6 +52,15 @@ impl Player { self.points, self.holes, self.can_bredouille as u8, self.can_big_bredouille as u8 ) } + + pub fn to_vec(&self) -> Vec { + vec![ + self.points, + self.holes, + self.can_bredouille as u8, + self.can_big_bredouille as u8, + ] + } } /// Represents a player in the game.