action mask

This commit is contained in:
Henri Bourcereau 2025-07-26 09:37:54 +02:00
parent cb30fd3229
commit 01090eb403
5 changed files with 94 additions and 469 deletions

View file

@ -103,6 +103,9 @@ impl Environment for TrictracEnvironment {
let player1_id = 1;
let player2_id = 2;
// Commencer la partie
game.consume(&GameEvent::BeginGame { goes_first: 1 });
let current_state = TrictracState::from_game_state(&game);
TrictracEnvironment {
game,
@ -140,7 +143,7 @@ impl Environment for TrictracEnvironment {
// Convertir l'action burn-rl vers une action Trictrac
let trictrac_action = self.convert_action(action, &self.game);
// println!("chosen action: {:?} -> {:?}", action, trictrac_action);
println!("chosen action: {:?} -> {:?}", action, trictrac_action);
let mut reward = 0.0;
let mut terminated = false;

View file

@ -10,27 +10,28 @@ type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 10;
let num_episodes = 50;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent);
let path = "models/burn_dqn_50".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
let loaded_model = load_model(&path);
let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé");
demo_model::<Env>(loaded_agent);
demo_model(loaded_agent);
}
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string();
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
@ -39,11 +40,10 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
.unwrap();
}
fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);

View file

@ -1,57 +1,59 @@
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::strategy::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut env = E::new(true);
let mut state = env.state();
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);
let mut done = false;
while !done {
// // Get q values for current state
// let model = agent.model().as_ref().unwrap();
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
// let q_values = model.infer(state_tensor);
//
// // Get valid actions
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
// let action = match infer_action(&agent, &env, state) {
let action = match infer_action(&agent, &env) {
Some(value) => value,
None => break,
};
// Execute action
let snapshot = env.step(action);
done = snapshot.done();
}
}
if let Some(action) = agent.react(&state) {
// println!("before : {:?}", state);
// println!("action : {:?}", action);
let snapshot = env.step(action);
state = *snapshot.state();
// println!("after : {:?}", state);
// done = true;
done = snapshot.done();
fn infer_action<B: Backend, M: DQNModel<B>>(
agent: &DQN<TrictracEnvironment, B, M>,
env: &TrictracEnvironment,
) -> Option<TrictracAction> {
let state = env.state();
// Get q-values
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
// Get valid actions
let valid_actions_indices = get_valid_action_indices(&env.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = TrictracAction::from(action_index);
Some(action)
}
fn soft_update_tensor<const N: usize, B: Backend>(

View file

@ -71,7 +71,7 @@ impl TrictracAction {
encoded -= 625
}
let from1 = encoded / 25;
let from2 = encoded % 25;
let from2 = 1 + encoded % 25;
(dice_order, from1, from2)
}
@ -378,3 +378,30 @@ pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracActi
let mut rng = thread_rng();
valid_actions.choose(&mut rng).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn to_action_index() {
let action = TrictracAction::Move {
dice_order: true,
from1: 3,
from2: 4,
};
let index = action.to_action_index();
assert_eq!(Some(action), TrictracAction::from_action_index(index));
assert_eq!(81, index);
}
#[test]
fn from_action_index() {
let action = TrictracAction::Move {
dice_order: true,
from1: 3,
from2: 4,
};
assert_eq!(Some(action), TrictracAction::from_action_index(81));
}
}