feat(spiel_bot): dqn

This commit is contained in:
Henri Bourcereau 2026-03-10 22:12:52 +01:00
parent 7c0f230e3d
commit e7d13c9a02
9 changed files with 1192 additions and 0 deletions

View file

@ -0,0 +1,247 @@
//! DQN self-play episode generation.
//!
//! Both players share the same Q-network (the [`TrictracEnv`] handles board
//! mirroring so that each player always acts from "White's perspective").
//! Transitions for both players are stored in the returned sample list.
//!
//! # Reward
//!
//! After each full decision (action applied and the state has advanced through
//! any intervening chance nodes back to the same player's next turn), the
//! reward is:
//!
//! ```text
//! r = (my_total_score_now my_total_score_then)
//! (opp_total_score_now opp_total_score_then)
//! ```
//!
//! where `total_score = holes × 12 + points`.
//!
//! # Transition structure
//!
//! We use a "pending transition" per player. When a player acts again, we
//! *complete* the previous pending transition by filling in `next_obs`,
//! `next_legal`, and computing `reward`. Terminal transitions are completed
//! when the game ends.
use burn::tensor::{backend::Backend, Tensor, TensorData};
use rand::Rng;
use crate::env::{GameEnv, TrictracEnv};
use crate::network::QValueNet;
use super::DqnSample;
// ── Internals ─────────────────────────────────────────────────────────────────
struct PendingTransition {
obs: Vec<f32>,
action: usize,
/// Score snapshot `[p1_total, p2_total]` at the moment of the action.
score_before: [i32; 2],
}
/// Pick an action ε-greedily: random with probability `epsilon`, greedy otherwise.
fn epsilon_greedy<B: Backend, Q: QValueNet<B>>(
q_net: &Q,
obs: &[f32],
legal: &[usize],
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
) -> usize {
debug_assert!(!legal.is_empty(), "epsilon_greedy: no legal actions");
if rng.random::<f32>() < epsilon {
legal[rng.random_range(0..legal.len())]
} else {
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs.to_vec(), [1, obs.len()]),
device,
);
let q_values: Vec<f32> = q_net.forward(obs_tensor).into_data().to_vec().unwrap();
legal
.iter()
.copied()
.max_by(|&a, &b| {
q_values[a].partial_cmp(&q_values[b]).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap()
}
}
/// Reward for `player_idx` (0 = P1, 1 = P2) given score snapshots before/after.
fn compute_reward(player_idx: usize, score_before: &[i32; 2], score_after: &[i32; 2]) -> f32 {
let opp_idx = 1 - player_idx;
((score_after[player_idx] - score_before[player_idx])
- (score_after[opp_idx] - score_before[opp_idx])) as f32
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Play one full game and return all transitions for both players.
///
/// - `q_net` uses the **inference backend** (no autodiff wrapper).
/// - `epsilon` in `[0, 1]`: probability of taking a random action.
/// - `reward_scale`: reward divisor (e.g. `12.0` to map one hole → `±1`).
pub fn generate_dqn_episode<B: Backend, Q: QValueNet<B>>(
env: &TrictracEnv,
q_net: &Q,
epsilon: f32,
rng: &mut impl Rng,
device: &B::Device,
reward_scale: f32,
) -> Vec<DqnSample> {
let obs_size = env.obs_size();
let mut state = env.new_game();
let mut pending: [Option<PendingTransition>; 2] = [None, None];
let mut samples: Vec<DqnSample> = Vec::new();
loop {
// ── Advance past chance nodes ──────────────────────────────────────
while env.current_player(&state).is_chance() {
env.apply_chance(&mut state, rng);
}
let score_now = TrictracEnv::score_snapshot(&state);
if env.current_player(&state).is_terminal() {
// Complete all pending transitions as terminal.
for player_idx in 0..2 {
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: vec![0.0; obs_size],
next_legal: vec![],
done: true,
});
}
}
break;
}
let player_idx = env.current_player(&state).index().unwrap();
let legal = env.legal_actions(&state);
let obs = env.observation(&state, player_idx);
// ── Complete the previous transition for this player ───────────────
if let Some(prev) = pending[player_idx].take() {
let reward =
compute_reward(player_idx, &prev.score_before, &score_now) / reward_scale;
samples.push(DqnSample {
obs: prev.obs,
action: prev.action,
reward,
next_obs: obs.clone(),
next_legal: legal.clone(),
done: false,
});
}
// ── Pick and apply action ──────────────────────────────────────────
let action = epsilon_greedy(q_net, &obs, &legal, epsilon, rng, device);
env.apply(&mut state, action);
// ── Record new pending transition ──────────────────────────────────
pending[player_idx] = Some(PendingTransition {
obs,
action,
score_before: score_now,
});
}
samples
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use rand::{SeedableRng, rngs::SmallRng};
use crate::network::{QNet, QNetConfig};
type B = NdArray<f32>;
fn device() -> <B as Backend>::Device { Default::default() }
fn rng() -> SmallRng { SmallRng::seed_from_u64(7) }
fn tiny_q() -> QNet<B> {
QNet::new(&QNetConfig::default(), &device())
}
#[test]
fn episode_terminates_and_produces_samples() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty(), "episode must produce at least one sample");
}
#[test]
fn episode_obs_size_correct() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert_eq!(s.obs.len(), 217, "obs size mismatch");
if s.done {
assert_eq!(s.next_obs.len(), 217, "done next_obs should be zeros of obs_size");
assert!(s.next_legal.is_empty());
} else {
assert_eq!(s.next_obs.len(), 217, "next_obs size mismatch");
assert!(!s.next_legal.is_empty());
}
}
}
#[test]
fn episode_actions_within_action_space() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
for s in &samples {
assert!(s.action < 514, "action {} out of bounds", s.action);
}
}
#[test]
fn greedy_episode_also_terminates() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 0.0, &mut rng(), &device(), 1.0);
assert!(!samples.is_empty());
}
#[test]
fn at_least_one_done_sample() {
let env = TrictracEnv;
let q = tiny_q();
let samples = generate_dqn_episode(&env, &q, 1.0, &mut rng(), &device(), 1.0);
let n_done = samples.iter().filter(|s| s.done).count();
// Two players, so 1 or 2 terminal transitions.
assert!(n_done >= 1 && n_done <= 2, "expected 1-2 done samples, got {n_done}");
}
#[test]
fn compute_reward_correct() {
// P1 gains 4 points (2 holes 10 pts → 3 holes 2 pts), opp unchanged.
let before = [2 * 12 + 10, 0];
let after = [3 * 12 + 2, 0];
let r = compute_reward(0, &before, &after);
assert!((r - 4.0).abs() < 1e-6, "expected 4.0, got {r}");
}
#[test]
fn compute_reward_with_opponent_scoring() {
// P1 gains 2, opp gains 3 → net = -1 from P1's perspective.
let before = [0, 0];
let after = [2, 3];
let r = compute_reward(0, &before, &after);
assert!((r - (-1.0)).abs() < 1e-6, "expected -1.0, got {r}");
}
}

232
spiel_bot/src/dqn/mod.rs Normal file
View file

@ -0,0 +1,232 @@
//! DQN: self-play data generation, replay buffer, and training step.
//!
//! # Algorithm
//!
//! Deep Q-Network with:
//! - **ε-greedy** exploration (linearly decayed).
//! - **Dense per-turn rewards**: `my_score_delta opponent_score_delta` where
//! `score = holes × 12 + points`.
//! - **Experience replay** with a fixed-capacity circular buffer.
//! - **Target network**: hard-copied from the online Q-net every
//! `target_update_freq` gradient steps for training stability.
//!
//! # Modules
//!
//! | Module | Contents |
//! |--------|----------|
//! | [`episode`] | [`DqnSample`], [`generate_dqn_episode`] |
//! | [`trainer`] | [`dqn_train_step`], [`compute_target_q`], [`hard_update`] |
pub mod episode;
pub mod trainer;
pub use episode::generate_dqn_episode;
pub use trainer::{compute_target_q, dqn_train_step, hard_update};
use std::collections::VecDeque;
use rand::Rng;
// ── DqnSample ─────────────────────────────────────────────────────────────────
/// One transition `(s, a, r, s', done)` collected during self-play.
#[derive(Clone, Debug)]
pub struct DqnSample {
/// Observation from the acting player's perspective (`obs_size` floats).
pub obs: Vec<f32>,
/// Action index taken.
pub action: usize,
/// Per-turn reward: `my_score_delta opponent_score_delta`.
pub reward: f32,
/// Next observation from the same player's perspective.
/// All-zeros when `done = true` (ignored by the TD target).
pub next_obs: Vec<f32>,
/// Legal actions at `next_obs`. Empty when `done = true`.
pub next_legal: Vec<usize>,
/// `true` when `next_obs` is a terminal state.
pub done: bool,
}
// ── DqnReplayBuffer ───────────────────────────────────────────────────────────
/// Fixed-capacity circular replay buffer for [`DqnSample`]s.
///
/// When full, the oldest sample is evicted on push.
/// Batches are drawn without replacement via a partial Fisher-Yates shuffle.
pub struct DqnReplayBuffer {
data: VecDeque<DqnSample>,
capacity: usize,
}
impl DqnReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self { data: VecDeque::with_capacity(capacity.min(1024)), capacity }
}
pub fn push(&mut self, sample: DqnSample) {
if self.data.len() == self.capacity {
self.data.pop_front();
}
self.data.push_back(sample);
}
pub fn extend(&mut self, samples: impl IntoIterator<Item = DqnSample>) {
for s in samples { self.push(s); }
}
pub fn len(&self) -> usize { self.data.len() }
pub fn is_empty(&self) -> bool { self.data.is_empty() }
/// Sample up to `n` distinct samples without replacement.
pub fn sample_batch(&self, n: usize, rng: &mut impl Rng) -> Vec<&DqnSample> {
let len = self.data.len();
let n = n.min(len);
let mut indices: Vec<usize> = (0..len).collect();
for i in 0..n {
let j = rng.random_range(i..len);
indices.swap(i, j);
}
indices[..n].iter().map(|&i| &self.data[i]).collect()
}
}
// ── DqnConfig ─────────────────────────────────────────────────────────────────
/// Top-level DQN hyperparameters for the training loop.
#[derive(Debug, Clone)]
pub struct DqnConfig {
/// Initial exploration rate (1.0 = fully random).
pub epsilon_start: f32,
/// Final exploration rate after decay.
pub epsilon_end: f32,
/// Number of gradient steps over which ε decays linearly from start to end.
///
/// Should be calibrated to the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`). A value larger than that
/// means exploration never reaches `epsilon_end` during the run.
pub epsilon_decay_steps: usize,
/// Discount factor γ for the TD target. Typical: 0.99.
pub gamma: f32,
/// Hard-copy Q → target every this many gradient steps.
///
/// Should be much smaller than the total number of gradient steps
/// (`n_iterations × n_train_steps_per_iter`).
pub target_update_freq: usize,
/// Adam learning rate.
pub learning_rate: f64,
/// Mini-batch size for each gradient step.
pub batch_size: usize,
/// Maximum number of samples in the replay buffer.
pub replay_capacity: usize,
/// Number of outer iterations (self-play + train).
pub n_iterations: usize,
/// Self-play games per iteration.
pub n_games_per_iter: usize,
/// Gradient steps per iteration.
pub n_train_steps_per_iter: usize,
/// Reward normalisation divisor.
///
/// Per-turn rewards (score delta) are divided by this constant before being
/// stored. Without normalisation, rewards can reach ±24 (jan with
/// bredouille = 12 pts × 2), driving Q-values into the hundreds and
/// causing MSE loss to grow unboundedly.
///
/// A value of `12.0` maps one hole (12 points) to `±1.0`, keeping
/// Q-value magnitudes in a stable range. Set to `1.0` to disable.
pub reward_scale: f32,
}
impl Default for DqnConfig {
fn default() -> Self {
// Total gradient steps with these defaults = 500 × 20 = 10_000,
// so epsilon decays fully and the target is updated 100 times.
Self {
epsilon_start: 1.0,
epsilon_end: 0.05,
epsilon_decay_steps: 10_000,
gamma: 0.99,
target_update_freq: 100,
learning_rate: 1e-3,
batch_size: 64,
replay_capacity: 50_000,
n_iterations: 500,
n_games_per_iter: 10,
n_train_steps_per_iter: 20,
reward_scale: 12.0,
}
}
}
/// Linear ε schedule: decays from `start` to `end` over `decay_steps` steps.
pub fn linear_epsilon(start: f32, end: f32, step: usize, decay_steps: usize) -> f32 {
if decay_steps == 0 || step >= decay_steps {
return end;
}
start + (end - start) * (step as f32 / decay_steps as f32)
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use rand::{SeedableRng, rngs::SmallRng};
fn dummy(reward: f32) -> DqnSample {
DqnSample {
obs: vec![0.0],
action: 0,
reward,
next_obs: vec![0.0],
next_legal: vec![0],
done: false,
}
}
#[test]
fn push_and_len() {
let mut buf = DqnReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(dummy(1.0));
buf.push(dummy(2.0));
assert_eq!(buf.len(), 2);
}
#[test]
fn evicts_oldest_at_capacity() {
let mut buf = DqnReplayBuffer::new(3);
buf.push(dummy(1.0));
buf.push(dummy(2.0));
buf.push(dummy(3.0));
buf.push(dummy(4.0));
assert_eq!(buf.len(), 3);
assert_eq!(buf.data[0].reward, 2.0);
}
#[test]
fn sample_batch_size() {
let mut buf = DqnReplayBuffer::new(20);
for i in 0..10 { buf.push(dummy(i as f32)); }
let mut rng = SmallRng::seed_from_u64(0);
assert_eq!(buf.sample_batch(5, &mut rng).len(), 5);
}
#[test]
fn linear_epsilon_start() {
assert!((linear_epsilon(1.0, 0.05, 0, 100) - 1.0).abs() < 1e-6);
}
#[test]
fn linear_epsilon_end() {
assert!((linear_epsilon(1.0, 0.05, 100, 100) - 0.05).abs() < 1e-6);
}
#[test]
fn linear_epsilon_monotone() {
let mut prev = f32::INFINITY;
for step in 0..=100 {
let e = linear_epsilon(1.0, 0.05, step, 100);
assert!(e <= prev + 1e-6);
prev = e;
}
}
}

View file

@ -0,0 +1,278 @@
//! DQN gradient step and target-network management.
//!
//! # TD target
//!
//! ```text
//! y_i = r_i + γ · max_{a ∈ legal_next_i} Q_target(s'_i, a) if not done
//! y_i = r_i if done
//! ```
//!
//! # Loss
//!
//! Mean-squared error between `Q(s_i, a_i)` (gathered from the online net)
//! and `y_i` (computed from the frozen target net).
//!
//! # Target network
//!
//! [`hard_update`] copies the online Q-net weights into the target net by
//! stripping the autodiff wrapper via [`AutodiffModule::valid`].
use burn::{
module::AutodiffModule,
optim::{GradientsParams, Optimizer},
prelude::ElementConversion,
tensor::{
Int, Tensor, TensorData,
backend::{AutodiffBackend, Backend},
},
};
use crate::network::QValueNet;
use super::DqnSample;
// ── Target Q computation ─────────────────────────────────────────────────────
/// Compute `max_{a ∈ legal} Q_target(s', a)` for every non-done sample.
///
/// Returns a `Vec<f32>` of length `batch.len()`. Done samples get `0.0`
/// (their bootstrap term is dropped by the TD target anyway).
///
/// The target network runs on the **inference backend** (`InferB`) with no
/// gradient tape, so this function is backend-agnostic (`B: Backend`).
pub fn compute_target_q<B: Backend, Q: QValueNet<B>>(
target_net: &Q,
batch: &[DqnSample],
action_size: usize,
device: &B::Device,
) -> Vec<f32> {
let batch_size = batch.len();
// Collect indices of non-done samples (done samples have no next state).
let non_done: Vec<usize> = batch
.iter()
.enumerate()
.filter(|(_, s)| !s.done)
.map(|(i, _)| i)
.collect();
if non_done.is_empty() {
return vec![0.0; batch_size];
}
let obs_size = batch[0].next_obs.len();
let nd = non_done.len();
// Stack next observations for non-done samples → [nd, obs_size].
let obs_flat: Vec<f32> = non_done
.iter()
.flat_map(|&i| batch[i].next_obs.iter().copied())
.collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [nd, obs_size]),
device,
);
// Forward target net → [nd, action_size], then to Vec<f32>.
let q_flat: Vec<f32> = target_net.forward(obs_tensor).into_data().to_vec().unwrap();
// For each non-done sample, pick max Q over legal next actions.
let mut result = vec![0.0f32; batch_size];
for (k, &i) in non_done.iter().enumerate() {
let legal = &batch[i].next_legal;
let offset = k * action_size;
let max_q = legal
.iter()
.map(|&a| q_flat[offset + a])
.fold(f32::NEG_INFINITY, f32::max);
// If legal is empty (shouldn't happen for non-done, but be safe):
result[i] = if max_q.is_finite() { max_q } else { 0.0 };
}
result
}
// ── Training step ─────────────────────────────────────────────────────────────
/// Run one gradient step on `q_net` using `batch`.
///
/// `target_max_q` must be pre-computed via [`compute_target_q`] using the
/// frozen target network and passed in here so that this function only
/// needs the **autodiff backend**.
///
/// Returns the updated network and the scalar MSE loss.
pub fn dqn_train_step<B, Q, O>(
q_net: Q,
optimizer: &mut O,
batch: &[DqnSample],
target_max_q: &[f32],
device: &B::Device,
lr: f64,
gamma: f32,
) -> (Q, f32)
where
B: AutodiffBackend,
Q: QValueNet<B> + AutodiffModule<B>,
O: Optimizer<Q, B>,
{
assert!(!batch.is_empty(), "dqn_train_step: empty batch");
assert_eq!(batch.len(), target_max_q.len(), "batch and target_max_q length mismatch");
let batch_size = batch.len();
let obs_size = batch[0].obs.len();
// ── Build observation tensor [B, obs_size] ────────────────────────────
let obs_flat: Vec<f32> = batch.iter().flat_map(|s| s.obs.iter().copied()).collect();
let obs_tensor = Tensor::<B, 2>::from_data(
TensorData::new(obs_flat, [batch_size, obs_size]),
device,
);
// ── Forward Q-net → [B, action_size] ─────────────────────────────────
let q_all = q_net.forward(obs_tensor);
// ── Gather Q(s, a) for the taken action → [B] ────────────────────────
let actions: Vec<i32> = batch.iter().map(|s| s.action as i32).collect();
let action_tensor: Tensor<B, 2, Int> = Tensor::<B, 1, Int>::from_data(
TensorData::new(actions, [batch_size]),
device,
)
.reshape([batch_size, 1]); // [B] → [B, 1]
let q_pred: Tensor<B, 1> = q_all.gather(1, action_tensor).reshape([batch_size]); // [B, 1] → [B]
// ── TD targets: r + γ · max_next_q · (1 done) ──────────────────────
let targets: Vec<f32> = batch
.iter()
.zip(target_max_q.iter())
.map(|(s, &max_q)| {
if s.done { s.reward } else { s.reward + gamma * max_q }
})
.collect();
let target_tensor = Tensor::<B, 1>::from_data(
TensorData::new(targets, [batch_size]),
device,
);
// ── MSE loss ──────────────────────────────────────────────────────────
let diff = q_pred - target_tensor.detach();
let loss = (diff.clone() * diff).mean();
let loss_scalar: f32 = loss.clone().into_scalar().elem();
// ── Backward + optimizer step ─────────────────────────────────────────
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &q_net);
let q_net = optimizer.step(lr, q_net, grads);
(q_net, loss_scalar)
}
// ── Target network update ─────────────────────────────────────────────────────
/// Hard-copy the online Q-net weights to a new target network.
///
/// Strips the autodiff wrapper via [`AutodiffModule::valid`], returning an
/// inference-backend module with identical weights.
pub fn hard_update<B: AutodiffBackend, Q: AutodiffModule<B>>(q_net: &Q) -> Q::InnerModule {
q_net.valid()
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::AdamConfig,
};
use crate::network::{QNet, QNetConfig};
type InferB = NdArray<f32>;
type TrainB = Autodiff<NdArray<f32>>;
fn infer_device() -> <InferB as Backend>::Device { Default::default() }
fn train_device() -> <TrainB as Backend>::Device { Default::default() }
fn dummy_batch(n: usize, obs_size: usize, action_size: usize) -> Vec<DqnSample> {
(0..n)
.map(|i| DqnSample {
obs: vec![0.5f32; obs_size],
action: i % action_size,
reward: if i % 2 == 0 { 1.0 } else { -1.0 },
next_obs: vec![0.5f32; obs_size],
next_legal: vec![0, 1],
done: i == n - 1,
})
.collect()
}
#[test]
fn compute_target_q_length() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 8);
}
#[test]
fn compute_target_q_done_is_zero() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let target = QNet::<InferB>::new(&cfg, &infer_device());
// Single done sample.
let batch = vec![DqnSample {
obs: vec![0.0; 4],
action: 0,
reward: 5.0,
next_obs: vec![0.0; 4],
next_legal: vec![],
done: true,
}];
let tq = compute_target_q(&target, &batch, 4, &infer_device());
assert_eq!(tq.len(), 1);
assert_eq!(tq[0], 0.0);
}
#[test]
fn train_step_returns_finite_loss() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 16 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(8, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let (_, loss) = dqn_train_step(q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-3, 0.99);
assert!(loss.is_finite(), "loss must be finite, got {loss}");
}
#[test]
fn train_step_loss_decreases() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 32 };
let mut q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = QNet::<InferB>::new(&cfg, &infer_device());
let mut optimizer = AdamConfig::new().init();
let batch = dummy_batch(16, 4, 4);
let tq = compute_target_q(&target, &batch, 4, &infer_device());
let mut prev_loss = f32::INFINITY;
for _ in 0..10 {
let (q, loss) = dqn_train_step(
q_net, &mut optimizer, &batch, &tq, &train_device(), 1e-2, 0.99,
);
q_net = q;
assert!(loss.is_finite());
prev_loss = loss;
}
assert!(prev_loss < 5.0, "loss did not decrease: {prev_loss}");
}
#[test]
fn hard_update_copies_weights() {
let cfg = QNetConfig { obs_size: 4, action_size: 4, hidden_size: 8 };
let q_net = QNet::<TrainB>::new(&cfg, &train_device());
let target = hard_update::<TrainB, _>(&q_net);
let obs = burn::tensor::Tensor::<InferB, 2>::zeros([1, 4], &infer_device());
let q_out: Vec<f32> = target.forward(obs).into_data().to_vec().unwrap();
// After hard_update the target produces finite outputs.
assert!(q_out.iter().all(|v| v.is_finite()));
}
}