Compare commits
No commits in common. "744a70cf1d8ce324bee423336be6338aab0bf46c" and "28c2aa836ff1a0626466d13f06f37d4ed6156865" have entirely different histories.
744a70cf1d
...
28c2aa836f
|
|
@ -179,9 +179,9 @@ impl Environment for TrictracEnvironment {
|
||||||
// Récompense finale basée sur le résultat
|
// Récompense finale basée sur le résultat
|
||||||
if let Some(winner_id) = self.game.determine_winner() {
|
if let Some(winner_id) = self.game.determine_winner() {
|
||||||
if winner_id == self.active_player_id {
|
if winner_id == self.active_player_id {
|
||||||
reward += 50.0; // Victoire
|
reward += 100.0; // Victoire
|
||||||
} else {
|
} else {
|
||||||
reward -= 25.0; // Défaite
|
reward -= 50.0; // Défaite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -259,7 +259,7 @@ impl TrictracEnvironment {
|
||||||
// }
|
// }
|
||||||
TrictracAction::Go => {
|
TrictracAction::Go => {
|
||||||
// Continuer après avoir gagné un trou
|
// Continuer après avoir gagné un trou
|
||||||
reward += 0.2;
|
reward += 0.4;
|
||||||
Some(GameEvent::Go {
|
Some(GameEvent::Go {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
})
|
})
|
||||||
|
|
@ -288,7 +288,7 @@ impl TrictracEnvironment {
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
reward += 0.2;
|
reward += 0.4;
|
||||||
Some(GameEvent::Move {
|
Some(GameEvent::Move {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
moves: (checker_move1, checker_move2),
|
moves: (checker_move1, checker_move2),
|
||||||
|
|
@ -313,8 +313,6 @@ impl TrictracEnvironment {
|
||||||
};
|
};
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
|
||||||
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -358,7 +356,7 @@ impl TrictracEnvironment {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkPoints => {
|
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||||
let opponent_color = store::Color::Black;
|
let opponent_color = store::Color::Black;
|
||||||
let dice_roll_count = self
|
let dice_roll_count = self
|
||||||
.game
|
.game
|
||||||
|
|
@ -368,31 +366,14 @@ impl TrictracEnvironment {
|
||||||
.dice_roll_count;
|
.dice_roll_count;
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
let points = points_rules.get_points(dice_roll_count).0;
|
||||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
points,
|
points,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkAdvPoints => {
|
|
||||||
let opponent_color = store::Color::Black;
|
|
||||||
let dice_roll_count = self
|
|
||||||
.game
|
|
||||||
.players
|
|
||||||
.get(&self.opponent_id)
|
|
||||||
.unwrap()
|
|
||||||
.dice_roll_count;
|
|
||||||
let points_rules =
|
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
|
||||||
let points = points_rules.get_points(dice_roll_count).1;
|
|
||||||
// pas de reward : déjà comptabilisé lors du tour de blanc
|
|
||||||
GameEvent::Mark {
|
|
||||||
player_id: self.opponent_id,
|
|
||||||
points,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TurnStage::HoldOrGoChoice => {
|
TurnStage::HoldOrGoChoice => {
|
||||||
// Stratégie simple : toujours continuer
|
// Stratégie simple : toujours continuer
|
||||||
GameEvent::Go {
|
GameEvent::Go {
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,15 @@ type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
type Env = environment::TrictracEnvironment;
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// println!("> Entraînement");
|
println!("> Entraînement");
|
||||||
let conf = dqn_model::DqnConfig {
|
let conf = dqn_model::DqnConfig {
|
||||||
num_episodes: 40,
|
num_episodes: 50,
|
||||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||||
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
|
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||||
dense_size: 256, // neural network complexity
|
dense_size: 256, // neural network complexity
|
||||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
eps_decay: 3000.0,
|
eps_decay: 1000.0,
|
||||||
};
|
};
|
||||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -357,8 +357,8 @@ impl TrictracEnv {
|
||||||
&self.game_state.board,
|
&self.game_state.board,
|
||||||
self.game_state.dice,
|
self.game_state.dice,
|
||||||
);
|
);
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
let points = points_rules.get_points(dice_roll_count).0;
|
||||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward -= 0.3 * points as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_player_id,
|
player_id: self.opponent_player_id,
|
||||||
|
|
|
||||||
|
|
@ -46,14 +46,7 @@ impl BotStrategy for ClientStrategy {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_adv_points(&self) -> u8 {
|
fn calculate_adv_points(&self) -> u8 {
|
||||||
let dice_roll_count = self
|
self.calculate_points()
|
||||||
.get_game()
|
|
||||||
.players
|
|
||||||
.get(&self.player_id)
|
|
||||||
.unwrap()
|
|
||||||
.dice_roll_count;
|
|
||||||
let points_rules = PointsRules::new(&Color::White, &self.game.board, self.game.dice);
|
|
||||||
points_rules.get_points(dice_roll_count).1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn choose_go(&self) -> bool {
|
fn choose_go(&self) -> bool {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@
|
||||||
|
|
||||||
# dev tools
|
# dev tools
|
||||||
pkgs.samply # code profiler
|
pkgs.samply # code profiler
|
||||||
pkgs.feedgnuplot # to visualize bots training results
|
|
||||||
|
|
||||||
# for bevy
|
# for bevy
|
||||||
pkgs.alsa-lib
|
pkgs.alsa-lib
|
||||||
|
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
||||||
# DQN avec burn-rl
|
|
||||||
|
|
||||||
## Paramètre d'entraînement dans dqn/burnrl/dqn_model.rs
|
|
||||||
|
|
||||||
Ces constantes sont des hyperparamètres, c'est-à-dire des réglages que l'on fixe avant l'entraînement et qui conditionnent la manière dont le modèle va apprendre.
|
|
||||||
|
|
||||||
MEMORY_SIZE
|
|
||||||
|
|
||||||
- Ce que c'est : La taille de la "mémoire de rejeu" (Replay Memory/Buffer).
|
|
||||||
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
|
|
||||||
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
|
|
||||||
- Pourquoi c'est important :
|
|
||||||
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
|
|
||||||
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
|
|
||||||
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
|
|
||||||
|
|
||||||
DENSE_SIZE
|
|
||||||
|
|
||||||
- Ce que c'est : La taille des couches cachées du réseau de neurones. "Dense" signifie que chaque neurone d'une couche est connecté à tous les neurones de la couche suivante.
|
|
||||||
- À quoi ça sert : C'est la "capacité de réflexion" de votre agent. Le réseau de neurones (ici, Net) prend l'état du jeu en entrée, le fait passer à travers des couches de calcul (de taille DENSE_SIZE), et sort une
|
|
||||||
estimation de la qualité de chaque action possible.
|
|
||||||
- Pourquoi c'est important :
|
|
||||||
- Une valeur trop petite : le modèle ne sera pas assez "intelligent" pour apprendre les stratégies complexes du TricTrac.
|
|
||||||
- Une valeur trop grande : l'entraînement sera plus lent et le modèle pourrait "sur-apprendre" (overfitting), c'est-à-dire devenir très bon sur les situations vues en entraînement mais incapable de généraliser
|
|
||||||
sur de nouvelles situations.
|
|
||||||
- Dans votre code : const DENSE_SIZE: usize = 128; définit que les couches cachées du réseau auront 128 neurones.
|
|
||||||
|
|
||||||
EPS_START, EPS_END et EPS_DECAY
|
|
||||||
|
|
||||||
Ces trois constantes gèrent la stratégie d'exploration de l'agent, appelée "epsilon-greedy". Le but est de trouver un équilibre entre :
|
|
||||||
|
|
||||||
- L'Exploitation : Jouer le coup que le modèle pense être le meilleur.
|
|
||||||
- L'Exploration : Jouer un coup au hasard pour découvrir de nouvelles stratégies, potentiellement meilleures.
|
|
||||||
|
|
||||||
epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
|
|
||||||
|
|
||||||
- `EPS_START` (Epsilon de départ) :
|
|
||||||
|
|
||||||
- Ce que c'est : La valeur d'epsilon au tout début de l'entraînement.
|
|
||||||
- Rôle : Au début, le modèle ne sait rien. Il est donc crucial qu'il explore beaucoup pour accumuler des expériences variées. Une valeur élevée (proche de 1.0) est typique.
|
|
||||||
- Dans votre code : const EPS_START: f64 = 0.9; signifie qu'au début, l'agent a 90% de chances de jouer un coup au hasard.
|
|
||||||
|
|
||||||
- `EPS_END` (Epsilon final) :
|
|
||||||
|
|
||||||
- Ce que c'est : La valeur minimale d'epsilon, atteinte après un certain nombre d'étapes.
|
|
||||||
- Rôle : Même après un long entraînement, on veut conserver une petite part d'exploration pour éviter que l'agent ne se fige dans une stratégie sous-optimale.
|
|
||||||
- Dans votre code : const EPS_END: f64 = 0.05; signifie qu'à la fin, l'agent explorera encore avec 5% de probabilité.
|
|
||||||
|
|
||||||
- `EPS_DECAY` (Décroissance d'epsilon) :
|
|
||||||
- Ce que c'est : Contrôle la vitesse à laquelle epsilon passe de EPS_START à EPS_END.
|
|
||||||
- Rôle : C'est un facteur de "lissage" dans la formule de décroissance exponentielle. Plus cette valeur est élevée, plus la décroissance est lente, et donc plus l'agent passera de temps à explorer.
|
|
||||||
- Dans votre code : const EPS_DECAY: f64 = 1000.0; est utilisé dans la formule EPS_END + (EPS_START - EPS_END) \* f64::exp(-(step as f64) / EPS_DECAY); pour faire diminuer progressivement la valeur d'epsilon à
|
|
||||||
chaque étape (step) de l'entraînement.
|
|
||||||
|
|
||||||
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
|
|
||||||
nouvelles (EPS*\*).
|
|
||||||
10
justfile
10
justfile
|
|
@ -9,8 +9,7 @@ shell:
|
||||||
runcli:
|
runcli:
|
||||||
RUST_LOG=info cargo run --bin=client_cli
|
RUST_LOG=info cargo run --bin=client_cli
|
||||||
runclibots:
|
runclibots:
|
||||||
#RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy
|
RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy
|
||||||
RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
|
||||||
match:
|
match:
|
||||||
cargo build --release --bin=client_cli
|
cargo build --release --bin=client_cli
|
||||||
LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn
|
LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn
|
||||||
|
|
@ -24,12 +23,9 @@ pythonlib:
|
||||||
trainbot:
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
|
|
||||||
cargo build --release --bin=train_dqn_burn
|
cargo build --release --bin=train_dqn_burn
|
||||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn
|
||||||
plottrainbot:
|
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
|
||||||
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
|
||||||
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
|
||||||
debugtrainbot:
|
debugtrainbot:
|
||||||
cargo build --bin=train_dqn_burn
|
cargo build --bin=train_dqn_burn
|
||||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue