runcli with bot dqn burn-rl
This commit is contained in:
parent
a19c5d8596
commit
17d29b8633
|
|
@ -141,7 +141,7 @@ impl Environment for TrictracEnvironment {
|
||||||
self.step_count += 1;
|
self.step_count += 1;
|
||||||
|
|
||||||
// Convertir l'action burn-rl vers une action Trictrac
|
// Convertir l'action burn-rl vers une action Trictrac
|
||||||
let trictrac_action = self.convert_action(action, &self.game);
|
let trictrac_action = Self::convert_action(action);
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let mut terminated = false;
|
let mut terminated = false;
|
||||||
|
|
@ -203,11 +203,7 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
impl TrictracEnvironment {
|
impl TrictracEnvironment {
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
fn convert_action(
|
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||||
&self,
|
|
||||||
action: TrictracAction,
|
|
||||||
game_state: &GameState,
|
|
||||||
) -> Option<dqn_common::TrictracAction> {
|
|
||||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ fn main() {
|
||||||
|
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
let loaded_model = load_model(conf.dense_size, &path);
|
let loaded_model = load_model(conf.dense_size, &path);
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model.unwrap());
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model(loaded_agent);
|
demo_model(loaded_agent);
|
||||||
|
|
|
||||||
|
|
@ -22,23 +22,21 @@ pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
|
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||||
let model_path = format!("{path}_model.mpk");
|
let model_path = format!("{path}_model.mpk");
|
||||||
println!("Chargement du modèle depuis : {model_path}");
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
let device = NdArrayDevice::default();
|
CompactRecorder::new()
|
||||||
let recorder = CompactRecorder::new();
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
let record = recorder
|
dqn_model::Net::new(
|
||||||
.load(model_path.into(), &device)
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
.expect("Impossible de charger le modèle");
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
dqn_model::Net::new(
|
)
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
.load_record(record)
|
||||||
dense_size,
|
})
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
.ok()
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ use log::{debug, error};
|
||||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
pub use strategy::default::DefaultStrategy;
|
pub use strategy::default::DefaultStrategy;
|
||||||
pub use strategy::dqn::DqnStrategy;
|
pub use strategy::dqn::DqnStrategy;
|
||||||
|
pub use strategy::dqnburn::DqnBurnStrategy;
|
||||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||||
pub use strategy::random::RandomStrategy;
|
pub use strategy::random::RandomStrategy;
|
||||||
pub use strategy::stable_baselines3::StableBaselines3Strategy;
|
pub use strategy::stable_baselines3::StableBaselines3Strategy;
|
||||||
|
|
|
||||||
176
bot/src/strategy/dqnburn.rs
Normal file
176
bot/src/strategy/dqnburn.rs
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
|
use burn_rl::base::{ElemType, Model, State};
|
||||||
|
|
||||||
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use log::info;
|
||||||
|
use store::MoveRules;
|
||||||
|
|
||||||
|
use crate::dqn::burnrl::{dqn_model, environment, utils};
|
||||||
|
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||||
|
|
||||||
|
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||||||
|
|
||||||
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnBurnStrategy {
|
||||||
|
pub game: GameState,
|
||||||
|
pub player_id: PlayerId,
|
||||||
|
pub color: Color,
|
||||||
|
pub model: Option<DqnBurnNetwork>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnBurnStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
game: GameState::default(),
|
||||||
|
player_id: 1,
|
||||||
|
color: Color::White,
|
||||||
|
model: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnBurnStrategy {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_model(model_path: &String) -> Self {
|
||||||
|
info!("Loading model {model_path:?}");
|
||||||
|
let mut strategy = Self::new();
|
||||||
|
strategy.model = utils::load_model(256, model_path);
|
||||||
|
strategy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilise le modèle DQN pour choisir une action valide
|
||||||
|
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||||
|
if let Some(ref model) = self.model {
|
||||||
|
let state = environment::TrictracState::from_game_state(&self.game);
|
||||||
|
let valid_actions_indices = get_valid_action_indices(&self.game);
|
||||||
|
if valid_actions_indices.is_empty() {
|
||||||
|
return None; // No valid actions, end of episode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtenir les Q-values pour toutes les actions
|
||||||
|
let q_values = model.infer(state.to_tensor().unsqueeze());
|
||||||
|
|
||||||
|
// Set non valid actions q-values to lowest
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions_indices.contains(&index) {
|
||||||
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get best action (highest q-value)
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
|
||||||
|
action_index,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
// Fallback : action aléatoire valide
|
||||||
|
sample_valid_action(&self.game)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for DqnBurnStrategy {
|
||||||
|
fn get_game(&self) -> &GameState {
|
||||||
|
&self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||||||
|
&mut self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_color(&mut self, color: Color) {
|
||||||
|
self.color = color;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||||
|
self.player_id = player_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_adv_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
// Utiliser le DQN pour décider si on continue
|
||||||
|
if let Some(action) = self.get_dqn_action() {
|
||||||
|
matches!(action, TrictracAction::Go)
|
||||||
|
} else {
|
||||||
|
// Fallback : toujours continuer
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
// Utiliser le DQN pour choisir le mouvement
|
||||||
|
if let Some(TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
}) = self.get_dqn_action()
|
||||||
|
{
|
||||||
|
let dicevals = self.game.dice.values;
|
||||||
|
let (mut dice1, mut dice2) = if dice_order {
|
||||||
|
(dicevals.0, dicevals.1)
|
||||||
|
} else {
|
||||||
|
(dicevals.1, dicevals.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
if from1 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice1 = 0;
|
||||||
|
}
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
if 24 < to1 {
|
||||||
|
// sortie
|
||||||
|
to1 = 0;
|
||||||
|
}
|
||||||
|
if from2 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice2 = 0;
|
||||||
|
}
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
if 24 < to2 {
|
||||||
|
// sortie
|
||||||
|
to2 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
let chosen_move = if self.color == Color::White {
|
||||||
|
(checker_move1, checker_move2)
|
||||||
|
} else {
|
||||||
|
(checker_move1.mirror(), checker_move2.mirror())
|
||||||
|
};
|
||||||
|
|
||||||
|
return chosen_move;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback : utiliser la stratégie par défaut
|
||||||
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
let chosen_move = *possible_moves
|
||||||
|
.first()
|
||||||
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
|
if self.color == Color::White {
|
||||||
|
chosen_move
|
||||||
|
} else {
|
||||||
|
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod default;
|
pub mod default;
|
||||||
pub mod dqn;
|
pub mod dqn;
|
||||||
|
pub mod dqnburn;
|
||||||
pub mod erroneous_moves;
|
pub mod erroneous_moves;
|
||||||
pub mod random;
|
pub mod random;
|
||||||
pub mod stable_baselines3;
|
pub mod stable_baselines3;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bot::{
|
use bot::{
|
||||||
BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
|
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
|
||||||
StableBaselines3Strategy,
|
StableBaselines3Strategy,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
@ -25,11 +25,11 @@ pub struct App {
|
||||||
impl App {
|
impl App {
|
||||||
// Constructs a new instance of [`App`].
|
// Constructs a new instance of [`App`].
|
||||||
pub fn new(args: AppArgs) -> Self {
|
pub fn new(args: AppArgs) -> Self {
|
||||||
let bot_strategies: Vec<Box<dyn BotStrategy>> =
|
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
|
||||||
args.bot
|
.bot
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.map(|str_bots| {
|
.map(|str_bots| {
|
||||||
str_bots
|
str_bots
|
||||||
.split(",")
|
.split(",")
|
||||||
.filter_map(|s| match s.trim() {
|
.filter_map(|s| match s.trim() {
|
||||||
"dummy" => {
|
"dummy" => {
|
||||||
|
|
@ -44,6 +44,9 @@ impl App {
|
||||||
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
||||||
as Box<dyn BotStrategy>),
|
as Box<dyn BotStrategy>),
|
||||||
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
|
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
|
||||||
|
"dqnburn" => {
|
||||||
|
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
s if s.starts_with("ai:") => {
|
s if s.starts_with("ai:") => {
|
||||||
let path = s.trim_start_matches("ai:");
|
let path = s.trim_start_matches("ai:");
|
||||||
Some(Box::new(StableBaselines3Strategy::new(path))
|
Some(Box::new(StableBaselines3Strategy::new(path))
|
||||||
|
|
@ -54,11 +57,16 @@ impl App {
|
||||||
Some(Box::new(DqnStrategy::new_with_model(path))
|
Some(Box::new(DqnStrategy::new_with_model(path))
|
||||||
as Box<dyn BotStrategy>)
|
as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
|
s if s.starts_with("dqnburn:") => {
|
||||||
|
let path = s.trim_start_matches("dqnburn:");
|
||||||
|
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}")))
|
||||||
|
as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let schools_enabled = false;
|
let schools_enabled = false;
|
||||||
let should_quit = bot_strategies.len() > 1;
|
let should_quit = bot_strategies.len() > 1;
|
||||||
Self {
|
Self {
|
||||||
|
|
|
||||||
3
justfile
3
justfile
|
|
@ -9,7 +9,8 @@ shell:
|
||||||
runcli:
|
runcli:
|
||||||
RUST_LOG=info cargo run --bin=client_cli
|
RUST_LOG=info cargo run --bin=client_cli
|
||||||
runclibots:
|
runclibots:
|
||||||
cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
|
cargo run --bin=client_cli -- --bot random,dqnburn:./models/burn_dqn_model.mpk
|
||||||
|
#cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy
|
||||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||||
match:
|
match:
|
||||||
cargo build --release --bin=client_cli
|
cargo build --release --bin=client_cli
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue