feat: add get_tensor on GameState more explicit for training than the minimal get_vec()

This commit is contained in:
Henri Bourcereau 2026-03-07 12:56:03 +01:00
parent 145ab7dcda
commit aa7f5fe42a
4 changed files with 108 additions and 1000 deletions

View file

@ -83,8 +83,8 @@ pub mod ffi {
/// Both players' scores.
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
/// 36-element state vector (i8). Mirrored for player_idx == 1.
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
/// 217-element state tensor (f32), normalized to [0,1]. Mirrored for player_idx == 1.
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<f32>;
/// Human-readable state description for `player_idx`.
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
@ -180,11 +180,11 @@ impl TricTracEngine {
.unwrap_or(-1)
}
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
if player_idx == 0 {
self.game_state.to_vec()
self.game_state.to_tensor()
} else {
self.game_state.mirror().to_vec()
self.game_state.mirror().to_tensor()
}
}

View file

@ -200,6 +200,106 @@ impl GameState {
self.to_vec().iter().map(|&x| x as f32).collect()
}
/// Get state as a tensor for neural network training (Option B, TD-Gammon style).
/// Returns 217 f32 values, all normalized to [0, 1].
///
/// Must be called from the active player's perspective: callers should mirror
/// the GameState for Black before calling so that "own" always means White.
///
/// Layout:
/// [0..95] own (White) checkers: 4 values per field × 24 fields
/// [96..191] opp (Black) checkers: 4 values per field × 24 fields
/// [192..193] dice values / 6
/// [194] active player color (0=White, 1=Black)
/// [195] turn_stage / 5
/// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille
/// [200..203] Black player: same
/// [204..207] own quarter filled (quarters 1-4)
/// [208..211] opp quarter filled (quarters 1-4)
/// [212] own checkers all in exit zone (fields 19-24)
/// [213] opp checkers all in exit zone (fields 1-6)
/// [214] own coin de repos taken (field 12 has ≥2 own checkers)
/// [215] opp coin de repos taken (field 13 has ≥2 opp checkers)
/// [216] own dice_roll_count / 3, clamped to 1
pub fn to_tensor(&self) -> Vec<f32> {
let mut t = Vec::with_capacity(217);
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
// [0..95] own (White) checkers, TD-Gammon encoding
for &c in &pos {
let own = c.max(0) as u8;
t.push((own == 1) as u8 as f32);
t.push((own == 2) as u8 as f32);
t.push((own == 3) as u8 as f32);
t.push(own.saturating_sub(3) as f32);
}
// [96..191] opp (Black) checkers, TD-Gammon encoding
for &c in &pos {
let opp = (-c).max(0) as u8;
t.push((opp == 1) as u8 as f32);
t.push((opp == 2) as u8 as f32);
t.push((opp == 3) as u8 as f32);
t.push(opp.saturating_sub(3) as f32);
}
// [192..193] dice
t.push(self.dice.values.0 as f32 / 6.0);
t.push(self.dice.values.1 as f32 / 6.0);
// [194] active player color
t.push(
self.who_plays()
.map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 })
.unwrap_or(0.0),
);
// [195] turn stage
t.push(u8::from(self.turn_stage) as f32 / 5.0);
// [196..199] White player stats
let wp = self.get_white_player();
t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0));
t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0));
t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
// [200..203] Black player stats
let bp = self.get_black_player();
t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0));
t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0));
t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
// [204..207] own (White) quarter fill status
for &start in &[1usize, 7, 13, 19] {
t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32);
}
// [208..211] opp (Black) quarter fill status
for &start in &[1usize, 7, 13, 19] {
t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32);
}
// [212] can_exit_own: no own checker in fields 1-18
t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32);
// [213] can_exit_opp: no opp checker in fields 7-24
t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32);
// [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers)
t.push((pos[11] >= 2) as u8 as f32);
// [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers)
t.push((pos[12] <= -2) as u8 as f32);
// [216] own dice_roll_count / 3, clamped to 1
t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0));
debug_assert_eq!(t.len(), 217, "to_tensor length mismatch");
t
}
/// Get state as a vector (to be used for bot training input) :
/// length = 36
/// i8 for board positions with negative values for blacks

View file

@ -113,11 +113,11 @@ impl TricTrac {
[self.get_score(1), self.get_score(2)]
}
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
fn get_tensor(&self, player_idx: u64) -> Vec<f32> {
if player_idx == 0 {
self.game_state.to_vec()
self.game_state.to_tensor()
} else {
self.game_state.mirror().to_vec()
self.game_state.mirror().to_tensor()
}
}