action mask

This commit is contained in:
Henri Bourcereau 2025-07-26 09:37:54 +02:00
parent cb30fd3229
commit 3e1775428d
7 changed files with 111 additions and 554 deletions

View file

@ -1,57 +1,59 @@
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::strategy::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut env = E::new(true);
let mut state = env.state();
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);
let mut done = false;
while !done {
// // Get q values for current state
// let model = agent.model().as_ref().unwrap();
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
// let q_values = model.infer(state_tensor);
//
// // Get valid actions
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
// let action = match infer_action(&agent, &env, state) {
let action = match infer_action(&agent, &env) {
Some(value) => value,
None => break,
};
// Execute action
let snapshot = env.step(action);
done = snapshot.done();
}
}
if let Some(action) = agent.react(&state) {
// println!("before : {:?}", state);
// println!("action : {:?}", action);
let snapshot = env.step(action);
state = *snapshot.state();
// println!("after : {:?}", state);
// done = true;
done = snapshot.done();
fn infer_action<B: Backend, M: DQNModel<B>>(
agent: &DQN<TrictracEnvironment, B, M>,
env: &TrictracEnvironment,
) -> Option<TrictracAction> {
let state = env.state();
// Get q-values
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
// Get valid actions
let valid_actions_indices = get_valid_action_indices(&env.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = TrictracAction::from(action_index);
Some(action)
}
fn soft_update_tensor<const N: usize, B: Backend>(