train burn-rl with integers

This commit is contained in:
Henri Bourcereau 2025-08-10 08:39:31 +02:00
parent 5b02293221
commit e4b3092018
6 changed files with 47 additions and 64 deletions

View file

@ -1,26 +0,0 @@
# Trictrac Project Guidelines
## Build & Run Commands
- Build: `cargo build`
- Test: `cargo test`
- Test specific: `cargo test -- test_name`
- Lint: `cargo clippy`
- Format: `cargo fmt`
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
- Build Python lib: `maturin build -m store/Cargo.toml --release`
## Code Style
- Use Rust 2021 edition idioms
- Error handling: Use Result<T, Error> pattern with custom Error types
- Naming: snake_case for functions/variables, CamelCase for types
- Imports: Group standard lib, external crates, then internal modules
- Module structure: Prefer small, focused modules with clear responsibilities
- Documentation: Document public APIs with doc comments
- Testing: Write unit tests in same file as implementation
- Python bindings: Use pyo3 for creating Python modules
## Architecture
- Core game logic in `store` crate
- Multiple clients: CLI, TUI, Bevy (graphical)
- Bot interfaces in `bot` crate

View file

@ -150,11 +150,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
episode_done = true; episode_done = true;
println!( println!(
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}", "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
episode, now.elapsed().unwrap().as_secs(),
episode_reward,
episode_duration,
now.elapsed().unwrap().as_secs()
); );
now = SystemTime::now(); now = SystemTime::now();
} else { } else {

View file

@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
/// État du jeu Trictrac pour burn-rl /// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct TrictracState { pub struct TrictracState {
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
} }
impl State for TrictracState { impl State for TrictracState {
type Data = [f32; 36]; type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> { fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default()) Tensor::from_floats(self.data, &B::Device::default())
@ -25,8 +25,8 @@ impl State for TrictracState {
impl TrictracState { impl TrictracState {
/// Convertit un GameState en TrictracState /// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self { pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec_float(); let state_vec = game_state.to_vec();
let mut data = [0.0; 36]; let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille // Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36); let copy_len = state_vec.len().min(36);
@ -39,6 +39,7 @@ impl TrictracState {
/// Actions possibles dans Trictrac pour burn-rl /// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction { pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32, pub index: u32,
} }
@ -82,7 +83,8 @@ pub struct TrictracEnvironment {
opponent_id: PlayerId, opponent_id: PlayerId,
current_state: TrictracState, current_state: TrictracState,
episode_reward: f32, episode_reward: f32,
step_count: usize, pub step_count: usize,
pub goodmoves_count: usize,
pub visualized: bool, pub visualized: bool,
} }
@ -91,7 +93,7 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction; type ActionType = TrictracAction;
type RewardType = f32; type RewardType = f32;
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self { fn new(visualized: bool) -> Self {
let mut game = GameState::new(false); let mut game = GameState::new(false);
@ -113,6 +115,7 @@ impl Environment for TrictracEnvironment {
current_state, current_state,
episode_reward: 0.0, episode_reward: 0.0,
step_count: 0, step_count: 0,
goodmoves_count: 0,
visualized, visualized,
} }
} }
@ -132,7 +135,13 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game); self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0; self.episode_reward = 0.0;
println!(
"correct moves: {} ({}%)",
self.goodmoves_count,
100 * self.goodmoves_count / self.step_count
);
self.step_count = 0; self.step_count = 0;
self.goodmoves_count = 0;
Snapshot::new(self.current_state, 0.0, false) Snapshot::new(self.current_state, 0.0, false)
} }
@ -149,14 +158,9 @@ impl Environment for TrictracEnvironment {
// Exécuter l'action si c'est le tour de l'agent DQN // Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id { if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action { if let Some(action) = trictrac_action {
match self.execute_action(action) { reward = self.execute_action(action);
Ok(action_reward) => { if reward != Self::ERROR_REWARD {
reward = action_reward; self.goodmoves_count += 1;
}
Err(_) => {
// Action invalide, pénalité
reward = -1.0;
}
} }
} else { } else {
// Action non convertible, pénalité // Action non convertible, pénalité
@ -202,6 +206,9 @@ impl Environment for TrictracEnvironment {
} }
impl TrictracEnvironment { impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
@ -228,10 +235,11 @@ impl TrictracEnvironment {
} }
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu
fn execute_action( // fn execute_action(
&mut self, // &mut self,
action: dqn_common::TrictracAction, // action: dqn_common::TrictracAction,
) -> Result<f32, Box<dyn std::error::Error>> { // ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
use dqn_common::TrictracAction; use dqn_common::TrictracAction;
let mut reward = 0.0; let mut reward = 0.0;
@ -310,16 +318,22 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) { if self.game.validate(&dice_event) {
self.game.consume(&dice_event); self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points; let (points, adv_points) = self.game.dice_points;
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points reward += Self::REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
println!("rolled for {reward}");
}
// Récompense proportionnelle aux points
} }
} }
} else { } else {
// Pénalité pour action invalide // Pénalité pour action invalide
reward -= 2.0; // on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
} }
} }
Ok(reward) reward
} }
/// Fait jouer l'adversaire avec une stratégie simple /// Fait jouer l'adversaire avec une stratégie simple
@ -329,15 +343,14 @@ impl TrictracEnvironment {
// Si c'est le tour de l'adversaire, jouer automatiquement // Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire // Utiliser la stratégie default pour l'adversaire
use crate::strategy::default::DefaultStrategy;
use crate::BotStrategy; use crate::BotStrategy;
let mut default_strategy = DefaultStrategy::default(); let mut strategy = crate::strategy::random::RandomStrategy::default();
default_strategy.set_player_id(self.opponent_id); strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
default_strategy.set_color(color); strategy.set_color(color);
} }
*default_strategy.get_mut_game() = self.game.clone(); *strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage // Exécuter l'action selon le turn_stage
let event = match self.game.turn_stage { let event = match self.game.turn_stage {
@ -365,7 +378,7 @@ impl TrictracEnvironment {
let points_rules = let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count); let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark { GameEvent::Mark {
player_id: self.opponent_id, player_id: self.opponent_id,
@ -397,7 +410,7 @@ impl TrictracEnvironment {
} }
TurnStage::Move => GameEvent::Move { TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id, player_id: self.opponent_id,
moves: default_strategy.choose_move(), moves: strategy.choose_move(),
}, },
}; };

View file

@ -14,11 +14,11 @@ fn main() {
let conf = dqn_model::DqnConfig { let conf = dqn_model::DqnConfig {
num_episodes: 40, num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant // memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant // max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration) eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05, eps_end: 0.05,
eps_decay: 3000.0, eps_decay: 1500.0,
}; };
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true); let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);

View file

@ -28,12 +28,10 @@ trainsimple:
trainbot: trainbot:
#python ./store/python/trainModel.py #python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok # cargo run --bin=train_dqn # ok
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
cargo build --release --bin=train_dqn_burn cargo build --release --bin=train_dqn_burn
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
plottrainbot: plottrainbot:
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
debugtrainbot: debugtrainbot:
cargo build --bin=train_dqn_burn cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn

View file

@ -151,6 +151,7 @@ impl GameState {
/// Get state as a vector (to be used for bot training input) : /// Get state as a vector (to be used for bot training input) :
/// length = 36 /// length = 36
/// i8 for board positions with negative values for blacks
pub fn to_vec(&self) -> Vec<i8> { pub fn to_vec(&self) -> Vec<i8> {
let state_len = 36; let state_len = 36;
let mut state = Vec::with_capacity(state_len); let mut state = Vec::with_capacity(state_len);