refact dqn burn demo
This commit is contained in:
parent
bf820ecc4e
commit
1b58ca4ccc
|
|
@ -1,9 +1,10 @@
|
||||||
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
|
use bot::dqn::burnrl::{
|
||||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
dqn_model, environment,
|
||||||
use burn::module::Module;
|
utils::{demo_model, load_model, save_model},
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
};
|
||||||
|
use burn::backend::{Autodiff, NdArray};
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
|
use burn_rl::base::ElemType;
|
||||||
|
|
||||||
type Backend = Autodiff<NdArray<ElemType>>;
|
type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
type Env = environment::TrictracEnvironment;
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
@ -25,12 +26,9 @@ fn main() {
|
||||||
|
|
||||||
println!("> Sauvegarde du modèle de validation");
|
println!("> Sauvegarde du modèle de validation");
|
||||||
|
|
||||||
let path = "models/burn_dqn_50".to_string();
|
let path = "models/burn_dqn_40".to_string();
|
||||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||||
|
|
||||||
// println!("> Test avec le modèle entraîné");
|
|
||||||
// demo_model::<Env>(valid_agent);
|
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model(conf.dense_size, &path);
|
let loaded_model = load_model(conf.dense_size, &path);
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model);
|
||||||
|
|
@ -38,31 +36,3 @@ fn main() {
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model(loaded_agent);
|
demo_model(loaded_agent);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{}_model.mpk", path);
|
|
||||||
println!("Modèle de validation sauvegardé : {}", model_path);
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
|
||||||
let model_path = format!("{}_model.mpk", path);
|
|
||||||
println!("Chargement du modèle depuis : {}", model_path);
|
|
||||||
|
|
||||||
let device = NdArrayDevice::default();
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
|
|
||||||
let record = recorder
|
|
||||||
.load(model_path.into(), &device)
|
|
||||||
.expect("Impossible de charger le modèle");
|
|
||||||
|
|
||||||
dqn_model::Net::new(
|
|
||||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,45 @@
|
||||||
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
use crate::dqn::burnrl::{
|
||||||
|
dqn_model,
|
||||||
|
environment::{TrictracAction, TrictracEnvironment},
|
||||||
|
};
|
||||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||||
use burn::module::{Param, ParamId};
|
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||||
|
use burn::module::{Module, Param, ParamId};
|
||||||
use burn::nn::Linear;
|
use burn::nn::Linear;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
use burn::tensor::backend::Backend;
|
use burn::tensor::backend::Backend;
|
||||||
use burn::tensor::cast::ToElement;
|
use burn::tensor::cast::ToElement;
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::{DQNModel, DQN};
|
use burn_rl::agent::{DQNModel, DQN};
|
||||||
use burn_rl::base::{ElemType, Environment, State};
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
|
||||||
|
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
println!("Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
|
||||||
|
let record = recorder
|
||||||
|
.load(model_path.into(), &device)
|
||||||
|
.expect("Impossible de charger le modèle");
|
||||||
|
|
||||||
|
dqn_model::Net::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||||
let mut env = TrictracEnvironment::new(true);
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
|
|
||||||
|
|
@ -114,12 +114,11 @@ impl BotStrategy for DqnStrategy {
|
||||||
|
|
||||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
// Utiliser le DQN pour choisir le mouvement
|
// Utiliser le DQN pour choisir le mouvement
|
||||||
if let Some(action) = self.get_dqn_action() {
|
if let Some(TrictracAction::Move {
|
||||||
if let TrictracAction::Move {
|
|
||||||
dice_order,
|
dice_order,
|
||||||
from1,
|
from1,
|
||||||
from2,
|
from2,
|
||||||
} = action
|
}) = self.get_dqn_action()
|
||||||
{
|
{
|
||||||
let dicevals = self.game.dice.values;
|
let dicevals = self.game.dice.values;
|
||||||
let (mut dice1, mut dice2) = if dice_order {
|
let (mut dice1, mut dice2) = if dice_order {
|
||||||
|
|
@ -158,7 +157,6 @@ impl BotStrategy for DqnStrategy {
|
||||||
|
|
||||||
return chosen_move;
|
return chosen_move;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback : utiliser la stratégie par défaut
|
// Fallback : utiliser la stratégie par défaut
|
||||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue