fix: convert_action from_action_index
This commit is contained in:
parent
1e18b784d1
commit
b92c9eb7ff
|
|
@ -92,6 +92,7 @@ impl Environment for TrictracEnvironment {
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
||||||
|
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
@ -139,6 +140,7 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
// Convertir l'action burn-rl vers une action Trictrac
|
// Convertir l'action burn-rl vers une action Trictrac
|
||||||
let trictrac_action = self.convert_action(action, &self.game);
|
let trictrac_action = self.convert_action(action, &self.game);
|
||||||
|
// println!("chosen action: {:?} -> {:?}", action, trictrac_action);
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let mut terminated = false;
|
let mut terminated = false;
|
||||||
|
|
@ -204,6 +206,15 @@ impl TrictracEnvironment {
|
||||||
&self,
|
&self,
|
||||||
action: TrictracAction,
|
action: TrictracAction,
|
||||||
game_state: &GameState,
|
game_state: &GameState,
|
||||||
|
) -> Option<dqn_common::TrictracAction> {
|
||||||
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||||
|
fn convert_valid_action_index(
|
||||||
|
&self,
|
||||||
|
action: TrictracAction,
|
||||||
|
game_state: &GameState,
|
||||||
) -> Option<dqn_common::TrictracAction> {
|
) -> Option<dqn_common::TrictracAction> {
|
||||||
use dqn_common::get_valid_actions;
|
use dqn_common::get_valid_actions;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("> Entraînement");
|
println!("> Entraînement");
|
||||||
let num_episodes = 3;
|
let num_episodes = 10;
|
||||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
let valid_agent = agent.valid();
|
||||||
|
|
@ -18,6 +18,9 @@ fn main() {
|
||||||
println!("> Sauvegarde du modèle de validation");
|
println!("> Sauvegarde du modèle de validation");
|
||||||
save_model(valid_agent.model().as_ref().unwrap());
|
save_model(valid_agent.model().as_ref().unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle entraîné");
|
||||||
|
demo_model::<Env>(valid_agent);
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model();
|
let loaded_model = load_model();
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model);
|
||||||
|
|
@ -29,7 +32,7 @@ fn main() {
|
||||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
||||||
let path = "models/burn_dqn".to_string();
|
let path = "models/burn_dqn".to_string();
|
||||||
let recorder = CompactRecorder::new();
|
let recorder = CompactRecorder::new();
|
||||||
let model_path = format!("{}_model.burn", path);
|
let model_path = format!("{}_model.mpk", path);
|
||||||
println!("Modèle de validation sauvegardé : {}", model_path);
|
println!("Modèle de validation sauvegardé : {}", model_path);
|
||||||
recorder
|
recorder
|
||||||
.record(model.clone().into_record(), model_path.into())
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
|
@ -41,7 +44,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
const DENSE_SIZE: usize = 128;
|
const DENSE_SIZE: usize = 128;
|
||||||
|
|
||||||
let path = "models/burn_dqn".to_string();
|
let path = "models/burn_dqn".to_string();
|
||||||
let model_path = format!("{}_model.burn", path);
|
let model_path = format!("{}_model.mpk", path);
|
||||||
println!("Chargement du modèle depuis : {}", model_path);
|
println!("Chargement du modèle depuis : {}", model_path);
|
||||||
|
|
||||||
let device = NdArrayDevice::default();
|
let device = NdArrayDevice::default();
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,46 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||||
let mut state = env.state();
|
let mut state = env.state();
|
||||||
let mut done = false;
|
let mut done = false;
|
||||||
while !done {
|
while !done {
|
||||||
|
// // Get q values for current state
|
||||||
|
// let model = agent.model().as_ref().unwrap();
|
||||||
|
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||||
|
// let q_values = model.infer(state_tensor);
|
||||||
|
//
|
||||||
|
// // Get valid actions
|
||||||
|
// let valid_actions = get_valid_actions(&state);
|
||||||
|
// if valid_actions.is_empty() {
|
||||||
|
// break; // No valid actions, end of episode
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Set q values of non valid actions to the lowest
|
||||||
|
// let mut masked_q_values = q_values.clone();
|
||||||
|
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||||
|
// masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
// masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
// f32::NEG_INFINITY,
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Get action with the highest q-value
|
||||||
|
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
// let action = E::ActionType::from(action_index);
|
||||||
|
//
|
||||||
|
// // Execute action
|
||||||
|
// let snapshot = env.step(action);
|
||||||
|
// state = *snapshot.state();
|
||||||
|
// // println!("{:?}", state);
|
||||||
|
// done = snapshot.done();
|
||||||
|
|
||||||
if let Some(action) = agent.react(&state) {
|
if let Some(action) = agent.react(&state) {
|
||||||
|
// println!("before : {:?}", state);
|
||||||
|
// println!("action : {:?}", action);
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
state = *snapshot.state();
|
state = *snapshot.state();
|
||||||
// println!("{:?}", state);
|
// println!("after : {:?}", state);
|
||||||
|
// done = true;
|
||||||
done = snapshot.done();
|
done = snapshot.done();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
85
bot/src/burnrl/utils_wip.rs
Normal file
85
bot/src/burnrl/utils_wip.rs
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
use burn::module::{Module, Param, ParamId};
|
||||||
|
use burn::nn::Linear;
|
||||||
|
use burn::tensor::backend::Backend;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
|
||||||
|
pub fn demo_model<E, M, B, F>(agent: DQN<E, B, M>, mut get_valid_actions: F)
|
||||||
|
where
|
||||||
|
E: Environment,
|
||||||
|
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||||
|
B: Backend,
|
||||||
|
F: FnMut(&E) -> Vec<E::ActionType>,
|
||||||
|
<E as Environment>::ActionType: PartialEq,
|
||||||
|
{
|
||||||
|
let mut env = E::new(true);
|
||||||
|
let mut state = env.state();
|
||||||
|
let mut done = false;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
let mut steps = 0;
|
||||||
|
|
||||||
|
while !done {
|
||||||
|
let model = agent.model().as_ref().unwrap();
|
||||||
|
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||||
|
let q_values = model.infer(state_tensor);
|
||||||
|
|
||||||
|
let valid_actions = get_valid_actions(&env);
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
break; // No valid actions, end of episode
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||||
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
let action = E::ActionType::from(action_index);
|
||||||
|
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
state = *snapshot.state();
|
||||||
|
total_reward +=
|
||||||
|
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||||
|
steps += 1;
|
||||||
|
done = snapshot.done() || steps >= E::MAX_STEPS;
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||||
|
total_reward, steps
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||||
|
this: &Param<Tensor<B, N>>,
|
||||||
|
that: &Param<Tensor<B, N>>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Param<Tensor<B, N>> {
|
||||||
|
let that_weight = that.val();
|
||||||
|
let this_weight = this.val();
|
||||||
|
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||||
|
|
||||||
|
Param::initialized(ParamId::new(), new_weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn soft_update_linear<B: Backend>(
|
||||||
|
this: Linear<B>,
|
||||||
|
that: &Linear<B>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Linear<B> {
|
||||||
|
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||||
|
let bias = match (&this.bias, &that.bias) {
|
||||||
|
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Linear::<B> { weight, bias }
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue