Compare commits

...

4 commits

Author SHA1 Message Date
Henri Bourcereau 66377f877c wip action mask 2025-07-23 22:28:59 +02:00
Henri Bourcereau 1e18b784d1 load inference model 2025-07-23 21:52:32 +02:00
Henri Bourcereau f3fc053dbd save inference model 2025-07-23 21:28:29 +02:00
Henri Bourcereau 6fa8a31cc7 refact : save model 2025-07-23 21:16:28 +02:00
7 changed files with 128 additions and 43 deletions

View file

@ -1,16 +1,15 @@
use crate::burnrl::utils::soft_update_linear;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use crate::burnrl::utils::soft_update_linear;
#[derive(Module, Debug)]
#[derive(Module, Debug, Clone)]
pub struct Net<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
@ -19,11 +18,11 @@ pub struct Net<B: Backend> {
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
linear_0: LinearConfig::new(input_size, dense_size).init(device),
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
linear_2: LinearConfig::new(dense_size, output_size).init(device),
}
}
@ -34,7 +33,7 @@ impl<B: Backend> Net<B> {
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_0_output = relu(self.linear_0.forward(input.clone()));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
relu(self.linear_2.forward(layer_1_output))
@ -46,8 +45,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
}
impl<B: Backend> DQNModel<B> for Net<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = self.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
@ -72,11 +71,13 @@ pub fn run<E: Environment, B: AutodiffBackend>(
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
let device = Default::default();
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
<E::StateType as State>::size(),
DENSE_SIZE,
<<E as Environment>::ActionType as Action>::size(),
<E::ActionType as Action>::size(),
&device,
);
let mut agent = MyAgent::new(model);
@ -107,7 +108,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let snapshot = env.step(action);
episode_reward +=
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
memory.push(
state,
@ -118,8 +119,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
);
if config.batch_size < memory.len() {
policy_net =
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
policy_net = agent.train(policy_net, &memory, &mut optimizer, &config);
}
step += 1;
@ -138,16 +138,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
}
}
}
// Save
let path = "models/burn_dqn".to_string();
let inference_network = agent.model().clone().into_record();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
println!("Modèle sauvegardé : {}", model_path);
recorder
.record(inference_network, model_path.into())
.unwrap();
agent.valid()
}
}

View file

@ -199,6 +199,15 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
pub fn valid_actions(&self) -> Vec<TrictracAction> {
dqn_common::get_valid_actions(&self.game)
.into_iter()
.map(|a| TrictracAction {
index: a.to_action_index() as u32,
})
.collect()
}
/// Convertit une action burn-rl vers une action Trictrac
fn convert_action(
&self,
@ -380,4 +389,4 @@ impl TrictracEnvironment {
}
reward
}
}
}

View file

@ -1,13 +1,61 @@
use bot::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{Autodiff, NdArray};
use burn_rl::base::ElemType;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::Module;
use burn::record::{CompactRecorder, Recorder};
use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
println!("> Entraînement");
let num_episodes = 3;
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
demo_model::<Env>(agent);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
save_model(valid_agent.model().as_ref().unwrap());
println!("> Chargement du modèle pour test");
let loaded_model = load_model();
let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé");
demo_model(loaded_agent, |env| env.valid_actions());
}
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
let path = "models/burn_dqn".to_string();
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.burn", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
const DENSE_SIZE: usize = 128;
let path = "models/burn_dqn".to_string();
let model_path = format!("{}_model.burn", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
DENSE_SIZE,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
&device,
)
.load_record(record)
}

View file

@ -1,21 +1,60 @@
use burn::module::{Param, ParamId};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use burn_rl::base::{Action, ElemType, Environment, State};
use burn_rl::agent::DQN;
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
pub fn demo_model<E, M, B, F>(
agent: DQN<E, B, M>,
mut get_valid_actions: F,
) where
E: Environment,
M: Module<B> + burn_rl::agent::DQNModel<B>,
B: Backend,
F: FnMut(&E) -> Vec<E::ActionType>,
<E as Environment>::ActionType: PartialEq,
{
let mut env = E::new(true);
let mut state = env.state();
let mut done = false;
let mut total_reward = 0.0;
let mut steps = 0;
while !done {
if let Some(action) = agent.react(&state) {
let snapshot = env.step(action);
state = *snapshot.state();
// println!("{:?}", state);
done = snapshot.done();
let model = agent.model().as_ref().unwrap();
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
let q_values = model.infer(state_tensor);
let valid_actions = get_valid_actions(&env);
if valid_actions.is_empty() {
break; // No valid actions, end of episode
}
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
masked_q_values =
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
}
}
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
let action = E::ActionType::from(action_index);
let snapshot = env.step(action);
state = *snapshot.state();
total_reward +=
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
steps += 1;
done = snapshot.done() || steps >= E::MAX_STEPS;
}
println!(
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
total_reward, steps
);
}
fn soft_update_tensor<const N: usize, B: Backend>(

View file

@ -1,4 +1,4 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use store::MoveRules;
#[derive(Debug)]

View file

@ -1,4 +1,4 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use std::path::Path;
use store::MoveRules;

View file

@ -1,7 +1,7 @@
use std::cmp::{max, min};
use serde::{Deserialize, Serialize};
use store::{CheckerMove, Dice, GameEvent, PlayerId};
use store::{CheckerMove, Dice};
/// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@ -259,7 +259,7 @@ impl SimpleNeuralNetwork {
/// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use crate::PointsRules;
use store::TurnStage;
let mut valid_actions = Vec::new();