wip fix action mask
This commit is contained in:
parent
66377f877c
commit
ab0eb4be5a
|
|
@ -1,3 +1,4 @@
|
|||
use crate::burnrl::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
|
|
@ -7,9 +8,8 @@ use burn::tensor::Tensor;
|
|||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use crate::burnrl::utils::soft_update_linear;
|
||||
|
||||
#[derive(Module, Debug, Clone)]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
|
|
@ -18,7 +18,12 @@ pub struct Net<B: Backend> {
|
|||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
|
||||
pub fn new(
|
||||
input_size: usize,
|
||||
dense_size: usize,
|
||||
output_size: usize,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(device),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
|
||||
|
|
@ -45,8 +50,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = self.consume();
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
|
|
@ -107,8 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward +=
|
||||
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
episode_reward += <E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
|
|
@ -139,4 +143,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
}
|
||||
}
|
||||
agent.valid()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,8 @@ fn main() {
|
|||
let num_episodes = 3;
|
||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
save_model(valid_agent.model().as_ref().unwrap());
|
||||
save_model(agent.model().as_ref().unwrap());
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model();
|
||||
|
|
@ -58,4 +56,5 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
|||
&device,
|
||||
)
|
||||
.load_record(record)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
|
||||
pub fn demo_model<E, M, B, F>(
|
||||
agent: DQN<E, B, M>,
|
||||
mut get_valid_actions: F,
|
||||
) where
|
||||
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
|
||||
where
|
||||
E: Environment,
|
||||
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||
B: Backend,
|
||||
|
|
@ -36,12 +35,14 @@ pub fn demo_model<E, M, B, F>(
|
|||
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
masked_q_values =
|
||||
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = E::ActionType::from(action_index);
|
||||
|
||||
let snapshot = env.step(action);
|
||||
|
|
|
|||
Loading…
Reference in a new issue