mirrors for open_spiel
This commit is contained in:
parent
47142d593f
commit
d53b65c947
7 changed files with 164 additions and 30 deletions
|
|
@ -50,9 +50,14 @@ impl TricTrac {
|
|||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_id: u64) -> Vec<usize> {
|
||||
if player_id == self.current_player_idx() {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
|
||||
if player_idx == self.current_player_idx() {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
|
@ -80,14 +85,18 @@ impl TricTrac {
|
|||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
|
||||
if let Some(event) =
|
||||
TrictracAction::from_action_index(action_idx).and_then(|a| a.to_event(&self.game_state))
|
||||
{
|
||||
println!("get event {:?}", event);
|
||||
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
let game_state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(game_state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
}) {
|
||||
if self.game_state.validate(&event) {
|
||||
println!("valid event");
|
||||
self.game_state.consume(&event);
|
||||
println!("state {}", self.game_state);
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(pyo3::exceptions::PyRuntimeError::new_err(
|
||||
|
|
@ -113,8 +122,20 @@ impl TricTrac {
|
|||
[self.get_score(1), self.get_score(2)]
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player: PlayerId) -> Vec<i8> {
|
||||
self.game_state.to_vec()
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
/// Afficher l'état du jeu (pour le débogage)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue