diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 1dea531..5df0623 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -17,6 +17,10 @@ path = "src/burnrl/dqn_big/main.rs" name = "train_dqn_burn" path = "src/burnrl/dqn/main.rs" +[[bin]] +name = "train_ppo_burn" +path = "src/burnrl/ppo/main.rs" + [[bin]] name = "train_dqn_simple" path = "src/dqn_simple/main.rs" diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index cc98db5..d310bbe 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -4,7 +4,8 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" CFG_SIZE=12 -BINBOT=train_dqn_burn +BINBOT=train_ppo_burn +# BINBOT=train_dqn_burn # BINBOT=train_dqn_burn_big # BINBOT=train_dqn_burn_before OPPONENT="random" diff --git a/bot/src/burnrl/ppo/main.rs b/bot/src/burnrl/ppo/main.rs new file mode 100644 index 0000000..3633e29 --- /dev/null +++ b/bot/src/burnrl/ppo/main.rs @@ -0,0 +1,52 @@ +use bot::burnrl::environment; +use bot::burnrl::ppo::{ + ppo_model, + utils::{demo_model, load_model, save_model}, +}; +use burn::backend::{Autodiff, NdArray}; +use burn_rl::agent::PPO; +use burn_rl::base::ElemType; + +type Backend = Autodiff>; +type Env = environment::TrictracEnvironment; + +fn main() { + // println!("> Entraînement"); + + // See also MEMORY_SIZE in dqn_model.rs : 8192 + let conf = ppo_model::PpoConfig { + // defaults + num_episodes: 50, // 40 + max_steps: 1000, // 1000 max steps by episode + dense_size: 256, // 128 neural network complexity (default 128) + gamma: 0.9999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme + // plus lente moins sensible aux coups de chance + learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais + // converger + batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. + clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100) + + lambda: 0.95, + epsilon_clip: 0.2, + critic_weight: 0.5, + entropy_weight: 0.01, + epochs: 8, + }; + println!("{conf}----------"); + let valid_agent = ppo_model::run::(&conf, false); //true); + + // let valid_agent = agent.valid(model); + + println!("> Sauvegarde du modèle de validation"); + + let path = "bot/models/burnrl_ppo".to_string(); + panic!("how to do that : save model"); + // save_model(valid_agent.model().as_ref().unwrap(), &path); + + // println!("> Chargement du modèle pour test"); + // let loaded_model = load_model(conf.dense_size, &path); + // let loaded_agent = PPO::new(loaded_model.unwrap()); + // + // println!("> Test avec le modèle chargé"); + // demo_model(loaded_agent); +} diff --git a/bot/src/burnrl/ppo/mod.rs b/bot/src/burnrl/ppo/mod.rs new file mode 100644 index 0000000..1b442d8 --- /dev/null +++ b/bot/src/burnrl/ppo/mod.rs @@ -0,0 +1,2 @@ +pub mod ppo_model; +pub mod utils; diff --git a/bot/src/burnrl/ppo/ppo_model.rs b/bot/src/burnrl/ppo/ppo_model.rs new file mode 100644 index 0000000..dc0b5ca --- /dev/null +++ b/bot/src/burnrl/ppo/ppo_model.rs @@ -0,0 +1,184 @@ +use crate::burnrl::environment::TrictracEnvironment; +use burn::module::Module; +use burn::nn::{Initializer, Linear, LinearConfig}; +use burn::optim::AdamWConfig; +use burn::tensor::activation::{relu, softmax}; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::Tensor; +use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO}; +use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use std::fmt; +use std::time::SystemTime; + +#[derive(Module, Debug)] +pub struct Net { + linear: Linear, + linear_actor: Linear, + linear_critic: Linear, +} + +impl Net { + #[allow(unused)] + pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + let initializer = Initializer::XavierUniform { gain: 1.0 }; + Self { + linear: LinearConfig::new(input_size, dense_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_actor: LinearConfig::new(dense_size, output_size) + .with_initializer(initializer.clone()) + .init(&Default::default()), + linear_critic: LinearConfig::new(dense_size, 1) + .with_initializer(initializer) + .init(&Default::default()), + } + } +} + +impl Model, PPOOutput, Tensor> for Net { + fn forward(&self, input: Tensor) -> PPOOutput { + let layer_0_output = relu(self.linear.forward(input)); + let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1); + let values = self.linear_critic.forward(layer_0_output); + + PPOOutput::::new(policies, values) + } + + fn infer(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear.forward(input)); + softmax(self.linear_actor.forward(layer_0_output.clone()), 1) + } +} + +impl PPOModel for Net {} +#[allow(unused)] +const MEMORY_SIZE: usize = 512; + +pub struct PpoConfig { + pub max_steps: usize, + pub num_episodes: usize, + pub dense_size: usize, + + pub gamma: f32, + pub lambda: f32, + pub epsilon_clip: f32, + pub critic_weight: f32, + pub entropy_weight: f32, + pub learning_rate: f32, + pub epochs: usize, + pub batch_size: usize, + pub clip_grad: f32, +} + +impl fmt::Display for PpoConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = String::new(); + s.push_str(&format!("max_steps={:?}\n", self.max_steps)); + s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); + s.push_str(&format!("dense_size={:?}\n", self.dense_size)); + s.push_str(&format!("gamma={:?}\n", self.gamma)); + s.push_str(&format!("lambda={:?}\n", self.lambda)); + s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip)); + s.push_str(&format!("critic_weight={:?}\n", self.critic_weight)); + s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight)); + s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); + s.push_str(&format!("epochs={:?}\n", self.epochs)); + s.push_str(&format!("batch_size={:?}\n", self.batch_size)); + write!(f, "{s}") + } +} + +impl Default for PpoConfig { + fn default() -> Self { + Self { + max_steps: 2000, + num_episodes: 1000, + dense_size: 256, + + gamma: 0.99, + lambda: 0.95, + epsilon_clip: 0.2, + critic_weight: 0.5, + entropy_weight: 0.01, + learning_rate: 0.001, + epochs: 8, + batch_size: 8, + clip_grad: 100.0, + } + } +} +type MyAgent = PPO>; + +#[allow(unused)] +pub fn run, B: AutodiffBackend>( + conf: &PpoConfig, + visualized: bool, + // ) -> PPO> { +) -> impl Agent { + let mut env = E::new(visualized); + env.as_mut().max_steps = conf.max_steps; + + let mut model = Net::::new( + <::StateType as State>::size(), + conf.dense_size, + <::ActionType as Action>::size(), + ); + let agent = MyAgent::default(); + let config = PPOTrainingConfig { + gamma: conf.gamma, + lambda: conf.lambda, + epsilon_clip: conf.epsilon_clip, + critic_weight: conf.critic_weight, + entropy_weight: conf.entropy_weight, + learning_rate: conf.learning_rate, + epochs: conf.epochs, + batch_size: conf.batch_size, + clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value( + conf.clip_grad, + )), + }; + + let mut optimizer = AdamWConfig::new() + .with_grad_clipping(config.clip_grad.clone()) + .init(); + let mut memory = Memory::::default(); + for episode in 0..conf.num_episodes { + let mut episode_done = false; + let mut episode_reward = 0.0; + let mut episode_duration = 0_usize; + let mut now = SystemTime::now(); + + env.reset(); + while !episode_done { + let state = env.state(); + if let Some(action) = MyAgent::::react_with_model(&state, &model) { + let snapshot = env.step(action); + episode_reward += <::RewardType as Into>::into( + snapshot.reward().clone(), + ); + + memory.push( + state, + *snapshot.state(), + action, + snapshot.reward().clone(), + snapshot.done(), + ); + + episode_duration += 1; + episode_done = snapshot.done() || episode_duration >= conf.max_steps; + } + } + println!( + "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}", + now.elapsed().unwrap().as_secs(), + ); + + now = SystemTime::now(); + model = MyAgent::train::(model, &memory, &mut optimizer, &config); + memory.clear(); + } + + agent.valid(model) + // agent +} diff --git a/bot/src/burnrl/ppo/utils.rs b/bot/src/burnrl/ppo/utils.rs new file mode 100644 index 0000000..9457217 --- /dev/null +++ b/bot/src/burnrl/ppo/utils.rs @@ -0,0 +1,88 @@ +use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; +use crate::burnrl::ppo::ppo_model; +use crate::training_common::get_valid_action_indices; +use burn::backend::{ndarray::NdArrayDevice, NdArray}; +use burn::module::{Module, Param, ParamId}; +use burn::nn::Linear; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; +use burn::tensor::Tensor; +use burn_rl::agent::{PPOModel, PPO}; +use burn_rl::base::{Action, ElemType, Environment, State}; + +pub fn save_model(model: &ppo_model::Net>, path: &String) { + let recorder = CompactRecorder::new(); + let model_path = format!("{path}.mpk"); + println!("Modèle de validation sauvegardé : {model_path}"); + recorder + .record(model.clone().into_record(), model_path.into()) + .unwrap(); +} + +pub fn load_model(dense_size: usize, path: &String) -> Option>> { + let model_path = format!("{path}.mpk"); + // println!("Chargement du modèle depuis : {model_path}"); + + CompactRecorder::new() + .load(model_path.into(), &NdArrayDevice::default()) + .map(|record| { + ppo_model::Net::new( + ::StateType::size(), + dense_size, + ::ActionType::size(), + ) + .load_record(record) + }) + .ok() +} + +pub fn demo_model>(agent: PPO) { + let mut env = TrictracEnvironment::new(true); + let mut done = false; + while !done { + // let action = match infer_action(&agent, &env, state) { + let action = match infer_action(&agent, &env) { + Some(value) => value, + None => break, + }; + // Execute action + let snapshot = env.step(action); + done = snapshot.done(); + } +} + +fn infer_action>( + agent: &PPO, + env: &TrictracEnvironment, +) -> Option { + let state = env.state(); + panic!("how to do that ?"); + None + // Get q-values + // let q_values = agent + // .model() + // .as_ref() + // .unwrap() + // .infer(state.to_tensor().unsqueeze()); + // // Get valid actions + // let valid_actions_indices = get_valid_action_indices(&env.game); + // if valid_actions_indices.is_empty() { + // return None; // No valid actions, end of episode + // } + // // Set non valid actions q-values to lowest + // let mut masked_q_values = q_values.clone(); + // let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + // for (index, q_value) in q_values_vec.iter().enumerate() { + // if !valid_actions_indices.contains(&index) { + // masked_q_values = masked_q_values.clone().mask_fill( + // masked_q_values.clone().equal_elem(*q_value), + // f32::NEG_INFINITY, + // ); + // } + // } + // // Get best action (highest q-value) + // let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + // let action = TrictracAction::from(action_index); + // Some(action) +}