fix: convert_action from_action_index

This commit is contained in:
Henri Bourcereau 2025-07-25 17:26:02 +02:00
parent 1e18b784d1
commit b92c9eb7ff
4 changed files with 139 additions and 4 deletions

View file

@ -92,6 +92,7 @@ impl Environment for TrictracEnvironment {
type RewardType = f32; type RewardType = f32;
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self { fn new(visualized: bool) -> Self {
let mut game = GameState::new(false); let mut game = GameState::new(false);
@ -139,6 +140,7 @@ impl Environment for TrictracEnvironment {
// Convertir l'action burn-rl vers une action Trictrac // Convertir l'action burn-rl vers une action Trictrac
let trictrac_action = self.convert_action(action, &self.game); let trictrac_action = self.convert_action(action, &self.game);
// println!("chosen action: {:?} -> {:?}", action, trictrac_action);
let mut reward = 0.0; let mut reward = 0.0;
let mut terminated = false; let mut terminated = false;
@ -204,6 +206,15 @@ impl TrictracEnvironment {
&self, &self,
action: TrictracAction, action: TrictracAction,
game_state: &GameState, game_state: &GameState,
) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
fn convert_valid_action_index(
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<dqn_common::TrictracAction> { ) -> Option<dqn_common::TrictracAction> {
use dqn_common::get_valid_actions; use dqn_common::get_valid_actions;

View file

@ -10,7 +10,7 @@ type Env = environment::TrictracEnvironment;
fn main() { fn main() {
println!("> Entraînement"); println!("> Entraînement");
let num_episodes = 3; let num_episodes = 10;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true); let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
let valid_agent = agent.valid(); let valid_agent = agent.valid();
@ -18,6 +18,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation"); println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap()); save_model(valid_agent.model().as_ref().unwrap());
println!("> Test avec le modèle entraîné");
demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(); let loaded_model = load_model();
let loaded_agent = DQN::new(loaded_model); let loaded_agent = DQN::new(loaded_model);
@ -29,7 +32,7 @@ fn main() {
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) { fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string(); let path = "models/burn_dqn".to_string();
let recorder = CompactRecorder::new(); let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path); let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path); println!("Modèle de validation sauvegardé : {}", model_path);
recorder recorder
.record(model.clone().into_record(), model_path.into()) .record(model.clone().into_record(), model_path.into())
@ -41,7 +44,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
const DENSE_SIZE: usize = 128; const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string(); let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.burn", path); let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path); println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default(); let device = NdArrayDevice::default();

View file

@ -9,10 +9,46 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut state = env.state(); let mut state = env.state();
let mut done = false; let mut done = false;
while !done { while !done {
// // Get q values for current state
// let model = agent.model().as_ref().unwrap();
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
// let q_values = model.infer(state_tensor);
//
// // Get valid actions
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
if let Some(action) = agent.react(&state) { if let Some(action) = agent.react(&state) {
// println!("before : {:?}", state);
// println!("action : {:?}", action);
let snapshot = env.step(action); let snapshot = env.step(action);
state = *snapshot.state(); state = *snapshot.state();
// println!("{:?}", state); // println!("after : {:?}", state);
// done = true;
done = snapshot.done(); done = snapshot.done();
} }
} }

View file

@ -0,0 +1,85 @@
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
where
E: Environment,
M: Module<B> + burn_rl::agent::DQNModel<B>,
B: Backend,
F: FnMut(&E) -> Vec<E::ActionType>,
<E as Environment>::ActionType: PartialEq,
{
let mut env = E::new(true);
let mut state = env.state();
let mut done = false;
let mut total_reward = 0.0;
let mut steps = 0;
while !done {
let model = agent.model().as_ref().unwrap();
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
let q_values = model.infer(state_tensor);
let valid_actions = get_valid_actions(&env);
if valid_actions.is_empty() {
break; // No valid actions, end of episode
}
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = E::ActionType::from(action_index);
let snapshot = env.step(action);
state = *snapshot.state();
total_reward +=
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
steps += 1;
done = snapshot.done() || steps >= E::MAX_STEPS;
}
println!(
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
total_reward, steps
);
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}