refact: remove dqn_simple

This commit is contained in:
Henri Bourcereau 2026-01-05 10:14:58 +01:00
parent 74f692d7ba
commit 7c50a6d07b
2 changed files with 8 additions and 18 deletions

View file

@ -9,10 +9,6 @@ edition = "2021"
name = "burn_train"
path = "src/burnrl/main.rs"
[[bin]]
name = "train_dqn_simple"
path = "src/dqn_simple/main.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -1,5 +1,5 @@
use bot::{
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy,
};
use itertools::Itertools;
@ -25,8 +25,8 @@ pub struct App {
impl App {
// Constructs a new instance of [`App`].
pub fn new(args: AppArgs) -> Self {
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
.bot
let bot_strategies: Vec<Box<dyn BotStrategy>> =
args.bot
.as_deref()
.map(|str_bots| {
str_bots
@ -43,7 +43,6 @@ impl App {
}
"ai" => Some(Box::new(StableBaselines3Strategy::default())
as Box<dyn BotStrategy>),
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
"dqnburn" => {
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
}
@ -52,11 +51,6 @@ impl App {
Some(Box::new(StableBaselines3Strategy::new(path))
as Box<dyn BotStrategy>)
}
s if s.starts_with("dqn:") => {
let path = s.trim_start_matches("dqn:");
Some(Box::new(DqnStrategy::new_with_model(path))
as Box<dyn BotStrategy>)
}
s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))