fix: convert_action from_action_index
This commit is contained in:
parent
1e18b784d1
commit
b92c9eb7ff
4 changed files with 139 additions and 4 deletions
|
|
@ -10,7 +10,7 @@ type Env = environment::TrictracEnvironment;
|
|||
|
||||
fn main() {
|
||||
println!("> Entraînement");
|
||||
let num_episodes = 3;
|
||||
let num_episodes = 10;
|
||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
|
@ -18,6 +18,9 @@ fn main() {
|
|||
println!("> Sauvegarde du modèle de validation");
|
||||
save_model(valid_agent.model().as_ref().unwrap());
|
||||
|
||||
println!("> Test avec le modèle entraîné");
|
||||
demo_model::<Env>(valid_agent);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model();
|
||||
let loaded_agent = DQN::new(loaded_model);
|
||||
|
|
@ -29,7 +32,7 @@ fn main() {
|
|||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
||||
let path = "models/burn_dqn".to_string();
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Modèle de validation sauvegardé : {}", model_path);
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
|
|
@ -41,7 +44,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
|||
const DENSE_SIZE: usize = 128;
|
||||
|
||||
let path = "models/burn_dqn".to_string();
|
||||
let model_path = format!("{}_model.burn", path);
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Chargement du modèle depuis : {}", model_path);
|
||||
|
||||
let device = NdArrayDevice::default();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue