diff --git a/bot/src/burnrl/dqn_model.rs b/bot/src/burnrl/dqn_model.rs index 221b391..9465ec1 100644 --- a/bot/src/burnrl/dqn_model.rs +++ b/bot/src/burnrl/dqn_model.rs @@ -1,16 +1,15 @@ -use crate::burnrl::utils::soft_update_linear; use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::optim::AdamWConfig; -use burn::record::{CompactRecorder, Recorder}; use burn::tensor::activation::relu; use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::Tensor; use burn_rl::agent::DQN; use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; +use crate::burnrl::utils::soft_update_linear; -#[derive(Module, Debug)] +#[derive(Module, Debug, Clone)] pub struct Net { linear_0: Linear, linear_1: Linear, @@ -19,11 +18,11 @@ pub struct Net { impl Net { #[allow(unused)] - pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self { + pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self { Self { - linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()), - linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()), - linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()), + linear_0: LinearConfig::new(input_size, dense_size).init(device), + linear_1: LinearConfig::new(dense_size, dense_size).init(device), + linear_2: LinearConfig::new(dense_size, output_size).init(device), } } @@ -34,7 +33,7 @@ impl Net { impl Model, Tensor> for Net { fn forward(&self, input: Tensor) -> Tensor { - let layer_0_output = relu(self.linear_0.forward(input)); + let layer_0_output = relu(self.linear_0.forward(input.clone())); let layer_1_output = relu(self.linear_1.forward(layer_0_output)); relu(self.linear_2.forward(layer_1_output)) @@ -46,8 +45,8 @@ impl Model, Tensor> for Net { } impl DQNModel for Net { - fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self { - let (linear_0, linear_1, linear_2) = this.consume(); + fn soft_update(self, that: &Self, tau: ElemType) -> Self { + let (linear_0, linear_1, linear_2) = self.consume(); Self { linear_0: soft_update_linear(linear_0, &that.linear_0, tau), @@ -70,14 +69,15 @@ type MyAgent = DQN>; pub fn run( num_episodes: usize, visualized: bool, -) -> DQN> { - // ) -> impl Agent { +) -> impl Agent { let mut env = E::new(visualized); + let device = Default::default(); let model = Net::::new( - <::StateType as State>::size(), + ::size(), DENSE_SIZE, - <::ActionType as Action>::size(), + ::size(), + &device, ); let mut agent = MyAgent::new(model); @@ -108,7 +108,7 @@ pub fn run( let snapshot = env.step(action); episode_reward += - <::RewardType as Into>::into(snapshot.reward().clone()); + >::into(snapshot.reward().clone()); memory.push( state, @@ -119,8 +119,7 @@ pub fn run( ); if config.batch_size < memory.len() { - policy_net = - agent.train::(policy_net, &memory, &mut optimizer, &config); + policy_net = agent.train(policy_net, &memory, &mut optimizer, &config); } step += 1; @@ -139,5 +138,5 @@ pub fn run( } } } - agent -} + agent.valid() +} \ No newline at end of file diff --git a/bot/src/burnrl/environment.rs b/bot/src/burnrl/environment.rs index 669d3b4..e6faf80 100644 --- a/bot/src/burnrl/environment.rs +++ b/bot/src/burnrl/environment.rs @@ -199,6 +199,15 @@ impl Environment for TrictracEnvironment { } impl TrictracEnvironment { + pub fn valid_actions(&self) -> Vec { + dqn_common::get_valid_actions(&self.game) + .into_iter() + .map(|a| TrictracAction { + index: a.to_action_index() as u32, + }) + .collect() + } + /// Convertit une action burn-rl vers une action Trictrac fn convert_action( &self, @@ -380,4 +389,4 @@ impl TrictracEnvironment { } reward } -} +} \ No newline at end of file diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 41a29e2..aa657ac 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -23,7 +23,7 @@ fn main() { let loaded_agent = DQN::new(loaded_model); println!("> Test avec le modèle chargé"); - demo_model::(loaded_agent); + demo_model(loaded_agent, |env| env.valid_actions()); } fn save_model(model: &dqn_model::Net>) { @@ -55,6 +55,7 @@ fn load_model() -> dqn_model::Net> { ::StateType::size(), DENSE_SIZE, ::ActionType::size(), + &device, ) .load_record(record) -} +} \ No newline at end of file diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index bc8d836..d17df4a 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -1,21 +1,60 @@ -use burn::module::{Param, ParamId}; +use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; use burn::tensor::Tensor; -use burn_rl::base::{Agent, ElemType, Environment}; +use burn_rl::base::{Action, ElemType, Environment, State}; +use burn_rl::agent::DQN; -pub fn demo_model(agent: impl Agent) { +pub fn demo_model( + agent: DQN, + mut get_valid_actions: F, +) where + E: Environment, + M: Module + burn_rl::agent::DQNModel, + B: Backend, + F: FnMut(&E) -> Vec, + ::ActionType: PartialEq, +{ let mut env = E::new(true); let mut state = env.state(); let mut done = false; + let mut total_reward = 0.0; + let mut steps = 0; + while !done { - if let Some(action) = agent.react(&state) { - let snapshot = env.step(action); - state = *snapshot.state(); - // println!("{:?}", state); - done = snapshot.done(); + let model = agent.model().as_ref().unwrap(); + let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); + let q_values = model.infer(state_tensor); + + let valid_actions = get_valid_actions(&env); + if valid_actions.is_empty() { + break; // No valid actions, end of episode } + + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions.contains(&E::ActionType::from(index as u32)) { + masked_q_values = + masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY); + } + } + + let action_index = masked_q_values.argmax(1).into_scalar() as u32; + let action = E::ActionType::from(action_index); + + let snapshot = env.step(action); + state = *snapshot.state(); + total_reward += + <::RewardType as Into>::into(snapshot.reward().clone()); + steps += 1; + done = snapshot.done() || steps >= E::MAX_STEPS; } + println!( + "Episode terminé. Récompense totale: {:.2}, Étapes: {}", + total_reward, steps + ); } fn soft_update_tensor( diff --git a/bot/src/strategy/default.rs b/bot/src/strategy/default.rs index 81aa5f1..e01f406 100644 --- a/bot/src/strategy/default.rs +++ b/bot/src/strategy/default.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use store::MoveRules; #[derive(Debug)] diff --git a/bot/src/strategy/dqn.rs b/bot/src/strategy/dqn.rs index 779ce3d..6f22fac 100644 --- a/bot/src/strategy/dqn.rs +++ b/bot/src/strategy/dqn.rs @@ -1,4 +1,4 @@ -use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules}; +use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId}; use std::path::Path; use store::MoveRules; diff --git a/bot/src/strategy/dqn_common.rs b/bot/src/strategy/dqn_common.rs index 9a24ae6..1993a5d 100644 --- a/bot/src/strategy/dqn_common.rs +++ b/bot/src/strategy/dqn_common.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use serde::{Deserialize, Serialize}; -use store::{CheckerMove, Dice, GameEvent, PlayerId}; +use store::{CheckerMove, Dice}; /// Types d'actions possibles dans le jeu #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -259,7 +259,7 @@ impl SimpleNeuralNetwork { /// Obtient les actions valides pour l'état de jeu actuel pub fn get_valid_actions(game_state: &crate::GameState) -> Vec { - use crate::PointsRules; + use store::TurnStage; let mut valid_actions = Vec::new();