use crate::burnrl::environment::{TrictracAction, TrictracEnvironment}; use crate::burnrl::ppo::ppo_model; use crate::training_common::get_valid_action_indices; use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::module::{Module, Param, ParamId}; use burn::nn::Linear; use burn::record::{CompactRecorder, Recorder}; use burn::tensor::backend::Backend; use burn::tensor::cast::ToElement; use burn::tensor::Tensor; use burn_rl::agent::{PPOModel, PPO}; use burn_rl::base::{Action, ElemType, Environment, State}; pub fn save_model(model: &ppo_model::Net>, path: &String) { let recorder = CompactRecorder::new(); let model_path = format!("{path}.mpk"); println!("Modèle de validation sauvegardé : {model_path}"); recorder .record(model.clone().into_record(), model_path.into()) .unwrap(); } pub fn load_model(dense_size: usize, path: &String) -> Option>> { let model_path = format!("{path}.mpk"); // println!("Chargement du modèle depuis : {model_path}"); CompactRecorder::new() .load(model_path.into(), &NdArrayDevice::default()) .map(|record| { ppo_model::Net::new( ::StateType::size(), dense_size, ::ActionType::size(), ) .load_record(record) }) .ok() } pub fn demo_model>(agent: PPO) { let mut env = TrictracEnvironment::new(true); let mut done = false; while !done { // let action = match infer_action(&agent, &env, state) { let action = match infer_action(&agent, &env) { Some(value) => value, None => break, }; // Execute action let snapshot = env.step(action); done = snapshot.done(); } } fn infer_action>( agent: &PPO, env: &TrictracEnvironment, ) -> Option { let state = env.state(); panic!("how to do that ?"); None // Get q-values // let q_values = agent // .model() // .as_ref() // .unwrap() // .infer(state.to_tensor().unsqueeze()); // // Get valid actions // let valid_actions_indices = get_valid_action_indices(&env.game); // if valid_actions_indices.is_empty() { // return None; // No valid actions, end of episode // } // // Set non valid actions q-values to lowest // let mut masked_q_values = q_values.clone(); // let q_values_vec: Vec = q_values.into_data().into_vec().unwrap(); // for (index, q_value) in q_values_vec.iter().enumerate() { // if !valid_actions_indices.contains(&index) { // masked_q_values = masked_q_values.clone().mask_fill( // masked_q_values.clone().equal_elem(*q_value), // f32::NEG_INFINITY, // ); // } // } // // Get best action (highest q-value) // let action_index = masked_q_values.argmax(1).into_scalar().to_u32(); // let action = TrictracAction::from(action_index); // Some(action) }