train burn-rl with integers
This commit is contained in:
parent
5b02293221
commit
e4b3092018
26
CLAUDE.md
26
CLAUDE.md
|
|
@ -1,26 +0,0 @@
|
||||||
# Trictrac Project Guidelines
|
|
||||||
|
|
||||||
## Build & Run Commands
|
|
||||||
- Build: `cargo build`
|
|
||||||
- Test: `cargo test`
|
|
||||||
- Test specific: `cargo test -- test_name`
|
|
||||||
- Lint: `cargo clippy`
|
|
||||||
- Format: `cargo fmt`
|
|
||||||
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
|
|
||||||
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
|
|
||||||
- Build Python lib: `maturin build -m store/Cargo.toml --release`
|
|
||||||
|
|
||||||
## Code Style
|
|
||||||
- Use Rust 2021 edition idioms
|
|
||||||
- Error handling: Use Result<T, Error> pattern with custom Error types
|
|
||||||
- Naming: snake_case for functions/variables, CamelCase for types
|
|
||||||
- Imports: Group standard lib, external crates, then internal modules
|
|
||||||
- Module structure: Prefer small, focused modules with clear responsibilities
|
|
||||||
- Documentation: Document public APIs with doc comments
|
|
||||||
- Testing: Write unit tests in same file as implementation
|
|
||||||
- Python bindings: Use pyo3 for creating Python modules
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
- Core game logic in `store` crate
|
|
||||||
- Multiple clients: CLI, TUI, Bevy (graphical)
|
|
||||||
- Bot interfaces in `bot` crate
|
|
||||||
|
|
@ -150,11 +150,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
episode_done = true;
|
episode_done = true;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
|
||||||
episode,
|
now.elapsed().unwrap().as_secs(),
|
||||||
episode_reward,
|
|
||||||
episode_duration,
|
|
||||||
now.elapsed().unwrap().as_secs()
|
|
||||||
);
|
);
|
||||||
now = SystemTime::now();
|
now = SystemTime::now();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct TrictracState {
|
pub struct TrictracState {
|
||||||
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
|
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State for TrictracState {
|
impl State for TrictracState {
|
||||||
type Data = [f32; 36];
|
type Data = [i8; 36];
|
||||||
|
|
||||||
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
||||||
Tensor::from_floats(self.data, &B::Device::default())
|
Tensor::from_floats(self.data, &B::Device::default())
|
||||||
|
|
@ -25,8 +25,8 @@ impl State for TrictracState {
|
||||||
impl TrictracState {
|
impl TrictracState {
|
||||||
/// Convertit un GameState en TrictracState
|
/// Convertit un GameState en TrictracState
|
||||||
pub fn from_game_state(game_state: &GameState) -> Self {
|
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||||
let state_vec = game_state.to_vec_float();
|
let state_vec = game_state.to_vec();
|
||||||
let mut data = [0.0; 36];
|
let mut data = [0; 36];
|
||||||
|
|
||||||
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||||
let copy_len = state_vec.len().min(36);
|
let copy_len = state_vec.len().min(36);
|
||||||
|
|
@ -39,6 +39,7 @@ impl TrictracState {
|
||||||
/// Actions possibles dans Trictrac pour burn-rl
|
/// Actions possibles dans Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct TrictracAction {
|
pub struct TrictracAction {
|
||||||
|
// u32 as required by burn_rl::base::Action type
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,7 +83,8 @@ pub struct TrictracEnvironment {
|
||||||
opponent_id: PlayerId,
|
opponent_id: PlayerId,
|
||||||
current_state: TrictracState,
|
current_state: TrictracState,
|
||||||
episode_reward: f32,
|
episode_reward: f32,
|
||||||
step_count: usize,
|
pub step_count: usize,
|
||||||
|
pub goodmoves_count: usize,
|
||||||
pub visualized: bool,
|
pub visualized: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -91,7 +93,7 @@ impl Environment for TrictracEnvironment {
|
||||||
type ActionType = TrictracAction;
|
type ActionType = TrictracAction;
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
@ -113,6 +115,7 @@ impl Environment for TrictracEnvironment {
|
||||||
current_state,
|
current_state,
|
||||||
episode_reward: 0.0,
|
episode_reward: 0.0,
|
||||||
step_count: 0,
|
step_count: 0,
|
||||||
|
goodmoves_count: 0,
|
||||||
visualized,
|
visualized,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +135,13 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
self.episode_reward = 0.0;
|
self.episode_reward = 0.0;
|
||||||
|
println!(
|
||||||
|
"correct moves: {} ({}%)",
|
||||||
|
self.goodmoves_count,
|
||||||
|
100 * self.goodmoves_count / self.step_count
|
||||||
|
);
|
||||||
self.step_count = 0;
|
self.step_count = 0;
|
||||||
|
self.goodmoves_count = 0;
|
||||||
|
|
||||||
Snapshot::new(self.current_state, 0.0, false)
|
Snapshot::new(self.current_state, 0.0, false)
|
||||||
}
|
}
|
||||||
|
|
@ -149,14 +158,9 @@ impl Environment for TrictracEnvironment {
|
||||||
// Exécuter l'action si c'est le tour de l'agent DQN
|
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||||
if self.game.active_player_id == self.active_player_id {
|
if self.game.active_player_id == self.active_player_id {
|
||||||
if let Some(action) = trictrac_action {
|
if let Some(action) = trictrac_action {
|
||||||
match self.execute_action(action) {
|
reward = self.execute_action(action);
|
||||||
Ok(action_reward) => {
|
if reward != Self::ERROR_REWARD {
|
||||||
reward = action_reward;
|
self.goodmoves_count += 1;
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
// Action invalide, pénalité
|
|
||||||
reward = -1.0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Action non convertible, pénalité
|
// Action non convertible, pénalité
|
||||||
|
|
@ -202,6 +206,9 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrictracEnvironment {
|
impl TrictracEnvironment {
|
||||||
|
const ERROR_REWARD: f32 = -1.12121;
|
||||||
|
const REWARD_RATIO: f32 = 1.0;
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
|
|
@ -228,10 +235,11 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
fn execute_action(
|
// fn execute_action(
|
||||||
&mut self,
|
// &mut self,
|
||||||
action: dqn_common::TrictracAction,
|
// action: dqn_common::TrictracAction,
|
||||||
) -> Result<f32, Box<dyn std::error::Error>> {
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
|
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
|
||||||
use dqn_common::TrictracAction;
|
use dqn_common::TrictracAction;
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
|
|
@ -310,16 +318,22 @@ impl TrictracEnvironment {
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
|
if points > 0 {
|
||||||
|
println!("rolled for {reward}");
|
||||||
|
}
|
||||||
|
// Récompense proportionnelle aux points
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Pénalité pour action invalide
|
// Pénalité pour action invalide
|
||||||
reward -= 2.0;
|
// on annule les précédents reward
|
||||||
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
|
reward = Self::ERROR_REWARD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(reward)
|
reward
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fait jouer l'adversaire avec une stratégie simple
|
/// Fait jouer l'adversaire avec une stratégie simple
|
||||||
|
|
@ -329,15 +343,14 @@ impl TrictracEnvironment {
|
||||||
// Si c'est le tour de l'adversaire, jouer automatiquement
|
// Si c'est le tour de l'adversaire, jouer automatiquement
|
||||||
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
// Utiliser la stratégie default pour l'adversaire
|
// Utiliser la stratégie default pour l'adversaire
|
||||||
use crate::strategy::default::DefaultStrategy;
|
|
||||||
use crate::BotStrategy;
|
use crate::BotStrategy;
|
||||||
|
|
||||||
let mut default_strategy = DefaultStrategy::default();
|
let mut strategy = crate::strategy::random::RandomStrategy::default();
|
||||||
default_strategy.set_player_id(self.opponent_id);
|
strategy.set_player_id(self.opponent_id);
|
||||||
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
||||||
default_strategy.set_color(color);
|
strategy.set_color(color);
|
||||||
}
|
}
|
||||||
*default_strategy.get_mut_game() = self.game.clone();
|
*strategy.get_mut_game() = self.game.clone();
|
||||||
|
|
||||||
// Exécuter l'action selon le turn_stage
|
// Exécuter l'action selon le turn_stage
|
||||||
let event = match self.game.turn_stage {
|
let event = match self.game.turn_stage {
|
||||||
|
|
@ -365,7 +378,7 @@ impl TrictracEnvironment {
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
|
|
@ -397,7 +410,7 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
TurnStage::Move => GameEvent::Move {
|
TurnStage::Move => GameEvent::Move {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
moves: default_strategy.choose_move(),
|
moves: strategy.choose_move(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,11 @@ fn main() {
|
||||||
let conf = dqn_model::DqnConfig {
|
let conf = dqn_model::DqnConfig {
|
||||||
num_episodes: 40,
|
num_episodes: 40,
|
||||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
|
||||||
dense_size: 256, // neural network complexity
|
dense_size: 256, // neural network complexity
|
||||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||||
eps_end: 0.05,
|
eps_end: 0.05,
|
||||||
eps_decay: 3000.0,
|
eps_decay: 1500.0,
|
||||||
};
|
};
|
||||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
|
|
|
||||||
4
justfile
4
justfile
|
|
@ -28,12 +28,10 @@ trainsimple:
|
||||||
trainbot:
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
|
|
||||||
cargo build --release --bin=train_dqn_burn
|
cargo build --release --bin=train_dqn_burn
|
||||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
|
||||||
plottrainbot:
|
plottrainbot:
|
||||||
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png
|
||||||
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
|
|
||||||
debugtrainbot:
|
debugtrainbot:
|
||||||
cargo build --bin=train_dqn_burn
|
cargo build --bin=train_dqn_burn
|
||||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,7 @@ impl GameState {
|
||||||
|
|
||||||
/// Get state as a vector (to be used for bot training input) :
|
/// Get state as a vector (to be used for bot training input) :
|
||||||
/// length = 36
|
/// length = 36
|
||||||
|
/// i8 for board positions with negative values for blacks
|
||||||
pub fn to_vec(&self) -> Vec<i8> {
|
pub fn to_vec(&self) -> Vec<i8> {
|
||||||
let state_len = 36;
|
let state_len = 36;
|
||||||
let mut state = Vec::with_capacity(state_len);
|
let mut state = Vec::with_capacity(state_len);
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue