diff --git a/README.md b/README.md index e74fb69..e5a0f39 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Training of AI bots is the work in progress. - game rules and game state are implemented in the _store/_ folder. - the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other -- the bots algorithms and the training of their models are implemented in the _bot/_ and _spiel_bot_ folders. +- the bots algorithms and the training of their models are implemented in the _bot/_ folder ### _store_ package diff --git a/bot/Cargo.toml b/bot/Cargo.toml index d24adcc..de957df 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,7 +13,7 @@ path = "src/burnrl/main.rs" pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -trictrac-store = { path = "../store", features = ["python"] } +trictrac-store = { path = "../store" } rand = "0.9" env_logger = "0.10" burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index d85dd8b..52318cb 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -13,7 +13,7 @@ bincode = "1.3.3" pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" -trictrac-store = { path = "../store", features = ["python"] } +trictrac-store = { path = "../store" } trictrac-bot = { path = "../bot" } spiel_bot = { path = "../spiel_bot" } itertools = "0.13.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 1458d66..b541adc 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -trictrac-store = { path = "../store", features = ["python"] } +trictrac-store = { path = "../store" } trictrac-bot = { path = "../bot" } anyhow = "1" rand = "0.9" diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 4d36acc..1d9750d 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -156,13 +156,7 @@ pub(super) fn simulate( let returns = env .returns(&next_state) .expect("terminal node must have returns"); - let v = returns[player_idx]; - // Update child stats so PUCT and mcts_policy count terminal visits. - // Store from player_idx's perspective so child.q() is directly usable - // by the parent's PUCT selection (high = good for the selecting player). - child.n += 1; - child.w += v; - v + returns[player_idx] } else { let child_player = next_cp.index().unwrap(); let v = if crossed_chance { @@ -172,13 +166,12 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); - // Store from player_idx's (parent's) perspective so PUCT works correctly. - // `value` is from child_player's POV; negate when child is the opponent - // so that child.q() = expected return for the player CHOOSING this child. - // Without the negation, root would maximise the opponent's Q-value and - // systematically pick the worst action. + // Record the visit so that PUCT and mcts_policy use real counts. + // Without this, child.n stays 0 for every simulation in games where + // every player action is immediately followed by a chance node (e.g. + // Trictrac), causing mcts_policy to always return a uniform policy. child.n += 1; - child.w += if child_player == player_idx { value } else { -value }; + child.w += value; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) diff --git a/store/Cargo.toml b/store/Cargo.toml index fbb4f6d..935a2a0 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -12,10 +12,6 @@ name = "trictrac_store" # "staticlib" → used by the C++ OpenSpiel game (cxxengine) crate-type = ["cdylib", "rlib", "staticlib"] -[features] -# Enable Python bindings (required for maturin / AI training). Not available on wasm32. -python = ["pyo3"] - [dependencies] anyhow = "1.0" base64 = "0.21.7" @@ -24,7 +20,7 @@ cxx = "1.0" log = "0.4.20" merge = "0.1.0" # generate python lib (with maturin) to be used in AI training -pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"], optional = true } +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" diff --git a/store/src/lib.rs b/store/src/lib.rs index 25d2dcb..4fc8dff 100644 --- a/store/src/lib.rs +++ b/store/src/lib.rs @@ -20,7 +20,6 @@ pub use dice::{Dice, DiceRoller}; pub mod training_common; // python interface "trictrac_engine" (for AI training..) -#[cfg(feature = "python")] mod pyengine; // C++ interface via cxx.rs (for OpenSpiel C++ integration) diff --git a/store/src/player.rs b/store/src/player.rs index cca02b5..1e48593 100644 --- a/store/src/player.rs +++ b/store/src/player.rs @@ -1,4 +1,3 @@ -#[cfg(feature = "python")] use pyo3::prelude::*; use serde::{Deserialize, Serialize}; use std::fmt; @@ -6,7 +5,7 @@ use std::fmt; // This just makes it easier to dissern between a player id and any ol' u64 pub type PlayerId = u64; -#[cfg_attr(feature = "python", pyclass(eq, eq_int))] +#[pyclass(eq, eq_int)] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Color { White,