script train bots
This commit is contained in:
parent
e4b3092018
commit
778ac1817b
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -11,6 +11,4 @@ devenv.local.nix
|
||||||
|
|
||||||
# generated by samply rust profiler
|
# generated by samply rust profiler
|
||||||
profile.json
|
profile.json
|
||||||
|
bot/models
|
||||||
# IA modles used by bots
|
|
||||||
/models
|
|
||||||
|
|
|
||||||
38
bot/scripts/train.sh
Executable file
38
bot/scripts/train.sh
Executable file
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||||
|
LOGS_DIR="$ROOT/bot/models/logs"
|
||||||
|
|
||||||
|
CFG_SIZE=12
|
||||||
|
OPPONENT="random"
|
||||||
|
|
||||||
|
PLOT_EXT="png"
|
||||||
|
|
||||||
|
train() {
|
||||||
|
cargo build --release --bin=train_dqn_burn
|
||||||
|
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
|
||||||
|
LOGS="$LOGS_DIR/$NAME.out"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
|
||||||
|
}
|
||||||
|
|
||||||
|
plot() {
|
||||||
|
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||||
|
LOGS="$LOGS_DIR/$NAME"
|
||||||
|
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||||
|
for cfg in $cfgs; do
|
||||||
|
eval "$cfg"
|
||||||
|
done
|
||||||
|
|
||||||
|
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
|
||||||
|
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||||
|
grep -v "info:" |
|
||||||
|
awk -F '[ ,]' '{print $5}' |
|
||||||
|
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ "$1" = "plot" ]; then
|
||||||
|
plot
|
||||||
|
else
|
||||||
|
train
|
||||||
|
fi
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
|
|
@ -8,6 +9,7 @@ use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::fmt;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
|
|
@ -61,23 +63,56 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
const MEMORY_SIZE: usize = 8192;
|
const MEMORY_SIZE: usize = 8192;
|
||||||
|
|
||||||
pub struct DqnConfig {
|
pub struct DqnConfig {
|
||||||
|
pub min_steps: f32,
|
||||||
|
pub max_steps: usize,
|
||||||
pub num_episodes: usize,
|
pub num_episodes: usize,
|
||||||
// pub memory_size: usize,
|
|
||||||
pub dense_size: usize,
|
pub dense_size: usize,
|
||||||
pub eps_start: f64,
|
pub eps_start: f64,
|
||||||
pub eps_end: f64,
|
pub eps_end: f64,
|
||||||
pub eps_decay: f64,
|
pub eps_decay: f64,
|
||||||
|
|
||||||
|
pub gamma: f32,
|
||||||
|
pub tau: f32,
|
||||||
|
pub learning_rate: f32,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub clip_grad: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for DqnConfig {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
let mut s = String::new();
|
||||||
|
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||||
|
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||||
|
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||||
|
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||||
|
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||||
|
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||||
|
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||||
|
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||||
|
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||||
|
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||||
|
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||||
|
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||||
|
write!(f, "{s}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for DqnConfig {
|
impl Default for DqnConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
min_steps: 250.0,
|
||||||
|
max_steps: 2000,
|
||||||
num_episodes: 1000,
|
num_episodes: 1000,
|
||||||
// memory_size: 8192,
|
|
||||||
dense_size: 256,
|
dense_size: 256,
|
||||||
eps_start: 0.9,
|
eps_start: 0.9,
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
eps_decay: 1000.0,
|
eps_decay: 1000.0,
|
||||||
|
|
||||||
|
gamma: 0.999,
|
||||||
|
tau: 0.005,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 32,
|
||||||
|
clip_grad: 100.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -85,12 +120,14 @@ impl Default for DqnConfig {
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||||
conf: &DqnConfig,
|
conf: &DqnConfig,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
// ) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().min_steps = conf.min_steps;
|
||||||
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
<<E as Environment>::StateType as State>::size(),
|
<<E as Environment>::StateType as State>::size(),
|
||||||
|
|
@ -100,7 +137,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let mut agent = MyAgent::new(model);
|
let mut agent = MyAgent::new(model);
|
||||||
|
|
||||||
let config = DQNTrainingConfig::default();
|
// let config = DQNTrainingConfig::default();
|
||||||
|
let config = DQNTrainingConfig {
|
||||||
|
gamma: conf.gamma,
|
||||||
|
tau: conf.tau,
|
||||||
|
learning_rate: conf.learning_rate,
|
||||||
|
batch_size: conf.batch_size,
|
||||||
|
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||||
|
conf.clip_grad,
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||||
|
|
||||||
|
|
@ -145,12 +191,12 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
step += 1;
|
step += 1;
|
||||||
episode_duration += 1;
|
episode_duration += 1;
|
||||||
|
|
||||||
if snapshot.done() || episode_duration >= E::MAX_STEPS {
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
env.reset();
|
env.reset();
|
||||||
episode_done = true;
|
episode_done = true;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold:.3}, \"duration\": {}}}",
|
||||||
now.elapsed().unwrap().as_secs(),
|
now.elapsed().unwrap().as_secs(),
|
||||||
);
|
);
|
||||||
now = SystemTime::now();
|
now = SystemTime::now();
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,10 @@ pub struct TrictracEnvironment {
|
||||||
current_state: TrictracState,
|
current_state: TrictracState,
|
||||||
episode_reward: f32,
|
episode_reward: f32,
|
||||||
pub step_count: usize,
|
pub step_count: usize,
|
||||||
|
pub min_steps: f32,
|
||||||
|
pub max_steps: usize,
|
||||||
pub goodmoves_count: usize,
|
pub goodmoves_count: usize,
|
||||||
|
pub goodmoves_ratio: f32,
|
||||||
pub visualized: bool,
|
pub visualized: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -93,8 +96,6 @@ impl Environment for TrictracEnvironment {
|
||||||
type ActionType = TrictracAction;
|
type ActionType = TrictracAction;
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
|
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
||||||
|
|
@ -115,7 +116,10 @@ impl Environment for TrictracEnvironment {
|
||||||
current_state,
|
current_state,
|
||||||
episode_reward: 0.0,
|
episode_reward: 0.0,
|
||||||
step_count: 0,
|
step_count: 0,
|
||||||
|
min_steps: 250.0,
|
||||||
|
max_steps: 2000,
|
||||||
goodmoves_count: 0,
|
goodmoves_count: 0,
|
||||||
|
goodmoves_ratio: 0.0,
|
||||||
visualized,
|
visualized,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -135,10 +139,15 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
self.episode_reward = 0.0;
|
self.episode_reward = 0.0;
|
||||||
|
self.goodmoves_ratio = if self.step_count == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
self.goodmoves_count as f32 / self.step_count as f32
|
||||||
|
};
|
||||||
println!(
|
println!(
|
||||||
"correct moves: {} ({}%)",
|
"info: correct moves: {} ({}%)",
|
||||||
self.goodmoves_count,
|
self.goodmoves_count,
|
||||||
100 * self.goodmoves_count / self.step_count
|
(100.0 * self.goodmoves_ratio).round() as u32
|
||||||
);
|
);
|
||||||
self.step_count = 0;
|
self.step_count = 0;
|
||||||
self.goodmoves_count = 0;
|
self.goodmoves_count = 0;
|
||||||
|
|
@ -174,12 +183,12 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vérifier si la partie est terminée
|
// Vérifier si la partie est terminée
|
||||||
let done = self.game.stage == Stage::Ended
|
let max_steps = self.min_steps
|
||||||
|| self.game.determine_winner().is_some()
|
+ (self.max_steps as f32 - self.min_steps)
|
||||||
|| self.step_count >= Self::MAX_STEPS;
|
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||||
|
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||||
|
|
||||||
if done {
|
if done {
|
||||||
terminated = true;
|
|
||||||
// Récompense finale basée sur le résultat
|
// Récompense finale basée sur le résultat
|
||||||
if let Some(winner_id) = self.game.determine_winner() {
|
if let Some(winner_id) = self.game.determine_winner() {
|
||||||
if winner_id == self.active_player_id {
|
if winner_id == self.active_player_id {
|
||||||
|
|
@ -189,6 +198,7 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||||
|
|
||||||
// Mettre à jour l'état
|
// Mettre à jour l'état
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
|
|
@ -320,7 +330,7 @@ impl TrictracEnvironment {
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
if points > 0 {
|
if points > 0 {
|
||||||
println!("rolled for {reward}");
|
println!("info: rolled for {reward}");
|
||||||
}
|
}
|
||||||
// Récompense proportionnelle aux points
|
// Récompense proportionnelle aux points
|
||||||
}
|
}
|
||||||
|
|
@ -421,3 +431,9 @@ impl TrictracEnvironment {
|
||||||
reward
|
reward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||||
|
fn as_mut(&mut self) -> &mut Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,15 +11,29 @@ type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// println!("> Entraînement");
|
// println!("> Entraînement");
|
||||||
|
|
||||||
|
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||||
let conf = dqn_model::DqnConfig {
|
let conf = dqn_model::DqnConfig {
|
||||||
num_episodes: 40,
|
num_episodes: 40,
|
||||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
min_steps: 250.0, // min steps by episode (mise à jour par la fonction)
|
||||||
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
|
max_steps: 2000, // max steps by episode
|
||||||
dense_size: 256, // neural network complexity
|
dense_size: 256, // neural network complexity
|
||||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
eps_decay: 1500.0,
|
// eps_decay higher = epsilon decrease slower
|
||||||
|
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||||
|
// epsilon is updated at the start of each episode
|
||||||
|
eps_decay: 3000.0,
|
||||||
|
|
||||||
|
gamma: 0.999, // discount factor. Plus élevé = encourage stratégies à long terme
|
||||||
|
tau: 0.005, // soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||||
|
// plus lente moins sensible aux coups de chance
|
||||||
|
learning_rate: 0.001, // taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||||
|
// converger
|
||||||
|
batch_size: 32, // nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||||
|
clip_grad: 100.0, // plafonnement du gradient : limite max de correction à apporter
|
||||||
};
|
};
|
||||||
|
println!("{conf}----------");
|
||||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
let valid_agent = agent.valid();
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ MEMORY_SIZE
|
||||||
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
||||||
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
||||||
- Pourquoi c'est important :
|
- Pourquoi c'est important :
|
||||||
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
||||||
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
||||||
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
||||||
|
|
||||||
DENSE_SIZE
|
DENSE_SIZE
|
||||||
|
|
@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
|
||||||
|
|
||||||
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
||||||
nouvelles (EPS*\*).
|
nouvelles (EPS*\*).
|
||||||
|
|
||||||
|
## Paramètres DQNTrainingConfig
|
||||||
|
|
||||||
|
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
|
||||||
|
|
||||||
|
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
|
||||||
|
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
|
||||||
|
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
|
||||||
|
gain plus grand plus tard.
|
||||||
|
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
|
||||||
|
0.99 l'encourage à élaborer des stratégies à long terme.
|
||||||
|
|
||||||
|
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
|
||||||
|
|
||||||
|
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
|
||||||
|
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
|
||||||
|
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
|
||||||
|
réseau cible.
|
||||||
|
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
|
||||||
|
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
|
||||||
|
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
|
||||||
|
coup de chance (ou de malchance).
|
||||||
|
|
||||||
|
3. `learning_rate` (Taux d'apprentissage)
|
||||||
|
|
||||||
|
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
|
||||||
|
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
|
||||||
|
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
|
||||||
|
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
|
||||||
|
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
|
||||||
|
lent.
|
||||||
|
|
||||||
|
4. `batch_size` (Taille du lot)
|
||||||
|
|
||||||
|
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
|
||||||
|
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
|
||||||
|
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
|
||||||
|
ses poids.
|
||||||
|
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
|
||||||
|
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
|
||||||
|
particulier qui pourrait être une anomalie.
|
||||||
|
|
||||||
|
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
|
||||||
|
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
|
||||||
|
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
|
||||||
|
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
|
||||||
|
réseau a appris.
|
||||||
|
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
|
||||||
|
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
|
||||||
|
mais on va corriger calmement, sans tout casser".
|
||||||
|
|
|
||||||
9
justfile
9
justfile
|
|
@ -9,8 +9,8 @@ shell:
|
||||||
runcli:
|
runcli:
|
||||||
RUST_LOG=info cargo run --bin=client_cli
|
RUST_LOG=info cargo run --bin=client_cli
|
||||||
runclibots:
|
runclibots:
|
||||||
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
|
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||||
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
|
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||||
match:
|
match:
|
||||||
cargo build --release --bin=client_cli
|
cargo build --release --bin=client_cli
|
||||||
|
|
@ -28,10 +28,9 @@ trainsimple:
|
||||||
trainbot:
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
cargo build --release --bin=train_dqn_burn
|
./bot/scripts/train.sh
|
||||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
|
||||||
plottrainbot:
|
plottrainbot:
|
||||||
cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png
|
./bot/scripts/train.sh plot
|
||||||
debugtrainbot:
|
debugtrainbot:
|
||||||
cargo build --bin=train_dqn_burn
|
cargo build --bin=train_dqn_burn
|
||||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue