feat(spiel_bot): cli spiel_bot strategy
This commit is contained in:
parent
e80dade303
commit
b2d66ce41e
7 changed files with 281 additions and 4 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -6009,6 +6009,7 @@ dependencies = [
|
||||||
"criterion",
|
"criterion",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
|
"trictrac-bot",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -6854,6 +6855,7 @@ dependencies = [
|
||||||
"pico-args",
|
"pico-args",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"renet",
|
"renet",
|
||||||
|
"spiel_bot",
|
||||||
"trictrac-bot",
|
"trictrac-bot",
|
||||||
"trictrac-store",
|
"trictrac-store",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ name = "trictrac-client_cli"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
[[bin]]
|
||||||
|
name = "client_cli"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.75"
|
anyhow = "1.0.75"
|
||||||
|
|
@ -12,7 +14,8 @@ pico-args = "0.5.0"
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
renet = "0.0.13"
|
renet = "0.0.13"
|
||||||
trictrac-store = { path = "../store" }
|
trictrac-store = { path = "../store" }
|
||||||
trictrac-bot = { path = "../bot" }
|
trictrac-bot = { path = "../bot" }
|
||||||
|
spiel_bot = { path = "../spiel_bot" }
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
env_logger = "0.11.6"
|
env_logger = "0.11.6"
|
||||||
log = "0.4.20"
|
log = "0.4.20"
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use spiel_bot::strategy::{AzBotStrategy, DqnSpielBotStrategy};
|
||||||
use trictrac_bot::{
|
use trictrac_bot::{
|
||||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
||||||
StableBaselines3Strategy,
|
StableBaselines3Strategy,
|
||||||
|
|
@ -56,6 +57,27 @@ impl App {
|
||||||
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
||||||
as Box<dyn BotStrategy>)
|
as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
|
"az" => {
|
||||||
|
Some(Box::new(AzBotStrategy::new_mlp(None)) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
|
s if s.starts_with("az:") && !s.starts_with("az-") => {
|
||||||
|
let path = s.trim_start_matches("az:");
|
||||||
|
Some(Box::new(AzBotStrategy::new_mlp(Some(path))) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
|
"az-resnet" => {
|
||||||
|
Some(Box::new(AzBotStrategy::new_resnet(None)) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
|
s if s.starts_with("az-resnet:") => {
|
||||||
|
let path = s.trim_start_matches("az-resnet:");
|
||||||
|
Some(Box::new(AzBotStrategy::new_resnet(Some(path))) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
|
"az-dqn" => {
|
||||||
|
Some(Box::new(DqnSpielBotStrategy::new(None)) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
|
s if s.starts_with("az-dqn:") => {
|
||||||
|
let path = s.trim_start_matches("az-dqn:");
|
||||||
|
Some(Box::new(DqnSpielBotStrategy::new(Some(path))) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,14 @@ OPTIONS:
|
||||||
- dummy: Default strategy selecting the first valid move
|
- dummy: Default strategy selecting the first valid move
|
||||||
- ai: AI strategy using the default model at models/trictrac_ppo.zip
|
- ai: AI strategy using the default model at models/trictrac_ppo.zip
|
||||||
- ai:/path/to/model.zip: AI strategy using a custom model
|
- ai:/path/to/model.zip: AI strategy using a custom model
|
||||||
- dqn: DQN strategy using native Rust implementation with Burn
|
- dqnburn: DQN strategy (burn-rl backend)
|
||||||
- dqn:/path/to/model: DQN strategy using a custom model
|
- dqnburn:/path/to/model: DQN strategy (burn-rl backend) with custom model
|
||||||
|
- az: AlphaZero MlpNet (random weights)
|
||||||
|
- az:/path/to/model.mpk: AlphaZero MlpNet checkpoint
|
||||||
|
- az-resnet: AlphaZero ResNet (random weights)
|
||||||
|
- az-resnet:/path/to/model.mpk: AlphaZero ResNet checkpoint
|
||||||
|
- az-dqn: DQN QNet (random weights, first-legal-move fallback)
|
||||||
|
- az-dqn:/path/to/model.mpk: DQN QNet checkpoint
|
||||||
|
|
||||||
ARGS:
|
ARGS:
|
||||||
<INPUT>
|
<INPUT>
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
trictrac-store = { path = "../store" }
|
trictrac-store = { path = "../store" }
|
||||||
|
trictrac-bot = { path = "../bot" }
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
rand_distr = "0.5"
|
rand_distr = "0.5"
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ pub mod dqn;
|
||||||
pub mod env;
|
pub mod env;
|
||||||
pub mod mcts;
|
pub mod mcts;
|
||||||
pub mod network;
|
pub mod network;
|
||||||
|
pub mod strategy;
|
||||||
|
|
|
||||||
242
spiel_bot/src/strategy.rs
Normal file
242
spiel_bot/src/strategy.rs
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
//! [`BotStrategy`] implementations backed by `spiel_bot` models.
|
||||||
|
//!
|
||||||
|
//! | Strategy struct | Network | CLI token |
|
||||||
|
//! |-----------------|---------|-----------|
|
||||||
|
//! | [`AzBotStrategy`] (mlp) | MlpNet (AlphaZero) | `az` / `az:PATH` |
|
||||||
|
//! | [`AzBotStrategy`] (resnet) | ResNet (AlphaZero) | `az-resnet` / `az-resnet:PATH` |
|
||||||
|
//! | [`DqnSpielBotStrategy`] | QNet (DQN) | `az-dqn` / `az-dqn:PATH` |
|
||||||
|
//!
|
||||||
|
//! All strategies operate from **White's perspective** (player_id = 1) internally;
|
||||||
|
//! the [`Bot`](trictrac_bot::Bot) wrapper handles board mirroring for Black.
|
||||||
|
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
backend::NdArray,
|
||||||
|
tensor::{Tensor, TensorData},
|
||||||
|
};
|
||||||
|
use rand::{SeedableRng, rngs::SmallRng};
|
||||||
|
use trictrac_bot::BotStrategy;
|
||||||
|
use trictrac_store::{
|
||||||
|
training_common::{get_valid_action_indices, TrictracAction},
|
||||||
|
CheckerMove, Color, GameEvent, GameState, MoveRules, PlayerId,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
alphazero::BurnEvaluator,
|
||||||
|
env::{GameEnv, TrictracEnv},
|
||||||
|
mcts::{self, Evaluator, MctsConfig},
|
||||||
|
network::{MlpConfig, MlpNet, QNet, QNetConfig, QValueNet, ResNet, ResNetConfig},
|
||||||
|
};
|
||||||
|
|
||||||
|
type B = NdArray<f32>;
|
||||||
|
|
||||||
|
/// Default MCTS simulations per move used by [`AzBotStrategy`].
|
||||||
|
pub const AZ_BOT_N_SIM: usize = 50;
|
||||||
|
|
||||||
|
// ── Shared helpers ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Decode an action index → `(CheckerMove, CheckerMove)` using the game state.
|
||||||
|
fn action_to_moves(action: usize, game: &GameState) -> Option<(CheckerMove, CheckerMove)> {
|
||||||
|
match TrictracAction::from_action_index(action)?.to_event(game)? {
|
||||||
|
GameEvent::Move { moves, .. } => Some(moves),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback: return the first legal move from `MoveRules` (always succeeds).
|
||||||
|
fn fallback_move(game: &GameState) -> (CheckerMove, CheckerMove) {
|
||||||
|
let rules = MoveRules::new(&Color::White, &game.board, game.dice);
|
||||||
|
let moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
*moves.first().unwrap_or(&(CheckerMove::default(), CheckerMove::default()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── AzBotStrategy ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// AlphaZero bot usable as a [`BotStrategy`].
|
||||||
|
///
|
||||||
|
/// Supports both MlpNet and ResNet checkpoints through separate constructors.
|
||||||
|
/// Uses greedy (temperature = 0) MCTS for action selection.
|
||||||
|
///
|
||||||
|
/// # Construction
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// // MlpNet with random weights
|
||||||
|
/// AzBotStrategy::new_mlp(None);
|
||||||
|
///
|
||||||
|
/// // MlpNet from a checkpoint
|
||||||
|
/// AzBotStrategy::new_mlp(Some("checkpoints/iter_0100.mpk"));
|
||||||
|
///
|
||||||
|
/// // ResNet from a checkpoint
|
||||||
|
/// AzBotStrategy::new_resnet(Some("checkpoints/resnet_0200.mpk"));
|
||||||
|
/// ```
|
||||||
|
pub struct AzBotStrategy {
|
||||||
|
game: GameState,
|
||||||
|
evaluator: Box<dyn Evaluator>,
|
||||||
|
mcts_config: MctsConfig,
|
||||||
|
/// Interior-mutable RNG so `choose_move(&self)` can drive MCTS.
|
||||||
|
rng: RefCell<SmallRng>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for AzBotStrategy {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("AzBotStrategy")
|
||||||
|
.field("n_sim", &self.mcts_config.n_simulations)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AzBotStrategy {
|
||||||
|
fn from_evaluator(evaluator: Box<dyn Evaluator>) -> Self {
|
||||||
|
Self {
|
||||||
|
game: GameState::default(),
|
||||||
|
evaluator,
|
||||||
|
mcts_config: MctsConfig {
|
||||||
|
n_simulations: AZ_BOT_N_SIM,
|
||||||
|
dirichlet_alpha: 0.0, // no noise during play
|
||||||
|
dirichlet_eps: 0.0,
|
||||||
|
temperature: 0.0, // greedy selection
|
||||||
|
..MctsConfig::default()
|
||||||
|
},
|
||||||
|
rng: RefCell::new(SmallRng::seed_from_u64(42)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MlpNet-backed bot. `path = None` → random weights.
|
||||||
|
pub fn new_mlp(path: Option<&str>) -> Self {
|
||||||
|
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||||
|
let cfg = MlpConfig { obs_size: 217, action_size: 514, hidden_size: 256 };
|
||||||
|
let model = match path {
|
||||||
|
Some(p) => MlpNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||||
|
eprintln!("az: load failed ({e}), using random weights");
|
||||||
|
MlpNet::<B>::new(&cfg, &device)
|
||||||
|
}),
|
||||||
|
None => MlpNet::<B>::new(&cfg, &device),
|
||||||
|
};
|
||||||
|
Self::from_evaluator(Box::new(BurnEvaluator::<B, MlpNet<B>>::new(model, device)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ResNet-backed bot. `path = None` → random weights.
|
||||||
|
pub fn new_resnet(path: Option<&str>) -> Self {
|
||||||
|
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||||
|
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: 512 };
|
||||||
|
let model = match path {
|
||||||
|
Some(p) => ResNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||||
|
eprintln!("az-resnet: load failed ({e}), using random weights");
|
||||||
|
ResNet::<B>::new(&cfg, &device)
|
||||||
|
}),
|
||||||
|
None => ResNet::<B>::new(&cfg, &device),
|
||||||
|
};
|
||||||
|
Self::from_evaluator(Box::new(BurnEvaluator::<B, ResNet<B>>::new(model, device)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run MCTS and return the greedy best action index, or `None` if no legal moves.
|
||||||
|
fn best_action(&self) -> Option<usize> {
|
||||||
|
let env = TrictracEnv;
|
||||||
|
if env.legal_actions(&self.game).is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mut rng = self.rng.borrow_mut();
|
||||||
|
let root = mcts::run_mcts(
|
||||||
|
&env,
|
||||||
|
&self.game,
|
||||||
|
self.evaluator.as_ref(),
|
||||||
|
&self.mcts_config,
|
||||||
|
&mut *rng,
|
||||||
|
);
|
||||||
|
Some(mcts::select_action(&root, 0.0, &mut *rng))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for AzBotStrategy {
|
||||||
|
fn get_game(&self) -> &GameState { &self.game }
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState { &mut self.game }
|
||||||
|
fn calculate_points(&self) -> u8 { self.game.dice_points.0 }
|
||||||
|
fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 }
|
||||||
|
fn set_player_id(&mut self, _player_id: PlayerId) {}
|
||||||
|
fn set_color(&mut self, _color: Color) {}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
// Action index 1 == TrictracAction::Go
|
||||||
|
self.best_action().map(|a| a == 1).unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
self.best_action()
|
||||||
|
.and_then(|a| action_to_moves(a, &self.game))
|
||||||
|
.unwrap_or_else(|| fallback_move(&self.game))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── DqnSpielBotStrategy ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// DQN bot (QNet from `spiel_bot`) usable as a [`BotStrategy`].
|
||||||
|
///
|
||||||
|
/// Selects actions by greedy argmax over Q-values, masked to legal moves.
|
||||||
|
/// When no checkpoint is provided the model falls back to the first legal move.
|
||||||
|
///
|
||||||
|
/// # Construction
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// // No model — always picks first legal move
|
||||||
|
/// DqnSpielBotStrategy::new(None);
|
||||||
|
///
|
||||||
|
/// // Trained checkpoint
|
||||||
|
/// DqnSpielBotStrategy::new(Some("checkpoints/dqn_iter_0500.mpk"));
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnSpielBotStrategy {
|
||||||
|
game: GameState,
|
||||||
|
model: Option<QNet<B>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnSpielBotStrategy {
|
||||||
|
/// Create a DQN bot. `path = None` → falls back to first legal move.
|
||||||
|
pub fn new(path: Option<&str>) -> Self {
|
||||||
|
let model = path.map(|p| {
|
||||||
|
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||||
|
let cfg = QNetConfig::default();
|
||||||
|
QNet::<B>::load(&cfg, Path::new(p), &device).unwrap_or_else(|e| {
|
||||||
|
eprintln!("az-dqn: load failed ({e}), using random weights");
|
||||||
|
QNet::<B>::new(&cfg, &device)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Self { game: GameState::default(), model }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Greedy Q-value selection masked to legal actions, or `None` if no model / no legal moves.
|
||||||
|
fn best_action(&self) -> Option<usize> {
|
||||||
|
let model = self.model.as_ref()?;
|
||||||
|
let legal = get_valid_action_indices(&self.game).unwrap_or_default();
|
||||||
|
if legal.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
|
||||||
|
let obs = self.game.to_tensor();
|
||||||
|
let obs_t = Tensor::<B, 2>::from_data(TensorData::new(obs, [1, 217]), &device);
|
||||||
|
let q_vals: Vec<f32> = model.forward(obs_t).into_data().to_vec().unwrap();
|
||||||
|
legal.into_iter().max_by(|&a, &b| {
|
||||||
|
q_vals[a].partial_cmp(&q_vals[b]).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for DqnSpielBotStrategy {
|
||||||
|
fn get_game(&self) -> &GameState { &self.game }
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState { &mut self.game }
|
||||||
|
fn calculate_points(&self) -> u8 { self.game.dice_points.0 }
|
||||||
|
fn calculate_adv_points(&self) -> u8 { self.game.dice_points.1 }
|
||||||
|
fn set_player_id(&mut self, _player_id: PlayerId) {}
|
||||||
|
fn set_color(&mut self, _color: Color) {}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
self.best_action().map(|a| a == 1).unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
self.best_action()
|
||||||
|
.and_then(|a| action_to_moves(a, &self.game))
|
||||||
|
.unwrap_or_else(|| fallback_move(&self.game))
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue