diff --git a/Cargo.lock b/Cargo.lock index 270eb15..a71f75a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -834,7 +834,7 @@ dependencies = [ "derive-new", "log", "nvml-wrapper", - "ratatui 0.29.0", + "ratatui", "rstest", "serde", "sysinfo", @@ -1066,17 +1066,6 @@ dependencies = [ "store", ] -[[package]] -name = "client_tui" -version = "0.1.0" -dependencies = [ - "anyhow", - "bincode 1.3.3", - "crossterm", - "ratatui 0.28.1", - "store", -] - [[package]] name = "cmake" version = "0.1.54" @@ -4414,27 +4403,6 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" -[[package]] -name = "ratatui" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d" -dependencies = [ - "bitflags 2.9.4", - "cassowary", - "compact_str", - "crossterm", - "instability", - "itertools 0.13.0", - "lru", - "paste", - "strum 0.26.3", - "strum_macros 0.26.4", - "unicode-segmentation", - "unicode-truncate", - "unicode-width 0.1.14", -] - [[package]] name = "ratatui" version = "0.29.0" @@ -5813,18 +5781,6 @@ dependencies = [ "strength_reduce", ] -[[package]] -name = "trictrac-server" -version = "0.1.0" -dependencies = [ - "bincode 1.3.3", - "env_logger 0.10.2", - "log", - "pico-args", - "renet", - "store", -] - [[package]] name = "tungstenite" version = "0.26.2" diff --git a/Cargo.toml b/Cargo.toml index 6068644..b9e6d45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_tui", "client_cli", "bot", "server", "store"] +members = ["client_cli", "bot", "store"] diff --git a/bot/src/burnrl/algos/dqn_big.rs b/bot/src/burnrl/algos/dqn_big.rs deleted file mode 100644 index 7e8951f..0000000 --- a/bot/src/burnrl/algos/dqn_big.rs +++ /dev/null @@ -1,194 +0,0 @@ -use crate::burnrl::environment_big::TrictracEnvironment; -use crate::burnrl::utils::{soft_update_linear, Config}; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::Module; -use burn::nn::{Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::activation::relu; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::DQN; -use burn_rl::agent::{DQNModel, DQNTrainingConfig}; -use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Net { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Net { - #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } - - fn consume(self) -> (Linear, Linear, Linear) { - (self.linear_0, self.linear_1, self.linear_2) - } -} - -impl Model, Tensor> for Net { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - relu(self.linear_2.forward(layer_1_output)) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); - - Self { - linear_0: soft_update_linear(linear_0, &that.linear_0, tau), - linear_1: soft_update_linear(linear_1, &that.linear_1, tau), - linear_2: soft_update_linear(linear_2, &that.linear_2, tau), - } - } -} - -#[allow(unused)] -const MEMORY_SIZE: usize = 8192; - -type MyAgent = DQN>; - -#[allow(unused)] -// pub fn run, B: AutodiffBackend>( -pub fn run< - E: Environment + AsMut, - B: AutodiffBackend, ->( - conf: &Config, - visualized: bool, - // ) -> DQN> { -) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().max_steps = conf.max_steps; - - let model = Net::::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - - let mut agent = MyAgent::new(model); - - // let config = DQNTrainingConfig::default(); - let config = DQNTrainingConfig { - gamma: conf.gamma, - tau: conf.tau, - learning_rate: conf.learning_rate, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut memory = Memory::::default(); - - let mut optimizer = AdamWConfig::new() - .with_grad_clipping(config.clip_grad.clone()) - .init(); - - let mut policy_net = agent.model().as_ref().unwrap().clone(); - - let mut step = 0_usize; - - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward: ElemType = 0.0; - let mut episode_duration = 0_usize; - let mut state = env.state(); - let mut now = SystemTime::now(); - - while !episode_done { - let eps_threshold = conf.eps_end - + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); - let action = - DQN::>::react_with_exploration(&policy_net, state, eps_threshold); - let snapshot = env.step(action); - - episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); - } - - step += 1; - episode_duration += 1; - - if snapshot.done() || episode_duration >= conf.max_steps { - let envmut = env.as_mut(); - let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32) - * 100.0) - .round() as u32; - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}", - envmut.goodmoves_count, - goodmoves_ratio, - envmut.pointrolls_count, - now.elapsed().unwrap().as_secs(), - ); - env.reset(); - episode_done = true; - now = SystemTime::now(); - } else { - state = *snapshot.state(); - } - } - } - let valid_agent = agent.valid(); - if let Some(path) = &conf.save_path { - save_model(valid_agent.model().as_ref().unwrap(), path); - } - valid_agent -} - -pub fn save_model(model: &Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}.mpk"); - println!("info: Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} diff --git a/bot/src/burnrl/algos/mod.rs b/bot/src/burnrl/algos/mod.rs index af13327..5a67dfc 100644 --- a/bot/src/burnrl/algos/mod.rs +++ b/bot/src/burnrl/algos/mod.rs @@ -1,9 +1,6 @@ pub mod dqn; -pub mod dqn_big; pub mod dqn_valid; pub mod ppo; -pub mod ppo_big; pub mod ppo_valid; pub mod sac; -pub mod sac_big; pub mod sac_valid; diff --git a/bot/src/burnrl/algos/ppo_big.rs b/bot/src/burnrl/algos/ppo_big.rs deleted file mode 100644 index ab860ee..0000000 --- a/bot/src/burnrl/algos/ppo_big.rs +++ /dev/null @@ -1,191 +0,0 @@ -use crate::burnrl::environment_big::TrictracEnvironment; -use crate::burnrl::utils::Config; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::Module; -use burn::nn::{Initializer, Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::activation::{relu, softmax}; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; -use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::env; -use std::fs; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Net { - linear: Linear, - linear_actor: Linear, - linear_critic: Linear, -} - -impl Net { - #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - let initializer = Initializer::XavierUniform { gain: 1.0 }; - Self { - linear: LinearConfig::new(input_size, dense_size) - .with_initializer(initializer.clone()) - .init(&Default::default()), - linear_actor: LinearConfig::new(dense_size, output_size) - .with_initializer(initializer.clone()) - .init(&Default::default()), - linear_critic: LinearConfig::new(dense_size, 1) - .with_initializer(initializer) - .init(&Default::default()), - } - } -} - -impl Model, PPOOutput, Tensor> for Net { - fn forward(&self, input: Tensor) -> PPOOutput { - let layer_0_output = relu(self.linear.forward(input)); - let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1); - let values = self.linear_critic.forward(layer_0_output); - - PPOOutput::::new(policies, values) - } - - fn infer(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear.forward(input)); - softmax(self.linear_actor.forward(layer_0_output.clone()), 1) - } -} - -impl PPOModel for Net {} -#[allow(unused)] -const MEMORY_SIZE: usize = 512; - -type MyAgent = PPO>; - -#[allow(unused)] -pub fn run< - E: Environment + AsMut, - B: AutodiffBackend, ->( - conf: &Config, - visualized: bool, - // ) -> PPO> { -) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().max_steps = conf.max_steps; - - let mut model = Net::::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - let agent = MyAgent::default(); - let config = PPOTrainingConfig { - gamma: conf.gamma, - lambda: conf.lambda, - epsilon_clip: conf.epsilon_clip, - critic_weight: conf.critic_weight, - entropy_weight: conf.entropy_weight, - learning_rate: conf.learning_rate, - epochs: conf.epochs, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut optimizer = AdamWConfig::new() - .with_grad_clipping(config.clip_grad.clone()) - .init(); - let mut memory = Memory::::default(); - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward = 0.0; - let mut episode_duration = 0_usize; - let mut now = SystemTime::now(); - - env.reset(); - while !episode_done { - let state = env.state(); - if let Some(action) = MyAgent::::react_with_model(&state, &model) { - let snapshot = env.step(action); - episode_reward += <::RewardType as Into>::into( - snapshot.reward().clone(), - ); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - episode_duration += 1; - episode_done = snapshot.done() || episode_duration >= conf.max_steps; - } - } - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", - now.elapsed().unwrap().as_secs(), - ); - - now = SystemTime::now(); - model = MyAgent::train::(model, &memory, &mut optimizer, &config); - memory.clear(); - } - - if let Some(path) = &conf.save_path { - let device = NdArrayDevice::default(); - let recorder = CompactRecorder::new(); - let tmp_path = env::temp_dir().join("tmp_model.mpk"); - - // Save the trained model (backend B) to a temporary file - recorder - .record(model.clone().into_record(), tmp_path.clone()) - .expect("Failed to save temporary model"); - - // Create a new model instance with the target backend (NdArray) - let model_to_save: Net> = Net::new( - <::StateType as State>::size(), - conf.dense_size, - <::ActionType as Action>::size(), - ); - - // Load the record from the temporary file into the new model - let record = recorder - .load(tmp_path.clone(), &device) - .expect("Failed to load temporary model"); - let model_with_loaded_weights = model_to_save.load_record(record); - - // Clean up the temporary file - fs::remove_file(tmp_path).expect("Failed to remove temporary model file"); - - save_model(&model_with_loaded_weights, path); - } - agent.valid(model) -} - -pub fn save_model(model: &Net>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}.mpk"); - println!("info: Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - Net::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} diff --git a/bot/src/burnrl/algos/sac_big.rs b/bot/src/burnrl/algos/sac_big.rs deleted file mode 100644 index 1361b42..0000000 --- a/bot/src/burnrl/algos/sac_big.rs +++ /dev/null @@ -1,222 +0,0 @@ -use crate::burnrl::environment_big::TrictracEnvironment; -use crate::burnrl::utils::{soft_update_linear, Config}; -use burn::backend::{ndarray::NdArrayDevice, NdArray}; -use burn::module::Module; -use burn::nn::{Linear, LinearConfig}; -use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; -use burn::tensor::activation::{relu, softmax}; -use burn::tensor::backend::{AutodiffBackend, Backend}; -use burn::tensor::Tensor; -use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; -use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use std::time::SystemTime; - -#[derive(Module, Debug)] -pub struct Actor { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Actor { - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } -} - -impl Model, Tensor> for Actor { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - softmax(self.linear_2.forward(layer_1_output), 1) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl SACActor for Actor {} - -#[derive(Module, Debug)] -pub struct Critic { - linear_0: Linear, - linear_1: Linear, - linear_2: Linear, -} - -impl Critic { - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { - Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), - } - } - - fn consume(self) -> (Linear, Linear, Linear) { - (self.linear_0, self.linear_1, self.linear_2) - } -} - -impl Model, Tensor> for Critic { - fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); - let layer_1_output = relu(self.linear_1.forward(layer_0_output)); - - self.linear_2.forward(layer_1_output) - } - - fn infer(&self, input: Tensor) -> Tensor { - self.forward(input) - } -} - -impl SACCritic for Critic { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); - - Self { - linear_0: soft_update_linear(linear_0, &that.linear_0, tau), - linear_1: soft_update_linear(linear_1, &that.linear_1, tau), - linear_2: soft_update_linear(linear_2, &that.linear_2, tau), - } - } -} - -#[allow(unused)] -const MEMORY_SIZE: usize = 4096; - -type MyAgent = SAC>; - -#[allow(unused)] -pub fn run< - E: Environment + AsMut, - B: AutodiffBackend, ->( - conf: &Config, - visualized: bool, -) -> impl Agent { - let mut env = E::new(visualized); - env.as_mut().max_steps = conf.max_steps; - let state_dim = <::StateType as State>::size(); - let action_dim = <::ActionType as Action>::size(); - - let actor = Actor::::new(state_dim, conf.dense_size, action_dim); - let critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); - let critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); - let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); - - let mut agent = MyAgent::default(); - - let config = SACTrainingConfig { - gamma: conf.gamma, - tau: conf.tau, - learning_rate: conf.learning_rate, - min_probability: conf.min_probability, - batch_size: conf.batch_size, - clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( - conf.clip_grad, - )), - }; - - let mut memory = Memory::::default(); - - let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone()); - - let mut optimizer = SACOptimizer::new( - optimizer_config.clone().init(), - optimizer_config.clone().init(), - optimizer_config.clone().init(), - optimizer_config.init(), - ); - - let mut step = 0_usize; - - for episode in 0..conf.num_episodes { - let mut episode_done = false; - let mut episode_reward = 0.0; - let mut episode_duration = 0_usize; - let mut state = env.state(); - let mut now = SystemTime::now(); - - while !episode_done { - if let Some(action) = MyAgent::::react_with_model(&state, &nets.actor) { - let snapshot = env.step(action); - - episode_reward += <::RewardType as Into>::into( - snapshot.reward().clone(), - ); - - memory.push( - state, - *snapshot.state(), - action, - snapshot.reward().clone(), - snapshot.done(), - ); - - if config.batch_size < memory.len() { - nets = agent.train::(nets, &memory, &mut optimizer, &config); - } - - step += 1; - episode_duration += 1; - - if snapshot.done() || episode_duration >= conf.max_steps { - env.reset(); - episode_done = true; - - println!( - "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", - now.elapsed().unwrap().as_secs() - ); - now = SystemTime::now(); - } else { - state = *snapshot.state(); - } - } - } - } - - let valid_agent = agent.valid(nets.actor); - if let Some(path) = &conf.save_path { - if let Some(model) = valid_agent.model() { - save_model(model, path); - } - } - valid_agent -} - -pub fn save_model(model: &Actor>, path: &String) { - let recorder = CompactRecorder::new(); - let model_path = format!("{path}.mpk"); - println!("info: Modèle de validation sauvegardé : {model_path}"); - recorder - .record(model.clone().into_record(), model_path.into()) - .unwrap(); -} - -pub fn load_model(dense_size: usize, path: &String) -> Option>> { - let model_path = format!("{path}.mpk"); - // println!("Chargement du modèle depuis : {model_path}"); - - CompactRecorder::new() - .load(model_path.into(), &NdArrayDevice::default()) - .map(|record| { - Actor::new( - ::StateType::size(), - dense_size, - ::ActionType::size(), - ) - .load_record(record) - }) - .ok() -} - diff --git a/bot/src/burnrl/environment_big.rs b/bot/src/burnrl/environment_big.rs deleted file mode 100644 index 40d5a74..0000000 --- a/bot/src/burnrl/environment_big.rs +++ /dev/null @@ -1,470 +0,0 @@ -use crate::training_common_big; -use burn::{prelude::Backend, tensor::Tensor}; -use burn_rl::base::{Action, Environment, Snapshot, State}; -use rand::{thread_rng, Rng}; -use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; - -const ERROR_REWARD: f32 = -1.00012121; -const REWARD_VALID_MOVE: f32 = 1.00012121; -const REWARD_RATIO: f32 = 0.1; -const WIN_POINTS: f32 = 100.0; - -/// État du jeu Trictrac pour burn-rl -#[derive(Debug, Clone, Copy)] -pub struct TrictracState { - pub data: [i8; 36], // Représentation vectorielle de l'état du jeu -} - -impl State for TrictracState { - type Data = [i8; 36]; - - fn to_tensor(&self) -> Tensor { - Tensor::from_floats(self.data, &B::Device::default()) - } - - fn size() -> usize { - 36 - } -} - -impl TrictracState { - /// Convertit un GameState en TrictracState - pub fn from_game_state(game_state: &GameState) -> Self { - let state_vec = game_state.to_vec(); - let mut data = [0; 36]; - - // Copier les données en s'assurant qu'on ne dépasse pas la taille - let copy_len = state_vec.len().min(36); - data[..copy_len].copy_from_slice(&state_vec[..copy_len]); - - TrictracState { data } - } -} - -/// Actions possibles dans Trictrac pour burn-rl -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TrictracAction { - // u32 as required by burn_rl::base::Action type - pub index: u32, -} - -impl Action for TrictracAction { - fn random() -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - TrictracAction { - index: rng.gen_range(0..Self::size() as u32), - } - } - - fn enumerate() -> Vec { - (0..Self::size() as u32) - .map(|index| TrictracAction { index }) - .collect() - } - - fn size() -> usize { - 1252 - } -} - -impl From for TrictracAction { - fn from(index: u32) -> Self { - TrictracAction { index } - } -} - -impl From for u32 { - fn from(action: TrictracAction) -> u32 { - action.index - } -} - -/// Environnement Trictrac pour burn-rl -#[derive(Debug)] -pub struct TrictracEnvironment { - pub game: GameState, - active_player_id: PlayerId, - opponent_id: PlayerId, - current_state: TrictracState, - episode_reward: f32, - pub step_count: usize, - pub max_steps: usize, - pub pointrolls_count: usize, - pub goodmoves_count: usize, - pub goodmoves_ratio: f32, - pub visualized: bool, -} - -impl Environment for TrictracEnvironment { - type StateType = TrictracState; - type ActionType = TrictracAction; - type RewardType = f32; - - fn new(visualized: bool) -> Self { - let mut game = GameState::new(false); - - // Ajouter deux joueurs - game.init_player("DQN Agent"); - game.init_player("Opponent"); - let player1_id = 1; - let player2_id = 2; - - // Commencer la partie - game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - let current_state = TrictracState::from_game_state(&game); - TrictracEnvironment { - game, - active_player_id: player1_id, - opponent_id: player2_id, - current_state, - episode_reward: 0.0, - step_count: 0, - max_steps: 2000, - pointrolls_count: 0, - goodmoves_count: 0, - goodmoves_ratio: 0.0, - visualized, - } - } - - fn state(&self) -> Self::StateType { - self.current_state - } - - fn reset(&mut self) -> Snapshot { - // Réinitialiser le jeu - self.game = GameState::new(false); - self.game.init_player("DQN Agent"); - self.game.init_player("Opponent"); - - // Commencer la partie - self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); - - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward = 0.0; - self.goodmoves_ratio = if self.step_count == 0 { - 0.0 - } else { - self.goodmoves_count as f32 / self.step_count as f32 - }; - println!( - "info: correct moves: {} ({}%)", - self.goodmoves_count, - (100.0 * self.goodmoves_ratio).round() as u32 - ); - self.step_count = 0; - self.pointrolls_count = 0; - self.goodmoves_count = 0; - - Snapshot::new(self.current_state, 0.0, false) - } - - fn step(&mut self, action: Self::ActionType) -> Snapshot { - self.step_count += 1; - - // Convertir l'action burn-rl vers une action Trictrac - let trictrac_action = Self::convert_action(action); - - let mut reward = 0.0; - let is_rollpoint; - - // Exécuter l'action si c'est le tour de l'agent DQN - if self.game.active_player_id == self.active_player_id { - if let Some(action) = trictrac_action { - (reward, is_rollpoint) = self.execute_action(action); - if is_rollpoint { - self.pointrolls_count += 1; - } - if reward != ERROR_REWARD { - self.goodmoves_count += 1; - // println!("{str_action}"); - } - } else { - // Action non convertible, pénalité - reward = -0.5; - } - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // print!(":"); - reward += self.play_opponent_if_needed(); - } - - // Vérifier si la partie est terminée - // let max_steps = self.max_steps - // let max_steps = self.min_steps - // + (self.max_steps as f32 - self.min_steps) - // * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); - let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); - - if done { - // Récompense finale basée sur le résultat - if let Some(winner_id) = self.game.determine_winner() { - if winner_id == self.active_player_id { - reward += WIN_POINTS; // Victoire - } else { - reward -= WIN_POINTS; // Défaite - } - } - } - let terminated = done || self.step_count >= self.max_steps; - - // Mettre à jour l'état - self.current_state = TrictracState::from_game_state(&self.game); - self.episode_reward += reward; - if self.visualized && terminated { - println!( - "Episode terminé. Récompense totale: {:.2}, Étapes: {}", - self.episode_reward, self.step_count - ); - } - - Snapshot::new(self.current_state, reward, terminated) - } -} - -impl TrictracEnvironment { - /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) - } - - /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac - #[allow(dead_code)] - fn convert_valid_action_index( - &self, - action: TrictracAction, - game_state: &GameState, - ) -> Option { - use training_common_big::get_valid_actions; - - // Obtenir les actions valides dans le contexte actuel - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - return None; - } - - // Mapper l'index d'action sur une action valide - let action_index = (action.index as usize) % valid_actions.len(); - Some(valid_actions[action_index].clone()) - } - - /// Exécute une action Trictrac dans le jeu - // fn execute_action( - // &mut self, - // action:training_common_big::TrictracAction, - // ) -> Result> { - fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { - use training_common_big::TrictracAction; - - let mut reward = 0.0; - let mut is_rollpoint = false; - let mut need_roll = false; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - need_roll = true; - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game.validate(&event) { - self.game.consume(&event); - reward += REWARD_VALID_MOVE; - // Simuler le résultat des dés après un Roll - // if matches!(action, TrictracAction::Roll) { - if need_roll { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.active_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - // print!("o"); - if self.game.validate(&dice_event) { - self.game.consume(&dice_event); - let (points, adv_points) = self.game.dice_points; - reward += REWARD_RATIO * (points - adv_points) as f32; - if points > 0 { - is_rollpoint = true; - // println!("info: rolled for {reward}"); - } - // Récompense proportionnelle aux points - } - } - } else { - // Pénalité pour action invalide - // on annule les précédents reward - // et on indique une valeur reconnaissable pour statistiques - reward = ERROR_REWARD; - self.game.mark_points_for_bot_training(self.opponent_id, 1); - } - } - - (reward, is_rollpoint) - } - - /// Fait jouer l'adversaire avec une stratégie simple - fn play_opponent_if_needed(&mut self) -> f32 { - // print!("z?"); - let mut reward = 0.0; - - // Si c'est le tour de l'adversaire, jouer automatiquement - if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // Utiliser la stratégie default pour l'adversaire - use crate::BotStrategy; - - let mut strategy = crate::strategy::random::RandomStrategy::default(); - strategy.set_player_id(self.opponent_id); - if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { - strategy.set_color(color); - } - *strategy.get_mut_game() = self.game.clone(); - - // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; - let event = match self.game.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - calculate_points = true; // comment to replicate burnrl_before - GameEvent::RollResult { - player_id: self.opponent_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkPoints => { - panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); - // let dice_roll_count = self - // .game - // .players - // .get(&self.opponent_id) - // .unwrap() - // .dice_roll_count; - // let points_rules = - // PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - // GameEvent::Mark { - // player_id: self.opponent_id, - // points: points_rules.get_points(dice_roll_count).0, - // } - } - TurnStage::MarkAdvPoints => { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - // pas de reward : déjà comptabilisé lors du tour de blanc - GameEvent::Mark { - player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_id, - } - } - TurnStage::Move => GameEvent::Move { - player_id: self.opponent_id, - moves: strategy.choose_move(), - }, - }; - - if self.game.validate(&event) { - self.game.consume(&event); - // print!("."); - if calculate_points { - // print!("x"); - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - let adv_reward = REWARD_RATIO * (points - adv_points) as f32; - reward -= adv_reward; - // if adv_reward != 0.0 { - // println!("info: opponent : {adv_reward} -> {reward}"); - // } - } - } - } - reward - } -} - -impl AsMut for TrictracEnvironment { - fn as_mut(&mut self) -> &mut Self { - self - } -} diff --git a/bot/src/burnrl/environment_valid.rs b/bot/src/burnrl/environment_valid.rs index 346044c..9c27af9 100644 --- a/bot/src/burnrl/environment_valid.rs +++ b/bot/src/burnrl/environment_valid.rs @@ -1,9 +1,12 @@ -use crate::training_common_big; +use crate::training_common; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; +const ERROR_REWARD: f32 = -1.0012121; +const REWARD_RATIO: f32 = 0.1; + /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] pub struct TrictracState { @@ -214,16 +217,16 @@ impl TrictracEnvironment { const REWARD_RATIO: f32 = 1.0; /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + training_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac fn convert_valid_action_index( &self, action: TrictracAction, - ) -> Option { - use training_common_big::get_valid_actions; + ) -> Option { + use training_common::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(&self.game); @@ -240,72 +243,19 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: training_common_big::TrictracAction, + // action: training_common::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { - use training_common_big::TrictracAction; + fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) { + use training_common::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - Some(GameEvent::Roll { - player_id: self.active_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.active_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - Some(GameEvent::Go { - player_id: self.active_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game.dice.values.0, self.game.dice.values.1) - } else { - (self.game.dice.values.1, self.game.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - Some(GameEvent::Move { - player_id: self.active_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - // Appliquer l'événement si valide - if let Some(event) = event { + if let Some(event) = action.to_event(&self.game) { if self.game.validate(&event) { self.game.consume(&event); - + // reward += REWARD_VALID_MOVE; // Simuler le résultat des dés après un Roll if matches!(action, TrictracAction::Roll) { let mut rng = thread_rng(); @@ -319,7 +269,7 @@ impl TrictracEnvironment { if self.game.validate(&dice_event) { self.game.consume(&dice_event); let (points, adv_points) = self.game.dice_points; - reward += Self::REWARD_RATIO * (points - adv_points) as f32; + reward += REWARD_RATIO * (points as f32 - adv_points as f32); if points > 0 { is_rollpoint = true; // println!("info: rolled for {reward}"); @@ -331,9 +281,12 @@ impl TrictracEnvironment { // Pénalité pour action invalide // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques - println!("info: action invalide -> err_reward"); - reward = Self::ERROR_REWARD; + reward = ERROR_REWARD; + self.game.mark_points_for_bot_training(self.opponent_id, 1); } + } else { + reward = ERROR_REWARD; + self.game.mark_points_for_bot_training(self.opponent_id, 1); } (reward, is_rollpoint) diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index f7608a3..5230ec0 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,8 +1,5 @@ -use bot::burnrl::algos::{ - dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid, -}; +use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid}; use bot::burnrl::environment::TrictracEnvironment; -use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig; use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid; use bot::burnrl::utils::{demo_model, Config}; use burn::backend::{Autodiff, NdArray}; @@ -36,16 +33,6 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - "dqn_big" => { - let _agent = dqn_big::run::(&conf, false); - println!("> Chargement du modèle pour test"); - let loaded_model = dqn_big::load_model(conf.dense_size, &path); - let loaded_agent: burn_rl::agent::DQN = - burn_rl::agent::DQN::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); - } "dqn_valid" => { let _agent = dqn_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); @@ -66,16 +53,6 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - "sac_big" => { - let _agent = sac_big::run::(&conf, false); - println!("> Chargement du modèle pour test"); - let loaded_model = sac_big::load_model(conf.dense_size, &path); - let loaded_agent: burn_rl::agent::SAC = - burn_rl::agent::SAC::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); - } "sac_valid" => { let _agent = sac_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); @@ -96,16 +73,6 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } - "ppo_big" => { - let _agent = ppo_big::run::(&conf, false); - println!("> Chargement du modèle pour test"); - let loaded_model = ppo_big::load_model(conf.dense_size, &path); - let loaded_agent: burn_rl::agent::PPO = - burn_rl::agent::PPO::new(loaded_model.unwrap()); - - println!("> Test avec le modèle chargé"); - demo_model(loaded_agent); - } "ppo_valid" => { let _agent = ppo_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs index 62bebc8..292bbb8 100644 --- a/bot/src/burnrl/mod.rs +++ b/bot/src/burnrl/mod.rs @@ -1,5 +1,4 @@ pub mod algos; pub mod environment; -pub mod environment_big; pub mod environment_valid; pub mod utils; diff --git a/bot/src/dqn_simple/dqn_model.rs b/bot/src/dqn_simple/dqn_model.rs deleted file mode 100644 index 9c31f44..0000000 --- a/bot/src/dqn_simple/dqn_model.rs +++ /dev/null @@ -1,153 +0,0 @@ -use crate::training_common_big::TrictracAction; -use serde::{Deserialize, Serialize}; - -/// Configuration pour l'agent DQN -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DqnConfig { - pub state_size: usize, - pub hidden_size: usize, - pub num_actions: usize, - pub learning_rate: f64, - pub gamma: f64, - pub epsilon: f64, - pub epsilon_decay: f64, - pub epsilon_min: f64, - pub replay_buffer_size: usize, - pub batch_size: usize, -} - -impl Default for DqnConfig { - fn default() -> Self { - Self { - state_size: 36, - hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi - num_actions: TrictracAction::action_space_size(), - learning_rate: 0.001, - gamma: 0.99, - epsilon: 0.1, - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - } - } -} - -/// Réseau de neurones DQN simplifié (matrice de poids basique) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleNeuralNetwork { - pub weights1: Vec>, - pub biases1: Vec, - pub weights2: Vec>, - pub biases2: Vec, - pub weights3: Vec>, - pub biases3: Vec, -} - -impl SimpleNeuralNetwork { - pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { - use rand::{thread_rng, Rng}; - let mut rng = thread_rng(); - - // Initialisation aléatoire des poids avec Xavier/Glorot - let scale1 = (2.0 / input_size as f32).sqrt(); - let weights1 = (0..hidden_size) - .map(|_| { - (0..input_size) - .map(|_| rng.gen_range(-scale1..scale1)) - .collect() - }) - .collect(); - let biases1 = vec![0.0; hidden_size]; - - let scale2 = (2.0 / hidden_size as f32).sqrt(); - let weights2 = (0..hidden_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale2..scale2)) - .collect() - }) - .collect(); - let biases2 = vec![0.0; hidden_size]; - - let scale3 = (2.0 / hidden_size as f32).sqrt(); - let weights3 = (0..output_size) - .map(|_| { - (0..hidden_size) - .map(|_| rng.gen_range(-scale3..scale3)) - .collect() - }) - .collect(); - let biases3 = vec![0.0; output_size]; - - Self { - weights1, - biases1, - weights2, - biases2, - weights3, - biases3, - } - } - - pub fn forward(&self, input: &[f32]) -> Vec { - // Première couche - let mut layer1: Vec = self.biases1.clone(); - for (i, neuron_weights) in self.weights1.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < input.len() { - layer1[i] += input[j] * weight; - } - } - layer1[i] = layer1[i].max(0.0); // ReLU - } - - // Deuxième couche - let mut layer2: Vec = self.biases2.clone(); - for (i, neuron_weights) in self.weights2.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer1.len() { - layer2[i] += layer1[j] * weight; - } - } - layer2[i] = layer2[i].max(0.0); // ReLU - } - - // Couche de sortie - let mut output: Vec = self.biases3.clone(); - for (i, neuron_weights) in self.weights3.iter().enumerate() { - for (j, &weight) in neuron_weights.iter().enumerate() { - if j < layer2.len() { - output[i] += layer2[j] * weight; - } - } - } - - output - } - - pub fn get_best_action(&self, input: &[f32]) -> usize { - let q_values = self.forward(input); - q_values - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .map(|(index, _)| index) - .unwrap_or(0) - } - - pub fn save>( - &self, - path: P, - ) -> Result<(), Box> { - let data = serde_json::to_string_pretty(self)?; - std::fs::write(path, data)?; - Ok(()) - } - - pub fn load>(path: P) -> Result> { - let data = std::fs::read_to_string(path)?; - let network = serde_json::from_str(&data)?; - Ok(network) - } -} diff --git a/bot/src/dqn_simple/dqn_trainer.rs b/bot/src/dqn_simple/dqn_trainer.rs deleted file mode 100644 index ed60f5e..0000000 --- a/bot/src/dqn_simple/dqn_trainer.rs +++ /dev/null @@ -1,494 +0,0 @@ -use crate::{CheckerMove, Color, GameState, PlayerId}; -use rand::prelude::SliceRandom; -use rand::{thread_rng, Rng}; -use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; -use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; - -use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; -use crate::training_common_big::{get_valid_actions, TrictracAction}; - -/// Expérience pour le buffer de replay -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Experience { - pub state: Vec, - pub action: TrictracAction, - pub reward: f32, - pub next_state: Vec, - pub done: bool, -} - -/// Buffer de replay pour stocker les expériences -#[derive(Debug)] -pub struct ReplayBuffer { - buffer: VecDeque, - capacity: usize, -} - -impl ReplayBuffer { - pub fn new(capacity: usize) -> Self { - Self { - buffer: VecDeque::with_capacity(capacity), - capacity, - } - } - - pub fn push(&mut self, experience: Experience) { - if self.buffer.len() >= self.capacity { - self.buffer.pop_front(); - } - self.buffer.push_back(experience); - } - - pub fn sample(&self, batch_size: usize) -> Vec { - let mut rng = thread_rng(); - let len = self.buffer.len(); - if len < batch_size { - return self.buffer.iter().cloned().collect(); - } - - let mut batch = Vec::with_capacity(batch_size); - for _ in 0..batch_size { - let idx = rng.gen_range(0..len); - batch.push(self.buffer[idx].clone()); - } - batch - } - - pub fn is_empty(&self) -> bool { - self.buffer.is_empty() - } - - pub fn len(&self) -> usize { - self.buffer.len() - } -} - -/// Agent DQN pour l'apprentissage par renforcement -#[derive(Debug)] -pub struct DqnAgent { - config: DqnConfig, - model: SimpleNeuralNetwork, - target_model: SimpleNeuralNetwork, - replay_buffer: ReplayBuffer, - epsilon: f64, - step_count: usize, -} - -impl DqnAgent { - pub fn new(config: DqnConfig) -> Self { - let model = - SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions); - let target_model = model.clone(); - let replay_buffer = ReplayBuffer::new(config.replay_buffer_size); - let epsilon = config.epsilon; - - Self { - config, - model, - target_model, - replay_buffer, - epsilon, - step_count: 0, - } - } - - pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction { - let valid_actions = get_valid_actions(game_state); - - if valid_actions.is_empty() { - // Fallback si aucune action valide - return TrictracAction::Roll; - } - - let mut rng = thread_rng(); - if rng.gen::() < self.epsilon { - // Exploration : action valide aléatoire - valid_actions - .choose(&mut rng) - .cloned() - .unwrap_or(TrictracAction::Roll) - } else { - // Exploitation : meilleure action valide selon le modèle - let q_values = self.model.forward(state); - - let mut best_action = &valid_actions[0]; - let mut best_q_value = f32::NEG_INFINITY; - - for action in &valid_actions { - let action_index = action.to_action_index(); - if action_index < q_values.len() { - let q_value = q_values[action_index]; - if q_value > best_q_value { - best_q_value = q_value; - best_action = action; - } - } - } - - best_action.clone() - } - } - - pub fn store_experience(&mut self, experience: Experience) { - self.replay_buffer.push(experience); - } - - pub fn train(&mut self) { - if self.replay_buffer.len() < self.config.batch_size { - return; - } - - // Pour l'instant, on simule l'entraînement en mettant à jour epsilon - // Dans une implémentation complète, ici on ferait la backpropagation - self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min); - self.step_count += 1; - - // Mise à jour du target model tous les 100 steps - if self.step_count % 100 == 0 { - self.target_model = self.model.clone(); - } - } - - pub fn save_model>( - &self, - path: P, - ) -> Result<(), Box> { - self.model.save(path) - } - - pub fn get_epsilon(&self) -> f64 { - self.epsilon - } - - pub fn get_step_count(&self) -> usize { - self.step_count - } -} - -/// Environnement Trictrac pour l'entraînement -#[derive(Debug)] -pub struct TrictracEnv { - pub game_state: GameState, - pub agent_player_id: PlayerId, - pub opponent_player_id: PlayerId, - pub agent_color: Color, - pub max_steps: usize, - pub current_step: usize, -} - -impl Default for TrictracEnv { - fn default() -> Self { - let mut game_state = GameState::new(false); - game_state.init_player("agent"); - game_state.init_player("opponent"); - - Self { - game_state, - agent_player_id: 1, - opponent_player_id: 2, - agent_color: Color::White, - max_steps: 1000, - current_step: 0, - } - } -} - -impl TrictracEnv { - pub fn reset(&mut self) -> Vec { - self.game_state = GameState::new(false); - self.game_state.init_player("agent"); - self.game_state.init_player("opponent"); - - // Commencer la partie - self.game_state.consume(&GameEvent::BeginGame { - goes_first: self.agent_player_id, - }); - - self.current_step = 0; - self.game_state.to_vec_float() - } - - pub fn step(&mut self, action: TrictracAction) -> (Vec, f32, bool) { - let mut reward = 0.0; - - // Appliquer l'action de l'agent - if self.game_state.active_player_id == self.agent_player_id { - reward += self.apply_agent_action(action); - } - - // Faire jouer l'adversaire (stratégie simple) - while self.game_state.active_player_id == self.opponent_player_id - && self.game_state.stage != Stage::Ended - { - reward += self.play_opponent_turn(); - } - - // Vérifier si la partie est terminée - let done = self.game_state.stage == Stage::Ended - || self.game_state.determine_winner().is_some() - || self.current_step >= self.max_steps; - - // Récompense finale si la partie est terminée - if done { - if let Some(winner) = self.game_state.determine_winner() { - if winner == self.agent_player_id { - reward += 100.0; // Bonus pour gagner - } else { - reward -= 50.0; // Pénalité pour perdre - } - } - } - - self.current_step += 1; - let next_state = self.game_state.to_vec_float(); - (next_state, reward, done) - } - - fn apply_agent_action(&mut self, action: TrictracAction) -> f32 { - let mut reward = 0.0; - - let event = match action { - TrictracAction::Roll => { - // Lancer les dés - reward += 0.1; - Some(GameEvent::Roll { - player_id: self.agent_player_id, - }) - } - // TrictracAction::Mark => { - // // Marquer des points - // let points = self.game_state. - // reward += 0.1 * points as f32; - // Some(GameEvent::Mark { - // player_id: self.agent_player_id, - // points, - // }) - // } - TrictracAction::Go => { - // Continuer après avoir gagné un trou - reward += 0.2; - Some(GameEvent::Go { - player_id: self.agent_player_id, - }) - } - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Effectuer un mouvement - let (dice1, dice2) = if dice_order { - (self.game_state.dice.values.0, self.game_state.dice.values.1) - } else { - (self.game_state.dice.values.1, self.game_state.dice.values.0) - }; - let mut to1 = from1 + dice1 as usize; - let mut to2 = from2 + dice2 as usize; - - // Gestion prise de coin par puissance - let opp_rest_field = 13; - if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; - } - - let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); - - reward += 0.2; - Some(GameEvent::Move { - player_id: self.agent_player_id, - moves: (checker_move1, checker_move2), - }) - } - }; - - // Appliquer l'événement si valide - if let Some(event) = event { - if self.game_state.validate(&event) { - self.game_state.consume(&event); - - // Simuler le résultat des dés après un Roll - if matches!(action, TrictracAction::Roll) { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - let dice_event = GameEvent::RollResult { - player_id: self.agent_player_id, - dice: store::Dice { - values: dice_values, - }, - }; - if self.game_state.validate(&dice_event) { - self.game_state.consume(&dice_event); - } - } - } else { - // Pénalité pour action invalide - reward -= 2.0; - } - } - - reward - } - - // TODO : use default bot strategy - fn play_opponent_turn(&mut self) -> f32 { - let mut reward = 0.0; - let event = match self.game_state.turn_stage { - TurnStage::RollDice => GameEvent::Roll { - player_id: self.opponent_player_id, - }, - TurnStage::RollWaiting => { - let mut rng = thread_rng(); - let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - GameEvent::RollResult { - player_id: self.opponent_player_id, - dice: store::Dice { - values: dice_values, - }, - } - } - TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { - let opponent_color = self.agent_color.opponent_color(); - let dice_roll_count = self - .game_state - .players - .get(&self.opponent_player_id) - .unwrap() - .dice_roll_count; - let points_rules = PointsRules::new( - &opponent_color, - &self.game_state.board, - self.game_state.dice, - ); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points - - GameEvent::Mark { - player_id: self.opponent_player_id, - points, - } - } - TurnStage::Move => { - let opponent_color = self.agent_color.opponent_color(); - let rules = MoveRules::new( - &opponent_color, - &self.game_state.board, - self.game_state.dice, - ); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Stratégie simple : choix aléatoire - let mut rng = thread_rng(); - let choosen_move = *possible_moves - .choose(&mut rng) - .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); - - GameEvent::Move { - player_id: self.opponent_player_id, - moves: if opponent_color == Color::White { - choosen_move - } else { - (choosen_move.0.mirror(), choosen_move.1.mirror()) - }, - } - } - TurnStage::HoldOrGoChoice => { - // Stratégie simple : toujours continuer - GameEvent::Go { - player_id: self.opponent_player_id, - } - } - }; - if self.game_state.validate(&event) { - self.game_state.consume(&event); - } - reward - } -} - -/// Entraîneur pour le modèle DQN -pub struct DqnTrainer { - agent: DqnAgent, - env: TrictracEnv, -} - -impl DqnTrainer { - pub fn new(config: DqnConfig) -> Self { - Self { - agent: DqnAgent::new(config), - env: TrictracEnv::default(), - } - } - - pub fn train_episode(&mut self) -> f32 { - let mut total_reward = 0.0; - let mut state = self.env.reset(); - // let mut step_count = 0; - - loop { - // step_count += 1; - let action = self.agent.select_action(&self.env.game_state, &state); - let (next_state, reward, done) = self.env.step(action.clone()); - total_reward += reward; - - let experience = Experience { - state: state.clone(), - action, - reward, - next_state: next_state.clone(), - done, - }; - self.agent.store_experience(experience); - self.agent.train(); - - if done { - break; - } - // if step_count % 100 == 0 { - // println!("{:?}", next_state); - // } - state = next_state; - } - - total_reward - } - - pub fn train( - &mut self, - episodes: usize, - save_every: usize, - model_path: &str, - ) -> Result<(), Box> { - println!("Démarrage de l'entraînement DQN pour {episodes} épisodes"); - - for episode in 1..=episodes { - let reward = self.train_episode(); - - if episode % 100 == 0 { - println!( - "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", - episode, - episodes, - reward, - self.agent.get_epsilon(), - self.agent.get_step_count() - ); - } - - if episode % save_every == 0 { - let save_path = format!("{model_path}_episode_{episode}.json"); - self.agent.save_model(&save_path)?; - println!("Modèle sauvegardé : {save_path}"); - } - } - - // Sauvegarder le modèle final - let final_path = format!("{model_path}_final.json"); - self.agent.save_model(&final_path)?; - println!("Modèle final sauvegardé : {final_path}"); - - Ok(()) - } -} diff --git a/bot/src/dqn_simple/main.rs b/bot/src/dqn_simple/main.rs deleted file mode 100644 index 024f895..0000000 --- a/bot/src/dqn_simple/main.rs +++ /dev/null @@ -1,109 +0,0 @@ -use bot::dqn_simple::dqn_model::DqnConfig; -use bot::dqn_simple::dqn_trainer::DqnTrainer; -use bot::training_common::TrictracAction; -use std::env; - -fn main() -> Result<(), Box> { - env_logger::init(); - - let args: Vec = env::args().collect(); - - // Paramètres par défaut - let mut episodes = 1000; - let mut model_path = "models/dqn_model".to_string(); - let mut save_every = 100; - - // Parser les arguments de ligne de commande - let mut i = 1; - while i < args.len() { - match args[i].as_str() { - "--episodes" => { - if i + 1 < args.len() { - episodes = args[i + 1].parse().unwrap_or(1000); - i += 2; - } else { - eprintln!("Erreur : --episodes nécessite une valeur"); - std::process::exit(1); - } - } - "--model-path" => { - if i + 1 < args.len() { - model_path = args[i + 1].clone(); - i += 2; - } else { - eprintln!("Erreur : --model-path nécessite une valeur"); - std::process::exit(1); - } - } - "--save-every" => { - if i + 1 < args.len() { - save_every = args[i + 1].parse().unwrap_or(100); - i += 2; - } else { - eprintln!("Erreur : --save-every nécessite une valeur"); - std::process::exit(1); - } - } - "--help" | "-h" => { - print_help(); - std::process::exit(0); - } - _ => { - eprintln!("Argument inconnu : {}", args[i]); - print_help(); - std::process::exit(1); - } - } - } - - // Créer le dossier models s'il n'existe pas - std::fs::create_dir_all("models")?; - - println!("Configuration d'entraînement DQN :"); - println!(" Épisodes : {episodes}"); - println!(" Chemin du modèle : {model_path}"); - println!(" Sauvegarde tous les {save_every} épisodes"); - println!(); - - // Configuration DQN - let config = DqnConfig { - state_size: 36, // state.to_vec size - hidden_size: 256, - num_actions: TrictracAction::action_space_size(), - learning_rate: 0.001, - gamma: 0.99, - epsilon: 0.9, // Commencer avec plus d'exploration - epsilon_decay: 0.995, - epsilon_min: 0.01, - replay_buffer_size: 10000, - batch_size: 32, - }; - - // Créer et lancer l'entraîneur - let mut trainer = DqnTrainer::new(config); - trainer.train(episodes, save_every, &model_path)?; - - println!("Entraînement terminé avec succès !"); - println!("Pour utiliser le modèle entraîné :"); - println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy"); - - Ok(()) -} - -fn print_help() { - println!("Entraîneur DQN pour Trictrac"); - println!(); - println!("USAGE:"); - println!(" cargo run --bin=train_dqn [OPTIONS]"); - println!(); - println!("OPTIONS:"); - println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); - println!(" --model-path Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)"); - println!(" --save-every Sauvegarder le modèle tous les N épisodes (défaut: 100)"); - println!(" -h, --help Afficher cette aide"); - println!(); - println!("EXEMPLES:"); - println!(" cargo run --bin=train_dqn"); - println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500"); - println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000"); -} diff --git a/bot/src/dqn_simple/mod.rs b/bot/src/dqn_simple/mod.rs deleted file mode 100644 index 8090a29..0000000 --- a/bot/src/dqn_simple/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod dqn_model; -pub mod dqn_trainer; diff --git a/bot/src/lib.rs b/bot/src/lib.rs index dab36be..0fc6fdf 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,14 +1,11 @@ pub mod burnrl; -pub mod dqn_simple; pub mod strategy; pub mod training_common; -pub mod training_common_big; pub mod trictrac_board; use log::debug; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; -pub use strategy::dqn::DqnStrategy; pub use strategy::dqnburn::DqnBurnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::random::RandomStrategy; diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs deleted file mode 100644 index 2874195..0000000 --- a/bot/src/strategy/dqn.rs +++ /dev/null @@ -1,174 +0,0 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; -use log::info; -use std::path::Path; -use store::MoveRules; - -use crate::dqn_simple::dqn_model::SimpleNeuralNetwork; -use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction}; - -/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné -#[derive(Debug)] -pub struct DqnStrategy { - pub game: GameState, - pub player_id: PlayerId, - pub color: Color, - pub model: Option, -} - -impl Default for DqnStrategy { - fn default() -> Self { - Self { - game: GameState::default(), - player_id: 1, - color: Color::White, - model: None, - } - } -} - -impl DqnStrategy { - pub fn new() -> Self { - Self::default() - } - - pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { - let mut strategy = Self::new(); - if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { - info!("Loading model {model_path:?}"); - strategy.model = Some(model); - } - strategy - } - - /// Utilise le modèle DQN pour choisir une action valide - fn get_dqn_action(&self) -> Option { - if let Some(ref model) = self.model { - let state = self.game.to_vec_float(); - let valid_actions = get_valid_actions(&self.game); - - if valid_actions.is_empty() { - return None; - } - - // Obtenir les Q-values pour toutes les actions - let q_values = model.forward(&state); - - // Trouver la meilleure action valide - let mut best_action = &valid_actions[0]; - let mut best_q_value = f32::NEG_INFINITY; - - for action in &valid_actions { - let action_index = action.to_action_index(); - if action_index < q_values.len() { - let q_value = q_values[action_index]; - if q_value > best_q_value { - best_q_value = q_value; - best_action = action; - } - } - } - - Some(best_action.clone()) - } else { - // Fallback : action aléatoire valide - sample_valid_action(&self.game) - } - } -} - -impl BotStrategy for DqnStrategy { - fn get_game(&self) -> &GameState { - &self.game - } - - fn get_mut_game(&mut self) -> &mut GameState { - &mut self.game - } - - fn set_color(&mut self, color: Color) { - self.color = color; - } - - fn set_player_id(&mut self, player_id: PlayerId) { - self.player_id = player_id; - } - - fn calculate_points(&self) -> u8 { - self.game.dice_points.0 - } - - fn calculate_adv_points(&self) -> u8 { - self.game.dice_points.1 - } - - fn choose_go(&self) -> bool { - // Utiliser le DQN pour décider si on continue - if let Some(action) = self.get_dqn_action() { - matches!(action, TrictracAction::Go) - } else { - // Fallback : toujours continuer - true - } - } - - fn choose_move(&self) -> (CheckerMove, CheckerMove) { - // Utiliser le DQN pour choisir le mouvement - if let Some(TrictracAction::Move { - dice_order, - from1, - from2, - }) = self.get_dqn_action() - { - let dicevals = self.game.dice.values; - let (mut dice1, mut dice2) = if dice_order { - (dicevals.0, dicevals.1) - } else { - (dicevals.1, dicevals.0) - }; - - if from1 == 0 { - // empty move - dice1 = 0; - } - let mut to1 = from1 + dice1 as usize; - if 24 < to1 { - // sortie - to1 = 0; - } - if from2 == 0 { - // empty move - dice2 = 0; - } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; - } - - let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); - let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); - - let chosen_move = if self.color == Color::White { - (checker_move1, checker_move2) - } else { - (checker_move1.mirror(), checker_move2.mirror()) - }; - - return chosen_move; - } - - // Fallback : utiliser la stratégie par défaut - let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - let chosen_move = *possible_moves - .first() - .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); - - if self.color == Color::White { - chosen_move - } else { - (chosen_move.0.mirror(), chosen_move.1.mirror()) - } - } -} diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index b9fa3b2..00293cb 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -1,6 +1,5 @@ pub mod client; pub mod default; -pub mod dqn; pub mod dqnburn; pub mod erroneous_moves; pub mod random; diff --git a/bot/src/training_common.rs b/bot/src/training_common.rs index 750b2ae..ee33d0c 100644 --- a/bot/src/training_common.rs +++ b/bot/src/training_common.rs @@ -1,5 +1,5 @@ -/// training_common_big.rs : environnement avec espace d'actions optimisé -/// (514 au lieu de 1252 pour training_common_big.rs) +/// training_common.rs : environnement avec espace d'actions optimisé +/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' ) use std::cmp::{max, min}; use std::fmt::{Debug, Display, Formatter}; diff --git a/bot/src/training_common_big.rs b/bot/src/training_common_big.rs deleted file mode 100644 index d7e5bf1..0000000 --- a/bot/src/training_common_big.rs +++ /dev/null @@ -1,268 +0,0 @@ -/// training_common_big.rs : environnement avec espace d'actions non optimisé -/// (1252 au lieu de 514 pour training_common.rs) -use std::cmp::{max, min}; - -use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice}; - -/// Types d'actions possibles dans le jeu -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum TrictracAction { - /// Lancer les dés - Roll, - /// Continuer après avoir gagné un trou - Go, - /// Effectuer un mouvement de pions - Move { - dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier - from1: usize, // position de départ du premier pion (0-24) - from2: usize, // position de départ du deuxième pion (0-24) - }, - // Marquer les points : à activer si support des écoles - // Mark, -} - -impl TrictracAction { - /// Encode une action en index pour le réseau de neurones - pub fn to_action_index(&self) -> usize { - match self { - TrictracAction::Roll => 0, - TrictracAction::Go => 1, - TrictracAction::Move { - dice_order, - from1, - from2, - } => { - // Encoder les mouvements dans l'espace d'actions - // Indices 2+ pour les mouvements - // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) - let mut start = 2; - if !dice_order { - // 25 * 25 = 625 - start += 625; - } - start + from1 * 25 + from2 - } // TrictracAction::Mark => 1252, - } - } - - /// Décode un index d'action en TrictracAction - pub fn from_action_index(index: usize) -> Option { - match index { - 0 => Some(TrictracAction::Roll), - // 1252 => Some(TrictracAction::Mark), - 1 => Some(TrictracAction::Go), - i if i >= 3 => { - let move_code = i - 3; - let (dice_order, from1, from2) = Self::decode_move(move_code); - Some(TrictracAction::Move { - dice_order, - from1, - from2, - }) - } - _ => None, - } - } - - /// Décode un entier en paire de mouvements - fn decode_move(code: usize) -> (bool, usize, usize) { - let mut encoded = code; - let dice_order = code < 626; - if !dice_order { - encoded -= 625 - } - let from1 = encoded / 25; - let from2 = 1 + encoded % 25; - (dice_order, from1, from2) - } - - /// Retourne la taille de l'espace d'actions total - pub fn action_space_size() -> usize { - // 1 (Roll) + 1 (Go) + mouvements possibles - // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) - // Mais on peut optimiser en limitant aux positions valides (1-24) - 2 + (2 * 25 * 25) // = 1252 - } - - // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { - // match action { - // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), - // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), - // TrictracAction::Go => Some(GameEvent::Go { player_id }), - // TrictracAction::Move { - // dice_order, - // from1, - // from2, - // } => { - // // Effectuer un mouvement - // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); - // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); - // - // Some(GameEvent::Move { - // player_id: self.agent_player_id, - // moves: (checker_move1, checker_move2), - // }) - // } - // }; - // } -} - -/// Obtient les actions valides pour l'état de jeu actuel -pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use store::TurnStage; - - let mut valid_actions = Vec::new(); - - let active_player_id = game_state.active_player_id; - let player_color = game_state.player_color_by_id(&active_player_id); - - if let Some(color) = player_color { - match game_state.turn_stage { - TurnStage::RollDice => { - valid_actions.push(TrictracAction::Roll); - } - TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => { - panic!( - "get_valid_actions not implemented for turn stage {:?}", - game_state.turn_stage - ); - // valid_actions.push(TrictracAction::Mark); - } - TurnStage::HoldOrGoChoice => { - valid_actions.push(TrictracAction::Go); - - // Ajoute aussi les mouvements possibles - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let possible_moves = rules.get_possible_moves_sequences(true, vec![]); - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - TurnStage::Move => { - let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); - let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]); - if possible_moves.is_empty() { - // Empty move - possible_moves.push((CheckerMove::default(), CheckerMove::default())); - } - - // Modififier checker_moves_to_trictrac_action si on doit gérer Black - assert_eq!(color, store::Color::White); - for (move1, move2) in possible_moves { - valid_actions.push(checker_moves_to_trictrac_action( - &move1, - &move2, - &game_state.dice, - )); - } - } - } - } - - if valid_actions.is_empty() { - panic!("empty valid_actions for state {game_state}"); - } - valid_actions -} - -// Valid only for White player -fn checker_moves_to_trictrac_action( - move1: &CheckerMove, - move2: &CheckerMove, - dice: &Dice, -) -> TrictracAction { - let to1 = move1.get_to(); - let to2 = move2.get_to(); - let from1 = move1.get_from(); - let from2 = move2.get_from(); - - let mut diff_move1 = if to1 > 0 { - // Mouvement sans sortie - to1 - from1 - } else { - // sortie, on utilise la valeur du dé - if to2 > 0 { - // sortie pour le mouvement 1 uniquement - let dice2 = to2 - from2; - if dice2 == dice.values.0 as usize { - dice.values.1 as usize - } else { - dice.values.0 as usize - } - } else { - // double sortie - if from1 < from2 { - max(dice.values.0, dice.values.1) as usize - } else { - min(dice.values.0, dice.values.1) as usize - } - } - }; - - // modification de diff_move1 si on est dans le cas d'un mouvement par puissance - let rest_field = 12; - if to1 == rest_field - && to2 == rest_field - && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field - { - // prise par puissance - diff_move1 += 1; - } - TrictracAction::Move { - dice_order: diff_move1 == dice.values.0 as usize, - from1: move1.get_from(), - from2: move2.get_from(), - } -} - -/// Retourne les indices des actions valides -pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { - get_valid_actions(game_state) - .into_iter() - .map(|action| action.to_action_index()) - .collect() -} - -/// Sélectionne une action valide aléatoire -pub fn sample_valid_action(game_state: &crate::GameState) -> Option { - use rand::{seq::SliceRandom, thread_rng}; - - let valid_actions = get_valid_actions(game_state); - let mut rng = thread_rng(); - valid_actions.choose(&mut rng).cloned() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - let index = action.to_action_index(); - assert_eq!(Some(action), TrictracAction::from_action_index(index)); - assert_eq!(81, index); - } - - #[test] - fn from_action_index() { - let action = TrictracAction::Move { - dice_order: true, - from1: 3, - from2: 4, - }; - assert_eq!(Some(action), TrictracAction::from_action_index(81)); - } -} diff --git a/client_bevy/.cargo/config.toml b/client_bevy/.cargo/config.toml deleted file mode 100644 index b6bc0d3..0000000 --- a/client_bevy/.cargo/config.toml +++ /dev/null @@ -1,8 +0,0 @@ -[target.x86_64-unknown-linux-gnu] -linker = "clang" -rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"] - -# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only' -# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains. -#[profile.dev] -#debug = 1 diff --git a/client_bevy/Cargo.toml b/client_bevy/Cargo.toml deleted file mode 100644 index aaa6b7d..0000000 --- a/client_bevy/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "trictrac-client" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -anyhow = "1.0.75" -bevy = { version = "0.11.3" } -bevy_renet = "0.0.9" -bincode = "1.3.3" -renet = "0.0.13" -store = { path = "../store" } diff --git a/client_bevy/assets/Inconsolata.ttf b/client_bevy/assets/Inconsolata.ttf deleted file mode 100644 index 34848ca..0000000 Binary files a/client_bevy/assets/Inconsolata.ttf and /dev/null differ diff --git a/client_bevy/assets/board.png b/client_bevy/assets/board.png deleted file mode 100644 index 5d16ac3..0000000 Binary files a/client_bevy/assets/board.png and /dev/null differ diff --git a/client_bevy/assets/sound/click.wav b/client_bevy/assets/sound/click.wav deleted file mode 100644 index 8b6c99d..0000000 Binary files a/client_bevy/assets/sound/click.wav and /dev/null differ diff --git a/client_bevy/assets/sound/throw.wav b/client_bevy/assets/sound/throw.wav deleted file mode 100755 index cb5e438..0000000 Binary files a/client_bevy/assets/sound/throw.wav and /dev/null differ diff --git a/client_bevy/assets/tac.png b/client_bevy/assets/tac.png deleted file mode 100644 index 2c18813..0000000 Binary files a/client_bevy/assets/tac.png and /dev/null differ diff --git a/client_bevy/assets/tic.png b/client_bevy/assets/tic.png deleted file mode 100644 index 786e0c7..0000000 Binary files a/client_bevy/assets/tic.png and /dev/null differ diff --git a/client_bevy/src/main.rs b/client_bevy/src/main.rs deleted file mode 100644 index 504602e..0000000 --- a/client_bevy/src/main.rs +++ /dev/null @@ -1,334 +0,0 @@ -use std::{net::UdpSocket, time::SystemTime}; - -use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES}; -use store::{GameEvent, GameState, CheckerMove}; - -use bevy::prelude::*; -use bevy::window::PrimaryWindow; -use bevy_renet::{ - renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient}, - transport::{client_connected, NetcodeClientPlugin}, - RenetClientPlugin, -}; - -#[derive(Debug, Resource)] -struct CurrentClientId(u64); - -#[derive(Resource)] -struct BevyGameState(GameState); - -impl Default for BevyGameState { - fn default() -> Self { - Self { - 0: GameState::default(), - } - } -} - -#[derive(Resource, Deref, DerefMut)] -struct GameUIState { - selected_tile: Option, -} - -impl Default for GameUIState { - fn default() -> Self { - Self { - selected_tile: None, - } - } -} - -#[derive(Event)] -struct BevyGameEvent(GameEvent); - -// This id needs to be the same as the server is using -const PROTOCOL_ID: u64 = 2878; - -fn main() { - // Get username from stdin args - let args = std::env::args().collect::>(); - let username = &args[1]; - - let (client, transport, client_id) = new_renet_client(&username).unwrap(); - App::new() - // Lets add a nice dark grey background color - .insert_resource(ClearColor(Color::hex("282828").unwrap())) - .add_plugins(DefaultPlugins.set(WindowPlugin { - primary_window: Some(Window { - // Adding the username to the window title makes debugging a whole lot easier. - title: format!("TricTrac <{}>", username), - resolution: (1080.0, 1080.0).into(), - ..default() - }), - ..default() - })) - // Add our game state and register GameEvent as a bevy event - .insert_resource(BevyGameState::default()) - .insert_resource(GameUIState::default()) - .add_event::() - // Renet setup - .add_plugins(RenetClientPlugin) - .add_plugins(NetcodeClientPlugin) - .insert_resource(client) - .insert_resource(transport) - .insert_resource(CurrentClientId(client_id)) - .add_systems(Startup, setup) - .add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system)) - .add_systems( - PostUpdate, - receive_events_from_server.run_if(client_connected()), - ) - .run(); -} - -////////// COMPONENTS ////////// -#[derive(Component)] -struct UIRoot; - -#[derive(Component)] -struct WaitingText; - -#[derive(Component)] -struct Board { - squares: [Square; 26] -} - -impl Default for Board { - fn default() -> Self { - Self { - squares: [Square { count: 0, color: None, position: 0}; 26] - } - } -} - -impl Board { - fn square_at(&self, position: usize) -> Square { - self.squares[position] - } -} - -#[derive(Component, Clone, Copy)] -struct Square { - count: usize, - color: Option, - position: usize, -} - -////////// UPDATE SYSTEMS ////////// -fn update_board( - mut commands: Commands, - game_state: Res, - mut game_events: EventReader, - asset_server: Res, -) { - for event in game_events.iter() { - match event.0 { - GameEvent::Move { player_id, moves } => { - // trictrac positions, TODO : dépend de player_id - let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)}; - let texture = - asset_server.load(match game_state.0.players[&player_id].color { - store::Color::Black => "tac.png", - store::Color::White => "tic.png", - }); - - info!("spawning tictac sprite"); - commands.spawn(SpriteBundle { - transform: Transform::from_xyz( - 83.0 * (x as f32 - 1.0), - -30.0 + 540.0 * (y as f32 - 1.0), - 0.0, - ), - sprite: Sprite { - custom_size: Some(Vec2::new(83.0, 83.0)), - ..default() - }, - texture: texture.into(), - ..default() - }); - } - _ => {} - } - } -} - -fn update_waiting_text(mut text_query: Query<&mut Text, With>, time: Res