diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index 9465ec1..c9b249d 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -1,3 +1,4 @@ +use crate::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; @@ -7,9 +8,8 @@ use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; -use crate::burnrl::utils::soft_update_linear; -#[derive(Module, Debug, Clone)] +#[derive(Module, Debug)] pub struct Net { linear_0: Linear, linear_1: Linear, @@ -18,7 +18,12 @@ pub struct Net { impl Net { #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self { + pub fn new( + input_size: usize, + dense_size: usize, + output_size: usize, + device: &B::Device, + ) -> Self { Self { linear_0: LinearConfig::new(input_size, dense_size).init(device), linear_1: LinearConfig::new(dense_size, dense_size).init(device), @@ -45,8 +50,8 @@ impl Model, Tensor> for Net { } impl DQNModel for Net { - fn soft_update(self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = self.consume(); + fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = this.consume(); Self { linear_0: soft_update_linear(linear_0, &that.linear_0, tau), @@ -107,8 +112,7 @@ pub fn run( DQN::>::react_with_exploration(&policy_net, state, eps_threshold); let snapshot = env.step(action); - episode_reward += - >::into(snapshot.reward().clone()); + episode_reward += >::into(snapshot.reward().clone()); memory.push( state, @@ -139,4 +143,5 @@ pub fn run( } } agent.valid() -} \ No newline at end of file +} + diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index aa657ac..f9f511d 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -13,10 +13,8 @@ fn main() { let num_episodes = 3; let agent = dqn_model::run::(num_episodes, false); //true); - let valid_agent = agent.valid(); - println!("> Sauvegarde du modèle de validation"); - save_model(valid_agent.model().as_ref().unwrap()); + save_model(agent.model().as_ref().unwrap()); println!("> Chargement du modèle pour test"); let loaded_model = load_model(); @@ -58,4 +56,5 @@ fn load_model() -> dqn_model::Net> { &device, ) .load_record(record) -} \ No newline at end of file +} + diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index d17df4a..dcf08a2 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -1,14 +1,13 @@ use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; use burn::tensor::Tensor; -use burn_rl::base::{Action, ElemType, Environment, State}; use burn_rl::agent::DQN; +use burn_rl::base::{Action, ElemType, Environment, State}; -pub fn demo_model( - agent: DQN, - mut get_valid_actions: F, -) where +pub fn demo_model(agent: DQN, mut get_valid_actions: F) +where E: Environment, M: Module + burn_rl::agent::DQNModel, B: Backend, @@ -36,12 +35,14 @@ pub fn demo_model( for (index, q_value) in q_values_vec.iter().enumerate() { if !valid_actions.contains(&E::ActionType::from(index as u32)) { - masked_q_values = - masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY); + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); } } - let action_index = masked_q_values.argmax(1).into_scalar() as u32; + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); let action = E::ActionType::from(action_index); let snapshot = env.step(action);