wip action mask

This commit is contained in:
Henri Bourcereau 2025-07-26 09:37:54 +02:00
parent cb30fd3229
commit fa1a22f73d
2 changed files with 70 additions and 53 deletions

View file

@ -9,17 +9,17 @@ type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 10;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent);
// println!("> Entraînement");
// let num_episodes = 20;
// let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
//
// let valid_agent = agent.valid();
//
// println!("> Sauvegarde du modèle de validation");
// save_model(valid_agent.model().as_ref().unwrap());
//
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model();

View file

@ -1,56 +1,73 @@
use crate::strategy::dqn_common::{get_valid_actions, TrictracAction};
use crate::GameState;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
pub fn demo_model<
// E: Environment<StateType = GameState, ActionType = TrictracAction>,
E: Environment,
B: Backend,
M: DQNModel<B>,
>(
agent: DQN<E, B, M>,
) {
// pub fn demo_model<E: Environment, B: Backend, M: DQNModel<B>>(agent: DQN<E, B, M>) {
let mut env = E::new(true);
let mut state = env.state();
let mut done = false;
while !done {
// // Get q values for current state
// let model = agent.model().as_ref().unwrap();
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
// let q_values = model.infer(state_tensor);
//
// // Get valid actions
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
// Get q values for current state
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
if let Some(action) = agent.react(&state) {
// println!("before : {:?}", state);
// println!("action : {:?}", action);
let snapshot = env.step(action);
state = *snapshot.state();
// println!("after : {:?}", state);
// done = true;
done = snapshot.done();
// let actions_indexes = q_values.argmax(1).to_data().as_slice::<i64>().unwrap();
//let react_action = (actions_indexes[0] as u32).into();
// Get valid actions
let valid_actions = get_valid_actions(&state);
if valid_actions.is_empty() {
break; // No valid actions, end of episode
}
// Set q values of non valid actions to the lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get action with the highest q-value
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = E::ActionType::from(action_index);
// Execute action
let snapshot = env.step(action);
state = *snapshot.state();
// println!("{:?}", state);
done = snapshot.done();
// if let Some(action) = agent.react(&state) {
// // println!("before : {:?}", state);
// // println!("action : {:?}", action);
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("after : {:?}", state);
// // done = true;
// done = snapshot.done();
// }
}
}