Compare commits
No commits in common. "74f692d7babeaa442d99c7cf9294b13d18e7b198" and "883ebf9bc166f3d601ff8d58457e53523c4cbc38" have entirely different histories.
74f692d7ba
...
883ebf9bc1
40 changed files with 3281 additions and 35 deletions
46
Cargo.lock
generated
46
Cargo.lock
generated
|
|
@ -834,7 +834,7 @@ dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"log",
|
"log",
|
||||||
"nvml-wrapper",
|
"nvml-wrapper",
|
||||||
"ratatui",
|
"ratatui 0.29.0",
|
||||||
"rstest",
|
"rstest",
|
||||||
"serde",
|
"serde",
|
||||||
"sysinfo",
|
"sysinfo",
|
||||||
|
|
@ -1066,6 +1066,17 @@ dependencies = [
|
||||||
"store",
|
"store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "client_tui"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"bincode 1.3.3",
|
||||||
|
"crossterm",
|
||||||
|
"ratatui 0.28.1",
|
||||||
|
"store",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cmake"
|
name = "cmake"
|
||||||
version = "0.1.54"
|
version = "0.1.54"
|
||||||
|
|
@ -4403,6 +4414,27 @@ version = "0.1.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde"
|
checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ratatui"
|
||||||
|
version = "0.28.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.4",
|
||||||
|
"cassowary",
|
||||||
|
"compact_str",
|
||||||
|
"crossterm",
|
||||||
|
"instability",
|
||||||
|
"itertools 0.13.0",
|
||||||
|
"lru",
|
||||||
|
"paste",
|
||||||
|
"strum 0.26.3",
|
||||||
|
"strum_macros 0.26.4",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode-truncate",
|
||||||
|
"unicode-width 0.1.14",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ratatui"
|
name = "ratatui"
|
||||||
version = "0.29.0"
|
version = "0.29.0"
|
||||||
|
|
@ -5781,6 +5813,18 @@ dependencies = [
|
||||||
"strength_reduce",
|
"strength_reduce",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "trictrac-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bincode 1.3.3",
|
||||||
|
"env_logger 0.10.2",
|
||||||
|
"log",
|
||||||
|
"pico-args",
|
||||||
|
"renet",
|
||||||
|
"store",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tungstenite"
|
name = "tungstenite"
|
||||||
version = "0.26.2"
|
version = "0.26.2"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
members = ["client_cli", "bot", "store"]
|
members = ["client_tui", "client_cli", "bot", "server", "store"]
|
||||||
|
|
|
||||||
194
bot/src/burnrl/algos/dqn_big.rs
Normal file
194
bot/src/burnrl/algos/dqn_big.rs
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||||
|
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||||
|
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||||
|
use burn::module::Module;
|
||||||
|
use burn::nn::{Linear, LinearConfig};
|
||||||
|
use burn::optim::AdamWConfig;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn::tensor::activation::relu;
|
||||||
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
|
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Net<B: Backend> {
|
||||||
|
linear_0: Linear<B>,
|
||||||
|
linear_1: Linear<B>,
|
||||||
|
linear_2: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Net<B> {
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||||
|
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||||
|
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||||
|
(self.linear_0, self.linear_1, self.linear_2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||||
|
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let layer_0_output = relu(self.linear_0.forward(input));
|
||||||
|
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||||
|
|
||||||
|
relu(self.linear_2.forward(layer_1_output))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
self.forward(input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
|
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||||
|
let (linear_0, linear_1, linear_2) = this.consume();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||||
|
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||||
|
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
const MEMORY_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||||
|
pub fn run<
|
||||||
|
E: Environment + AsMut<TrictracEnvironment>,
|
||||||
|
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||||
|
>(
|
||||||
|
conf: &Config,
|
||||||
|
visualized: bool,
|
||||||
|
// ) -> DQN<E, B, Net<B>> {
|
||||||
|
) -> impl Agent<E> {
|
||||||
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
|
let model = Net::<B>::new(
|
||||||
|
<<E as Environment>::StateType as State>::size(),
|
||||||
|
conf.dense_size,
|
||||||
|
<<E as Environment>::ActionType as Action>::size(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut agent = MyAgent::new(model);
|
||||||
|
|
||||||
|
// let config = DQNTrainingConfig::default();
|
||||||
|
let config = DQNTrainingConfig {
|
||||||
|
gamma: conf.gamma,
|
||||||
|
tau: conf.tau,
|
||||||
|
learning_rate: conf.learning_rate,
|
||||||
|
batch_size: conf.batch_size,
|
||||||
|
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||||
|
conf.clip_grad,
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||||
|
|
||||||
|
let mut optimizer = AdamWConfig::new()
|
||||||
|
.with_grad_clipping(config.clip_grad.clone())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let mut policy_net = agent.model().as_ref().unwrap().clone();
|
||||||
|
|
||||||
|
let mut step = 0_usize;
|
||||||
|
|
||||||
|
for episode in 0..conf.num_episodes {
|
||||||
|
let mut episode_done = false;
|
||||||
|
let mut episode_reward: ElemType = 0.0;
|
||||||
|
let mut episode_duration = 0_usize;
|
||||||
|
let mut state = env.state();
|
||||||
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
|
while !episode_done {
|
||||||
|
let eps_threshold = conf.eps_end
|
||||||
|
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
|
||||||
|
let action =
|
||||||
|
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
|
||||||
|
episode_reward +=
|
||||||
|
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||||
|
|
||||||
|
memory.push(
|
||||||
|
state,
|
||||||
|
*snapshot.state(),
|
||||||
|
action,
|
||||||
|
snapshot.reward().clone(),
|
||||||
|
snapshot.done(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if config.batch_size < memory.len() {
|
||||||
|
policy_net =
|
||||||
|
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
|
||||||
|
}
|
||||||
|
|
||||||
|
step += 1;
|
||||||
|
episode_duration += 1;
|
||||||
|
|
||||||
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
|
let envmut = env.as_mut();
|
||||||
|
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
|
||||||
|
* 100.0)
|
||||||
|
.round() as u32;
|
||||||
|
println!(
|
||||||
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
|
||||||
|
envmut.goodmoves_count,
|
||||||
|
goodmoves_ratio,
|
||||||
|
envmut.pointrolls_count,
|
||||||
|
now.elapsed().unwrap().as_secs(),
|
||||||
|
);
|
||||||
|
env.reset();
|
||||||
|
episode_done = true;
|
||||||
|
now = SystemTime::now();
|
||||||
|
} else {
|
||||||
|
state = *snapshot.state();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let valid_agent = agent.valid();
|
||||||
|
if let Some(path) = &conf.save_path {
|
||||||
|
save_model(valid_agent.model().as_ref().unwrap(), path);
|
||||||
|
}
|
||||||
|
valid_agent
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
CompactRecorder::new()
|
||||||
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
|
Net::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
pub mod dqn;
|
pub mod dqn;
|
||||||
|
pub mod dqn_big;
|
||||||
pub mod dqn_valid;
|
pub mod dqn_valid;
|
||||||
pub mod ppo;
|
pub mod ppo;
|
||||||
|
pub mod ppo_big;
|
||||||
pub mod ppo_valid;
|
pub mod ppo_valid;
|
||||||
pub mod sac;
|
pub mod sac;
|
||||||
|
pub mod sac_big;
|
||||||
pub mod sac_valid;
|
pub mod sac_valid;
|
||||||
|
|
|
||||||
191
bot/src/burnrl/algos/ppo_big.rs
Normal file
191
bot/src/burnrl/algos/ppo_big.rs
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||||
|
use crate::burnrl::utils::Config;
|
||||||
|
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||||
|
use burn::module::Module;
|
||||||
|
use burn::nn::{Initializer, Linear, LinearConfig};
|
||||||
|
use burn::optim::AdamWConfig;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn::tensor::activation::{relu, softmax};
|
||||||
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
||||||
|
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Net<B: Backend> {
|
||||||
|
linear: Linear<B>,
|
||||||
|
linear_actor: Linear<B>,
|
||||||
|
linear_critic: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Net<B> {
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||||
|
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
||||||
|
Self {
|
||||||
|
linear: LinearConfig::new(input_size, dense_size)
|
||||||
|
.with_initializer(initializer.clone())
|
||||||
|
.init(&Default::default()),
|
||||||
|
linear_actor: LinearConfig::new(dense_size, output_size)
|
||||||
|
.with_initializer(initializer.clone())
|
||||||
|
.init(&Default::default()),
|
||||||
|
linear_critic: LinearConfig::new(dense_size, 1)
|
||||||
|
.with_initializer(initializer)
|
||||||
|
.init(&Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
||||||
|
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
||||||
|
let layer_0_output = relu(self.linear.forward(input));
|
||||||
|
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
||||||
|
let values = self.linear_critic.forward(layer_0_output);
|
||||||
|
|
||||||
|
PPOOutput::<B>::new(policies, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let layer_0_output = relu(self.linear.forward(input));
|
||||||
|
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> PPOModel<B> for Net<B> {}
|
||||||
|
#[allow(unused)]
|
||||||
|
const MEMORY_SIZE: usize = 512;
|
||||||
|
|
||||||
|
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn run<
|
||||||
|
E: Environment + AsMut<TrictracEnvironment>,
|
||||||
|
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||||
|
>(
|
||||||
|
conf: &Config,
|
||||||
|
visualized: bool,
|
||||||
|
// ) -> PPO<E, B, Net<B>> {
|
||||||
|
) -> impl Agent<E> {
|
||||||
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
|
let mut model = Net::<B>::new(
|
||||||
|
<<E as Environment>::StateType as State>::size(),
|
||||||
|
conf.dense_size,
|
||||||
|
<<E as Environment>::ActionType as Action>::size(),
|
||||||
|
);
|
||||||
|
let agent = MyAgent::default();
|
||||||
|
let config = PPOTrainingConfig {
|
||||||
|
gamma: conf.gamma,
|
||||||
|
lambda: conf.lambda,
|
||||||
|
epsilon_clip: conf.epsilon_clip,
|
||||||
|
critic_weight: conf.critic_weight,
|
||||||
|
entropy_weight: conf.entropy_weight,
|
||||||
|
learning_rate: conf.learning_rate,
|
||||||
|
epochs: conf.epochs,
|
||||||
|
batch_size: conf.batch_size,
|
||||||
|
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||||
|
conf.clip_grad,
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut optimizer = AdamWConfig::new()
|
||||||
|
.with_grad_clipping(config.clip_grad.clone())
|
||||||
|
.init();
|
||||||
|
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||||
|
for episode in 0..conf.num_episodes {
|
||||||
|
let mut episode_done = false;
|
||||||
|
let mut episode_reward = 0.0;
|
||||||
|
let mut episode_duration = 0_usize;
|
||||||
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
|
env.reset();
|
||||||
|
while !episode_done {
|
||||||
|
let state = env.state();
|
||||||
|
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||||
|
snapshot.reward().clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
memory.push(
|
||||||
|
state,
|
||||||
|
*snapshot.state(),
|
||||||
|
action,
|
||||||
|
snapshot.reward().clone(),
|
||||||
|
snapshot.done(),
|
||||||
|
);
|
||||||
|
|
||||||
|
episode_duration += 1;
|
||||||
|
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||||
|
now.elapsed().unwrap().as_secs(),
|
||||||
|
);
|
||||||
|
|
||||||
|
now = SystemTime::now();
|
||||||
|
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
||||||
|
memory.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = &conf.save_path {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
||||||
|
|
||||||
|
// Save the trained model (backend B) to a temporary file
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), tmp_path.clone())
|
||||||
|
.expect("Failed to save temporary model");
|
||||||
|
|
||||||
|
// Create a new model instance with the target backend (NdArray)
|
||||||
|
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
||||||
|
<<E as Environment>::StateType as State>::size(),
|
||||||
|
conf.dense_size,
|
||||||
|
<<E as Environment>::ActionType as Action>::size(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load the record from the temporary file into the new model
|
||||||
|
let record = recorder
|
||||||
|
.load(tmp_path.clone(), &device)
|
||||||
|
.expect("Failed to load temporary model");
|
||||||
|
let model_with_loaded_weights = model_to_save.load_record(record);
|
||||||
|
|
||||||
|
// Clean up the temporary file
|
||||||
|
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
||||||
|
|
||||||
|
save_model(&model_with_loaded_weights, path);
|
||||||
|
}
|
||||||
|
agent.valid(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
CompactRecorder::new()
|
||||||
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
|
Net::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
222
bot/src/burnrl/algos/sac_big.rs
Normal file
222
bot/src/burnrl/algos/sac_big.rs
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||||
|
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||||
|
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||||
|
use burn::module::Module;
|
||||||
|
use burn::nn::{Linear, LinearConfig};
|
||||||
|
use burn::optim::AdamWConfig;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn::tensor::activation::{relu, softmax};
|
||||||
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
||||||
|
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Actor<B: Backend> {
|
||||||
|
linear_0: Linear<B>,
|
||||||
|
linear_1: Linear<B>,
|
||||||
|
linear_2: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Actor<B> {
|
||||||
|
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||||
|
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||||
|
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
||||||
|
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let layer_0_output = relu(self.linear_0.forward(input));
|
||||||
|
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||||
|
|
||||||
|
softmax(self.linear_2.forward(layer_1_output), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
self.forward(input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> SACActor<B> for Actor<B> {}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Critic<B: Backend> {
|
||||||
|
linear_0: Linear<B>,
|
||||||
|
linear_1: Linear<B>,
|
||||||
|
linear_2: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Critic<B> {
|
||||||
|
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||||
|
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||||
|
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||||
|
(self.linear_0, self.linear_1, self.linear_2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
||||||
|
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let layer_0_output = relu(self.linear_0.forward(input));
|
||||||
|
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||||
|
|
||||||
|
self.linear_2.forward(layer_1_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
self.forward(input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> SACCritic<B> for Critic<B> {
|
||||||
|
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||||
|
let (linear_0, linear_1, linear_2) = this.consume();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||||
|
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||||
|
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
const MEMORY_SIZE: usize = 4096;
|
||||||
|
|
||||||
|
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn run<
|
||||||
|
E: Environment + AsMut<TrictracEnvironment>,
|
||||||
|
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||||
|
>(
|
||||||
|
conf: &Config,
|
||||||
|
visualized: bool,
|
||||||
|
) -> impl Agent<E> {
|
||||||
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
let state_dim = <<E as Environment>::StateType as State>::size();
|
||||||
|
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
||||||
|
|
||||||
|
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||||
|
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||||
|
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||||
|
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
||||||
|
|
||||||
|
let mut agent = MyAgent::default();
|
||||||
|
|
||||||
|
let config = SACTrainingConfig {
|
||||||
|
gamma: conf.gamma,
|
||||||
|
tau: conf.tau,
|
||||||
|
learning_rate: conf.learning_rate,
|
||||||
|
min_probability: conf.min_probability,
|
||||||
|
batch_size: conf.batch_size,
|
||||||
|
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||||
|
conf.clip_grad,
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||||
|
|
||||||
|
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
||||||
|
|
||||||
|
let mut optimizer = SACOptimizer::new(
|
||||||
|
optimizer_config.clone().init(),
|
||||||
|
optimizer_config.clone().init(),
|
||||||
|
optimizer_config.clone().init(),
|
||||||
|
optimizer_config.init(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut step = 0_usize;
|
||||||
|
|
||||||
|
for episode in 0..conf.num_episodes {
|
||||||
|
let mut episode_done = false;
|
||||||
|
let mut episode_reward = 0.0;
|
||||||
|
let mut episode_duration = 0_usize;
|
||||||
|
let mut state = env.state();
|
||||||
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
|
while !episode_done {
|
||||||
|
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
|
||||||
|
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||||
|
snapshot.reward().clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
memory.push(
|
||||||
|
state,
|
||||||
|
*snapshot.state(),
|
||||||
|
action,
|
||||||
|
snapshot.reward().clone(),
|
||||||
|
snapshot.done(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if config.batch_size < memory.len() {
|
||||||
|
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
||||||
|
}
|
||||||
|
|
||||||
|
step += 1;
|
||||||
|
episode_duration += 1;
|
||||||
|
|
||||||
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
|
env.reset();
|
||||||
|
episode_done = true;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||||
|
now.elapsed().unwrap().as_secs()
|
||||||
|
);
|
||||||
|
now = SystemTime::now();
|
||||||
|
} else {
|
||||||
|
state = *snapshot.state();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let valid_agent = agent.valid(nets.actor);
|
||||||
|
if let Some(path) = &conf.save_path {
|
||||||
|
if let Some(model) = valid_agent.model() {
|
||||||
|
save_model(model, path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
valid_agent
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||||
|
let model_path = format!("{path}.mpk");
|
||||||
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
CompactRecorder::new()
|
||||||
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
|
Actor::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -6,10 +6,10 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -1.0012121;
|
const ERROR_REWARD: f32 = -1.12121;
|
||||||
const REWARD_VALID_MOVE: f32 = 1.0012121;
|
const REWARD_VALID_MOVE: f32 = 1.12121;
|
||||||
const REWARD_RATIO: f32 = 0.1;
|
const REWARD_RATIO: f32 = 0.01;
|
||||||
const WIN_POINTS: f32 = 100.0;
|
const WIN_POINTS: f32 = 1.0;
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
|
@ -285,7 +285,7 @@ impl TrictracEnvironment {
|
||||||
if let Some(event) = action.to_event(&self.game) {
|
if let Some(event) = action.to_event(&self.game) {
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
// reward += REWARD_VALID_MOVE;
|
reward += REWARD_VALID_MOVE;
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -312,11 +312,9 @@ impl TrictracEnvironment {
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
reward = ERROR_REWARD;
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reward = ERROR_REWARD;
|
reward = ERROR_REWARD;
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(reward, is_rollpoint)
|
(reward, is_rollpoint)
|
||||||
|
|
|
||||||
469
bot/src/burnrl/environment_big.rs
Normal file
469
bot/src/burnrl/environment_big.rs
Normal file
|
|
@ -0,0 +1,469 @@
|
||||||
|
use crate::training_common_big;
|
||||||
|
use burn::{prelude::Backend, tensor::Tensor};
|
||||||
|
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
|
const ERROR_REWARD: f32 = -2.12121;
|
||||||
|
const REWARD_VALID_MOVE: f32 = 2.12121;
|
||||||
|
const REWARD_RATIO: f32 = 0.01;
|
||||||
|
const WIN_POINTS: f32 = 0.1;
|
||||||
|
|
||||||
|
/// État du jeu Trictrac pour burn-rl
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct TrictracState {
|
||||||
|
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State for TrictracState {
|
||||||
|
type Data = [i8; 36];
|
||||||
|
|
||||||
|
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
||||||
|
Tensor::from_floats(self.data, &B::Device::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size() -> usize {
|
||||||
|
36
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracState {
|
||||||
|
/// Convertit un GameState en TrictracState
|
||||||
|
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||||
|
let state_vec = game_state.to_vec();
|
||||||
|
let mut data = [0; 36];
|
||||||
|
|
||||||
|
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||||
|
let copy_len = state_vec.len().min(36);
|
||||||
|
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
|
||||||
|
|
||||||
|
TrictracState { data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Actions possibles dans Trictrac pour burn-rl
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub struct TrictracAction {
|
||||||
|
// u32 as required by burn_rl::base::Action type
|
||||||
|
pub index: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Action for TrictracAction {
|
||||||
|
fn random() -> Self {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
TrictracAction {
|
||||||
|
index: rng.gen_range(0..Self::size() as u32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn enumerate() -> Vec<Self> {
|
||||||
|
(0..Self::size() as u32)
|
||||||
|
.map(|index| TrictracAction { index })
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size() -> usize {
|
||||||
|
1252
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<u32> for TrictracAction {
|
||||||
|
fn from(index: u32) -> Self {
|
||||||
|
TrictracAction { index }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TrictracAction> for u32 {
|
||||||
|
fn from(action: TrictracAction) -> u32 {
|
||||||
|
action.index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Environnement Trictrac pour burn-rl
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TrictracEnvironment {
|
||||||
|
pub game: GameState,
|
||||||
|
active_player_id: PlayerId,
|
||||||
|
opponent_id: PlayerId,
|
||||||
|
current_state: TrictracState,
|
||||||
|
episode_reward: f32,
|
||||||
|
pub step_count: usize,
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub pointrolls_count: usize,
|
||||||
|
pub goodmoves_count: usize,
|
||||||
|
pub goodmoves_ratio: f32,
|
||||||
|
pub visualized: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Environment for TrictracEnvironment {
|
||||||
|
type StateType = TrictracState;
|
||||||
|
type ActionType = TrictracAction;
|
||||||
|
type RewardType = f32;
|
||||||
|
|
||||||
|
fn new(visualized: bool) -> Self {
|
||||||
|
let mut game = GameState::new(false);
|
||||||
|
|
||||||
|
// Ajouter deux joueurs
|
||||||
|
game.init_player("DQN Agent");
|
||||||
|
game.init_player("Opponent");
|
||||||
|
let player1_id = 1;
|
||||||
|
let player2_id = 2;
|
||||||
|
|
||||||
|
// Commencer la partie
|
||||||
|
game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||||
|
|
||||||
|
let current_state = TrictracState::from_game_state(&game);
|
||||||
|
TrictracEnvironment {
|
||||||
|
game,
|
||||||
|
active_player_id: player1_id,
|
||||||
|
opponent_id: player2_id,
|
||||||
|
current_state,
|
||||||
|
episode_reward: 0.0,
|
||||||
|
step_count: 0,
|
||||||
|
max_steps: 2000,
|
||||||
|
pointrolls_count: 0,
|
||||||
|
goodmoves_count: 0,
|
||||||
|
goodmoves_ratio: 0.0,
|
||||||
|
visualized,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state(&self) -> Self::StateType {
|
||||||
|
self.current_state
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset(&mut self) -> Snapshot<Self> {
|
||||||
|
// Réinitialiser le jeu
|
||||||
|
self.game = GameState::new(false);
|
||||||
|
self.game.init_player("DQN Agent");
|
||||||
|
self.game.init_player("Opponent");
|
||||||
|
|
||||||
|
// Commencer la partie
|
||||||
|
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||||
|
|
||||||
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
|
self.episode_reward = 0.0;
|
||||||
|
self.goodmoves_ratio = if self.step_count == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
self.goodmoves_count as f32 / self.step_count as f32
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"info: correct moves: {} ({}%)",
|
||||||
|
self.goodmoves_count,
|
||||||
|
(100.0 * self.goodmoves_ratio).round() as u32
|
||||||
|
);
|
||||||
|
self.step_count = 0;
|
||||||
|
self.pointrolls_count = 0;
|
||||||
|
self.goodmoves_count = 0;
|
||||||
|
|
||||||
|
Snapshot::new(self.current_state, 0.0, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn step(&mut self, action: Self::ActionType) -> Snapshot<Self> {
|
||||||
|
self.step_count += 1;
|
||||||
|
|
||||||
|
// Convertir l'action burn-rl vers une action Trictrac
|
||||||
|
let trictrac_action = Self::convert_action(action);
|
||||||
|
|
||||||
|
let mut reward = 0.0;
|
||||||
|
let is_rollpoint;
|
||||||
|
|
||||||
|
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||||
|
if self.game.active_player_id == self.active_player_id {
|
||||||
|
if let Some(action) = trictrac_action {
|
||||||
|
(reward, is_rollpoint) = self.execute_action(action);
|
||||||
|
if is_rollpoint {
|
||||||
|
self.pointrolls_count += 1;
|
||||||
|
}
|
||||||
|
if reward != ERROR_REWARD {
|
||||||
|
self.goodmoves_count += 1;
|
||||||
|
// println!("{str_action}");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Action non convertible, pénalité
|
||||||
|
reward = -0.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Faire jouer l'adversaire (stratégie simple)
|
||||||
|
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
|
// print!(":");
|
||||||
|
reward += self.play_opponent_if_needed();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vérifier si la partie est terminée
|
||||||
|
// let max_steps = self.max_steps
|
||||||
|
// let max_steps = self.min_steps
|
||||||
|
// + (self.max_steps as f32 - self.min_steps)
|
||||||
|
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||||
|
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||||
|
|
||||||
|
if done {
|
||||||
|
// Récompense finale basée sur le résultat
|
||||||
|
if let Some(winner_id) = self.game.determine_winner() {
|
||||||
|
if winner_id == self.active_player_id {
|
||||||
|
reward += WIN_POINTS; // Victoire
|
||||||
|
} else {
|
||||||
|
reward -= WIN_POINTS; // Défaite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let terminated = done || self.step_count >= self.max_steps;
|
||||||
|
|
||||||
|
// Mettre à jour l'état
|
||||||
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
|
self.episode_reward += reward;
|
||||||
|
if self.visualized && terminated {
|
||||||
|
println!(
|
||||||
|
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||||
|
self.episode_reward, self.step_count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Snapshot::new(self.current_state, reward, terminated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracEnvironment {
|
||||||
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
|
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||||
|
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn convert_valid_action_index(
|
||||||
|
&self,
|
||||||
|
action: TrictracAction,
|
||||||
|
game_state: &GameState,
|
||||||
|
) -> Option<training_common_big::TrictracAction> {
|
||||||
|
use training_common_big::get_valid_actions;
|
||||||
|
|
||||||
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapper l'index d'action sur une action valide
|
||||||
|
let action_index = (action.index as usize) % valid_actions.len();
|
||||||
|
Some(valid_actions[action_index].clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exécute une action Trictrac dans le jeu
|
||||||
|
// fn execute_action(
|
||||||
|
// &mut self,
|
||||||
|
// action:training_common_big::TrictracAction,
|
||||||
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
|
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||||
|
use training_common_big::TrictracAction;
|
||||||
|
|
||||||
|
let mut reward = 0.0;
|
||||||
|
let mut is_rollpoint = false;
|
||||||
|
let mut need_roll = false;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
need_roll = true;
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.active_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game.dice.values.0, self.game.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game.dice.values.1, self.game.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appliquer l'événement si valide
|
||||||
|
if let Some(event) = event {
|
||||||
|
if self.game.validate(&event) {
|
||||||
|
self.game.consume(&event);
|
||||||
|
reward += REWARD_VALID_MOVE;
|
||||||
|
// Simuler le résultat des dés après un Roll
|
||||||
|
// if matches!(action, TrictracAction::Roll) {
|
||||||
|
if need_roll {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
let dice_event = GameEvent::RollResult {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
// print!("o");
|
||||||
|
if self.game.validate(&dice_event) {
|
||||||
|
self.game.consume(&dice_event);
|
||||||
|
let (points, adv_points) = self.game.dice_points;
|
||||||
|
reward += REWARD_RATIO * (points - adv_points) as f32;
|
||||||
|
if points > 0 {
|
||||||
|
is_rollpoint = true;
|
||||||
|
// println!("info: rolled for {reward}");
|
||||||
|
}
|
||||||
|
// Récompense proportionnelle aux points
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Pénalité pour action invalide
|
||||||
|
// on annule les précédents reward
|
||||||
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
|
reward = ERROR_REWARD;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(reward, is_rollpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fait jouer l'adversaire avec une stratégie simple
|
||||||
|
fn play_opponent_if_needed(&mut self) -> f32 {
|
||||||
|
// print!("z?");
|
||||||
|
let mut reward = 0.0;
|
||||||
|
|
||||||
|
// Si c'est le tour de l'adversaire, jouer automatiquement
|
||||||
|
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
|
// Utiliser la stratégie default pour l'adversaire
|
||||||
|
use crate::BotStrategy;
|
||||||
|
|
||||||
|
let mut strategy = crate::strategy::random::RandomStrategy::default();
|
||||||
|
strategy.set_player_id(self.opponent_id);
|
||||||
|
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
||||||
|
strategy.set_color(color);
|
||||||
|
}
|
||||||
|
*strategy.get_mut_game() = self.game.clone();
|
||||||
|
|
||||||
|
// Exécuter l'action selon le turn_stage
|
||||||
|
let mut calculate_points = false;
|
||||||
|
let opponent_color = store::Color::Black;
|
||||||
|
let event = match self.game.turn_stage {
|
||||||
|
TurnStage::RollDice => GameEvent::Roll {
|
||||||
|
player_id: self.opponent_id,
|
||||||
|
},
|
||||||
|
TurnStage::RollWaiting => {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
calculate_points = true; // comment to replicate burnrl_before
|
||||||
|
GameEvent::RollResult {
|
||||||
|
player_id: self.opponent_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::MarkPoints => {
|
||||||
|
panic!("in play_opponent_if_needed > TurnStage::MarkPoints");
|
||||||
|
// let dice_roll_count = self
|
||||||
|
// .game
|
||||||
|
// .players
|
||||||
|
// .get(&self.opponent_id)
|
||||||
|
// .unwrap()
|
||||||
|
// .dice_roll_count;
|
||||||
|
// let points_rules =
|
||||||
|
// PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
// GameEvent::Mark {
|
||||||
|
// player_id: self.opponent_id,
|
||||||
|
// points: points_rules.get_points(dice_roll_count).0,
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
TurnStage::MarkAdvPoints => {
|
||||||
|
let dice_roll_count = self
|
||||||
|
.game
|
||||||
|
.players
|
||||||
|
.get(&self.opponent_id)
|
||||||
|
.unwrap()
|
||||||
|
.dice_roll_count;
|
||||||
|
let points_rules =
|
||||||
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
// pas de reward : déjà comptabilisé lors du tour de blanc
|
||||||
|
GameEvent::Mark {
|
||||||
|
player_id: self.opponent_id,
|
||||||
|
points: points_rules.get_points(dice_roll_count).1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::HoldOrGoChoice => {
|
||||||
|
// Stratégie simple : toujours continuer
|
||||||
|
GameEvent::Go {
|
||||||
|
player_id: self.opponent_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::Move => GameEvent::Move {
|
||||||
|
player_id: self.opponent_id,
|
||||||
|
moves: strategy.choose_move(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.game.validate(&event) {
|
||||||
|
self.game.consume(&event);
|
||||||
|
// print!(".");
|
||||||
|
if calculate_points {
|
||||||
|
// print!("x");
|
||||||
|
let dice_roll_count = self
|
||||||
|
.game
|
||||||
|
.players
|
||||||
|
.get(&self.opponent_id)
|
||||||
|
.unwrap()
|
||||||
|
.dice_roll_count;
|
||||||
|
let points_rules =
|
||||||
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
|
// Récompense proportionnelle aux points
|
||||||
|
let adv_reward = REWARD_RATIO * (points - adv_points) as f32;
|
||||||
|
reward -= adv_reward;
|
||||||
|
// if adv_reward != 0.0 {
|
||||||
|
// println!("info: opponent : {adv_reward} -> {reward}");
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reward
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||||
|
fn as_mut(&mut self) -> &mut Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
use crate::training_common;
|
use crate::training_common_big;
|
||||||
use burn::{prelude::Backend, tensor::Tensor};
|
use burn::{prelude::Backend, tensor::Tensor};
|
||||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -1.0012121;
|
|
||||||
const REWARD_RATIO: f32 = 0.1;
|
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct TrictracState {
|
pub struct TrictracState {
|
||||||
|
|
@ -217,16 +214,16 @@ impl TrictracEnvironment {
|
||||||
const REWARD_RATIO: f32 = 1.0;
|
const REWARD_RATIO: f32 = 1.0;
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||||
fn convert_valid_action_index(
|
fn convert_valid_action_index(
|
||||||
&self,
|
&self,
|
||||||
action: TrictracAction,
|
action: TrictracAction,
|
||||||
) -> Option<training_common::TrictracAction> {
|
) -> Option<training_common_big::TrictracAction> {
|
||||||
use training_common::get_valid_actions;
|
use training_common_big::get_valid_actions;
|
||||||
|
|
||||||
// Obtenir les actions valides dans le contexte actuel
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
let valid_actions = get_valid_actions(&self.game);
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
@ -243,19 +240,72 @@ impl TrictracEnvironment {
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
// fn execute_action(
|
// fn execute_action(
|
||||||
// &mut self,
|
// &mut self,
|
||||||
// action: training_common::TrictracAction,
|
// action: training_common_big::TrictracAction,
|
||||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||||
use training_common::TrictracAction;
|
use training_common_big::TrictracAction;
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let mut is_rollpoint = false;
|
let mut is_rollpoint = false;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.active_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game.dice.values.0, self.game.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game.dice.values.1, self.game.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Appliquer l'événement si valide
|
// Appliquer l'événement si valide
|
||||||
if let Some(event) = action.to_event(&self.game) {
|
if let Some(event) = event {
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
// reward += REWARD_VALID_MOVE;
|
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -269,7 +319,7 @@ impl TrictracEnvironment {
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
if points > 0 {
|
if points > 0 {
|
||||||
is_rollpoint = true;
|
is_rollpoint = true;
|
||||||
// println!("info: rolled for {reward}");
|
// println!("info: rolled for {reward}");
|
||||||
|
|
@ -281,12 +331,9 @@ impl TrictracEnvironment {
|
||||||
// Pénalité pour action invalide
|
// Pénalité pour action invalide
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
println!("info: action invalide -> err_reward");
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
reward = Self::ERROR_REWARD;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
reward = ERROR_REWARD;
|
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(reward, is_rollpoint)
|
(reward, is_rollpoint)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
|
use bot::burnrl::algos::{
|
||||||
|
dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid,
|
||||||
|
};
|
||||||
use bot::burnrl::environment::TrictracEnvironment;
|
use bot::burnrl::environment::TrictracEnvironment;
|
||||||
|
use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig;
|
||||||
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
||||||
use bot::burnrl::utils::{demo_model, Config};
|
use bot::burnrl::utils::{demo_model, Config};
|
||||||
use burn::backend::{Autodiff, NdArray};
|
use burn::backend::{Autodiff, NdArray};
|
||||||
|
|
@ -33,6 +36,16 @@ fn main() {
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model(loaded_agent);
|
demo_model(loaded_agent);
|
||||||
}
|
}
|
||||||
|
"dqn_big" => {
|
||||||
|
let _agent = dqn_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||||
|
println!("> Chargement du modèle pour test");
|
||||||
|
let loaded_model = dqn_big::load_model(conf.dense_size, &path);
|
||||||
|
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentBig, _, _> =
|
||||||
|
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle chargé");
|
||||||
|
demo_model(loaded_agent);
|
||||||
|
}
|
||||||
"dqn_valid" => {
|
"dqn_valid" => {
|
||||||
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
|
|
@ -53,6 +66,16 @@ fn main() {
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model(loaded_agent);
|
demo_model(loaded_agent);
|
||||||
}
|
}
|
||||||
|
"sac_big" => {
|
||||||
|
let _agent = sac_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||||
|
println!("> Chargement du modèle pour test");
|
||||||
|
let loaded_model = sac_big::load_model(conf.dense_size, &path);
|
||||||
|
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentBig, _, _> =
|
||||||
|
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle chargé");
|
||||||
|
demo_model(loaded_agent);
|
||||||
|
}
|
||||||
"sac_valid" => {
|
"sac_valid" => {
|
||||||
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
|
|
@ -73,6 +96,16 @@ fn main() {
|
||||||
println!("> Test avec le modèle chargé");
|
println!("> Test avec le modèle chargé");
|
||||||
demo_model(loaded_agent);
|
demo_model(loaded_agent);
|
||||||
}
|
}
|
||||||
|
"ppo_big" => {
|
||||||
|
let _agent = ppo_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||||
|
println!("> Chargement du modèle pour test");
|
||||||
|
let loaded_model = ppo_big::load_model(conf.dense_size, &path);
|
||||||
|
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentBig, _, _> =
|
||||||
|
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle chargé");
|
||||||
|
demo_model(loaded_agent);
|
||||||
|
}
|
||||||
"ppo_valid" => {
|
"ppo_valid" => {
|
||||||
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||||
println!("> Chargement du modèle pour test");
|
println!("> Chargement du modèle pour test");
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
pub mod algos;
|
pub mod algos;
|
||||||
pub mod environment;
|
pub mod environment;
|
||||||
|
pub mod environment_big;
|
||||||
pub mod environment_valid;
|
pub mod environment_valid;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
|
||||||
153
bot/src/dqn_simple/dqn_model.rs
Normal file
153
bot/src/dqn_simple/dqn_model.rs
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
use crate::training_common_big::TrictracAction;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Configuration pour l'agent DQN
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DqnConfig {
|
||||||
|
pub state_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub num_actions: usize,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub gamma: f64,
|
||||||
|
pub epsilon: f64,
|
||||||
|
pub epsilon_decay: f64,
|
||||||
|
pub epsilon_min: f64,
|
||||||
|
pub replay_buffer_size: usize,
|
||||||
|
pub batch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state_size: 36,
|
||||||
|
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||||
|
num_actions: TrictracAction::action_space_size(),
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 0.1,
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SimpleNeuralNetwork {
|
||||||
|
pub weights1: Vec<Vec<f32>>,
|
||||||
|
pub biases1: Vec<f32>,
|
||||||
|
pub weights2: Vec<Vec<f32>>,
|
||||||
|
pub biases2: Vec<f32>,
|
||||||
|
pub weights3: Vec<Vec<f32>>,
|
||||||
|
pub biases3: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleNeuralNetwork {
|
||||||
|
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
|
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||||
|
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||||
|
let weights1 = (0..hidden_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..input_size)
|
||||||
|
.map(|_| rng.gen_range(-scale1..scale1))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases1 = vec![0.0; hidden_size];
|
||||||
|
|
||||||
|
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||||
|
let weights2 = (0..hidden_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..hidden_size)
|
||||||
|
.map(|_| rng.gen_range(-scale2..scale2))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases2 = vec![0.0; hidden_size];
|
||||||
|
|
||||||
|
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||||
|
let weights3 = (0..output_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..hidden_size)
|
||||||
|
.map(|_| rng.gen_range(-scale3..scale3))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases3 = vec![0.0; output_size];
|
||||||
|
|
||||||
|
Self {
|
||||||
|
weights1,
|
||||||
|
biases1,
|
||||||
|
weights2,
|
||||||
|
biases2,
|
||||||
|
weights3,
|
||||||
|
biases3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||||
|
// Première couche
|
||||||
|
let mut layer1: Vec<f32> = self.biases1.clone();
|
||||||
|
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < input.len() {
|
||||||
|
layer1[i] += input[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
layer1[i] = layer1[i].max(0.0); // ReLU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deuxième couche
|
||||||
|
let mut layer2: Vec<f32> = self.biases2.clone();
|
||||||
|
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < layer1.len() {
|
||||||
|
layer2[i] += layer1[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
layer2[i] = layer2[i].max(0.0); // ReLU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Couche de sortie
|
||||||
|
let mut output: Vec<f32> = self.biases3.clone();
|
||||||
|
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < layer2.len() {
|
||||||
|
output[i] += layer2[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||||||
|
let q_values = self.forward(input);
|
||||||
|
q_values
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||||
|
.map(|(index, _)| index)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
path: P,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let data = serde_json::to_string_pretty(self)?;
|
||||||
|
std::fs::write(path, data)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let data = std::fs::read_to_string(path)?;
|
||||||
|
let network = serde_json::from_str(&data)?;
|
||||||
|
Ok(network)
|
||||||
|
}
|
||||||
|
}
|
||||||
494
bot/src/dqn_simple/dqn_trainer.rs
Normal file
494
bot/src/dqn_simple/dqn_trainer.rs
Normal file
|
|
@ -0,0 +1,494 @@
|
||||||
|
use crate::{CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use rand::prelude::SliceRandom;
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
|
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
|
||||||
|
use crate::training_common_big::{get_valid_actions, TrictracAction};
|
||||||
|
|
||||||
|
/// Expérience pour le buffer de replay
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Experience {
|
||||||
|
pub state: Vec<f32>,
|
||||||
|
pub action: TrictracAction,
|
||||||
|
pub reward: f32,
|
||||||
|
pub next_state: Vec<f32>,
|
||||||
|
pub done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer de replay pour stocker les expériences
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
buffer: VecDeque<Experience>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: VecDeque::with_capacity(capacity),
|
||||||
|
capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, experience: Experience) {
|
||||||
|
if self.buffer.len() >= self.capacity {
|
||||||
|
self.buffer.pop_front();
|
||||||
|
}
|
||||||
|
self.buffer.push_back(experience);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let len = self.buffer.len();
|
||||||
|
if len < batch_size {
|
||||||
|
return self.buffer.iter().cloned().collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batch = Vec::with_capacity(batch_size);
|
||||||
|
for _ in 0..batch_size {
|
||||||
|
let idx = rng.gen_range(0..len);
|
||||||
|
batch.push(self.buffer[idx].clone());
|
||||||
|
}
|
||||||
|
batch
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.buffer.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent DQN pour l'apprentissage par renforcement
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnAgent {
|
||||||
|
config: DqnConfig,
|
||||||
|
model: SimpleNeuralNetwork,
|
||||||
|
target_model: SimpleNeuralNetwork,
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
epsilon: f64,
|
||||||
|
step_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnAgent {
|
||||||
|
pub fn new(config: DqnConfig) -> Self {
|
||||||
|
let model =
|
||||||
|
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||||
|
let target_model = model.clone();
|
||||||
|
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||||
|
let epsilon = config.epsilon;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
model,
|
||||||
|
target_model,
|
||||||
|
replay_buffer,
|
||||||
|
epsilon,
|
||||||
|
step_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||||
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
// Fallback si aucune action valide
|
||||||
|
return TrictracAction::Roll;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
if rng.gen::<f64>() < self.epsilon {
|
||||||
|
// Exploration : action valide aléatoire
|
||||||
|
valid_actions
|
||||||
|
.choose(&mut rng)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(TrictracAction::Roll)
|
||||||
|
} else {
|
||||||
|
// Exploitation : meilleure action valide selon le modèle
|
||||||
|
let q_values = self.model.forward(state);
|
||||||
|
|
||||||
|
let mut best_action = &valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for action in &valid_actions {
|
||||||
|
let action_index = action.to_action_index();
|
||||||
|
if action_index < q_values.len() {
|
||||||
|
let q_value = q_values[action_index];
|
||||||
|
if q_value > best_q_value {
|
||||||
|
best_q_value = q_value;
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
best_action.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_experience(&mut self, experience: Experience) {
|
||||||
|
self.replay_buffer.push(experience);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&mut self) {
|
||||||
|
if self.replay_buffer.len() < self.config.batch_size {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||||||
|
// Dans une implémentation complète, ici on ferait la backpropagation
|
||||||
|
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||||||
|
self.step_count += 1;
|
||||||
|
|
||||||
|
// Mise à jour du target model tous les 100 steps
|
||||||
|
if self.step_count % 100 == 0 {
|
||||||
|
self.target_model = self.model.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
path: P,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
self.model.save(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_epsilon(&self) -> f64 {
|
||||||
|
self.epsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_step_count(&self) -> usize {
|
||||||
|
self.step_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Environnement Trictrac pour l'entraînement
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TrictracEnv {
|
||||||
|
pub game_state: GameState,
|
||||||
|
pub agent_player_id: PlayerId,
|
||||||
|
pub opponent_player_id: PlayerId,
|
||||||
|
pub agent_color: Color,
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub current_step: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrictracEnv {
|
||||||
|
fn default() -> Self {
|
||||||
|
let mut game_state = GameState::new(false);
|
||||||
|
game_state.init_player("agent");
|
||||||
|
game_state.init_player("opponent");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
game_state,
|
||||||
|
agent_player_id: 1,
|
||||||
|
opponent_player_id: 2,
|
||||||
|
agent_color: Color::White,
|
||||||
|
max_steps: 1000,
|
||||||
|
current_step: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracEnv {
|
||||||
|
pub fn reset(&mut self) -> Vec<f32> {
|
||||||
|
self.game_state = GameState::new(false);
|
||||||
|
self.game_state.init_player("agent");
|
||||||
|
self.game_state.init_player("opponent");
|
||||||
|
|
||||||
|
// Commencer la partie
|
||||||
|
self.game_state.consume(&GameEvent::BeginGame {
|
||||||
|
goes_first: self.agent_player_id,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.current_step = 0;
|
||||||
|
self.game_state.to_vec_float()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
|
||||||
|
// Appliquer l'action de l'agent
|
||||||
|
if self.game_state.active_player_id == self.agent_player_id {
|
||||||
|
reward += self.apply_agent_action(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Faire jouer l'adversaire (stratégie simple)
|
||||||
|
while self.game_state.active_player_id == self.opponent_player_id
|
||||||
|
&& self.game_state.stage != Stage::Ended
|
||||||
|
{
|
||||||
|
reward += self.play_opponent_turn();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vérifier si la partie est terminée
|
||||||
|
let done = self.game_state.stage == Stage::Ended
|
||||||
|
|| self.game_state.determine_winner().is_some()
|
||||||
|
|| self.current_step >= self.max_steps;
|
||||||
|
|
||||||
|
// Récompense finale si la partie est terminée
|
||||||
|
if done {
|
||||||
|
if let Some(winner) = self.game_state.determine_winner() {
|
||||||
|
if winner == self.agent_player_id {
|
||||||
|
reward += 100.0; // Bonus pour gagner
|
||||||
|
} else {
|
||||||
|
reward -= 50.0; // Pénalité pour perdre
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_step += 1;
|
||||||
|
let next_state = self.game_state.to_vec_float();
|
||||||
|
(next_state, reward, done)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
reward += 0.1;
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game_state.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.agent_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game_state.dice.values.0, self.game_state.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game_state.dice.values.1, self.game_state.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appliquer l'événement si valide
|
||||||
|
if let Some(event) = event {
|
||||||
|
if self.game_state.validate(&event) {
|
||||||
|
self.game_state.consume(&event);
|
||||||
|
|
||||||
|
// Simuler le résultat des dés après un Roll
|
||||||
|
if matches!(action, TrictracAction::Roll) {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
let dice_event = GameEvent::RollResult {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if self.game_state.validate(&dice_event) {
|
||||||
|
self.game_state.consume(&dice_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Pénalité pour action invalide
|
||||||
|
reward -= 2.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reward
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO : use default bot strategy
|
||||||
|
fn play_opponent_turn(&mut self) -> f32 {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
let event = match self.game_state.turn_stage {
|
||||||
|
TurnStage::RollDice => GameEvent::Roll {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
},
|
||||||
|
TurnStage::RollWaiting => {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
GameEvent::RollResult {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||||
|
let opponent_color = self.agent_color.opponent_color();
|
||||||
|
let dice_roll_count = self
|
||||||
|
.game_state
|
||||||
|
.players
|
||||||
|
.get(&self.opponent_player_id)
|
||||||
|
.unwrap()
|
||||||
|
.dice_roll_count;
|
||||||
|
let points_rules = PointsRules::new(
|
||||||
|
&opponent_color,
|
||||||
|
&self.game_state.board,
|
||||||
|
self.game_state.dice,
|
||||||
|
);
|
||||||
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
|
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
|
GameEvent::Mark {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
points,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::Move => {
|
||||||
|
let opponent_color = self.agent_color.opponent_color();
|
||||||
|
let rules = MoveRules::new(
|
||||||
|
&opponent_color,
|
||||||
|
&self.game_state.board,
|
||||||
|
self.game_state.dice,
|
||||||
|
);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
// Stratégie simple : choix aléatoire
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let choosen_move = *possible_moves
|
||||||
|
.choose(&mut rng)
|
||||||
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
|
GameEvent::Move {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
moves: if opponent_color == Color::White {
|
||||||
|
choosen_move
|
||||||
|
} else {
|
||||||
|
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::HoldOrGoChoice => {
|
||||||
|
// Stratégie simple : toujours continuer
|
||||||
|
GameEvent::Go {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if self.game_state.validate(&event) {
|
||||||
|
self.game_state.consume(&event);
|
||||||
|
}
|
||||||
|
reward
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Entraîneur pour le modèle DQN
|
||||||
|
pub struct DqnTrainer {
|
||||||
|
agent: DqnAgent,
|
||||||
|
env: TrictracEnv,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnTrainer {
|
||||||
|
pub fn new(config: DqnConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
agent: DqnAgent::new(config),
|
||||||
|
env: TrictracEnv::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_episode(&mut self) -> f32 {
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
let mut state = self.env.reset();
|
||||||
|
// let mut step_count = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// step_count += 1;
|
||||||
|
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||||
|
let (next_state, reward, done) = self.env.step(action.clone());
|
||||||
|
total_reward += reward;
|
||||||
|
|
||||||
|
let experience = Experience {
|
||||||
|
state: state.clone(),
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
next_state: next_state.clone(),
|
||||||
|
done,
|
||||||
|
};
|
||||||
|
self.agent.store_experience(experience);
|
||||||
|
self.agent.train();
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// if step_count % 100 == 0 {
|
||||||
|
// println!("{:?}", next_state);
|
||||||
|
// }
|
||||||
|
state = next_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
total_reward
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(
|
||||||
|
&mut self,
|
||||||
|
episodes: usize,
|
||||||
|
save_every: usize,
|
||||||
|
model_path: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("Démarrage de l'entraînement DQN pour {episodes} épisodes");
|
||||||
|
|
||||||
|
for episode in 1..=episodes {
|
||||||
|
let reward = self.train_episode();
|
||||||
|
|
||||||
|
if episode % 100 == 0 {
|
||||||
|
println!(
|
||||||
|
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||||
|
episode,
|
||||||
|
episodes,
|
||||||
|
reward,
|
||||||
|
self.agent.get_epsilon(),
|
||||||
|
self.agent.get_step_count()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if episode % save_every == 0 {
|
||||||
|
let save_path = format!("{model_path}_episode_{episode}.json");
|
||||||
|
self.agent.save_model(&save_path)?;
|
||||||
|
println!("Modèle sauvegardé : {save_path}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sauvegarder le modèle final
|
||||||
|
let final_path = format!("{model_path}_final.json");
|
||||||
|
self.agent.save_model(&final_path)?;
|
||||||
|
println!("Modèle final sauvegardé : {final_path}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
109
bot/src/dqn_simple/main.rs
Normal file
109
bot/src/dqn_simple/main.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
use bot::dqn_simple::dqn_model::DqnConfig;
|
||||||
|
use bot::dqn_simple::dqn_trainer::DqnTrainer;
|
||||||
|
use bot::training_common::TrictracAction;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
// Paramètres par défaut
|
||||||
|
let mut episodes = 1000;
|
||||||
|
let mut model_path = "models/dqn_model".to_string();
|
||||||
|
let mut save_every = 100;
|
||||||
|
|
||||||
|
// Parser les arguments de ligne de commande
|
||||||
|
let mut i = 1;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--episodes" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--model-path" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
model_path = args[i + 1].clone();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --model-path nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--save-every" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
save_every = args[i + 1].parse().unwrap_or(100);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
print_help();
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("Argument inconnu : {}", args[i]);
|
||||||
|
print_help();
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Créer le dossier models s'il n'existe pas
|
||||||
|
std::fs::create_dir_all("models")?;
|
||||||
|
|
||||||
|
println!("Configuration d'entraînement DQN :");
|
||||||
|
println!(" Épisodes : {episodes}");
|
||||||
|
println!(" Chemin du modèle : {model_path}");
|
||||||
|
println!(" Sauvegarde tous les {save_every} épisodes");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Configuration DQN
|
||||||
|
let config = DqnConfig {
|
||||||
|
state_size: 36, // state.to_vec size
|
||||||
|
hidden_size: 256,
|
||||||
|
num_actions: TrictracAction::action_space_size(),
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 0.9, // Commencer avec plus d'exploration
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Créer et lancer l'entraîneur
|
||||||
|
let mut trainer = DqnTrainer::new(config);
|
||||||
|
trainer.train(episodes, save_every, &model_path)?;
|
||||||
|
|
||||||
|
println!("Entraînement terminé avec succès !");
|
||||||
|
println!("Pour utiliser le modèle entraîné :");
|
||||||
|
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_help() {
|
||||||
|
println!("Entraîneur DQN pour Trictrac");
|
||||||
|
println!();
|
||||||
|
println!("USAGE:");
|
||||||
|
println!(" cargo run --bin=train_dqn [OPTIONS]");
|
||||||
|
println!();
|
||||||
|
println!("OPTIONS:");
|
||||||
|
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||||
|
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
|
||||||
|
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
||||||
|
println!(" -h, --help Afficher cette aide");
|
||||||
|
println!();
|
||||||
|
println!("EXEMPLES:");
|
||||||
|
println!(" cargo run --bin=train_dqn");
|
||||||
|
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
|
||||||
|
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
|
||||||
|
}
|
||||||
2
bot/src/dqn_simple/mod.rs
Normal file
2
bot/src/dqn_simple/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod dqn_model;
|
||||||
|
pub mod dqn_trainer;
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
pub mod burnrl;
|
pub mod burnrl;
|
||||||
|
pub mod dqn_simple;
|
||||||
pub mod strategy;
|
pub mod strategy;
|
||||||
pub mod training_common;
|
pub mod training_common;
|
||||||
|
pub mod training_common_big;
|
||||||
pub mod trictrac_board;
|
pub mod trictrac_board;
|
||||||
|
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
pub use strategy::default::DefaultStrategy;
|
pub use strategy::default::DefaultStrategy;
|
||||||
|
pub use strategy::dqn::DqnStrategy;
|
||||||
pub use strategy::dqnburn::DqnBurnStrategy;
|
pub use strategy::dqnburn::DqnBurnStrategy;
|
||||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||||
pub use strategy::random::RandomStrategy;
|
pub use strategy::random::RandomStrategy;
|
||||||
|
|
|
||||||
174
bot/src/strategy/dqn.rs
Normal file
174
bot/src/strategy/dqn.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use log::info;
|
||||||
|
use std::path::Path;
|
||||||
|
use store::MoveRules;
|
||||||
|
|
||||||
|
use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
|
||||||
|
use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||||
|
|
||||||
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnStrategy {
|
||||||
|
pub game: GameState,
|
||||||
|
pub player_id: PlayerId,
|
||||||
|
pub color: Color,
|
||||||
|
pub model: Option<SimpleNeuralNetwork>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
game: GameState::default(),
|
||||||
|
player_id: 1,
|
||||||
|
color: Color::White,
|
||||||
|
model: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnStrategy {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_model<P: AsRef<Path> + std::fmt::Debug>(model_path: P) -> Self {
|
||||||
|
let mut strategy = Self::new();
|
||||||
|
if let Ok(model) = SimpleNeuralNetwork::load(&model_path) {
|
||||||
|
info!("Loading model {model_path:?}");
|
||||||
|
strategy.model = Some(model);
|
||||||
|
}
|
||||||
|
strategy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilise le modèle DQN pour choisir une action valide
|
||||||
|
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||||
|
if let Some(ref model) = self.model {
|
||||||
|
let state = self.game.to_vec_float();
|
||||||
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtenir les Q-values pour toutes les actions
|
||||||
|
let q_values = model.forward(&state);
|
||||||
|
|
||||||
|
// Trouver la meilleure action valide
|
||||||
|
let mut best_action = &valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for action in &valid_actions {
|
||||||
|
let action_index = action.to_action_index();
|
||||||
|
if action_index < q_values.len() {
|
||||||
|
let q_value = q_values[action_index];
|
||||||
|
if q_value > best_q_value {
|
||||||
|
best_q_value = q_value;
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(best_action.clone())
|
||||||
|
} else {
|
||||||
|
// Fallback : action aléatoire valide
|
||||||
|
sample_valid_action(&self.game)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for DqnStrategy {
|
||||||
|
fn get_game(&self) -> &GameState {
|
||||||
|
&self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||||||
|
&mut self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_color(&mut self, color: Color) {
|
||||||
|
self.color = color;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||||
|
self.player_id = player_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_adv_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
// Utiliser le DQN pour décider si on continue
|
||||||
|
if let Some(action) = self.get_dqn_action() {
|
||||||
|
matches!(action, TrictracAction::Go)
|
||||||
|
} else {
|
||||||
|
// Fallback : toujours continuer
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
// Utiliser le DQN pour choisir le mouvement
|
||||||
|
if let Some(TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
}) = self.get_dqn_action()
|
||||||
|
{
|
||||||
|
let dicevals = self.game.dice.values;
|
||||||
|
let (mut dice1, mut dice2) = if dice_order {
|
||||||
|
(dicevals.0, dicevals.1)
|
||||||
|
} else {
|
||||||
|
(dicevals.1, dicevals.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
if from1 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice1 = 0;
|
||||||
|
}
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
if 24 < to1 {
|
||||||
|
// sortie
|
||||||
|
to1 = 0;
|
||||||
|
}
|
||||||
|
if from2 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice2 = 0;
|
||||||
|
}
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
if 24 < to2 {
|
||||||
|
// sortie
|
||||||
|
to2 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
let chosen_move = if self.color == Color::White {
|
||||||
|
(checker_move1, checker_move2)
|
||||||
|
} else {
|
||||||
|
(checker_move1.mirror(), checker_move2.mirror())
|
||||||
|
};
|
||||||
|
|
||||||
|
return chosen_move;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback : utiliser la stratégie par défaut
|
||||||
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
let chosen_move = *possible_moves
|
||||||
|
.first()
|
||||||
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
|
if self.color == Color::White {
|
||||||
|
chosen_move
|
||||||
|
} else {
|
||||||
|
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod default;
|
pub mod default;
|
||||||
|
pub mod dqn;
|
||||||
pub mod dqnburn;
|
pub mod dqnburn;
|
||||||
pub mod erroneous_moves;
|
pub mod erroneous_moves;
|
||||||
pub mod random;
|
pub mod random;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
/// training_common.rs : environnement avec espace d'actions optimisé
|
|
||||||
/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' )
|
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
use std::fmt::{Debug, Display, Formatter};
|
use std::fmt::{Debug, Display, Formatter};
|
||||||
|
|
||||||
|
|
|
||||||
266
bot/src/training_common_big.rs
Normal file
266
bot/src/training_common_big.rs
Normal file
|
|
@ -0,0 +1,266 @@
|
||||||
|
use std::cmp::{max, min};
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use store::{CheckerMove, Dice};
|
||||||
|
|
||||||
|
/// Types d'actions possibles dans le jeu
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub enum TrictracAction {
|
||||||
|
/// Lancer les dés
|
||||||
|
Roll,
|
||||||
|
/// Continuer après avoir gagné un trou
|
||||||
|
Go,
|
||||||
|
/// Effectuer un mouvement de pions
|
||||||
|
Move {
|
||||||
|
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||||
|
from1: usize, // position de départ du premier pion (0-24)
|
||||||
|
from2: usize, // position de départ du deuxième pion (0-24)
|
||||||
|
},
|
||||||
|
// Marquer les points : à activer si support des écoles
|
||||||
|
// Mark,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracAction {
|
||||||
|
/// Encode une action en index pour le réseau de neurones
|
||||||
|
pub fn to_action_index(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
TrictracAction::Roll => 0,
|
||||||
|
TrictracAction::Go => 1,
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Encoder les mouvements dans l'espace d'actions
|
||||||
|
// Indices 2+ pour les mouvements
|
||||||
|
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier)
|
||||||
|
let mut start = 2;
|
||||||
|
if !dice_order {
|
||||||
|
// 25 * 25 = 625
|
||||||
|
start += 625;
|
||||||
|
}
|
||||||
|
start + from1 * 25 + from2
|
||||||
|
} // TrictracAction::Mark => 1252,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Décode un index d'action en TrictracAction
|
||||||
|
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||||
|
match index {
|
||||||
|
0 => Some(TrictracAction::Roll),
|
||||||
|
// 1252 => Some(TrictracAction::Mark),
|
||||||
|
1 => Some(TrictracAction::Go),
|
||||||
|
i if i >= 3 => {
|
||||||
|
let move_code = i - 3;
|
||||||
|
let (dice_order, from1, from2) = Self::decode_move(move_code);
|
||||||
|
Some(TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Décode un entier en paire de mouvements
|
||||||
|
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||||
|
let mut encoded = code;
|
||||||
|
let dice_order = code < 626;
|
||||||
|
if !dice_order {
|
||||||
|
encoded -= 625
|
||||||
|
}
|
||||||
|
let from1 = encoded / 25;
|
||||||
|
let from2 = 1 + encoded % 25;
|
||||||
|
(dice_order, from1, from2)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retourne la taille de l'espace d'actions total
|
||||||
|
pub fn action_space_size() -> usize {
|
||||||
|
// 1 (Roll) + 1 (Go) + mouvements possibles
|
||||||
|
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
|
||||||
|
// Mais on peut optimiser en limitant aux positions valides (1-24)
|
||||||
|
2 + (2 * 25 * 25) // = 1252
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
||||||
|
// match action {
|
||||||
|
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
|
||||||
|
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
|
||||||
|
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
|
||||||
|
// TrictracAction::Move {
|
||||||
|
// dice_order,
|
||||||
|
// from1,
|
||||||
|
// from2,
|
||||||
|
// } => {
|
||||||
|
// // Effectuer un mouvement
|
||||||
|
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
||||||
|
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
||||||
|
//
|
||||||
|
// Some(GameEvent::Move {
|
||||||
|
// player_id: self.agent_player_id,
|
||||||
|
// moves: (checker_move1, checker_move2),
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Obtient les actions valides pour l'état de jeu actuel
|
||||||
|
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
|
use store::TurnStage;
|
||||||
|
|
||||||
|
let mut valid_actions = Vec::new();
|
||||||
|
|
||||||
|
let active_player_id = game_state.active_player_id;
|
||||||
|
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||||
|
|
||||||
|
if let Some(color) = player_color {
|
||||||
|
match game_state.turn_stage {
|
||||||
|
TurnStage::RollDice => {
|
||||||
|
valid_actions.push(TrictracAction::Roll);
|
||||||
|
}
|
||||||
|
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||||
|
panic!(
|
||||||
|
"get_valid_actions not implemented for turn stage {:?}",
|
||||||
|
game_state.turn_stage
|
||||||
|
);
|
||||||
|
// valid_actions.push(TrictracAction::Mark);
|
||||||
|
}
|
||||||
|
TurnStage::HoldOrGoChoice => {
|
||||||
|
valid_actions.push(TrictracAction::Go);
|
||||||
|
|
||||||
|
// Ajoute aussi les mouvements possibles
|
||||||
|
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||||
|
assert_eq!(color, store::Color::White);
|
||||||
|
for (move1, move2) in possible_moves {
|
||||||
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
|
&move1,
|
||||||
|
&move2,
|
||||||
|
&game_state.dice,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::Move => {
|
||||||
|
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||||
|
let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
if possible_moves.is_empty() {
|
||||||
|
// Empty move
|
||||||
|
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||||
|
assert_eq!(color, store::Color::White);
|
||||||
|
for (move1, move2) in possible_moves {
|
||||||
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
|
&move1,
|
||||||
|
&move2,
|
||||||
|
&game_state.dice,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
panic!("empty valid_actions for state {game_state}");
|
||||||
|
}
|
||||||
|
valid_actions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid only for White player
|
||||||
|
fn checker_moves_to_trictrac_action(
|
||||||
|
move1: &CheckerMove,
|
||||||
|
move2: &CheckerMove,
|
||||||
|
dice: &Dice,
|
||||||
|
) -> TrictracAction {
|
||||||
|
let to1 = move1.get_to();
|
||||||
|
let to2 = move2.get_to();
|
||||||
|
let from1 = move1.get_from();
|
||||||
|
let from2 = move2.get_from();
|
||||||
|
|
||||||
|
let mut diff_move1 = if to1 > 0 {
|
||||||
|
// Mouvement sans sortie
|
||||||
|
to1 - from1
|
||||||
|
} else {
|
||||||
|
// sortie, on utilise la valeur du dé
|
||||||
|
if to2 > 0 {
|
||||||
|
// sortie pour le mouvement 1 uniquement
|
||||||
|
let dice2 = to2 - from2;
|
||||||
|
if dice2 == dice.values.0 as usize {
|
||||||
|
dice.values.1 as usize
|
||||||
|
} else {
|
||||||
|
dice.values.0 as usize
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// double sortie
|
||||||
|
if from1 < from2 {
|
||||||
|
max(dice.values.0, dice.values.1) as usize
|
||||||
|
} else {
|
||||||
|
min(dice.values.0, dice.values.1) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
|
||||||
|
let rest_field = 12;
|
||||||
|
if to1 == rest_field
|
||||||
|
&& to2 == rest_field
|
||||||
|
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
|
||||||
|
{
|
||||||
|
// prise par puissance
|
||||||
|
diff_move1 += 1;
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order: diff_move1 == dice.values.0 as usize,
|
||||||
|
from1: move1.get_from(),
|
||||||
|
from2: move2.get_from(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retourne les indices des actions valides
|
||||||
|
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
||||||
|
get_valid_actions(game_state)
|
||||||
|
.into_iter()
|
||||||
|
.map(|action| action.to_action_index())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sélectionne une action valide aléatoire
|
||||||
|
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||||
|
use rand::{seq::SliceRandom, thread_rng};
|
||||||
|
|
||||||
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
valid_actions.choose(&mut rng).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_action_index() {
|
||||||
|
let action = TrictracAction::Move {
|
||||||
|
dice_order: true,
|
||||||
|
from1: 3,
|
||||||
|
from2: 4,
|
||||||
|
};
|
||||||
|
let index = action.to_action_index();
|
||||||
|
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||||
|
assert_eq!(81, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_action_index() {
|
||||||
|
let action = TrictracAction::Move {
|
||||||
|
dice_order: true,
|
||||||
|
from1: 3,
|
||||||
|
from2: 4,
|
||||||
|
};
|
||||||
|
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||||
|
}
|
||||||
|
}
|
||||||
8
client_bevy/.cargo/config.toml
Normal file
8
client_bevy/.cargo/config.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
linker = "clang"
|
||||||
|
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
|
||||||
|
|
||||||
|
# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only'
|
||||||
|
# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains.
|
||||||
|
#[profile.dev]
|
||||||
|
#debug = 1
|
||||||
14
client_bevy/Cargo.toml
Normal file
14
client_bevy/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "trictrac-client"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.75"
|
||||||
|
bevy = { version = "0.11.3" }
|
||||||
|
bevy_renet = "0.0.9"
|
||||||
|
bincode = "1.3.3"
|
||||||
|
renet = "0.0.13"
|
||||||
|
store = { path = "../store" }
|
||||||
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
Binary file not shown.
BIN
client_bevy/assets/board.png
Normal file
BIN
client_bevy/assets/board.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
BIN
client_bevy/assets/sound/click.wav
Normal file
BIN
client_bevy/assets/sound/click.wav
Normal file
Binary file not shown.
BIN
client_bevy/assets/sound/throw.wav
Executable file
BIN
client_bevy/assets/sound/throw.wav
Executable file
Binary file not shown.
BIN
client_bevy/assets/tac.png
Normal file
BIN
client_bevy/assets/tac.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
BIN
client_bevy/assets/tic.png
Normal file
BIN
client_bevy/assets/tic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 KiB |
334
client_bevy/src/main.rs
Normal file
334
client_bevy/src/main.rs
Normal file
|
|
@ -0,0 +1,334 @@
|
||||||
|
use std::{net::UdpSocket, time::SystemTime};
|
||||||
|
|
||||||
|
use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES};
|
||||||
|
use store::{GameEvent, GameState, CheckerMove};
|
||||||
|
|
||||||
|
use bevy::prelude::*;
|
||||||
|
use bevy::window::PrimaryWindow;
|
||||||
|
use bevy_renet::{
|
||||||
|
renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient},
|
||||||
|
transport::{client_connected, NetcodeClientPlugin},
|
||||||
|
RenetClientPlugin,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Resource)]
|
||||||
|
struct CurrentClientId(u64);
|
||||||
|
|
||||||
|
#[derive(Resource)]
|
||||||
|
struct BevyGameState(GameState);
|
||||||
|
|
||||||
|
impl Default for BevyGameState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
0: GameState::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Resource, Deref, DerefMut)]
|
||||||
|
struct GameUIState {
|
||||||
|
selected_tile: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GameUIState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
selected_tile: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Event)]
|
||||||
|
struct BevyGameEvent(GameEvent);
|
||||||
|
|
||||||
|
// This id needs to be the same as the server is using
|
||||||
|
const PROTOCOL_ID: u64 = 2878;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Get username from stdin args
|
||||||
|
let args = std::env::args().collect::<Vec<String>>();
|
||||||
|
let username = &args[1];
|
||||||
|
|
||||||
|
let (client, transport, client_id) = new_renet_client(&username).unwrap();
|
||||||
|
App::new()
|
||||||
|
// Lets add a nice dark grey background color
|
||||||
|
.insert_resource(ClearColor(Color::hex("282828").unwrap()))
|
||||||
|
.add_plugins(DefaultPlugins.set(WindowPlugin {
|
||||||
|
primary_window: Some(Window {
|
||||||
|
// Adding the username to the window title makes debugging a whole lot easier.
|
||||||
|
title: format!("TricTrac <{}>", username),
|
||||||
|
resolution: (1080.0, 1080.0).into(),
|
||||||
|
..default()
|
||||||
|
}),
|
||||||
|
..default()
|
||||||
|
}))
|
||||||
|
// Add our game state and register GameEvent as a bevy event
|
||||||
|
.insert_resource(BevyGameState::default())
|
||||||
|
.insert_resource(GameUIState::default())
|
||||||
|
.add_event::<BevyGameEvent>()
|
||||||
|
// Renet setup
|
||||||
|
.add_plugins(RenetClientPlugin)
|
||||||
|
.add_plugins(NetcodeClientPlugin)
|
||||||
|
.insert_resource(client)
|
||||||
|
.insert_resource(transport)
|
||||||
|
.insert_resource(CurrentClientId(client_id))
|
||||||
|
.add_systems(Startup, setup)
|
||||||
|
.add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system))
|
||||||
|
.add_systems(
|
||||||
|
PostUpdate,
|
||||||
|
receive_events_from_server.run_if(client_connected()),
|
||||||
|
)
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// COMPONENTS //////////
|
||||||
|
#[derive(Component)]
|
||||||
|
struct UIRoot;
|
||||||
|
|
||||||
|
#[derive(Component)]
|
||||||
|
struct WaitingText;
|
||||||
|
|
||||||
|
#[derive(Component)]
|
||||||
|
struct Board {
|
||||||
|
squares: [Square; 26]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Board {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
squares: [Square { count: 0, color: None, position: 0}; 26]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Board {
|
||||||
|
fn square_at(&self, position: usize) -> Square {
|
||||||
|
self.squares[position]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Component, Clone, Copy)]
|
||||||
|
struct Square {
|
||||||
|
count: usize,
|
||||||
|
color: Option<bool>,
|
||||||
|
position: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// UPDATE SYSTEMS //////////
|
||||||
|
fn update_board(
|
||||||
|
mut commands: Commands,
|
||||||
|
game_state: Res<BevyGameState>,
|
||||||
|
mut game_events: EventReader<BevyGameEvent>,
|
||||||
|
asset_server: Res<AssetServer>,
|
||||||
|
) {
|
||||||
|
for event in game_events.iter() {
|
||||||
|
match event.0 {
|
||||||
|
GameEvent::Move { player_id, moves } => {
|
||||||
|
// trictrac positions, TODO : dépend de player_id
|
||||||
|
let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)};
|
||||||
|
let texture =
|
||||||
|
asset_server.load(match game_state.0.players[&player_id].color {
|
||||||
|
store::Color::Black => "tac.png",
|
||||||
|
store::Color::White => "tic.png",
|
||||||
|
});
|
||||||
|
|
||||||
|
info!("spawning tictac sprite");
|
||||||
|
commands.spawn(SpriteBundle {
|
||||||
|
transform: Transform::from_xyz(
|
||||||
|
83.0 * (x as f32 - 1.0),
|
||||||
|
-30.0 + 540.0 * (y as f32 - 1.0),
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
sprite: Sprite {
|
||||||
|
custom_size: Some(Vec2::new(83.0, 83.0)),
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
texture: texture.into(),
|
||||||
|
..default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_waiting_text(mut text_query: Query<&mut Text, With<WaitingText>>, time: Res<Time>) {
|
||||||
|
if let Ok(mut text) = text_query.get_single_mut() {
|
||||||
|
let num_dots = (time.elapsed_seconds() as usize % 3) + 1;
|
||||||
|
text.sections[0].value = format!(
|
||||||
|
"Waiting for an opponent{}{}",
|
||||||
|
".".repeat(num_dots as usize),
|
||||||
|
// Pad with spaces to avoid text changing width and dancing all around the screen 🕺
|
||||||
|
" ".repeat(3 - num_dots as usize)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input(
|
||||||
|
primary_query: Query<&Window, With<PrimaryWindow>>,
|
||||||
|
// windows: Res<Windows>,
|
||||||
|
input: Res<Input<MouseButton>>,
|
||||||
|
game_state: Res<BevyGameState>,
|
||||||
|
mut game_ui_state: ResMut<GameUIState>,
|
||||||
|
mut client: ResMut<RenetClient>,
|
||||||
|
client_id: Res<CurrentClientId>,
|
||||||
|
) {
|
||||||
|
// We only want to handle inputs once we are ingame
|
||||||
|
if game_state.0.stage != store::Stage::InGame {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let window = primary_query.get_single().unwrap();
|
||||||
|
if let Some(mouse_position) = window.cursor_position() {
|
||||||
|
// Determine the index of the tile that the mouse is currently over
|
||||||
|
// NOTE: This calculation assumes a fixed window size.
|
||||||
|
// That's fine for now, but consider using the windows size instead.
|
||||||
|
let mut tile_x: usize = (mouse_position.x / 83.0).floor() as usize;
|
||||||
|
let tile_y: usize = (mouse_position.y / 540.0).floor() as usize;
|
||||||
|
if tile_x > 5 {
|
||||||
|
// remove the middle bar offset
|
||||||
|
tile_x = tile_x - 1
|
||||||
|
}
|
||||||
|
// let tile = tile_x + tile_y * 12;
|
||||||
|
|
||||||
|
// traduction en position backgammon
|
||||||
|
let tile = if tile_y == 0 {
|
||||||
|
13 + tile_x
|
||||||
|
} else {
|
||||||
|
12 - tile_x
|
||||||
|
};
|
||||||
|
|
||||||
|
// If mouse is outside of board we do nothing
|
||||||
|
if 23 < tile {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If left mouse button is pressed, send a place tile event to the server
|
||||||
|
if input.just_pressed(MouseButton::Left) {
|
||||||
|
info!("select piece at tile {:?}", tile);
|
||||||
|
if game_ui_state.selected_tile.is_some() {
|
||||||
|
let from_tile = game_ui_state.selected_tile.unwrap();
|
||||||
|
info!("sending movement from: {:?} to: {:?} ", from_tile, tile);
|
||||||
|
let event = GameEvent::Move {
|
||||||
|
player_id: client_id.0,
|
||||||
|
moves: (
|
||||||
|
CheckerMove::new(from_tile, tile).unwrap(),
|
||||||
|
CheckerMove::new(from_tile, tile).unwrap()
|
||||||
|
)
|
||||||
|
};
|
||||||
|
client.send_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
game_ui_state.selected_tile = if game_ui_state.selected_tile.is_some() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// SETUP //////////
|
||||||
|
fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
|
||||||
|
// Tric Trac is a 2D game
|
||||||
|
// To show 2D sprites we need a 2D camera
|
||||||
|
commands.spawn(Camera2dBundle::default());
|
||||||
|
|
||||||
|
// Spawn board background
|
||||||
|
commands.spawn(SpriteBundle {
|
||||||
|
transform: Transform::from_xyz(0.0, -30.0, 0.0),
|
||||||
|
sprite: Sprite {
|
||||||
|
custom_size: Some(Vec2::new(1080.0, 927.0)),
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
texture: asset_server.load("board.png").into(),
|
||||||
|
..default()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn pregame ui
|
||||||
|
commands
|
||||||
|
// A container that centers its children on the screen
|
||||||
|
.spawn(NodeBundle {
|
||||||
|
style: Style {
|
||||||
|
position_type: PositionType::Absolute,
|
||||||
|
left: Val::Px(0.0),
|
||||||
|
top: Val::Px(0.0),
|
||||||
|
width: Val::Percent(100.0),
|
||||||
|
height: Val::Percent(100.0),
|
||||||
|
align_items: AlignItems::Center,
|
||||||
|
justify_content: JustifyContent::Center,
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
..default()
|
||||||
|
})
|
||||||
|
.insert(UIRoot)
|
||||||
|
.with_children(|parent| {
|
||||||
|
// parent.spawn(Board::default()); // panic
|
||||||
|
parent
|
||||||
|
.spawn(TextBundle::from_section(
|
||||||
|
"Waiting for an opponent...",
|
||||||
|
TextStyle {
|
||||||
|
font: asset_server.load("Inconsolata.ttf"),
|
||||||
|
font_size: 24.0,
|
||||||
|
color: Color::hex("ebdbb2").unwrap(),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.insert(WaitingText);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// RENET NETWORKING //////////
|
||||||
|
// Creates a RenetClient thats already connected to a server.
|
||||||
|
// Returns an Err if connection fails
|
||||||
|
fn new_renet_client(
|
||||||
|
username: &String,
|
||||||
|
) -> anyhow::Result<(RenetClient, NetcodeClientTransport, u64)> {
|
||||||
|
let client = RenetClient::new(ConnectionConfig::default());
|
||||||
|
let server_addr = "127.0.0.1:5000".parse()?;
|
||||||
|
let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||||
|
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?;
|
||||||
|
let client_id = current_time.as_millis() as u64;
|
||||||
|
|
||||||
|
// Place username in user data
|
||||||
|
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
|
||||||
|
if username.len() > NETCODE_USER_DATA_BYTES - 8 {
|
||||||
|
panic!("Username is too big");
|
||||||
|
}
|
||||||
|
user_data[0..8].copy_from_slice(&(username.len() as u64).to_le_bytes());
|
||||||
|
user_data[8..username.len() + 8].copy_from_slice(username.as_bytes());
|
||||||
|
|
||||||
|
let authentication = ClientAuthentication::Unsecure {
|
||||||
|
server_addr,
|
||||||
|
client_id,
|
||||||
|
user_data: Some(user_data),
|
||||||
|
protocol_id: PROTOCOL_ID,
|
||||||
|
};
|
||||||
|
let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
|
||||||
|
|
||||||
|
Ok((client, transport, client_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn receive_events_from_server(
|
||||||
|
mut client: ResMut<RenetClient>,
|
||||||
|
mut game_state: ResMut<BevyGameState>,
|
||||||
|
mut game_events: EventWriter<BevyGameEvent>,
|
||||||
|
) {
|
||||||
|
while let Some(message) = client.receive_message(0) {
|
||||||
|
// Whenever the server sends a message we know that it must be a game event
|
||||||
|
let event: GameEvent = bincode::deserialize(&message).unwrap();
|
||||||
|
trace!("{:#?}", event);
|
||||||
|
|
||||||
|
// We trust the server - It's always been good to us!
|
||||||
|
// No need to validate the events it is sending us
|
||||||
|
game_state.0.consume(&event);
|
||||||
|
|
||||||
|
// Send the event into the bevy event system so systems can react to it
|
||||||
|
game_events.send(BevyGameEvent(event));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any error is found we just panic
|
||||||
|
fn panic_on_error_system(mut renet_error: EventReader<NetcodeTransportError>) {
|
||||||
|
for e in renet_error.iter() {
|
||||||
|
panic!("{}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
14
client_tui/Cargo.toml
Normal file
14
client_tui/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "client_tui"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.89"
|
||||||
|
bincode = "1.3.3"
|
||||||
|
crossterm = "0.28.1"
|
||||||
|
ratatui = "0.28.1"
|
||||||
|
# renet = "0.0.13"
|
||||||
|
store = { path = "../store" }
|
||||||
53
client_tui/src/app.rs
Normal file
53
client_tui/src/app.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
// Application.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct App {
|
||||||
|
// should the application exit?
|
||||||
|
pub should_quit: bool,
|
||||||
|
// counter
|
||||||
|
pub counter: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl App {
|
||||||
|
// Constructs a new instance of [`App`].
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles the tick event of the terminal.
|
||||||
|
pub fn tick(&self) {}
|
||||||
|
|
||||||
|
// Set running to false to quit the application.
|
||||||
|
pub fn quit(&mut self) {
|
||||||
|
self.should_quit = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_counter(&mut self) {
|
||||||
|
if let Some(res) = self.counter.checked_add(1) {
|
||||||
|
self.counter = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrement_counter(&mut self) {
|
||||||
|
if let Some(res) = self.counter.checked_sub(1) {
|
||||||
|
self.counter = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
#[test]
|
||||||
|
fn test_app_increment_counter() {
|
||||||
|
let mut app = App::default();
|
||||||
|
app.increment_counter();
|
||||||
|
assert_eq!(app.counter, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_app_decrement_counter() {
|
||||||
|
let mut app = App::default();
|
||||||
|
app.decrement_counter();
|
||||||
|
assert_eq!(app.counter, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
87
client_tui/src/event.rs
Normal file
87
client_tui/src/event.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
use std::{
|
||||||
|
sync::mpsc,
|
||||||
|
thread,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
|
||||||
|
|
||||||
|
// Terminal events.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub enum Event {
|
||||||
|
// Terminal tick.
|
||||||
|
Tick,
|
||||||
|
// Key press.
|
||||||
|
Key(KeyEvent),
|
||||||
|
// Mouse click/scroll.
|
||||||
|
Mouse(MouseEvent),
|
||||||
|
// Terminal resize.
|
||||||
|
Resize(u16, u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminal event handler.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EventHandler {
|
||||||
|
// Event sender channel.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
sender: mpsc::Sender<Event>,
|
||||||
|
// Event receiver channel.
|
||||||
|
receiver: mpsc::Receiver<Event>,
|
||||||
|
// Event handler thread.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
handler: thread::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventHandler {
|
||||||
|
// Constructs a new instance of [`EventHandler`].
|
||||||
|
pub fn new(tick_rate: u64) -> Self {
|
||||||
|
let tick_rate = Duration::from_millis(tick_rate);
|
||||||
|
let (sender, receiver) = mpsc::channel();
|
||||||
|
let handler = {
|
||||||
|
let sender = sender.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut last_tick = Instant::now();
|
||||||
|
loop {
|
||||||
|
let timeout = tick_rate
|
||||||
|
.checked_sub(last_tick.elapsed())
|
||||||
|
.unwrap_or(tick_rate);
|
||||||
|
|
||||||
|
if event::poll(timeout).expect("no events available") {
|
||||||
|
match event::read().expect("unable to read event") {
|
||||||
|
CrosstermEvent::Key(e) => {
|
||||||
|
if e.kind == event::KeyEventKind::Press {
|
||||||
|
sender.send(Event::Key(e))
|
||||||
|
} else {
|
||||||
|
Ok(()) // ignore KeyEventKind::Release on windows
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CrosstermEvent::Mouse(e) => sender.send(Event::Mouse(e)),
|
||||||
|
CrosstermEvent::Resize(w, h) => sender.send(Event::Resize(w, h)),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
.expect("failed to send terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
|
if last_tick.elapsed() >= tick_rate {
|
||||||
|
sender.send(Event::Tick).expect("failed to send tick event");
|
||||||
|
last_tick = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
sender,
|
||||||
|
receiver,
|
||||||
|
handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive the next event from the handler thread.
|
||||||
|
//
|
||||||
|
// This function will always block the current thread if
|
||||||
|
// there is no data available and it's possible for more data to be sent.
|
||||||
|
pub fn next(&self) -> Result<Event> {
|
||||||
|
Ok(self.receiver.recv()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
50
client_tui/src/main.rs
Normal file
50
client_tui/src/main.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Application.
|
||||||
|
pub mod app;
|
||||||
|
|
||||||
|
// Terminal events handler.
|
||||||
|
pub mod event;
|
||||||
|
|
||||||
|
// Widget renderer.
|
||||||
|
pub mod ui;
|
||||||
|
|
||||||
|
// Terminal user interface.
|
||||||
|
pub mod tui;
|
||||||
|
|
||||||
|
// Application updater.
|
||||||
|
pub mod update;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use app::App;
|
||||||
|
use event::{Event, EventHandler};
|
||||||
|
use ratatui::{backend::CrosstermBackend, Terminal};
|
||||||
|
use tui::Tui;
|
||||||
|
use update::update;
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// Create an application.
|
||||||
|
let mut app = App::new();
|
||||||
|
|
||||||
|
// Initialize the terminal user interface.
|
||||||
|
let backend = CrosstermBackend::new(std::io::stderr());
|
||||||
|
let terminal = Terminal::new(backend)?;
|
||||||
|
let events = EventHandler::new(250);
|
||||||
|
let mut tui = Tui::new(terminal, events);
|
||||||
|
tui.enter()?;
|
||||||
|
|
||||||
|
// Start the main loop.
|
||||||
|
while !app.should_quit {
|
||||||
|
// Render the user interface.
|
||||||
|
tui.draw(&mut app)?;
|
||||||
|
// Handle events.
|
||||||
|
match tui.events.next()? {
|
||||||
|
Event::Tick => {}
|
||||||
|
Event::Key(key_event) => update(&mut app, key_event),
|
||||||
|
Event::Mouse(_) => {}
|
||||||
|
Event::Resize(_, _) => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit the user interface.
|
||||||
|
tui.exit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
77
client_tui/src/tui.rs
Normal file
77
client_tui/src/tui.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
use std::{io, panic};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use crossterm::{
|
||||||
|
event::{DisableMouseCapture, EnableMouseCapture},
|
||||||
|
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type CrosstermTerminal = ratatui::Terminal<ratatui::backend::CrosstermBackend<std::io::Stderr>>;
|
||||||
|
|
||||||
|
use crate::{app::App, event::EventHandler, ui};
|
||||||
|
|
||||||
|
// Representation of a terminal user interface.
|
||||||
|
//
|
||||||
|
// It is responsible for setting up the terminal,
|
||||||
|
// initializing the interface and handling the draw events.
|
||||||
|
pub struct Tui {
|
||||||
|
// Interface to the Terminal.
|
||||||
|
terminal: CrosstermTerminal,
|
||||||
|
// Terminal event handler.
|
||||||
|
pub events: EventHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tui {
|
||||||
|
// Constructs a new instance of [`Tui`].
|
||||||
|
pub fn new(terminal: CrosstermTerminal, events: EventHandler) -> Self {
|
||||||
|
Self { terminal, events }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initializes the terminal interface.
|
||||||
|
//
|
||||||
|
// It enables the raw mode and sets terminal properties.
|
||||||
|
pub fn enter(&mut self) -> Result<()> {
|
||||||
|
terminal::enable_raw_mode()?;
|
||||||
|
crossterm::execute!(io::stderr(), EnterAlternateScreen, EnableMouseCapture)?;
|
||||||
|
|
||||||
|
// Define a custom panic hook to reset the terminal properties.
|
||||||
|
// This way, you won't have your terminal messed up if an unexpected error happens.
|
||||||
|
let panic_hook = panic::take_hook();
|
||||||
|
panic::set_hook(Box::new(move |panic| {
|
||||||
|
Self::reset().expect("failed to reset the terminal");
|
||||||
|
panic_hook(panic);
|
||||||
|
}));
|
||||||
|
|
||||||
|
self.terminal.hide_cursor()?;
|
||||||
|
self.terminal.clear()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// [`Draw`] the terminal interface by [`rendering`] the widgets.
|
||||||
|
//
|
||||||
|
// [`Draw`]: tui::Terminal::draw
|
||||||
|
// [`rendering`]: crate::ui:render
|
||||||
|
pub fn draw(&mut self, app: &mut App) -> Result<()> {
|
||||||
|
self.terminal.draw(|frame| ui::render(app, frame))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the terminal interface.
|
||||||
|
//
|
||||||
|
// This function is also used for the panic hook to revert
|
||||||
|
// the terminal properties if unexpected errors occur.
|
||||||
|
fn reset() -> Result<()> {
|
||||||
|
terminal::disable_raw_mode()?;
|
||||||
|
crossterm::execute!(io::stderr(), LeaveAlternateScreen, DisableMouseCapture)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exits the terminal interface.
|
||||||
|
//
|
||||||
|
// It disables the raw mode and reverts back the terminal properties.
|
||||||
|
pub fn exit(&mut self) -> Result<()> {
|
||||||
|
Self::reset()?;
|
||||||
|
self.terminal.show_cursor()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
30
client_tui/src/ui.rs
Normal file
30
client_tui/src/ui.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
use ratatui::{
|
||||||
|
prelude::{Alignment, Frame},
|
||||||
|
style::{Color, Style},
|
||||||
|
widgets::{Block, BorderType, Borders, Paragraph},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
|
|
||||||
|
pub fn render(app: &mut App, f: &mut Frame) {
|
||||||
|
f.render_widget(
|
||||||
|
Paragraph::new(format!(
|
||||||
|
"
|
||||||
|
Press `Esc`, `Ctrl-C` or `q` to stop running.\n\
|
||||||
|
Press `j` and `k` to increment and decrement the counter respectively.\n\
|
||||||
|
Counter: {}
|
||||||
|
",
|
||||||
|
app.counter
|
||||||
|
))
|
||||||
|
.block(
|
||||||
|
Block::default()
|
||||||
|
.title("Counter App")
|
||||||
|
.title_alignment(Alignment::Center)
|
||||||
|
.borders(Borders::ALL)
|
||||||
|
.border_type(BorderType::Rounded),
|
||||||
|
)
|
||||||
|
.style(Style::default().fg(Color::Yellow))
|
||||||
|
.alignment(Alignment::Center),
|
||||||
|
f.area(),
|
||||||
|
)
|
||||||
|
}
|
||||||
17
client_tui/src/update.rs
Normal file
17
client_tui/src/update.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
|
|
||||||
|
pub fn update(app: &mut App, key_event: KeyEvent) {
|
||||||
|
match key_event.code {
|
||||||
|
KeyCode::Esc | KeyCode::Char('q') => app.quit(),
|
||||||
|
KeyCode::Char('c') | KeyCode::Char('C') => {
|
||||||
|
if key_event.modifiers == KeyModifiers::CONTROL {
|
||||||
|
app.quit()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KeyCode::Right | KeyCode::Char('j') => app.increment_counter(),
|
||||||
|
KeyCode::Left | KeyCode::Char('k') => app.decrement_counter(),
|
||||||
|
_ => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
3
justfile
3
justfile
|
|
@ -22,6 +22,9 @@ profile:
|
||||||
pythonlib:
|
pythonlib:
|
||||||
maturin build -m store/Cargo.toml --release
|
maturin build -m store/Cargo.toml --release
|
||||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||||
|
trainsimple:
|
||||||
|
cargo build --release --bin=train_dqn_simple
|
||||||
|
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out
|
||||||
trainbot algo:
|
trainbot algo:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
|
|
|
||||||
14
server/Cargo.toml
Normal file
14
server/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "trictrac-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
store = { path = "../store" }
|
||||||
|
env_logger = "0.10.0"
|
||||||
|
log = "0.4.20"
|
||||||
|
pico-args = "0.5.0"
|
||||||
|
renet = "0.0.13"
|
||||||
|
bincode = "1.3.3"
|
||||||
147
server/src/main.rs
Normal file
147
server/src/main.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
use log::{info, trace, warn};
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, Instant, SystemTime};
|
||||||
|
|
||||||
|
use renet::{
|
||||||
|
transport::{
|
||||||
|
NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
|
||||||
|
},
|
||||||
|
ConnectionConfig, RenetServer, ServerEvent,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only clients that can provide the same PROTOCOL_ID that the server is using will be able to connect.
|
||||||
|
// This can be used to make sure players use the most recent version of the client for instance.
|
||||||
|
pub const PROTOCOL_ID: u64 = 2878;
|
||||||
|
|
||||||
|
/// Utility function for extracting a players name from renet user data
|
||||||
|
fn name_from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> String {
|
||||||
|
let mut buffer = [0u8; 8];
|
||||||
|
buffer.copy_from_slice(&user_data[0..8]);
|
||||||
|
let mut len = u64::from_le_bytes(buffer) as usize;
|
||||||
|
len = len.min(NETCODE_USER_DATA_BYTES - 8);
|
||||||
|
let data = user_data[8..len + 8].to_vec();
|
||||||
|
String::from_utf8(data).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let mut server = RenetServer::new(ConnectionConfig::default());
|
||||||
|
|
||||||
|
// Setup transport layer
|
||||||
|
const SERVER_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5000);
|
||||||
|
let socket: UdpSocket = UdpSocket::bind(SERVER_ADDR).unwrap();
|
||||||
|
let server_config = ServerConfig {
|
||||||
|
max_clients: 2,
|
||||||
|
protocol_id: PROTOCOL_ID,
|
||||||
|
public_addr: SERVER_ADDR,
|
||||||
|
authentication: ServerAuthentication::Unsecure,
|
||||||
|
};
|
||||||
|
let current_time = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap();
|
||||||
|
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
|
||||||
|
|
||||||
|
trace!("❂ TricTrac server listening on {SERVER_ADDR}");
|
||||||
|
|
||||||
|
let mut game_state = store::GameState::default();
|
||||||
|
let mut last_updated = Instant::now();
|
||||||
|
loop {
|
||||||
|
// Update server time
|
||||||
|
let now = Instant::now();
|
||||||
|
let delta_time = now - last_updated;
|
||||||
|
server.update(delta_time);
|
||||||
|
transport.update(delta_time, &mut server).unwrap();
|
||||||
|
last_updated = now;
|
||||||
|
|
||||||
|
// Receive connection events from clients
|
||||||
|
while let Some(event) = server.get_event() {
|
||||||
|
match event {
|
||||||
|
ServerEvent::ClientConnected { client_id } => {
|
||||||
|
let user_data = transport.user_data(client_id).unwrap();
|
||||||
|
|
||||||
|
// Tell the recently joined player about the other player
|
||||||
|
for (player_id, player) in game_state.players.iter() {
|
||||||
|
let event = store::GameEvent::PlayerJoined {
|
||||||
|
player_id: *player_id,
|
||||||
|
name: player.name.clone(),
|
||||||
|
};
|
||||||
|
server.send_message(client_id, 0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new player to the game
|
||||||
|
let event = store::GameEvent::PlayerJoined {
|
||||||
|
player_id: client_id,
|
||||||
|
name: name_from_user_data(&user_data),
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
|
||||||
|
// Tell all players that a new player has joined
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
info!("🎉 Client {client_id} connected.");
|
||||||
|
// In TicTacTussle the game can begin once two players has joined
|
||||||
|
if game_state.players.len() == 2 {
|
||||||
|
let event = store::GameEvent::BeginGame {
|
||||||
|
goes_first: client_id,
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
trace!("The game gas begun");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ServerEvent::ClientDisconnected {
|
||||||
|
client_id,
|
||||||
|
reason: _,
|
||||||
|
} => {
|
||||||
|
// First consume a disconnect event
|
||||||
|
let event = store::GameEvent::PlayerDisconnected {
|
||||||
|
player_id: client_id,
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
info!("Client {client_id} disconnected");
|
||||||
|
|
||||||
|
// Then end the game, since tic tac toe can't go on with a single player
|
||||||
|
let event = store::GameEvent::EndGame {
|
||||||
|
reason: store::EndGameReason::PlayerLeft {
|
||||||
|
player_id: client_id,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
// NOTE: Since we don't authenticate users we can't do any reconnection attempts.
|
||||||
|
// We simply have no way to know if the next user is the same as the one that disconnected.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive GameEvents from clients. Broadcast valid events.
|
||||||
|
for client_id in server.clients_id().into_iter() {
|
||||||
|
while let Some(message) = server.receive_message(client_id, 0) {
|
||||||
|
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
|
||||||
|
if game_state.validate(&event) {
|
||||||
|
game_state.consume(&event);
|
||||||
|
trace!("Player {client_id} sent:\n\t{event:#?}");
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
// Determine if a player has won the game
|
||||||
|
if let Some(winner) = game_state.determine_winner() {
|
||||||
|
let event = store::GameEvent::EndGame {
|
||||||
|
reason: store::EndGameReason::PlayerWon { winner },
|
||||||
|
};
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Player {client_id} sent invalid event:\n\t{event:#?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.send_packets(&mut server);
|
||||||
|
thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -742,10 +742,6 @@ impl GameState {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
|
|
||||||
self.mark_points(player_id, points)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||||
// Update player points and holes
|
// Update player points and holes
|
||||||
let mut new_hole = false;
|
let mut new_hole = false;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue