refact: remove dqn_simple
This commit is contained in:
parent
74f692d7ba
commit
7c50a6d07b
|
|
@ -9,10 +9,6 @@ edition = "2021"
|
||||||
name = "burn_train"
|
name = "burn_train"
|
||||||
path = "src/burnrl/main.rs"
|
path = "src/burnrl/main.rs"
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "train_dqn_simple"
|
|
||||||
path = "src/dqn_simple/main.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bot::{
|
use bot::{
|
||||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
|
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
||||||
StableBaselines3Strategy,
|
StableBaselines3Strategy,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
@ -25,8 +25,8 @@ pub struct App {
|
||||||
impl App {
|
impl App {
|
||||||
// Constructs a new instance of [`App`].
|
// Constructs a new instance of [`App`].
|
||||||
pub fn new(args: AppArgs) -> Self {
|
pub fn new(args: AppArgs) -> Self {
|
||||||
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
|
let bot_strategies: Vec<Box<dyn BotStrategy>> =
|
||||||
.bot
|
args.bot
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.map(|str_bots| {
|
.map(|str_bots| {
|
||||||
str_bots
|
str_bots
|
||||||
|
|
@ -43,7 +43,6 @@ impl App {
|
||||||
}
|
}
|
||||||
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
||||||
as Box<dyn BotStrategy>),
|
as Box<dyn BotStrategy>),
|
||||||
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
|
|
||||||
"dqnburn" => {
|
"dqnburn" => {
|
||||||
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
|
|
@ -52,11 +51,6 @@ impl App {
|
||||||
Some(Box::new(StableBaselines3Strategy::new(path))
|
Some(Box::new(StableBaselines3Strategy::new(path))
|
||||||
as Box<dyn BotStrategy>)
|
as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
s if s.starts_with("dqn:") => {
|
|
||||||
let path = s.trim_start_matches("dqn:");
|
|
||||||
Some(Box::new(DqnStrategy::new_with_model(path))
|
|
||||||
as Box<dyn BotStrategy>)
|
|
||||||
}
|
|
||||||
s if s.starts_with("dqnburn:") => {
|
s if s.starts_with("dqnburn:") => {
|
||||||
let path = s.trim_start_matches("dqnburn:");
|
let path = s.trim_start_matches("dqnburn:");
|
||||||
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue