fix by gemini

This commit is contained in:
Henri Bourcereau 2025-06-28 22:18:39 +02:00
parent f05094b2d4
commit 6a7b1cbebc
5 changed files with 468 additions and 38 deletions

View file

@ -1,6 +1,7 @@
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
use bot::strategy::dqn_common::get_valid_actions;
use burn::optim::AdamConfig;
use burn_rl::base::Environment;
use std::env;
@ -116,7 +117,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
loop {
step += 1;
let current_state = snapshot.state();
let current_state_data = snapshot.state().data;
// Obtenir les actions valides selon le contexte du jeu
let valid_actions = get_valid_actions(&env.game);
@ -130,11 +131,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Sélectionner une action avec l'agent DQN
let action_index = agent.select_action(
&current_state
.data
.iter()
.map(|&x| x as f32)
.collect::<Vec<_>>(),
&current_state_data,
&valid_indices,
);
let action = TrictracAction {
@ -143,32 +140,32 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Exécuter l'action
snapshot = env.step(action);
episode_reward += snapshot.reward();
episode_reward += *snapshot.reward();
// Préparer l'expérience pour l'agent
let experience = Experience {
state: current_state.data.iter().map(|&x| x as f32).collect(),
state: current_state_data.to_vec(),
action: action_index,
reward: snapshot.reward(),
next_state: if snapshot.terminated {
reward: *snapshot.reward(),
next_state: if snapshot.done() {
None
} else {
Some(snapshot.state().data.iter().map(|&x| x as f32).collect())
Some(snapshot.state().data.to_vec())
},
done: snapshot.terminated,
done: snapshot.done(),
};
// Ajouter l'expérience au replay buffer
agent.add_experience(experience);
// Entraîner l'agent
if let Some(loss) = agent.train_step(optimizer) {
if let Some(loss) = agent.train_step(&mut optimizer) {
episode_loss += loss;
loss_count += 1;
}
// Vérifier les conditions de fin
if snapshot.terminated || step >= max_steps_per_episode {
if snapshot.done() || step >= max_steps_per_episode {
break;
}
}

View file

@ -1,10 +1,8 @@
use burn::module::AutodiffModule;
use burn::tensor::backend::AutodiffBackend;
use burn::{
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
module::Module,
nn::{loss::MseLoss, Linear, LinearConfig},
optim::{GradientsParams, Optimizer},
optim::Optimizer,
record::{CompactRecorder, Recorder},
tensor::Tensor,
};
@ -138,6 +136,8 @@ impl BurnDqnAgent {
/// Sélectionne une action avec epsilon-greedy
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
if valid_actions.is_empty() {
// Retourne une action par défaut ou une action "nulle" si aucune n'est valide
// Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte
return 0;
}
@ -148,7 +148,8 @@ impl BurnDqnAgent {
}
// Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device);
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device)
.reshape([1, self.config.state_size]);
let q_values = self.q_network.forward(state_tensor);
// Convertir en vecteur pour traitement
@ -177,9 +178,9 @@ impl BurnDqnAgent {
}
/// Entraîne le réseau sur un batch d'expériences
pub fn train_step<B: AutodiffBackend, M: AutodiffModule<B>>(
pub fn train_step(
&mut self,
optimizer: &mut impl Optimizer<M, B>,
optimizer: &mut impl Optimizer<DqnNetwork<MyBackend>, MyBackend>,
) -> Option<f32> {
if self.replay_buffer.len() < self.config.batch_size {
return None;
@ -189,8 +190,9 @@ impl BurnDqnAgent {
let batch = self.sample_batch();
// Préparer les tenseurs d'état
let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect();
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states.as_slice(), &self.device)
.reshape([self.config.batch_size, self.config.state_size]);
// Calculer les Q-values actuelles
let current_q_values = self.q_network.forward(state_tensor);
@ -208,8 +210,8 @@ impl BurnDqnAgent {
// Backpropagation (version simplifiée)
let grads = loss.backward();
// Gradients linked to each parameter of the model.
// let grads = GradientsParams::from_grads(grads, &self.q_network);
self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads);
let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network);
self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads);
// Mise à jour du réseau cible
self.step_count += 1;

View file

@ -64,11 +64,11 @@ impl BurnDqnStrategy {
// Convertir l'état du jeu en tensor
let state_vec = self.game.to_vec_float();
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats([state_vec], &self.device);
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats(state_vec.as_slice(), &self.device).reshape([1, self.config.as_ref().unwrap().state_size]);
// Faire une prédiction
let q_values = network.forward(state_tensor);
let q_data = q_values.into_data().convert::<f32>().value;
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
// Trouver la meilleure action parmi les actions valides
let mut best_action = &valid_actions[0];
@ -129,14 +129,14 @@ impl BotStrategy for BurnDqnStrategy {
fn calculate_points(&self) -> u8 {
// Utiliser le modèle DQN pour décider des points à marquer
let valid_actions = get_valid_actions(&self.game);
// let valid_actions = get_valid_actions(&self.game);
// Chercher une action Mark dans les actions valides
for action in &valid_actions {
if let super::dqn_common::TrictracAction::Mark { points } = action {
return *points;
}
}
// for action in &valid_actions {
// if let super::dqn_common::TrictracAction::Mark { points } = action {
// return *points;
// }
// }
// Par défaut, marquer 0 points
0

View file

@ -6,7 +6,7 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
}
impl State for TrictracState {
@ -24,14 +24,12 @@ impl State for TrictracState {
impl TrictracState {
/// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec();
let mut data = [0; 36];
let state_vec = game_state.to_vec_float();
let mut data = [0.0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36);
for i in 0..copy_len {
data[i] = state_vec[i];
}
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
TrictracState { data }
}