wip action mask

This commit is contained in:
Henri Bourcereau 2025-07-26 09:37:54 +02:00
parent cb30fd3229
commit fa1a22f73d
2 changed files with 70 additions and 53 deletions

View file

@ -9,17 +9,17 @@ type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment; type Env = environment::TrictracEnvironment;
fn main() { fn main() {
println!("> Entraînement"); // println!("> Entraînement");
let num_episodes = 10; // let num_episodes = 20;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true); // let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
//
let valid_agent = agent.valid(); // let valid_agent = agent.valid();
//
println!("> Sauvegarde du modèle de validation"); // println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap()); // save_model(valid_agent.model().as_ref().unwrap());
//
println!("> Test avec le modèle entraîné"); // println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent); // demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(); let loaded_model = load_model();

View file

@ -1,56 +1,73 @@
use crate::strategy::dqn_common::{get_valid_actions, TrictracAction};
use crate::GameState;
use burn::module::{Param, ParamId}; use burn::module::{Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment}; use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
pub fn demo_model<E: Environment>(agent: impl Agent<E>) { pub fn demo_model<
// E: Environment<StateType = GameState, ActionType = TrictracAction>,
E: Environment,
B: Backend,
M: DQNModel<B>,
>(
agent: DQN<E, B, M>,
) {
// pub fn demo_model<E: Environment, B: Backend, M: DQNModel<B>>(agent: DQN<E, B, M>) {
let mut env = E::new(true); let mut env = E::new(true);
let mut state = env.state(); let mut state = env.state();
let mut done = false; let mut done = false;
while !done { while !done {
// // Get q values for current state // Get q values for current state
// let model = agent.model().as_ref().unwrap(); let q_values = agent
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze(); .model()
// let q_values = model.infer(state_tensor); .as_ref()
// .unwrap()
// // Get valid actions .infer(state.to_tensor().unsqueeze());
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
if let Some(action) = agent.react(&state) { // let actions_indexes = q_values.argmax(1).to_data().as_slice::<i64>().unwrap();
// println!("before : {:?}", state); //let react_action = (actions_indexes[0] as u32).into();
// println!("action : {:?}", action);
let snapshot = env.step(action); // Get valid actions
state = *snapshot.state(); let valid_actions = get_valid_actions(&state);
// println!("after : {:?}", state); if valid_actions.is_empty() {
// done = true; break; // No valid actions, end of episode
done = snapshot.done();
} }
// Set q values of non valid actions to the lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get action with the highest q-value
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = E::ActionType::from(action_index);
// Execute action
let snapshot = env.step(action);
state = *snapshot.state();
// println!("{:?}", state);
done = snapshot.done();
// if let Some(action) = agent.react(&state) {
// // println!("before : {:?}", state);
// // println!("action : {:?}", action);
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("after : {:?}", state);
// // done = true;
// done = snapshot.done();
// }
} }
} }