use burn::module::{Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; use burn::tensor::Tensor; use burn_rl::base::{Agent, ElemType, Environment}; pub struct Config { pub save_path: Option, pub max_steps: usize, pub num_episodes: usize, pub dense_size: usize, pub gamma: f32, pub tau: f32, pub learning_rate: f32, pub batch_size: usize, pub clip_grad: f32, // for SAC pub min_probability: f32, // for DQN pub eps_start: f64, pub eps_end: f64, pub eps_decay: f64, // for PPO pub lambda: f32, pub epsilon_clip: f32, pub critic_weight: f32, pub entropy_weight: f32, pub epochs: usize, } impl Default for Config { fn default() -> Self { Self { save_path: None, max_steps: 2000, num_episodes: 1000, dense_size: 256, gamma: 0.999, tau: 0.005, learning_rate: 0.001, batch_size: 32, clip_grad: 100.0, min_probability: 1e-9, eps_start: 0.9, eps_end: 0.05, eps_decay: 1000.0, lambda: 0.95, epsilon_clip: 0.2, critic_weight: 0.5, entropy_weight: 0.01, epochs: 8, } } } impl std::fmt::Display for Config { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let mut s = String::new(); s.push_str(&format!("max_steps={:?}\n", self.max_steps)); s.push_str(&format!("num_episodes={:?}\n", self.num_episodes)); s.push_str(&format!("dense_size={:?}\n", self.dense_size)); s.push_str(&format!("eps_start={:?}\n", self.eps_start)); s.push_str(&format!("eps_end={:?}\n", self.eps_end)); s.push_str(&format!("eps_decay={:?}\n", self.eps_decay)); s.push_str(&format!("gamma={:?}\n", self.gamma)); s.push_str(&format!("tau={:?}\n", self.tau)); s.push_str(&format!("learning_rate={:?}\n", self.learning_rate)); s.push_str(&format!("batch_size={:?}\n", self.batch_size)); s.push_str(&format!("clip_grad={:?}\n", self.clip_grad)); s.push_str(&format!("min_probability={:?}\n", self.min_probability)); s.push_str(&format!("lambda={:?}\n", self.lambda)); s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip)); s.push_str(&format!("critic_weight={:?}\n", self.critic_weight)); s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight)); s.push_str(&format!("epochs={:?}\n", self.epochs)); write!(f, "{s}") } } pub fn demo_model(agent: impl Agent) { let mut env = E::new(true); let mut state = env.state(); let mut done = false; while !done { if let Some(action) = agent.react(&state) { let snapshot = env.step(action); state = *snapshot.state(); done = snapshot.done(); } } } fn soft_update_tensor( this: &Param>, that: &Param>, tau: ElemType, ) -> Param> { let that_weight = that.val(); let this_weight = this.val(); let new_weight = this_weight * (1.0 - tau) + that_weight * tau; Param::initialized(ParamId::new(), new_weight) } pub fn soft_update_linear( this: Linear, that: &Linear, tau: ElemType, ) -> Linear { let weight = soft_update_tensor(&this.weight, &that.weight, tau); let bias = match (&this.bias, &that.bias) { (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), _ => None, }; Linear:: { weight, bias } }