wip fix action mask

This commit is contained in:
Henri Bourcereau 2025-07-24 17:46:47 +02:00
parent 66377f877c
commit ab0eb4be5a
3 changed files with 25 additions and 20 deletions

View file

@ -1,3 +1,4 @@
use crate::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
@ -7,9 +8,8 @@ use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use crate::burnrl::utils::soft_update_linear;
#[derive(Module, Debug, Clone)]
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
@ -18,7 +18,12 @@ pub struct Net<B: Backend> {
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
pub fn new(
input_size: usize,
dense_size: usize,
output_size: usize,
device: &B::Device,
) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(device),
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
@ -45,8 +50,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
}
impl<B: Backend> DQNModel<B> for Net<B> {
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = self.consume();
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
@ -107,8 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action);
episode_reward +=
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
episode_reward += <E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
memory.push(
state,
@ -139,4 +143,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
}
}
agent.valid()
}
}

View file

@ -13,10 +13,8 @@ fn main() {
let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
save_model(agent.model().as_ref().unwrap());
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
@ -58,4 +56,5 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
&device,
)
.load_record(record)
}
}

View file

@ -1,14 +1,13 @@
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::base::{Action, ElemType, Environment, State};
use burn_rl::agent::DQN;
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn demo_model<E, M, B, F>(
agent: DQN<E, B, M>,
mut get_valid_actions: F,
) where
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
where
E: Environment,
M: Module<B> + burn_rl::agent::DQNModel<B>,
B: Backend,
@ -36,12 +35,14 @@ pub fn demo_model<E, M, B, F>(
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values =
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = E::ActionType::from(action_index);
let snapshot = env.step(action);