wip fix action mask
This commit is contained in:
parent
66377f877c
commit
ab0eb4be5a
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::burnrl::utils::soft_update_linear;
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::optim::AdamWConfig;
|
use burn::optim::AdamWConfig;
|
||||||
|
|
@ -7,9 +8,8 @@ use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||||
use crate::burnrl::utils::soft_update_linear;
|
|
||||||
|
|
||||||
#[derive(Module, Debug, Clone)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Net<B: Backend> {
|
pub struct Net<B: Backend> {
|
||||||
linear_0: Linear<B>,
|
linear_0: Linear<B>,
|
||||||
linear_1: Linear<B>,
|
linear_1: Linear<B>,
|
||||||
|
|
@ -18,7 +18,12 @@ pub struct Net<B: Backend> {
|
||||||
|
|
||||||
impl<B: Backend> Net<B> {
|
impl<B: Backend> Net<B> {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
|
pub fn new(
|
||||||
|
input_size: usize,
|
||||||
|
dense_size: usize,
|
||||||
|
output_size: usize,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(device),
|
linear_0: LinearConfig::new(input_size, dense_size).init(device),
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
|
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
|
||||||
|
|
@ -45,8 +50,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> DQNModel<B> for Net<B> {
|
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
|
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||||
let (linear_0, linear_1, linear_2) = self.consume();
|
let (linear_0, linear_1, linear_2) = this.consume();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||||
|
|
@ -107,8 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
|
|
||||||
episode_reward +=
|
episode_reward += <E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||||
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
|
||||||
|
|
||||||
memory.push(
|
memory.push(
|
||||||
state,
|
state,
|
||||||
|
|
@ -139,4 +143,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
agent.valid()
|
agent.valid()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,8 @@ fn main() {
|
||||||
let num_episodes = 3;
|
let num_episodes = 3;
|
||||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
|
||||||
|
|
||||||
println!("> Sauvegarde du modèle de validation");
|
println!("> Sauvegarde du modèle de validation");
|
||||||
save_model(valid_agent.model().as_ref().unwrap());
|
save_model(agent.model().as_ref().unwrap());
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model();
|
let loaded_model = load_model();
|
||||||
|
|
@ -58,4 +56,5 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
&device,
|
&device,
|
||||||
)
|
)
|
||||||
.load_record(record)
|
.load_record(record)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
use burn::module::{Module, Param, ParamId};
|
use burn::module::{Module, Param, ParamId};
|
||||||
use burn::nn::Linear;
|
use burn::nn::Linear;
|
||||||
use burn::tensor::backend::Backend;
|
use burn::tensor::backend::Backend;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
|
||||||
pub fn demo_model<E, M, B, F>(
|
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
|
||||||
agent: DQN<E, B, M>,
|
where
|
||||||
mut get_valid_actions: F,
|
|
||||||
) where
|
|
||||||
E: Environment,
|
E: Environment,
|
||||||
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||||
B: Backend,
|
B: Backend,
|
||||||
|
|
@ -36,12 +35,14 @@ pub fn demo_model<E, M, B, F>(
|
||||||
|
|
||||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||||
masked_q_values =
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
let action = E::ActionType::from(action_index);
|
let action = E::ActionType::from(action_index);
|
||||||
|
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue