wip
This commit is contained in:
parent
354dcfd341
commit
c6d33555ec
|
|
@ -2,6 +2,7 @@ use crate::burnrl::utils::soft_update_linear;
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::optim::AdamWConfig;
|
use burn::optim::AdamWConfig;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
use burn::tensor::activation::relu;
|
use burn::tensor::activation::relu;
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
|
|
@ -138,5 +139,15 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save
|
||||||
|
let path = "models/burn_dqn".to_string();
|
||||||
|
let inference_network = agent.model().clone().into_record();
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{}_model.burn", path);
|
||||||
|
println!("Modèle sauvegardé : {}", model_path);
|
||||||
|
recorder
|
||||||
|
.record(inference_network, model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
agent.valid()
|
agent.valid()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
|
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
||||||
use burn::backend::{Autodiff, NdArray};
|
use burn::backend::{Autodiff, NdArray};
|
||||||
use burn_rl::base::ElemType;
|
use burn_rl::base::ElemType;
|
||||||
use bot::burnrl::{
|
|
||||||
dqn_model,
|
|
||||||
environment,
|
|
||||||
utils::demo_model,
|
|
||||||
};
|
|
||||||
|
|
||||||
type Backend = Autodiff<NdArray<ElemType>>;
|
type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
type Env = environment::TrictracEnvironment;
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let agent = dqn_model::run::<Env, Backend>(512, false); //true);
|
let num_episodes = 3;
|
||||||
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||||
|
|
||||||
demo_model::<Env>(agent);
|
demo_model::<Env>(agent);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||||
if let Some(action) = agent.react(&state) {
|
if let Some(action) = agent.react(&state) {
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
state = *snapshot.state();
|
state = *snapshot.state();
|
||||||
|
// println!("{:?}", state);
|
||||||
done = snapshot.done();
|
done = snapshot.done();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,17 @@
|
||||||
# Backlog
|
# Backlog
|
||||||
|
|
||||||
position dans tutoriel :
|
|
||||||
|
|
||||||
## DONE
|
## DONE
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
|
- bot burn
|
||||||
|
- train = `just trainbot`
|
||||||
|
- durée d'entrainement selon params ?
|
||||||
|
- save
|
||||||
|
- load and run against default bot
|
||||||
|
- many configs, save models selon config
|
||||||
|
- retrain against himself ?
|
||||||
|
|
||||||
### Doc
|
### Doc
|
||||||
|
|
||||||
Cheatsheet : arbre des situations et priorité des règles
|
Cheatsheet : arbre des situations et priorité des règles
|
||||||
|
|
|
||||||
3
justfile
3
justfile
|
|
@ -21,4 +21,5 @@ trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
# cargo run --bin=train_burn_rl # doesn't save model
|
# cargo run --bin=train_burn_rl # doesn't save model
|
||||||
cargo run --bin=train_dqn_full
|
# cargo run --bin=train_dqn_full
|
||||||
|
cargo run --bin=train_dqn_burn
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue