46 lines
1.3 KiB
Rust
46 lines
1.3 KiB
Rust
use burn::module::{Param, ParamId};
|
|
use burn::nn::Linear;
|
|
use burn::tensor::backend::Backend;
|
|
use burn::tensor::Tensor;
|
|
use burn_rl::base::{Agent, ElemType, Environment};
|
|
|
|
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
|
let mut env = E::new(true);
|
|
let mut state = env.state();
|
|
let mut done = false;
|
|
while !done {
|
|
if let Some(action) = agent.react(&state) {
|
|
let snapshot = env.step(action);
|
|
state = *snapshot.state();
|
|
// println!("{:?}", state);
|
|
done = snapshot.done();
|
|
}
|
|
}
|
|
}
|
|
|
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
|
this: &Param<Tensor<B, N>>,
|
|
that: &Param<Tensor<B, N>>,
|
|
tau: ElemType,
|
|
) -> Param<Tensor<B, N>> {
|
|
let that_weight = that.val();
|
|
let this_weight = this.val();
|
|
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
|
|
|
Param::initialized(ParamId::new(), new_weight)
|
|
}
|
|
|
|
pub fn soft_update_linear<B: Backend>(
|
|
this: Linear<B>,
|
|
that: &Linear<B>,
|
|
tau: ElemType,
|
|
) -> Linear<B> {
|
|
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
|
let bias = match (&this.bias, &that.bias) {
|
|
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
|
_ => None,
|
|
};
|
|
|
|
Linear::<B> { weight, bias }
|
|
}
|