Compare commits

..

4 commits

8 changed files with 25 additions and 12 deletions

View file

@ -24,7 +24,7 @@ Training of AI bots is the work in progress.
- game rules and game state are implemented in the _store/_ folder. - game rules and game state are implemented in the _store/_ folder.
- the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other - the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other
- the bots algorithms and the training of their models are implemented in the _bot/_ folder - the bots algorithms and the training of their models are implemented in the _bot/_ and _spiel_bot_ folders.
### _store_ package ### _store_ package

View file

@ -13,7 +13,7 @@ path = "src/burnrl/main.rs"
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
trictrac-store = { path = "../store" } trictrac-store = { path = "../store", features = ["python"] }
rand = "0.9" rand = "0.9"
env_logger = "0.10" env_logger = "0.10"
burn = { version = "0.20", features = ["ndarray", "autodiff"] } burn = { version = "0.20", features = ["ndarray", "autodiff"] }

View file

@ -13,7 +13,7 @@ bincode = "1.3.3"
pico-args = "0.5.0" pico-args = "0.5.0"
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
renet = "0.0.13" renet = "0.0.13"
trictrac-store = { path = "../store" } trictrac-store = { path = "../store", features = ["python"] }
trictrac-bot = { path = "../bot" } trictrac-bot = { path = "../bot" }
spiel_bot = { path = "../spiel_bot" } spiel_bot = { path = "../spiel_bot" }
itertools = "0.13.0" itertools = "0.13.0"

View file

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
trictrac-store = { path = "../store" } trictrac-store = { path = "../store", features = ["python"] }
trictrac-bot = { path = "../bot" } trictrac-bot = { path = "../bot" }
anyhow = "1" anyhow = "1"
rand = "0.9" rand = "0.9"

View file

@ -156,7 +156,13 @@ pub(super) fn simulate<E: GameEnv>(
let returns = env let returns = env
.returns(&next_state) .returns(&next_state)
.expect("terminal node must have returns"); .expect("terminal node must have returns");
returns[player_idx] let v = returns[player_idx];
// Update child stats so PUCT and mcts_policy count terminal visits.
// Store from player_idx's perspective so child.q() is directly usable
// by the parent's PUCT selection (high = good for the selecting player).
child.n += 1;
child.w += v;
v
} else { } else {
let child_player = next_cp.index().unwrap(); let child_player = next_cp.index().unwrap();
let v = if crossed_chance { let v = if crossed_chance {
@ -166,12 +172,13 @@ pub(super) fn simulate<E: GameEnv>(
// previously cached children would be for a different outcome. // previously cached children would be for a different outcome.
let obs = env.observation(&next_state, child_player); let obs = env.observation(&next_state, child_player);
let (_, value) = evaluator.evaluate(&obs); let (_, value) = evaluator.evaluate(&obs);
// Record the visit so that PUCT and mcts_policy use real counts. // Store from player_idx's (parent's) perspective so PUCT works correctly.
// Without this, child.n stays 0 for every simulation in games where // `value` is from child_player's POV; negate when child is the opponent
// every player action is immediately followed by a chance node (e.g. // so that child.q() = expected return for the player CHOOSING this child.
// Trictrac), causing mcts_policy to always return a uniform policy. // Without the negation, root would maximise the opponent's Q-value and
// systematically pick the worst action.
child.n += 1; child.n += 1;
child.w += value; child.w += if child_player == player_idx { value } else { -value };
value value
} else if child.expanded { } else if child.expanded {
simulate(child, next_state, env, evaluator, config, rng, child_player) simulate(child, next_state, env, evaluator, config, rng, child_player)

View file

@ -12,6 +12,10 @@ name = "trictrac_store"
# "staticlib" → used by the C++ OpenSpiel game (cxxengine) # "staticlib" → used by the C++ OpenSpiel game (cxxengine)
crate-type = ["cdylib", "rlib", "staticlib"] crate-type = ["cdylib", "rlib", "staticlib"]
[features]
# Enable Python bindings (required for maturin / AI training). Not available on wasm32.
python = ["pyo3"]
[dependencies] [dependencies]
anyhow = "1.0" anyhow = "1.0"
base64 = "0.21.7" base64 = "0.21.7"
@ -20,7 +24,7 @@ cxx = "1.0"
log = "0.4.20" log = "0.4.20"
merge = "0.1.0" merge = "0.1.0"
# generate python lib (with maturin) to be used in AI training # generate python lib (with maturin) to be used in AI training
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"], optional = true }
rand = "0.9" rand = "0.9"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
transpose = "0.2.2" transpose = "0.2.2"

View file

@ -20,6 +20,7 @@ pub use dice::{Dice, DiceRoller};
pub mod training_common; pub mod training_common;
// python interface "trictrac_engine" (for AI training..) // python interface "trictrac_engine" (for AI training..)
#[cfg(feature = "python")]
mod pyengine; mod pyengine;
// C++ interface via cxx.rs (for OpenSpiel C++ integration) // C++ interface via cxx.rs (for OpenSpiel C++ integration)

View file

@ -1,3 +1,4 @@
#[cfg(feature = "python")]
use pyo3::prelude::*; use pyo3::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
@ -5,7 +6,7 @@ use std::fmt;
// This just makes it easier to dissern between a player id and any ol' u64 // This just makes it easier to dissern between a player id and any ol' u64
pub type PlayerId = u64; pub type PlayerId = u64;
#[pyclass(eq, eq_int)] #[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Color { pub enum Color {
White, White,