diff --git a/bot/Cargo.toml b/bot/Cargo.toml index 20c4e93..c775179 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [[bin]] -name = "burn_demo" +name = "burn_train" path = "src/burnrl/main.rs" [[bin]] diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 4c02189..b9f7f2a 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -3,8 +3,9 @@ ROOT="$(cd "$(dirname "$0")" && pwd)/../.." LOGS_DIR="$ROOT/bot/models/logs" -CFG_SIZE=12 -BINBOT=train_sac_burn +CFG_SIZE=18 +ALGO="dqn" +BINBOT=burn_train # BINBOT=train_ppo_burn # BINBOT=train_dqn_burn # BINBOT=train_dqn_burn_big @@ -16,14 +17,14 @@ PLOT_EXT="png" train() { cargo build --release --bin=$BINBOT NAME="$(date +%Y-%m-%d_%H:%M:%S)" - LOGS="$LOGS_DIR/$BINBOT/$NAME.out" - mkdir -p "$LOGS_DIR/$BINBOT" - LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" | tee "$LOGS" + LOGS="$LOGS_DIR/$ALGO/$NAME.out" + mkdir -p "$LOGS_DIR/$ALGO" + LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" $ALGO | tee "$LOGS" } plot() { - NAME=$(ls -rt "$LOGS_DIR/$BINBOT" | tail -n 1) - LOGS="$LOGS_DIR/$BINBOT/$NAME" + NAME=$(ls -rt "$LOGS_DIR/$ALGO" | tail -n 1) + LOGS="$LOGS_DIR/$ALGO/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do eval "$cfg" @@ -33,7 +34,7 @@ plot() { tail -n +$((CFG_SIZE + 2)) "$LOGS" | grep -v "info:" | awk -F '[ ,]' '{print $5}' | - feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$BINBOT/$NAME.$PLOT_EXT" + feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT" } if [ "$1" = "plot" ]; then diff --git a/bot/src/burnrl/main.rs b/bot/src/burnrl/main.rs index 24759f0..a911e06 100644 --- a/bot/src/burnrl/main.rs +++ b/bot/src/burnrl/main.rs @@ -1,19 +1,20 @@ -use bot::burnrl::sac_model as burn_model; -// use bot::burnrl::dqn_big_model as burn_model; -// use bot::burnrl::dqn_model as burn_model; -// use bot::burnrl::environment_big::TrictracEnvironment; use bot::burnrl::environment::TrictracEnvironment; +use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig; +use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid; use bot::burnrl::utils::{demo_model, Config}; +use bot::burnrl::{dqn_big_model, dqn_model, dqn_valid_model, ppo_model, sac_model}; use burn::backend::{Autodiff, NdArray}; -use burn_rl::agent::SAC as MyAgent; -// use burn_rl::agent::DQN as MyAgent; use burn_rl::base::ElemType; +use std::env; type Backend = Autodiff>; -type Env = TrictracEnvironment; fn main() { - let path = "bot/models/burnrl_dqn".to_string(); + let args: Vec = env::args().collect(); + let algo = &args[1]; + // let dir_path = &args[2]; + + let path = format!("bot/models/burnrl_{algo}"); let conf = Config { save_path: Some(path.clone()), num_episodes: 30, // 40 @@ -45,14 +46,38 @@ fn main() { }; println!("{conf}----------"); - let agent = burn_model::run::(&conf, false); //true); + match algo.as_str() { + "dqn" => { + let agent = dqn_model::run::(&conf, false); + println!("> Chargement du modèle pour test"); + let loaded_model = dqn_model::load_model(conf.dense_size, &path); + let loaded_agent: burn_rl::agent::DQN = + burn_rl::agent::DQN::new(loaded_model.unwrap()); - // println!("> Chargement du modèle pour test"); - // let loaded_model = burn_model::load_model(conf.dense_size, &path); - // let loaded_agent: MyAgent = MyAgent::new(loaded_model.unwrap()); - // - // println!("> Test avec le modèle chargé"); - // demo_model(loaded_agent); - - // demo_model::(agent); + println!("> Test avec le modèle chargé"); + demo_model(loaded_agent); + } + "dqn_big" => { + let agent = dqn_big_model::run::(&conf, false); + } + "dqn_valid" => { + let agent = dqn_valid_model::run::(&conf, false); + } + "sac" => { + let agent = sac_model::run::(&conf, false); + // println!("> Chargement du modèle pour test"); + // let loaded_model = sac_model::load_model(conf.dense_size, &path); + // let loaded_agent: burn_rl::agent::SAC = + // burn_rl::agent::SAC::new(loaded_model.unwrap()); + // + // println!("> Test avec le modèle chargé"); + // demo_model(loaded_agent); + } + "ppo" => { + let agent = ppo_model::run::(&conf, false); + } + &_ => { + dbg!("unknown algo {algo}"); + } + } }