runcli with bot dqn burn-rl

This commit is contained in:
Henri Bourcereau 2025-08-08 21:31:38 +02:00
parent a19c5d8596
commit 17d29b8633
8 changed files with 212 additions and 31 deletions

View file

@ -1,5 +1,5 @@
use bot::{
BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy,
};
use itertools::Itertools;
@ -25,11 +25,11 @@ pub struct App {
impl App {
// Constructs a new instance of [`App`].
pub fn new(args: AppArgs) -> Self {
let bot_strategies: Vec<Box<dyn BotStrategy>> =
args.bot
.as_deref()
.map(|str_bots| {
str_bots
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
.bot
.as_deref()
.map(|str_bots| {
str_bots
.split(",")
.filter_map(|s| match s.trim() {
"dummy" => {
@ -44,6 +44,9 @@ impl App {
"ai" => Some(Box::new(StableBaselines3Strategy::default())
as Box<dyn BotStrategy>),
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
"dqnburn" => {
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
}
s if s.starts_with("ai:") => {
let path = s.trim_start_matches("ai:");
Some(Box::new(StableBaselines3Strategy::new(path))
@ -54,11 +57,16 @@ impl App {
Some(Box::new(DqnStrategy::new_with_model(path))
as Box<dyn BotStrategy>)
}
s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}")))
as Box<dyn BotStrategy>)
}
_ => None,
})
.collect()
})
.unwrap_or_default();
})
.unwrap_or_default();
let schools_enabled = false;
let should_quit = bot_strategies.len() > 1;
Self {