train burn-rl with integers

This commit is contained in:
Henri Bourcereau 2025-08-10 08:39:31 +02:00
parent 5b02293221
commit e4b3092018
6 changed files with 47 additions and 64 deletions

View file

@ -1,26 +0,0 @@
# Trictrac Project Guidelines
## Build & Run Commands
- Build: `cargo build`
- Test: `cargo test`
- Test specific: `cargo test -- test_name`
- Lint: `cargo clippy`
- Format: `cargo fmt`
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
- Build Python lib: `maturin build -m store/Cargo.toml --release`
## Code Style
- Use Rust 2021 edition idioms
- Error handling: Use Result<T, Error> pattern with custom Error types
- Naming: snake_case for functions/variables, CamelCase for types
- Imports: Group standard lib, external crates, then internal modules
- Module structure: Prefer small, focused modules with clear responsibilities
- Documentation: Document public APIs with doc comments
- Testing: Write unit tests in same file as implementation
- Python bindings: Use pyo3 for creating Python modules
## Architecture
- Core game logic in `store` crate
- Multiple clients: CLI, TUI, Bevy (graphical)
- Bot interfaces in `bot` crate

View file

@ -150,11 +150,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
episode_done = true;
println!(
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
episode,
episode_reward,
episode_duration,
now.elapsed().unwrap().as_secs()
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"threshold\": {eps_threshold}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
} else {

View file

@ -7,11 +7,11 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
}
impl State for TrictracState {
type Data = [f32; 36];
type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default())
@ -25,8 +25,8 @@ impl State for TrictracState {
impl TrictracState {
/// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec_float();
let mut data = [0.0; 36];
let state_vec = game_state.to_vec();
let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36);
@ -39,6 +39,7 @@ impl TrictracState {
/// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32,
}
@ -82,7 +83,8 @@ pub struct TrictracEnvironment {
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
step_count: usize,
pub step_count: usize,
pub goodmoves_count: usize,
pub visualized: bool,
}
@ -91,7 +93,7 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction;
type RewardType = f32;
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
const MAX_STEPS: usize = 600; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
@ -113,6 +115,7 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
goodmoves_count: 0,
visualized,
}
}
@ -132,7 +135,13 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
println!(
"correct moves: {} ({}%)",
self.goodmoves_count,
100 * self.goodmoves_count / self.step_count
);
self.step_count = 0;
self.goodmoves_count = 0;
Snapshot::new(self.current_state, 0.0, false)
}
@ -149,14 +158,9 @@ impl Environment for TrictracEnvironment {
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
match self.execute_action(action) {
Ok(action_reward) => {
reward = action_reward;
}
Err(_) => {
// Action invalide, pénalité
reward = -1.0;
}
reward = self.execute_action(action);
if reward != Self::ERROR_REWARD {
self.goodmoves_count += 1;
}
} else {
// Action non convertible, pénalité
@ -202,6 +206,9 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
@ -228,10 +235,11 @@ impl TrictracEnvironment {
}
/// Exécute une action Trictrac dans le jeu
fn execute_action(
&mut self,
action: dqn_common::TrictracAction,
) -> Result<f32, Box<dyn std::error::Error>> {
// fn execute_action(
// &mut self,
// action: dqn_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> f32 {
use dqn_common::TrictracAction;
let mut reward = 0.0;
@ -310,16 +318,22 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
println!("rolled for {reward}");
}
// Récompense proportionnelle aux points
}
}
} else {
// Pénalité pour action invalide
reward -= 2.0;
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
}
}
Ok(reward)
reward
}
/// Fait jouer l'adversaire avec une stratégie simple
@ -329,15 +343,14 @@ impl TrictracEnvironment {
// Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire
use crate::strategy::default::DefaultStrategy;
use crate::BotStrategy;
let mut default_strategy = DefaultStrategy::default();
default_strategy.set_player_id(self.opponent_id);
let mut strategy = crate::strategy::random::RandomStrategy::default();
strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
default_strategy.set_color(color);
strategy.set_color(color);
}
*default_strategy.get_mut_game() = self.game.clone();
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let event = match self.game.turn_stage {
@ -365,7 +378,7 @@ impl TrictracEnvironment {
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_id,
@ -397,7 +410,7 @@ impl TrictracEnvironment {
}
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: default_strategy.choose_move(),
moves: strategy.choose_move(),
},
};

View file

@ -14,11 +14,11 @@ fn main() {
let conf = dqn_model::DqnConfig {
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
// max_steps: 600, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 3000.0,
eps_decay: 1500.0,
};
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);

View file

@ -28,12 +28,10 @@ trainsimple:
trainbot:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok
# cargo run --bin=train_dqn_burn # utilise debug (why ?)
cargo build --release --bin=train_dqn_burn
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out
plottrainbot:
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
cat /tmp/train.out | grep -v rolled | grep -v correct | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid --title 'adv = random ; density = 256 ; err_reward = -1 ; reward_ratio = 1 ; decay = 1500 ; max steps = 600' --terminal png > doc/trainbots_stats/train_random_256_1_1_1500_600.png
debugtrainbot:
cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn

View file

@ -151,6 +151,7 @@ impl GameState {
/// Get state as a vector (to be used for bot training input) :
/// length = 36
/// i8 for board positions with negative values for blacks
pub fn to_vec(&self) -> Vec<i8> {
let state_len = 36;
let mut state = Vec::with_capacity(state_len);