wip fix action mask

This commit is contained in:
Henri Bourcereau 2025-07-24 17:46:47 +02:00
parent 66377f877c
commit ab0eb4be5a
3 changed files with 25 additions and 20 deletions

View file

@ -1,3 +1,4 @@
use crate::burnrl::utils::soft_update_linear;
use burn::module::Module; use burn::module::Module;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;
@ -7,9 +8,8 @@ use burn::tensor::Tensor;
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use crate::burnrl::utils::soft_update_linear;
#[derive(Module, Debug, Clone)] #[derive(Module, Debug)]
pub struct Net<B: Backend> { pub struct Net<B: Backend> {
linear_0: Linear<B>, linear_0: Linear<B>,
linear_1: Linear<B>, linear_1: Linear<B>,
@ -18,7 +18,12 @@ pub struct Net<B: Backend> {
impl<B: Backend> Net<B> { impl<B: Backend> Net<B> {
#[allow(unused)] #[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self { pub fn new(
input_size: usize,
dense_size: usize,
output_size: usize,
device: &B::Device,
) -> Self {
Self { Self {
linear_0: LinearConfig::new(input_size, dense_size).init(device), linear_0: LinearConfig::new(input_size, dense_size).init(device),
linear_1: LinearConfig::new(dense_size, dense_size).init(device), linear_1: LinearConfig::new(dense_size, dense_size).init(device),
@ -45,8 +50,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
} }
impl<B: Backend> DQNModel<B> for Net<B> { impl<B: Backend> DQNModel<B> for Net<B> {
fn soft_update(self, that: &Self, tau: ElemType) -> Self { fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = self.consume(); let (linear_0, linear_1, linear_2) = this.consume();
Self { Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau), linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
@ -107,8 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold); DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action); let snapshot = env.step(action);
episode_reward += episode_reward += <E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
memory.push( memory.push(
state, state,
@ -139,4 +143,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
} }
} }
agent.valid() agent.valid()
} }

View file

@ -13,10 +13,8 @@ fn main() {
let num_episodes = 3; let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true); let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation"); println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap()); save_model(agent.model().as_ref().unwrap());
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(); let loaded_model = load_model();
@ -58,4 +56,5 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
&device, &device,
) )
.load_record(record) .load_record(record)
} }

View file

@ -1,14 +1,13 @@
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::base::{Action, ElemType, Environment, State};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn demo_model<E, M, B, F>( pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
agent: DQN<E, B, M>, where
mut get_valid_actions: F,
) where
E: Environment, E: Environment,
M: Module<B> + burn_rl::agent::DQNModel<B>, M: Module<B> + burn_rl::agent::DQNModel<B>,
B: Backend, B: Backend,
@ -36,12 +35,14 @@ pub fn demo_model<E, M, B, F>(
for (index, q_value) in q_values_vec.iter().enumerate() { for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) { if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values = masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY); masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
} }
} }
let action_index = masked_q_values.argmax(1).into_scalar() as u32; let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = E::ActionType::from(action_index); let action = E::ActionType::from(action_index);
let snapshot = env.step(action); let snapshot = env.step(action);