refact: remove dqn_simple

This commit is contained in:
Henri Bourcereau 2026-01-05 10:14:58 +01:00
parent 74f692d7ba
commit 7c50a6d07b
2 changed files with 8 additions and 18 deletions

View file

@ -9,10 +9,6 @@ edition = "2021"
name = "burn_train" name = "burn_train"
path = "src/burnrl/main.rs" path = "src/burnrl/main.rs"
[[bin]]
name = "train_dqn_simple"
path = "src/dqn_simple/main.rs"
[dependencies] [dependencies]
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }

View file

@ -1,5 +1,5 @@
use bot::{ use bot::{
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy, BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy, StableBaselines3Strategy,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -25,11 +25,11 @@ pub struct App {
impl App { impl App {
// Constructs a new instance of [`App`]. // Constructs a new instance of [`App`].
pub fn new(args: AppArgs) -> Self { pub fn new(args: AppArgs) -> Self {
let bot_strategies: Vec<Box<dyn BotStrategy>> = args let bot_strategies: Vec<Box<dyn BotStrategy>> =
.bot args.bot
.as_deref() .as_deref()
.map(|str_bots| { .map(|str_bots| {
str_bots str_bots
.split(",") .split(",")
.filter_map(|s| match s.trim() { .filter_map(|s| match s.trim() {
"dummy" => { "dummy" => {
@ -43,7 +43,6 @@ impl App {
} }
"ai" => Some(Box::new(StableBaselines3Strategy::default()) "ai" => Some(Box::new(StableBaselines3Strategy::default())
as Box<dyn BotStrategy>), as Box<dyn BotStrategy>),
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
"dqnburn" => { "dqnburn" => {
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
} }
@ -52,11 +51,6 @@ impl App {
Some(Box::new(StableBaselines3Strategy::new(path)) Some(Box::new(StableBaselines3Strategy::new(path))
as Box<dyn BotStrategy>) as Box<dyn BotStrategy>)
} }
s if s.starts_with("dqn:") => {
let path = s.trim_start_matches("dqn:");
Some(Box::new(DqnStrategy::new_with_model(path))
as Box<dyn BotStrategy>)
}
s if s.starts_with("dqnburn:") => { s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqnburn:"); let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string())) Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
@ -65,8 +59,8 @@ impl App {
_ => None, _ => None,
}) })
.collect() .collect()
}) })
.unwrap_or_default(); .unwrap_or_default();
let schools_enabled = false; let schools_enabled = false;
let should_quit = bot_strategies.len() > 1; let should_quit = bot_strategies.len() > 1;
Self { Self {