wip action mask
This commit is contained in:
parent
1e18b784d1
commit
66377f877c
|
|
@ -1,16 +1,15 @@
|
|||
use crate::burnrl::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use crate::burnrl::utils::soft_update_linear;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
|
|
@ -19,11 +18,11 @@ pub struct Net<B: Backend> {
|
|||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize, device: &B::Device) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(device),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(device),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(device),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -34,7 +33,7 @@ impl<B: Backend> Net<B> {
|
|||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_0_output = relu(self.linear_0.forward(input.clone()));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
relu(self.linear_2.forward(layer_1_output))
|
||||
|
|
@ -46,8 +45,8 @@ impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
fn soft_update(self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = self.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
|
|
@ -70,14 +69,15 @@ type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
|||
pub fn run<E: Environment, B: AutodiffBackend>(
|
||||
num_episodes: usize,
|
||||
visualized: bool,
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
let device = Default::default();
|
||||
|
||||
let model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
<E::StateType as State>::size(),
|
||||
DENSE_SIZE,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
<E::ActionType as Action>::size(),
|
||||
&device,
|
||||
);
|
||||
|
||||
let mut agent = MyAgent::new(model);
|
||||
|
|
@ -108,7 +108,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward +=
|
||||
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
<E::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
|
|
@ -119,8 +119,7 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
policy_net =
|
||||
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
|
||||
policy_net = agent.train(policy_net, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
step += 1;
|
||||
|
|
@ -139,5 +138,5 @@ pub fn run<E: Environment, B: AutodiffBackend>(
|
|||
}
|
||||
}
|
||||
}
|
||||
agent
|
||||
}
|
||||
agent.valid()
|
||||
}
|
||||
|
|
@ -199,6 +199,15 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
impl TrictracEnvironment {
|
||||
pub fn valid_actions(&self) -> Vec<TrictracAction> {
|
||||
dqn_common::get_valid_actions(&self.game)
|
||||
.into_iter()
|
||||
.map(|a| TrictracAction {
|
||||
index: a.to_action_index() as u32,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
fn convert_action(
|
||||
&self,
|
||||
|
|
@ -380,4 +389,4 @@ impl TrictracEnvironment {
|
|||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -23,7 +23,7 @@ fn main() {
|
|||
let loaded_agent = DQN::new(loaded_model);
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model::<Env>(loaded_agent);
|
||||
demo_model(loaded_agent, |env| env.valid_actions());
|
||||
}
|
||||
|
||||
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>) {
|
||||
|
|
@ -55,6 +55,7 @@ fn load_model() -> dqn_model::Net<NdArray<ElemType>> {
|
|||
<environment::TrictracEnvironment as Environment>::StateType::size(),
|
||||
DENSE_SIZE,
|
||||
<environment::TrictracEnvironment as Environment>::ActionType::size(),
|
||||
&device,
|
||||
)
|
||||
.load_record(record)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,21 +1,60 @@
|
|||
use burn::module::{Param, ParamId};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Agent, ElemType, Environment};
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
use burn_rl::agent::DQN;
|
||||
|
||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||
pub fn demo_model<E, M, B, F>(
|
||||
agent: DQN<E, B, M>,
|
||||
mut get_valid_actions: F,
|
||||
) where
|
||||
E: Environment,
|
||||
M: Module<B> + burn_rl::agent::DQNModel<B>,
|
||||
B: Backend,
|
||||
F: FnMut(&E) -> Vec<E::ActionType>,
|
||||
<E as Environment>::ActionType: PartialEq,
|
||||
{
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
let mut done = false;
|
||||
let mut total_reward = 0.0;
|
||||
let mut steps = 0;
|
||||
|
||||
while !done {
|
||||
if let Some(action) = agent.react(&state) {
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("{:?}", state);
|
||||
done = snapshot.done();
|
||||
let model = agent.model().as_ref().unwrap();
|
||||
let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
let q_values = model.infer(state_tensor);
|
||||
|
||||
let valid_actions = get_valid_actions(&env);
|
||||
if valid_actions.is_empty() {
|
||||
break; // No valid actions, end of episode
|
||||
}
|
||||
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
masked_q_values =
|
||||
masked_q_values.mask_fill(masked_q_values.clone().equal_elem(*q_value), f32::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
let action_index = masked_q_values.argmax(1).into_scalar() as u32;
|
||||
let action = E::ActionType::from(action_index);
|
||||
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
total_reward +=
|
||||
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
steps += 1;
|
||||
done = snapshot.done() || steps >= E::MAX_STEPS;
|
||||
}
|
||||
println!(
|
||||
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||
total_reward, steps
|
||||
);
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId, PointsRules};
|
||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::cmp::{max, min};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use store::{CheckerMove, Dice, GameEvent, PlayerId};
|
||||
use store::{CheckerMove, Dice};
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
|
|
@ -259,7 +259,7 @@ impl SimpleNeuralNetwork {
|
|||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||
use crate::PointsRules;
|
||||
|
||||
use store::TurnStage;
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
|
|
|||
Loading…
Reference in a new issue