feat: bot sac & ppo save & load
This commit is contained in:
parent
afeb3561e0
commit
0c58490f87
8 changed files with 127 additions and 103 deletions
|
|
@ -9,26 +9,6 @@ edition = "2021"
|
|||
name = "burn_train"
|
||||
path = "src/burnrl/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn_valid"
|
||||
path = "src/burnrl/dqn_valid/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn_big"
|
||||
path = "src/burnrl/dqn_big/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn"
|
||||
path = "src/burnrl/dqn/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_sac_burn"
|
||||
path = "src/burnrl/sac/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_ppo_burn"
|
||||
path = "src/burnrl/ppo/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_simple"
|
||||
path = "src/dqn_simple/main.rs"
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||
LOGS_DIR="$ROOT/bot/models/logs"
|
||||
|
||||
CFG_SIZE=18
|
||||
ALGO="dqn"
|
||||
CFG_SIZE=17
|
||||
ALGO="sac"
|
||||
BINBOT=burn_train
|
||||
# BINBOT=train_ppo_burn
|
||||
# BINBOT=train_dqn_burn
|
||||
|
|
|
|||
|
|
@ -155,10 +155,10 @@ impl Environment for TrictracEnvironment {
|
|||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
|
||||
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
|
||||
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
|
||||
let path = "bot/models/logs/debug.log";
|
||||
if let Ok(mut out) = std::fs::File::create(path) {
|
||||
write!(out, "{:?}", history);
|
||||
write!(out, "{history:?}").expect("could not write history log");
|
||||
}
|
||||
"!!!!"
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -29,8 +29,10 @@ fn main() {
|
|||
batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
|
||||
// SAC
|
||||
min_probability: 1e-9,
|
||||
|
||||
// DQN
|
||||
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05, // 0.05
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
|
|
@ -38,6 +40,7 @@ fn main() {
|
|||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 2000.0, // 1000 ?
|
||||
|
||||
// PPO
|
||||
lambda: 0.95,
|
||||
epsilon_clip: 0.2,
|
||||
critic_weight: 0.5,
|
||||
|
|
@ -48,7 +51,7 @@ fn main() {
|
|||
|
||||
match algo.as_str() {
|
||||
"dqn" => {
|
||||
let agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
let _agent = dqn_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_model::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
|
||||
|
|
@ -58,23 +61,30 @@ fn main() {
|
|||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_big" => {
|
||||
let agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
let _agent = dqn_big_model::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
}
|
||||
"dqn_valid" => {
|
||||
let agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
let _agent = dqn_valid_model::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
}
|
||||
"sac" => {
|
||||
let agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
// println!("> Chargement du modèle pour test");
|
||||
// let loaded_model = sac_model::load_model(conf.dense_size, &path);
|
||||
// let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
|
||||
// burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
//
|
||||
// println!("> Test avec le modèle chargé");
|
||||
// demo_model(loaded_agent);
|
||||
let _agent = sac_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac_model::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo" => {
|
||||
let agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
let _agent = ppo_model::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_model::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
&_ => {
|
||||
dbg!("unknown algo {algo}");
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use crate::burnrl::utils::Config;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
@ -57,7 +61,10 @@ const MEMORY_SIZE: usize = 512;
|
|||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
// ) -> PPO<E, B, Net<B>> {
|
||||
|
|
@ -126,9 +133,61 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
memory.clear();
|
||||
}
|
||||
|
||||
let valid_agent = agent.valid(model);
|
||||
if let Some(path) = &conf.save_path {
|
||||
// save_model(???, path);
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
||||
|
||||
// Save the trained model (backend B) to a temporary file
|
||||
recorder
|
||||
.record(model.clone().into_record(), tmp_path.clone())
|
||||
.expect("Failed to save temporary model");
|
||||
|
||||
// Create a new model instance with the target backend (NdArray)
|
||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
// Load the record from the temporary file into the new model
|
||||
let record = recorder
|
||||
.load(tmp_path.clone(), &device)
|
||||
.expect("Failed to load temporary model");
|
||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
||||
|
||||
// Clean up the temporary file
|
||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
||||
|
||||
save_model(&model_with_loaded_weights, path);
|
||||
}
|
||||
let valid_agent = agent.valid(model);
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,10 @@ const MEMORY_SIZE: usize = 4096;
|
|||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
) -> impl Agent<E> {
|
||||
|
|
@ -105,9 +108,9 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
let state_dim = <<E as Environment>::StateType as State>::size();
|
||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
||||
|
||||
let mut actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
||||
|
||||
let mut agent = MyAgent::default();
|
||||
|
|
@ -134,8 +137,6 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
optimizer_config.init(),
|
||||
);
|
||||
|
||||
let mut policy_net = agent.model().clone();
|
||||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..conf.num_episodes {
|
||||
|
|
@ -186,33 +187,35 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
|
||||
let valid_agent = agent.valid(nets.actor);
|
||||
if let Some(path) = &conf.save_path {
|
||||
// save_model(???, path);
|
||||
if let Some(model) = valid_agent.model() {
|
||||
save_model(model, path);
|
||||
}
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
// pub fn save_model(model: ???, path: &String) {
|
||||
// let recorder = CompactRecorder::new();
|
||||
// let model_path = format!("{path}.mpk");
|
||||
// println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
// recorder
|
||||
// .record(model.clone().into_record(), model_path.into())
|
||||
// .unwrap();
|
||||
// }
|
||||
//
|
||||
// pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||
// let model_path = format!("{path}.mpk");
|
||||
// // println!("Chargement du modèle depuis : {model_path}");
|
||||
//
|
||||
// CompactRecorder::new()
|
||||
// .load(model_path.into(), &NdArrayDevice::default())
|
||||
// .map(|record| {
|
||||
// Actor::new(
|
||||
// <TrictracEnvironment as Environment>::StateType::size(),
|
||||
// dense_size,
|
||||
// <TrictracEnvironment as Environment>::ActionType::size(),
|
||||
// )
|
||||
// .load_record(record)
|
||||
// })
|
||||
// .ok()
|
||||
// }
|
||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Actor::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
|
@ -8,7 +8,6 @@ use store::MoveRules;
|
|||
|
||||
use crate::burnrl::dqn_model;
|
||||
use crate::burnrl::environment;
|
||||
use crate::burnrl::utils;
|
||||
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||
|
||||
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue