From ad30d09311aa34e3c5a48b3be4ac6b89ffa7fd8b Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 12 Mar 2026 21:17:14 +0100 Subject: [PATCH 1/4] feat(spiel_bot): cli spiel_bot strategy --- Cargo.lock | 2 + client_cli/Cargo.toml | 7 +- client_cli/src/app.rs | 22 ++++ client_cli/src/main.rs | 10 +- spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/strategy.rs | 242 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 281 insertions(+), 4 deletions(-) create mode 100644 spiel_bot/src/strategy.rs diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..fa260cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "trictrac-bot", "trictrac-store", ] @@ -6854,6 +6855,7 @@ dependencies = [ "pico-args", "pretty_assertions", "renet", + "spiel_bot", "trictrac-bot", "trictrac-store", ] diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index e48a249..52318cb 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -3,7 +3,9 @@ name = "trictrac-client_cli" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "client_cli" +path = "src/main.rs" [dependencies] anyhow = "1.0.75" @@ -12,7 +14,8 @@ pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" trictrac-store = { path = "../store" } -trictrac-bot = { path = "../bot" } +trictrac-bot = { path = "../bot" } +spiel_bot = { path = "../spiel_bot" } itertools = "0.13.0" env_logger = "0.11.6" log = "0.4.20" diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index b803efe..ab61451 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,3 +1,4 @@ +use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy}; use trictrac_bot::{ BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, @@ -56,6 +57,27 @@ impl App { Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string())) as Box) } + "az" => { + Some(Box::new(AzBotStrategy::new_mlp(None)) as Box) + } + s if s.starts_with("az:") && !s.starts_with("az-") => { + let path = s.trim_start_matches("az:"); + Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box) + } + "az-resnet" => { + Some(Box::new(AzBotStrategy::new_resnet(None)) as Box) + } + s if s.starts_with("az-resnet:") => { + let path = s.trim_start_matches("az-resnet:"); + Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box) + } + "az-dqn" => { + Some(Box::new(DqnSpielBotStrategy::new(None)) as Box) + } + s if s.starts_with("az-dqn:") => { + let path = s.trim_start_matches("az-dqn:"); + Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box) + } _ => None, }) .collect() diff --git a/client_cli/src/main.rs b/client_cli/src/main.rs index 0107b43..e06299b 100644 --- a/client_cli/src/main.rs +++ b/client_cli/src/main.rs @@ -23,8 +23,14 @@ OPTIONS: - dummy: Default strategy selecting the first valid move - ai: AI strategy using the default model at models/trictrac_ppo.zip - ai:/path/to/model.zip: AI strategy using a custom model - - dqn: DQN strategy using native Rust implementation with Burn - - dqn:/path/to/model: DQN strategy using a custom model + - dqnburn: DQN strategy (burn-rl backend) + - dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model + - az: AlphaZero MlpNet (random weights) + - az:/path/to/model.mpk: AlphaZero MlpNet checkpoint + - az-resnet: AlphaZero ResNet (random weights) + - az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint + - az-dqn: DQN QNet (random weights, first-legal-move fallback) + - az-dqn:/path/to/model.mpk: DQN QNet checkpoint ARGS: diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..b541adc 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] trictrac-store = { path = "../store" } +trictrac-bot = { path = "../bot" } anyhow = "1" rand = "0.9" rand_distr = "0.5" diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 9dfb4de..cf6d865 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -3,3 +3,4 @@ pub mod dqn; pub mod env; pub mod mcts; pub mod network; +pub mod strategy; diff --git a/spiel_bot/src/strategy.rs b/spiel_bot/src/strategy.rs new file mode 100644 index 0000000..8309bf3 --- /dev/null +++ b/spiel_bot/src/strategy.rs @@ -0,0 +1,242 @@ +//! [`BotStrategy`] implementations backed by `spiel_bot` models. +//! +//! | Strategy struct | Network | CLI token | +//! |-----------------|---------|-----------| +//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` | +//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` | +//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` | +//! +//! All strategies operate from **White's perspective** (player_id = 1) internally; +//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black. + +use std::cell::RefCell; +use std::path::Path; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData}, +}; +use rand::{SeedableRng, rngs::SmallRng}; +use trictrac_bot::BotStrategy; +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction}, + CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId, +}; + +use crate::{ + alphazero::BurnEvaluator, + env::{GameEnv, TrictracEnv}, + mcts::{self, Evaluator, MctsConfig}, + network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig}, +}; + +type B = NdArray; + +/// Default MCTS simulations per move used by [`AzBotStrategy`]. +pub const AZ_BOT_N_SIM: usize = 50; + +// ── Shared helpers ───────────────────────────────────────────────────────────── + +/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state. +fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> { + match TrictracAction::from_action_index(action)?.to_event(game)? { + GameEvent::Move { moves, .. } => Some(moves), + _ => None, + } +} + +/// Fallback: return the first legal move from `MoveRules` (always succeeds). +fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&Color::White, &game.board, game.dice); + let moves = rules.get_possible_moves_sequences(true, vec![]); + *moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default())) +} + +// ── AzBotStrategy ───────────────────────────────────────────────────────────── + +/// AlphaZero bot usable as a [`BotStrategy`]. +/// +/// Supports both MlpNet and ResNet checkpoints through separate constructors. +/// Uses greedy (temperature = 0) MCTS for action selection. +/// +/// # Construction +/// +/// ```rust,ignore +/// // MlpNet with random weights +/// AzBotStrategy::new_mlp(None); +/// +/// // MlpNet from a checkpoint +/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk")); +/// +/// // ResNet from a checkpoint +/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk")); +/// ``` +pub struct AzBotStrategy { + game: GameState, + evaluator: Box, + mcts_config: MctsConfig, + /// Interior-mutable RNG so `choose_move(&self)` can drive MCTS. + rng: RefCell, +} + +impl std::fmt::Debug for AzBotStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AzBotStrategy") + .field("n_sim", &self.mcts_config.n_simulations) + .finish() + } +} + +impl AzBotStrategy { + fn from_evaluator(evaluator: Box) -> Self { + Self { + game: GameState::default(), + evaluator, + mcts_config: MctsConfig { + n_simulations: AZ_BOT_N_SIM, + dirichlet_alpha: 0.0, // no noise during play + dirichlet_eps: 0.0, + temperature: 0.0, // greedy selection + ..MctsConfig::default() + }, + rng: RefCell::new(SmallRng::seed_from_u64(42)), + } + } + + /// MlpNet-backed bot. `path = None` → random weights. + pub fn new_mlp(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 }; + let model = match path { + Some(p) => MlpNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az: load failed ({e}), using random weights"); + MlpNet::::new(&cfg, &device) + }), + None => MlpNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// ResNet-backed bot. `path = None` → random weights. + pub fn new_resnet(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 }; + let model = match path { + Some(p) => ResNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-resnet: load failed ({e}), using random weights"); + ResNet::::new(&cfg, &device) + }), + None => ResNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// Run MCTS and return the greedy best action index, or `None` if no legal moves. + fn best_action(&self) -> Option { + let env = TrictracEnv; + if env.legal_actions(&self.game).is_empty() { + return None; + } + let mut rng = self.rng.borrow_mut(); + let root = mcts::run_mcts( + &env, + &self.game, + self.evaluator.as_ref(), + &self.mcts_config, + &mut *rng, + ); + Some(mcts::select_action(&root, 0.0, &mut *rng)) + } +} + +impl BotStrategy for AzBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + // Action index 1 == TrictracAction::Go + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +} + +// ── DqnSpielBotStrategy ─────────────────────────────────────────────────────── + +/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`]. +/// +/// Selects actions by greedy argmax over Q-values, masked to legal moves. +/// When no checkpoint is provided the model falls back to the first legal move. +/// +/// # Construction +/// +/// ```rust,ignore +/// // No model — always picks first legal move +/// DqnSpielBotStrategy::new(None); +/// +/// // Trained checkpoint +/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk")); +/// ``` +#[derive(Debug)] +pub struct DqnSpielBotStrategy { + game: GameState, + model: Option>, +} + +impl DqnSpielBotStrategy { + /// Create a DQN bot. `path = None` → falls back to first legal move. + pub fn new(path: Option<&str>) -> Self { + let model = path.map(|p| { + let device: ::Device = Default::default(); + let cfg = QNetConfig::default(); + QNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-dqn: load failed ({e}), using random weights"); + QNet::::new(&cfg, &device) + }) + }); + Self { game: GameState::default(), model } + } + + /// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves. + fn best_action(&self) -> Option { + let model = self.model.as_ref()?; + let legal = get_valid_action_indices(&self.game).unwrap_or_default(); + if legal.is_empty() { + return None; + } + let device: ::Device = Default::default(); + let obs = self.game.to_tensor(); + let obs_t = Tensor::::from_data(TensorData::new(obs, [1, 217]), &device); + let q_vals: Vec = model.forward(obs_t).into_data().to_vec().unwrap(); + legal.into_iter().max_by(|&a, &b| { + q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + } +} + +impl BotStrategy for DqnSpielBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +} From cf50784a2387b976f87a1eb7795537016993fde9 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 11 Mar 2026 22:17:03 +0100 Subject: [PATCH 2/4] fix: --n-sim training parameter --- spiel_bot/src/mcts/mod.rs | 8 ++++---- spiel_bot/src/mcts/search.rs | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/spiel_bot/src/mcts/mod.rs b/spiel_bot/src/mcts/mod.rs index a0a690d..eead171 100644 --- a/spiel_bot/src/mcts/mod.rs +++ b/spiel_bot/src/mcts/mod.rs @@ -403,10 +403,10 @@ mod tests { let root = run_mcts(&env, &state, &ZeroEval(514), &config, &mut r); // root.n = 1 (expansion) + n_simulations (one backup per simulation). assert_eq!(root.n, 1 + config.n_simulations as u32); - // Children visit counts may sum to less than n_simulations when some - // simulations cross a chance node at depth 1 (turn ends after one move) - // and evaluate with the network directly without updating child.n. + // Every simulation crosses a chance node at depth 1 (dice roll after + // the player's move). Since the fix now updates child.n in that case, + // children visit counts must sum to exactly n_simulations. let total: u32 = root.children.iter().map(|(_, c)| c.n).sum(); - assert!(total <= config.n_simulations as u32); + assert_eq!(total, config.n_simulations as u32); } } diff --git a/spiel_bot/src/mcts/search.rs b/spiel_bot/src/mcts/search.rs index 55db701..4d36acc 100644 --- a/spiel_bot/src/mcts/search.rs +++ b/spiel_bot/src/mcts/search.rs @@ -156,7 +156,13 @@ pub(super) fn simulate( let returns = env .returns(&next_state) .expect("terminal node must have returns"); - returns[player_idx] + let v = returns[player_idx]; + // Update child stats so PUCT and mcts_policy count terminal visits. + // Store from player_idx's perspective so child.q() is directly usable + // by the parent's PUCT selection (high = good for the selecting player). + child.n += 1; + child.w += v; + v } else { let child_player = next_cp.index().unwrap(); let v = if crossed_chance { @@ -166,6 +172,13 @@ pub(super) fn simulate( // previously cached children would be for a different outcome. let obs = env.observation(&next_state, child_player); let (_, value) = evaluator.evaluate(&obs); + // Store from player_idx's (parent's) perspective so PUCT works correctly. + // `value` is from child_player's POV; negate when child is the opponent + // so that child.q() = expected return for the player CHOOSING this child. + // Without the negation, root would maximise the opponent's Q-value and + // systematically pick the worst action. + child.n += 1; + child.w += if child_player == player_idx { value } else { -value }; value } else if child.expanded { simulate(child, next_state, env, evaluator, config, rng, child_player) From 00f23543a5abe7a8a1918134797c38c423c21d13 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Fri, 20 Mar 2026 17:37:44 +0100 Subject: [PATCH 3/4] todo --- doc/todo.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 doc/todo.md diff --git a/doc/todo.md b/doc/todo.md new file mode 100644 index 0000000..a91c0d0 --- /dev/null +++ b/doc/todo.md @@ -0,0 +1,31 @@ +# TODO + +## webgame server + +- axum +- postgresql + +Should be able to serve many games concurrently + +## web client + +Dioxus or Leptos ? + +## User stories + +As a user I want to be able to + +- create an account. +- connect to the application with my account +- play (anonymously if not connected) + - enter a game code to participate an existing game + - start a new game + - choose to invite a friend to the game by giving him the game code + - ask a bot to play + +As an administrator I want to be able to + +- see the registered users +- see the list of games currently active +- see the past games +- see the server statistics From 0b06c62fd9984c08938e1092a5555bbfb1d88101 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Wed, 25 Mar 2026 16:04:06 +0100 Subject: [PATCH 4/4] refact: add cargo "python" feature for pyo3 --- README.md | 2 +- bot/Cargo.toml | 2 +- client_cli/Cargo.toml | 2 +- spiel_bot/Cargo.toml | 2 +- store/Cargo.toml | 6 +++++- store/src/lib.rs | 1 + store/src/player.rs | 3 ++- 7 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index e5a0f39..e74fb69 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Training of AI bots is the work in progress. - game rules and game state are implemented in the _store/_ folder. - the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other -- the bots algorithms and the training of their models are implemented in the _bot/_ folder +- the bots algorithms and the training of their models are implemented in the _bot/_ and _spiel_bot_ folders. ### _store_ package diff --git a/bot/Cargo.toml b/bot/Cargo.toml index de957df..d24adcc 100644 --- a/bot/Cargo.toml +++ b/bot/Cargo.toml @@ -13,7 +13,7 @@ path = "src/burnrl/main.rs" pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -trictrac-store = { path = "../store" } +trictrac-store = { path = "../store", features = ["python"] } rand = "0.9" env_logger = "0.10" burn = { version = "0.20", features = ["ndarray", "autodiff"] } diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index 52318cb..d85dd8b 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -13,7 +13,7 @@ bincode = "1.3.3" pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" -trictrac-store = { path = "../store" } +trictrac-store = { path = "../store", features = ["python"] } trictrac-bot = { path = "../bot" } spiel_bot = { path = "../spiel_bot" } itertools = "0.13.0" diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index b541adc..1458d66 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -trictrac-store = { path = "../store" } +trictrac-store = { path = "../store", features = ["python"] } trictrac-bot = { path = "../bot" } anyhow = "1" rand = "0.9" diff --git a/store/Cargo.toml b/store/Cargo.toml index 935a2a0..fbb4f6d 100644 --- a/store/Cargo.toml +++ b/store/Cargo.toml @@ -12,6 +12,10 @@ name = "trictrac_store" # "staticlib" → used by the C++ OpenSpiel game (cxxengine) crate-type = ["cdylib", "rlib", "staticlib"] +[features] +# Enable Python bindings (required for maturin / AI training). Not available on wasm32. +python = ["pyo3"] + [dependencies] anyhow = "1.0" base64 = "0.21.7" @@ -20,7 +24,7 @@ cxx = "1.0" log = "0.4.20" merge = "0.1.0" # generate python lib (with maturin) to be used in AI training -pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"], optional = true } rand = "0.9" serde = { version = "1.0", features = ["derive"] } transpose = "0.2.2" diff --git a/store/src/lib.rs b/store/src/lib.rs index 4fc8dff..25d2dcb 100644 --- a/store/src/lib.rs +++ b/store/src/lib.rs @@ -20,6 +20,7 @@ pub use dice::{Dice, DiceRoller}; pub mod training_common; // python interface "trictrac_engine" (for AI training..) +#[cfg(feature = "python")] mod pyengine; // C++ interface via cxx.rs (for OpenSpiel C++ integration) diff --git a/store/src/player.rs b/store/src/player.rs index 1e48593..cca02b5 100644 --- a/store/src/player.rs +++ b/store/src/player.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "python")] use pyo3::prelude::*; use serde::{Deserialize, Serialize}; use std::fmt; @@ -5,7 +6,7 @@ use std::fmt; // This just makes it easier to dissern between a player id and any ol' u64 pub type PlayerId = u64; -#[pyclass(eq, eq_int)] +#[cfg_attr(feature = "python", pyclass(eq, eq_int))] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Color { White,