diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh index 546bc01..55424a2 100755 --- a/bot/scripts/trainValid.sh +++ b/bot/scripts/trainValid.sh @@ -17,7 +17,7 @@ train() { } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do @@ -31,8 +31,19 @@ plot() { feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT" } +avg() { + NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) + LOGS="$LOGS_DIR/$NAME" + echo $LOGS + tail -n +$((CFG_SIZE + 2)) "$LOGS" | + grep -v "info:" | + awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }' +} + if [ "$1" = "plot" ]; then plot +elif [ "$1" = "avg" ]; then + avg else train fi diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index e634200..82ca118 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -386,6 +386,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -393,6 +395,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -401,7 +404,6 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -410,12 +412,9 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).0, } } TurnStage::MarkAdvPoints => { @@ -428,11 +427,10 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).1, } } TurnStage::HoldOrGoChoice => { @@ -449,6 +447,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } } } reward diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/dqn/burnrl_valid/environment.rs index 200aa49..08e65f7 100644 --- a/bot/src/dqn/burnrl_valid/environment.rs +++ b/bot/src/dqn/burnrl_valid/environment.rs @@ -156,17 +156,26 @@ impl Environment for TrictracEnvironment { if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { (reward, is_rollpoint) = self.execute_action(action); + // if reward != 0.0 { + // println!("info: self rew {reward}"); + // } if is_rollpoint { self.pointrolls_count += 1; } } else { // Action non convertible, pénalité + println!("info: action non convertible -> -1 {trictrac_action:?}"); reward = -1.0; } } // Faire jouer l'adversaire (stratégie simple) while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // let op_rew = self.play_opponent_if_needed(); + // if op_rew != 0.0 { + // println!("info: op rew {op_rew}"); + // } + // reward += op_rew; reward += self.play_opponent_if_needed(); } @@ -322,6 +331,7 @@ impl TrictracEnvironment { // Pénalité pour action invalide // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques + println!("info: action invalide -> err_reward"); reward = Self::ERROR_REWARD; } } @@ -346,6 +356,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -353,6 +365,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -361,7 +374,6 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -371,15 +383,12 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, points, } } TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -409,6 +418,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + // Récompense proportionnelle aux points + } } } reward