action mask
This commit is contained in:
parent
cb30fd3229
commit
3e1775428d
7 changed files with 111 additions and 554 deletions
|
|
@ -103,6 +103,9 @@ impl Environment for TrictracEnvironment {
|
|||
let player1_id = 1;
|
||||
let player2_id = 2;
|
||||
|
||||
// Commencer la partie
|
||||
game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
let current_state = TrictracState::from_game_state(&game);
|
||||
TrictracEnvironment {
|
||||
game,
|
||||
|
|
@ -140,7 +143,6 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
// Convertir l'action burn-rl vers une action Trictrac
|
||||
let trictrac_action = self.convert_action(action, &self.game);
|
||||
// println!("chosen action: {:?} -> {:?}", action, trictrac_action);
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut terminated = false;
|
||||
|
|
|
|||
|
|
@ -10,27 +10,28 @@ type Env = environment::TrictracEnvironment;
|
|||
|
||||
fn main() {
|
||||
println!("> Entraînement");
|
||||
let num_episodes = 10;
|
||||
let num_episodes = 50;
|
||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
save_model(valid_agent.model().as_ref().unwrap());
|
||||
|
||||
println!("> Test avec le modèle entraîné");
|
||||
demo_model::<Env>(valid_agent);
|
||||
let path = "models/burn_dqn_50".to_string();
|
||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||
|
||||
// println!("> Test avec le modèle entraîné");
|
||||
// demo_model::<Env>(valid_agent);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model();
|
||||
let loaded_model = load_model(&path);
|
||||
let loaded_agent = DQN::new(loaded_model);
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model::<Env>(loaded_agent);
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
|
||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
||||
let path = "models/burn_dqn".to_string();
|
||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Modèle de validation sauvegardé : {}", model_path);
|
||||
|
|
@ -39,11 +40,10 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
||||
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
|
||||
const DENSE_SIZE: usize = 128;
|
||||
|
||||
let path = "models/burn_dqn".to_string();
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Chargement du modèle depuis : {}", model_path);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,57 +1,59 @@
|
|||
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
||||
use crate::strategy::dqn_common::get_valid_action_indices;
|
||||
use burn::module::{Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Agent, ElemType, Environment};
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{ElemType, Environment, State};
|
||||
|
||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// // Get q values for current state
|
||||
// let model = agent.model().as_ref().unwrap();
|
||||
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
// let q_values = model.infer(state_tensor);
|
||||
//
|
||||
// // Get valid actions
|
||||
// let valid_actions = get_valid_actions(&state);
|
||||
// if valid_actions.is_empty() {
|
||||
// break; // No valid actions, end of episode
|
||||
// }
|
||||
//
|
||||
// // Set q values of non valid actions to the lowest
|
||||
// let mut masked_q_values = q_values.clone();
|
||||
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
// masked_q_values = masked_q_values.clone().mask_fill(
|
||||
// masked_q_values.clone().equal_elem(*q_value),
|
||||
// f32::NEG_INFINITY,
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Get action with the highest q-value
|
||||
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
// let action = E::ActionType::from(action_index);
|
||||
//
|
||||
// // Execute action
|
||||
// let snapshot = env.step(action);
|
||||
// state = *snapshot.state();
|
||||
// // println!("{:?}", state);
|
||||
// done = snapshot.done();
|
||||
// let action = match infer_action(&agent, &env, state) {
|
||||
let action = match infer_action(&agent, &env) {
|
||||
Some(value) => value,
|
||||
None => break,
|
||||
};
|
||||
// Execute action
|
||||
let snapshot = env.step(action);
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(action) = agent.react(&state) {
|
||||
// println!("before : {:?}", state);
|
||||
// println!("action : {:?}", action);
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("after : {:?}", state);
|
||||
// done = true;
|
||||
done = snapshot.done();
|
||||
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||
agent: &DQN<TrictracEnvironment, B, M>,
|
||||
env: &TrictracEnvironment,
|
||||
) -> Option<TrictracAction> {
|
||||
let state = env.state();
|
||||
// Get q-values
|
||||
let q_values = agent
|
||||
.model()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.infer(state.to_tensor().unsqueeze());
|
||||
// Get valid actions
|
||||
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = TrictracAction::from(action_index);
|
||||
Some(action)
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
|
|
|
|||
|
|
@ -1,85 +0,0 @@
|
|||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
|
||||
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
|
||||
where
|
||||
E: Environment,
|
||||
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||
B: Backend,
|
||||
F: FnMut(&E) -> Vec<E::ActionType>,
|
||||
<E as Environment>::ActionType: PartialEq,
|
||||
{
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
let mut done = false;
|
||||
let mut total_reward = 0.0;
|
||||
let mut steps = 0;
|
||||
|
||||
while !done {
|
||||
let model = agent.model().as_ref().unwrap();
|
||||
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
let q_values = model.infer(state_tensor);
|
||||
|
||||
let valid_actions = get_valid_actions(&env);
|
||||
if valid_actions.is_empty() {
|
||||
break; // No valid actions, end of episode
|
||||
}
|
||||
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = E::ActionType::from(action_index);
|
||||
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
total_reward +=
|
||||
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
steps += 1;
|
||||
done = snapshot.done() || steps >= E::MAX_STEPS;
|
||||
}
|
||||
println!(
|
||||
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||
total_reward, steps
|
||||
);
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
this: &Param<Tensor<B, N>>,
|
||||
that: &Param<Tensor<B, N>>,
|
||||
tau: ElemType,
|
||||
) -> Param<Tensor<B, N>> {
|
||||
let that_weight = that.val();
|
||||
let this_weight = this.val();
|
||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||
|
||||
Param::initialized(ParamId::new(), new_weight)
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(
|
||||
this: Linear<B>,
|
||||
that: &Linear<B>,
|
||||
tau: ElemType,
|
||||
) -> Linear<B> {
|
||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||
let bias = match (&this.bias, &that.bias) {
|
||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Linear::<B> { weight, bias }
|
||||
}
|
||||
|
|
@ -71,7 +71,7 @@ impl TrictracAction {
|
|||
encoded -= 625
|
||||
}
|
||||
let from1 = encoded / 25;
|
||||
let from2 = encoded % 25;
|
||||
let from2 = 1 + encoded % 25;
|
||||
(dice_order, from1, from2)
|
||||
}
|
||||
|
||||
|
|
@ -378,3 +378,30 @@ pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracActi
|
|||
let mut rng = thread_rng();
|
||||
valid_actions.choose(&mut rng).cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
let index = action.to_action_index();
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||
assert_eq!(81, index);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue