89 lines
3.2 KiB
Rust
89 lines
3.2 KiB
Rust
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
|
use crate::burnrl::ppo::ppo_model;
|
|
use crate::training_common::get_valid_action_indices;
|
|
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
use burn::module::{Module, Param, ParamId};
|
|
use burn::nn::Linear;
|
|
use burn::record::{CompactRecorder, Recorder};
|
|
use burn::tensor::backend::Backend;
|
|
use burn::tensor::cast::ToElement;
|
|
use burn::tensor::Tensor;
|
|
use burn_rl::agent::{PPOModel, PPO};
|
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
|
|
|
pub fn save_model(model: &ppo_model::Net<NdArray<ElemType>>, path: &String) {
|
|
let recorder = CompactRecorder::new();
|
|
let model_path = format!("{path}.mpk");
|
|
println!("Modèle de validation sauvegardé : {model_path}");
|
|
recorder
|
|
.record(model.clone().into_record(), model_path.into())
|
|
.unwrap();
|
|
}
|
|
|
|
pub fn load_model(dense_size: usize, path: &String) -> Option<ppo_model::Net<NdArray<ElemType>>> {
|
|
let model_path = format!("{path}.mpk");
|
|
// println!("Chargement du modèle depuis : {model_path}");
|
|
|
|
CompactRecorder::new()
|
|
.load(model_path.into(), &NdArrayDevice::default())
|
|
.map(|record| {
|
|
ppo_model::Net::new(
|
|
<TrictracEnvironment as Environment>::StateType::size(),
|
|
dense_size,
|
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
)
|
|
.load_record(record)
|
|
})
|
|
.ok()
|
|
}
|
|
|
|
pub fn demo_model<B: Backend, M: PPOModel<B>>(agent: PPO<TrictracEnvironment, B, M>) {
|
|
let mut env = TrictracEnvironment::new(true);
|
|
let mut done = false;
|
|
while !done {
|
|
// let action = match infer_action(&agent, &env, state) {
|
|
let action = match infer_action(&agent, &env) {
|
|
Some(value) => value,
|
|
None => break,
|
|
};
|
|
// Execute action
|
|
let snapshot = env.step(action);
|
|
done = snapshot.done();
|
|
}
|
|
}
|
|
|
|
fn infer_action<B: Backend, M: PPOModel<B>>(
|
|
agent: &PPO<TrictracEnvironment, B, M>,
|
|
env: &TrictracEnvironment,
|
|
) -> Option<TrictracAction> {
|
|
let state = env.state();
|
|
panic!("how to do that ?");
|
|
None
|
|
// Get q-values
|
|
// let q_values = agent
|
|
// .model()
|
|
// .as_ref()
|
|
// .unwrap()
|
|
// .infer(state.to_tensor().unsqueeze());
|
|
// // Get valid actions
|
|
// let valid_actions_indices = get_valid_action_indices(&env.game);
|
|
// if valid_actions_indices.is_empty() {
|
|
// return None; // No valid actions, end of episode
|
|
// }
|
|
// // Set non valid actions q-values to lowest
|
|
// let mut masked_q_values = q_values.clone();
|
|
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
|
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
|
// if !valid_actions_indices.contains(&index) {
|
|
// masked_q_values = masked_q_values.clone().mask_fill(
|
|
// masked_q_values.clone().equal_elem(*q_value),
|
|
// f32::NEG_INFINITY,
|
|
// );
|
|
// }
|
|
// }
|
|
// // Get best action (highest q-value)
|
|
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
|
// let action = TrictracAction::from(action_index);
|
|
// Some(action)
|
|
}
|