diff --git a/bot/scripts/train.sh b/bot/scripts/train.sh index 78e7e3f..9e54c7a 100755 --- a/bot/scripts/train.sh +++ b/bot/scripts/train.sh @@ -17,7 +17,7 @@ train() { } plot() { - NAME=$(ls "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh index 349517d..546bc01 100755 --- a/bot/scripts/trainValid.sh +++ b/bot/scripts/trainValid.sh @@ -17,7 +17,7 @@ train() { } plot() { - NAME=$(ls "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do diff --git a/bot/src/dqn/burnrl/main.rs b/bot/src/dqn/burnrl/main.rs index d8b200f..097a27b 100644 --- a/bot/src/dqn/burnrl/main.rs +++ b/bot/src/dqn/burnrl/main.rs @@ -17,7 +17,7 @@ fn main() { // defaults num_episodes: 40, // 40 min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction) - max_steps: 3000, // 1000 max steps by episode + max_steps: 1000, // 1000 max steps by episode dense_size: 256, // 128 neural network complexity (default 128) eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration) eps_end: 0.05, // 0.05 diff --git a/bot/src/dqn/dqn_common.rs b/bot/src/dqn/dqn_common.rs index a5661a0..d3e3c4e 100644 --- a/bot/src/dqn/dqn_common.rs +++ b/bot/src/dqn/dqn_common.rs @@ -71,7 +71,7 @@ impl TrictracAction { encoded -= 256 } let checker1 = encoded / 16; - let checker2 = 1 + encoded % 16; + let checker2 = encoded % 16; (dice_order, checker1, checker2) } @@ -251,7 +251,7 @@ mod tests { }; let index = action.to_action_index(); assert_eq!(Some(action), TrictracAction::from_action_index(index)); - assert_eq!(81, index); + assert_eq!(54, index); } #[test] @@ -261,6 +261,6 @@ mod tests { checker1: 3, checker2: 4, }; - assert_eq!(Some(action), TrictracAction::from_action_index(81)); + assert_eq!(Some(action), TrictracAction::from_action_index(54)); } } diff --git a/bot/src/strategy/dqnburn.rs b/bot/src/strategy/dqnburn.rs index 6532adb..2b37e88 100644 --- a/bot/src/strategy/dqnburn.rs +++ b/bot/src/strategy/dqnburn.rs @@ -128,6 +128,7 @@ impl BotStrategy for DqnBurnStrategy { (dicevals.1, dicevals.0) }; + assert_eq!(self.color, Color::White); let from1 = self .game .board @@ -138,14 +139,16 @@ impl BotStrategy for DqnBurnStrategy { // empty move dice1 = 0; } - let mut to1 = if self.color == Color::White { - from1 + dice1 as usize + let mut to1 = from1; + if self.color == Color::White { + to1 += dice1 as usize; + if 24 < to1 { + // sortie + to1 = 0; + } } else { - from1 - dice1 as usize - }; - if 24 < to1 || to1 < 0 { - // sortie - to1 = 0; + let fto1 = to1 as i16 - dice1 as i16; + to1 = if fto1 < 0 { 0 } else { fto1 as usize }; } let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default(); @@ -159,17 +162,28 @@ impl BotStrategy for DqnBurnStrategy { // empty move dice2 = 0; } - let mut to2 = from2 + dice2 as usize; - if 24 < to2 { - // sortie - to2 = 0; + let mut to2 = from2; + if self.color == Color::White { + to2 += dice2 as usize; + if 24 < to2 { + // sortie + to2 = 0; + } + } else { + let fto2 = to2 as i16 - dice2 as i16; + to2 = if fto2 < 0 { 0 } else { fto2 as usize }; } // Gestion prise de coin par puissance - let opp_rest_field = 13; + let opp_rest_field = if self.color == Color::White { 13 } else { 12 }; if to1 == opp_rest_field && to2 == opp_rest_field { - to1 -= 1; - to2 -= 1; + if self.color == Color::White { + to1 -= 1; + to2 -= 1; + } else { + to1 += 1; + to2 += 1; + } } let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default(); @@ -178,6 +192,7 @@ impl BotStrategy for DqnBurnStrategy { let chosen_move = if self.color == Color::White { (checker_move1, checker_move2) } else { + // XXX : really ? (checker_move1.mirror(), checker_move2.mirror()) }; diff --git a/justfile b/justfile index c35d494..ffa3229 100644 --- a/justfile +++ b/justfile @@ -28,9 +28,10 @@ trainsimple: trainbot: #python ./store/python/trainModel.py # cargo run --bin=train_dqn # ok - ./bot/scripts/trainValid.sh + # ./bot/scripts/trainValid.sh + ./bot/scripts/train.sh plottrainbot: - ./bot/scripts/trainValid.sh plot + ./bot/scripts/train.sh plot debugtrainbot: cargo build --bin=train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn