refact : save model
This commit is contained in:
parent
c6d33555ec
commit
6fa8a31cc7
|
|
@ -70,7 +70,8 @@ type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
num_episodes: usize,
|
num_episodes: usize,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> impl Agent<E> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
|
// ) -> impl Agent<E> {
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
|
|
@ -138,16 +139,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
agent
|
||||||
// Save
|
|
||||||
let path = "models/burn_dqn".to_string();
|
|
||||||
let inference_network = agent.model().clone().into_record();
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{}_model.burn", path);
|
|
||||||
println!("Modèle sauvegardé : {}", model_path);
|
|
||||||
recorder
|
|
||||||
.record(inference_network, model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
agent.valid()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,35 @@
|
||||||
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
||||||
use burn::backend::{Autodiff, NdArray};
|
use burn::backend::{Autodiff, NdArray};
|
||||||
|
use burn::module::Module;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::base::ElemType;
|
use burn_rl::base::ElemType;
|
||||||
|
|
||||||
type Backend = Autodiff<NdArray<ElemType>>;
|
type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
type Env = environment::TrictracEnvironment;
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
println!("> Entraînement");
|
||||||
let num_episodes = 3;
|
let num_episodes = 3;
|
||||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||||
|
println!("> Sauvegarde");
|
||||||
|
save(&agent);
|
||||||
|
|
||||||
demo_model::<Env>(agent);
|
// cette ligne sert à extraire le "cerveau" de l'agent entraîné,
|
||||||
|
// sans les données nécessaires à l'entraînement
|
||||||
|
let valid_agent = agent.valid();
|
||||||
|
|
||||||
|
println!("> Test");
|
||||||
|
demo_model::<Env>(valid_agent);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save(agent: &DQN<Env, Backend, dqn_model::Net<Backend>>) {
|
||||||
|
let path = "models/burn_dqn".to_string();
|
||||||
|
let inference_network = agent.model().clone().into_record();
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{}_model.burn", path);
|
||||||
|
println!("Modèle sauvegardé : {}", model_path);
|
||||||
|
recorder
|
||||||
|
.record(inference_network, model_path.into())
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue