504 lines
15 KiB
Rust
504 lines
15 KiB
Rust
|
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||
|
|
use store::MoveRules;
|
||
|
|
use rand::{thread_rng, Rng};
|
||
|
|
use std::collections::VecDeque;
|
||
|
|
use std::path::Path;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
/// Configuration pour l'agent DQN
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct DqnConfig {
|
||
|
|
pub input_size: usize,
|
||
|
|
pub hidden_size: usize,
|
||
|
|
pub num_actions: usize,
|
||
|
|
pub learning_rate: f64,
|
||
|
|
pub gamma: f64,
|
||
|
|
pub epsilon: f64,
|
||
|
|
pub epsilon_decay: f64,
|
||
|
|
pub epsilon_min: f64,
|
||
|
|
pub replay_buffer_size: usize,
|
||
|
|
pub batch_size: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for DqnConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
input_size: 32,
|
||
|
|
hidden_size: 256,
|
||
|
|
num_actions: 3,
|
||
|
|
learning_rate: 0.001,
|
||
|
|
gamma: 0.99,
|
||
|
|
epsilon: 0.1,
|
||
|
|
epsilon_decay: 0.995,
|
||
|
|
epsilon_min: 0.01,
|
||
|
|
replay_buffer_size: 10000,
|
||
|
|
batch_size: 32,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct SimpleNeuralNetwork {
|
||
|
|
weights1: Vec<Vec<f32>>,
|
||
|
|
biases1: Vec<f32>,
|
||
|
|
weights2: Vec<Vec<f32>>,
|
||
|
|
biases2: Vec<f32>,
|
||
|
|
weights3: Vec<Vec<f32>>,
|
||
|
|
biases3: Vec<f32>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl SimpleNeuralNetwork {
|
||
|
|
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||
|
|
let mut rng = thread_rng();
|
||
|
|
|
||
|
|
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||
|
|
let scale1 = (2.0 / input_size as f32).sqrt();
|
||
|
|
let weights1 = (0..hidden_size)
|
||
|
|
.map(|_| (0..input_size).map(|_| rng.gen_range(-scale1..scale1)).collect())
|
||
|
|
.collect();
|
||
|
|
let biases1 = vec![0.0; hidden_size];
|
||
|
|
|
||
|
|
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||
|
|
let weights2 = (0..hidden_size)
|
||
|
|
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale2..scale2)).collect())
|
||
|
|
.collect();
|
||
|
|
let biases2 = vec![0.0; hidden_size];
|
||
|
|
|
||
|
|
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||
|
|
let weights3 = (0..output_size)
|
||
|
|
.map(|_| (0..hidden_size).map(|_| rng.gen_range(-scale3..scale3)).collect())
|
||
|
|
.collect();
|
||
|
|
let biases3 = vec![0.0; output_size];
|
||
|
|
|
||
|
|
Self {
|
||
|
|
weights1,
|
||
|
|
biases1,
|
||
|
|
weights2,
|
||
|
|
biases2,
|
||
|
|
weights3,
|
||
|
|
biases3,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||
|
|
// Première couche
|
||
|
|
let mut layer1: Vec<f32> = self.biases1.clone();
|
||
|
|
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||
|
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||
|
|
if j < input.len() {
|
||
|
|
layer1[i] += input[j] * weight;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
layer1[i] = layer1[i].max(0.0); // ReLU
|
||
|
|
}
|
||
|
|
|
||
|
|
// Deuxième couche
|
||
|
|
let mut layer2: Vec<f32> = self.biases2.clone();
|
||
|
|
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||
|
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||
|
|
if j < layer1.len() {
|
||
|
|
layer2[i] += layer1[j] * weight;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
layer2[i] = layer2[i].max(0.0); // ReLU
|
||
|
|
}
|
||
|
|
|
||
|
|
// Couche de sortie
|
||
|
|
let mut output: Vec<f32> = self.biases3.clone();
|
||
|
|
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||
|
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||
|
|
if j < layer2.len() {
|
||
|
|
output[i] += layer2[j] * weight;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
output
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||
|
|
let q_values = self.forward(input);
|
||
|
|
q_values
|
||
|
|
.iter()
|
||
|
|
.enumerate()
|
||
|
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||
|
|
.map(|(index, _)| index)
|
||
|
|
.unwrap_or(0)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Expérience pour le buffer de replay
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct Experience {
|
||
|
|
pub state: Vec<f32>,
|
||
|
|
pub action: usize,
|
||
|
|
pub reward: f32,
|
||
|
|
pub next_state: Vec<f32>,
|
||
|
|
pub done: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Buffer de replay pour stocker les expériences
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct ReplayBuffer {
|
||
|
|
buffer: VecDeque<Experience>,
|
||
|
|
capacity: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ReplayBuffer {
|
||
|
|
pub fn new(capacity: usize) -> Self {
|
||
|
|
Self {
|
||
|
|
buffer: VecDeque::with_capacity(capacity),
|
||
|
|
capacity,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn push(&mut self, experience: Experience) {
|
||
|
|
if self.buffer.len() >= self.capacity {
|
||
|
|
self.buffer.pop_front();
|
||
|
|
}
|
||
|
|
self.buffer.push_back(experience);
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||
|
|
let mut rng = thread_rng();
|
||
|
|
let len = self.buffer.len();
|
||
|
|
if len < batch_size {
|
||
|
|
return self.buffer.iter().cloned().collect();
|
||
|
|
}
|
||
|
|
|
||
|
|
let mut batch = Vec::with_capacity(batch_size);
|
||
|
|
for _ in 0..batch_size {
|
||
|
|
let idx = rng.gen_range(0..len);
|
||
|
|
batch.push(self.buffer[idx].clone());
|
||
|
|
}
|
||
|
|
batch
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn len(&self) -> usize {
|
||
|
|
self.buffer.len()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Agent DQN pour l'apprentissage par renforcement
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct DqnAgent {
|
||
|
|
config: DqnConfig,
|
||
|
|
model: SimpleNeuralNetwork,
|
||
|
|
target_model: SimpleNeuralNetwork,
|
||
|
|
replay_buffer: ReplayBuffer,
|
||
|
|
epsilon: f64,
|
||
|
|
step_count: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl DqnAgent {
|
||
|
|
pub fn new(config: DqnConfig) -> Self {
|
||
|
|
let model = SimpleNeuralNetwork::new(config.input_size, config.hidden_size, config.num_actions);
|
||
|
|
let target_model = model.clone();
|
||
|
|
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||
|
|
let epsilon = config.epsilon;
|
||
|
|
|
||
|
|
Self {
|
||
|
|
config,
|
||
|
|
model,
|
||
|
|
target_model,
|
||
|
|
replay_buffer,
|
||
|
|
epsilon,
|
||
|
|
step_count: 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn select_action(&mut self, state: &[f32]) -> usize {
|
||
|
|
let mut rng = thread_rng();
|
||
|
|
if rng.gen::<f64>() < self.epsilon {
|
||
|
|
// Exploration : action aléatoire
|
||
|
|
rng.gen_range(0..self.config.num_actions)
|
||
|
|
} else {
|
||
|
|
// Exploitation : meilleure action selon le modèle
|
||
|
|
self.model.get_best_action(state)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn store_experience(&mut self, experience: Experience) {
|
||
|
|
self.replay_buffer.push(experience);
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn train(&mut self) {
|
||
|
|
if self.replay_buffer.len() < self.config.batch_size {
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||
|
|
// Dans une implémentation complète, ici on ferait la backpropagation
|
||
|
|
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||
|
|
self.step_count += 1;
|
||
|
|
|
||
|
|
// Mise à jour du target model tous les 100 steps
|
||
|
|
if self.step_count % 100 == 0 {
|
||
|
|
self.target_model = self.model.clone();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn save_model<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
|
||
|
|
let data = serde_json::to_string_pretty(&self.model)?;
|
||
|
|
std::fs::write(path, data)?;
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn load_model<P: AsRef<Path>>(&mut self, path: P) -> Result<(), Box<dyn std::error::Error>> {
|
||
|
|
let data = std::fs::read_to_string(path)?;
|
||
|
|
self.model = serde_json::from_str(&data)?;
|
||
|
|
self.target_model = self.model.clone();
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Environnement Trictrac pour l'entraînement
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct TrictracEnv {
|
||
|
|
pub game_state: GameState,
|
||
|
|
pub agent_player_id: PlayerId,
|
||
|
|
pub opponent_player_id: PlayerId,
|
||
|
|
pub agent_color: Color,
|
||
|
|
pub max_steps: usize,
|
||
|
|
pub current_step: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl TrictracEnv {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
let mut game_state = GameState::new(false);
|
||
|
|
game_state.init_player("agent");
|
||
|
|
game_state.init_player("opponent");
|
||
|
|
|
||
|
|
Self {
|
||
|
|
game_state,
|
||
|
|
agent_player_id: 1,
|
||
|
|
opponent_player_id: 2,
|
||
|
|
agent_color: Color::White,
|
||
|
|
max_steps: 1000,
|
||
|
|
current_step: 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn reset(&mut self) -> Vec<f32> {
|
||
|
|
self.game_state = GameState::new(false);
|
||
|
|
self.game_state.init_player("agent");
|
||
|
|
self.game_state.init_player("opponent");
|
||
|
|
self.current_step = 0;
|
||
|
|
self.get_state_vector()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn step(&mut self, _action: usize) -> (Vec<f32>, f32, bool) {
|
||
|
|
let reward = 0.0; // Simplifié pour l'instant
|
||
|
|
let done = self.game_state.stage == store::Stage::Ended ||
|
||
|
|
self.game_state.determine_winner().is_some() ||
|
||
|
|
self.current_step >= self.max_steps;
|
||
|
|
|
||
|
|
self.current_step += 1;
|
||
|
|
|
||
|
|
// Retourner l'état suivant
|
||
|
|
let next_state = self.get_state_vector();
|
||
|
|
|
||
|
|
(next_state, reward, done)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn get_state_vector(&self) -> Vec<f32> {
|
||
|
|
let mut state = Vec::with_capacity(32);
|
||
|
|
|
||
|
|
// Plateau (24 cases)
|
||
|
|
let white_positions = self.game_state.board.get_color_fields(Color::White);
|
||
|
|
let black_positions = self.game_state.board.get_color_fields(Color::Black);
|
||
|
|
|
||
|
|
let mut board = vec![0.0; 24];
|
||
|
|
for (pos, count) in white_positions {
|
||
|
|
if pos < 24 {
|
||
|
|
board[pos] = count as f32;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for (pos, count) in black_positions {
|
||
|
|
if pos < 24 {
|
||
|
|
board[pos] = -(count as f32);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
state.extend(board);
|
||
|
|
|
||
|
|
// Informations supplémentaires limitées pour respecter input_size = 32
|
||
|
|
state.push(self.game_state.active_player_id as f32);
|
||
|
|
state.push(self.game_state.dice.values.0 as f32);
|
||
|
|
state.push(self.game_state.dice.values.1 as f32);
|
||
|
|
|
||
|
|
// Points et trous des joueurs
|
||
|
|
if let Some(white_player) = self.game_state.get_white_player() {
|
||
|
|
state.push(white_player.points as f32);
|
||
|
|
state.push(white_player.holes as f32);
|
||
|
|
} else {
|
||
|
|
state.extend(vec![0.0, 0.0]);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Assurer que la taille est exactement input_size
|
||
|
|
state.truncate(32);
|
||
|
|
while state.len() < 32 {
|
||
|
|
state.push(0.0);
|
||
|
|
}
|
||
|
|
|
||
|
|
state
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Stratégie DQN pour le bot
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub struct DqnStrategy {
|
||
|
|
pub game: GameState,
|
||
|
|
pub player_id: PlayerId,
|
||
|
|
pub color: Color,
|
||
|
|
pub agent: Option<DqnAgent>,
|
||
|
|
pub env: TrictracEnv,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for DqnStrategy {
|
||
|
|
fn default() -> Self {
|
||
|
|
let game = GameState::default();
|
||
|
|
let config = DqnConfig::default();
|
||
|
|
let agent = DqnAgent::new(config);
|
||
|
|
let env = TrictracEnv::new();
|
||
|
|
|
||
|
|
Self {
|
||
|
|
game,
|
||
|
|
player_id: 2,
|
||
|
|
color: Color::Black,
|
||
|
|
agent: Some(agent),
|
||
|
|
env,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl DqnStrategy {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self::default()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn new_with_model(model_path: &str) -> Self {
|
||
|
|
let mut strategy = Self::new();
|
||
|
|
if let Some(ref mut agent) = strategy.agent {
|
||
|
|
let _ = agent.load_model(model_path);
|
||
|
|
}
|
||
|
|
strategy
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn train_episode(&mut self) -> f32 {
|
||
|
|
let mut total_reward = 0.0;
|
||
|
|
let mut state = self.env.reset();
|
||
|
|
|
||
|
|
loop {
|
||
|
|
let action = if let Some(ref mut agent) = self.agent {
|
||
|
|
agent.select_action(&state)
|
||
|
|
} else {
|
||
|
|
0
|
||
|
|
};
|
||
|
|
|
||
|
|
let (next_state, reward, done) = self.env.step(action);
|
||
|
|
total_reward += reward;
|
||
|
|
|
||
|
|
if let Some(ref mut agent) = self.agent {
|
||
|
|
let experience = Experience {
|
||
|
|
state: state.clone(),
|
||
|
|
action,
|
||
|
|
reward,
|
||
|
|
next_state: next_state.clone(),
|
||
|
|
done,
|
||
|
|
};
|
||
|
|
agent.store_experience(experience);
|
||
|
|
agent.train();
|
||
|
|
}
|
||
|
|
|
||
|
|
if done {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
state = next_state;
|
||
|
|
}
|
||
|
|
|
||
|
|
total_reward
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||
|
|
if let Some(ref agent) = self.agent {
|
||
|
|
agent.save_model(path)?;
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl BotStrategy for DqnStrategy {
|
||
|
|
fn get_game(&self) -> &GameState {
|
||
|
|
&self.game
|
||
|
|
}
|
||
|
|
|
||
|
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||
|
|
&mut self.game
|
||
|
|
}
|
||
|
|
|
||
|
|
fn set_color(&mut self, color: Color) {
|
||
|
|
self.color = color;
|
||
|
|
}
|
||
|
|
|
||
|
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||
|
|
self.player_id = player_id;
|
||
|
|
}
|
||
|
|
|
||
|
|
fn calculate_points(&self) -> u8 {
|
||
|
|
// Pour l'instant, utilisation de la méthode standard
|
||
|
|
let dice_roll_count = self
|
||
|
|
.get_game()
|
||
|
|
.players
|
||
|
|
.get(&self.player_id)
|
||
|
|
.unwrap()
|
||
|
|
.dice_roll_count;
|
||
|
|
let points_rules = PointsRules::new(&self.color, &self.game.board, self.game.dice);
|
||
|
|
points_rules.get_points(dice_roll_count).0
|
||
|
|
}
|
||
|
|
|
||
|
|
fn calculate_adv_points(&self) -> u8 {
|
||
|
|
self.calculate_points()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn choose_go(&self) -> bool {
|
||
|
|
// Utiliser le DQN pour décider (simplifié pour l'instant)
|
||
|
|
if let Some(ref agent) = self.agent {
|
||
|
|
let state = self.env.get_state_vector();
|
||
|
|
// Action 2 = "go", on vérifie si c'est la meilleure action
|
||
|
|
let q_values = agent.model.forward(&state);
|
||
|
|
if q_values.len() > 2 {
|
||
|
|
return q_values[2] > q_values[0] && q_values[2] > *q_values.get(1).unwrap_or(&0.0);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
true // Fallback
|
||
|
|
}
|
||
|
|
|
||
|
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||
|
|
// Pour l'instant, utiliser la stratégie par défaut
|
||
|
|
// Plus tard, on pourrait utiliser le DQN pour choisir parmi les mouvements valides
|
||
|
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||
|
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||
|
|
|
||
|
|
let chosen_move = if let Some(ref agent) = self.agent {
|
||
|
|
// Utiliser le DQN pour choisir le meilleur mouvement
|
||
|
|
let state = self.env.get_state_vector();
|
||
|
|
let action = agent.model.get_best_action(&state);
|
||
|
|
|
||
|
|
// Pour l'instant, on mappe simplement l'action à un mouvement
|
||
|
|
// Dans une implémentation complète, on aurait un espace d'action plus sophistiqué
|
||
|
|
let move_index = action.min(possible_moves.len().saturating_sub(1));
|
||
|
|
*possible_moves.get(move_index).unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||
|
|
} else {
|
||
|
|
*possible_moves
|
||
|
|
.first()
|
||
|
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||
|
|
};
|
||
|
|
|
||
|
|
if self.color == Color::White {
|
||
|
|
chosen_move
|
||
|
|
} else {
|
||
|
|
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|