refact dqn burn demo
This commit is contained in:
parent
bf820ecc4e
commit
1b58ca4ccc
|
|
@ -1,9 +1,10 @@
|
|||
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
|
||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use bot::dqn::burnrl::{
|
||||
dqn_model, environment,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
|
||||
use burn_rl::base::ElemType;
|
||||
|
||||
type Backend = Autodiff<NdArray<ElemType>>;
|
||||
type Env = environment::TrictracEnvironment;
|
||||
|
|
@ -25,12 +26,9 @@ fn main() {
|
|||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
|
||||
let path = "models/burn_dqn_50".to_string();
|
||||
let path = "models/burn_dqn_40".to_string();
|
||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||
|
||||
// println!("> Test avec le modèle entraîné");
|
||||
// demo_model::<Env>(valid_agent);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model(conf.dense_size, &path);
|
||||
let loaded_agent = DQN::new(loaded_model);
|
||||
|
|
@ -38,31 +36,3 @@ fn main() {
|
|||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
|
||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Modèle de validation sauvegardé : {}", model_path);
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Chargement du modèle depuis : {}", model_path);
|
||||
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
|
||||
let record = recorder
|
||||
.load(model_path.into(), &device)
|
||||
.expect("Impossible de charger le modèle");
|
||||
|
||||
dqn_model::Net::new(
|
||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,45 @@
|
|||
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
||||
use crate::dqn::burnrl::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||
use burn::module::{Param, ParamId};
|
||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{ElemType, Environment, State};
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
|
||||
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
println!("Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
|
||||
let record = recorder
|
||||
.load(model_path.into(), &device)
|
||||
.expect("Impossible de charger le modèle");
|
||||
|
||||
dqn_model::Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
}
|
||||
|
||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
|
|
|
|||
|
|
@ -114,12 +114,11 @@ impl BotStrategy for DqnStrategy {
|
|||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
// Utiliser le DQN pour choisir le mouvement
|
||||
if let Some(action) = self.get_dqn_action() {
|
||||
if let TrictracAction::Move {
|
||||
if let Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} = action
|
||||
}) = self.get_dqn_action()
|
||||
{
|
||||
let dicevals = self.game.dice.values;
|
||||
let (mut dice1, mut dice2) = if dice_order {
|
||||
|
|
@ -158,7 +157,6 @@ impl BotStrategy for DqnStrategy {
|
|||
|
||||
return chosen_move;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback : utiliser la stratégie par défaut
|
||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||
|
|
|
|||
Loading…
Reference in a new issue