wip action mask
This commit is contained in:
parent
1e18b784d1
commit
66377f877c
|
|
@ -1,16 +1,15 @@
|
||||||
use crate::burnrl::utils::soft_update_linear;
|
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::optim::AdamWConfig;
|
use burn::optim::AdamWConfig;
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::relu;
|
use burn::tensor::activation::relu;
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||||
|
use crate::burnrl::utils::soft_update_linear;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug, Clone)]
|
||||||
pub struct Net<B: Backend> {
|
pub struct Net<B: Backend> {
|
||||||
linear_0: Linear<B>,
|
linear_0: Linear<B>,
|
||||||
linear_1: Linear<B>,
|
linear_1: Linear<B>,
|
||||||
|
|
@ -19,11 +18,11 @@ pub struct Net<B: Backend> {
|
||||||
|
|
||||||
impl<B: Backend> Net<B> {
|
impl<B: Backend> Net<B> {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
linear_0: LinearConfig::new(input_size, dense_size).init(device),
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
|
||||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
linear_2: LinearConfig::new(dense_size, output_size).init(device),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -34,7 +33,7 @@ impl<B: Backend> Net<B> {
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
let layer_0_output = relu(self.linear_0.forward(input));
|
let layer_0_output = relu(self.linear_0.forward(input.clone()));
|
||||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||||
|
|
||||||
relu(self.linear_2.forward(layer_1_output))
|
relu(self.linear_2.forward(layer_1_output))
|
||||||
|
|
@ -46,8 +45,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> DQNModel<B> for Net<B> {
|
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
|
||||||
let (linear_0, linear_1, linear_2) = this.consume();
|
let (linear_0, linear_1, linear_2) = self.consume();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||||
|
|
@ -70,14 +69,15 @@ type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
pub fn run<E: Environment, B: AutodiffBackend>(
|
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
num_episodes: usize,
|
num_episodes: usize,
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
) -> DQN<E, B, Net<B>> {
|
) -> impl Agent<E> {
|
||||||
// ) -> impl Agent<E> {
|
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
|
let device = Default::default();
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
<<E as Environment>::StateType as State>::size(),
|
<E::StateType as State>::size(),
|
||||||
DENSE_SIZE,
|
DENSE_SIZE,
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
<E::ActionType as Action>::size(),
|
||||||
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut agent = MyAgent::new(model);
|
let mut agent = MyAgent::new(model);
|
||||||
|
|
@ -108,7 +108,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
let snapshot = env.step(action);
|
let snapshot = env.step(action);
|
||||||
|
|
||||||
episode_reward +=
|
episode_reward +=
|
||||||
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||||
|
|
||||||
memory.push(
|
memory.push(
|
||||||
state,
|
state,
|
||||||
|
|
@ -119,8 +119,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
);
|
);
|
||||||
|
|
||||||
if config.batch_size < memory.len() {
|
if config.batch_size < memory.len() {
|
||||||
policy_net =
|
policy_net = agent.train(policy_net, &memory, &mut optimizer, &config);
|
||||||
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
step += 1;
|
step += 1;
|
||||||
|
|
@ -139,5 +138,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
agent
|
agent.valid()
|
||||||
}
|
}
|
||||||
|
|
@ -199,6 +199,15 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrictracEnvironment {
|
impl TrictracEnvironment {
|
||||||
|
pub fn valid_actions(&self) -> Vec<TrictracAction> {
|
||||||
|
dqn_common::get_valid_actions(&self.game)
|
||||||
|
.into_iter()
|
||||||
|
.map(|a| TrictracAction {
|
||||||
|
index: a.to_action_index() as u32,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
fn convert_action(
|
fn convert_action(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -380,4 +389,4 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
reward
|
reward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -23,7 +23,7 @@ fn main() {
|
||||||
let loaded_agent = DQN::new(loaded_model);
|
let loaded_agent = DQN::new(loaded_model);
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model::<Env>(loaded_agent);
|
demo_model(loaded_agent, |env| env.valid_actions());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
||||||
|
|
@ -55,6 +55,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
||||||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||||
DENSE_SIZE,
|
DENSE_SIZE,
|
||||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
&device,
|
||||||
)
|
)
|
||||||
.load_record(record)
|
.load_record(record)
|
||||||
}
|
}
|
||||||
|
|
@ -1,21 +1,60 @@
|
||||||
use burn::module::{Param, ParamId};
|
use burn::module::{Module, Param, ParamId};
|
||||||
use burn::nn::Linear;
|
use burn::nn::Linear;
|
||||||
use burn::tensor::backend::Backend;
|
use burn::tensor::backend::Backend;
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::base::{Agent, ElemType, Environment};
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
|
|
||||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
pub fn demo_model<E, M, B, F>(
|
||||||
|
agent: DQN<E, B, M>,
|
||||||
|
mut get_valid_actions: F,
|
||||||
|
) where
|
||||||
|
E: Environment,
|
||||||
|
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||||
|
B: Backend,
|
||||||
|
F: FnMut(&E) -> Vec<E::ActionType>,
|
||||||
|
<E as Environment>::ActionType: PartialEq,
|
||||||
|
{
|
||||||
let mut env = E::new(true);
|
let mut env = E::new(true);
|
||||||
let mut state = env.state();
|
let mut state = env.state();
|
||||||
let mut done = false;
|
let mut done = false;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
let mut steps = 0;
|
||||||
|
|
||||||
while !done {
|
while !done {
|
||||||
if let Some(action) = agent.react(&state) {
|
let model = agent.model().as_ref().unwrap();
|
||||||
let snapshot = env.step(action);
|
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||||
state = *snapshot.state();
|
let q_values = model.infer(state_tensor);
|
||||||
// println!("{:?}", state);
|
|
||||||
done = snapshot.done();
|
let valid_actions = get_valid_actions(&env);
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
break; // No valid actions, end of episode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||||
|
masked_q_values =
|
||||||
|
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
|
||||||
|
let action = E::ActionType::from(action_index);
|
||||||
|
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
state = *snapshot.state();
|
||||||
|
total_reward +=
|
||||||
|
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||||
|
steps += 1;
|
||||||
|
done = snapshot.done() || steps >= E::MAX_STEPS;
|
||||||
}
|
}
|
||||||
|
println!(
|
||||||
|
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||||
|
total_reward, steps
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
use store::MoveRules;
|
use store::MoveRules;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use store::MoveRules;
|
use store::MoveRules;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use store::{CheckerMove, Dice, GameEvent, PlayerId};
|
use store::{CheckerMove, Dice};
|
||||||
|
|
||||||
/// Types d'actions possibles dans le jeu
|
/// Types d'actions possibles dans le jeu
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
|
@ -259,7 +259,7 @@ impl SimpleNeuralNetwork {
|
||||||
|
|
||||||
/// Obtient les actions valides pour l'état de jeu actuel
|
/// Obtient les actions valides pour l'état de jeu actuel
|
||||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
use crate::PointsRules;
|
|
||||||
use store::TurnStage;
|
use store::TurnStage;
|
||||||
|
|
||||||
let mut valid_actions = Vec::new();
|
let mut valid_actions = Vec::new();
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue