diff --git a/bot/scripts/trainValid.sh b/bot/scripts/trainValid.sh index 546bc01..55424a2 100755 --- a/bot/scripts/trainValid.sh +++ b/bot/scripts/trainValid.sh @@ -17,7 +17,7 @@ train() { } plot() { - NAME=$(ls -rt "$LOGS_DIR" | tail -n 1) + NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) LOGS="$LOGS_DIR/$NAME" cfgs=$(head -n $CFG_SIZE "$LOGS") for cfg in $cfgs; do @@ -31,8 +31,19 @@ plot() { feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT" } +avg() { + NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1) + LOGS="$LOGS_DIR/$NAME" + echo $LOGS + tail -n +$((CFG_SIZE + 2)) "$LOGS" | + grep -v "info:" | + awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }' +} + if [ "$1" = "plot" ]; then plot +elif [ "$1" = "avg" ]; then + avg else train fi diff --git a/bot/src/dqn/burnrl/environment.rs b/bot/src/dqn/burnrl/environment.rs index e634200..82ca118 100644 --- a/bot/src/dqn/burnrl/environment.rs +++ b/bot/src/dqn/burnrl/environment.rs @@ -386,6 +386,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -393,6 +395,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -401,7 +404,6 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -410,12 +412,9 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).0, } } TurnStage::MarkAdvPoints => { @@ -428,11 +427,10 @@ impl TrictracEnvironment { .dice_roll_count; let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); - let points = points_rules.get_points(dice_roll_count).1; // pas de reward : déjà comptabilisé lors du tour de blanc GameEvent::Mark { player_id: self.opponent_id, - points, + points: points_rules.get_points(dice_roll_count).1, } } TurnStage::HoldOrGoChoice => { @@ -449,6 +447,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + // Récompense proportionnelle aux points + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + } } } reward diff --git a/bot/src/dqn/burnrl_valid/environment.rs b/bot/src/dqn/burnrl_valid/environment.rs index 200aa49..08e65f7 100644 --- a/bot/src/dqn/burnrl_valid/environment.rs +++ b/bot/src/dqn/burnrl_valid/environment.rs @@ -156,17 +156,26 @@ impl Environment for TrictracEnvironment { if self.game.active_player_id == self.active_player_id { if let Some(action) = trictrac_action { (reward, is_rollpoint) = self.execute_action(action); + // if reward != 0.0 { + // println!("info: self rew {reward}"); + // } if is_rollpoint { self.pointrolls_count += 1; } } else { // Action non convertible, pénalité + println!("info: action non convertible -> -1 {trictrac_action:?}"); reward = -1.0; } } // Faire jouer l'adversaire (stratégie simple) while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { + // let op_rew = self.play_opponent_if_needed(); + // if op_rew != 0.0 { + // println!("info: op rew {op_rew}"); + // } + // reward += op_rew; reward += self.play_opponent_if_needed(); } @@ -322,6 +331,7 @@ impl TrictracEnvironment { // Pénalité pour action invalide // on annule les précédents reward // et on indique une valeur reconnaissable pour statistiques + println!("info: action invalide -> err_reward"); reward = Self::ERROR_REWARD; } } @@ -346,6 +356,8 @@ impl TrictracEnvironment { *strategy.get_mut_game() = self.game.clone(); // Exécuter l'action selon le turn_stage + let mut calculate_points = false; + let opponent_color = store::Color::Black; let event = match self.game.turn_stage { TurnStage::RollDice => GameEvent::Roll { player_id: self.opponent_id, @@ -353,6 +365,7 @@ impl TrictracEnvironment { TurnStage::RollWaiting => { let mut rng = thread_rng(); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); + calculate_points = true; GameEvent::RollResult { player_id: self.opponent_id, dice: store::Dice { @@ -361,7 +374,6 @@ impl TrictracEnvironment { } } TurnStage::MarkPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -371,15 +383,12 @@ impl TrictracEnvironment { let points_rules = PointsRules::new(&opponent_color, &self.game.board, self.game.dice); let (points, adv_points) = points_rules.get_points(dice_roll_count); - reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points - GameEvent::Mark { player_id: self.opponent_id, points, } } TurnStage::MarkAdvPoints => { - let opponent_color = store::Color::Black; let dice_roll_count = self .game .players @@ -409,6 +418,19 @@ impl TrictracEnvironment { if self.game.validate(&event) { self.game.consume(&event); + if calculate_points { + let dice_roll_count = self + .game + .players + .get(&self.opponent_id) + .unwrap() + .dice_roll_count; + let points_rules = + PointsRules::new(&opponent_color, &self.game.board, self.game.dice); + let (points, adv_points) = points_rules.get_points(dice_roll_count); + reward -= Self::REWARD_RATIO * (points - adv_points) as f32; + // Récompense proportionnelle aux points + } } } reward diff --git a/doc/store.puml b/doc/store.puml new file mode 100644 index 0000000..dd90df5 --- /dev/null +++ b/doc/store.puml @@ -0,0 +1,172 @@ +@startuml + +class "CheckerMove" { + - from: Field + - to: Field + + to_display_string() + + new(from: Field, to: Field) + + mirror() + + chain(cmove: Self) + + get_from() + + get_to() + + is_exit() + + doable_with_dice(dice: usize) +} + +class "Board" { + - positions: [i8;24] + + new() + + mirror() + + set_positions(positions: [ i8 ; 24 ]) + + count_checkers(color: Color, from: Field, to: Field) + + to_vec() + + to_gnupg_pos_id() + + to_display_grid(col_size: usize) + + set(color: & Color, field: Field, amount: i8) + + blocked(color: & Color, field: Field) + + passage_blocked(color: & Color, field: Field) + + get_field_checkers(field: Field) + + get_checkers_color(field: Field) + + is_field_in_small_jan(field: Field) + + get_color_fields(color: Color) + + get_color_corner(color: & Color) + + get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool) + + passage_possible(color: & Color, cmove: & CheckerMove) + + move_possible(color: & Color, cmove: & CheckerMove) + + any_quarter_filled(color: Color) + + is_quarter_filled(color: Color, field: Field) + + get_quarter_filling_candidate(color: Color) + + is_quarter_fillable(color: Color, field: Field) + - get_quarter_fields(field: Field) + + move_checker(color: & Color, cmove: CheckerMove) + + remove_checker(color: & Color, field: Field) + + add_checker(color: & Color, field: Field) +} + +class "MoveRules" { + + board: Board + + dice: Dice + + new(color: & Color, board: & Board, dice: Dice) + + set_board(color: & Color, board: & Board) + - get_board_from_color(color: & Color, board: & Board) + + moves_follow_rules(moves: & ( CheckerMove , CheckerMove )) + - moves_possible(moves: & ( CheckerMove , CheckerMove )) + - moves_follows_dices(moves: & ( CheckerMove , CheckerMove )) + - get_move_compatible_dices(cmove: & CheckerMove) + + moves_allowed(moves: & ( CheckerMove , CheckerMove )) + - check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove )) + - check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove )) + - check_corner_rules(moves: & ( CheckerMove , CheckerMove )) + - has_checkers_outside_last_quarter() + - check_exit_rules(moves: & ( CheckerMove , CheckerMove )) + + get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >) + + get_scoring_quarter_filling_moves_sequences() + - get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field) + + get_quarter_filling_moves_sequences() + - get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >) + - _get_direct_exit_moves(state: & GameState) + - is_move_by_puissance(moves: & ( CheckerMove , CheckerMove )) + - can_take_corner_by_effect() +} + +class "DiceRoller" { + - rng: StdRng + + new(opt_seed: Option < u64 >) + + roll() +} + +class "Dice" { + + values: (u8,u8) + + to_bits_string() + + to_display_string() + + is_double() +} + +class "GameState" { + + stage: Stage + + turn_stage: TurnStage + + board: Board + + active_player_id: PlayerId + + players: HashMap + + history: Vec + + dice: Dice + + dice_points: (u8,u8) + + dice_moves: (CheckerMove,CheckerMove) + + dice_jans: PossibleJans + - roll_first: bool + + schools_enabled: bool + + new(schools_enabled: bool) + - set_schools_enabled(schools_enabled: bool) + - get_active_player() + - get_opponent_id() + + to_vec_float() + + to_vec() + + to_string_id() + + who_plays() + + get_white_player() + + get_black_player() + + player_id_by_color(color: Color) + + player_id(player: & Player) + + player_color_by_id(player_id: & PlayerId) + + validate(event: & GameEvent) + + init_player(player_name: & str) + - add_player(player_id: PlayerId, player: Player) + + switch_active_player() + + consume(valid_event: & GameEvent) + - new_pick_up() + - get_rollresult_jans(dice: & Dice) + + determine_winner() + - inc_roll_count(player_id: PlayerId) + - mark_points(player_id: PlayerId, points: u8) +} + +class "Player" { + + name: String + + color: Color + + points: u8 + + holes: u8 + + can_bredouille: bool + + can_big_bredouille: bool + + dice_roll_count: u8 + + new(name: String, color: Color) + + to_bits_string() + + to_vec() +} + +class "PointsRules" { + + board: Board + + dice: Dice + + move_rules: MoveRules + + new(color: & Color, board: & Board, dice: Dice) + + set_dice(dice: Dice) + + update_positions(positions: [ i8 ; 24 ]) + - get_jans(board_ini: & Board, dice_rolls_count: u8) + + get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >) + + get_points(dice_rolls_count: u8) + + get_result_jans(dice_rolls_count: u8) +} + + + + +"MoveRules" <-- "Board" +"MoveRules" <-- "Dice" + + + + + + +"GameState" <-- "Board" +"HashMap" <-- "Player" +"GameState" <-- "HashMap" +"GameState" <-- "Dice" + + + + +"PointsRules" <-- "Board" +"PointsRules" <-- "Dice" +"PointsRules" <-- "MoveRules" + +@enduml