wip fix train

This commit is contained in:
Henri Bourcereau 2025-05-30 20:32:00 +02:00
parent ab959fa27b
commit a2e54bc449
9 changed files with 335 additions and 229 deletions

View file

@ -1,17 +1,17 @@
use bot::strategy::dqn_trainer::{DqnTrainer};
use bot::strategy::dqn_common::DqnConfig;
use bot::strategy::dqn_trainer::DqnTrainer;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/dqn_model".to_string();
let mut save_every = 100;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
@ -54,38 +54,41 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
}
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes);
println!(" Chemin du modèle : {}", model_path);
println!(" Sauvegarde tous les {} épisodes", save_every);
println!();
// Configuration DQN
let config = DqnConfig {
input_size: 32,
state_size: 36, // state.to_vec size
hidden_size: 256,
num_actions: 3,
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.9, // Commencer avec plus d'exploration
epsilon: 0.9, // Commencer avec plus d'exploration
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
};
// Créer et lancer l'entraîneur
let mut trainer = DqnTrainer::new(config);
trainer.train(episodes, save_every, &model_path)?;
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
println!(" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy", model_path);
println!(
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
model_path
);
Ok(())
}
@ -105,4 +108,4 @@ fn print_help() {
println!(" cargo run --bin=train_dqn");
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
}
}

View file

@ -1,8 +1,8 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use store::MoveRules;
use std::path::Path;
use store::MoveRules;
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
@ -40,7 +40,7 @@ impl DqnStrategy {
/// Utilise le modèle DQN pour choisir une action
fn get_dqn_action(&self) -> Option<usize> {
if let Some(ref model) = self.model {
let state = game_state_to_vector(&self.game);
let state = self.game.to_vec_float();
Some(model.get_best_action(&state))
} else {
None
@ -52,7 +52,7 @@ impl BotStrategy for DqnStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
@ -66,8 +66,6 @@ impl BotStrategy for DqnStrategy {
}
fn calculate_points(&self) -> u8 {
// Pour l'instant, utilisation de la méthode standard
// Plus tard on pourrait utiliser le DQN pour optimiser le calcul de points
let dice_roll_count = self
.get_game()
.players
@ -96,7 +94,7 @@ impl BotStrategy for DqnStrategy {
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let chosen_move = if let Some(action) = self.get_dqn_action() {
// Utiliser l'action DQN pour choisir parmi les mouvements valides
// Action 0 = premier mouvement, action 1 = mouvement moyen, etc.
@ -107,18 +105,21 @@ impl BotStrategy for DqnStrategy {
} else {
possible_moves.len().saturating_sub(1) // Dernier mouvement
};
*possible_moves.get(move_index).unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
*possible_moves
.get(move_index)
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
} else {
// Fallback : premier mouvement valide
*possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
};
if self.color == Color::White {
chosen_move
} else {
(chosen_move.0.mirror(), chosen_move.1.mirror())
}
}
}
}

View file

@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DqnConfig {
pub input_size: usize,
pub state_size: usize,
pub hidden_size: usize,
pub num_actions: usize,
pub learning_rate: f64,
@ -18,7 +18,7 @@ pub struct DqnConfig {
impl Default for DqnConfig {
fn default() -> Self {
Self {
input_size: 32,
state_size: 36,
hidden_size: 256,
num_actions: 3,
learning_rate: 0.001,
@ -47,23 +47,35 @@ impl SimpleNeuralNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
// Initialisation aléatoire des poids avec Xavier/Glorot
let scale1 = (2.0 / input_size as f32).sqrt();
let weights1 = (0..hidden_size)
.map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect())
.map(|_| {
(0..input_size)
.map(|_| rng.gen_range(-scale1..scale1))
.collect()
})
.collect();
let biases1 = vec![0.0; hidden_size];
let scale2 = (2.0 / hidden_size as f32).sqrt();
let weights2 = (0..hidden_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect())
.map(|_| {
(0..hidden_size)
.map(|_| rng.gen_range(-scale2..scale2))
.collect()
})
.collect();
let biases2 = vec![0.0; hidden_size];
let scale3 = (2.0 / hidden_size as f32).sqrt();
let weights3 = (0..output_size)
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect())
.map(|_| {
(0..hidden_size)
.map(|_| rng.gen_range(-scale3..scale3))
.collect()
})
.collect();
let biases3 = vec![0.0; output_size];
@ -123,7 +135,10 @@ impl SimpleNeuralNetwork {
.unwrap_or(0)
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
pub fn save<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), Box<dyn std::error::Error>> {
let data = serde_json::to_string_pretty(self)?;
std::fs::write(path, data)?;
Ok(())
@ -136,47 +151,3 @@ impl SimpleNeuralNetwork {
}
}
/// Convertit l'état du jeu en vecteur d'entrée pour le réseau de neurones
pub fn game_state_to_vector(game_state: &crate::GameState) -> Vec<f32> {
use crate::Color;
let mut state = Vec::with_capacity(32);
// Plateau (24 cases)
let white_positions = game_state.board.get_color_fields(Color::White);
let black_positions = game_state.board.get_color_fields(Color::Black);
let mut board = vec![0.0; 24];
for (pos, count) in white_positions {
if pos < 24 {
board[pos] = count as f32;
}
}
for (pos, count) in black_positions {
if pos < 24 {
board[pos] = -(count as f32);
}
}
state.extend(board);
// Informations supplémentaires limitées pour respecter input_size = 32
state.push(game_state.active_player_id as f32);
state.push(game_state.dice.values.0 as f32);
state.push(game_state.dice.values.1 as f32);
// Points et trous des joueurs
if let Some(white_player) = game_state.get_white_player() {
state.push(white_player.points as f32);
state.push(white_player.holes as f32);
} else {
state.extend(vec![0.0, 0.0]);
}
// Assurer que la taille est exactement input_size
state.truncate(32);
while state.len() < 32 {
state.push(0.0);
}
state
}

View file

@ -1,10 +1,11 @@
use crate::{Color, GameState, PlayerId};
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
use super::dqn_common::{DqnConfig, SimpleNeuralNetwork};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -71,7 +72,8 @@ pub struct DqnAgent {
impl DqnAgent {
pub fn new(config: DqnConfig) -> Self {
let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions);
let model =
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
let target_model = model.clone();
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
let epsilon = config.epsilon;
@ -117,7 +119,10 @@ impl DqnAgent {
}
}
pub fn save_model<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
pub fn save_model<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), Box<dyn std::error::Error>> {
self.model.save(path)
}
@ -141,12 +146,12 @@ pub struct TrictracEnv {
pub current_step: usize,
}
impl TrictracEnv {
pub fn new() -> Self {
impl Default for TrictracEnv {
fn default() -> Self {
let mut game_state = GameState::new(false);
game_state.init_player("agent");
game_state.init_player("opponent");
Self {
game_state,
agent_player_id: 1,
@ -156,213 +161,233 @@ impl TrictracEnv {
current_step: 0,
}
}
}
impl TrictracEnv {
pub fn reset(&mut self) -> Vec<f32> {
self.game_state = GameState::new(false);
self.game_state.init_player("agent");
self.game_state.init_player("opponent");
// Commencer la partie
self.game_state.consume(&GameEvent::BeginGame { goes_first: self.agent_player_id });
self.game_state.consume(&GameEvent::BeginGame {
goes_first: self.agent_player_id,
});
self.current_step = 0;
game_state_to_vector(&self.game_state)
self.game_state.to_vec_float()
}
pub fn step(&mut self, action: usize) -> (Vec<f32>, f32, bool) {
let mut reward = 0.0;
// Appliquer l'action de l'agent
if self.game_state.active_player_id == self.agent_player_id {
reward += self.apply_agent_action(action);
}
// Faire jouer l'adversaire (stratégie simple)
while self.game_state.active_player_id == self.opponent_player_id
&& self.game_state.stage != Stage::Ended {
self.play_opponent_turn();
while self.game_state.active_player_id == self.opponent_player_id
&& self.game_state.stage != Stage::Ended
{
reward += self.play_opponent_turn();
}
// Vérifier si la partie est terminée
let done = self.game_state.stage == Stage::Ended ||
self.game_state.determine_winner().is_some() ||
self.current_step >= self.max_steps;
let done = self.game_state.stage == Stage::Ended
|| self.game_state.determine_winner().is_some()
|| self.current_step >= self.max_steps;
// Récompense finale si la partie est terminée
if done {
if let Some(winner) = self.game_state.determine_winner() {
if winner == self.agent_player_id {
reward += 10.0; // Bonus pour gagner
reward += 100.0; // Bonus pour gagner
} else {
reward -= 5.0; // Pénalité pour perdre
reward -= 50.0; // Pénalité pour perdre
}
}
}
self.current_step += 1;
let next_state = game_state_to_vector(&self.game_state);
let next_state = self.game_state.to_vec_float();
(next_state, reward, done)
}
fn apply_agent_action(&mut self, action: usize) -> f32 {
let mut reward = 0.0;
match self.game_state.turn_stage {
// TODO : déterminer event selon action ...
let event = match self.game_state.turn_stage {
TurnStage::RollDice => {
// Lancer les dés
let event = GameEvent::Roll { player_id: self.agent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
// Simuler le résultat des dés
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.agent_player_id,
dice: store::Dice { values: dice_values },
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
}
reward += 0.1;
GameEvent::Roll {
player_id: self.agent_player_id,
}
}
TurnStage::RollWaiting => {
// Simuler le résultat des dés
reward += 0.1;
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
GameEvent::RollResult {
player_id: self.agent_player_id,
dice: store::Dice {
values: dice_values,
},
}
}
TurnStage::Move => {
// Choisir un mouvement selon l'action
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let rules = MoveRules::new(
&self.agent_color,
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let move_index = if action == 0 {
0
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2
} else {
possible_moves.len().saturating_sub(1)
};
let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]);
let event = GameEvent::Move {
player_id: self.agent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.2;
} else {
reward -= 1.0; // Pénalité pour mouvement invalide
}
// TODO : choix d'action
let move_index = if action == 0 {
0
} else if action == 1 && possible_moves.len() > 1 {
possible_moves.len() / 2
} else {
possible_moves.len().saturating_sub(1)
};
let moves = *possible_moves.get(move_index).unwrap_or(&possible_moves[0]);
GameEvent::Move {
player_id: self.agent_player_id,
moves,
}
}
TurnStage::MarkPoints => {
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
// Calculer et marquer les points
let dice_roll_count = self.game_state.players.get(&self.agent_player_id).unwrap().dice_roll_count;
let points_rules = PointsRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let dice_roll_count = self
.game_state
.players
.get(&self.agent_player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(
&self.agent_color,
&self.game_state.board,
self.game_state.dice,
);
let points = points_rules.get_points(dice_roll_count).0;
let event = GameEvent::Mark {
reward += 0.3 * points as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.agent_player_id,
points,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.1 * points as f32; // Récompense proportionnelle aux points
}
}
TurnStage::HoldOrGoChoice => {
// Décider de continuer ou pas selon l'action
if action == 2 { // Action "go"
let event = GameEvent::Go { player_id: self.agent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.1;
if action == 2 {
// Action "go"
GameEvent::Go {
player_id: self.agent_player_id,
}
} else {
// Passer son tour en jouant un mouvement
let rules = MoveRules::new(&self.agent_color, &self.game_state.board, self.game_state.dice);
let rules = MoveRules::new(
&self.agent_color,
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let moves = possible_moves[0];
let event = GameEvent::Move {
player_id: self.agent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
let moves = possible_moves[0];
GameEvent::Move {
player_id: self.agent_player_id,
moves,
}
}
}
_ => {}
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
reward += 0.2;
} else {
reward -= 1.0; // Pénalité pour action invalide
}
reward
}
fn play_opponent_turn(&mut self) {
match self.game_state.turn_stage {
TurnStage::RollDice => {
let event = GameEvent::Roll { player_id: self.opponent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.opponent_player_id,
dice: store::Dice { values: dice_values },
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
}
// TODO : use default bot strategy
fn play_opponent_turn(&mut self) -> f32 {
let mut reward = 0.0;
let event = match self.game_state.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_player_id,
},
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
GameEvent::RollResult {
player_id: self.opponent_player_id,
dice: store::Dice {
values: dice_values,
},
}
}
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
let opponent_color = self.agent_color.opponent_color();
let dice_roll_count = self
.game_state
.players
.get(&self.opponent_player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(
&opponent_color,
&self.game_state.board,
self.game_state.dice,
);
let points = points_rules.get_points(dice_roll_count).0;
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_player_id,
points,
}
}
TurnStage::Move => {
let opponent_color = self.agent_color.opponent_color();
let rules = MoveRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
let rules = MoveRules::new(
&opponent_color,
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if !possible_moves.is_empty() {
let moves = possible_moves[0]; // Stratégie simple : premier mouvement
let event = GameEvent::Move {
player_id: self.opponent_player_id,
moves,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
}
}
TurnStage::MarkPoints => {
let opponent_color = self.agent_color.opponent_color();
let dice_roll_count = self.game_state.players.get(&self.opponent_player_id).unwrap().dice_roll_count;
let points_rules = PointsRules::new(&opponent_color, &self.game_state.board, self.game_state.dice);
let points = points_rules.get_points(dice_roll_count).0;
let event = GameEvent::Mark {
// Stratégie simple : choix aléatoire
let mut rng = thread_rng();
let choosen_move = *possible_moves.choose(&mut rng).unwrap();
GameEvent::Move {
player_id: self.opponent_player_id,
points,
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
moves: if opponent_color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
},
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
let event = GameEvent::Go { player_id: self.opponent_player_id };
if self.game_state.validate(&event) {
self.game_state.consume(&event);
GameEvent::Go {
player_id: self.opponent_player_id,
}
}
_ => {}
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
reward
}
}
@ -376,14 +401,14 @@ impl DqnTrainer {
pub fn new(config: DqnConfig) -> Self {
Self {
agent: DqnAgent::new(config),
env: TrictracEnv::new(),
env: TrictracEnv::default(),
}
}
pub fn train_episode(&mut self) -> f32 {
let mut total_reward = 0.0;
let mut state = self.env.reset();
loop {
let action = self.agent.select_action(&state);
let (next_state, reward, done) = self.env.step(action);
@ -408,31 +433,40 @@ impl DqnTrainer {
total_reward
}
pub fn train(&mut self, episodes: usize, save_every: usize, model_path: &str) -> Result<(), Box<dyn std::error::Error>> {
pub fn train(
&mut self,
episodes: usize,
save_every: usize,
model_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
for episode in 1..=episodes {
let reward = self.train_episode();
if episode % 100 == 0 {
println!(
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
episode, episodes, reward, self.agent.get_epsilon(), self.agent.get_step_count()
episode,
episodes,
reward,
self.agent.get_epsilon(),
self.agent.get_step_count()
);
}
if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode);
self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path);
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path);
self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path);
Ok(())
}
}
}

View file

@ -1,5 +1,4 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use store::MoveRules;
#[derive(Debug)]
pub struct ErroneousStrategy {

View file

@ -18,4 +18,5 @@ pythonlib:
maturin build -m store/Cargo.toml --release
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot:
python ./store/python/trainModel.py
#python ./store/python/trainModel.py
cargo run --bin=train_dqn

View file

@ -153,6 +153,10 @@ impl Board {
.unsigned_abs()
}
pub fn to_vec(&self) -> Vec<i8> {
self.positions.to_vec()
}
// maybe todo : operate on bits (cf. https://github.com/bungogood/bkgm/blob/a2fb3f395243bcb0bc9f146df73413f73f5ea1e0/src/position.rs#L217)
pub fn to_gnupg_pos_id(&self) -> String {
// Pieces placement -> 77bits (24 + 23 + 30 max)

View file

@ -32,6 +32,33 @@ pub enum TurnStage {
MarkAdvPoints,
}
impl From<u8> for TurnStage {
fn from(item: u8) -> Self {
match item {
0 => TurnStage::RollWaiting,
1 => TurnStage::RollDice,
2 => TurnStage::MarkPoints,
3 => TurnStage::HoldOrGoChoice,
4 => TurnStage::Move,
5 => TurnStage::MarkAdvPoints,
_ => TurnStage::RollWaiting,
}
}
}
impl From<TurnStage> for u8 {
fn from(stage: TurnStage) -> u8 {
match stage {
TurnStage::RollWaiting => 0,
TurnStage::RollDice => 1,
TurnStage::MarkPoints => 2,
TurnStage::HoldOrGoChoice => 3,
TurnStage::Move => 4,
TurnStage::MarkAdvPoints => 5,
}
}
}
/// Represents a TricTrac game
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GameState {
@ -117,6 +144,63 @@ impl GameState {
// accessors
// -------------------------------------------------------------------------
pub fn to_vec_float(&self) -> Vec<f32> {
self.to_vec().iter().map(|&x| x as f32).collect()
}
/// Get state as a vector (to be used for bot training input) :
/// length = 36
pub fn to_vec(&self) -> Vec<i8> {
let state_len = 36;
let mut state = Vec::with_capacity(state_len);
// length = 24
state.extend(self.board.to_vec());
// active player -> length = 1
// white : 0 (false)
// black : 1 (true)
state.push(
self.who_plays()
.map(|player| if player.color == Color::Black { 1 } else { 0 })
.unwrap_or(0), // White by default
);
// step -> length = 1
let turn_stage: u8 = self.turn_stage.into();
state.push(turn_stage as i8);
// dice roll -> length = 2
state.push(self.dice.values.0 as i8);
state.push(self.dice.values.1 as i8);
// points length=4 x2 joueurs = 8
let white_player: Vec<i8> = self
.get_white_player()
.unwrap()
.to_vec()
.iter()
.map(|&x| x as i8)
.collect();
state.extend(white_player);
let black_player: Vec<i8> = self
.get_black_player()
.unwrap()
.to_vec()
.iter()
.map(|&x| x as i8)
.collect();
// .iter().map(|&x| x as i8) .collect()
state.extend(black_player);
// ensure state has length state_len
state.truncate(state_len);
while state.len() < state_len {
state.push(0);
}
state
}
/// Calculate game state id :
pub fn to_string_id(&self) -> String {
// Pieces placement -> 77 bits (24 + 23 + 30 max)

View file

@ -52,6 +52,15 @@ impl Player {
self.points, self.holes, self.can_bredouille as u8, self.can_big_bredouille as u8
)
}
pub fn to_vec(&self) -> Vec<u8> {
vec![
self.points,
self.holes,
self.can_bredouille as u8,
self.can_big_bredouille as u8,
]
}
}
/// Represents a player in the game.