mirrors for open_spiel

This commit is contained in:
Henri Bourcereau 2026-02-15 12:08:24 +01:00
parent 47142d593f
commit d53b65c947
7 changed files with 164 additions and 30 deletions

View file

@ -50,9 +50,14 @@ impl TricTrac {
self.game_state.active_player_id - 1
}
fn get_legal_actions(&self, player_id: u64) -> Vec<usize> {
if player_id == self.current_player_idx() {
get_valid_action_indices(&self.game_state)
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
if player_idx == self.current_player_idx() {
if player_idx == 0 {
get_valid_action_indices(&self.game_state)
} else {
let mirror = self.game_state.mirror();
get_valid_action_indices(&mirror)
}
} else {
vec![]
}
@ -80,14 +85,18 @@ impl TricTrac {
}
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
if let Some(event) =
TrictracAction::from_action_index(action_idx).and_then(|a| a.to_event(&self.game_state))
{
println!("get event {:?}", event);
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
let needs_mirror = self.game_state.active_player_id == 2;
let game_state = if needs_mirror {
&self.game_state.mirror()
} else {
&self.game_state
};
a.to_event(game_state)
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
}) {
if self.game_state.validate(&event) {
println!("valid event");
self.game_state.consume(&event);
println!("state {}", self.game_state);
return Ok(());
} else {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
@ -113,8 +122,20 @@ impl TricTrac {
[self.get_score(1), self.get_score(2)]
}
fn get_tensor(&self, player: PlayerId) -> Vec<i8> {
self.game_state.to_vec()
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
if player_idx == 0 {
self.game_state.to_vec()
} else {
self.game_state.mirror().to_vec()
}
}
fn get_observation_string(&self, player_idx: u64) -> String {
if player_idx == 0 {
format!("{}", self.game_state)
} else {
format!("{}", self.game_state.mirror())
}
}
/// Afficher l'état du jeu (pour le débogage)