fix: train bot dqn burnrl : extract config
This commit is contained in:
parent
ad5ae17168
commit
28c2aa836f
|
|
@ -7,7 +7,7 @@ edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "train_dqn_burn"
|
name = "train_dqn_burn"
|
||||||
path = "src/burnrl/main.rs"
|
path = "src/dqn/burnrl/main.rs"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "train_dqn"
|
name = "train_dqn"
|
||||||
|
|
|
||||||
|
|
@ -58,17 +58,35 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
const MEMORY_SIZE: usize = 4096;
|
const MEMORY_SIZE: usize = 8192;
|
||||||
const DENSE_SIZE: usize = 128;
|
|
||||||
const EPS_DECAY: f64 = 1000.0;
|
pub struct DqnConfig {
|
||||||
const EPS_START: f64 = 0.9;
|
pub num_episodes: usize,
|
||||||
const EPS_END: f64 = 0.05;
|
// pub memory_size: usize,
|
||||||
|
pub dense_size: usize,
|
||||||
|
pub eps_start: f64,
|
||||||
|
pub eps_end: f64,
|
||||||
|
pub eps_decay: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
num_episodes: 1000,
|
||||||
|
// memory_size: 8192,
|
||||||
|
dense_size: 256,
|
||||||
|
eps_start: 0.9,
|
||||||
|
eps_end: 0.05,
|
||||||
|
eps_decay: 1000.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
num_episodes: usize,
|
conf: &DqnConfig,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
// ) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
|
|
@ -76,7 +94,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
<<E as Environment>::StateType as State>::size(),
|
<<E as Environment>::StateType as State>::size(),
|
||||||
DENSE_SIZE,
|
conf.dense_size,
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
<<E as Environment>::ActionType as Action>::size(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -94,7 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let mut step = 0_usize;
|
let mut step = 0_usize;
|
||||||
|
|
||||||
for episode in 0..num_episodes {
|
for episode in 0..conf.num_episodes {
|
||||||
let mut episode_done = false;
|
let mut episode_done = false;
|
||||||
let mut episode_reward: ElemType = 0.0;
|
let mut episode_reward: ElemType = 0.0;
|
||||||
let mut episode_duration = 0_usize;
|
let mut episode_duration = 0_usize;
|
||||||
|
|
@ -102,8 +120,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
let mut now = SystemTime::now();
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
while !episode_done {
|
while !episode_done {
|
||||||
let eps_threshold =
|
let eps_threshold = conf.eps_end
|
||||||
EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY);
|
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
|
||||||
let action =
|
let action =
|
||||||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
|
|
|
||||||
|
|
@ -91,8 +91,7 @@ impl Environment for TrictracEnvironment {
|
||||||
type ActionType = TrictracAction;
|
type ActionType = TrictracAction;
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
||||||
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
|
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
@ -260,7 +259,7 @@ impl TrictracEnvironment {
|
||||||
// }
|
// }
|
||||||
TrictracAction::Go => {
|
TrictracAction::Go => {
|
||||||
// Continuer après avoir gagné un trou
|
// Continuer après avoir gagné un trou
|
||||||
reward += 0.2;
|
reward += 0.4;
|
||||||
Some(GameEvent::Go {
|
Some(GameEvent::Go {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
})
|
})
|
||||||
|
|
@ -289,7 +288,7 @@ impl TrictracEnvironment {
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
reward += 0.2;
|
reward += 0.4;
|
||||||
Some(GameEvent::Move {
|
Some(GameEvent::Move {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
moves: (checker_move1, checker_move2),
|
moves: (checker_move1, checker_move2),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
|
||||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
|
@ -10,8 +10,16 @@ type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("> Entraînement");
|
println!("> Entraînement");
|
||||||
let num_episodes = 50;
|
let conf = dqn_model::DqnConfig {
|
||||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
num_episodes: 50,
|
||||||
|
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||||
|
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||||
|
dense_size: 256, // neural network complexity
|
||||||
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||||
|
eps_end: 0.05,
|
||||||
|
eps_decay: 1000.0,
|
||||||
|
};
|
||||||
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
let valid_agent = agent.valid();
|
||||||
|
|
||||||
|
|
@ -24,7 +32,7 @@ fn main() {
|
||||||
// demo_model::<Env>(valid_agent);
|
// demo_model::<Env>(valid_agent);
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model(&path);
|
let loaded_model = load_model(conf.dense_size, &path);
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model);
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
|
|
@ -40,10 +48,7 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
|
|
||||||
const DENSE_SIZE: usize = 128;
|
|
||||||
|
|
||||||
let model_path = format!("{}_model.mpk", path);
|
let model_path = format!("{}_model.mpk", path);
|
||||||
println!("Chargement du modèle depuis : {}", model_path);
|
println!("Chargement du modèle depuis : {}", model_path);
|
||||||
|
|
||||||
|
|
@ -56,7 +61,7 @@ fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
|
|
||||||
dqn_model::Net::new(
|
dqn_model::Net::new(
|
||||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||||
DENSE_SIZE,
|
dense_size,
|
||||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
)
|
)
|
||||||
.load_record(record)
|
.load_record(record)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue