diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh index 55424a2..546bc01 100755 --- a/bot/scripts/trainValid.sh +++ b/bot/scripts/trainValid.sh @@ -17,7 +17,7 @@ train() { } plot() { - NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do @@ -31,19 +31,8 @@ plot() { feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT" } -avg() { - NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) - LOGS="$LOGS_DIR/$NAME" - echo $LOGS - tail -n +$((CFG_SIZE + 2)) "$LOGS" | - grep -v "info:" | - awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }' -} - if [ "$1" = "plot" ]; then plot -elif [ "$1" = "avg" ]; then - avg else train fi diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index 82ca118..e634200 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -386,8 +386,6 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -395,7 +393,6 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -404,6 +401,7 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { + let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -412,9 +410,12 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).0, + points, } } TurnStage::MarkAdvPoints => { @@ -427,10 +428,11 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points: points_rules.get_points(dice_roll_count).1, + points, } } TurnStage::HoldOrGoChoice => { @@ -447,19 +449,6 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - // Récompense proportionnelle aux points - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - } } } reward diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/dqn/burnrl_valid/environment.rs index 08e65f7..200aa49 100644 --- a/bot/src/dqn/burnrl_valid/environment.rs +++ b/bot/src/dqn/burnrl_valid/environment.rs @@ -156,26 +156,17 @@ impl Environment for TrictracEnvironment { if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { (reward, is_rollpoint) = self.execute_action(action); - // if reward != 0.0 { - // println!("info: self rew {reward}"); - // } if is_rollpoint { self.pointrolls_count += 1; } } else { // Action non convertible, pénalité - println!("info: action non convertible -> -1 {trictrac_action:?}"); reward = -1.0; } } // Faire jouer l'adversaire (stratégie simple) while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { - // let op_rew = self.play_opponent_if_needed(); - // if op_rew != 0.0 { - // println!("info: op rew {op_rew}"); - // } - // reward += op_rew; reward += self.play_opponent_if_needed(); } @@ -331,7 +322,6 @@ impl TrictracEnvironment { // Pénalité pour action invalide // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques - println!("info: action invalide -> err_reward"); reward = Self::ERROR_REWARD; } } @@ -356,8 +346,6 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage - let mut calculate_points = false; - let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -365,7 +353,6 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); - calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -374,6 +361,7 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { + let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -383,12 +371,15 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points + GameEvent::Mark { player_id: self.opponent_id, points, } } TurnStage::MarkAdvPoints => { + let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -418,19 +409,6 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); - if calculate_points { - let dice_roll_count = self - .game - .players - .get(&self.opponent_id) - .unwrap() - .dice_roll_count; - let points_rules = - PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; - // Récompense proportionnelle aux points - } } } reward diff --git a/doc/store.puml b/doc/store.puml deleted file mode 100644 index dd90df5..0000000 --- a/doc/store.puml +++ /dev/null @@ -1,172 +0,0 @@ -@startuml - -class "CheckerMove" { - - from: Field - - to: Field - + to_display_string() - + new(from: Field, to: Field) - + mirror() - + chain(cmove: Self) - + get_from() - + get_to() - + is_exit() - + doable_with_dice(dice: usize) -} - -class "Board" { - - positions: [i8;24] - + new() - + mirror() - + set_positions(positions: [ i8 ; 24 ]) - + count_checkers(color: Color, from: Field, to: Field) - + to_vec() - + to_gnupg_pos_id() - + to_display_grid(col_size: usize) - + set(color: & Color, field: Field, amount: i8) - + blocked(color: & Color, field: Field) - + passage_blocked(color: & Color, field: Field) - + get_field_checkers(field: Field) - + get_checkers_color(field: Field) - + is_field_in_small_jan(field: Field) - + get_color_fields(color: Color) - + get_color_corner(color: & Color) - + get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool) - + passage_possible(color: & Color, cmove: & CheckerMove) - + move_possible(color: & Color, cmove: & CheckerMove) - + any_quarter_filled(color: Color) - + is_quarter_filled(color: Color, field: Field) - + get_quarter_filling_candidate(color: Color) - + is_quarter_fillable(color: Color, field: Field) - - get_quarter_fields(field: Field) - + move_checker(color: & Color, cmove: CheckerMove) - + remove_checker(color: & Color, field: Field) - + add_checker(color: & Color, field: Field) -} - -class "MoveRules" { - + board: Board - + dice: Dice - + new(color: & Color, board: & Board, dice: Dice) - + set_board(color: & Color, board: & Board) - - get_board_from_color(color: & Color, board: & Board) - + moves_follow_rules(moves: & ( CheckerMove , CheckerMove )) - - moves_possible(moves: & ( CheckerMove , CheckerMove )) - - moves_follows_dices(moves: & ( CheckerMove , CheckerMove )) - - get_move_compatible_dices(cmove: & CheckerMove) - + moves_allowed(moves: & ( CheckerMove , CheckerMove )) - - check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove )) - - check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove )) - - check_corner_rules(moves: & ( CheckerMove , CheckerMove )) - - has_checkers_outside_last_quarter() - - check_exit_rules(moves: & ( CheckerMove , CheckerMove )) - + get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >) - + get_scoring_quarter_filling_moves_sequences() - - get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field) - + get_quarter_filling_moves_sequences() - - get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >) - - _get_direct_exit_moves(state: & GameState) - - is_move_by_puissance(moves: & ( CheckerMove , CheckerMove )) - - can_take_corner_by_effect() -} - -class "DiceRoller" { - - rng: StdRng - + new(opt_seed: Option < u64 >) - + roll() -} - -class "Dice" { - + values: (u8,u8) - + to_bits_string() - + to_display_string() - + is_double() -} - -class "GameState" { - + stage: Stage - + turn_stage: TurnStage - + board: Board - + active_player_id: PlayerId - + players: HashMap - + history: Vec - + dice: Dice - + dice_points: (u8,u8) - + dice_moves: (CheckerMove,CheckerMove) - + dice_jans: PossibleJans - - roll_first: bool - + schools_enabled: bool - + new(schools_enabled: bool) - - set_schools_enabled(schools_enabled: bool) - - get_active_player() - - get_opponent_id() - + to_vec_float() - + to_vec() - + to_string_id() - + who_plays() - + get_white_player() - + get_black_player() - + player_id_by_color(color: Color) - + player_id(player: & Player) - + player_color_by_id(player_id: & PlayerId) - + validate(event: & GameEvent) - + init_player(player_name: & str) - - add_player(player_id: PlayerId, player: Player) - + switch_active_player() - + consume(valid_event: & GameEvent) - - new_pick_up() - - get_rollresult_jans(dice: & Dice) - + determine_winner() - - inc_roll_count(player_id: PlayerId) - - mark_points(player_id: PlayerId, points: u8) -} - -class "Player" { - + name: String - + color: Color - + points: u8 - + holes: u8 - + can_bredouille: bool - + can_big_bredouille: bool - + dice_roll_count: u8 - + new(name: String, color: Color) - + to_bits_string() - + to_vec() -} - -class "PointsRules" { - + board: Board - + dice: Dice - + move_rules: MoveRules - + new(color: & Color, board: & Board, dice: Dice) - + set_dice(dice: Dice) - + update_positions(positions: [ i8 ; 24 ]) - - get_jans(board_ini: & Board, dice_rolls_count: u8) - + get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >) - + get_points(dice_rolls_count: u8) - + get_result_jans(dice_rolls_count: u8) -} - - - - -"MoveRules" <-- "Board" -"MoveRules" <-- "Dice" - - - - - - -"GameState" <-- "Board" -"HashMap" <-- "Player" -"GameState" <-- "HashMap" -"GameState" <-- "Dice" - - - - -"PointsRules" <-- "Board" -"PointsRules" <-- "Dice" -"PointsRules" <-- "MoveRules" - -@enduml