refact dqn burn demo

This commit is contained in:
Henri Bourcereau 2025-08-08 17:07:34 +02:00
parent bf820ecc4e
commit 1b58ca4ccc
3 changed files with 83 additions and 82 deletions

View file

@ -1,9 +1,10 @@
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; use bot::dqn::burnrl::{
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; dqn_model, environment,
use burn::module::Module; utils::{demo_model, load_model, save_model},
use burn::record::{CompactRecorder, Recorder}; };
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State}; use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>; type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment; type Env = environment::TrictracEnvironment;
@ -25,12 +26,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation"); println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_50".to_string(); let path = "models/burn_dqn_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path); save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path); let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model); let loaded_agent = DQN::new(loaded_model);
@ -38,31 +36,3 @@ fn main() {
println!("> Test avec le modèle chargé"); println!("> Test avec le modèle chargé");
demo_model(loaded_agent); demo_model(loaded_agent);
} }
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}

View file

@ -1,12 +1,45 @@
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices; use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement; use burn::tensor::cast::ToElement;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN}; use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State}; use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{path}_model.mpk");
println!("Chargement du modèle depuis : {model_path}");
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) { pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true); let mut env = TrictracEnvironment::new(true);

View file

@ -114,12 +114,11 @@ impl BotStrategy for DqnStrategy {
fn choose_move(&self) -> (CheckerMove, CheckerMove) { fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement // Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() { if let Some(TrictracAction::Move {
if let TrictracAction::Move {
dice_order, dice_order,
from1, from1,
from2, from2,
} = action }) = self.get_dqn_action()
{ {
let dicevals = self.game.dice.values; let dicevals = self.game.dice.values;
let (mut dice1, mut dice2) = if dice_order { let (mut dice1, mut dice2) = if dice_order {
@ -158,7 +157,6 @@ impl BotStrategy for DqnStrategy {
return chosen_move; return chosen_move;
} }
}
// Fallback : utiliser la stratégie par défaut // Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice); let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);