action mask
This commit is contained in:
parent
cb30fd3229
commit
3e1775428d
7 changed files with 111 additions and 554 deletions
|
|
@ -1,57 +1,59 @@
|
|||
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
||||
use crate::strategy::dqn_common::get_valid_action_indices;
|
||||
use burn::module::{Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Agent, ElemType, Environment};
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{ElemType, Environment, State};
|
||||
|
||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// // Get q values for current state
|
||||
// let model = agent.model().as_ref().unwrap();
|
||||
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
// let q_values = model.infer(state_tensor);
|
||||
//
|
||||
// // Get valid actions
|
||||
// let valid_actions = get_valid_actions(&state);
|
||||
// if valid_actions.is_empty() {
|
||||
// break; // No valid actions, end of episode
|
||||
// }
|
||||
//
|
||||
// // Set q values of non valid actions to the lowest
|
||||
// let mut masked_q_values = q_values.clone();
|
||||
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
// masked_q_values = masked_q_values.clone().mask_fill(
|
||||
// masked_q_values.clone().equal_elem(*q_value),
|
||||
// f32::NEG_INFINITY,
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Get action with the highest q-value
|
||||
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
// let action = E::ActionType::from(action_index);
|
||||
//
|
||||
// // Execute action
|
||||
// let snapshot = env.step(action);
|
||||
// state = *snapshot.state();
|
||||
// // println!("{:?}", state);
|
||||
// done = snapshot.done();
|
||||
// let action = match infer_action(&agent, &env, state) {
|
||||
let action = match infer_action(&agent, &env) {
|
||||
Some(value) => value,
|
||||
None => break,
|
||||
};
|
||||
// Execute action
|
||||
let snapshot = env.step(action);
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(action) = agent.react(&state) {
|
||||
// println!("before : {:?}", state);
|
||||
// println!("action : {:?}", action);
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("after : {:?}", state);
|
||||
// done = true;
|
||||
done = snapshot.done();
|
||||
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||
agent: &DQN<TrictracEnvironment, B, M>,
|
||||
env: &TrictracEnvironment,
|
||||
) -> Option<TrictracAction> {
|
||||
let state = env.state();
|
||||
// Get q-values
|
||||
let q_values = agent
|
||||
.model()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.infer(state.to_tensor().unsqueeze());
|
||||
// Get valid actions
|
||||
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = TrictracAction::from(action_index);
|
||||
Some(action)
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue