diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 9755ce1..c1f71fb 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -9,17 +9,17 @@ type Backend = Autodiff>; type Env = environment::TrictracEnvironment; fn main() { - println!("> Entraînement"); - let num_episodes = 10; - let agent = dqn_model::run::(num_episodes, false); //true); - - let valid_agent = agent.valid(); - - println!("> Sauvegarde du modèle de validation"); - save_model(valid_agent.model().as_ref().unwrap()); - - println!("> Test avec le modèle entraîné"); - demo_model::(valid_agent); + // println!("> Entraînement"); + // let num_episodes = 20; + // let agent = dqn_model::run::(num_episodes, false); //true); + // + // let valid_agent = agent.valid(); + // + // println!("> Sauvegarde du modèle de validation"); + // save_model(valid_agent.model().as_ref().unwrap()); + // + // println!("> Test avec le modèle entraîné"); + // demo_model::(valid_agent); println!("> Chargement du modèle pour test"); let loaded_model = load_model(); diff --git a/bot/src/burnrl/utils.rs b/bot/src/burnrl/utils.rs index 1815c08..35620fe 100644 --- a/bot/src/burnrl/utils.rs +++ b/bot/src/burnrl/utils.rs @@ -1,56 +1,73 @@ +use crate::strategy::dqn_common::{get_valid_actions, TrictracAction}; +use crate::GameState; use burn::module::{Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; +use burn::tensor::cast::ToElement; use burn::tensor::Tensor; -use burn_rl::base::{Agent, ElemType, Environment}; +use burn_rl::agent::{DQNModel, DQN}; +use burn_rl::base::{ElemType, Environment, State}; -pub fn demo_model(agent: impl Agent) { +pub fn demo_model< + // E: Environment, + E: Environment, + B: Backend, + M: DQNModel, +>( + agent: DQN, +) { + // pub fn demo_model>(agent: DQN) { let mut env = E::new(true); let mut state = env.state(); let mut done = false; while !done { - // // Get q values for current state - // let model = agent.model().as_ref().unwrap(); - // let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); - // let q_values = model.infer(state_tensor); - // - // // Get valid actions - // let valid_actions = get_valid_actions(&state); - // if valid_actions.is_empty() { - // break; // No valid actions, end of episode - // } - // - // // Set q values of non valid actions to the lowest - // let mut masked_q_values = q_values.clone(); - // let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); - // for (index, q_value) in q_values_vec.iter().enumerate() { - // if !valid_actions.contains(&E::ActionType::from(index as u32)) { - // masked_q_values = masked_q_values.clone().mask_fill( - // masked_q_values.clone().equal_elem(*q_value), - // f32::NEG_INFINITY, - // ); - // } - // } - // - // // Get action with the highest q-value - // let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); - // let action = E::ActionType::from(action_index); - // - // // Execute action - // let snapshot = env.step(action); - // state = *snapshot.state(); - // // println!("{:?}", state); - // done = snapshot.done(); + // Get q values for current state + let q_values = agent + .model() + .as_ref() + .unwrap() + .infer(state.to_tensor().unsqueeze()); - if let Some(action) = agent.react(&state) { - // println!("before : {:?}", state); - // println!("action : {:?}", action); - let snapshot = env.step(action); - state = *snapshot.state(); - // println!("after : {:?}", state); - // done = true; - done = snapshot.done(); + // let actions_indexes = q_values.argmax(1).to_data().as_slice::().unwrap(); + //let react_action = (actions_indexes[0] as u32).into(); + + // Get valid actions + let valid_actions = get_valid_actions(&state); + if valid_actions.is_empty() { + break; // No valid actions, end of episode } + + // Set q values of non valid actions to the lowest + let mut masked_q_values = q_values.clone(); + let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); + for (index, q_value) in q_values_vec.iter().enumerate() { + if !valid_actions.contains(&E::ActionType::from(index as u32)) { + masked_q_values = masked_q_values.clone().mask_fill( + masked_q_values.clone().equal_elem(*q_value), + f32::NEG_INFINITY, + ); + } + } + + // Get action with the highest q-value + let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); + let action = E::ActionType::from(action_index); + + // Execute action + let snapshot = env.step(action); + state = *snapshot.state(); + // println!("{:?}", state); + done = snapshot.done(); + + // if let Some(action) = agent.react(&state) { + // // println!("before : {:?}", state); + // // println!("action : {:?}", action); + // let snapshot = env.step(action); + // state = *snapshot.state(); + // // println!("after : {:?}", state); + // // done = true; + // done = snapshot.done(); + // } } }