diff --git a/Cargo.lock b/Cargo.lock index a71f75a..270eb15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -834,7 +834,7 @@ dependencies = [ "derive-new", "log", "nvml-wrapper", - "ratatui", + "ratatui 0.29.0", "rstest", "serde", "sysinfo", @@ -1066,6 +1066,17 @@ dependencies = [ "store", ] +[[package]] +name = "client_tui" +version = "0.1.0" +dependencies = [ + "anyhow", + "bincode 1.3.3", + "crossterm", + "ratatui 0.28.1", + "store", +] + [[package]] name = "cmake" version = "0.1.54" @@ -4403,6 +4414,27 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" +[[package]] +name = "ratatui" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d" +dependencies = [ + "bitflags 2.9.4", + "cassowary", + "compact_str", + "crossterm", + "instability", + "itertools 0.13.0", + "lru", + "paste", + "strum 0.26.3", + "strum_macros 0.26.4", + "unicode-segmentation", + "unicode-truncate", + "unicode-width 0.1.14", +] + [[package]] name = "ratatui" version = "0.29.0" @@ -5781,6 +5813,18 @@ dependencies = [ "strength_reduce", ] +[[package]] +name = "trictrac-server" +version = "0.1.0" +dependencies = [ + "bincode 1.3.3", + "env_logger 0.10.2", + "log", + "pico-args", + "renet", + "store", +] + [[package]] name = "tungstenite" version = "0.26.2" diff --git a/Cargo.toml b/Cargo.toml index b9e6d45..6068644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ [workspace] resolver = "2" -members = ["client_cli", "bot", "store"] +members = ["client_tui", "client_cli", "bot", "server", "store"] diff --git a/bot/src/burnrl/algos/dqn_big.rs b/bot/src/burnrl/algos/dqn_big.rs new file mode 100644 index 0000000..7e8951f --- /dev/null +++ b/bot/src/burnrl/algos/dqn_big.rs @@ -0,0 +1,194 @@ +use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::relu; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::DQN; +use burn_rl::agent::{DQNModel, DQNTrainingConfig}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Net { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + relu(self.linear_2.forward(layer_1_output)) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl DQNModel for Net { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 8192; + +type MyAgent = DQN>; + +#[allow(unused)] +// pub fn run, B: AutodiffBackend>( +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, + // ) -> DQN> { +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + let mut agent = MyAgent::new(model); + + // let config = DQNTrainingConfig::default(); + let config = DQNTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + + let mut policy_net = agent.model().as_ref().unwrap().clone(); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward: ElemType = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + let eps_threshold = conf.eps_end + + (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay); + let action = + DQN::>::react_with_exploration(&policy_net, state, eps_threshold); + let snapshot = env.step(action); + + episode_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + policy_net = + agent.train::(policy_net, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + let envmut = env.as_mut(); + let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32) + * 100.0) + .round() as u32; + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}", + envmut.goodmoves_count, + goodmoves_ratio, + envmut.pointrolls_count, + now.elapsed().unwrap().as_secs(), + ); + env.reset(); + episode_done = true; + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + let valid_agent = agent.valid(); + if let Some(path) = &conf.save_path { + save_model(valid_agent.model().as_ref().unwrap(), path); + } + valid_agent +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} diff --git a/bot/src/burnrl/algos/mod.rs b/bot/src/burnrl/algos/mod.rs index 5a67dfc..af13327 100644 --- a/bot/src/burnrl/algos/mod.rs +++ b/bot/src/burnrl/algos/mod.rs @@ -1,6 +1,9 @@ pub mod dqn; +pub mod dqn_big; pub mod dqn_valid; pub mod ppo; +pub mod ppo_big; pub mod ppo_valid; pub mod sac; +pub mod sac_big; pub mod sac_valid; diff --git a/bot/src/burnrl/algos/ppo_big.rs b/bot/src/burnrl/algos/ppo_big.rs new file mode 100644 index 0000000..ab860ee --- /dev/null +++ b/bot/src/burnrl/algos/ppo_big.rs @@ -0,0 +1,191 @@ +use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::utils::Config; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Initializer, Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::env; +use std::fs; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear: Linear, + linear_actor: Linear, + linear_critic: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + let initializer = Initializer::XavierUniform { gain: 1.0 }; + Self { + linear: LinearConfig::new(input_size, dense_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_actor: LinearConfig::new(dense_size, output_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_critic: LinearConfig::new(dense_size, 1) + .with_initializer(initializer) + .init(&Default::default()), + } + } +} + +impl Model, PPOOutput, Tensor> for Net { + fn forward(&self, input: Tensor) -> PPOOutput { + let layer_0_output = relu(self.linear.forward(input)); + let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1); + let values = self.linear_critic.forward(layer_0_output); + + PPOOutput::::new(policies, values) + } + + fn infer(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear.forward(input)); + softmax(self.linear_actor.forward(layer_0_output.clone()), 1) + } +} + +impl PPOModel for Net {} +#[allow(unused)] +const MEMORY_SIZE: usize = 512; + +type MyAgent = PPO>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, + // ) -> PPO> { +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let mut model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + let agent = MyAgent::default(); + let config = PPOTrainingConfig { + gamma: conf.gamma, + lambda: conf.lambda, + epsilon_clip: conf.epsilon_clip, + critic_weight: conf.critic_weight, + entropy_weight: conf.entropy_weight, + learning_rate: conf.learning_rate, + epochs: conf.epochs, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + let mut memory = Memory::::default(); + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut now = SystemTime::now(); + + env.reset(); + while !episode_done { + let state = env.state(); + if let Some(action) = MyAgent::::react_with_model(&state, &model) { + let snapshot = env.step(action); + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + episode_duration += 1; + episode_done = snapshot.done() || episode_duration >= conf.max_steps; + } + } + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), + ); + + now = SystemTime::now(); + model = MyAgent::train::(model, &memory, &mut optimizer, &config); + memory.clear(); + } + + if let Some(path) = &conf.save_path { + let device = NdArrayDevice::default(); + let recorder = CompactRecorder::new(); + let tmp_path = env::temp_dir().join("tmp_model.mpk"); + + // Save the trained model (backend B) to a temporary file + recorder + .record(model.clone().into_record(), tmp_path.clone()) + .expect("Failed to save temporary model"); + + // Create a new model instance with the target backend (NdArray) + let model_to_save: Net> = Net::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + + // Load the record from the temporary file into the new model + let record = recorder + .load(tmp_path.clone(), &device) + .expect("Failed to load temporary model"); + let model_with_loaded_weights = model_to_save.load_record(record); + + // Clean up the temporary file + fs::remove_file(tmp_path).expect("Failed to remove temporary model file"); + + save_model(&model_with_loaded_weights, path); + } + agent.valid(model) +} + +pub fn save_model(model: &Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} diff --git a/bot/src/burnrl/algos/sac_big.rs b/bot/src/burnrl/algos/sac_big.rs new file mode 100644 index 0000000..1361b42 --- /dev/null +++ b/bot/src/burnrl/algos/sac_big.rs @@ -0,0 +1,222 @@ +use crate::burnrl::environment_big::TrictracEnvironment; +use crate::burnrl::utils::{soft_update_linear, Config}; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::Module; +use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Actor { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Actor { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } +} + +impl Model, Tensor> for Actor { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + softmax(self.linear_2.forward(layer_1_output), 1) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACActor for Actor {} + +#[derive(Module, Debug)] +pub struct Critic { + linear_0: Linear, + linear_1: Linear, + linear_2: Linear, +} + +impl Critic { + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + Self { + linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), + linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), + linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + } + } + + fn consume(self) -> (Linear, Linear, Linear) { + (self.linear_0, self.linear_1, self.linear_2) + } +} + +impl Model, Tensor> for Critic { + fn forward(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear_0.forward(input)); + let layer_1_output = relu(self.linear_1.forward(layer_0_output)); + + self.linear_2.forward(layer_1_output) + } + + fn infer(&self, input: Tensor) -> Tensor { + self.forward(input) + } +} + +impl SACCritic for Critic { + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); + + Self { + linear_0: soft_update_linear(linear_0, &that.linear_0, tau), + linear_1: soft_update_linear(linear_1, &that.linear_1, tau), + linear_2: soft_update_linear(linear_2, &that.linear_2, tau), + } + } +} + +#[allow(unused)] +const MEMORY_SIZE: usize = 4096; + +type MyAgent = SAC>; + +#[allow(unused)] +pub fn run< + E: Environment + AsMut, + B: AutodiffBackend, +>( + conf: &Config, + visualized: bool, +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + let state_dim = <::StateType as State>::size(); + let action_dim = <::ActionType as Action>::size(); + + let actor = Actor::::new(state_dim, conf.dense_size, action_dim); + let critic_1 = Critic::::new(state_dim, conf.dense_size, action_dim); + let critic_2 = Critic::::new(state_dim, conf.dense_size, action_dim); + let mut nets = SACNets::, Critic>::new(actor, critic_1, critic_2); + + let mut agent = MyAgent::default(); + + let config = SACTrainingConfig { + gamma: conf.gamma, + tau: conf.tau, + learning_rate: conf.learning_rate, + min_probability: conf.min_probability, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut memory = Memory::::default(); + + let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone()); + + let mut optimizer = SACOptimizer::new( + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.clone().init(), + optimizer_config.init(), + ); + + let mut step = 0_usize; + + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut state = env.state(); + let mut now = SystemTime::now(); + + while !episode_done { + if let Some(action) = MyAgent::::react_with_model(&state, &nets.actor) { + let snapshot = env.step(action); + + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + if config.batch_size < memory.len() { + nets = agent.train::(nets, &memory, &mut optimizer, &config); + } + + step += 1; + episode_duration += 1; + + if snapshot.done() || episode_duration >= conf.max_steps { + env.reset(); + episode_done = true; + + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs() + ); + now = SystemTime::now(); + } else { + state = *snapshot.state(); + } + } + } + } + + let valid_agent = agent.valid(nets.actor); + if let Some(path) = &conf.save_path { + if let Some(model) = valid_agent.model() { + save_model(model, path); + } + } + valid_agent +} + +pub fn save_model(model: &Actor>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("info: Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + Actor::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 84c8311..50daf11 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -6,10 +6,10 @@ use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; -const ERROR_REWARD: f32 = -1.0012121; -const REWARD_VALID_MOVE: f32 = 1.0012121; -const REWARD_RATIO: f32 = 0.1; -const WIN_POINTS: f32 = 100.0; +const ERROR_REWARD: f32 = -1.12121; +const REWARD_VALID_MOVE: f32 = 1.12121; +const REWARD_RATIO: f32 = 0.01; +const WIN_POINTS: f32 = 1.0; /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] @@ -285,7 +285,7 @@ impl TrictracEnvironment { if let Some(event) = action.to_event(&self.game) { if self.game.validate(&event) { self.game.consume(&event); - // reward += REWARD_VALID_MOVE; + reward += REWARD_VALID_MOVE; // Simuler le résultat des dés après un Roll if matches!(action, TrictracAction::Roll) { let mut rng = thread_rng(); @@ -312,11 +312,9 @@ impl TrictracEnvironment { // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques reward = ERROR_REWARD; - self.game.mark_points_for_bot_training(self.opponent_id, 1); } } else { reward = ERROR_REWARD; - self.game.mark_points_for_bot_training(self.opponent_id, 1); } (reward, is_rollpoint) diff --git a/bot/src/burnrl/environment_big.rs b/bot/src/burnrl/environment_big.rs new file mode 100644 index 0000000..1bba2bd --- /dev/null +++ b/bot/src/burnrl/environment_big.rs @@ -0,0 +1,469 @@ +use crate::training_common_big; +use burn::{prelude::Backend, tensor::Tensor}; +use burn_rl::base::{Action, Environment, Snapshot, State}; +use rand::{thread_rng, Rng}; +use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; + +const ERROR_REWARD: f32 = -2.12121; +const REWARD_VALID_MOVE: f32 = 2.12121; +const REWARD_RATIO: f32 = 0.01; +const WIN_POINTS: f32 = 0.1; + +/// État du jeu Trictrac pour burn-rl +#[derive(Debug, Clone, Copy)] +pub struct TrictracState { + pub data: [i8; 36], // Représentation vectorielle de l'état du jeu +} + +impl State for TrictracState { + type Data = [i8; 36]; + + fn to_tensor(&self) -> Tensor { + Tensor::from_floats(self.data, &B::Device::default()) + } + + fn size() -> usize { + 36 + } +} + +impl TrictracState { + /// Convertit un GameState en TrictracState + pub fn from_game_state(game_state: &GameState) -> Self { + let state_vec = game_state.to_vec(); + let mut data = [0; 36]; + + // Copier les données en s'assurant qu'on ne dépasse pas la taille + let copy_len = state_vec.len().min(36); + data[..copy_len].copy_from_slice(&state_vec[..copy_len]); + + TrictracState { data } + } +} + +/// Actions possibles dans Trictrac pour burn-rl +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TrictracAction { + // u32 as required by burn_rl::base::Action type + pub index: u32, +} + +impl Action for TrictracAction { + fn random() -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + TrictracAction { + index: rng.gen_range(0..Self::size() as u32), + } + } + + fn enumerate() -> Vec { + (0..Self::size() as u32) + .map(|index| TrictracAction { index }) + .collect() + } + + fn size() -> usize { + 1252 + } +} + +impl From for TrictracAction { + fn from(index: u32) -> Self { + TrictracAction { index } + } +} + +impl From for u32 { + fn from(action: TrictracAction) -> u32 { + action.index + } +} + +/// Environnement Trictrac pour burn-rl +#[derive(Debug)] +pub struct TrictracEnvironment { + pub game: GameState, + active_player_id: PlayerId, + opponent_id: PlayerId, + current_state: TrictracState, + episode_reward: f32, + pub step_count: usize, + pub max_steps: usize, + pub pointrolls_count: usize, + pub goodmoves_count: usize, + pub goodmoves_ratio: f32, + pub visualized: bool, +} + +impl Environment for TrictracEnvironment { + type StateType = TrictracState; + type ActionType = TrictracAction; + type RewardType = f32; + + fn new(visualized: bool) -> Self { + let mut game = GameState::new(false); + + // Ajouter deux joueurs + game.init_player("DQN Agent"); + game.init_player("Opponent"); + let player1_id = 1; + let player2_id = 2; + + // Commencer la partie + game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + let current_state = TrictracState::from_game_state(&game); + TrictracEnvironment { + game, + active_player_id: player1_id, + opponent_id: player2_id, + current_state, + episode_reward: 0.0, + step_count: 0, + max_steps: 2000, + pointrolls_count: 0, + goodmoves_count: 0, + goodmoves_ratio: 0.0, + visualized, + } + } + + fn state(&self) -> Self::StateType { + self.current_state + } + + fn reset(&mut self) -> Snapshot { + // Réinitialiser le jeu + self.game = GameState::new(false); + self.game.init_player("DQN Agent"); + self.game.init_player("Opponent"); + + // Commencer la partie + self.game.consume(&GameEvent::BeginGame { goes_first: 1 }); + + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward = 0.0; + self.goodmoves_ratio = if self.step_count == 0 { + 0.0 + } else { + self.goodmoves_count as f32 / self.step_count as f32 + }; + println!( + "info: correct moves: {} ({}%)", + self.goodmoves_count, + (100.0 * self.goodmoves_ratio).round() as u32 + ); + self.step_count = 0; + self.pointrolls_count = 0; + self.goodmoves_count = 0; + + Snapshot::new(self.current_state, 0.0, false) + } + + fn step(&mut self, action: Self::ActionType) -> Snapshot { + self.step_count += 1; + + // Convertir l'action burn-rl vers une action Trictrac + let trictrac_action = Self::convert_action(action); + + let mut reward = 0.0; + let is_rollpoint; + + // Exécuter l'action si c'est le tour de l'agent DQN + if self.game.active_player_id == self.active_player_id { + if let Some(action) = trictrac_action { + (reward, is_rollpoint) = self.execute_action(action); + if is_rollpoint { + self.pointrolls_count += 1; + } + if reward != ERROR_REWARD { + self.goodmoves_count += 1; + // println!("{str_action}"); + } + } else { + // Action non convertible, pénalité + reward = -0.5; + } + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // print!(":"); + reward += self.play_opponent_if_needed(); + } + + // Vérifier si la partie est terminée + // let max_steps = self.max_steps + // let max_steps = self.min_steps + // + (self.max_steps as f32 - self.min_steps) + // * f32::exp((self.goodmoves_ratio - 1.0) / 0.25); + let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); + + if done { + // Récompense finale basée sur le résultat + if let Some(winner_id) = self.game.determine_winner() { + if winner_id == self.active_player_id { + reward += WIN_POINTS; // Victoire + } else { + reward -= WIN_POINTS; // Défaite + } + } + } + let terminated = done || self.step_count >= self.max_steps; + + // Mettre à jour l'état + self.current_state = TrictracState::from_game_state(&self.game); + self.episode_reward += reward; + if self.visualized && terminated { + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + self.episode_reward, self.step_count + ); + } + + Snapshot::new(self.current_state, reward, terminated) + } +} + +impl TrictracEnvironment { + /// Convertit une action burn-rl vers une action Trictrac + pub fn convert_action(action: TrictracAction) -> Option { + training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) + } + + /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac + #[allow(dead_code)] + fn convert_valid_action_index( + &self, + action: TrictracAction, + game_state: &GameState, + ) -> Option { + use training_common_big::get_valid_actions; + + // Obtenir les actions valides dans le contexte actuel + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + return None; + } + + // Mapper l'index d'action sur une action valide + let action_index = (action.index as usize) % valid_actions.len(); + Some(valid_actions[action_index].clone()) + } + + /// Exécute une action Trictrac dans le jeu + // fn execute_action( + // &mut self, + // action:training_common_big::TrictracAction, + // ) -> Result> { + fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { + use training_common_big::TrictracAction; + + let mut reward = 0.0; + let mut is_rollpoint = false; + let mut need_roll = false; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + need_roll = true; + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game.validate(&event) { + self.game.consume(&event); + reward += REWARD_VALID_MOVE; + // Simuler le résultat des dés après un Roll + // if matches!(action, TrictracAction::Roll) { + if need_roll { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.active_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + // print!("o"); + if self.game.validate(&dice_event) { + self.game.consume(&dice_event); + let (points, adv_points) = self.game.dice_points; + reward += REWARD_RATIO * (points - adv_points) as f32; + if points > 0 { + is_rollpoint = true; + // println!("info: rolled for {reward}"); + } + // Récompense proportionnelle aux points + } + } + } else { + // Pénalité pour action invalide + // on annule les précédents reward + // et on indique une valeur reconnaissable pour statistiques + reward = ERROR_REWARD; + } + } + + (reward, is_rollpoint) + } + + /// Fait jouer l'adversaire avec une stratégie simple + fn play_opponent_if_needed(&mut self) -> f32 { + // print!("z?"); + let mut reward = 0.0; + + // Si c'est le tour de l'adversaire, jouer automatiquement + if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // Utiliser la stratégie default pour l'adversaire + use crate::BotStrategy; + + let mut strategy = crate::strategy::random::RandomStrategy::default(); + strategy.set_player_id(self.opponent_id); + if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { + strategy.set_color(color); + } + *strategy.get_mut_game() = self.game.clone(); + + // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; + let event = match self.game.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; // comment to replicate burnrl_before + GameEvent::RollResult { + player_id: self.opponent_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkPoints => { + panic!("in play_opponent_if_needed > TurnStage::MarkPoints"); + // let dice_roll_count = self + // .game + // .players + // .get(&self.opponent_id) + // .unwrap() + // .dice_roll_count; + // let points_rules = + // PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // GameEvent::Mark { + // player_id: self.opponent_id, + // points: points_rules.get_points(dice_roll_count).0, + // } + } + TurnStage::MarkAdvPoints => { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + // pas de reward : déjà comptabilisé lors du tour de blanc + GameEvent::Mark { + player_id: self.opponent_id, + points: points_rules.get_points(dice_roll_count).1, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_id, + } + } + TurnStage::Move => GameEvent::Move { + player_id: self.opponent_id, + moves: strategy.choose_move(), + }, + }; + + if self.game.validate(&event) { + self.game.consume(&event); + // print!("."); + if calculate_points { + // print!("x"); + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + let adv_reward = REWARD_RATIO * (points - adv_points) as f32; + reward -= adv_reward; + // if adv_reward != 0.0 { + // println!("info: opponent : {adv_reward} -> {reward}"); + // } + } + } + } + reward + } +} + +impl AsMut for TrictracEnvironment { + fn as_mut(&mut self) -> &mut Self { + self + } +} diff --git a/bot/src/burnrl/environment_valid.rs b/bot/src/burnrl/environment_valid.rs index 9c27af9..346044c 100644 --- a/bot/src/burnrl/environment_valid.rs +++ b/bot/src/burnrl/environment_valid.rs @@ -1,12 +1,9 @@ -use crate::training_common; +use crate::training_common_big; use burn::{prelude::Backend, tensor::Tensor}; use burn_rl::base::{Action, Environment, Snapshot, State}; use rand::{thread_rng, Rng}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; -const ERROR_REWARD: f32 = -1.0012121; -const REWARD_RATIO: f32 = 0.1; - /// État du jeu Trictrac pour burn-rl #[derive(Debug, Clone, Copy)] pub struct TrictracState { @@ -217,16 +214,16 @@ impl TrictracEnvironment { const REWARD_RATIO: f32 = 1.0; /// Convertit une action burn-rl vers une action Trictrac - pub fn convert_action(action: TrictracAction) -> Option { - training_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) + pub fn convert_action(action: TrictracAction) -> Option { + training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) } /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac fn convert_valid_action_index( &self, action: TrictracAction, - ) -> Option { - use training_common::get_valid_actions; + ) -> Option { + use training_common_big::get_valid_actions; // Obtenir les actions valides dans le contexte actuel let valid_actions = get_valid_actions(&self.game); @@ -243,19 +240,72 @@ impl TrictracEnvironment { /// Exécute une action Trictrac dans le jeu // fn execute_action( // &mut self, - // action: training_common::TrictracAction, + // action: training_common_big::TrictracAction, // ) -> Result> { - fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) { - use training_common::TrictracAction; + fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) { + use training_common_big::TrictracAction; let mut reward = 0.0; let mut is_rollpoint = false; + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + Some(GameEvent::Roll { + player_id: self.active_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.active_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + Some(GameEvent::Go { + player_id: self.active_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game.dice.values.0, self.game.dice.values.1) + } else { + (self.game.dice.values.1, self.game.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + Some(GameEvent::Move { + player_id: self.active_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + // Appliquer l'événement si valide - if let Some(event) = action.to_event(&self.game) { + if let Some(event) = event { if self.game.validate(&event) { self.game.consume(&event); - // reward += REWARD_VALID_MOVE; + // Simuler le résultat des dés après un Roll if matches!(action, TrictracAction::Roll) { let mut rng = thread_rng(); @@ -269,7 +319,7 @@ impl TrictracEnvironment { if self.game.validate(&dice_event) { self.game.consume(&dice_event); let (points, adv_points) = self.game.dice_points; - reward += REWARD_RATIO * (points as f32 - adv_points as f32); + reward += Self::REWARD_RATIO * (points - adv_points) as f32; if points > 0 { is_rollpoint = true; // println!("info: rolled for {reward}"); @@ -281,12 +331,9 @@ impl TrictracEnvironment { // Pénalité pour action invalide // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques - reward = ERROR_REWARD; - self.game.mark_points_for_bot_training(self.opponent_id, 1); + println!("info: action invalide -> err_reward"); + reward = Self::ERROR_REWARD; } - } else { - reward = ERROR_REWARD; - self.game.mark_points_for_bot_training(self.opponent_id, 1); } (reward, is_rollpoint) diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 5230ec0..f7608a3 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,5 +1,8 @@ -use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid}; +use bot::burnrl::algos::{ + dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid, +}; use bot::burnrl::environment::TrictracEnvironment; +use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig; use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid; use bot::burnrl::utils::{demo_model, Config}; use burn::backend::{Autodiff, NdArray}; @@ -33,6 +36,16 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } + "dqn_big" => { + let _agent = dqn_big::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = dqn_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::DQN = + burn_rl::agent::DQN::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } "dqn_valid" => { let _agent = dqn_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); @@ -53,6 +66,16 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } + "sac_big" => { + let _agent = sac_big::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = sac_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::SAC = + burn_rl::agent::SAC::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } "sac_valid" => { let _agent = sac_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); @@ -73,6 +96,16 @@ fn main() { println!("> Test avec le modèle chargé"); demo_model(loaded_agent); } + "ppo_big" => { + let _agent = ppo_big::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = ppo_big::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::PPO = + burn_rl::agent::PPO::new(loaded_model.unwrap()); + + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } "ppo_valid" => { let _agent = ppo_valid::run::(&conf, false); println!("> Chargement du modèle pour test"); diff --git a/bot/src/burnrl/mod.rs b/bot/src/burnrl/mod.rs index 292bbb8..62bebc8 100644 --- a/bot/src/burnrl/mod.rs +++ b/bot/src/burnrl/mod.rs @@ -1,4 +1,5 @@ pub mod algos; pub mod environment; +pub mod environment_big; pub mod environment_valid; pub mod utils; diff --git a/bot/src/dqn_simple/dqn_model.rs b/bot/src/dqn_simple/dqn_model.rs new file mode 100644 index 0000000..9c31f44 --- /dev/null +++ b/bot/src/dqn_simple/dqn_model.rs @@ -0,0 +1,153 @@ +use crate::training_common_big::TrictracAction; +use serde::{Deserialize, Serialize}; + +/// Configuration pour l'agent DQN +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DqnConfig { + pub state_size: usize, + pub hidden_size: usize, + pub num_actions: usize, + pub learning_rate: f64, + pub gamma: f64, + pub epsilon: f64, + pub epsilon_decay: f64, + pub epsilon_min: f64, + pub replay_buffer_size: usize, + pub batch_size: usize, +} + +impl Default for DqnConfig { + fn default() -> Self { + Self { + state_size: 36, + hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi + num_actions: TrictracAction::action_space_size(), + learning_rate: 0.001, + gamma: 0.99, + epsilon: 0.1, + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + } + } +} + +/// Réseau de neurones DQN simplifié (matrice de poids basique) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleNeuralNetwork { + pub weights1: Vec>, + pub biases1: Vec, + pub weights2: Vec>, + pub biases2: Vec, + pub weights3: Vec>, + pub biases3: Vec, +} + +impl SimpleNeuralNetwork { + pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self { + use rand::{thread_rng, Rng}; + let mut rng = thread_rng(); + + // Initialisation aléatoire des poids avec Xavier/Glorot + let scale1 = (2.0 / input_size as f32).sqrt(); + let weights1 = (0..hidden_size) + .map(|_| { + (0..input_size) + .map(|_| rng.gen_range(-scale1..scale1)) + .collect() + }) + .collect(); + let biases1 = vec![0.0; hidden_size]; + + let scale2 = (2.0 / hidden_size as f32).sqrt(); + let weights2 = (0..hidden_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale2..scale2)) + .collect() + }) + .collect(); + let biases2 = vec![0.0; hidden_size]; + + let scale3 = (2.0 / hidden_size as f32).sqrt(); + let weights3 = (0..output_size) + .map(|_| { + (0..hidden_size) + .map(|_| rng.gen_range(-scale3..scale3)) + .collect() + }) + .collect(); + let biases3 = vec![0.0; output_size]; + + Self { + weights1, + biases1, + weights2, + biases2, + weights3, + biases3, + } + } + + pub fn forward(&self, input: &[f32]) -> Vec { + // Première couche + let mut layer1: Vec = self.biases1.clone(); + for (i, neuron_weights) in self.weights1.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < input.len() { + layer1[i] += input[j] * weight; + } + } + layer1[i] = layer1[i].max(0.0); // ReLU + } + + // Deuxième couche + let mut layer2: Vec = self.biases2.clone(); + for (i, neuron_weights) in self.weights2.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer1.len() { + layer2[i] += layer1[j] * weight; + } + } + layer2[i] = layer2[i].max(0.0); // ReLU + } + + // Couche de sortie + let mut output: Vec = self.biases3.clone(); + for (i, neuron_weights) in self.weights3.iter().enumerate() { + for (j, &weight) in neuron_weights.iter().enumerate() { + if j < layer2.len() { + output[i] += layer2[j] * weight; + } + } + } + + output + } + + pub fn get_best_action(&self, input: &[f32]) -> usize { + let q_values = self.forward(input); + q_values + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(index, _)| index) + .unwrap_or(0) + } + + pub fn save>( + &self, + path: P, + ) -> Result<(), Box> { + let data = serde_json::to_string_pretty(self)?; + std::fs::write(path, data)?; + Ok(()) + } + + pub fn load>(path: P) -> Result> { + let data = std::fs::read_to_string(path)?; + let network = serde_json::from_str(&data)?; + Ok(network) + } +} diff --git a/bot/src/dqn_simple/dqn_trainer.rs b/bot/src/dqn_simple/dqn_trainer.rs new file mode 100644 index 0000000..ed60f5e --- /dev/null +++ b/bot/src/dqn_simple/dqn_trainer.rs @@ -0,0 +1,494 @@ +use crate::{CheckerMove, Color, GameState, PlayerId}; +use rand::prelude::SliceRandom; +use rand::{thread_rng, Rng}; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; + +use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; +use crate::training_common_big::{get_valid_actions, TrictracAction}; + +/// Expérience pour le buffer de replay +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Experience { + pub state: Vec, + pub action: TrictracAction, + pub reward: f32, + pub next_state: Vec, + pub done: bool, +} + +/// Buffer de replay pour stocker les expériences +#[derive(Debug)] +pub struct ReplayBuffer { + buffer: VecDeque, + capacity: usize, +} + +impl ReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: VecDeque::with_capacity(capacity), + capacity, + } + } + + pub fn push(&mut self, experience: Experience) { + if self.buffer.len() >= self.capacity { + self.buffer.pop_front(); + } + self.buffer.push_back(experience); + } + + pub fn sample(&self, batch_size: usize) -> Vec { + let mut rng = thread_rng(); + let len = self.buffer.len(); + if len < batch_size { + return self.buffer.iter().cloned().collect(); + } + + let mut batch = Vec::with_capacity(batch_size); + for _ in 0..batch_size { + let idx = rng.gen_range(0..len); + batch.push(self.buffer[idx].clone()); + } + batch + } + + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + pub fn len(&self) -> usize { + self.buffer.len() + } +} + +/// Agent DQN pour l'apprentissage par renforcement +#[derive(Debug)] +pub struct DqnAgent { + config: DqnConfig, + model: SimpleNeuralNetwork, + target_model: SimpleNeuralNetwork, + replay_buffer: ReplayBuffer, + epsilon: f64, + step_count: usize, +} + +impl DqnAgent { + pub fn new(config: DqnConfig) -> Self { + let model = + SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions); + let target_model = model.clone(); + let replay_buffer = ReplayBuffer::new(config.replay_buffer_size); + let epsilon = config.epsilon; + + Self { + config, + model, + target_model, + replay_buffer, + epsilon, + step_count: 0, + } + } + + pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction { + let valid_actions = get_valid_actions(game_state); + + if valid_actions.is_empty() { + // Fallback si aucune action valide + return TrictracAction::Roll; + } + + let mut rng = thread_rng(); + if rng.gen::() < self.epsilon { + // Exploration : action valide aléatoire + valid_actions + .choose(&mut rng) + .cloned() + .unwrap_or(TrictracAction::Roll) + } else { + // Exploitation : meilleure action valide selon le modèle + let q_values = self.model.forward(state); + + let mut best_action = &valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for action in &valid_actions { + let action_index = action.to_action_index(); + if action_index < q_values.len() { + let q_value = q_values[action_index]; + if q_value > best_q_value { + best_q_value = q_value; + best_action = action; + } + } + } + + best_action.clone() + } + } + + pub fn store_experience(&mut self, experience: Experience) { + self.replay_buffer.push(experience); + } + + pub fn train(&mut self) { + if self.replay_buffer.len() < self.config.batch_size { + return; + } + + // Pour l'instant, on simule l'entraînement en mettant à jour epsilon + // Dans une implémentation complète, ici on ferait la backpropagation + self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min); + self.step_count += 1; + + // Mise à jour du target model tous les 100 steps + if self.step_count % 100 == 0 { + self.target_model = self.model.clone(); + } + } + + pub fn save_model>( + &self, + path: P, + ) -> Result<(), Box> { + self.model.save(path) + } + + pub fn get_epsilon(&self) -> f64 { + self.epsilon + } + + pub fn get_step_count(&self) -> usize { + self.step_count + } +} + +/// Environnement Trictrac pour l'entraînement +#[derive(Debug)] +pub struct TrictracEnv { + pub game_state: GameState, + pub agent_player_id: PlayerId, + pub opponent_player_id: PlayerId, + pub agent_color: Color, + pub max_steps: usize, + pub current_step: usize, +} + +impl Default for TrictracEnv { + fn default() -> Self { + let mut game_state = GameState::new(false); + game_state.init_player("agent"); + game_state.init_player("opponent"); + + Self { + game_state, + agent_player_id: 1, + opponent_player_id: 2, + agent_color: Color::White, + max_steps: 1000, + current_step: 0, + } + } +} + +impl TrictracEnv { + pub fn reset(&mut self) -> Vec { + self.game_state = GameState::new(false); + self.game_state.init_player("agent"); + self.game_state.init_player("opponent"); + + // Commencer la partie + self.game_state.consume(&GameEvent::BeginGame { + goes_first: self.agent_player_id, + }); + + self.current_step = 0; + self.game_state.to_vec_float() + } + + pub fn step(&mut self, action: TrictracAction) -> (Vec, f32, bool) { + let mut reward = 0.0; + + // Appliquer l'action de l'agent + if self.game_state.active_player_id == self.agent_player_id { + reward += self.apply_agent_action(action); + } + + // Faire jouer l'adversaire (stratégie simple) + while self.game_state.active_player_id == self.opponent_player_id + && self.game_state.stage != Stage::Ended + { + reward += self.play_opponent_turn(); + } + + // Vérifier si la partie est terminée + let done = self.game_state.stage == Stage::Ended + || self.game_state.determine_winner().is_some() + || self.current_step >= self.max_steps; + + // Récompense finale si la partie est terminée + if done { + if let Some(winner) = self.game_state.determine_winner() { + if winner == self.agent_player_id { + reward += 100.0; // Bonus pour gagner + } else { + reward -= 50.0; // Pénalité pour perdre + } + } + } + + self.current_step += 1; + let next_state = self.game_state.to_vec_float(); + (next_state, reward, done) + } + + fn apply_agent_action(&mut self, action: TrictracAction) -> f32 { + let mut reward = 0.0; + + let event = match action { + TrictracAction::Roll => { + // Lancer les dés + reward += 0.1; + Some(GameEvent::Roll { + player_id: self.agent_player_id, + }) + } + // TrictracAction::Mark => { + // // Marquer des points + // let points = self.game_state. + // reward += 0.1 * points as f32; + // Some(GameEvent::Mark { + // player_id: self.agent_player_id, + // points, + // }) + // } + TrictracAction::Go => { + // Continuer après avoir gagné un trou + reward += 0.2; + Some(GameEvent::Go { + player_id: self.agent_player_id, + }) + } + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Effectuer un mouvement + let (dice1, dice2) = if dice_order { + (self.game_state.dice.values.0, self.game_state.dice.values.1) + } else { + (self.game_state.dice.values.1, self.game_state.dice.values.0) + }; + let mut to1 = from1 + dice1 as usize; + let mut to2 = from2 + dice2 as usize; + + // Gestion prise de coin par puissance + let opp_rest_field = 13; + if to1 == opp_rest_field && to2 == opp_rest_field { + to1 -= 1; + to2 -= 1; + } + + let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default(); + + reward += 0.2; + Some(GameEvent::Move { + player_id: self.agent_player_id, + moves: (checker_move1, checker_move2), + }) + } + }; + + // Appliquer l'événement si valide + if let Some(event) = event { + if self.game_state.validate(&event) { + self.game_state.consume(&event); + + // Simuler le résultat des dés après un Roll + if matches!(action, TrictracAction::Roll) { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + let dice_event = GameEvent::RollResult { + player_id: self.agent_player_id, + dice: store::Dice { + values: dice_values, + }, + }; + if self.game_state.validate(&dice_event) { + self.game_state.consume(&dice_event); + } + } + } else { + // Pénalité pour action invalide + reward -= 2.0; + } + } + + reward + } + + // TODO : use default bot strategy + fn play_opponent_turn(&mut self) -> f32 { + let mut reward = 0.0; + let event = match self.game_state.turn_stage { + TurnStage::RollDice => GameEvent::Roll { + player_id: self.opponent_player_id, + }, + TurnStage::RollWaiting => { + let mut rng = thread_rng(); + let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + GameEvent::RollResult { + player_id: self.opponent_player_id, + dice: store::Dice { + values: dice_values, + }, + } + } + TurnStage::MarkAdvPoints | TurnStage::MarkPoints => { + let opponent_color = self.agent_color.opponent_color(); + let dice_roll_count = self + .game_state + .players + .get(&self.opponent_player_id) + .unwrap() + .dice_roll_count; + let points_rules = PointsRules::new( + &opponent_color, + &self.game_state.board, + self.game_state.dice, + ); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points + + GameEvent::Mark { + player_id: self.opponent_player_id, + points, + } + } + TurnStage::Move => { + let opponent_color = self.agent_color.opponent_color(); + let rules = MoveRules::new( + &opponent_color, + &self.game_state.board, + self.game_state.dice, + ); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + // Stratégie simple : choix aléatoire + let mut rng = thread_rng(); + let choosen_move = *possible_moves + .choose(&mut rng) + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); + + GameEvent::Move { + player_id: self.opponent_player_id, + moves: if opponent_color == Color::White { + choosen_move + } else { + (choosen_move.0.mirror(), choosen_move.1.mirror()) + }, + } + } + TurnStage::HoldOrGoChoice => { + // Stratégie simple : toujours continuer + GameEvent::Go { + player_id: self.opponent_player_id, + } + } + }; + if self.game_state.validate(&event) { + self.game_state.consume(&event); + } + reward + } +} + +/// Entraîneur pour le modèle DQN +pub struct DqnTrainer { + agent: DqnAgent, + env: TrictracEnv, +} + +impl DqnTrainer { + pub fn new(config: DqnConfig) -> Self { + Self { + agent: DqnAgent::new(config), + env: TrictracEnv::default(), + } + } + + pub fn train_episode(&mut self) -> f32 { + let mut total_reward = 0.0; + let mut state = self.env.reset(); + // let mut step_count = 0; + + loop { + // step_count += 1; + let action = self.agent.select_action(&self.env.game_state, &state); + let (next_state, reward, done) = self.env.step(action.clone()); + total_reward += reward; + + let experience = Experience { + state: state.clone(), + action, + reward, + next_state: next_state.clone(), + done, + }; + self.agent.store_experience(experience); + self.agent.train(); + + if done { + break; + } + // if step_count % 100 == 0 { + // println!("{:?}", next_state); + // } + state = next_state; + } + + total_reward + } + + pub fn train( + &mut self, + episodes: usize, + save_every: usize, + model_path: &str, + ) -> Result<(), Box> { + println!("Démarrage de l'entraînement DQN pour {episodes} épisodes"); + + for episode in 1..=episodes { + let reward = self.train_episode(); + + if episode % 100 == 0 { + println!( + "Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}", + episode, + episodes, + reward, + self.agent.get_epsilon(), + self.agent.get_step_count() + ); + } + + if episode % save_every == 0 { + let save_path = format!("{model_path}_episode_{episode}.json"); + self.agent.save_model(&save_path)?; + println!("Modèle sauvegardé : {save_path}"); + } + } + + // Sauvegarder le modèle final + let final_path = format!("{model_path}_final.json"); + self.agent.save_model(&final_path)?; + println!("Modèle final sauvegardé : {final_path}"); + + Ok(()) + } +} diff --git a/bot/src/dqn_simple/main.rs b/bot/src/dqn_simple/main.rs new file mode 100644 index 0000000..024f895 --- /dev/null +++ b/bot/src/dqn_simple/main.rs @@ -0,0 +1,109 @@ +use bot::dqn_simple::dqn_model::DqnConfig; +use bot::dqn_simple::dqn_trainer::DqnTrainer; +use bot::training_common::TrictracAction; +use std::env; + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = env::args().collect(); + + // Paramètres par défaut + let mut episodes = 1000; + let mut model_path = "models/dqn_model".to_string(); + let mut save_every = 100; + + // Parser les arguments de ligne de commande + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--episodes" => { + if i + 1 < args.len() { + episodes = args[i + 1].parse().unwrap_or(1000); + i += 2; + } else { + eprintln!("Erreur : --episodes nécessite une valeur"); + std::process::exit(1); + } + } + "--model-path" => { + if i + 1 < args.len() { + model_path = args[i + 1].clone(); + i += 2; + } else { + eprintln!("Erreur : --model-path nécessite une valeur"); + std::process::exit(1); + } + } + "--save-every" => { + if i + 1 < args.len() { + save_every = args[i + 1].parse().unwrap_or(100); + i += 2; + } else { + eprintln!("Erreur : --save-every nécessite une valeur"); + std::process::exit(1); + } + } + "--help" | "-h" => { + print_help(); + std::process::exit(0); + } + _ => { + eprintln!("Argument inconnu : {}", args[i]); + print_help(); + std::process::exit(1); + } + } + } + + // Créer le dossier models s'il n'existe pas + std::fs::create_dir_all("models")?; + + println!("Configuration d'entraînement DQN :"); + println!(" Épisodes : {episodes}"); + println!(" Chemin du modèle : {model_path}"); + println!(" Sauvegarde tous les {save_every} épisodes"); + println!(); + + // Configuration DQN + let config = DqnConfig { + state_size: 36, // state.to_vec size + hidden_size: 256, + num_actions: TrictracAction::action_space_size(), + learning_rate: 0.001, + gamma: 0.99, + epsilon: 0.9, // Commencer avec plus d'exploration + epsilon_decay: 0.995, + epsilon_min: 0.01, + replay_buffer_size: 10000, + batch_size: 32, + }; + + // Créer et lancer l'entraîneur + let mut trainer = DqnTrainer::new(config); + trainer.train(episodes, save_every, &model_path)?; + + println!("Entraînement terminé avec succès !"); + println!("Pour utiliser le modèle entraîné :"); + println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy"); + + Ok(()) +} + +fn print_help() { + println!("Entraîneur DQN pour Trictrac"); + println!(); + println!("USAGE:"); + println!(" cargo run --bin=train_dqn [OPTIONS]"); + println!(); + println!("OPTIONS:"); + println!(" --episodes Nombre d'épisodes d'entraînement (défaut: 1000)"); + println!(" --model-path Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)"); + println!(" --save-every Sauvegarder le modèle tous les N épisodes (défaut: 100)"); + println!(" -h, --help Afficher cette aide"); + println!(); + println!("EXEMPLES:"); + println!(" cargo run --bin=train_dqn"); + println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500"); + println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000"); +} diff --git a/bot/src/dqn_simple/mod.rs b/bot/src/dqn_simple/mod.rs new file mode 100644 index 0000000..8090a29 --- /dev/null +++ b/bot/src/dqn_simple/mod.rs @@ -0,0 +1,2 @@ +pub mod dqn_model; +pub mod dqn_trainer; diff --git a/bot/src/lib.rs b/bot/src/lib.rs index 0fc6fdf..dab36be 100644 --- a/bot/src/lib.rs +++ b/bot/src/lib.rs @@ -1,11 +1,14 @@ pub mod burnrl; +pub mod dqn_simple; pub mod strategy; pub mod training_common; +pub mod training_common_big; pub mod trictrac_board; use log::debug; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; pub use strategy::default::DefaultStrategy; +pub use strategy::dqn::DqnStrategy; pub use strategy::dqnburn::DqnBurnStrategy; pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::random::RandomStrategy; diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs new file mode 100644 index 0000000..2874195 --- /dev/null +++ b/bot/src/strategy/dqn.rs @@ -0,0 +1,174 @@ +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; +use log::info; +use std::path::Path; +use store::MoveRules; + +use crate::dqn_simple::dqn_model::SimpleNeuralNetwork; +use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction}; + +/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné +#[derive(Debug)] +pub struct DqnStrategy { + pub game: GameState, + pub player_id: PlayerId, + pub color: Color, + pub model: Option, +} + +impl Default for DqnStrategy { + fn default() -> Self { + Self { + game: GameState::default(), + player_id: 1, + color: Color::White, + model: None, + } + } +} + +impl DqnStrategy { + pub fn new() -> Self { + Self::default() + } + + pub fn new_with_model + std::fmt::Debug>(model_path: P) -> Self { + let mut strategy = Self::new(); + if let Ok(model) = SimpleNeuralNetwork::load(&model_path) { + info!("Loading model {model_path:?}"); + strategy.model = Some(model); + } + strategy + } + + /// Utilise le modèle DQN pour choisir une action valide + fn get_dqn_action(&self) -> Option { + if let Some(ref model) = self.model { + let state = self.game.to_vec_float(); + let valid_actions = get_valid_actions(&self.game); + + if valid_actions.is_empty() { + return None; + } + + // Obtenir les Q-values pour toutes les actions + let q_values = model.forward(&state); + + // Trouver la meilleure action valide + let mut best_action = &valid_actions[0]; + let mut best_q_value = f32::NEG_INFINITY; + + for action in &valid_actions { + let action_index = action.to_action_index(); + if action_index < q_values.len() { + let q_value = q_values[action_index]; + if q_value > best_q_value { + best_q_value = q_value; + best_action = action; + } + } + } + + Some(best_action.clone()) + } else { + // Fallback : action aléatoire valide + sample_valid_action(&self.game) + } + } +} + +impl BotStrategy for DqnStrategy { + fn get_game(&self) -> &GameState { + &self.game + } + + fn get_mut_game(&mut self) -> &mut GameState { + &mut self.game + } + + fn set_color(&mut self, color: Color) { + self.color = color; + } + + fn set_player_id(&mut self, player_id: PlayerId) { + self.player_id = player_id; + } + + fn calculate_points(&self) -> u8 { + self.game.dice_points.0 + } + + fn calculate_adv_points(&self) -> u8 { + self.game.dice_points.1 + } + + fn choose_go(&self) -> bool { + // Utiliser le DQN pour décider si on continue + if let Some(action) = self.get_dqn_action() { + matches!(action, TrictracAction::Go) + } else { + // Fallback : toujours continuer + true + } + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + // Utiliser le DQN pour choisir le mouvement + if let Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) = self.get_dqn_action() + { + let dicevals = self.game.dice.values; + let (mut dice1, mut dice2) = if dice_order { + (dicevals.0, dicevals.1) + } else { + (dicevals.1, dicevals.0) + }; + + if from1 == 0 { + // empty move + dice1 = 0; + } + let mut to1 = from1 + dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } + if from2 == 0 { + // empty move + dice2 = 0; + } + let mut to2 = from2 + dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + + let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); + let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default(); + + let chosen_move = if self.color == Color::White { + (checker_move1, checker_move2) + } else { + (checker_move1.mirror(), checker_move2.mirror()) + }; + + return chosen_move; + } + + // Fallback : utiliser la stratégie par défaut + let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + let chosen_move = *possible_moves + .first() + .unwrap_or(&(CheckerMove::default(), CheckerMove::default())); + + if self.color == Color::White { + chosen_move + } else { + (chosen_move.0.mirror(), chosen_move.1.mirror()) + } + } +} diff --git a/bot/src/strategy/mod.rs b/bot/src/strategy/mod.rs index 00293cb..b9fa3b2 100644 --- a/bot/src/strategy/mod.rs +++ b/bot/src/strategy/mod.rs @@ -1,5 +1,6 @@ pub mod client; pub mod default; +pub mod dqn; pub mod dqnburn; pub mod erroneous_moves; pub mod random; diff --git a/bot/src/training_common.rs b/bot/src/training_common.rs index ee33d0c..5d8e870 100644 --- a/bot/src/training_common.rs +++ b/bot/src/training_common.rs @@ -1,5 +1,3 @@ -/// training_common.rs : environnement avec espace d'actions optimisé -/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' ) use std::cmp::{max, min}; use std::fmt::{Debug, Display, Formatter}; diff --git a/bot/src/training_common_big.rs b/bot/src/training_common_big.rs new file mode 100644 index 0000000..9f8bae4 --- /dev/null +++ b/bot/src/training_common_big.rs @@ -0,0 +1,266 @@ +use std::cmp::{max, min}; + +use serde::{Deserialize, Serialize}; +use store::{CheckerMove, Dice}; + +/// Types d'actions possibles dans le jeu +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum TrictracAction { + /// Lancer les dés + Roll, + /// Continuer après avoir gagné un trou + Go, + /// Effectuer un mouvement de pions + Move { + dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier + from1: usize, // position de départ du premier pion (0-24) + from2: usize, // position de départ du deuxième pion (0-24) + }, + // Marquer les points : à activer si support des écoles + // Mark, +} + +impl TrictracAction { + /// Encode une action en index pour le réseau de neurones + pub fn to_action_index(&self) -> usize { + match self { + TrictracAction::Roll => 0, + TrictracAction::Go => 1, + TrictracAction::Move { + dice_order, + from1, + from2, + } => { + // Encoder les mouvements dans l'espace d'actions + // Indices 2+ pour les mouvements + // de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) + let mut start = 2; + if !dice_order { + // 25 * 25 = 625 + start += 625; + } + start + from1 * 25 + from2 + } // TrictracAction::Mark => 1252, + } + } + + /// Décode un index d'action en TrictracAction + pub fn from_action_index(index: usize) -> Option { + match index { + 0 => Some(TrictracAction::Roll), + // 1252 => Some(TrictracAction::Mark), + 1 => Some(TrictracAction::Go), + i if i >= 3 => { + let move_code = i - 3; + let (dice_order, from1, from2) = Self::decode_move(move_code); + Some(TrictracAction::Move { + dice_order, + from1, + from2, + }) + } + _ => None, + } + } + + /// Décode un entier en paire de mouvements + fn decode_move(code: usize) -> (bool, usize, usize) { + let mut encoded = code; + let dice_order = code < 626; + if !dice_order { + encoded -= 625 + } + let from1 = encoded / 25; + let from2 = 1 + encoded % 25; + (dice_order, from1, from2) + } + + /// Retourne la taille de l'espace d'actions total + pub fn action_space_size() -> usize { + // 1 (Roll) + 1 (Go) + mouvements possibles + // Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from) + // Mais on peut optimiser en limitant aux positions valides (1-24) + 2 + (2 * 25 * 25) // = 1252 + } + + // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { + // match action { + // TrictracAction::Roll => Some(GameEvent::Roll { player_id }), + // TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }), + // TrictracAction::Go => Some(GameEvent::Go { player_id }), + // TrictracAction::Move { + // dice_order, + // from1, + // from2, + // } => { + // // Effectuer un mouvement + // let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default(); + // let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default(); + // + // Some(GameEvent::Move { + // player_id: self.agent_player_id, + // moves: (checker_move1, checker_move2), + // }) + // } + // }; + // } +} + +/// Obtient les actions valides pour l'état de jeu actuel +pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { + use store::TurnStage; + + let mut valid_actions = Vec::new(); + + let active_player_id = game_state.active_player_id; + let player_color = game_state.player_color_by_id(&active_player_id); + + if let Some(color) = player_color { + match game_state.turn_stage { + TurnStage::RollDice => { + valid_actions.push(TrictracAction::Roll); + } + TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => { + panic!( + "get_valid_actions not implemented for turn stage {:?}", + game_state.turn_stage + ); + // valid_actions.push(TrictracAction::Mark); + } + TurnStage::HoldOrGoChoice => { + valid_actions.push(TrictracAction::Go); + + // Ajoute aussi les mouvements possibles + let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); + let possible_moves = rules.get_possible_moves_sequences(true, vec![]); + + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); + for (move1, move2) in possible_moves { + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); + } + } + TurnStage::Move => { + let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); + let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]); + if possible_moves.is_empty() { + // Empty move + possible_moves.push((CheckerMove::default(), CheckerMove::default())); + } + + // Modififier checker_moves_to_trictrac_action si on doit gérer Black + assert_eq!(color, store::Color::White); + for (move1, move2) in possible_moves { + valid_actions.push(checker_moves_to_trictrac_action( + &move1, + &move2, + &game_state.dice, + )); + } + } + } + } + + if valid_actions.is_empty() { + panic!("empty valid_actions for state {game_state}"); + } + valid_actions +} + +// Valid only for White player +fn checker_moves_to_trictrac_action( + move1: &CheckerMove, + move2: &CheckerMove, + dice: &Dice, +) -> TrictracAction { + let to1 = move1.get_to(); + let to2 = move2.get_to(); + let from1 = move1.get_from(); + let from2 = move2.get_from(); + + let mut diff_move1 = if to1 > 0 { + // Mouvement sans sortie + to1 - from1 + } else { + // sortie, on utilise la valeur du dé + if to2 > 0 { + // sortie pour le mouvement 1 uniquement + let dice2 = to2 - from2; + if dice2 == dice.values.0 as usize { + dice.values.1 as usize + } else { + dice.values.0 as usize + } + } else { + // double sortie + if from1 < from2 { + max(dice.values.0, dice.values.1) as usize + } else { + min(dice.values.0, dice.values.1) as usize + } + } + }; + + // modification de diff_move1 si on est dans le cas d'un mouvement par puissance + let rest_field = 12; + if to1 == rest_field + && to2 == rest_field + && max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field + { + // prise par puissance + diff_move1 += 1; + } + TrictracAction::Move { + dice_order: diff_move1 == dice.values.0 as usize, + from1: move1.get_from(), + from2: move2.get_from(), + } +} + +/// Retourne les indices des actions valides +pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec { + get_valid_actions(game_state) + .into_iter() + .map(|action| action.to_action_index()) + .collect() +} + +/// Sélectionne une action valide aléatoire +pub fn sample_valid_action(game_state: &crate::GameState) -> Option { + use rand::{seq::SliceRandom, thread_rng}; + + let valid_actions = get_valid_actions(game_state); + let mut rng = thread_rng(); + valid_actions.choose(&mut rng).cloned() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn to_action_index() { + let action = TrictracAction::Move { + dice_order: true, + from1: 3, + from2: 4, + }; + let index = action.to_action_index(); + assert_eq!(Some(action), TrictracAction::from_action_index(index)); + assert_eq!(81, index); + } + + #[test] + fn from_action_index() { + let action = TrictracAction::Move { + dice_order: true, + from1: 3, + from2: 4, + }; + assert_eq!(Some(action), TrictracAction::from_action_index(81)); + } +} diff --git a/client_bevy/.cargo/config.toml b/client_bevy/.cargo/config.toml new file mode 100644 index 0000000..b6bc0d3 --- /dev/null +++ b/client_bevy/.cargo/config.toml @@ -0,0 +1,8 @@ +[target.x86_64-unknown-linux-gnu] +linker = "clang" +rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"] + +# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only' +# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains. +#[profile.dev] +#debug = 1 diff --git a/client_bevy/Cargo.toml b/client_bevy/Cargo.toml new file mode 100644 index 0000000..aaa6b7d --- /dev/null +++ b/client_bevy/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "trictrac-client" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.75" +bevy = { version = "0.11.3" } +bevy_renet = "0.0.9" +bincode = "1.3.3" +renet = "0.0.13" +store = { path = "../store" } diff --git a/client_bevy/assets/Inconsolata.ttf b/client_bevy/assets/Inconsolata.ttf new file mode 100644 index 0000000..34848ca Binary files /dev/null and b/client_bevy/assets/Inconsolata.ttf differ diff --git a/client_bevy/assets/board.png b/client_bevy/assets/board.png new file mode 100644 index 0000000..5d16ac3 Binary files /dev/null and b/client_bevy/assets/board.png differ diff --git a/client_bevy/assets/sound/click.wav b/client_bevy/assets/sound/click.wav new file mode 100644 index 0000000..8b6c99d Binary files /dev/null and b/client_bevy/assets/sound/click.wav differ diff --git a/client_bevy/assets/sound/throw.wav b/client_bevy/assets/sound/throw.wav new file mode 100755 index 0000000..cb5e438 Binary files /dev/null and b/client_bevy/assets/sound/throw.wav differ diff --git a/client_bevy/assets/tac.png b/client_bevy/assets/tac.png new file mode 100644 index 0000000..2c18813 Binary files /dev/null and b/client_bevy/assets/tac.png differ diff --git a/client_bevy/assets/tic.png b/client_bevy/assets/tic.png new file mode 100644 index 0000000..786e0c7 Binary files /dev/null and b/client_bevy/assets/tic.png differ diff --git a/client_bevy/src/main.rs b/client_bevy/src/main.rs new file mode 100644 index 0000000..504602e --- /dev/null +++ b/client_bevy/src/main.rs @@ -0,0 +1,334 @@ +use std::{net::UdpSocket, time::SystemTime}; + +use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES}; +use store::{GameEvent, GameState, CheckerMove}; + +use bevy::prelude::*; +use bevy::window::PrimaryWindow; +use bevy_renet::{ + renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient}, + transport::{client_connected, NetcodeClientPlugin}, + RenetClientPlugin, +}; + +#[derive(Debug, Resource)] +struct CurrentClientId(u64); + +#[derive(Resource)] +struct BevyGameState(GameState); + +impl Default for BevyGameState { + fn default() -> Self { + Self { + 0: GameState::default(), + } + } +} + +#[derive(Resource, Deref, DerefMut)] +struct GameUIState { + selected_tile: Option, +} + +impl Default for GameUIState { + fn default() -> Self { + Self { + selected_tile: None, + } + } +} + +#[derive(Event)] +struct BevyGameEvent(GameEvent); + +// This id needs to be the same as the server is using +const PROTOCOL_ID: u64 = 2878; + +fn main() { + // Get username from stdin args + let args = std::env::args().collect::>(); + let username = &args[1]; + + let (client, transport, client_id) = new_renet_client(&username).unwrap(); + App::new() + // Lets add a nice dark grey background color + .insert_resource(ClearColor(Color::hex("282828").unwrap())) + .add_plugins(DefaultPlugins.set(WindowPlugin { + primary_window: Some(Window { + // Adding the username to the window title makes debugging a whole lot easier. + title: format!("TricTrac <{}>", username), + resolution: (1080.0, 1080.0).into(), + ..default() + }), + ..default() + })) + // Add our game state and register GameEvent as a bevy event + .insert_resource(BevyGameState::default()) + .insert_resource(GameUIState::default()) + .add_event::() + // Renet setup + .add_plugins(RenetClientPlugin) + .add_plugins(NetcodeClientPlugin) + .insert_resource(client) + .insert_resource(transport) + .insert_resource(CurrentClientId(client_id)) + .add_systems(Startup, setup) + .add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system)) + .add_systems( + PostUpdate, + receive_events_from_server.run_if(client_connected()), + ) + .run(); +} + +////////// COMPONENTS ////////// +#[derive(Component)] +struct UIRoot; + +#[derive(Component)] +struct WaitingText; + +#[derive(Component)] +struct Board { + squares: [Square; 26] +} + +impl Default for Board { + fn default() -> Self { + Self { + squares: [Square { count: 0, color: None, position: 0}; 26] + } + } +} + +impl Board { + fn square_at(&self, position: usize) -> Square { + self.squares[position] + } +} + +#[derive(Component, Clone, Copy)] +struct Square { + count: usize, + color: Option, + position: usize, +} + +////////// UPDATE SYSTEMS ////////// +fn update_board( + mut commands: Commands, + game_state: Res, + mut game_events: EventReader, + asset_server: Res, +) { + for event in game_events.iter() { + match event.0 { + GameEvent::Move { player_id, moves } => { + // trictrac positions, TODO : dépend de player_id + let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)}; + let texture = + asset_server.load(match game_state.0.players[&player_id].color { + store::Color::Black => "tac.png", + store::Color::White => "tic.png", + }); + + info!("spawning tictac sprite"); + commands.spawn(SpriteBundle { + transform: Transform::from_xyz( + 83.0 * (x as f32 - 1.0), + -30.0 + 540.0 * (y as f32 - 1.0), + 0.0, + ), + sprite: Sprite { + custom_size: Some(Vec2::new(83.0, 83.0)), + ..default() + }, + texture: texture.into(), + ..default() + }); + } + _ => {} + } + } +} + +fn update_waiting_text(mut text_query: Query<&mut Text, With>, time: Res