feat: bot all algos
This commit is contained in:
parent
0c58490f87
commit
8f41cc1412
18 changed files with 929 additions and 39 deletions
|
|
@ -1,8 +1,10 @@
|
|||
use bot::burnrl::algos::{
|
||||
dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid,
|
||||
};
|
||||
use bot::burnrl::environment::TrictracEnvironment;
|
||||
use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig;
|
||||
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
||||
use bot::burnrl::utils::{demo_model, Config};
|
||||
use bot::burnrl::{dqn_big_model, dqn_model, dqn_valid_model, ppo_model, sac_model};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::base::ElemType;
|
||||
use std::env;
|
||||
|
|
@ -51,9 +53,9 @@ fn main() {
|
|||
|
||||
match algo.as_str() {
|
||||
"dqn" => {
|
||||
let _agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
let _agent = dqn::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_model::load_model(conf.dense_size, &path);
|
||||
let loaded_model = dqn::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
|
|
@ -61,33 +63,87 @@ fn main() {
|
|||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_big" => {
|
||||
let _agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
let _agent = dqn_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_valid" => {
|
||||
let _agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac" => {
|
||||
let _agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
let _agent = sac::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac_model::load_model(conf.dense_size, &path);
|
||||
let loaded_model = sac::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo" => {
|
||||
let _agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
"sac_big" => {
|
||||
let _agent = sac_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_model::load_model(conf.dense_size, &path);
|
||||
let loaded_model = sac_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac_valid" => {
|
||||
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo" => {
|
||||
let _agent = ppo::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo_big" => {
|
||||
let _agent = ppo_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo_valid" => {
|
||||
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
&_ => {
|
||||
dbg!("unknown algo {algo}");
|
||||
println!("unknown algo {algo}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue