feat: add get_tensor on GameState more explicit for training than the minimal get_vec()
This commit is contained in:
parent
145ab7dcda
commit
aa7f5fe42a
4 changed files with 108 additions and 1000 deletions
|
|
@ -200,6 +200,106 @@ impl GameState {
|
|||
self.to_vec().iter().map(|&x| x as f32).collect()
|
||||
}
|
||||
|
||||
/// Get state as a tensor for neural network training (Option B, TD-Gammon style).
|
||||
/// Returns 217 f32 values, all normalized to [0, 1].
|
||||
///
|
||||
/// Must be called from the active player's perspective: callers should mirror
|
||||
/// the GameState for Black before calling so that "own" always means White.
|
||||
///
|
||||
/// Layout:
|
||||
/// [0..95] own (White) checkers: 4 values per field × 24 fields
|
||||
/// [96..191] opp (Black) checkers: 4 values per field × 24 fields
|
||||
/// [192..193] dice values / 6
|
||||
/// [194] active player color (0=White, 1=Black)
|
||||
/// [195] turn_stage / 5
|
||||
/// [196..199] White player: points/12, holes/12, can_bredouille, can_big_bredouille
|
||||
/// [200..203] Black player: same
|
||||
/// [204..207] own quarter filled (quarters 1-4)
|
||||
/// [208..211] opp quarter filled (quarters 1-4)
|
||||
/// [212] own checkers all in exit zone (fields 19-24)
|
||||
/// [213] opp checkers all in exit zone (fields 1-6)
|
||||
/// [214] own coin de repos taken (field 12 has ≥2 own checkers)
|
||||
/// [215] opp coin de repos taken (field 13 has ≥2 opp checkers)
|
||||
/// [216] own dice_roll_count / 3, clamped to 1
|
||||
pub fn to_tensor(&self) -> Vec<f32> {
|
||||
let mut t = Vec::with_capacity(217);
|
||||
let pos: Vec<i8> = self.board.to_vec(); // 24 elements, positive=White, negative=Black
|
||||
|
||||
// [0..95] own (White) checkers, TD-Gammon encoding
|
||||
for &c in &pos {
|
||||
let own = c.max(0) as u8;
|
||||
t.push((own == 1) as u8 as f32);
|
||||
t.push((own == 2) as u8 as f32);
|
||||
t.push((own == 3) as u8 as f32);
|
||||
t.push(own.saturating_sub(3) as f32);
|
||||
}
|
||||
|
||||
// [96..191] opp (Black) checkers, TD-Gammon encoding
|
||||
for &c in &pos {
|
||||
let opp = (-c).max(0) as u8;
|
||||
t.push((opp == 1) as u8 as f32);
|
||||
t.push((opp == 2) as u8 as f32);
|
||||
t.push((opp == 3) as u8 as f32);
|
||||
t.push(opp.saturating_sub(3) as f32);
|
||||
}
|
||||
|
||||
// [192..193] dice
|
||||
t.push(self.dice.values.0 as f32 / 6.0);
|
||||
t.push(self.dice.values.1 as f32 / 6.0);
|
||||
|
||||
// [194] active player color
|
||||
t.push(
|
||||
self.who_plays()
|
||||
.map(|p| if p.color == Color::Black { 1.0f32 } else { 0.0 })
|
||||
.unwrap_or(0.0),
|
||||
);
|
||||
|
||||
// [195] turn stage
|
||||
t.push(u8::from(self.turn_stage) as f32 / 5.0);
|
||||
|
||||
// [196..199] White player stats
|
||||
let wp = self.get_white_player();
|
||||
t.push(wp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(wp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(wp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [200..203] Black player stats
|
||||
let bp = self.get_black_player();
|
||||
t.push(bp.map_or(0.0, |p| p.points as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.holes as f32 / 12.0));
|
||||
t.push(bp.map_or(0.0, |p| p.can_bredouille as u8 as f32));
|
||||
t.push(bp.map_or(0.0, |p| p.can_big_bredouille as u8 as f32));
|
||||
|
||||
// [204..207] own (White) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::White, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [208..211] opp (Black) quarter fill status
|
||||
for &start in &[1usize, 7, 13, 19] {
|
||||
t.push(self.board.is_quarter_filled(Color::Black, start) as u8 as f32);
|
||||
}
|
||||
|
||||
// [212] can_exit_own: no own checker in fields 1-18
|
||||
t.push(pos[0..18].iter().all(|&c| c <= 0) as u8 as f32);
|
||||
|
||||
// [213] can_exit_opp: no opp checker in fields 7-24
|
||||
t.push(pos[6..24].iter().all(|&c| c >= 0) as u8 as f32);
|
||||
|
||||
// [214] own coin de repos taken (field 12 = index 11, ≥2 own checkers)
|
||||
t.push((pos[11] >= 2) as u8 as f32);
|
||||
|
||||
// [215] opp coin de repos taken (field 13 = index 12, ≥2 opp checkers)
|
||||
t.push((pos[12] <= -2) as u8 as f32);
|
||||
|
||||
// [216] own dice_roll_count / 3, clamped to 1
|
||||
t.push((wp.map_or(0, |p| p.dice_roll_count) as f32 / 3.0).min(1.0));
|
||||
|
||||
debug_assert_eq!(t.len(), 217, "to_tensor length mismatch");
|
||||
t
|
||||
}
|
||||
|
||||
/// Get state as a vector (to be used for bot training input) :
|
||||
/// length = 36
|
||||
/// i8 for board positions with negative values for blacks
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue