refacto: bot directories

This commit is contained in:
Henri Bourcereau 2025-08-19 16:27:37 +02:00
parent e66921fcce
commit fcd50bc0f2
27 changed files with 110 additions and 94 deletions

View file

@ -7,19 +7,19 @@ edition = "2021"
[[bin]] [[bin]]
name = "train_dqn_burn_valid" name = "train_dqn_burn_valid"
path = "src/dqn/burnrl_valid/main.rs" path = "src/burnrl/dqn_valid/main.rs"
[[bin]] [[bin]]
name = "train_dqn_burn_big" name = "train_dqn_burn_big"
path = "src/dqn/burnrl_big/main.rs" path = "src/burnrl/dqn_big/main.rs"
[[bin]] [[bin]]
name = "train_dqn_burn" name = "train_dqn_burn"
path = "src/dqn/burnrl/main.rs" path = "src/burnrl/dqn/main.rs"
[[bin]] [[bin]]
name = "train_dqn_simple" name = "train_dqn_simple"
path = "src/dqn/simple/main.rs" path = "src/dqn_simple/main.rs"
[dependencies] [dependencies]
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl_big::environment::TrictracEnvironment; use crate::burnrl::dqn::utils::soft_update_linear;
use crate::dqn::burnrl_big::utils::soft_update_linear; use crate::burnrl::environment::TrictracEnvironment;
use burn::module::Module; use burn::module::Module;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;
@ -126,7 +126,7 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
) -> DQN<E, B, Net<B>> { ) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> { // ) -> impl Agent<E> {
let mut env = E::new(visualized); let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps; // env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps; env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new( let model = Net::<B>::new(
@ -193,12 +193,17 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
if snapshot.done() || episode_duration >= conf.max_steps { if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut(); let envmut = env.as_mut();
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
* 100.0)
.round() as u32;
println!( println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}", "{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
envmut.goodmoves_count, envmut.goodmoves_count,
goodmoves_ratio,
envmut.pointrolls_count, envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(), now.elapsed().unwrap().as_secs(),
); );
if goodmoves_ratio < 5 && 10 < episode {}
env.reset(); env.reset();
episode_done = true; episode_done = true;
now = SystemTime::now(); now = SystemTime::now();

View file

@ -1,7 +1,8 @@
use bot::dqn::burnrl::{ use bot::burnrl::dqn::{
dqn_model, environment, dqn_model,
utils::{demo_model, load_model, save_model}, utils::{demo_model, load_model, save_model},
}; };
use bot::burnrl::environment;
use burn::backend::{Autodiff, NdArray}; use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::ElemType; use burn_rl::base::ElemType;
@ -15,9 +16,9 @@ fn main() {
// See also MEMORY_SIZE in dqn_model.rs : 8192 // See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig { let conf = dqn_model::DqnConfig {
// defaults // defaults
num_episodes: 40, // 40 num_episodes: 50, // 40
min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction) min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 2000, // 1000 max steps by episode max_steps: 1000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128) dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05 eps_end: 0.05, // 0.05
@ -31,8 +32,8 @@ fn main() {
// plus lente moins sensible aux coups de chance // plus lente moins sensible aux coups de chance
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger // converger
batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy. batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100) clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
}; };
println!("{conf}----------"); println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true); let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);

View file

@ -1,3 +1,2 @@
pub mod dqn_model; pub mod dqn_model;
pub mod environment;
pub mod utils; pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl::{ use crate::burnrl::dqn::dqn_model;
dqn_model, use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
environment::{TrictracAction, TrictracEnvironment}, use crate::training_common::get_valid_action_indices;
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl::environment::TrictracEnvironment; use crate::burnrl::dqn_big::utils::soft_update_linear;
use crate::dqn::burnrl::utils::soft_update_linear; use crate::burnrl::environment_big::TrictracEnvironment;
use burn::module::Module; use burn::module::Module;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;

View file

@ -1,13 +1,14 @@
use bot::dqn::burnrl_big::{ use bot::burnrl::dqn_big::{
dqn_model, environment, dqn_model,
utils::{demo_model, load_model, save_model}, utils::{demo_model, load_model, save_model},
}; };
use bot::burnrl::environment_big;
use burn::backend::{Autodiff, NdArray}; use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::ElemType; use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>; type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment; type Env = environment_big::TrictracEnvironment;
fn main() { fn main() {
// println!("> Entraînement"); // println!("> Entraînement");

View file

@ -1,3 +1,2 @@
pub mod dqn_model; pub mod dqn_model;
pub mod environment;
pub mod utils; pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl_valid::{ use crate::burnrl::dqn_big::dqn_model;
dqn_model, use crate::burnrl::environment_big::{TrictracAction, TrictracEnvironment};
environment::{TrictracAction, TrictracEnvironment}, use crate::training_common_big::get_valid_action_indices;
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;

View file

@ -1,5 +1,5 @@
use crate::dqn::burnrl_valid::environment::TrictracEnvironment; use crate::burnrl::dqn_valid::utils::soft_update_linear;
use crate::dqn::burnrl_valid::utils::soft_update_linear; use crate::burnrl::environment::TrictracEnvironment;
use burn::module::Module; use burn::module::Module;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;

View file

@ -1,7 +1,8 @@
use bot::dqn::burnrl_valid::{ use bot::burnrl::dqn_valid::{
dqn_model, environment, dqn_model,
utils::{demo_model, load_model, save_model}, utils::{demo_model, load_model, save_model},
}; };
use bot::burnrl::environment;
use burn::backend::{Autodiff, NdArray}; use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::base::ElemType; use burn_rl::base::ElemType;

View file

@ -1,3 +1,2 @@
pub mod dqn_model; pub mod dqn_model;
pub mod environment;
pub mod utils; pub mod utils;

View file

@ -1,8 +1,6 @@
use crate::dqn::burnrl_big::{ use crate::burnrl::dqn_valid::dqn_model;
dqn_model, use crate::burnrl::environment_valid::{TrictracAction, TrictracEnvironment};
environment::{TrictracAction, TrictracEnvironment}, use crate::training_common::get_valid_action_indices;
};
use crate::dqn::dqn_common_big::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, NdArray}; use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::{Module, Param, ParamId}; use burn::module::{Module, Param, ParamId};
use burn::nn::Linear; use burn::nn::Linear;

View file

@ -1,13 +1,15 @@
use crate::dqn::dqn_common; use std::io::Write;
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121; const ERROR_REWARD: f32 = -1.12121;
const REWARD_VALID_MOVE: f32 = 2.12121; const REWARD_VALID_MOVE: f32 = 1.12121;
const REWARD_RATIO: f32 = 0.01; const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1; const WIN_POINTS: f32 = 1.0;
/// État du jeu Trictrac pour burn-rl /// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -89,7 +91,7 @@ pub struct TrictracEnvironment {
current_state: TrictracState, current_state: TrictracState,
episode_reward: f32, episode_reward: f32,
pub step_count: usize, pub step_count: usize,
pub min_steps: f32, pub best_ratio: f32,
pub max_steps: usize, pub max_steps: usize,
pub pointrolls_count: usize, pub pointrolls_count: usize,
pub goodmoves_count: usize, pub goodmoves_count: usize,
@ -122,7 +124,7 @@ impl Environment for TrictracEnvironment {
current_state, current_state,
episode_reward: 0.0, episode_reward: 0.0,
step_count: 0, step_count: 0,
min_steps: 250.0, best_ratio: 0.0,
max_steps: 2000, max_steps: 2000,
pointrolls_count: 0, pointrolls_count: 0,
goodmoves_count: 0, goodmoves_count: 0,
@ -151,10 +153,21 @@ impl Environment for TrictracEnvironment {
} else { } else {
self.goodmoves_count as f32 / self.step_count as f32 self.goodmoves_count as f32 / self.step_count as f32
}; };
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{:?}", self.game.history);
}
"!!!!"
} else {
""
};
println!( println!(
"info: correct moves: {} ({}%)", "info: correct moves: {} ({}%) {}",
self.goodmoves_count, self.goodmoves_count,
(100.0 * self.goodmoves_ratio).round() as u32 (100.0 * self.goodmoves_ratio).round() as u32,
warning
); );
self.step_count = 0; self.step_count = 0;
self.pointrolls_count = 0; self.pointrolls_count = 0;
@ -195,9 +208,10 @@ impl Environment for TrictracEnvironment {
} }
// Vérifier si la partie est terminée // Vérifier si la partie est terminée
let max_steps = self.min_steps // let max_steps = self.max_steps;
+ (self.max_steps as f32 - self.min_steps) // let max_steps = self.min_steps
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25); // + (self.max_steps as f32 - self.min_steps)
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some(); let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done { if done {
@ -210,7 +224,8 @@ impl Environment for TrictracEnvironment {
} }
} }
} }
let terminated = done || self.step_count >= max_steps.round() as usize; let terminated = done || self.step_count >= self.max_steps;
// let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état // Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game); self.current_state = TrictracState::from_game_state(&self.game);
@ -229,8 +244,8 @@ impl Environment for TrictracEnvironment {
impl TrictracEnvironment { impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap()) training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
@ -239,8 +254,8 @@ impl TrictracEnvironment {
&self, &self,
action: TrictracAction, action: TrictracAction,
game_state: &GameState, game_state: &GameState,
) -> Option<dqn_common::TrictracAction> { ) -> Option<training_common::TrictracAction> {
use dqn_common::get_valid_actions; use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state); let valid_actions = get_valid_actions(game_state);
@ -257,10 +272,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu
// fn execute_action( // fn execute_action(
// &mut self, // &mut self,
// action: dqn_common::TrictracAction, // action: training_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> { // ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) { fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
use dqn_common::TrictracAction; use training_common::TrictracAction;
let mut reward = 0.0; let mut reward = 0.0;
let mut is_rollpoint = false; let mut is_rollpoint = false;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common_big; use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
@ -229,8 +229,8 @@ impl Environment for TrictracEnvironment {
impl TrictracEnvironment { impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
@ -239,8 +239,8 @@ impl TrictracEnvironment {
&self, &self,
action: TrictracAction, action: TrictracAction,
game_state: &GameState, game_state: &GameState,
) -> Option<dqn_common_big::TrictracAction> { ) -> Option<training_common_big::TrictracAction> {
use dqn_common_big::get_valid_actions; use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state); let valid_actions = get_valid_actions(game_state);
@ -257,10 +257,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu
// fn execute_action( // fn execute_action(
// &mut self, // &mut self,
// action:dqn_common_big::TrictracAction, // action:training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> { // ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use dqn_common_big::TrictracAction; use training_common_big::TrictracAction;
let mut reward = 0.0; let mut reward = 0.0;
let mut is_rollpoint = false; let mut is_rollpoint = false;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common_big; use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
@ -214,16 +214,16 @@ impl TrictracEnvironment {
const REWARD_RATIO: f32 = 1.0; const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> { pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap()) training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
fn convert_valid_action_index( fn convert_valid_action_index(
&self, &self,
action: TrictracAction, action: TrictracAction,
) -> Option<dqn_common_big::TrictracAction> { ) -> Option<training_common_big::TrictracAction> {
use dqn_common_big::get_valid_actions; use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(&self.game); let valid_actions = get_valid_actions(&self.game);
@ -240,10 +240,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu
// fn execute_action( // fn execute_action(
// &mut self, // &mut self,
// action: dqn_common_big::TrictracAction, // action: training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> { // ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) { fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use dqn_common_big::TrictracAction; use training_common_big::TrictracAction;
let mut reward = 0.0; let mut reward = 0.0;
let mut is_rollpoint = false; let mut is_rollpoint = false;

6
bot/src/burnrl/mod.rs Normal file
View file

@ -0,0 +1,6 @@
pub mod dqn;
pub mod dqn_big;
pub mod dqn_valid;
pub mod environment;
pub mod environment_big;
pub mod environment_valid;

View file

@ -1,7 +0,0 @@
pub mod burnrl;
pub mod burnrl_big;
pub mod dqn_common;
pub mod dqn_common_big;
pub mod simple;
pub mod burnrl_valid;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common::TrictracAction; use crate::training_common_big::TrictracAction;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN /// Configuration pour l'agent DQN
@ -151,4 +151,3 @@ impl SimpleNeuralNetwork {
Ok(network) Ok(network)
} }
} }

View file

@ -6,7 +6,7 @@ use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage}; use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork}; use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
use crate::dqn::dqn_common_big::{get_valid_actions, TrictracAction}; use crate::training_common_big::{get_valid_actions, TrictracAction};
/// Expérience pour le buffer de replay /// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,6 +1,6 @@
use bot::dqn::dqn_common::TrictracAction; use bot::dqn_simple::dqn_model::DqnConfig;
use bot::dqn::simple::dqn_model::DqnConfig; use bot::dqn_simple::dqn_trainer::DqnTrainer;
use bot::dqn::simple::dqn_trainer::DqnTrainer; use bot::training_common::TrictracAction;
use std::env; use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {

View file

@ -1,5 +1,8 @@
pub mod dqn; pub mod burnrl;
pub mod dqn_simple;
pub mod strategy; pub mod strategy;
pub mod training_common;
pub mod training_common_big;
use log::debug; use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};

View file

@ -3,8 +3,8 @@ use log::info;
use std::path::Path; use std::path::Path;
use store::MoveRules; use store::MoveRules;
use crate::dqn::dqn_common_big::{get_valid_actions, sample_valid_action, TrictracAction}; use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork; use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné /// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)] #[derive(Debug)]

View file

@ -6,8 +6,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info; use log::info;
use store::MoveRules; use store::MoveRules;
use crate::dqn::burnrl::{dqn_model, environment, utils}; use crate::burnrl::dqn::{dqn_model, utils};
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction}; use crate::burnrl::environment;
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>; type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;