2025-08-20 13:09:57 +02:00
|
|
|
use burn::module::{Param, ParamId};
|
|
|
|
|
use burn::nn::Linear;
|
|
|
|
|
use burn::tensor::backend::Backend;
|
|
|
|
|
use burn::tensor::Tensor;
|
|
|
|
|
use burn_rl::base::{Agent, ElemType, Environment};
|
2025-08-22 09:24:01 +02:00
|
|
|
use serde::{Deserialize, Serialize};
|
2025-08-20 13:09:57 +02:00
|
|
|
|
2025-08-22 09:24:01 +02:00
|
|
|
#[derive(Serialize, Deserialize)]
|
2025-08-20 13:09:57 +02:00
|
|
|
pub struct Config {
|
|
|
|
|
pub save_path: Option<String>,
|
2025-08-22 09:24:01 +02:00
|
|
|
pub max_steps: usize, // max steps by episode
|
2025-08-20 13:09:57 +02:00
|
|
|
pub num_episodes: usize,
|
2025-08-22 09:24:01 +02:00
|
|
|
pub dense_size: usize, // neural network complexity
|
2025-08-20 13:09:57 +02:00
|
|
|
|
2025-08-22 09:24:01 +02:00
|
|
|
// discount factor. Plus élevé = encourage stratégies à long terme
|
2025-08-20 13:09:57 +02:00
|
|
|
pub gamma: f32,
|
2025-08-22 09:24:01 +02:00
|
|
|
// soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation plus lente moins sensible aux coups de chance
|
2025-08-20 13:09:57 +02:00
|
|
|
pub tau: f32,
|
2025-08-22 09:24:01 +02:00
|
|
|
// taille du pas. Bas : plus lent, haut : risque de ne jamais
|
2025-08-20 13:09:57 +02:00
|
|
|
pub learning_rate: f32,
|
2025-08-22 09:24:01 +02:00
|
|
|
// nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
2025-08-20 13:09:57 +02:00
|
|
|
pub batch_size: usize,
|
2025-08-22 09:24:01 +02:00
|
|
|
// limite max de correction à apporter au gradient (default 100)
|
2025-08-20 13:09:57 +02:00
|
|
|
pub clip_grad: f32,
|
|
|
|
|
|
2025-08-22 09:24:01 +02:00
|
|
|
// ---- for SAC
|
2025-08-20 13:09:57 +02:00
|
|
|
pub min_probability: f32,
|
|
|
|
|
|
2025-08-22 09:24:01 +02:00
|
|
|
// ---- for DQN
|
|
|
|
|
// epsilon initial value (0.9 => more exploration)
|
2025-08-20 13:09:57 +02:00
|
|
|
pub eps_start: f64,
|
|
|
|
|
pub eps_end: f64,
|
2025-08-22 09:24:01 +02:00
|
|
|
// eps_decay higher = epsilon decrease slower
|
|
|
|
|
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
|
|
|
|
// epsilon is updated at the start of each episode
|
2025-08-20 13:09:57 +02:00
|
|
|
pub eps_decay: f64,
|
|
|
|
|
|
2025-08-22 09:24:01 +02:00
|
|
|
// ---- for PPO
|
2025-08-20 13:09:57 +02:00
|
|
|
pub lambda: f32,
|
|
|
|
|
pub epsilon_clip: f32,
|
|
|
|
|
pub critic_weight: f32,
|
|
|
|
|
pub entropy_weight: f32,
|
|
|
|
|
pub epochs: usize,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Default for Config {
|
|
|
|
|
fn default() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
save_path: None,
|
|
|
|
|
max_steps: 2000,
|
|
|
|
|
num_episodes: 1000,
|
|
|
|
|
dense_size: 256,
|
|
|
|
|
gamma: 0.999,
|
|
|
|
|
tau: 0.005,
|
|
|
|
|
learning_rate: 0.001,
|
|
|
|
|
batch_size: 32,
|
|
|
|
|
clip_grad: 100.0,
|
|
|
|
|
min_probability: 1e-9,
|
|
|
|
|
eps_start: 0.9,
|
|
|
|
|
eps_end: 0.05,
|
|
|
|
|
eps_decay: 1000.0,
|
|
|
|
|
lambda: 0.95,
|
|
|
|
|
epsilon_clip: 0.2,
|
|
|
|
|
critic_weight: 0.5,
|
|
|
|
|
entropy_weight: 0.01,
|
|
|
|
|
epochs: 8,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl std::fmt::Display for Config {
|
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
|
|
|
let mut s = String::new();
|
|
|
|
|
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
|
|
|
|
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
|
|
|
|
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
|
|
|
|
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
|
|
|
|
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
|
|
|
|
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
|
|
|
|
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
|
|
|
|
s.push_str(&format!("tau={:?}\n", self.tau));
|
|
|
|
|
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
|
|
|
|
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
|
|
|
|
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
|
|
|
|
s.push_str(&format!("min_probability={:?}\n", self.min_probability));
|
|
|
|
|
s.push_str(&format!("lambda={:?}\n", self.lambda));
|
|
|
|
|
s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip));
|
|
|
|
|
s.push_str(&format!("critic_weight={:?}\n", self.critic_weight));
|
|
|
|
|
s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight));
|
|
|
|
|
s.push_str(&format!("epochs={:?}\n", self.epochs));
|
|
|
|
|
write!(f, "{s}")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
|
|
|
|
let mut env = E::new(true);
|
|
|
|
|
let mut state = env.state();
|
|
|
|
|
let mut done = false;
|
|
|
|
|
while !done {
|
|
|
|
|
if let Some(action) = agent.react(&state) {
|
|
|
|
|
let snapshot = env.step(action);
|
|
|
|
|
state = *snapshot.state();
|
|
|
|
|
done = snapshot.done();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
|
|
|
|
this: &Param<Tensor<B, N>>,
|
|
|
|
|
that: &Param<Tensor<B, N>>,
|
|
|
|
|
tau: ElemType,
|
|
|
|
|
) -> Param<Tensor<B, N>> {
|
|
|
|
|
let that_weight = that.val();
|
|
|
|
|
let this_weight = this.val();
|
|
|
|
|
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
|
|
|
|
|
|
|
|
|
Param::initialized(ParamId::new(), new_weight)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn soft_update_linear<B: Backend>(
|
|
|
|
|
this: Linear<B>,
|
|
|
|
|
that: &Linear<B>,
|
|
|
|
|
tau: ElemType,
|
|
|
|
|
) -> Linear<B> {
|
|
|
|
|
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
|
|
|
|
let bias = match (&this.bias, &that.bias) {
|
|
|
|
|
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
|
|
|
|
_ => None,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Linear::<B> { weight, bias }
|
|
|
|
|
}
|