refact dqn burn demo

This commit is contained in:
Henri Bourcereau 2025-08-08 17:07:34 +02:00
parent bf820ecc4e
commit 1b58ca4ccc
3 changed files with 83 additions and 82 deletions

View file

@ -1,9 +1,10 @@
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::Module;
use burn::record::{CompactRecorder, Recorder};
use bot::dqn::burnrl::{
dqn_model, environment,
utils::{demo_model, load_model, save_model},
};
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
@ -25,12 +26,9 @@ fn main() {
println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_50".to_string();
let path = "models/burn_dqn_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model);
@ -38,31 +36,3 @@ fn main() {
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}

View file

@ -1,12 +1,45 @@
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId};
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{path}_model.mpk");
println!("Chargement du modèle depuis : {model_path}");
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);

View file

@ -114,12 +114,11 @@ impl BotStrategy for DqnStrategy {
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() {
if let TrictracAction::Move {
if let Some(TrictracAction::Move {
dice_order,
from1,
from2,
} = action
}) = self.get_dqn_action()
{
let dicevals = self.game.dice.values;
let (mut dice1, mut dice2) = if dice_order {
@ -158,7 +157,6 @@ impl BotStrategy for DqnStrategy {
return chosen_move;
}
}
// Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);