From b2d66ce41e04390108aca40284ffccf0817ce8b6 Mon Sep 17 00:00:00 2001 From: Henri Bourcereau Date: Thu, 12 Mar 2026 21:17:14 +0100 Subject: [PATCH] feat(spiel_bot): cli spiel_bot strategy --- Cargo.lock | 2 + client_cli/Cargo.toml | 7 +- client_cli/src/app.rs | 22 ++++ client_cli/src/main.rs | 10 +- spiel_bot/Cargo.toml | 1 + spiel_bot/src/lib.rs | 1 + spiel_bot/src/strategy.rs | 242 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 281 insertions(+), 4 deletions(-) create mode 100644 spiel_bot/src/strategy.rs diff --git a/Cargo.lock b/Cargo.lock index 34bfe80..fa260cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6009,6 +6009,7 @@ dependencies = [ "criterion", "rand 0.9.2", "rand_distr", + "trictrac-bot", "trictrac-store", ] @@ -6854,6 +6855,7 @@ dependencies = [ "pico-args", "pretty_assertions", "renet", + "spiel_bot", "trictrac-bot", "trictrac-store", ] diff --git a/client_cli/Cargo.toml b/client_cli/Cargo.toml index e48a249..52318cb 100644 --- a/client_cli/Cargo.toml +++ b/client_cli/Cargo.toml @@ -3,7 +3,9 @@ name = "trictrac-client_cli" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "client_cli" +path = "src/main.rs" [dependencies] anyhow = "1.0.75" @@ -12,7 +14,8 @@ pico-args = "0.5.0" pretty_assertions = "1.4.0" renet = "0.0.13" trictrac-store = { path = "../store" } -trictrac-bot = { path = "../bot" } +trictrac-bot = { path = "../bot" } +spiel_bot = { path = "../spiel_bot" } itertools = "0.13.0" env_logger = "0.11.6" log = "0.4.20" diff --git a/client_cli/src/app.rs b/client_cli/src/app.rs index b803efe..ab61451 100644 --- a/client_cli/src/app.rs +++ b/client_cli/src/app.rs @@ -1,3 +1,4 @@ +use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy}; use trictrac_bot::{ BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy, StableBaselines3Strategy, @@ -56,6 +57,27 @@ impl App { Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string())) as Box) } + "az" => { + Some(Box::new(AzBotStrategy::new_mlp(None)) as Box) + } + s if s.starts_with("az:") && !s.starts_with("az-") => { + let path = s.trim_start_matches("az:"); + Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box) + } + "az-resnet" => { + Some(Box::new(AzBotStrategy::new_resnet(None)) as Box) + } + s if s.starts_with("az-resnet:") => { + let path = s.trim_start_matches("az-resnet:"); + Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box) + } + "az-dqn" => { + Some(Box::new(DqnSpielBotStrategy::new(None)) as Box) + } + s if s.starts_with("az-dqn:") => { + let path = s.trim_start_matches("az-dqn:"); + Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box) + } _ => None, }) .collect() diff --git a/client_cli/src/main.rs b/client_cli/src/main.rs index 0107b43..e06299b 100644 --- a/client_cli/src/main.rs +++ b/client_cli/src/main.rs @@ -23,8 +23,14 @@ OPTIONS: - dummy: Default strategy selecting the first valid move - ai: AI strategy using the default model at models/trictrac_ppo.zip - ai:/path/to/model.zip: AI strategy using a custom model - - dqn: DQN strategy using native Rust implementation with Burn - - dqn:/path/to/model: DQN strategy using a custom model + - dqnburn: DQN strategy (burn-rl backend) + - dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model + - az: AlphaZero MlpNet (random weights) + - az:/path/to/model.mpk: AlphaZero MlpNet checkpoint + - az-resnet: AlphaZero ResNet (random weights) + - az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint + - az-dqn: DQN QNet (random weights, first-legal-move fallback) + - az-dqn:/path/to/model.mpk: DQN QNet checkpoint ARGS: diff --git a/spiel_bot/Cargo.toml b/spiel_bot/Cargo.toml index 3848dce..b541adc 100644 --- a/spiel_bot/Cargo.toml +++ b/spiel_bot/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] trictrac-store = { path = "../store" } +trictrac-bot = { path = "../bot" } anyhow = "1" rand = "0.9" rand_distr = "0.5" diff --git a/spiel_bot/src/lib.rs b/spiel_bot/src/lib.rs index 9dfb4de..cf6d865 100644 --- a/spiel_bot/src/lib.rs +++ b/spiel_bot/src/lib.rs @@ -3,3 +3,4 @@ pub mod dqn; pub mod env; pub mod mcts; pub mod network; +pub mod strategy; diff --git a/spiel_bot/src/strategy.rs b/spiel_bot/src/strategy.rs new file mode 100644 index 0000000..8309bf3 --- /dev/null +++ b/spiel_bot/src/strategy.rs @@ -0,0 +1,242 @@ +//! [`BotStrategy`] implementations backed by `spiel_bot` models. +//! +//! | Strategy struct | Network | CLI token | +//! |-----------------|---------|-----------| +//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` | +//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` | +//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` | +//! +//! All strategies operate from **White's perspective** (player_id = 1) internally; +//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black. + +use std::cell::RefCell; +use std::path::Path; + +use burn::{ + backend::NdArray, + tensor::{Tensor, TensorData}, +}; +use rand::{SeedableRng, rngs::SmallRng}; +use trictrac_bot::BotStrategy; +use trictrac_store::{ + training_common::{get_valid_action_indices, TrictracAction}, + CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId, +}; + +use crate::{ + alphazero::BurnEvaluator, + env::{GameEnv, TrictracEnv}, + mcts::{self, Evaluator, MctsConfig}, + network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig}, +}; + +type B = NdArray; + +/// Default MCTS simulations per move used by [`AzBotStrategy`]. +pub const AZ_BOT_N_SIM: usize = 50; + +// ── Shared helpers ───────────────────────────────────────────────────────────── + +/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state. +fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> { + match TrictracAction::from_action_index(action)?.to_event(game)? { + GameEvent::Move { moves, .. } => Some(moves), + _ => None, + } +} + +/// Fallback: return the first legal move from `MoveRules` (always succeeds). +fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) { + let rules = MoveRules::new(&Color::White, &game.board, game.dice); + let moves = rules.get_possible_moves_sequences(true, vec![]); + *moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default())) +} + +// ── AzBotStrategy ───────────────────────────────────────────────────────────── + +/// AlphaZero bot usable as a [`BotStrategy`]. +/// +/// Supports both MlpNet and ResNet checkpoints through separate constructors. +/// Uses greedy (temperature = 0) MCTS for action selection. +/// +/// # Construction +/// +/// ```rust,ignore +/// // MlpNet with random weights +/// AzBotStrategy::new_mlp(None); +/// +/// // MlpNet from a checkpoint +/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk")); +/// +/// // ResNet from a checkpoint +/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk")); +/// ``` +pub struct AzBotStrategy { + game: GameState, + evaluator: Box, + mcts_config: MctsConfig, + /// Interior-mutable RNG so `choose_move(&self)` can drive MCTS. + rng: RefCell, +} + +impl std::fmt::Debug for AzBotStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AzBotStrategy") + .field("n_sim", &self.mcts_config.n_simulations) + .finish() + } +} + +impl AzBotStrategy { + fn from_evaluator(evaluator: Box) -> Self { + Self { + game: GameState::default(), + evaluator, + mcts_config: MctsConfig { + n_simulations: AZ_BOT_N_SIM, + dirichlet_alpha: 0.0, // no noise during play + dirichlet_eps: 0.0, + temperature: 0.0, // greedy selection + ..MctsConfig::default() + }, + rng: RefCell::new(SmallRng::seed_from_u64(42)), + } + } + + /// MlpNet-backed bot. `path = None` → random weights. + pub fn new_mlp(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 }; + let model = match path { + Some(p) => MlpNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az: load failed ({e}), using random weights"); + MlpNet::::new(&cfg, &device) + }), + None => MlpNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// ResNet-backed bot. `path = None` → random weights. + pub fn new_resnet(path: Option<&str>) -> Self { + let device: ::Device = Default::default(); + let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 }; + let model = match path { + Some(p) => ResNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-resnet: load failed ({e}), using random weights"); + ResNet::::new(&cfg, &device) + }), + None => ResNet::::new(&cfg, &device), + }; + Self::from_evaluator(Box::new(BurnEvaluator::>::new(model, device))) + } + + /// Run MCTS and return the greedy best action index, or `None` if no legal moves. + fn best_action(&self) -> Option { + let env = TrictracEnv; + if env.legal_actions(&self.game).is_empty() { + return None; + } + let mut rng = self.rng.borrow_mut(); + let root = mcts::run_mcts( + &env, + &self.game, + self.evaluator.as_ref(), + &self.mcts_config, + &mut *rng, + ); + Some(mcts::select_action(&root, 0.0, &mut *rng)) + } +} + +impl BotStrategy for AzBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + // Action index 1 == TrictracAction::Go + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +} + +// ── DqnSpielBotStrategy ─────────────────────────────────────────────────────── + +/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`]. +/// +/// Selects actions by greedy argmax over Q-values, masked to legal moves. +/// When no checkpoint is provided the model falls back to the first legal move. +/// +/// # Construction +/// +/// ```rust,ignore +/// // No model — always picks first legal move +/// DqnSpielBotStrategy::new(None); +/// +/// // Trained checkpoint +/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk")); +/// ``` +#[derive(Debug)] +pub struct DqnSpielBotStrategy { + game: GameState, + model: Option>, +} + +impl DqnSpielBotStrategy { + /// Create a DQN bot. `path = None` → falls back to first legal move. + pub fn new(path: Option<&str>) -> Self { + let model = path.map(|p| { + let device: ::Device = Default::default(); + let cfg = QNetConfig::default(); + QNet::::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| { + eprintln!("az-dqn: load failed ({e}), using random weights"); + QNet::::new(&cfg, &device) + }) + }); + Self { game: GameState::default(), model } + } + + /// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves. + fn best_action(&self) -> Option { + let model = self.model.as_ref()?; + let legal = get_valid_action_indices(&self.game).unwrap_or_default(); + if legal.is_empty() { + return None; + } + let device: ::Device = Default::default(); + let obs = self.game.to_tensor(); + let obs_t = Tensor::::from_data(TensorData::new(obs, [1, 217]), &device); + let q_vals: Vec = model.forward(obs_t).into_data().to_vec().unwrap(); + legal.into_iter().max_by(|&a, &b| { + q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal) + }) + } +} + +impl BotStrategy for DqnSpielBotStrategy { + fn get_game(&self) -> &GameState { &self.game } + fn get_mut_game(&mut self) -> &mut GameState { &mut self.game } + fn calculate_points(&self) -> u8 { self.game.dice_points.0 } + fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 } + fn set_player_id(&mut self, _player_id: PlayerId) {} + fn set_color(&mut self, _color: Color) {} + + fn choose_go(&self) -> bool { + self.best_action().map(|a| a == 1).unwrap_or(false) + } + + fn choose_move(&self) -> (CheckerMove, CheckerMove) { + self.best_action() + .and_then(|a| action_to_moves(a, &self.game)) + .unwrap_or_else(|| fallback_move(&self.game)) + } +}