bot train burnrl reward opponent
This commit is contained in:
parent
883ebf9bc1
commit
1e773671d9
|
|
@ -6,10 +6,10 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -1.12121;
|
const ERROR_REWARD: f32 = -1.0012121;
|
||||||
const REWARD_VALID_MOVE: f32 = 1.12121;
|
const REWARD_VALID_MOVE: f32 = 1.0012121;
|
||||||
const REWARD_RATIO: f32 = 0.01;
|
const REWARD_RATIO: f32 = 0.1;
|
||||||
const WIN_POINTS: f32 = 1.0;
|
const WIN_POINTS: f32 = 100.0;
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
|
@ -285,7 +285,7 @@ impl TrictracEnvironment {
|
||||||
if let Some(event) = action.to_event(&self.game) {
|
if let Some(event) = action.to_event(&self.game) {
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
reward += REWARD_VALID_MOVE;
|
// reward += REWARD_VALID_MOVE;
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -312,9 +312,11 @@ impl TrictracEnvironment {
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
reward = ERROR_REWARD;
|
||||||
|
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reward = ERROR_REWARD;
|
reward = ERROR_REWARD;
|
||||||
|
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
(reward, is_rollpoint)
|
(reward, is_rollpoint)
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -2.12121;
|
const ERROR_REWARD: f32 = -1.00012121;
|
||||||
const REWARD_VALID_MOVE: f32 = 2.12121;
|
const REWARD_VALID_MOVE: f32 = 1.00012121;
|
||||||
const REWARD_RATIO: f32 = 0.01;
|
const REWARD_RATIO: f32 = 0.1;
|
||||||
const WIN_POINTS: f32 = 0.1;
|
const WIN_POINTS: f32 = 100.0;
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
|
@ -352,6 +352,7 @@ impl TrictracEnvironment {
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
reward = ERROR_REWARD;
|
||||||
|
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
/// training_common_big.rs : environnement avec espace d'actions optimisé
|
||||||
|
/// (514 au lieu de 1252 pour training_common_big.rs)
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
use std::fmt::{Debug, Display, Formatter};
|
use std::fmt::{Debug, Display, Formatter};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
/// training_common_big.rs : environnement avec espace d'actions non optimisé
|
||||||
|
/// (1252 au lieu de 514 pour training_common.rs)
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
|
||||||
|
|
@ -742,6 +742,10 @@ impl GameState {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||||
|
self.mark_points(player_id, points)
|
||||||
|
}
|
||||||
|
|
||||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||||
// Update player points and holes
|
// Update player points and holes
|
||||||
let mut new_hole = false;
|
let mut new_hole = false;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue