fix by gemini
This commit is contained in:
parent
f05094b2d4
commit
6a7b1cbebc
5 changed files with 468 additions and 38 deletions
|
|
@ -1,6 +1,7 @@
|
|||
use bot::strategy::burn_dqn_agent::{BurnDqnAgent, DqnConfig, Experience};
|
||||
use bot::strategy::burn_environment::{TrictracAction, TrictracEnvironment};
|
||||
use bot::strategy::dqn_common::get_valid_actions;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn_rl::base::Environment;
|
||||
use std::env;
|
||||
|
||||
|
|
@ -116,7 +117,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
loop {
|
||||
step += 1;
|
||||
let current_state = snapshot.state();
|
||||
let current_state_data = snapshot.state().data;
|
||||
|
||||
// Obtenir les actions valides selon le contexte du jeu
|
||||
let valid_actions = get_valid_actions(&env.game);
|
||||
|
|
@ -130,11 +131,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
// Sélectionner une action avec l'agent DQN
|
||||
let action_index = agent.select_action(
|
||||
¤t_state
|
||||
.data
|
||||
.iter()
|
||||
.map(|&x| x as f32)
|
||||
.collect::<Vec<_>>(),
|
||||
¤t_state_data,
|
||||
&valid_indices,
|
||||
);
|
||||
let action = TrictracAction {
|
||||
|
|
@ -143,32 +140,32 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
// Exécuter l'action
|
||||
snapshot = env.step(action);
|
||||
episode_reward += snapshot.reward();
|
||||
episode_reward += *snapshot.reward();
|
||||
|
||||
// Préparer l'expérience pour l'agent
|
||||
let experience = Experience {
|
||||
state: current_state.data.iter().map(|&x| x as f32).collect(),
|
||||
state: current_state_data.to_vec(),
|
||||
action: action_index,
|
||||
reward: snapshot.reward(),
|
||||
next_state: if snapshot.terminated {
|
||||
reward: *snapshot.reward(),
|
||||
next_state: if snapshot.done() {
|
||||
None
|
||||
} else {
|
||||
Some(snapshot.state().data.iter().map(|&x| x as f32).collect())
|
||||
Some(snapshot.state().data.to_vec())
|
||||
},
|
||||
done: snapshot.terminated,
|
||||
done: snapshot.done(),
|
||||
};
|
||||
|
||||
// Ajouter l'expérience au replay buffer
|
||||
agent.add_experience(experience);
|
||||
|
||||
// Entraîner l'agent
|
||||
if let Some(loss) = agent.train_step(optimizer) {
|
||||
if let Some(loss) = agent.train_step(&mut optimizer) {
|
||||
episode_loss += loss;
|
||||
loss_count += 1;
|
||||
}
|
||||
|
||||
// Vérifier les conditions de fin
|
||||
if snapshot.terminated || step >= max_steps_per_episode {
|
||||
if snapshot.done() || step >= max_steps_per_episode {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
use burn::module::AutodiffModule;
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use burn::{
|
||||
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
|
||||
module::Module,
|
||||
nn::{loss::MseLoss, Linear, LinearConfig},
|
||||
optim::{GradientsParams, Optimizer},
|
||||
optim::Optimizer,
|
||||
record::{CompactRecorder, Recorder},
|
||||
tensor::Tensor,
|
||||
};
|
||||
|
|
@ -138,6 +136,8 @@ impl BurnDqnAgent {
|
|||
/// Sélectionne une action avec epsilon-greedy
|
||||
pub fn select_action(&mut self, state: &[f32], valid_actions: &[usize]) -> usize {
|
||||
if valid_actions.is_empty() {
|
||||
// Retourne une action par défaut ou une action "nulle" si aucune n'est valide
|
||||
// Dans le contexte du jeu, cela ne devrait pas arriver si la logique de fin de partie est correcte
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -148,7 +148,8 @@ impl BurnDqnAgent {
|
|||
}
|
||||
|
||||
// Exploitation : choisir la meilleure action selon le Q-network
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device);
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device)
|
||||
.reshape([1, self.config.state_size]);
|
||||
let q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
// Convertir en vecteur pour traitement
|
||||
|
|
@ -177,9 +178,9 @@ impl BurnDqnAgent {
|
|||
}
|
||||
|
||||
/// Entraîne le réseau sur un batch d'expériences
|
||||
pub fn train_step<B: AutodiffBackend, M: AutodiffModule<B>>(
|
||||
pub fn train_step(
|
||||
&mut self,
|
||||
optimizer: &mut impl Optimizer<M, B>,
|
||||
optimizer: &mut impl Optimizer<DqnNetwork<MyBackend>, MyBackend>,
|
||||
) -> Option<f32> {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return None;
|
||||
|
|
@ -189,8 +190,9 @@ impl BurnDqnAgent {
|
|||
let batch = self.sample_batch();
|
||||
|
||||
// Préparer les tenseurs d'état
|
||||
let states: Vec<&[f32]> = batch.iter().map(|exp| exp.state.as_slice()).collect();
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states, &self.device);
|
||||
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states.as_slice(), &self.device)
|
||||
.reshape([self.config.batch_size, self.config.state_size]);
|
||||
|
||||
// Calculer les Q-values actuelles
|
||||
let current_q_values = self.q_network.forward(state_tensor);
|
||||
|
|
@ -208,8 +210,8 @@ impl BurnDqnAgent {
|
|||
// Backpropagation (version simplifiée)
|
||||
let grads = loss.backward();
|
||||
// Gradients linked to each parameter of the model.
|
||||
// let grads = GradientsParams::from_grads(grads, &self.q_network);
|
||||
self.q_network = optimizer.step(self.config.learning_rate, self.q_network, grads);
|
||||
let grads = burn::optim::GradientsParams::from_grads(grads, &self.q_network);
|
||||
self.q_network = optimizer.step(self.config.learning_rate, self.q_network.clone(), grads);
|
||||
|
||||
// Mise à jour du réseau cible
|
||||
self.step_count += 1;
|
||||
|
|
|
|||
|
|
@ -64,11 +64,11 @@ impl BurnDqnStrategy {
|
|||
|
||||
// Convertir l'état du jeu en tensor
|
||||
let state_vec = self.game.to_vec_float();
|
||||
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats([state_vec], &self.device);
|
||||
let state_tensor = Tensor::<InferenceBackend, 2>::from_floats(state_vec.as_slice(), &self.device).reshape([1, self.config.as_ref().unwrap().state_size]);
|
||||
|
||||
// Faire une prédiction
|
||||
let q_values = network.forward(state_tensor);
|
||||
let q_data = q_values.into_data().convert::<f32>().value;
|
||||
let q_data = q_values.into_data().convert::<f32>().into_vec().unwrap();
|
||||
|
||||
// Trouver la meilleure action parmi les actions valides
|
||||
let mut best_action = &valid_actions[0];
|
||||
|
|
@ -129,14 +129,14 @@ impl BotStrategy for BurnDqnStrategy {
|
|||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
// Utiliser le modèle DQN pour décider des points à marquer
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
// let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
// Chercher une action Mark dans les actions valides
|
||||
for action in &valid_actions {
|
||||
if let super::dqn_common::TrictracAction::Mark { points } = action {
|
||||
return *points;
|
||||
}
|
||||
}
|
||||
// for action in &valid_actions {
|
||||
// if let super::dqn_common::TrictracAction::Mark { points } = action {
|
||||
// return *points;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Par défaut, marquer 0 points
|
||||
0
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
|||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TrictracState {
|
||||
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu
|
||||
}
|
||||
|
||||
impl State for TrictracState {
|
||||
|
|
@ -24,14 +24,12 @@ impl State for TrictracState {
|
|||
impl TrictracState {
|
||||
/// Convertit un GameState en TrictracState
|
||||
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||
let state_vec = game_state.to_vec();
|
||||
let mut data = [0; 36];
|
||||
let state_vec = game_state.to_vec_float();
|
||||
let mut data = [0.0; 36];
|
||||
|
||||
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||
let copy_len = state_vec.len().min(36);
|
||||
for i in 0..copy_len {
|
||||
data[i] = state_vec[i];
|
||||
}
|
||||
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
|
||||
|
||||
TrictracState { data }
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue