feat(spiel_bot): upgrade network
This commit is contained in:
parent
7d37eebe52
commit
41b3fc5dad
3 changed files with 131 additions and 3 deletions
|
|
@ -32,7 +32,7 @@ use spiel_bot::{
|
|||
alphazero::{BurnEvaluator, TrainSample, generate_episode, train_step},
|
||||
env::{GameEnv, Player, TrictracEnv},
|
||||
mcts::{Evaluator, MctsConfig, run_mcts},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet},
|
||||
network::{MlpConfig, MlpNet, PolicyValueNet, ResNet, ResNetConfig},
|
||||
};
|
||||
|
||||
// ── Shared types ───────────────────────────────────────────────────────────
|
||||
|
|
@ -162,6 +162,38 @@ fn bench_network(c: &mut Criterion) {
|
|||
);
|
||||
}
|
||||
|
||||
// ── ResNet (4 residual blocks) ────────────────────────────────────────
|
||||
for &hidden in &[256usize, 512] {
|
||||
let cfg = ResNetConfig { obs_size: 217, action_size: 514, hidden_size: hidden };
|
||||
let model = ResNet::<InferB>::new(&cfg, &infer_device());
|
||||
let obs: Vec<f32> = vec![0.5; 217];
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b1", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs.clone(), [1, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
let obs32: Vec<f32> = vec![0.5; 217 * 32];
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("resnet_b32", hidden),
|
||||
&hidden,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
let data = TensorData::new(obs32.clone(), [32, 217]);
|
||||
let t = Tensor::<InferB, 2>::from_data(data, &infer_device());
|
||||
black_box(model.forward(t))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue