diff --git a/bot/Cargo.toml b/bot/Cargo.toml index a5667fa..68ff52d 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -10,8 +10,8 @@ name = "train_dqn_burn" path = "src/dqn/burnrl/main.rs" [[bin]] -name = "train_dqn" -path = "src/bin/train_dqn.rs" +name = "train_dqn_simple" +path = "src/dqn/simple/main.rs" [dependencies] pretty_assertions = "1.4.0" diff --git a/bot/src/dqn/dqn_common.rs b/bot/src/dqn/dqn_common.rs index 3ea0738..2da4aa5 100644 --- a/bot/src/dqn/dqn_common.rs +++ b/bot/src/dqn/dqn_common.rs @@ -106,157 +106,6 @@ impl TrictracAction { // } } -/// Configuration pour l'agent DQN -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DqnConfig { - pub state_size: usize, - pub hidden_size: usize, - pub num_actions: usize, - pub learning_rate: f64, - pub gamma: f64, - pub epsilon: f64, - pub epsilon_decay: f64, - pub epsilon_min: f64, - pub replay_buffer_size: usize, - pub batch_size: usize, -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - state_size: 36, - hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi - num_actions: TrictracAction::action_space_size(), - learning_rate: 0.001, - gamma: 0.99, - epsilon: 0.1, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - } - } -} - -/// Réseau de neurones DQN simplifié (matrice de poids basique) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleNeuralNetwork { - pub weights1: Vec>, - pub biases1: Vec, - pub weights2: Vec>, - pub biases2: Vec, - pub weights3: Vec>, - pub biases3: Vec, -} - -impl SimpleNeuralNetwork { - pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - - // Initialisation aléatoire des poids avec Xavier/Glorot - let scale1 = (2.0 / input_size as f32).sqrt(); - let weights1 = (0..hidden_size) - .map(|_| { - (0..input_size) - .map(|_| rng.gen_range(-scale1..scale1)) - .collect() - }) - .collect(); - let biases1 = vec![0.0; hidden_size]; - - let scale2 = (2.0 / hidden_size as f32).sqrt(); - let weights2 = (0..hidden_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale2..scale2)) - .collect() - }) - .collect(); - let biases2 = vec![0.0; hidden_size]; - - let scale3 = (2.0 / hidden_size as f32).sqrt(); - let weights3 = (0..output_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale3..scale3)) - .collect() - }) - .collect(); - let biases3 = vec![0.0; output_size]; - - Self { - weights1, - biases1, - weights2, - biases2, - weights3, - biases3, - } - } - - pub fn forward(&self, input: &[f32]) -> Vec { - // Première couche - let mut layer1: Vec = self.biases1.clone(); - for (i, neuron_weights) in self.weights1.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < input.len() { - layer1[i] += input[j] * weight; - } - } - layer1[i] = layer1[i].max(0.0); // ReLU - } - - // Deuxième couche - let mut layer2: Vec = self.biases2.clone(); - for (i, neuron_weights) in self.weights2.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer1.len() { - layer2[i] += layer1[j] * weight; - } - } - layer2[i] = layer2[i].max(0.0); // ReLU - } - - // Couche de sortie - let mut output: Vec = self.biases3.clone(); - for (i, neuron_weights) in self.weights3.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer2.len() { - output[i] += layer2[j] * weight; - } - } - } - - output - } - - pub fn get_best_action(&self, input: &[f32]) -> usize { - let q_values = self.forward(input); - q_values - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .map(|(index, _)| index) - .unwrap_or(0) - } - - pub fn save>( - &self, - path: P, - ) -> Result<(), Box> { - let data = serde_json::to_string_pretty(self)?; - std::fs::write(path, data)?; - Ok(()) - } - - pub fn load>(path: P) -> Result> { - let data = std::fs::read_to_string(path)?; - let network = serde_json::from_str(&data)?; - Ok(network) - } -} - /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { use store::TurnStage; diff --git a/bot/src/dqn/simple/dqn_model.rs b/bot/src/dqn/simple/dqn_model.rs new file mode 100644 index 0000000..ba46212 --- /dev/null +++ b/bot/src/dqn/simple/dqn_model.rs @@ -0,0 +1,154 @@ +use crate::dqn::dqn_common::TrictracAction; +use serde::{Deserialize, Serialize}; + +/// Configuration pour l'agent DQN +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DqnConfig { + pub state_size: usize, + pub hidden_size: usize, + pub num_actions: usize, + pub learning_rate: f64, + pub gamma: f64, + pub epsilon: f64, + pub epsilon_decay: f64, + pub epsilon_min: f64, + pub replay_buffer_size: usize, + pub batch_size: usize, +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + state_size: 36, + hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi + num_actions: TrictracAction::action_space_size(), + learning_rate: 0.001, + gamma: 0.99, + epsilon: 0.1, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + } + } +} + +/// Réseau de neurones DQN simplifié (matrice de poids basique) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleNeuralNetwork { + pub weights1: Vec>, + pub biases1: Vec, + pub weights2: Vec>, + pub biases2: Vec, + pub weights3: Vec>, + pub biases3: Vec, +} + +impl SimpleNeuralNetwork { + pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + + // Initialisation aléatoire des poids avec Xavier/Glorot + let scale1 = (2.0 / input_size as f32).sqrt(); + let weights1 = (0..hidden_size) + .map(|_| { + (0..input_size) + .map(|_| rng.gen_range(-scale1..scale1)) + .collect() + }) + .collect(); + let biases1 = vec![0.0; hidden_size]; + + let scale2 = (2.0 / hidden_size as f32).sqrt(); + let weights2 = (0..hidden_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale2..scale2)) + .collect() + }) + .collect(); + let biases2 = vec![0.0; hidden_size]; + + let scale3 = (2.0 / hidden_size as f32).sqrt(); + let weights3 = (0..output_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale3..scale3)) + .collect() + }) + .collect(); + let biases3 = vec![0.0; output_size]; + + Self { + weights1, + biases1, + weights2, + biases2, + weights3, + biases3, + } + } + + pub fn forward(&self, input: &[f32]) -> Vec { + // Première couche + let mut layer1: Vec = self.biases1.clone(); + for (i, neuron_weights) in self.weights1.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < input.len() { + layer1[i] += input[j] * weight; + } + } + layer1[i] = layer1[i].max(0.0); // ReLU + } + + // Deuxième couche + let mut layer2: Vec = self.biases2.clone(); + for (i, neuron_weights) in self.weights2.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer1.len() { + layer2[i] += layer1[j] * weight; + } + } + layer2[i] = layer2[i].max(0.0); // ReLU + } + + // Couche de sortie + let mut output: Vec = self.biases3.clone(); + for (i, neuron_weights) in self.weights3.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer2.len() { + output[i] += layer2[j] * weight; + } + } + } + + output + } + + pub fn get_best_action(&self, input: &[f32]) -> usize { + let q_values = self.forward(input); + q_values + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(index, _)| index) + .unwrap_or(0) + } + + pub fn save>( + &self, + path: P, + ) -> Result<(), Box> { + let data = serde_json::to_string_pretty(self)?; + std::fs::write(path, data)?; + Ok(()) + } + + pub fn load>(path: P) -> Result> { + let data = std::fs::read_to_string(path)?; + let network = serde_json::from_str(&data)?; + Ok(network) + } +} + diff --git a/bot/src/dqn/simple/dqn_trainer.rs b/bot/src/dqn/simple/dqn_trainer.rs index dedf382..78e6dc7 100644 --- a/bot/src/dqn/simple/dqn_trainer.rs +++ b/bot/src/dqn/simple/dqn_trainer.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize}; use std::collections::VecDeque; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; -use crate::dqn::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction}; +use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; +use crate::dqn::dqn_common::{get_valid_actions, TrictracAction}; /// Expérience pour le buffer de replay #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/bot/src/bin/train_dqn.rs b/bot/src/dqn/simple/main.rs similarity index 97% rename from bot/src/bin/train_dqn.rs rename to bot/src/dqn/simple/main.rs index e0929fb..30fd933 100644 --- a/bot/src/bin/train_dqn.rs +++ b/bot/src/dqn/simple/main.rs @@ -1,4 +1,5 @@ -use bot::dqn::dqn_common::{DqnConfig, TrictracAction}; +use bot::dqn::dqn_common::TrictracAction; +use bot::dqn::simple::dqn_model::DqnConfig; use bot::dqn::simple::dqn_trainer::DqnTrainer; use std::env; diff --git a/bot/src/dqn/simple/mod.rs b/bot/src/dqn/simple/mod.rs index 114bd10..8090a29 100644 --- a/bot/src/dqn/simple/mod.rs +++ b/bot/src/dqn/simple/mod.rs @@ -1 +1,2 @@ +pub mod dqn_model; pub mod dqn_trainer; diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 34fb853..cf24684 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -3,9 +3,8 @@ use log::info; use std::path::Path; use store::MoveRules; -use crate::dqn::dqn_common::{ - get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction, -}; +use crate::dqn::dqn_common::{get_valid_actions, sample_valid_action, TrictracAction}; +use crate::dqn::simple::dqn_model::SimpleNeuralNetwork; /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné #[derive(Debug)] diff --git a/justfile b/justfile index 0501ded..32193af 100644 --- a/justfile +++ b/justfile @@ -22,8 +22,8 @@ pythonlib: maturin build -m store/Cargo.toml --release pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainsimple: - cargo build --release --bin=train_dqn - LD_LIBRARY_PATH=./target/release ./target/release/train_dqn | tee /tmp/train.out + cargo build --release --bin=train_dqn_simple + LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok