feat(spiel_bot): cli spiel_bot strategy
This commit is contained in:
parent
e80dade303
commit
b2d66ce41e
7 changed files with 281 additions and 4 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -6009,6 +6009,7 @@ dependencies = [
|
|||
"criterion",
|
||||
"rand 0.9.2",
|
||||
"rand_distr",
|
||||
"trictrac-bot",
|
||||
"trictrac-store",
|
||||
]
|
||||
|
||||
|
|
@ -6854,6 +6855,7 @@ dependencies = [
|
|||
"pico-args",
|
||||
"pretty_assertions",
|
||||
"renet",
|
||||
"spiel_bot",
|
||||
"trictrac-bot",
|
||||
"trictrac-store",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ name = "trictrac-client_cli"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[[bin]]
|
||||
name = "client_cli"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
|
|
@ -12,7 +14,8 @@ pico-args = "0.5.0"
|
|||
pretty_assertions = "1.4.0"
|
||||
renet = "0.0.13"
|
||||
trictrac-store = { path = "../store" }
|
||||
trictrac-bot = { path = "../bot" }
|
||||
trictrac-bot = { path = "../bot" }
|
||||
spiel_bot = { path = "../spiel_bot" }
|
||||
itertools = "0.13.0"
|
||||
env_logger = "0.11.6"
|
||||
log = "0.4.20"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy};
|
||||
use trictrac_bot::{
|
||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
||||
StableBaselines3Strategy,
|
||||
|
|
@ -56,6 +57,27 @@ impl App {
|
|||
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
||||
as Box<dyn BotStrategy>)
|
||||
}
|
||||
"az" => {
|
||||
Some(Box::new(AzBotStrategy::new_mlp(None)) as Box<dyn BotStrategy>)
|
||||
}
|
||||
s if s.starts_with("az:") && !s.starts_with("az-") => {
|
||||
let path = s.trim_start_matches("az:");
|
||||
Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box<dyn BotStrategy>)
|
||||
}
|
||||
"az-resnet" => {
|
||||
Some(Box::new(AzBotStrategy::new_resnet(None)) as Box<dyn BotStrategy>)
|
||||
}
|
||||
s if s.starts_with("az-resnet:") => {
|
||||
let path = s.trim_start_matches("az-resnet:");
|
||||
Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box<dyn BotStrategy>)
|
||||
}
|
||||
"az-dqn" => {
|
||||
Some(Box::new(DqnSpielBotStrategy::new(None)) as Box<dyn BotStrategy>)
|
||||
}
|
||||
s if s.starts_with("az-dqn:") => {
|
||||
let path = s.trim_start_matches("az-dqn:");
|
||||
Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box<dyn BotStrategy>)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
|
|
|
|||
|
|
@ -23,8 +23,14 @@ OPTIONS:
|
|||
- dummy: Default strategy selecting the first valid move
|
||||
- ai: AI strategy using the default model at models/trictrac_ppo.zip
|
||||
- ai:/path/to/model.zip: AI strategy using a custom model
|
||||
- dqn: DQN strategy using native Rust implementation with Burn
|
||||
- dqn:/path/to/model: DQN strategy using a custom model
|
||||
- dqnburn: DQN strategy (burn-rl backend)
|
||||
- dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model
|
||||
- az: AlphaZero MlpNet (random weights)
|
||||
- az:/path/to/model.mpk: AlphaZero MlpNet checkpoint
|
||||
- az-resnet: AlphaZero ResNet (random weights)
|
||||
- az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint
|
||||
- az-dqn: DQN QNet (random weights, first-legal-move fallback)
|
||||
- az-dqn:/path/to/model.mpk: DQN QNet checkpoint
|
||||
|
||||
ARGS:
|
||||
<INPUT>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
trictrac-store = { path = "../store" }
|
||||
trictrac-bot = { path = "../bot" }
|
||||
anyhow = "1"
|
||||
rand = "0.9"
|
||||
rand_distr = "0.5"
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ pub mod dqn;
|
|||
pub mod env;
|
||||
pub mod mcts;
|
||||
pub mod network;
|
||||
pub mod strategy;
|
||||
|
|
|
|||
242
spiel_bot/src/strategy.rs
Normal file
242
spiel_bot/src/strategy.rs
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
//! [`BotStrategy`] implementations backed by `spiel_bot` models.
|
||||
//!
|
||||
//! | Strategy struct | Network | CLI token |
|
||||
//! |-----------------|---------|-----------|
|
||||
//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` |
|
||||
//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` |
|
||||
//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` |
|
||||
//!
|
||||
//! All strategies operate from **White's perspective** (player_id = 1) internally;
|
||||
//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::{
|
||||
backend::NdArray,
|
||||
tensor::{Tensor, TensorData},
|
||||
};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
use trictrac_bot::BotStrategy;
|
||||
use trictrac_store::{
|
||||
training_common::{get_valid_action_indices, TrictracAction},
|
||||
CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
alphazero::BurnEvaluator,
|
||||
env::{GameEnv, TrictracEnv},
|
||||
mcts::{self, Evaluator, MctsConfig},
|
||||
network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
/// Default MCTS simulations per move used by [`AzBotStrategy`].
|
||||
pub const AZ_BOT_N_SIM: usize = 50;
|
||||
|
||||
// ── Shared helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state.
|
||||
fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> {
|
||||
match TrictracAction::from_action_index(action)?.to_event(game)? {
|
||||
GameEvent::Move { moves, .. } => Some(moves),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback: return the first legal move from `MoveRules` (always succeeds).
|
||||
fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) {
|
||||
let rules = MoveRules::new(&Color::White, &game.board, game.dice);
|
||||
let moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
*moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||||
}
|
||||
|
||||
// ── AzBotStrategy ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// AlphaZero bot usable as a [`BotStrategy`].
|
||||
///
|
||||
/// Supports both MlpNet and ResNet checkpoints through separate constructors.
|
||||
/// Uses greedy (temperature = 0) MCTS for action selection.
|
||||
///
|
||||
/// # Construction
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// // MlpNet with random weights
|
||||
/// AzBotStrategy::new_mlp(None);
|
||||
///
|
||||
/// // MlpNet from a checkpoint
|
||||
/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk"));
|
||||
///
|
||||
/// // ResNet from a checkpoint
|
||||
/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk"));
|
||||
/// ```
|
||||
pub struct AzBotStrategy {
|
||||
game: GameState,
|
||||
evaluator: Box<dyn Evaluator>,
|
||||
mcts_config: MctsConfig,
|
||||
/// Interior-mutable RNG so `choose_move(&self)` can drive MCTS.
|
||||
rng: RefCell<SmallRng>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AzBotStrategy {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AzBotStrategy")
|
||||
.field("n_sim", &self.mcts_config.n_simulations)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AzBotStrategy {
|
||||
fn from_evaluator(evaluator: Box<dyn Evaluator>) -> Self {
|
||||
Self {
|
||||
game: GameState::default(),
|
||||
evaluator,
|
||||
mcts_config: MctsConfig {
|
||||
n_simulations: AZ_BOT_N_SIM,
|
||||
dirichlet_alpha: 0.0, // no noise during play
|
||||
dirichlet_eps: 0.0,
|
||||
temperature: 0.0, // greedy selection
|
||||
..MctsConfig::default()
|
||||
},
|
||||
rng: RefCell::new(SmallRng::seed_from_u64(42)),
|
||||
}
|
||||
}
|
||||
|
||||
/// MlpNet-backed bot. `path = None` → random weights.
|
||||
pub fn new_mlp(path: Option<&str>) -> Self {
|
||||
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 };
|
||||
let model = match path {
|
||||
Some(p) => MlpNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||
eprintln!("az: load failed ({e}), using random weights");
|
||||
MlpNet::<B>::new(&cfg, &device)
|
||||
}),
|
||||
None => MlpNet::<B>::new(&cfg, &device),
|
||||
};
|
||||
Self::from_evaluator(Box::new(BurnEvaluator::<B, MlpNet<B>>::new(model, device)))
|
||||
}
|
||||
|
||||
/// ResNet-backed bot. `path = None` → random weights.
|
||||
pub fn new_resnet(path: Option<&str>) -> Self {
|
||||
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 };
|
||||
let model = match path {
|
||||
Some(p) => ResNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||
eprintln!("az-resnet: load failed ({e}), using random weights");
|
||||
ResNet::<B>::new(&cfg, &device)
|
||||
}),
|
||||
None => ResNet::<B>::new(&cfg, &device),
|
||||
};
|
||||
Self::from_evaluator(Box::new(BurnEvaluator::<B, ResNet<B>>::new(model, device)))
|
||||
}
|
||||
|
||||
/// Run MCTS and return the greedy best action index, or `None` if no legal moves.
|
||||
fn best_action(&self) -> Option<usize> {
|
||||
let env = TrictracEnv;
|
||||
if env.legal_actions(&self.game).is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut rng = self.rng.borrow_mut();
|
||||
let root = mcts::run_mcts(
|
||||
&env,
|
||||
&self.game,
|
||||
self.evaluator.as_ref(),
|
||||
&self.mcts_config,
|
||||
&mut *rng,
|
||||
);
|
||||
Some(mcts::select_action(&root, 0.0, &mut *rng))
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for AzBotStrategy {
|
||||
fn get_game(&self) -> &GameState { &self.game }
|
||||
fn get_mut_game(&mut self) -> &mut GameState { &mut self.game }
|
||||
fn calculate_points(&self) -> u8 { self.game.dice_points.0 }
|
||||
fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 }
|
||||
fn set_player_id(&mut self, _player_id: PlayerId) {}
|
||||
fn set_color(&mut self, _color: Color) {}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
// Action index 1 == TrictracAction::Go
|
||||
self.best_action().map(|a| a == 1).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
self.best_action()
|
||||
.and_then(|a| action_to_moves(a, &self.game))
|
||||
.unwrap_or_else(|| fallback_move(&self.game))
|
||||
}
|
||||
}
|
||||
|
||||
// ── DqnSpielBotStrategy ───────────────────────────────────────────────────────
|
||||
|
||||
/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`].
|
||||
///
|
||||
/// Selects actions by greedy argmax over Q-values, masked to legal moves.
|
||||
/// When no checkpoint is provided the model falls back to the first legal move.
|
||||
///
|
||||
/// # Construction
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// // No model — always picks first legal move
|
||||
/// DqnSpielBotStrategy::new(None);
|
||||
///
|
||||
/// // Trained checkpoint
|
||||
/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk"));
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct DqnSpielBotStrategy {
|
||||
game: GameState,
|
||||
model: Option<QNet<B>>,
|
||||
}
|
||||
|
||||
impl DqnSpielBotStrategy {
|
||||
/// Create a DQN bot. `path = None` → falls back to first legal move.
|
||||
pub fn new(path: Option<&str>) -> Self {
|
||||
let model = path.map(|p| {
|
||||
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||
let cfg = QNetConfig::default();
|
||||
QNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||
eprintln!("az-dqn: load failed ({e}), using random weights");
|
||||
QNet::<B>::new(&cfg, &device)
|
||||
})
|
||||
});
|
||||
Self { game: GameState::default(), model }
|
||||
}
|
||||
|
||||
/// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves.
|
||||
fn best_action(&self) -> Option<usize> {
|
||||
let model = self.model.as_ref()?;
|
||||
let legal = get_valid_action_indices(&self.game).unwrap_or_default();
|
||||
if legal.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||
let obs = self.game.to_tensor();
|
||||
let obs_t = Tensor::<B, 2>::from_data(TensorData::new(obs, [1, 217]), &device);
|
||||
let q_vals: Vec<f32> = model.forward(obs_t).into_data().to_vec().unwrap();
|
||||
legal.into_iter().max_by(|&a, &b| {
|
||||
q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for DqnSpielBotStrategy {
|
||||
fn get_game(&self) -> &GameState { &self.game }
|
||||
fn get_mut_game(&mut self) -> &mut GameState { &mut self.game }
|
||||
fn calculate_points(&self) -> u8 { self.game.dice_points.0 }
|
||||
fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 }
|
||||
fn set_player_id(&mut self, _player_id: PlayerId) {}
|
||||
fn set_color(&mut self, _color: Color) {}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
self.best_action().map(|a| a == 1).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
self.best_action()
|
||||
.and_then(|a| action_to_moves(a, &self.game))
|
||||
.unwrap_or_else(|| fallback_move(&self.game))
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue