wip action mask
This commit is contained in:
parent
cb30fd3229
commit
fa1a22f73d
|
|
@ -9,17 +9,17 @@ type Backend = Autodiff<NdArray<ElemType>>;
|
|||
type Env = environment::TrictracEnvironment;
|
||||
|
||||
fn main() {
|
||||
println!("> Entraînement");
|
||||
let num_episodes = 10;
|
||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
save_model(valid_agent.model().as_ref().unwrap());
|
||||
|
||||
println!("> Test avec le modèle entraîné");
|
||||
demo_model::<Env>(valid_agent);
|
||||
// println!("> Entraînement");
|
||||
// let num_episodes = 20;
|
||||
// let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
//
|
||||
// let valid_agent = agent.valid();
|
||||
//
|
||||
// println!("> Sauvegarde du modèle de validation");
|
||||
// save_model(valid_agent.model().as_ref().unwrap());
|
||||
//
|
||||
// println!("> Test avec le modèle entraîné");
|
||||
// demo_model::<Env>(valid_agent);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model();
|
||||
|
|
|
|||
|
|
@ -1,56 +1,73 @@
|
|||
use crate::strategy::dqn_common::{get_valid_actions, TrictracAction};
|
||||
use crate::GameState;
|
||||
use burn::module::{Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Agent, ElemType, Environment};
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{ElemType, Environment, State};
|
||||
|
||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||
pub fn demo_model<
|
||||
// E: Environment<StateType = GameState, ActionType = TrictracAction>,
|
||||
E: Environment,
|
||||
B: Backend,
|
||||
M: DQNModel<B>,
|
||||
>(
|
||||
agent: DQN<E, B, M>,
|
||||
) {
|
||||
// pub fn demo_model<E: Environment, B: Backend, M: DQNModel<B>>(agent: DQN<E, B, M>) {
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// // Get q values for current state
|
||||
// let model = agent.model().as_ref().unwrap();
|
||||
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
// let q_values = model.infer(state_tensor);
|
||||
//
|
||||
// // Get valid actions
|
||||
// let valid_actions = get_valid_actions(&state);
|
||||
// if valid_actions.is_empty() {
|
||||
// break; // No valid actions, end of episode
|
||||
// }
|
||||
//
|
||||
// // Set q values of non valid actions to the lowest
|
||||
// let mut masked_q_values = q_values.clone();
|
||||
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
// masked_q_values = masked_q_values.clone().mask_fill(
|
||||
// masked_q_values.clone().equal_elem(*q_value),
|
||||
// f32::NEG_INFINITY,
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Get action with the highest q-value
|
||||
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
// let action = E::ActionType::from(action_index);
|
||||
//
|
||||
// // Execute action
|
||||
// let snapshot = env.step(action);
|
||||
// state = *snapshot.state();
|
||||
// // println!("{:?}", state);
|
||||
// done = snapshot.done();
|
||||
// Get q values for current state
|
||||
let q_values = agent
|
||||
.model()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.infer(state.to_tensor().unsqueeze());
|
||||
|
||||
if let Some(action) = agent.react(&state) {
|
||||
// println!("before : {:?}", state);
|
||||
// println!("action : {:?}", action);
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("after : {:?}", state);
|
||||
// done = true;
|
||||
done = snapshot.done();
|
||||
// let actions_indexes = q_values.argmax(1).to_data().as_slice::<i64>().unwrap();
|
||||
//let react_action = (actions_indexes[0] as u32).into();
|
||||
|
||||
// Get valid actions
|
||||
let valid_actions = get_valid_actions(&state);
|
||||
if valid_actions.is_empty() {
|
||||
break; // No valid actions, end of episode
|
||||
}
|
||||
|
||||
// Set q values of non valid actions to the lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Get action with the highest q-value
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = E::ActionType::from(action_index);
|
||||
|
||||
// Execute action
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("{:?}", state);
|
||||
done = snapshot.done();
|
||||
|
||||
// if let Some(action) = agent.react(&state) {
|
||||
// // println!("before : {:?}", state);
|
||||
// // println!("action : {:?}", action);
|
||||
// let snapshot = env.step(action);
|
||||
// state = *snapshot.state();
|
||||
// // println!("after : {:?}", state);
|
||||
// // done = true;
|
||||
// done = snapshot.done();
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue