Compare commits
2 commits
2e0a874879
...
28c2aa836f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
28c2aa836f | ||
|
|
ad5ae17168 |
|
|
@ -7,7 +7,7 @@ edition = "2021"
|
|||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn"
|
||||
path = "src/burnrl/main.rs"
|
||||
path = "src/dqn/burnrl/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn"
|
||||
|
|
|
|||
|
|
@ -58,17 +58,35 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
|||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 4096;
|
||||
const DENSE_SIZE: usize = 128;
|
||||
const EPS_DECAY: f64 = 1000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub struct DqnConfig {
|
||||
pub num_episodes: usize,
|
||||
// pub memory_size: usize,
|
||||
pub dense_size: usize,
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
pub eps_decay: f64,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_episodes: 1000,
|
||||
// memory_size: 8192,
|
||||
dense_size: 256,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||
num_episodes: usize,
|
||||
conf: &DqnConfig,
|
||||
visualized: bool,
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
|
|
@ -76,7 +94,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
|
||||
let model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
DENSE_SIZE,
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
|
|
@ -94,7 +112,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..num_episodes {
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward: ElemType = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
|
|
@ -102,8 +120,8 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
let mut now = SystemTime::now();
|
||||
|
||||
while !episode_done {
|
||||
let eps_threshold =
|
||||
EPS_END + (EPS_START - EPS_END) * f64::exp(-(step as f64) / EPS_DECAY);
|
||||
let eps_threshold = conf.eps_end
|
||||
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
|
||||
let action =
|
||||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||
let snapshot = env.step(action);
|
||||
|
|
|
|||
|
|
@ -91,8 +91,7 @@ impl Environment for TrictracEnvironment {
|
|||
type ActionType = TrictracAction;
|
||||
type RewardType = f32;
|
||||
|
||||
const MAX_STEPS: usize = 1000; // Limite max pour éviter les parties infinies
|
||||
// const MAX_STEPS: usize = 5; // Limite max pour éviter les parties infinies
|
||||
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
|
||||
|
||||
fn new(visualized: bool) -> Self {
|
||||
let mut game = GameState::new(false);
|
||||
|
|
@ -260,7 +259,7 @@ impl TrictracEnvironment {
|
|||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.2;
|
||||
reward += 0.4;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
|
|
@ -289,7 +288,7 @@ impl TrictracEnvironment {
|
|||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.2;
|
||||
reward += 0.4;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use bot::burnrl::{dqn_model, environment, utils::demo_model};
|
||||
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
|
||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
|
|
@ -10,8 +10,16 @@ type Env = environment::TrictracEnvironment;
|
|||
|
||||
fn main() {
|
||||
println!("> Entraînement");
|
||||
let num_episodes = 50;
|
||||
let agent = dqn_model::run::<Env, Backend>(num_episodes, false); //true);
|
||||
let conf = dqn_model::DqnConfig {
|
||||
num_episodes: 50,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
};
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
|
|
@ -24,7 +32,7 @@ fn main() {
|
|||
// demo_model::<Env>(valid_agent);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model(&path);
|
||||
let loaded_model = load_model(conf.dense_size, &path);
|
||||
let loaded_agent = DQN::new(loaded_model);
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
|
|
@ -40,10 +48,7 @@ fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||
// TODO : reprendre le DENSE_SIZE de dqn_model.rs
|
||||
const DENSE_SIZE: usize = 128;
|
||||
|
||||
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
||||
let model_path = format!("{}_model.mpk", path);
|
||||
println!("Chargement du modèle depuis : {}", model_path);
|
||||
|
||||
|
|
@ -56,7 +61,7 @@ fn load_model(path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
|||
|
||||
dqn_model::Net::new(
|
||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||
DENSE_SIZE,
|
||||
dense_size,
|
||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
|
|
|
|||
|
|
@ -93,6 +93,18 @@ impl MoveRules {
|
|||
/// ---- moves_possibles : First of three checks for moves
|
||||
fn moves_possible(&self, moves: &(CheckerMove, CheckerMove)) -> bool {
|
||||
let color = &Color::White;
|
||||
|
||||
let move0_from = moves.0.get_from();
|
||||
if 0 < move0_from && move0_from == moves.1.get_from() {
|
||||
if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from)
|
||||
{
|
||||
if color != field_color || field_count < 2 {
|
||||
info!("Move not physically possible");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(chained_move) = moves.0.chain(moves.1) {
|
||||
// Check intermediary move and chained_move : "Tout d'une"
|
||||
if !self.board.passage_possible(color, &moves.0)
|
||||
|
|
@ -1005,7 +1017,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn moves_possible() {
|
||||
let state = MoveRules::default();
|
||||
let mut state = MoveRules::default();
|
||||
|
||||
// Chained moves
|
||||
let moves = (
|
||||
|
|
@ -1021,6 +1033,17 @@ mod tests {
|
|||
);
|
||||
assert!(!state.moves_possible(&moves));
|
||||
|
||||
// Can't move the same checker twice
|
||||
state.board.set_positions([
|
||||
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
]);
|
||||
state.dice.values = (2, 1);
|
||||
let moves = (
|
||||
CheckerMove::new(3, 5).unwrap(),
|
||||
CheckerMove::new(3, 4).unwrap(),
|
||||
);
|
||||
assert!(!state.moves_possible(&moves));
|
||||
|
||||
// black moves
|
||||
let state = MoveRules::new(&Color::Black, &Board::default(), Dice::default());
|
||||
let moves = (
|
||||
|
|
|
|||
Loading…
Reference in a new issue