use burn::module::{Param, ParamId}; use burn::nn::Linear; use burn::tensor::backend::Backend; use burn::tensor::Tensor; use burn_rl::base::{Agent, ElemType, Environment}; pub fn demo_model(agent: impl Agent) { let mut env = E::new(true); let mut state = env.state(); let mut done = false; while !done { if let Some(action) = agent.react(&state) { let snapshot = env.step(action); state = *snapshot.state(); // println!("{:?}", state); done = snapshot.done(); } } } fn soft_update_tensor( this: &Param>, that: &Param>, tau: ElemType, ) -> Param> { let that_weight = that.val(); let this_weight = this.val(); let new_weight = this_weight * (1.0 - tau) + that_weight * tau; Param::initialized(ParamId::new(), new_weight) } pub fn soft_update_linear( this: Linear, that: &Linear, tau: ElemType, ) -> Linear { let weight = soft_update_tensor(&this.weight, &that.weight, tau); let bias = match (&this.bias, &that.bias) { (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)), _ => None, }; Linear:: { weight, bias } }