Compare commits
No commits in common. "28c2aa836ff1a0626466d13f06f37d4ed6156865" and "2e0a874879876ab159cb7f78f2977b0663692f03" have entirely different histories.
28c2aa836f
...
2e0a874879
|
|
@ -7,7 +7,7 @@ edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "train_dqn_burn"
|
name = "train_dqn_burn"
|
||||||
path = "src/dqn/burnrl/main.rs"
|
path = "src/burnrl/main.rs"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "train_dqn"
|
name = "train_dqn"
|
||||||
|
|
|
||||||
|
|
@ -58,35 +58,17 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
const MEMORY_SIZE: usize = 8192;
|
const MEMORY_SIZE: usize = 4096;
|
||||||
|
const DENSE_SIZE: usize = 128;
|
||||||
pub struct DqnConfig {
|
const EPS_DECAY: f64 = 1000.0;
|
||||||
pub num_episodes: usize,
|
const EPS_START: f64 = 0.9;
|
||||||
// pub memory_size: usize,
|
const EPS_END: f64 = 0.05;
|
||||||
pub dense_size: usize,
|
|
||||||
pub eps_start: f64,
|
|
||||||
pub eps_end: f64,
|
|
||||||
pub eps_decay: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for DqnConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
num_episodes: 1000,
|
|
||||||
// memory_size: 8192,
|
|
||||||
dense_size: 256,
|
|
||||||
eps_start: 0.9,
|
|
||||||
eps_end: 0.05,
|
|
||||||
eps_decay: 1000.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
conf: &DqnConfig,
|
num_episodes: usize,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
// ) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
|
|
@ -94,7 +76,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
<<E as Environment>::StateType as State>::size(),
|
<<E as Environment>::StateType as State>::size(),
|
||||||
conf.dense_size,
|
DENSE_SIZE,
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
<<E as Environment>::ActionType as Action>::size(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -112,7 +94,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
|
|
||||||
let mut step = 0_usize;
|
let mut step = 0_usize;
|
||||||
|
|
||||||
for episode in 0..conf.num_episodes {
|
for episode in 0..num_episodes {
|
||||||
let mut episode_done = false;
|
let mut episode_done = false;
|
||||||
let mut episode_reward: ElemType = 0.0;
|
let mut episode_reward: ElemType = 0.0;
|
||||||
let mut episode_duration = 0_usize;
|
let mut episode_duration = 0_usize;
|
||||||
|
|
@ -120,8 +102,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
let mut now = SystemTime::now();
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
while !episode_done {
|
while !episode_done {
|
||||||
let eps_threshold = conf.eps_end
|
let eps_threshold =
|
||||||
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
|
EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY);
|
||||||
let action =
|
let action =
|
||||||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,8 @@ impl Environment for TrictracEnvironment {
|
||||||
type ActionType = TrictracAction;
|
type ActionType = TrictracAction;
|
||||||
type RewardType = f32;
|
type RewardType = f32;
|
||||||
|
|
||||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
||||||
|
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
|
||||||
|
|
||||||
fn new(visualized: bool) -> Self {
|
fn new(visualized: bool) -> Self {
|
||||||
let mut game = GameState::new(false);
|
let mut game = GameState::new(false);
|
||||||
|
|
@ -259,7 +260,7 @@ impl TrictracEnvironment {
|
||||||
// }
|
// }
|
||||||
TrictracAction::Go => {
|
TrictracAction::Go => {
|
||||||
// Continuer après avoir gagné un trou
|
// Continuer après avoir gagné un trou
|
||||||
reward += 0.4;
|
reward += 0.2;
|
||||||
Some(GameEvent::Go {
|
Some(GameEvent::Go {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
})
|
})
|
||||||
|
|
@ -288,7 +289,7 @@ impl TrictracEnvironment {
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
reward += 0.4;
|
reward += 0.2;
|
||||||
Some(GameEvent::Move {
|
Some(GameEvent::Move {
|
||||||
player_id: self.active_player_id,
|
player_id: self.active_player_id,
|
||||||
moves: (checker_move1, checker_move2),
|
moves: (checker_move1, checker_move2),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
|
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
||||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
|
@ -10,16 +10,8 @@ type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("> Entraînement");
|
println!("> Entraînement");
|
||||||
let conf = dqn_model::DqnConfig {
|
let num_episodes = 50;
|
||||||
num_episodes: 50,
|
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
|
||||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
|
||||||
dense_size: 256, // neural network complexity
|
|
||||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
|
||||||
eps_end: 0.05,
|
|
||||||
eps_decay: 1000.0,
|
|
||||||
};
|
|
||||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
|
||||||
|
|
||||||
let valid_agent = agent.valid();
|
let valid_agent = agent.valid();
|
||||||
|
|
||||||
|
|
@ -32,7 +24,7 @@ fn main() {
|
||||||
// demo_model::<Env>(valid_agent);
|
// demo_model::<Env>(valid_agent);
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model(conf.dense_size, &path);
|
let loaded_model = load_model(&path);
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model);
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
|
|
@ -48,7 +40,10 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
|
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
|
||||||
|
const DENSE_SIZE: usize = 128;
|
||||||
|
|
||||||
let model_path = format!("{}_model.mpk", path);
|
let model_path = format!("{}_model.mpk", path);
|
||||||
println!("Chargement du modèle depuis : {}", model_path);
|
println!("Chargement du modèle depuis : {}", model_path);
|
||||||
|
|
||||||
|
|
@ -61,7 +56,7 @@ fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemTy
|
||||||
|
|
||||||
dqn_model::Net::new(
|
dqn_model::Net::new(
|
||||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||||
dense_size,
|
DENSE_SIZE,
|
||||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
)
|
)
|
||||||
.load_record(record)
|
.load_record(record)
|
||||||
|
|
|
||||||
|
|
@ -93,18 +93,6 @@ impl MoveRules {
|
||||||
/// ---- moves_possibles : First of three checks for moves
|
/// ---- moves_possibles : First of three checks for moves
|
||||||
fn moves_possible(&self, moves: &(CheckerMove, CheckerMove)) -> bool {
|
fn moves_possible(&self, moves: &(CheckerMove, CheckerMove)) -> bool {
|
||||||
let color = &Color::White;
|
let color = &Color::White;
|
||||||
|
|
||||||
let move0_from = moves.0.get_from();
|
|
||||||
if 0 < move0_from && move0_from == moves.1.get_from() {
|
|
||||||
if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from)
|
|
||||||
{
|
|
||||||
if color != field_color || field_count < 2 {
|
|
||||||
info!("Move not physically possible");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(chained_move) = moves.0.chain(moves.1) {
|
if let Ok(chained_move) = moves.0.chain(moves.1) {
|
||||||
// Check intermediary move and chained_move : "Tout d'une"
|
// Check intermediary move and chained_move : "Tout d'une"
|
||||||
if !self.board.passage_possible(color, &moves.0)
|
if !self.board.passage_possible(color, &moves.0)
|
||||||
|
|
@ -1017,7 +1005,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn moves_possible() {
|
fn moves_possible() {
|
||||||
let mut state = MoveRules::default();
|
let state = MoveRules::default();
|
||||||
|
|
||||||
// Chained moves
|
// Chained moves
|
||||||
let moves = (
|
let moves = (
|
||||||
|
|
@ -1033,17 +1021,6 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert!(!state.moves_possible(&moves));
|
assert!(!state.moves_possible(&moves));
|
||||||
|
|
||||||
// Can't move the same checker twice
|
|
||||||
state.board.set_positions([
|
|
||||||
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
]);
|
|
||||||
state.dice.values = (2, 1);
|
|
||||||
let moves = (
|
|
||||||
CheckerMove::new(3, 5).unwrap(),
|
|
||||||
CheckerMove::new(3, 4).unwrap(),
|
|
||||||
);
|
|
||||||
assert!(!state.moves_possible(&moves));
|
|
||||||
|
|
||||||
// black moves
|
// black moves
|
||||||
let state = MoveRules::new(&Color::Black, &Board::default(), Dice::default());
|
let state = MoveRules::new(&Color::Black, &Board::default(), Dice::default());
|
||||||
let moves = (
|
let moves = (
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue