feat(spiel_bot): upgrade network

This commit is contained in:
Henri Bourcereau 2026-03-07 23:05:53 +01:00
parent 7d37eebe52
commit 41b3fc5dad
3 changed files with 131 additions and 3 deletions

View file

@ -32,7 +32,7 @@ use spiel_bot::{
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
env::{GameEnv, Player, TrictracEnv},
mcts::{Evaluator, MctsConfig, run_mcts},
network::{MlpConfig, MlpNet, PolicyValueNet},
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
};
// ── Shared types ───────────────────────────────────────────────────────────
@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) {
);
}
// ── ResNet (4 residual blocks) ────────────────────────────────────────
for &hidden in &[256usize, 512] {
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
let model = ResNet::<InferB>::new(&cfg, &infer_device());
let obs: Vec<f32> = vec![0.5; 217];
group.bench_with_input(
BenchmarkId::new("resnet_b1", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs.clone(), [1, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
let obs32: Vec<f32> = vec![0.5; 217 * 32];
group.bench_with_input(
BenchmarkId::new("resnet_b32", hidden),
&hidden,
|b, _| {
b.iter(|| {
let data = TensorData::new(obs32.clone(), [32, 217]);
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
black_box(model.forward(t))
})
},
);
}
group.finish();
}