script train bots
This commit is contained in:
parent
e4b3092018
commit
778ac1817b
7 changed files with 191 additions and 30 deletions
|
|
@ -1,3 +1,4 @@
|
|||
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
|
|
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
|
|||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||
use std::fmt;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
|||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub struct DqnConfig {
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub num_episodes: usize,
|
||||
// pub memory_size: usize,
|
||||
pub dense_size: usize,
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
pub eps_decay: f64,
|
||||
|
||||
pub gamma: f32,
|
||||
pub tau: f32,
|
||||
pub learning_rate: f32,
|
||||
pub batch_size: usize,
|
||||
pub clip_grad: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for DqnConfig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut s = String::new();
|
||||
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
num_episodes: 1000,
|
||||
// memory_size: 8192,
|
||||
dense_size: 256,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
gamma: 0.999,
|
||||
tau: 0.005,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
clip_grad: 100.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -85,12 +120,14 @@ impl Default for DqnConfig {
|
|||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
conf: &DqnConfig,
|
||||
visualized: bool,
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().min_steps = conf.min_steps;
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
|
|
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
|
||||
let mut agent = MyAgent::new(model);
|
||||
|
||||
let config = DQNTrainingConfig::default();
|
||||
// let config = DQNTrainingConfig::default();
|
||||
let config = DQNTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
|
|
@ -145,12 +191,12 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= E::MAX_STEPS {
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
now = SystemTime::now();
|
||||
|
|
|
|||
|
|
@ -84,7 +84,10 @@ pub struct TrictracEnvironment {
|
|||
current_state: TrictracState,
|
||||
episode_reward: f32,
|
||||
pub step_count: usize,
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub goodmoves_count: usize,
|
||||
pub goodmoves_ratio: f32,
|
||||
pub visualized: bool,
|
||||
}
|
||||
|
||||
|
|
@ -93,8 +96,6 @@ impl Environment for TrictracEnvironment {
|
|||
type ActionType = TrictracAction;
|
||||
type RewardType = f32;
|
||||
|
||||
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
|
||||
|
||||
fn new(visualized: bool) -> Self {
|
||||
let mut game = GameState::new(false);
|
||||
|
||||
|
|
@ -115,7 +116,10 @@ impl Environment for TrictracEnvironment {
|
|||
current_state,
|
||||
episode_reward: 0.0,
|
||||
step_count: 0,
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
goodmoves_count: 0,
|
||||
goodmoves_ratio: 0.0,
|
||||
visualized,
|
||||
}
|
||||
}
|
||||
|
|
@ -135,10 +139,15 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward = 0.0;
|
||||
self.goodmoves_ratio = if self.step_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
println!(
|
||||
"correct moves: {} ({}%)",
|
||||
"info: correct moves: {} ({}%)",
|
||||
self.goodmoves_count,
|
||||
100 * self.goodmoves_count / self.step_count
|
||||
(100.0 * self.goodmoves_ratio).round() as u32
|
||||
);
|
||||
self.step_count = 0;
|
||||
self.goodmoves_count = 0;
|
||||
|
|
@ -174,12 +183,12 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game.stage == Stage::Ended
|
||||
|| self.game.determine_winner().is_some()
|
||||
|| self.step_count >= Self::MAX_STEPS;
|
||||
let max_steps = self.min_steps
|
||||
+ (self.max_steps as f32 - self.min_steps)
|
||||
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||
|
||||
if done {
|
||||
terminated = true;
|
||||
// Récompense finale basée sur le résultat
|
||||
if let Some(winner_id) = self.game.determine_winner() {
|
||||
if winner_id == self.active_player_id {
|
||||
|
|
@ -189,6 +198,7 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
}
|
||||
}
|
||||
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
|
||||
// Mettre à jour l'état
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
|
|
@ -320,7 +330,7 @@ impl TrictracEnvironment {
|
|||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
println!("rolled for {reward}");
|
||||
println!("info: rolled for {reward}");
|
||||
}
|
||||
// Récompense proportionnelle aux points
|
||||
}
|
||||
|
|
@ -421,3 +431,9 @@ impl TrictracEnvironment {
|
|||
reward
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||
fn as_mut(&mut self) -> &mut Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
|
|||
|
||||
fn main() {
|
||||
// println!("> Entraînement");
|
||||
|
||||
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||
let conf = dqn_model::DqnConfig {
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
|
||||
max_steps: 2000, // max steps by episode
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1500.0,
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 3000.0,
|
||||
|
||||
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue