refacto: bot directories
This commit is contained in:
parent
e66921fcce
commit
fcd50bc0f2
|
|
@ -7,19 +7,19 @@ edition = "2021"
|
|||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn_valid"
|
||||
path = "src/dqn/burnrl_valid/main.rs"
|
||||
path = "src/burnrl/dqn_valid/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn_big"
|
||||
path = "src/dqn/burnrl_big/main.rs"
|
||||
path = "src/burnrl/dqn_big/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn"
|
||||
path = "src/dqn/burnrl/main.rs"
|
||||
path = "src/burnrl/dqn/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_simple"
|
||||
path = "src/dqn/simple/main.rs"
|
||||
path = "src/dqn_simple/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pretty_assertions = "1.4.0"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::dqn::burnrl_big::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl_big::utils::soft_update_linear;
|
||||
use crate::burnrl::dqn::utils::soft_update_linear;
|
||||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
|
|
@ -126,7 +126,7 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().min_steps = conf.min_steps;
|
||||
// env.as_mut().min_steps = conf.min_steps;
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
|
|
@ -193,12 +193,17 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
|||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
let envmut = env.as_mut();
|
||||
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
|
||||
* 100.0)
|
||||
.round() as u32;
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
envmut.goodmoves_count,
|
||||
goodmoves_ratio,
|
||||
envmut.pointrolls_count,
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
if goodmoves_ratio < 5 && 10 < episode {}
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
now = SystemTime::now();
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
use bot::dqn::burnrl::{
|
||||
dqn_model, environment,
|
||||
use bot::burnrl::dqn::{
|
||||
dqn_model,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use bot::burnrl::environment;
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::ElemType;
|
||||
|
|
@ -15,9 +16,9 @@ fn main() {
|
|||
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||
let conf = dqn_model::DqnConfig {
|
||||
// defaults
|
||||
num_episodes: 40, // 40
|
||||
num_episodes: 50, // 40
|
||||
min_steps: 1000.0, // 1000 min of max steps by episode (mise à jour par la fonction)
|
||||
max_steps: 2000, // 1000 max steps by episode
|
||||
max_steps: 1000, // 1000 max steps by episode
|
||||
dense_size: 256, // 128 neural network complexity (default 128)
|
||||
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05, // 0.05
|
||||
|
|
@ -31,8 +32,8 @@ fn main() {
|
|||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 64, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 50.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
batch_size: 128, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 70.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
pub mod dqn_model;
|
||||
pub mod environment;
|
||||
pub mod utils;
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
use crate::dqn::burnrl::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||
use crate::burnrl::dqn::dqn_model;
|
||||
use crate::burnrl::environment::{TrictracAction, TrictracEnvironment};
|
||||
use crate::training_common::get_valid_action_indices;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||
use crate::burnrl::dqn_big::utils::soft_update_linear;
|
||||
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
|
|
@ -1,13 +1,14 @@
|
|||
use bot::dqn::burnrl_big::{
|
||||
dqn_model, environment,
|
||||
use bot::burnrl::dqn_big::{
|
||||
dqn_model,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use bot::burnrl::environment_big;
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::ElemType;
|
||||
|
||||
type Backend = Autodiff<NdArray<ElemType>>;
|
||||
type Env = environment::TrictracEnvironment;
|
||||
type Env = environment_big::TrictracEnvironment;
|
||||
|
||||
fn main() {
|
||||
// println!("> Entraînement");
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
pub mod dqn_model;
|
||||
pub mod environment;
|
||||
pub mod utils;
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
use crate::dqn::burnrl_valid::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||
use crate::burnrl::dqn_big::dqn_model;
|
||||
use crate::burnrl::environment_big::{TrictracAction, TrictracEnvironment};
|
||||
use crate::training_common_big::get_valid_action_indices;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::dqn::burnrl_valid::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl_valid::utils::soft_update_linear;
|
||||
use crate::burnrl::dqn_valid::utils::soft_update_linear;
|
||||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
use bot::dqn::burnrl_valid::{
|
||||
dqn_model, environment,
|
||||
use bot::burnrl::dqn_valid::{
|
||||
dqn_model,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use bot::burnrl::environment;
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::ElemType;
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
pub mod dqn_model;
|
||||
pub mod environment;
|
||||
pub mod utils;
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
use crate::dqn::burnrl_big::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common_big::get_valid_action_indices;
|
||||
use crate::burnrl::dqn_valid::dqn_model;
|
||||
use crate::burnrl::environment_valid::{TrictracAction, TrictracEnvironment};
|
||||
use crate::training_common::get_valid_action_indices;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
use crate::dqn::dqn_common;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::training_common;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -2.12121;
|
||||
const REWARD_VALID_MOVE: f32 = 2.12121;
|
||||
const ERROR_REWARD: f32 = -1.12121;
|
||||
const REWARD_VALID_MOVE: f32 = 1.12121;
|
||||
const REWARD_RATIO: f32 = 0.01;
|
||||
const WIN_POINTS: f32 = 0.1;
|
||||
const WIN_POINTS: f32 = 1.0;
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
@ -89,7 +91,7 @@ pub struct TrictracEnvironment {
|
|||
current_state: TrictracState,
|
||||
episode_reward: f32,
|
||||
pub step_count: usize,
|
||||
pub min_steps: f32,
|
||||
pub best_ratio: f32,
|
||||
pub max_steps: usize,
|
||||
pub pointrolls_count: usize,
|
||||
pub goodmoves_count: usize,
|
||||
|
|
@ -122,7 +124,7 @@ impl Environment for TrictracEnvironment {
|
|||
current_state,
|
||||
episode_reward: 0.0,
|
||||
step_count: 0,
|
||||
min_steps: 250.0,
|
||||
best_ratio: 0.0,
|
||||
max_steps: 2000,
|
||||
pointrolls_count: 0,
|
||||
goodmoves_count: 0,
|
||||
|
|
@ -151,10 +153,21 @@ impl Environment for TrictracEnvironment {
|
|||
} else {
|
||||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
|
||||
let warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
|
||||
let path = "bot/models/logs/debug.log";
|
||||
if let Ok(mut out) = std::fs::File::create(path) {
|
||||
write!(out, "{:?}", self.game.history);
|
||||
}
|
||||
"!!!!"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
println!(
|
||||
"info: correct moves: {} ({}%)",
|
||||
"info: correct moves: {} ({}%) {}",
|
||||
self.goodmoves_count,
|
||||
(100.0 * self.goodmoves_ratio).round() as u32
|
||||
(100.0 * self.goodmoves_ratio).round() as u32,
|
||||
warning
|
||||
);
|
||||
self.step_count = 0;
|
||||
self.pointrolls_count = 0;
|
||||
|
|
@ -195,9 +208,10 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let max_steps = self.min_steps
|
||||
+ (self.max_steps as f32 - self.min_steps)
|
||||
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
// let max_steps = self.max_steps;
|
||||
// let max_steps = self.min_steps
|
||||
// + (self.max_steps as f32 - self.min_steps)
|
||||
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||
|
||||
if done {
|
||||
|
|
@ -210,7 +224,8 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
}
|
||||
}
|
||||
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
let terminated = done || self.step_count >= self.max_steps;
|
||||
// let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
|
||||
// Mettre à jour l'état
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
|
|
@ -229,8 +244,8 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
impl TrictracEnvironment {
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
|
|
@ -239,8 +254,8 @@ impl TrictracEnvironment {
|
|||
&self,
|
||||
action: TrictracAction,
|
||||
game_state: &GameState,
|
||||
) -> Option<dqn_common::TrictracAction> {
|
||||
use dqn_common::get_valid_actions;
|
||||
) -> Option<training_common::TrictracAction> {
|
||||
use training_common::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
|
@ -257,10 +272,10 @@ impl TrictracEnvironment {
|
|||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: dqn_common::TrictracAction,
|
||||
// action: training_common::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
|
||||
use dqn_common::TrictracAction;
|
||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
||||
use training_common::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dqn::dqn_common_big;
|
||||
use crate::training_common_big;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
|
@ -229,8 +229,8 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
impl TrictracEnvironment {
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
|
||||
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
|
|
@ -239,8 +239,8 @@ impl TrictracEnvironment {
|
|||
&self,
|
||||
action: TrictracAction,
|
||||
game_state: &GameState,
|
||||
) -> Option<dqn_common_big::TrictracAction> {
|
||||
use dqn_common_big::get_valid_actions;
|
||||
) -> Option<training_common_big::TrictracAction> {
|
||||
use training_common_big::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
|
@ -257,10 +257,10 @@ impl TrictracEnvironment {
|
|||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action:dqn_common_big::TrictracAction,
|
||||
// action:training_common_big::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) {
|
||||
use dqn_common_big::TrictracAction;
|
||||
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||
use training_common_big::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dqn::dqn_common_big;
|
||||
use crate::training_common_big;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
|
@ -214,16 +214,16 @@ impl TrictracEnvironment {
|
|||
const REWARD_RATIO: f32 = 1.0;
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common_big::TrictracAction> {
|
||||
dqn_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
fn convert_valid_action_index(
|
||||
&self,
|
||||
action: TrictracAction,
|
||||
) -> Option<dqn_common_big::TrictracAction> {
|
||||
use dqn_common_big::get_valid_actions;
|
||||
) -> Option<training_common_big::TrictracAction> {
|
||||
use training_common_big::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
|
@ -240,10 +240,10 @@ impl TrictracEnvironment {
|
|||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: dqn_common_big::TrictracAction,
|
||||
// action: training_common_big::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: dqn_common_big::TrictracAction) -> (f32, bool) {
|
||||
use dqn_common_big::TrictracAction;
|
||||
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||
use training_common_big::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
6
bot/src/burnrl/mod.rs
Normal file
6
bot/src/burnrl/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
pub mod dqn;
|
||||
pub mod dqn_big;
|
||||
pub mod dqn_valid;
|
||||
pub mod environment;
|
||||
pub mod environment_big;
|
||||
pub mod environment_valid;
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
pub mod burnrl;
|
||||
pub mod burnrl_big;
|
||||
pub mod dqn_common;
|
||||
pub mod dqn_common_big;
|
||||
pub mod simple;
|
||||
|
||||
pub mod burnrl_valid;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dqn::dqn_common::TrictracAction;
|
||||
use crate::training_common_big::TrictracAction;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration pour l'agent DQN
|
||||
|
|
@ -151,4 +151,3 @@ impl SimpleNeuralNetwork {
|
|||
Ok(network)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ use std::collections::VecDeque;
|
|||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
|
||||
use crate::dqn::dqn_common_big::{get_valid_actions, TrictracAction};
|
||||
use crate::training_common_big::{get_valid_actions, TrictracAction};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use bot::dqn::dqn_common::TrictracAction;
|
||||
use bot::dqn::simple::dqn_model::DqnConfig;
|
||||
use bot::dqn::simple::dqn_trainer::DqnTrainer;
|
||||
use bot::dqn_simple::dqn_model::DqnConfig;
|
||||
use bot::dqn_simple::dqn_trainer::DqnTrainer;
|
||||
use bot::training_common::TrictracAction;
|
||||
use std::env;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
pub mod dqn;
|
||||
pub mod burnrl;
|
||||
pub mod dqn_simple;
|
||||
pub mod strategy;
|
||||
pub mod training_common;
|
||||
pub mod training_common_big;
|
||||
|
||||
use log::debug;
|
||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ use log::info;
|
|||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use crate::dqn::dqn_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork;
|
||||
use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
|
||||
use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
|||
use log::info;
|
||||
use store::MoveRules;
|
||||
|
||||
use crate::dqn::burnrl::{dqn_model, environment, utils};
|
||||
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||
use crate::burnrl::dqn::{dqn_model, utils};
|
||||
use crate::burnrl::environment;
|
||||
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||
|
||||
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue