Compare commits
No commits in common. "74f692d7babeaa442d99c7cf9294b13d18e7b198" and "883ebf9bc166f3d601ff8d58457e53523c4cbc38" have entirely different histories.
74f692d7ba
...
883ebf9bc1
46
Cargo.lock
generated
46
Cargo.lock
generated
|
|
@ -834,7 +834,7 @@ dependencies = [
|
|||
"derive-new",
|
||||
"log",
|
||||
"nvml-wrapper",
|
||||
"ratatui",
|
||||
"ratatui 0.29.0",
|
||||
"rstest",
|
||||
"serde",
|
||||
"sysinfo",
|
||||
|
|
@ -1066,6 +1066,17 @@ dependencies = [
|
|||
"store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "client_tui"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode 1.3.3",
|
||||
"crossterm",
|
||||
"ratatui 0.28.1",
|
||||
"store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
|
|
@ -4403,6 +4414,27 @@ version = "0.1.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde"
|
||||
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d"
|
||||
dependencies = [
|
||||
"bitflags 2.9.4",
|
||||
"cassowary",
|
||||
"compact_str",
|
||||
"crossterm",
|
||||
"instability",
|
||||
"itertools 0.13.0",
|
||||
"lru",
|
||||
"paste",
|
||||
"strum 0.26.3",
|
||||
"strum_macros 0.26.4",
|
||||
"unicode-segmentation",
|
||||
"unicode-truncate",
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.29.0"
|
||||
|
|
@ -5781,6 +5813,18 @@ dependencies = [
|
|||
"strength_reduce",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trictrac-server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bincode 1.3.3",
|
||||
"env_logger 0.10.2",
|
||||
"log",
|
||||
"pico-args",
|
||||
"renet",
|
||||
"store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = ["client_cli", "bot", "store"]
|
||||
members = ["client_tui", "client_cli", "bot", "server", "store"]
|
||||
|
|
|
|||
194
bot/src/burnrl/algos/dqn_big.rs
Normal file
194
bot/src/burnrl/algos/dqn_big.rs
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||
(self.linear_0, self.linear_1, self.linear_2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
relu(self.linear_2.forward(layer_1_output))
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DQNModel<B> for Net<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
// ) -> DQN<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
let mut agent = MyAgent::new(model);
|
||||
|
||||
// let config = DQNTrainingConfig::default();
|
||||
let config = DQNTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(config.clip_grad.clone())
|
||||
.init();
|
||||
|
||||
let mut policy_net = agent.model().as_ref().unwrap().clone();
|
||||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward: ElemType = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut state = env.state();
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
while !episode_done {
|
||||
let eps_threshold = conf.eps_end
|
||||
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
|
||||
let action =
|
||||
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
|
||||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward +=
|
||||
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
policy_net =
|
||||
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
let envmut = env.as_mut();
|
||||
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
|
||||
* 100.0)
|
||||
.round() as u32;
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
envmut.goodmoves_count,
|
||||
goodmoves_ratio,
|
||||
envmut.pointrolls_count,
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
now = SystemTime::now();
|
||||
} else {
|
||||
state = *snapshot.state();
|
||||
}
|
||||
}
|
||||
}
|
||||
let valid_agent = agent.valid();
|
||||
if let Some(path) = &conf.save_path {
|
||||
save_model(valid_agent.model().as_ref().unwrap(), path);
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
pub mod dqn;
|
||||
pub mod dqn_big;
|
||||
pub mod dqn_valid;
|
||||
pub mod ppo;
|
||||
pub mod ppo_big;
|
||||
pub mod ppo_valid;
|
||||
pub mod sac;
|
||||
pub mod sac_big;
|
||||
pub mod sac_valid;
|
||||
|
|
|
|||
191
bot/src/burnrl/algos/ppo_big.rs
Normal file
191
bot/src/burnrl/algos/ppo_big.rs
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||
use crate::burnrl::utils::Config;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear: Linear<B>,
|
||||
linear_actor: Linear<B>,
|
||||
linear_critic: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
||||
Self {
|
||||
linear: LinearConfig::new(input_size, dense_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_actor: LinearConfig::new(dense_size, output_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_critic: LinearConfig::new(dense_size, 1)
|
||||
.with_initializer(initializer)
|
||||
.init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
||||
let values = self.linear_critic.forward(layer_0_output);
|
||||
|
||||
PPOOutput::<B>::new(policies, values)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PPOModel<B> for Net<B> {}
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 512;
|
||||
|
||||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
// ) -> PPO<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let mut model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
let agent = MyAgent::default();
|
||||
let config = PPOTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
lambda: conf.lambda,
|
||||
epsilon_clip: conf.epsilon_clip,
|
||||
critic_weight: conf.critic_weight,
|
||||
entropy_weight: conf.entropy_weight,
|
||||
learning_rate: conf.learning_rate,
|
||||
epochs: conf.epochs,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(config.clip_grad.clone())
|
||||
.init();
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
env.reset();
|
||||
while !episode_done {
|
||||
let state = env.state();
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
||||
let snapshot = env.step(action);
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
episode_duration += 1;
|
||||
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
|
||||
now = SystemTime::now();
|
||||
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
||||
memory.clear();
|
||||
}
|
||||
|
||||
if let Some(path) = &conf.save_path {
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
||||
|
||||
// Save the trained model (backend B) to a temporary file
|
||||
recorder
|
||||
.record(model.clone().into_record(), tmp_path.clone())
|
||||
.expect("Failed to save temporary model");
|
||||
|
||||
// Create a new model instance with the target backend (NdArray)
|
||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
// Load the record from the temporary file into the new model
|
||||
let record = recorder
|
||||
.load(tmp_path.clone(), &device)
|
||||
.expect("Failed to load temporary model");
|
||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
||||
|
||||
// Clean up the temporary file
|
||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
||||
|
||||
save_model(&model_with_loaded_weights, path);
|
||||
}
|
||||
agent.valid(model)
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
222
bot/src/burnrl/algos/sac_big.rs
Normal file
222
bot/src/burnrl/algos/sac_big.rs
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
use crate::burnrl::environment_big::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Actor<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Actor<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
softmax(self.linear_2.forward(layer_1_output), 1)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACActor<B> for Actor<B> {}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Critic<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Critic<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||
(self.linear_0, self.linear_1, self.linear_2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
self.linear_2.forward(layer_1_output)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACCritic<B> for Critic<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 4096;
|
||||
|
||||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
let state_dim = <<E as Environment>::StateType as State>::size();
|
||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
||||
|
||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
||||
|
||||
let mut agent = MyAgent::default();
|
||||
|
||||
let config = SACTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
min_probability: conf.min_probability,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
||||
|
||||
let mut optimizer = SACOptimizer::new(
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.init(),
|
||||
);
|
||||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut state = env.state();
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
while !episode_done {
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
||||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs()
|
||||
);
|
||||
now = SystemTime::now();
|
||||
} else {
|
||||
state = *snapshot.state();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let valid_agent = agent.valid(nets.actor);
|
||||
if let Some(path) = &conf.save_path {
|
||||
if let Some(model) = valid_agent.model() {
|
||||
save_model(model, path);
|
||||
}
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Actor::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
|
|
@ -6,10 +6,10 @@ use burn_rl::base::{Action, Environment, Snapshot, State};
|
|||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -1.0012121;
|
||||
const REWARD_VALID_MOVE: f32 = 1.0012121;
|
||||
const REWARD_RATIO: f32 = 0.1;
|
||||
const WIN_POINTS: f32 = 100.0;
|
||||
const ERROR_REWARD: f32 = -1.12121;
|
||||
const REWARD_VALID_MOVE: f32 = 1.12121;
|
||||
const REWARD_RATIO: f32 = 0.01;
|
||||
const WIN_POINTS: f32 = 1.0;
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
@ -285,7 +285,7 @@ impl TrictracEnvironment {
|
|||
if let Some(event) = action.to_event(&self.game) {
|
||||
if self.game.validate(&event) {
|
||||
self.game.consume(&event);
|
||||
// reward += REWARD_VALID_MOVE;
|
||||
reward += REWARD_VALID_MOVE;
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = thread_rng();
|
||||
|
|
@ -312,11 +312,9 @@ impl TrictracEnvironment {
|
|||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
}
|
||||
} else {
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
}
|
||||
|
||||
(reward, is_rollpoint)
|
||||
|
|
|
|||
469
bot/src/burnrl/environment_big.rs
Normal file
469
bot/src/burnrl/environment_big.rs
Normal file
|
|
@ -0,0 +1,469 @@
|
|||
use crate::training_common_big;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -2.12121;
|
||||
const REWARD_VALID_MOVE: f32 = 2.12121;
|
||||
const REWARD_RATIO: f32 = 0.01;
|
||||
const WIN_POINTS: f32 = 0.1;
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TrictracState {
|
||||
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
|
||||
}
|
||||
|
||||
impl State for TrictracState {
|
||||
type Data = [i8; 36];
|
||||
|
||||
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
|
||||
Tensor::from_floats(self.data, &B::Device::default())
|
||||
}
|
||||
|
||||
fn size() -> usize {
|
||||
36
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracState {
|
||||
/// Convertit un GameState en TrictracState
|
||||
pub fn from_game_state(game_state: &GameState) -> Self {
|
||||
let state_vec = game_state.to_vec();
|
||||
let mut data = [0; 36];
|
||||
|
||||
// Copier les données en s'assurant qu'on ne dépasse pas la taille
|
||||
let copy_len = state_vec.len().min(36);
|
||||
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
|
||||
|
||||
TrictracState { data }
|
||||
}
|
||||
}
|
||||
|
||||
/// Actions possibles dans Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct TrictracAction {
|
||||
// u32 as required by burn_rl::base::Action type
|
||||
pub index: u32,
|
||||
}
|
||||
|
||||
impl Action for TrictracAction {
|
||||
fn random() -> Self {
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
TrictracAction {
|
||||
index: rng.gen_range(0..Self::size() as u32),
|
||||
}
|
||||
}
|
||||
|
||||
fn enumerate() -> Vec<Self> {
|
||||
(0..Self::size() as u32)
|
||||
.map(|index| TrictracAction { index })
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn size() -> usize {
|
||||
1252
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for TrictracAction {
|
||||
fn from(index: u32) -> Self {
|
||||
TrictracAction { index }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TrictracAction> for u32 {
|
||||
fn from(action: TrictracAction) -> u32 {
|
||||
action.index
|
||||
}
|
||||
}
|
||||
|
||||
/// Environnement Trictrac pour burn-rl
|
||||
#[derive(Debug)]
|
||||
pub struct TrictracEnvironment {
|
||||
pub game: GameState,
|
||||
active_player_id: PlayerId,
|
||||
opponent_id: PlayerId,
|
||||
current_state: TrictracState,
|
||||
episode_reward: f32,
|
||||
pub step_count: usize,
|
||||
pub max_steps: usize,
|
||||
pub pointrolls_count: usize,
|
||||
pub goodmoves_count: usize,
|
||||
pub goodmoves_ratio: f32,
|
||||
pub visualized: bool,
|
||||
}
|
||||
|
||||
impl Environment for TrictracEnvironment {
|
||||
type StateType = TrictracState;
|
||||
type ActionType = TrictracAction;
|
||||
type RewardType = f32;
|
||||
|
||||
fn new(visualized: bool) -> Self {
|
||||
let mut game = GameState::new(false);
|
||||
|
||||
// Ajouter deux joueurs
|
||||
game.init_player("DQN Agent");
|
||||
game.init_player("Opponent");
|
||||
let player1_id = 1;
|
||||
let player2_id = 2;
|
||||
|
||||
// Commencer la partie
|
||||
game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
let current_state = TrictracState::from_game_state(&game);
|
||||
TrictracEnvironment {
|
||||
game,
|
||||
active_player_id: player1_id,
|
||||
opponent_id: player2_id,
|
||||
current_state,
|
||||
episode_reward: 0.0,
|
||||
step_count: 0,
|
||||
max_steps: 2000,
|
||||
pointrolls_count: 0,
|
||||
goodmoves_count: 0,
|
||||
goodmoves_ratio: 0.0,
|
||||
visualized,
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::StateType {
|
||||
self.current_state
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Snapshot<Self> {
|
||||
// Réinitialiser le jeu
|
||||
self.game = GameState::new(false);
|
||||
self.game.init_player("DQN Agent");
|
||||
self.game.init_player("Opponent");
|
||||
|
||||
// Commencer la partie
|
||||
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward = 0.0;
|
||||
self.goodmoves_ratio = if self.step_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
println!(
|
||||
"info: correct moves: {} ({}%)",
|
||||
self.goodmoves_count,
|
||||
(100.0 * self.goodmoves_ratio).round() as u32
|
||||
);
|
||||
self.step_count = 0;
|
||||
self.pointrolls_count = 0;
|
||||
self.goodmoves_count = 0;
|
||||
|
||||
Snapshot::new(self.current_state, 0.0, false)
|
||||
}
|
||||
|
||||
fn step(&mut self, action: Self::ActionType) -> Snapshot<Self> {
|
||||
self.step_count += 1;
|
||||
|
||||
// Convertir l'action burn-rl vers une action Trictrac
|
||||
let trictrac_action = Self::convert_action(action);
|
||||
|
||||
let mut reward = 0.0;
|
||||
let is_rollpoint;
|
||||
|
||||
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||
if self.game.active_player_id == self.active_player_id {
|
||||
if let Some(action) = trictrac_action {
|
||||
(reward, is_rollpoint) = self.execute_action(action);
|
||||
if is_rollpoint {
|
||||
self.pointrolls_count += 1;
|
||||
}
|
||||
if reward != ERROR_REWARD {
|
||||
self.goodmoves_count += 1;
|
||||
// println!("{str_action}");
|
||||
}
|
||||
} else {
|
||||
// Action non convertible, pénalité
|
||||
reward = -0.5;
|
||||
}
|
||||
}
|
||||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||
// print!(":");
|
||||
reward += self.play_opponent_if_needed();
|
||||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
// let max_steps = self.max_steps
|
||||
// let max_steps = self.min_steps
|
||||
// + (self.max_steps as f32 - self.min_steps)
|
||||
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||
|
||||
if done {
|
||||
// Récompense finale basée sur le résultat
|
||||
if let Some(winner_id) = self.game.determine_winner() {
|
||||
if winner_id == self.active_player_id {
|
||||
reward += WIN_POINTS; // Victoire
|
||||
} else {
|
||||
reward -= WIN_POINTS; // Défaite
|
||||
}
|
||||
}
|
||||
}
|
||||
let terminated = done || self.step_count >= self.max_steps;
|
||||
|
||||
// Mettre à jour l'état
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward += reward;
|
||||
if self.visualized && terminated {
|
||||
println!(
|
||||
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
|
||||
self.episode_reward, self.step_count
|
||||
);
|
||||
}
|
||||
|
||||
Snapshot::new(self.current_state, reward, terminated)
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnvironment {
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
#[allow(dead_code)]
|
||||
fn convert_valid_action_index(
|
||||
&self,
|
||||
action: TrictracAction,
|
||||
game_state: &GameState,
|
||||
) -> Option<training_common_big::TrictracAction> {
|
||||
use training_common_big::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index].clone())
|
||||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action:training_common_big::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||
use training_common_big::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
let mut need_roll = false;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
need_roll = true;
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.active_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game.dice.values.0, self.game.dice.values.1)
|
||||
} else {
|
||||
(self.game.dice.values.1, self.game.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = event {
|
||||
if self.game.validate(&event) {
|
||||
self.game.consume(&event);
|
||||
reward += REWARD_VALID_MOVE;
|
||||
// Simuler le résultat des dés après un Roll
|
||||
// if matches!(action, TrictracAction::Roll) {
|
||||
if need_roll {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.active_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
// print!("o");
|
||||
if self.game.validate(&dice_event) {
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
is_rollpoint = true;
|
||||
// println!("info: rolled for {reward}");
|
||||
}
|
||||
// Récompense proportionnelle aux points
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pénalité pour action invalide
|
||||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = ERROR_REWARD;
|
||||
}
|
||||
}
|
||||
|
||||
(reward, is_rollpoint)
|
||||
}
|
||||
|
||||
/// Fait jouer l'adversaire avec une stratégie simple
|
||||
fn play_opponent_if_needed(&mut self) -> f32 {
|
||||
// print!("z?");
|
||||
let mut reward = 0.0;
|
||||
|
||||
// Si c'est le tour de l'adversaire, jouer automatiquement
|
||||
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||
// Utiliser la stratégie default pour l'adversaire
|
||||
use crate::BotStrategy;
|
||||
|
||||
let mut strategy = crate::strategy::random::RandomStrategy::default();
|
||||
strategy.set_player_id(self.opponent_id);
|
||||
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
|
||||
strategy.set_color(color);
|
||||
}
|
||||
*strategy.get_mut_game() = self.game.clone();
|
||||
|
||||
// Exécuter l'action selon le turn_stage
|
||||
let mut calculate_points = false;
|
||||
let opponent_color = store::Color::Black;
|
||||
let event = match self.game.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
calculate_points = true; // comment to replicate burnrl_before
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkPoints => {
|
||||
panic!("in play_opponent_if_needed > TurnStage::MarkPoints");
|
||||
// let dice_roll_count = self
|
||||
// .game
|
||||
// .players
|
||||
// .get(&self.opponent_id)
|
||||
// .unwrap()
|
||||
// .dice_roll_count;
|
||||
// let points_rules =
|
||||
// PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
// GameEvent::Mark {
|
||||
// player_id: self.opponent_id,
|
||||
// points: points_rules.get_points(dice_roll_count).0,
|
||||
// }
|
||||
}
|
||||
TurnStage::MarkAdvPoints => {
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
.get(&self.opponent_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
// pas de reward : déjà comptabilisé lors du tour de blanc
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points: points_rules.get_points(dice_roll_count).1,
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
GameEvent::Go {
|
||||
player_id: self.opponent_id,
|
||||
}
|
||||
}
|
||||
TurnStage::Move => GameEvent::Move {
|
||||
player_id: self.opponent_id,
|
||||
moves: strategy.choose_move(),
|
||||
},
|
||||
};
|
||||
|
||||
if self.game.validate(&event) {
|
||||
self.game.consume(&event);
|
||||
// print!(".");
|
||||
if calculate_points {
|
||||
// print!("x");
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
.get(&self.opponent_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
// Récompense proportionnelle aux points
|
||||
let adv_reward = REWARD_RATIO * (points - adv_points) as f32;
|
||||
reward -= adv_reward;
|
||||
// if adv_reward != 0.0 {
|
||||
// println!("info: opponent : {adv_reward} -> {reward}");
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
|
||||
fn as_mut(&mut self) -> &mut Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,9 @@
|
|||
use crate::training_common;
|
||||
use crate::training_common_big;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -1.0012121;
|
||||
const REWARD_RATIO: f32 = 0.1;
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TrictracState {
|
||||
|
|
@ -217,16 +214,16 @@ impl TrictracEnvironment {
|
|||
const REWARD_RATIO: f32 = 1.0;
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
|
||||
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
fn convert_valid_action_index(
|
||||
&self,
|
||||
action: TrictracAction,
|
||||
) -> Option<training_common::TrictracAction> {
|
||||
use training_common::get_valid_actions;
|
||||
) -> Option<training_common_big::TrictracAction> {
|
||||
use training_common_big::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
|
@ -243,19 +240,72 @@ impl TrictracEnvironment {
|
|||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: training_common::TrictracAction,
|
||||
// action: training_common_big::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
||||
use training_common::TrictracAction;
|
||||
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
|
||||
use training_common_big::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.active_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game.dice.values.0, self.game.dice.values.1)
|
||||
} else {
|
||||
(self.game.dice.values.1, self.game.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = action.to_event(&self.game) {
|
||||
if let Some(event) = event {
|
||||
if self.game.validate(&event) {
|
||||
self.game.consume(&event);
|
||||
// reward += REWARD_VALID_MOVE;
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = thread_rng();
|
||||
|
|
@ -269,7 +319,7 @@ impl TrictracEnvironment {
|
|||
if self.game.validate(&dice_event) {
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
is_rollpoint = true;
|
||||
// println!("info: rolled for {reward}");
|
||||
|
|
@ -281,12 +331,9 @@ impl TrictracEnvironment {
|
|||
// Pénalité pour action invalide
|
||||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
println!("info: action invalide -> err_reward");
|
||||
reward = Self::ERROR_REWARD;
|
||||
}
|
||||
} else {
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
}
|
||||
|
||||
(reward, is_rollpoint)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
|
||||
use bot::burnrl::algos::{
|
||||
dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid,
|
||||
};
|
||||
use bot::burnrl::environment::TrictracEnvironment;
|
||||
use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig;
|
||||
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
||||
use bot::burnrl::utils::{demo_model, Config};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
|
|
@ -33,6 +36,16 @@ fn main() {
|
|||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_big" => {
|
||||
let _agent = dqn_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_valid" => {
|
||||
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
|
|
@ -53,6 +66,16 @@ fn main() {
|
|||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac_big" => {
|
||||
let _agent = sac_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac_valid" => {
|
||||
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
|
|
@ -73,6 +96,16 @@ fn main() {
|
|||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo_big" => {
|
||||
let _agent = ppo_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_big::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentBig, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo_valid" => {
|
||||
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod algos;
|
||||
pub mod environment;
|
||||
pub mod environment_big;
|
||||
pub mod environment_valid;
|
||||
pub mod utils;
|
||||
|
|
|
|||
153
bot/src/dqn_simple/dqn_model.rs
Normal file
153
bot/src/dqn_simple/dqn_model.rs
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
use crate::training_common_big::TrictracAction;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration pour l'agent DQN
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DqnConfig {
|
||||
pub state_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_actions: usize,
|
||||
pub learning_rate: f64,
|
||||
pub gamma: f64,
|
||||
pub epsilon: f64,
|
||||
pub epsilon_decay: f64,
|
||||
pub epsilon_min: f64,
|
||||
pub replay_buffer_size: usize,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state_size: 36,
|
||||
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||
num_actions: TrictracAction::action_space_size(),
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 0.1,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimpleNeuralNetwork {
|
||||
pub weights1: Vec<Vec<f32>>,
|
||||
pub biases1: Vec<f32>,
|
||||
pub weights2: Vec<Vec<f32>>,
|
||||
pub biases2: Vec<f32>,
|
||||
pub weights3: Vec<Vec<f32>>,
|
||||
pub biases3: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SimpleNeuralNetwork {
|
||||
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||
let weights1 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..input_size)
|
||||
.map(|_| rng.gen_range(-scale1..scale1))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases1 = vec![0.0; hidden_size];
|
||||
|
||||
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights2 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale2..scale2))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases2 = vec![0.0; hidden_size];
|
||||
|
||||
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights3 = (0..output_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale3..scale3))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases3 = vec![0.0; output_size];
|
||||
|
||||
Self {
|
||||
weights1,
|
||||
biases1,
|
||||
weights2,
|
||||
biases2,
|
||||
weights3,
|
||||
biases3,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
// Première couche
|
||||
let mut layer1: Vec<f32> = self.biases1.clone();
|
||||
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < input.len() {
|
||||
layer1[i] += input[j] * weight;
|
||||
}
|
||||
}
|
||||
layer1[i] = layer1[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Deuxième couche
|
||||
let mut layer2: Vec<f32> = self.biases2.clone();
|
||||
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer1.len() {
|
||||
layer2[i] += layer1[j] * weight;
|
||||
}
|
||||
}
|
||||
layer2[i] = layer2[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Couche de sortie
|
||||
let mut output: Vec<f32> = self.biases3.clone();
|
||||
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer2.len() {
|
||||
output[i] += layer2[j] * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||||
let q_values = self.forward(input);
|
||||
q_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(index, _)| index)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let data = std::fs::read_to_string(path)?;
|
||||
let network = serde_json::from_str(&data)?;
|
||||
Ok(network)
|
||||
}
|
||||
}
|
||||
494
bot/src/dqn_simple/dqn_trainer.rs
Normal file
494
bot/src/dqn_simple/dqn_trainer.rs
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
use crate::{CheckerMove, Color, GameState, PlayerId};
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
|
||||
use crate::training_common_big::{get_valid_actions, TrictracAction};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Experience {
|
||||
pub state: Vec<f32>,
|
||||
pub action: TrictracAction,
|
||||
pub reward: f32,
|
||||
pub next_state: Vec<f32>,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Buffer de replay pour stocker les expériences
|
||||
#[derive(Debug)]
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Experience>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, experience: Experience) {
|
||||
if self.buffer.len() >= self.capacity {
|
||||
self.buffer.pop_front();
|
||||
}
|
||||
self.buffer.push_back(experience);
|
||||
}
|
||||
|
||||
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||||
let mut rng = thread_rng();
|
||||
let len = self.buffer.len();
|
||||
if len < batch_size {
|
||||
return self.buffer.iter().cloned().collect();
|
||||
}
|
||||
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
for _ in 0..batch_size {
|
||||
let idx = rng.gen_range(0..len);
|
||||
batch.push(self.buffer[idx].clone());
|
||||
}
|
||||
batch
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent DQN pour l'apprentissage par renforcement
|
||||
#[derive(Debug)]
|
||||
pub struct DqnAgent {
|
||||
config: DqnConfig,
|
||||
model: SimpleNeuralNetwork,
|
||||
target_model: SimpleNeuralNetwork,
|
||||
replay_buffer: ReplayBuffer,
|
||||
epsilon: f64,
|
||||
step_count: usize,
|
||||
}
|
||||
|
||||
impl DqnAgent {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let model =
|
||||
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||
let target_model = model.clone();
|
||||
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||
let epsilon = config.epsilon;
|
||||
|
||||
Self {
|
||||
config,
|
||||
model,
|
||||
target_model,
|
||||
replay_buffer,
|
||||
epsilon,
|
||||
step_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
// Fallback si aucune action valide
|
||||
return TrictracAction::Roll;
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
if rng.gen::<f64>() < self.epsilon {
|
||||
// Exploration : action valide aléatoire
|
||||
valid_actions
|
||||
.choose(&mut rng)
|
||||
.cloned()
|
||||
.unwrap_or(TrictracAction::Roll)
|
||||
} else {
|
||||
// Exploitation : meilleure action valide selon le modèle
|
||||
let q_values = self.model.forward(state);
|
||||
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for action in &valid_actions {
|
||||
let action_index = action.to_action_index();
|
||||
if action_index < q_values.len() {
|
||||
let q_value = q_values[action_index];
|
||||
if q_value > best_q_value {
|
||||
best_q_value = q_value;
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best_action.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_experience(&mut self, experience: Experience) {
|
||||
self.replay_buffer.push(experience);
|
||||
}
|
||||
|
||||
pub fn train(&mut self) {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return;
|
||||
}
|
||||
|
||||
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||||
// Dans une implémentation complète, ici on ferait la backpropagation
|
||||
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||||
self.step_count += 1;
|
||||
|
||||
// Mise à jour du target model tous les 100 steps
|
||||
if self.step_count % 100 == 0 {
|
||||
self.target_model = self.model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.model.save(path)
|
||||
}
|
||||
|
||||
pub fn get_epsilon(&self) -> f64 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
pub fn get_step_count(&self) -> usize {
|
||||
self.step_count
|
||||
}
|
||||
}
|
||||
|
||||
/// Environnement Trictrac pour l'entraînement
|
||||
#[derive(Debug)]
|
||||
pub struct TrictracEnv {
|
||||
pub game_state: GameState,
|
||||
pub agent_player_id: PlayerId,
|
||||
pub opponent_player_id: PlayerId,
|
||||
pub agent_color: Color,
|
||||
pub max_steps: usize,
|
||||
pub current_step: usize,
|
||||
}
|
||||
|
||||
impl Default for TrictracEnv {
|
||||
fn default() -> Self {
|
||||
let mut game_state = GameState::new(false);
|
||||
game_state.init_player("agent");
|
||||
game_state.init_player("opponent");
|
||||
|
||||
Self {
|
||||
game_state,
|
||||
agent_player_id: 1,
|
||||
opponent_player_id: 2,
|
||||
agent_color: Color::White,
|
||||
max_steps: 1000,
|
||||
current_step: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
pub fn reset(&mut self) -> Vec<f32> {
|
||||
self.game_state = GameState::new(false);
|
||||
self.game_state.init_player("agent");
|
||||
self.game_state.init_player("opponent");
|
||||
|
||||
// Commencer la partie
|
||||
self.game_state.consume(&GameEvent::BeginGame {
|
||||
goes_first: self.agent_player_id,
|
||||
});
|
||||
|
||||
self.current_step = 0;
|
||||
self.game_state.to_vec_float()
|
||||
}
|
||||
|
||||
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
|
||||
let mut reward = 0.0;
|
||||
|
||||
// Appliquer l'action de l'agent
|
||||
if self.game_state.active_player_id == self.agent_player_id {
|
||||
reward += self.apply_agent_action(action);
|
||||
}
|
||||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game_state.active_player_id == self.opponent_player_id
|
||||
&& self.game_state.stage != Stage::Ended
|
||||
{
|
||||
reward += self.play_opponent_turn();
|
||||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game_state.stage == Stage::Ended
|
||||
|| self.game_state.determine_winner().is_some()
|
||||
|| self.current_step >= self.max_steps;
|
||||
|
||||
// Récompense finale si la partie est terminée
|
||||
if done {
|
||||
if let Some(winner) = self.game_state.determine_winner() {
|
||||
if winner == self.agent_player_id {
|
||||
reward += 100.0; // Bonus pour gagner
|
||||
} else {
|
||||
reward -= 50.0; // Pénalité pour perdre
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.current_step += 1;
|
||||
let next_state = self.game_state.to_vec_float();
|
||||
(next_state, reward, done)
|
||||
}
|
||||
|
||||
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
reward += 0.1;
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game_state.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.agent_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game_state.dice.values.0, self.game_state.dice.values.1)
|
||||
} else {
|
||||
(self.game_state.dice.values.1, self.game_state.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = event {
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.agent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pénalité pour action invalide
|
||||
reward -= 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
|
||||
// TODO : use default bot strategy
|
||||
fn play_opponent_turn(&mut self) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
let event = match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_player_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let dice_roll_count = self
|
||||
.game_state
|
||||
.players
|
||||
.get(&self.opponent_player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let rules = MoveRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Stratégie simple : choix aléatoire
|
||||
let mut rng = thread_rng();
|
||||
let choosen_move = *possible_moves
|
||||
.choose(&mut rng)
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
GameEvent::Move {
|
||||
player_id: self.opponent_player_id,
|
||||
moves: if opponent_color == Color::White {
|
||||
choosen_move
|
||||
} else {
|
||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
GameEvent::Go {
|
||||
player_id: self.opponent_player_id,
|
||||
}
|
||||
}
|
||||
};
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
/// Entraîneur pour le modèle DQN
|
||||
pub struct DqnTrainer {
|
||||
agent: DqnAgent,
|
||||
env: TrictracEnv,
|
||||
}
|
||||
|
||||
impl DqnTrainer {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
Self {
|
||||
agent: DqnAgent::new(config),
|
||||
env: TrictracEnv::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_episode(&mut self) -> f32 {
|
||||
let mut total_reward = 0.0;
|
||||
let mut state = self.env.reset();
|
||||
// let mut step_count = 0;
|
||||
|
||||
loop {
|
||||
// step_count += 1;
|
||||
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||
let (next_state, reward, done) = self.env.step(action.clone());
|
||||
total_reward += reward;
|
||||
|
||||
let experience = Experience {
|
||||
state: state.clone(),
|
||||
action,
|
||||
reward,
|
||||
next_state: next_state.clone(),
|
||||
done,
|
||||
};
|
||||
self.agent.store_experience(experience);
|
||||
self.agent.train();
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
// if step_count % 100 == 0 {
|
||||
// println!("{:?}", next_state);
|
||||
// }
|
||||
state = next_state;
|
||||
}
|
||||
|
||||
total_reward
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
&mut self,
|
||||
episodes: usize,
|
||||
save_every: usize,
|
||||
model_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Démarrage de l'entraînement DQN pour {episodes} épisodes");
|
||||
|
||||
for episode in 1..=episodes {
|
||||
let reward = self.train_episode();
|
||||
|
||||
if episode % 100 == 0 {
|
||||
println!(
|
||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||
episode,
|
||||
episodes,
|
||||
reward,
|
||||
self.agent.get_epsilon(),
|
||||
self.agent.get_step_count()
|
||||
);
|
||||
}
|
||||
|
||||
if episode % save_every == 0 {
|
||||
let save_path = format!("{model_path}_episode_{episode}.json");
|
||||
self.agent.save_model(&save_path)?;
|
||||
println!("Modèle sauvegardé : {save_path}");
|
||||
}
|
||||
}
|
||||
|
||||
// Sauvegarder le modèle final
|
||||
let final_path = format!("{model_path}_final.json");
|
||||
self.agent.save_model(&final_path)?;
|
||||
println!("Modèle final sauvegardé : {final_path}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
109
bot/src/dqn_simple/main.rs
Normal file
109
bot/src/dqn_simple/main.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
use bot::dqn_simple::dqn_model::DqnConfig;
|
||||
use bot::dqn_simple::dqn_trainer::DqnTrainer;
|
||||
use bot::training_common::TrictracAction;
|
||||
use std::env;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::init();
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
// Paramètres par défaut
|
||||
let mut episodes = 1000;
|
||||
let mut model_path = "models/dqn_model".to_string();
|
||||
let mut save_every = 100;
|
||||
|
||||
// Parser les arguments de ligne de commande
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--episodes" => {
|
||||
if i + 1 < args.len() {
|
||||
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--model-path" => {
|
||||
if i + 1 < args.len() {
|
||||
model_path = args[i + 1].clone();
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --model-path nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--save-every" => {
|
||||
if i + 1 < args.len() {
|
||||
save_every = args[i + 1].parse().unwrap_or(100);
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
std::process::exit(0);
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Argument inconnu : {}", args[i]);
|
||||
print_help();
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Créer le dossier models s'il n'existe pas
|
||||
std::fs::create_dir_all("models")?;
|
||||
|
||||
println!("Configuration d'entraînement DQN :");
|
||||
println!(" Épisodes : {episodes}");
|
||||
println!(" Chemin du modèle : {model_path}");
|
||||
println!(" Sauvegarde tous les {save_every} épisodes");
|
||||
println!();
|
||||
|
||||
// Configuration DQN
|
||||
let config = DqnConfig {
|
||||
state_size: 36, // state.to_vec size
|
||||
hidden_size: 256,
|
||||
num_actions: TrictracAction::action_space_size(),
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 0.9, // Commencer avec plus d'exploration
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
};
|
||||
|
||||
// Créer et lancer l'entraîneur
|
||||
let mut trainer = DqnTrainer::new(config);
|
||||
trainer.train(episodes, save_every, &model_path)?;
|
||||
|
||||
println!("Entraînement terminé avec succès !");
|
||||
println!("Pour utiliser le modèle entraîné :");
|
||||
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Entraîneur DQN pour Trictrac");
|
||||
println!();
|
||||
println!("USAGE:");
|
||||
println!(" cargo run --bin=train_dqn [OPTIONS]");
|
||||
println!();
|
||||
println!("OPTIONS:");
|
||||
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
|
||||
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
||||
println!(" -h, --help Afficher cette aide");
|
||||
println!();
|
||||
println!("EXEMPLES:");
|
||||
println!(" cargo run --bin=train_dqn");
|
||||
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
|
||||
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
|
||||
}
|
||||
2
bot/src/dqn_simple/mod.rs
Normal file
2
bot/src/dqn_simple/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod dqn_model;
|
||||
pub mod dqn_trainer;
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
pub mod burnrl;
|
||||
pub mod dqn_simple;
|
||||
pub mod strategy;
|
||||
pub mod training_common;
|
||||
pub mod training_common_big;
|
||||
pub mod trictrac_board;
|
||||
|
||||
use log::debug;
|
||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
pub use strategy::default::DefaultStrategy;
|
||||
pub use strategy::dqn::DqnStrategy;
|
||||
pub use strategy::dqnburn::DqnBurnStrategy;
|
||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||
pub use strategy::random::RandomStrategy;
|
||||
|
|
|
|||
174
bot/src/strategy/dqn.rs
Normal file
174
bot/src/strategy/dqn.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use log::info;
|
||||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
|
||||
use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
pub struct DqnStrategy {
|
||||
pub game: GameState,
|
||||
pub player_id: PlayerId,
|
||||
pub color: Color,
|
||||
pub model: Option<SimpleNeuralNetwork>,
|
||||
}
|
||||
|
||||
impl Default for DqnStrategy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
game: GameState::default(),
|
||||
player_id: 1,
|
||||
color: Color::White,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DqnStrategy {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn new_with_model<P: AsRef<Path> + std::fmt::Debug>(model_path: P) -> Self {
|
||||
let mut strategy = Self::new();
|
||||
if let Ok(model) = SimpleNeuralNetwork::load(&model_path) {
|
||||
info!("Loading model {model_path:?}");
|
||||
strategy.model = Some(model);
|
||||
}
|
||||
strategy
|
||||
}
|
||||
|
||||
/// Utilise le modèle DQN pour choisir une action valide
|
||||
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||
if let Some(ref model) = self.model {
|
||||
let state = self.game.to_vec_float();
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Obtenir les Q-values pour toutes les actions
|
||||
let q_values = model.forward(&state);
|
||||
|
||||
// Trouver la meilleure action valide
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for action in &valid_actions {
|
||||
let action_index = action.to_action_index();
|
||||
if action_index < q_values.len() {
|
||||
let q_value = q_values[action_index];
|
||||
if q_value > best_q_value {
|
||||
best_q_value = q_value;
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(best_action.clone())
|
||||
} else {
|
||||
// Fallback : action aléatoire valide
|
||||
sample_valid_action(&self.game)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for DqnStrategy {
|
||||
fn get_game(&self) -> &GameState {
|
||||
&self.game
|
||||
}
|
||||
|
||||
fn get_mut_game(&mut self) -> &mut GameState {
|
||||
&mut self.game
|
||||
}
|
||||
|
||||
fn set_color(&mut self, color: Color) {
|
||||
self.color = color;
|
||||
}
|
||||
|
||||
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||
self.player_id = player_id;
|
||||
}
|
||||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
self.game.dice_points.0
|
||||
}
|
||||
|
||||
fn calculate_adv_points(&self) -> u8 {
|
||||
self.game.dice_points.1
|
||||
}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
// Utiliser le DQN pour décider si on continue
|
||||
if let Some(action) = self.get_dqn_action() {
|
||||
matches!(action, TrictracAction::Go)
|
||||
} else {
|
||||
// Fallback : toujours continuer
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
// Utiliser le DQN pour choisir le mouvement
|
||||
if let Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
}) = self.get_dqn_action()
|
||||
{
|
||||
let dicevals = self.game.dice.values;
|
||||
let (mut dice1, mut dice2) = if dice_order {
|
||||
(dicevals.0, dicevals.1)
|
||||
} else {
|
||||
(dicevals.1, dicevals.0)
|
||||
};
|
||||
|
||||
if from1 == 0 {
|
||||
// empty move
|
||||
dice1 = 0;
|
||||
}
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
if 24 < to1 {
|
||||
// sortie
|
||||
to1 = 0;
|
||||
}
|
||||
if from2 == 0 {
|
||||
// empty move
|
||||
dice2 = 0;
|
||||
}
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
if 24 < to2 {
|
||||
// sortie
|
||||
to2 = 0;
|
||||
}
|
||||
|
||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
let chosen_move = if self.color == Color::White {
|
||||
(checker_move1, checker_move2)
|
||||
} else {
|
||||
(checker_move1.mirror(), checker_move2.mirror())
|
||||
};
|
||||
|
||||
return chosen_move;
|
||||
}
|
||||
|
||||
// Fallback : utiliser la stratégie par défaut
|
||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
let chosen_move = *possible_moves
|
||||
.first()
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
if self.color == Color::White {
|
||||
chosen_move
|
||||
} else {
|
||||
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod default;
|
||||
pub mod dqn;
|
||||
pub mod dqnburn;
|
||||
pub mod erroneous_moves;
|
||||
pub mod random;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
/// training_common.rs : environnement avec espace d'actions optimisé
|
||||
/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' )
|
||||
use std::cmp::{max, min};
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
|
||||
|
|
|
|||
266
bot/src/training_common_big.rs
Normal file
266
bot/src/training_common_big.rs
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
use std::cmp::{max, min};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use store::{CheckerMove, Dice};
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum TrictracAction {
|
||||
/// Lancer les dés
|
||||
Roll,
|
||||
/// Continuer après avoir gagné un trou
|
||||
Go,
|
||||
/// Effectuer un mouvement de pions
|
||||
Move {
|
||||
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||
from1: usize, // position de départ du premier pion (0-24)
|
||||
from2: usize, // position de départ du deuxième pion (0-24)
|
||||
},
|
||||
// Marquer les points : à activer si support des écoles
|
||||
// Mark,
|
||||
}
|
||||
|
||||
impl TrictracAction {
|
||||
/// Encode une action en index pour le réseau de neurones
|
||||
pub fn to_action_index(&self) -> usize {
|
||||
match self {
|
||||
TrictracAction::Roll => 0,
|
||||
TrictracAction::Go => 1,
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Encoder les mouvements dans l'espace d'actions
|
||||
// Indices 2+ pour les mouvements
|
||||
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier)
|
||||
let mut start = 2;
|
||||
if !dice_order {
|
||||
// 25 * 25 = 625
|
||||
start += 625;
|
||||
}
|
||||
start + from1 * 25 + from2
|
||||
} // TrictracAction::Mark => 1252,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un index d'action en TrictracAction
|
||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||
match index {
|
||||
0 => Some(TrictracAction::Roll),
|
||||
// 1252 => Some(TrictracAction::Mark),
|
||||
1 => Some(TrictracAction::Go),
|
||||
i if i >= 3 => {
|
||||
let move_code = i - 3;
|
||||
let (dice_order, from1, from2) = Self::decode_move(move_code);
|
||||
Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un entier en paire de mouvements
|
||||
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||
let mut encoded = code;
|
||||
let dice_order = code < 626;
|
||||
if !dice_order {
|
||||
encoded -= 625
|
||||
}
|
||||
let from1 = encoded / 25;
|
||||
let from2 = 1 + encoded % 25;
|
||||
(dice_order, from1, from2)
|
||||
}
|
||||
|
||||
/// Retourne la taille de l'espace d'actions total
|
||||
pub fn action_space_size() -> usize {
|
||||
// 1 (Roll) + 1 (Go) + mouvements possibles
|
||||
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
|
||||
// Mais on peut optimiser en limitant aux positions valides (1-24)
|
||||
2 + (2 * 25 * 25) // = 1252
|
||||
}
|
||||
|
||||
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
||||
// match action {
|
||||
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
|
||||
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
|
||||
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
|
||||
// TrictracAction::Move {
|
||||
// dice_order,
|
||||
// from1,
|
||||
// from2,
|
||||
// } => {
|
||||
// // Effectuer un mouvement
|
||||
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
||||
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
||||
//
|
||||
// Some(GameEvent::Move {
|
||||
// player_id: self.agent_player_id,
|
||||
// moves: (checker_move1, checker_move2),
|
||||
// })
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
}
|
||||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||
use store::TurnStage;
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
||||
let active_player_id = game_state.active_player_id;
|
||||
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||
|
||||
if let Some(color) = player_color {
|
||||
match game_state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
valid_actions.push(TrictracAction::Roll);
|
||||
}
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||
panic!(
|
||||
"get_valid_actions not implemented for turn stage {:?}",
|
||||
game_state.turn_stage
|
||||
);
|
||||
// valid_actions.push(TrictracAction::Mark);
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
valid_actions.push(TrictracAction::Go);
|
||||
|
||||
// Ajoute aussi les mouvements possibles
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
if possible_moves.is_empty() {
|
||||
// Empty move
|
||||
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
||||
}
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
panic!("empty valid_actions for state {game_state}");
|
||||
}
|
||||
valid_actions
|
||||
}
|
||||
|
||||
// Valid only for White player
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
) -> TrictracAction {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
let from1 = move1.get_from();
|
||||
let from2 = move2.get_from();
|
||||
|
||||
let mut diff_move1 = if to1 > 0 {
|
||||
// Mouvement sans sortie
|
||||
to1 - from1
|
||||
} else {
|
||||
// sortie, on utilise la valeur du dé
|
||||
if to2 > 0 {
|
||||
// sortie pour le mouvement 1 uniquement
|
||||
let dice2 = to2 - from2;
|
||||
if dice2 == dice.values.0 as usize {
|
||||
dice.values.1 as usize
|
||||
} else {
|
||||
dice.values.0 as usize
|
||||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
|
||||
let rest_field = 12;
|
||||
if to1 == rest_field
|
||||
&& to2 == rest_field
|
||||
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
|
||||
{
|
||||
// prise par puissance
|
||||
diff_move1 += 1;
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order: diff_move1 == dice.values.0 as usize,
|
||||
from1: move1.get_from(),
|
||||
from2: move2.get_from(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retourne les indices des actions valides
|
||||
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
||||
get_valid_actions(game_state)
|
||||
.into_iter()
|
||||
.map(|action| action.to_action_index())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Sélectionne une action valide aléatoire
|
||||
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
let mut rng = thread_rng();
|
||||
valid_actions.choose(&mut rng).cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
let index = action.to_action_index();
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||
assert_eq!(81, index);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||
}
|
||||
}
|
||||
8
client_bevy/.cargo/config.toml
Normal file
8
client_bevy/.cargo/config.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
[target.x86_64-unknown-linux-gnu]
|
||||
linker = "clang"
|
||||
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
|
||||
|
||||
# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only'
|
||||
# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains.
|
||||
#[profile.dev]
|
||||
#debug = 1
|
||||
14
client_bevy/Cargo.toml
Normal file
14
client_bevy/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "trictrac-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
bevy = { version = "0.11.3" }
|
||||
bevy_renet = "0.0.9"
|
||||
bincode = "1.3.3"
|
||||
renet = "0.0.13"
|
||||
store = { path = "../store" }
|
||||
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
Binary file not shown.
BIN
client_bevy/assets/board.png
Normal file
BIN
client_bevy/assets/board.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
BIN
client_bevy/assets/sound/click.wav
Normal file
BIN
client_bevy/assets/sound/click.wav
Normal file
Binary file not shown.
BIN
client_bevy/assets/sound/throw.wav
Executable file
BIN
client_bevy/assets/sound/throw.wav
Executable file
Binary file not shown.
BIN
client_bevy/assets/tac.png
Normal file
BIN
client_bevy/assets/tac.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
BIN
client_bevy/assets/tic.png
Normal file
BIN
client_bevy/assets/tic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 KiB |
334
client_bevy/src/main.rs
Normal file
334
client_bevy/src/main.rs
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
use std::{net::UdpSocket, time::SystemTime};
|
||||
|
||||
use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES};
|
||||
use store::{GameEvent, GameState, CheckerMove};
|
||||
|
||||
use bevy::prelude::*;
|
||||
use bevy::window::PrimaryWindow;
|
||||
use bevy_renet::{
|
||||
renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient},
|
||||
transport::{client_connected, NetcodeClientPlugin},
|
||||
RenetClientPlugin,
|
||||
};
|
||||
|
||||
#[derive(Debug, Resource)]
|
||||
struct CurrentClientId(u64);
|
||||
|
||||
#[derive(Resource)]
|
||||
struct BevyGameState(GameState);
|
||||
|
||||
impl Default for BevyGameState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
0: GameState::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Resource, Deref, DerefMut)]
|
||||
struct GameUIState {
|
||||
selected_tile: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for GameUIState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
selected_tile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Event)]
|
||||
struct BevyGameEvent(GameEvent);
|
||||
|
||||
// This id needs to be the same as the server is using
|
||||
const PROTOCOL_ID: u64 = 2878;
|
||||
|
||||
fn main() {
|
||||
// Get username from stdin args
|
||||
let args = std::env::args().collect::<Vec<String>>();
|
||||
let username = &args[1];
|
||||
|
||||
let (client, transport, client_id) = new_renet_client(&username).unwrap();
|
||||
App::new()
|
||||
// Lets add a nice dark grey background color
|
||||
.insert_resource(ClearColor(Color::hex("282828").unwrap()))
|
||||
.add_plugins(DefaultPlugins.set(WindowPlugin {
|
||||
primary_window: Some(Window {
|
||||
// Adding the username to the window title makes debugging a whole lot easier.
|
||||
title: format!("TricTrac <{}>", username),
|
||||
resolution: (1080.0, 1080.0).into(),
|
||||
..default()
|
||||
}),
|
||||
..default()
|
||||
}))
|
||||
// Add our game state and register GameEvent as a bevy event
|
||||
.insert_resource(BevyGameState::default())
|
||||
.insert_resource(GameUIState::default())
|
||||
.add_event::<BevyGameEvent>()
|
||||
// Renet setup
|
||||
.add_plugins(RenetClientPlugin)
|
||||
.add_plugins(NetcodeClientPlugin)
|
||||
.insert_resource(client)
|
||||
.insert_resource(transport)
|
||||
.insert_resource(CurrentClientId(client_id))
|
||||
.add_systems(Startup, setup)
|
||||
.add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system))
|
||||
.add_systems(
|
||||
PostUpdate,
|
||||
receive_events_from_server.run_if(client_connected()),
|
||||
)
|
||||
.run();
|
||||
}
|
||||
|
||||
////////// COMPONENTS //////////
|
||||
#[derive(Component)]
|
||||
struct UIRoot;
|
||||
|
||||
#[derive(Component)]
|
||||
struct WaitingText;
|
||||
|
||||
#[derive(Component)]
|
||||
struct Board {
|
||||
squares: [Square; 26]
|
||||
}
|
||||
|
||||
impl Default for Board {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
squares: [Square { count: 0, color: None, position: 0}; 26]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Board {
|
||||
fn square_at(&self, position: usize) -> Square {
|
||||
self.squares[position]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Component, Clone, Copy)]
|
||||
struct Square {
|
||||
count: usize,
|
||||
color: Option<bool>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
////////// UPDATE SYSTEMS //////////
|
||||
fn update_board(
|
||||
mut commands: Commands,
|
||||
game_state: Res<BevyGameState>,
|
||||
mut game_events: EventReader<BevyGameEvent>,
|
||||
asset_server: Res<AssetServer>,
|
||||
) {
|
||||
for event in game_events.iter() {
|
||||
match event.0 {
|
||||
GameEvent::Move { player_id, moves } => {
|
||||
// trictrac positions, TODO : dépend de player_id
|
||||
let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)};
|
||||
let texture =
|
||||
asset_server.load(match game_state.0.players[&player_id].color {
|
||||
store::Color::Black => "tac.png",
|
||||
store::Color::White => "tic.png",
|
||||
});
|
||||
|
||||
info!("spawning tictac sprite");
|
||||
commands.spawn(SpriteBundle {
|
||||
transform: Transform::from_xyz(
|
||||
83.0 * (x as f32 - 1.0),
|
||||
-30.0 + 540.0 * (y as f32 - 1.0),
|
||||
0.0,
|
||||
),
|
||||
sprite: Sprite {
|
||||
custom_size: Some(Vec2::new(83.0, 83.0)),
|
||||
..default()
|
||||
},
|
||||
texture: texture.into(),
|
||||
..default()
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_waiting_text(mut text_query: Query<&mut Text, With<WaitingText>>, time: Res<Time>) {
|
||||
if let Ok(mut text) = text_query.get_single_mut() {
|
||||
let num_dots = (time.elapsed_seconds() as usize % 3) + 1;
|
||||
text.sections[0].value = format!(
|
||||
"Waiting for an opponent{}{}",
|
||||
".".repeat(num_dots as usize),
|
||||
// Pad with spaces to avoid text changing width and dancing all around the screen 🕺
|
||||
" ".repeat(3 - num_dots as usize)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn input(
|
||||
primary_query: Query<&Window, With<PrimaryWindow>>,
|
||||
// windows: Res<Windows>,
|
||||
input: Res<Input<MouseButton>>,
|
||||
game_state: Res<BevyGameState>,
|
||||
mut game_ui_state: ResMut<GameUIState>,
|
||||
mut client: ResMut<RenetClient>,
|
||||
client_id: Res<CurrentClientId>,
|
||||
) {
|
||||
// We only want to handle inputs once we are ingame
|
||||
if game_state.0.stage != store::Stage::InGame {
|
||||
return;
|
||||
}
|
||||
|
||||
let window = primary_query.get_single().unwrap();
|
||||
if let Some(mouse_position) = window.cursor_position() {
|
||||
// Determine the index of the tile that the mouse is currently over
|
||||
// NOTE: This calculation assumes a fixed window size.
|
||||
// That's fine for now, but consider using the windows size instead.
|
||||
let mut tile_x: usize = (mouse_position.x / 83.0).floor() as usize;
|
||||
let tile_y: usize = (mouse_position.y / 540.0).floor() as usize;
|
||||
if tile_x > 5 {
|
||||
// remove the middle bar offset
|
||||
tile_x = tile_x - 1
|
||||
}
|
||||
// let tile = tile_x + tile_y * 12;
|
||||
|
||||
// traduction en position backgammon
|
||||
let tile = if tile_y == 0 {
|
||||
13 + tile_x
|
||||
} else {
|
||||
12 - tile_x
|
||||
};
|
||||
|
||||
// If mouse is outside of board we do nothing
|
||||
if 23 < tile {
|
||||
return;
|
||||
}
|
||||
|
||||
// If left mouse button is pressed, send a place tile event to the server
|
||||
if input.just_pressed(MouseButton::Left) {
|
||||
info!("select piece at tile {:?}", tile);
|
||||
if game_ui_state.selected_tile.is_some() {
|
||||
let from_tile = game_ui_state.selected_tile.unwrap();
|
||||
info!("sending movement from: {:?} to: {:?} ", from_tile, tile);
|
||||
let event = GameEvent::Move {
|
||||
player_id: client_id.0,
|
||||
moves: (
|
||||
CheckerMove::new(from_tile, tile).unwrap(),
|
||||
CheckerMove::new(from_tile, tile).unwrap()
|
||||
)
|
||||
};
|
||||
client.send_message(0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
game_ui_state.selected_tile = if game_ui_state.selected_tile.is_some() {
|
||||
None
|
||||
} else {
|
||||
Some(tile)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////// SETUP //////////
|
||||
fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
|
||||
// Tric Trac is a 2D game
|
||||
// To show 2D sprites we need a 2D camera
|
||||
commands.spawn(Camera2dBundle::default());
|
||||
|
||||
// Spawn board background
|
||||
commands.spawn(SpriteBundle {
|
||||
transform: Transform::from_xyz(0.0, -30.0, 0.0),
|
||||
sprite: Sprite {
|
||||
custom_size: Some(Vec2::new(1080.0, 927.0)),
|
||||
..default()
|
||||
},
|
||||
texture: asset_server.load("board.png").into(),
|
||||
..default()
|
||||
});
|
||||
|
||||
// Spawn pregame ui
|
||||
commands
|
||||
// A container that centers its children on the screen
|
||||
.spawn(NodeBundle {
|
||||
style: Style {
|
||||
position_type: PositionType::Absolute,
|
||||
left: Val::Px(0.0),
|
||||
top: Val::Px(0.0),
|
||||
width: Val::Percent(100.0),
|
||||
height: Val::Percent(100.0),
|
||||
align_items: AlignItems::Center,
|
||||
justify_content: JustifyContent::Center,
|
||||
..default()
|
||||
},
|
||||
..default()
|
||||
})
|
||||
.insert(UIRoot)
|
||||
.with_children(|parent| {
|
||||
// parent.spawn(Board::default()); // panic
|
||||
parent
|
||||
.spawn(TextBundle::from_section(
|
||||
"Waiting for an opponent...",
|
||||
TextStyle {
|
||||
font: asset_server.load("Inconsolata.ttf"),
|
||||
font_size: 24.0,
|
||||
color: Color::hex("ebdbb2").unwrap(),
|
||||
},
|
||||
))
|
||||
.insert(WaitingText);
|
||||
});
|
||||
}
|
||||
|
||||
////////// RENET NETWORKING //////////
|
||||
// Creates a RenetClient thats already connected to a server.
|
||||
// Returns an Err if connection fails
|
||||
fn new_renet_client(
|
||||
username: &String,
|
||||
) -> anyhow::Result<(RenetClient, NetcodeClientTransport, u64)> {
|
||||
let client = RenetClient::new(ConnectionConfig::default());
|
||||
let server_addr = "127.0.0.1:5000".parse()?;
|
||||
let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?;
|
||||
let client_id = current_time.as_millis() as u64;
|
||||
|
||||
// Place username in user data
|
||||
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
|
||||
if username.len() > NETCODE_USER_DATA_BYTES - 8 {
|
||||
panic!("Username is too big");
|
||||
}
|
||||
user_data[0..8].copy_from_slice(&(username.len() as u64).to_le_bytes());
|
||||
user_data[8..username.len() + 8].copy_from_slice(username.as_bytes());
|
||||
|
||||
let authentication = ClientAuthentication::Unsecure {
|
||||
server_addr,
|
||||
client_id,
|
||||
user_data: Some(user_data),
|
||||
protocol_id: PROTOCOL_ID,
|
||||
};
|
||||
let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
|
||||
|
||||
Ok((client, transport, client_id))
|
||||
}
|
||||
|
||||
fn receive_events_from_server(
|
||||
mut client: ResMut<RenetClient>,
|
||||
mut game_state: ResMut<BevyGameState>,
|
||||
mut game_events: EventWriter<BevyGameEvent>,
|
||||
) {
|
||||
while let Some(message) = client.receive_message(0) {
|
||||
// Whenever the server sends a message we know that it must be a game event
|
||||
let event: GameEvent = bincode::deserialize(&message).unwrap();
|
||||
trace!("{:#?}", event);
|
||||
|
||||
// We trust the server - It's always been good to us!
|
||||
// No need to validate the events it is sending us
|
||||
game_state.0.consume(&event);
|
||||
|
||||
// Send the event into the bevy event system so systems can react to it
|
||||
game_events.send(BevyGameEvent(event));
|
||||
}
|
||||
}
|
||||
|
||||
// If any error is found we just panic
|
||||
fn panic_on_error_system(mut renet_error: EventReader<NetcodeTransportError>) {
|
||||
for e in renet_error.iter() {
|
||||
panic!("{}", e);
|
||||
}
|
||||
}
|
||||
14
client_tui/Cargo.toml
Normal file
14
client_tui/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "client_tui"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.89"
|
||||
bincode = "1.3.3"
|
||||
crossterm = "0.28.1"
|
||||
ratatui = "0.28.1"
|
||||
# renet = "0.0.13"
|
||||
store = { path = "../store" }
|
||||
53
client_tui/src/app.rs
Normal file
53
client_tui/src/app.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Application.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct App {
|
||||
// should the application exit?
|
||||
pub should_quit: bool,
|
||||
// counter
|
||||
pub counter: u8,
|
||||
}
|
||||
|
||||
impl App {
|
||||
// Constructs a new instance of [`App`].
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
// Handles the tick event of the terminal.
|
||||
pub fn tick(&self) {}
|
||||
|
||||
// Set running to false to quit the application.
|
||||
pub fn quit(&mut self) {
|
||||
self.should_quit = true;
|
||||
}
|
||||
|
||||
pub fn increment_counter(&mut self) {
|
||||
if let Some(res) = self.counter.checked_add(1) {
|
||||
self.counter = res;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrement_counter(&mut self) {
|
||||
if let Some(res) = self.counter.checked_sub(1) {
|
||||
self.counter = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_app_increment_counter() {
|
||||
let mut app = App::default();
|
||||
app.increment_counter();
|
||||
assert_eq!(app.counter, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_decrement_counter() {
|
||||
let mut app = App::default();
|
||||
app.decrement_counter();
|
||||
assert_eq!(app.counter, 0);
|
||||
}
|
||||
}
|
||||
87
client_tui/src/event.rs
Normal file
87
client_tui/src/event.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
use std::{
|
||||
sync::mpsc,
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
|
||||
|
||||
// Terminal events.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum Event {
|
||||
// Terminal tick.
|
||||
Tick,
|
||||
// Key press.
|
||||
Key(KeyEvent),
|
||||
// Mouse click/scroll.
|
||||
Mouse(MouseEvent),
|
||||
// Terminal resize.
|
||||
Resize(u16, u16),
|
||||
}
|
||||
|
||||
// Terminal event handler.
|
||||
#[derive(Debug)]
|
||||
pub struct EventHandler {
|
||||
// Event sender channel.
|
||||
#[allow(dead_code)]
|
||||
sender: mpsc::Sender<Event>,
|
||||
// Event receiver channel.
|
||||
receiver: mpsc::Receiver<Event>,
|
||||
// Event handler thread.
|
||||
#[allow(dead_code)]
|
||||
handler: thread::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl EventHandler {
|
||||
// Constructs a new instance of [`EventHandler`].
|
||||
pub fn new(tick_rate: u64) -> Self {
|
||||
let tick_rate = Duration::from_millis(tick_rate);
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let handler = {
|
||||
let sender = sender.clone();
|
||||
thread::spawn(move || {
|
||||
let mut last_tick = Instant::now();
|
||||
loop {
|
||||
let timeout = tick_rate
|
||||
.checked_sub(last_tick.elapsed())
|
||||
.unwrap_or(tick_rate);
|
||||
|
||||
if event::poll(timeout).expect("no events available") {
|
||||
match event::read().expect("unable to read event") {
|
||||
CrosstermEvent::Key(e) => {
|
||||
if e.kind == event::KeyEventKind::Press {
|
||||
sender.send(Event::Key(e))
|
||||
} else {
|
||||
Ok(()) // ignore KeyEventKind::Release on windows
|
||||
}
|
||||
}
|
||||
CrosstermEvent::Mouse(e) => sender.send(Event::Mouse(e)),
|
||||
CrosstermEvent::Resize(w, h) => sender.send(Event::Resize(w, h)),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
.expect("failed to send terminal event")
|
||||
}
|
||||
|
||||
if last_tick.elapsed() >= tick_rate {
|
||||
sender.send(Event::Tick).expect("failed to send tick event");
|
||||
last_tick = Instant::now();
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
Self {
|
||||
sender,
|
||||
receiver,
|
||||
handler,
|
||||
}
|
||||
}
|
||||
|
||||
// Receive the next event from the handler thread.
|
||||
//
|
||||
// This function will always block the current thread if
|
||||
// there is no data available and it's possible for more data to be sent.
|
||||
pub fn next(&self) -> Result<Event> {
|
||||
Ok(self.receiver.recv()?)
|
||||
}
|
||||
}
|
||||
50
client_tui/src/main.rs
Normal file
50
client_tui/src/main.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Application.
|
||||
pub mod app;
|
||||
|
||||
// Terminal events handler.
|
||||
pub mod event;
|
||||
|
||||
// Widget renderer.
|
||||
pub mod ui;
|
||||
|
||||
// Terminal user interface.
|
||||
pub mod tui;
|
||||
|
||||
// Application updater.
|
||||
pub mod update;
|
||||
|
||||
use anyhow::Result;
|
||||
use app::App;
|
||||
use event::{Event, EventHandler};
|
||||
use ratatui::{backend::CrosstermBackend, Terminal};
|
||||
use tui::Tui;
|
||||
use update::update;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Create an application.
|
||||
let mut app = App::new();
|
||||
|
||||
// Initialize the terminal user interface.
|
||||
let backend = CrosstermBackend::new(std::io::stderr());
|
||||
let terminal = Terminal::new(backend)?;
|
||||
let events = EventHandler::new(250);
|
||||
let mut tui = Tui::new(terminal, events);
|
||||
tui.enter()?;
|
||||
|
||||
// Start the main loop.
|
||||
while !app.should_quit {
|
||||
// Render the user interface.
|
||||
tui.draw(&mut app)?;
|
||||
// Handle events.
|
||||
match tui.events.next()? {
|
||||
Event::Tick => {}
|
||||
Event::Key(key_event) => update(&mut app, key_event),
|
||||
Event::Mouse(_) => {}
|
||||
Event::Resize(_, _) => {}
|
||||
};
|
||||
}
|
||||
|
||||
// Exit the user interface.
|
||||
tui.exit()?;
|
||||
Ok(())
|
||||
}
|
||||
77
client_tui/src/tui.rs
Normal file
77
client_tui/src/tui.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
use std::{io, panic};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossterm::{
|
||||
event::{DisableMouseCapture, EnableMouseCapture},
|
||||
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
};
|
||||
|
||||
pub type CrosstermTerminal = ratatui::Terminal<ratatui::backend::CrosstermBackend<std::io::Stderr>>;
|
||||
|
||||
use crate::{app::App, event::EventHandler, ui};
|
||||
|
||||
// Representation of a terminal user interface.
|
||||
//
|
||||
// It is responsible for setting up the terminal,
|
||||
// initializing the interface and handling the draw events.
|
||||
pub struct Tui {
|
||||
// Interface to the Terminal.
|
||||
terminal: CrosstermTerminal,
|
||||
// Terminal event handler.
|
||||
pub events: EventHandler,
|
||||
}
|
||||
|
||||
impl Tui {
|
||||
// Constructs a new instance of [`Tui`].
|
||||
pub fn new(terminal: CrosstermTerminal, events: EventHandler) -> Self {
|
||||
Self { terminal, events }
|
||||
}
|
||||
|
||||
// Initializes the terminal interface.
|
||||
//
|
||||
// It enables the raw mode and sets terminal properties.
|
||||
pub fn enter(&mut self) -> Result<()> {
|
||||
terminal::enable_raw_mode()?;
|
||||
crossterm::execute!(io::stderr(), EnterAlternateScreen, EnableMouseCapture)?;
|
||||
|
||||
// Define a custom panic hook to reset the terminal properties.
|
||||
// This way, you won't have your terminal messed up if an unexpected error happens.
|
||||
let panic_hook = panic::take_hook();
|
||||
panic::set_hook(Box::new(move |panic| {
|
||||
Self::reset().expect("failed to reset the terminal");
|
||||
panic_hook(panic);
|
||||
}));
|
||||
|
||||
self.terminal.hide_cursor()?;
|
||||
self.terminal.clear()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// [`Draw`] the terminal interface by [`rendering`] the widgets.
|
||||
//
|
||||
// [`Draw`]: tui::Terminal::draw
|
||||
// [`rendering`]: crate::ui:render
|
||||
pub fn draw(&mut self, app: &mut App) -> Result<()> {
|
||||
self.terminal.draw(|frame| ui::render(app, frame))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Resets the terminal interface.
|
||||
//
|
||||
// This function is also used for the panic hook to revert
|
||||
// the terminal properties if unexpected errors occur.
|
||||
fn reset() -> Result<()> {
|
||||
terminal::disable_raw_mode()?;
|
||||
crossterm::execute!(io::stderr(), LeaveAlternateScreen, DisableMouseCapture)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Exits the terminal interface.
|
||||
//
|
||||
// It disables the raw mode and reverts back the terminal properties.
|
||||
pub fn exit(&mut self) -> Result<()> {
|
||||
Self::reset()?;
|
||||
self.terminal.show_cursor()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
30
client_tui/src/ui.rs
Normal file
30
client_tui/src/ui.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use ratatui::{
|
||||
prelude::{Alignment, Frame},
|
||||
style::{Color, Style},
|
||||
widgets::{Block, BorderType, Borders, Paragraph},
|
||||
};
|
||||
|
||||
use crate::app::App;
|
||||
|
||||
pub fn render(app: &mut App, f: &mut Frame) {
|
||||
f.render_widget(
|
||||
Paragraph::new(format!(
|
||||
"
|
||||
Press `Esc`, `Ctrl-C` or `q` to stop running.\n\
|
||||
Press `j` and `k` to increment and decrement the counter respectively.\n\
|
||||
Counter: {}
|
||||
",
|
||||
app.counter
|
||||
))
|
||||
.block(
|
||||
Block::default()
|
||||
.title("Counter App")
|
||||
.title_alignment(Alignment::Center)
|
||||
.borders(Borders::ALL)
|
||||
.border_type(BorderType::Rounded),
|
||||
)
|
||||
.style(Style::default().fg(Color::Yellow))
|
||||
.alignment(Alignment::Center),
|
||||
f.area(),
|
||||
)
|
||||
}
|
||||
17
client_tui/src/update.rs
Normal file
17
client_tui/src/update.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
|
||||
use crate::app::App;
|
||||
|
||||
pub fn update(app: &mut App, key_event: KeyEvent) {
|
||||
match key_event.code {
|
||||
KeyCode::Esc | KeyCode::Char('q') => app.quit(),
|
||||
KeyCode::Char('c') | KeyCode::Char('C') => {
|
||||
if key_event.modifiers == KeyModifiers::CONTROL {
|
||||
app.quit()
|
||||
}
|
||||
}
|
||||
KeyCode::Right | KeyCode::Char('j') => app.increment_counter(),
|
||||
KeyCode::Left | KeyCode::Char('k') => app.decrement_counter(),
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
3
justfile
3
justfile
|
|
@ -22,6 +22,9 @@ profile:
|
|||
pythonlib:
|
||||
maturin build -m store/Cargo.toml --release
|
||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||
trainsimple:
|
||||
cargo build --release --bin=train_dqn_simple
|
||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out
|
||||
trainbot algo:
|
||||
#python ./store/python/trainModel.py
|
||||
# cargo run --bin=train_dqn # ok
|
||||
|
|
|
|||
14
server/Cargo.toml
Normal file
14
server/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "trictrac-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
store = { path = "../store" }
|
||||
env_logger = "0.10.0"
|
||||
log = "0.4.20"
|
||||
pico-args = "0.5.0"
|
||||
renet = "0.0.13"
|
||||
bincode = "1.3.3"
|
||||
147
server/src/main.rs
Normal file
147
server/src/main.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
use log::{info, trace, warn};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
|
||||
use renet::{
|
||||
transport::{
|
||||
NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
|
||||
},
|
||||
ConnectionConfig, RenetServer, ServerEvent,
|
||||
};
|
||||
|
||||
// Only clients that can provide the same PROTOCOL_ID that the server is using will be able to connect.
|
||||
// This can be used to make sure players use the most recent version of the client for instance.
|
||||
pub const PROTOCOL_ID: u64 = 2878;
|
||||
|
||||
/// Utility function for extracting a players name from renet user data
|
||||
fn name_from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> String {
|
||||
let mut buffer = [0u8; 8];
|
||||
buffer.copy_from_slice(&user_data[0..8]);
|
||||
let mut len = u64::from_le_bytes(buffer) as usize;
|
||||
len = len.min(NETCODE_USER_DATA_BYTES - 8);
|
||||
let data = user_data[8..len + 8].to_vec();
|
||||
String::from_utf8(data).unwrap()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
let mut server = RenetServer::new(ConnectionConfig::default());
|
||||
|
||||
// Setup transport layer
|
||||
const SERVER_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5000);
|
||||
let socket: UdpSocket = UdpSocket::bind(SERVER_ADDR).unwrap();
|
||||
let server_config = ServerConfig {
|
||||
max_clients: 2,
|
||||
protocol_id: PROTOCOL_ID,
|
||||
public_addr: SERVER_ADDR,
|
||||
authentication: ServerAuthentication::Unsecure,
|
||||
};
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
|
||||
|
||||
trace!("❂ TricTrac server listening on {SERVER_ADDR}");
|
||||
|
||||
let mut game_state = store::GameState::default();
|
||||
let mut last_updated = Instant::now();
|
||||
loop {
|
||||
// Update server time
|
||||
let now = Instant::now();
|
||||
let delta_time = now - last_updated;
|
||||
server.update(delta_time);
|
||||
transport.update(delta_time, &mut server).unwrap();
|
||||
last_updated = now;
|
||||
|
||||
// Receive connection events from clients
|
||||
while let Some(event) = server.get_event() {
|
||||
match event {
|
||||
ServerEvent::ClientConnected { client_id } => {
|
||||
let user_data = transport.user_data(client_id).unwrap();
|
||||
|
||||
// Tell the recently joined player about the other player
|
||||
for (player_id, player) in game_state.players.iter() {
|
||||
let event = store::GameEvent::PlayerJoined {
|
||||
player_id: *player_id,
|
||||
name: player.name.clone(),
|
||||
};
|
||||
server.send_message(client_id, 0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
|
||||
// Add the new player to the game
|
||||
let event = store::GameEvent::PlayerJoined {
|
||||
player_id: client_id,
|
||||
name: name_from_user_data(&user_data),
|
||||
};
|
||||
game_state.consume(&event);
|
||||
|
||||
// Tell all players that a new player has joined
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
info!("🎉 Client {client_id} connected.");
|
||||
// In TicTacTussle the game can begin once two players has joined
|
||||
if game_state.players.len() == 2 {
|
||||
let event = store::GameEvent::BeginGame {
|
||||
goes_first: client_id,
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
trace!("The game gas begun");
|
||||
}
|
||||
}
|
||||
ServerEvent::ClientDisconnected {
|
||||
client_id,
|
||||
reason: _,
|
||||
} => {
|
||||
// First consume a disconnect event
|
||||
let event = store::GameEvent::PlayerDisconnected {
|
||||
player_id: client_id,
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
info!("Client {client_id} disconnected");
|
||||
|
||||
// Then end the game, since tic tac toe can't go on with a single player
|
||||
let event = store::GameEvent::EndGame {
|
||||
reason: store::EndGameReason::PlayerLeft {
|
||||
player_id: client_id,
|
||||
},
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
// NOTE: Since we don't authenticate users we can't do any reconnection attempts.
|
||||
// We simply have no way to know if the next user is the same as the one that disconnected.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Receive GameEvents from clients. Broadcast valid events.
|
||||
for client_id in server.clients_id().into_iter() {
|
||||
while let Some(message) = server.receive_message(client_id, 0) {
|
||||
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
|
||||
if game_state.validate(&event) {
|
||||
game_state.consume(&event);
|
||||
trace!("Player {client_id} sent:\n\t{event:#?}");
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
// Determine if a player has won the game
|
||||
if let Some(winner) = game_state.determine_winner() {
|
||||
let event = store::GameEvent::EndGame {
|
||||
reason: store::EndGameReason::PlayerWon { winner },
|
||||
};
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
} else {
|
||||
warn!("Player {client_id} sent invalid event:\n\t{event:#?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transport.send_packets(&mut server);
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
|
@ -742,10 +742,6 @@ impl GameState {
|
|||
});
|
||||
}
|
||||
|
||||
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
self.mark_points(player_id, points)
|
||||
}
|
||||
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
// Update player points and holes
|
||||
let mut new_hole = false;
|
||||
|
|
|
|||
Loading…
Reference in a new issue