refact dqn burn demo

This commit is contained in:
Henri Bourcereau 2025-08-08 17:07:34 +02:00
parent bf820ecc4e
commit 1b58ca4ccc
3 changed files with 83 additions and 82 deletions

View file

@ -1,9 +1,10 @@
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model}; use bot::dqn::burnrl::{
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; dqn_model, environment,
use burn::module::Module; utils::{demo_model, load_model, save_model},
use burn::record::{CompactRecorder, Recorder}; };
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State}; use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>; type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment; type Env = environment::TrictracEnvironment;
@ -25,12 +26,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation"); println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_50".to_string(); let path = "models/burn_dqn_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path); save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test"); println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path); let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model); let loaded_agent = DQN::new(loaded_model);
@ -38,31 +36,3 @@ fn main() {
println!("> Test avec le modèle chargé"); println!("> Test avec le modèle chargé");
demo_model(loaded_agent); demo_model(loaded_agent);
} }
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}

View file

@ -1,12 +1,45 @@
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment}; use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices; use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId}; use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement; use burn::tensor::cast::ToElement;
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN}; use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State}; use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{path}_model.mpk");
println!("Chargement du modèle depuis : {model_path}");
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) { pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true); let mut env = TrictracEnvironment::new(true);

View file

@ -114,50 +114,48 @@ impl BotStrategy for DqnStrategy {
fn choose_move(&self) -> (CheckerMove, CheckerMove) { fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement // Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() { if let Some(TrictracAction::Move {
if let TrictracAction::Move { dice_order,
dice_order, from1,
from1, from2,
from2, }) = self.get_dqn_action()
} = action {
{ let dicevals = self.game.dice.values;
let dicevals = self.game.dice.values; let (mut dice1, mut dice2) = if dice_order {
let (mut dice1, mut dice2) = if dice_order { (dicevals.0, dicevals.1)
(dicevals.0, dicevals.1) } else {
} else { (dicevals.1, dicevals.0)
(dicevals.1, dicevals.0) };
};
if from1 == 0 { if from1 == 0 {
// empty move // empty move
dice1 = 0; dice1 = 0;
}
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
} }
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
} }
// Fallback : utiliser la stratégie par défaut // Fallback : utiliser la stratégie par défaut