Compare commits

..

51 commits

Author SHA1 Message Date
Henri Bourcereau 228dc5d50c Merge branch 'release/v0.1.1' 2026-01-15 17:14:40 +01:00
Henri Bourcereau de303ad574 refact: docs 2026-01-15 17:13:07 +01:00
Henri Bourcereau 7c50a6d07b refact: remove dqn_simple 2026-01-05 10:14:58 +01:00
Henri Bourcereau 74f692d7ba refact:remove server & bevy client ; remove _big bot algs 2026-01-04 12:43:21 +01:00
Henri Bourcereau 1e773671d9 bot train burnrl reward opponent 2026-01-04 10:12:18 +01:00
Henri Bourcereau 883ebf9bc1 chore: update dependencies 2025-09-13 17:42:50 +02:00
Henri Bourcereau c8d6712f09 fix: bot training : empty move if no other move allowed 2025-09-13 17:42:08 +02:00
Henri Bourcereau e66d8b6624 feat: TrictracBoard for kZero 2025-09-01 18:50:55 +02:00
Henri Bourcereau 4e299b04e2 feat: TrictracBoard::to_fen / from_fen 2025-08-30 16:14:21 +02:00
Henri Bourcereau 41383eddf6 feat: add GameState::from_string_id 2025-08-30 16:04:43 +02:00
Henri Bourcereau 7a501c90ea Merge tag 'v0.1.0' into develop
v0.1.0
2025-08-30 13:29:12 +02:00
Henri Bourcereau 2ef1f7ee50 Merge branch 'release/v0.1.0' 2025-08-30 13:29:07 +02:00
Henri Bourcereau 73cc6ee67e doc 2025-08-30 13:28:00 +02:00
Henri Bourcereau f2a89f60bc feat: Karel Peeters board game implementation 2025-08-28 19:20:06 +02:00
Henri Bourcereau 866ba611a6 fix: train.sh parsing 2025-08-26 17:12:19 +02:00
Henri Bourcereau e1b8d7e679 feat: bot training configuration file 2025-08-22 09:24:01 +02:00
Henri Bourcereau 8f41cc1412 feat: bot all algos 2025-08-21 17:39:45 +02:00
Henri Bourcereau 0c58490f87 feat: bot sac & ppo save & load 2025-08-21 14:35:25 +02:00
Henri Bourcereau afeb3561e0 refacto: bot one exec 2025-08-21 11:30:25 +02:00
Henri Bourcereau 18e85744d6 refacto: burnrl 2025-08-20 14:08:04 +02:00
Henri Bourcereau 97167ff389 feat: wip bot burn sac 2025-08-19 21:40:02 +02:00
Henri Bourcereau 088124fad1 feat: wip bot burn ppo 2025-08-19 17:46:22 +02:00
Henri Bourcereau fcd50bc0f2 refacto: bot directories 2025-08-19 16:27:37 +02:00
Henri Bourcereau e66921fcce refact models paths 2025-08-18 17:44:01 +02:00
Henri Bourcereau 2499c3377f refact script train bot 2025-08-17 17:42:59 +02:00
Henri Bourcereau a7aa087b18 fix: train bad move 2025-08-17 16:14:06 +02:00
Henri Bourcereau 1dc29d0ff0 chore:refacto clippy 2025-08-17 15:59:53 +02:00
Henri Bourcereau db9560dfac fix dqn burn small 2025-08-16 21:47:12 +02:00
Henri Bourcereau 47a8502b63 fix validations & client_cli 2025-08-16 17:59:00 +02:00
Henri Bourcereau c1e99a5f35 wip (tests fails) 2025-08-16 16:39:25 +02:00
Henri Bourcereau 56d155b911 wip debug 2025-08-16 11:13:31 +02:00
Henri Bourcereau d313cb6151 burnrl_big like before 2025-08-15 21:08:23 +02:00
Henri Bourcereau 93624c425d wip burnrl_big 2025-08-15 18:39:09 +02:00
Henri Bourcereau 86a67ae66a fix: train bot opponent rewards 2025-08-13 18:08:35 +02:00
Henri Bourcereau ac14341cf9 doc: schema store 2025-08-13 15:29:04 +02:00
Henri Bourcereau cfc19e6064 compile ok but diverge 2025-08-12 21:56:52 +02:00
Henri Bourcereau ec6ae26d38 wip reduction TrictracAction 2025-08-12 17:56:41 +02:00
Henri Bourcereau 5370eb4307 Merge branch 'feature/botTrainValidMoves' into develop 2025-08-11 18:56:17 +02:00
Henri Bourcereau bfd2a4ed47 burn-rl with valid moves 2025-08-11 18:53:45 +02:00
Henri Bourcereau 4353ba2bd1 doc params train bot 2025-08-10 21:49:15 +02:00
Henri Bourcereau 1fb04209f5 doc params train bot 2025-08-10 17:46:09 +02:00
Henri Bourcereau 778ac1817b script train bots 2025-08-10 15:35:12 +02:00
Henri Bourcereau e4b3092018 train burn-rl with integers 2025-08-10 08:39:31 +02:00
Henri Bourcereau 5b02293221 Merge branch 'feature/botBlackStrategy' into develop 2025-08-08 21:32:09 +02:00
Henri Bourcereau 17d29b8633 runcli with bot dqn burn-rl 2025-08-08 21:31:56 +02:00
Henri Bourcereau a19c5d8596 refact dqn simple 2025-08-08 18:58:21 +02:00
Henri Bourcereau 1b58ca4ccc refact dqn burn demo 2025-08-08 17:07:34 +02:00
Henri Bourcereau bf820ecc4e feat: bot random strategy 2025-08-08 16:24:40 +02:00
Henri Bourcereau b02ce8d185 fix dqn strategy color 2025-08-07 21:03:53 +02:00
Henri Bourcereau dc80243a1a fix black moves 2025-08-07 21:03:53 +02:00
Henri Bourcereau 12004ec4f3 wip bot mirror 2025-08-07 21:03:53 +02:00
76 changed files with 5755 additions and 3383 deletions

4
.gitignore vendored
View file

@ -11,6 +11,4 @@ devenv.local.nix
# generated by samply rust profiler # generated by samply rust profiler
profile.json profile.json
bot/models
# IA modles used by bots
/models

View file

@ -1,26 +0,0 @@
# Trictrac Project Guidelines
## Build & Run Commands
- Build: `cargo build`
- Test: `cargo test`
- Test specific: `cargo test -- test_name`
- Lint: `cargo clippy`
- Format: `cargo fmt`
- Run CLI: `RUST_LOG=info cargo run --bin=client_cli`
- Run CLI with bots: `RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dummy`
- Build Python lib: `maturin build -m store/Cargo.toml --release`
## Code Style
- Use Rust 2021 edition idioms
- Error handling: Use Result<T, Error> pattern with custom Error types
- Naming: snake_case for functions/variables, CamelCase for types
- Imports: Group standard lib, external crates, then internal modules
- Module structure: Prefer small, focused modules with clear responsibilities
- Documentation: Document public APIs with doc comments
- Testing: Write unit tests in same file as implementation
- Python bindings: Use pyo3 for creating Python modules
## Architecture
- Core game logic in `store` crate
- Multiple clients: CLI, TUI, Bevy (graphical)
- Bot interfaces in `bot` crate

2696
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
[workspace] [workspace]
resolver = "2" resolver = "2"
members = ["client_tui", "client_cli", "bot", "server", "store"] members = ["client_cli", "bot", "store"]

View file

@ -1,7 +1,41 @@
# Trictrac # Trictrac
Game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) in rust. This is a game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) rust implementation.
wip The project is on its early stages.
Rules (without "schools") are implemented, as well as a rudimentary terminal interface which allow you to play against a bot which plays randomly.
Training of AI bots is the work in progress.
## Usage
`cargo run --bin=client_cli -- --bot random`
## Roadmap
- [x] rules
- [x] command line interface
- [x] basic bot (random play)
- [ ] AI bot
- [ ] network game
- [ ] web client
## Code structure
- game rules and game state are implemented in the _store/_ folder.
- the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other
- the bots algorithms and the training of their models are implemented in the _bot/_ folder
### _store_ package
The game state is defined by the `GameState` struct in _store/src/game.rs_. The `to_string_id()` method allows this state to be encoded compactly in a string (without the played moves history). For a more readable textual representation, the `fmt::Display` trait is implemented.
### _client_cli_ package
`client_cli/src/game_runner.rs` contains the logic to make two bots play against each other.
### _bot_ package
- `bot/src/strategy/default.rs` contains the code for a basic bot strategy: it determines the list of valid moves (using the `get_possible_moves_sequences` method of `store::MoveRules`) and simply executes the first move in the list.
- `bot/src/strategy/dqnburn.rs` is another bot strategy that uses a reinforcement learning trained model with the DQN algorithm via the burn library (<https://burn.dev/>).
- `bot/scripts/trains.sh` allows you to train agents using different algorithms (DQN, PPO, SAC).

View file

@ -6,12 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]] [[bin]]
name = "train_dqn_burn" name = "burn_train"
path = "src/dqn/burnrl/main.rs" path = "src/burnrl/main.rs"
[[bin]]
name = "train_dqn"
path = "src/bin/train_dqn.rs"
[dependencies] [dependencies]
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
@ -20,5 +16,9 @@ serde_json = "1.0"
store = { path = "../store" } store = { path = "../store" }
rand = "0.8" rand = "0.8"
env_logger = "0.10" env_logger = "0.10"
burn = { version = "0.17", features = ["ndarray", "autodiff"] } burn = { version = "0.18", features = ["ndarray", "autodiff"] }
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" } burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
log = "0.4.20"
confy = "1.0.0"
board-game = "0.8.2"
internal-iterator = "0.2.3"

50
bot/scripts/train.sh Executable file
View file

@ -0,0 +1,50 @@
#!/usr/bin/env bash
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=17
BINBOT=burn_train
# BINBOT=train_ppo_burn
# BINBOT=train_dqn_burn
# BINBOT=train_dqn_burn_big
# BINBOT=train_dqn_burn_before
OPPONENT="random"
PLOT_EXT="png"
train() {
ALGO=$1
cargo build --release --bin=$BINBOT
NAME="$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$ALGO/$NAME.out"
mkdir -p "$LOGS_DIR/$ALGO"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" $ALGO | tee "$LOGS"
}
plot() {
ALGO=$1
NAME=$(ls -rt "$LOGS_DIR/$ALGO" | grep -v png | tail -n 1)
LOGS="$LOGS_DIR/$ALGO/$NAME"
cfgs=$(grep -v "info:" "$LOGS" | head -n $CFG_SIZE)
for cfg in $cfgs; do
eval "$cfg"
done
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT"
}
if [[ -z "$1" ]]; then
echo "Usage : train [plot] <algo>"
elif [ "$1" = "plot" ]; then
if [[ -z "$2" ]]; then
echo "Usage : train [plot] <algo>"
else
plot $2
fi
else
train $1
fi

49
bot/scripts/trainValid.sh Executable file
View file

@ -0,0 +1,49 @@
#!/usr/bin/env sh
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=11
OPPONENT="random"
PLOT_EXT="png"
train() {
cargo build --release --bin=train_dqn_burn_valid
NAME="trainValid_$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$NAME.out"
mkdir -p "$LOGS_DIR"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn_valid" | tee "$LOGS"
}
plot() {
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
for cfg in $cfgs; do
eval "$cfg"
done
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
}
avg() {
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
echo $LOGS
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }'
}
if [ "$1" = "plot" ]; then
plot
elif [ "$1" = "avg" ]; then
avg
else
train
fi

View file

@ -1,111 +0,0 @@
use bot::dqn::dqn_common::{DqnConfig, TrictracAction};
use bot::dqn::simple::dqn_trainer::DqnTrainer;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let args: Vec<String> = env::args().collect();
// Paramètres par défaut
let mut episodes = 1000;
let mut model_path = "models/dqn_model".to_string();
let mut save_every = 100;
// Parser les arguments de ligne de commande
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--episodes" => {
if i + 1 < args.len() {
episodes = args[i + 1].parse().unwrap_or(1000);
i += 2;
} else {
eprintln!("Erreur : --episodes nécessite une valeur");
std::process::exit(1);
}
}
"--model-path" => {
if i + 1 < args.len() {
model_path = args[i + 1].clone();
i += 2;
} else {
eprintln!("Erreur : --model-path nécessite une valeur");
std::process::exit(1);
}
}
"--save-every" => {
if i + 1 < args.len() {
save_every = args[i + 1].parse().unwrap_or(100);
i += 2;
} else {
eprintln!("Erreur : --save-every nécessite une valeur");
std::process::exit(1);
}
}
"--help" | "-h" => {
print_help();
std::process::exit(0);
}
_ => {
eprintln!("Argument inconnu : {}", args[i]);
print_help();
std::process::exit(1);
}
}
}
// Créer le dossier models s'il n'existe pas
std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes);
println!(" Chemin du modèle : {}", model_path);
println!(" Sauvegarde tous les {} épisodes", save_every);
println!();
// Configuration DQN
let config = DqnConfig {
state_size: 36, // state.to_vec size
hidden_size: 256,
num_actions: TrictracAction::action_space_size(),
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.9, // Commencer avec plus d'exploration
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
};
// Créer et lancer l'entraîneur
let mut trainer = DqnTrainer::new(config);
trainer.train(episodes, save_every, &model_path)?;
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
println!(
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
model_path
);
Ok(())
}
fn print_help() {
println!("Entraîneur DQN pour Trictrac");
println!();
println!("USAGE:");
println!(" cargo run --bin=train_dqn [OPTIONS]");
println!();
println!("OPTIONS:");
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
println!(" -h, --help Afficher cette aide");
println!();
println!("EXEMPLES:");
println!(" cargo run --bin=train_dqn");
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
}

195
bot/src/burnrl/algos/dqn.rs Normal file
View file

@ -0,0 +1,195 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
relu(self.linear_2.forward(layer_1_output))
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> DQNModel<B> for Net<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 8192;
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> DQN<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
// env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let mut agent = MyAgent::new(model);
// let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut policy_net = agent.model().as_ref().unwrap().clone();
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward: ElemType = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
let eps_threshold = conf.eps_end
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
let action =
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action);
episode_reward +=
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
policy_net =
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
* 100.0)
.round() as u32;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
envmut.goodmoves_count,
goodmoves_ratio,
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
env.reset();
episode_done = true;
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
let valid_agent = agent.valid();
if let Some(path) = &conf.save_path {
save_model(valid_agent.model().as_ref().unwrap(), path);
}
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -1,13 +1,16 @@
use crate::dqn::burnrl::utils::soft_update_linear; use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module; use burn::module::Module;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig; use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu; use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend}; use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor; use burn::tensor::Tensor;
use burn_rl::agent::DQN; use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig}; use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State}; use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime; use std::time::SystemTime;
#[derive(Module, Debug)] #[derive(Module, Debug)]
@ -60,37 +63,20 @@ impl<B: Backend> DQNModel<B> for Net<B> {
#[allow(unused)] #[allow(unused)]
const MEMORY_SIZE: usize = 8192; const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub num_episodes: usize,
// pub memory_size: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
num_episodes: 1000,
// memory_size: 8192,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
}
}
}
type MyAgent<E, B> = DQN<E, B, Net<B>>; type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)] #[allow(unused)]
pub fn run<E: Environment, B: AutodiffBackend>( // pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig, pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool, visualized: bool,
) -> DQN<E, B, Net<B>> { // ) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> { ) -> impl Agent<E> {
let mut env = E::new(visualized); let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new( let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(), <<E as Environment>::StateType as State>::size(),
@ -100,7 +86,16 @@ pub fn run<E: Environment, B: AutodiffBackend>(
let mut agent = MyAgent::new(model); let mut agent = MyAgent::new(model);
let config = DQNTrainingConfig::default(); // let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default(); let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
@ -145,22 +140,50 @@ pub fn run<E: Environment, B: AutodiffBackend>(
step += 1; step += 1;
episode_duration += 1; episode_duration += 1;
if snapshot.done() || episode_duration >= E::MAX_STEPS { if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
env.reset(); env.reset();
episode_done = true; episode_done = true;
println!(
"{{\"episode\": {}, \"reward\": {:.4}, \"steps count\": {}, \"duration\": {}}}",
episode,
episode_reward,
episode_duration,
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now(); now = SystemTime::now();
} else { } else {
state = *snapshot.state(); state = *snapshot.state();
} }
} }
} }
agent let valid_agent = agent.valid();
if let Some(path) = &conf.save_path {
save_model(valid_agent.model().as_ref().unwrap(), path);
}
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
} }

View file

@ -0,0 +1,6 @@
pub mod dqn;
pub mod dqn_valid;
pub mod ppo;
pub mod ppo_valid;
pub mod sac;
pub mod sac_valid;

191
bot/src/burnrl/algos/ppo.rs Normal file
View file

@ -0,0 +1,191 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear: Linear<B>,
linear_actor: Linear<B>,
linear_critic: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
let initializer = Initializer::XavierUniform { gain: 1.0 };
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
let layer_0_output = relu(self.linear.forward(input));
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
let values = self.linear_critic.forward(layer_0_output);
PPOOutput::<B>::new(policies, values)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}
impl<B: Backend> PPOModel<B> for Net<B> {}
#[allow(unused)]
const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let mut model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let agent = MyAgent::default();
let config = PPOTrainingConfig {
gamma: conf.gamma,
lambda: conf.lambda,
epsilon_clip: conf.epsilon_clip,
critic_weight: conf.critic_weight,
entropy_weight: conf.entropy_weight,
learning_rate: conf.learning_rate,
epochs: conf.epochs,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut now = SystemTime::now();
env.reset();
while !episode_done {
let state = env.state();
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
episode_duration += 1;
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
}
}
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
memory.clear();
}
if let Some(path) = &conf.save_path {
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
agent.valid(model)
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,191 @@
use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear: Linear<B>,
linear_actor: Linear<B>,
linear_critic: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
let initializer = Initializer::XavierUniform { gain: 1.0 };
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
let layer_0_output = relu(self.linear.forward(input));
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
let values = self.linear_critic.forward(layer_0_output);
PPOOutput::<B>::new(policies, values)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}
impl<B: Backend> PPOModel<B> for Net<B> {}
#[allow(unused)]
const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let mut model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let agent = MyAgent::default();
let config = PPOTrainingConfig {
gamma: conf.gamma,
lambda: conf.lambda,
epsilon_clip: conf.epsilon_clip,
critic_weight: conf.critic_weight,
entropy_weight: conf.entropy_weight,
learning_rate: conf.learning_rate,
epochs: conf.epochs,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut now = SystemTime::now();
env.reset();
while !episode_done {
let state = env.state();
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
episode_duration += 1;
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
}
}
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
memory.clear();
}
if let Some(path) = &conf.save_path {
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
agent.valid(model)
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

221
bot/src/burnrl/algos/sac.rs Normal file
View file

@ -0,0 +1,221 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Actor<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
softmax(self.linear_2.forward(layer_1_output), 1)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACActor<B> for Actor<B> {}
#[derive(Module, Debug)]
pub struct Critic<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
self.linear_2.forward(layer_1_output)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACCritic<B> for Critic<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
let config = SACTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
min_probability: conf.min_probability,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
let mut optimizer = SACOptimizer::new(
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.init(),
);
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
}
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,222 @@
use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Actor<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
softmax(self.linear_2.forward(layer_1_output), 1)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACActor<B> for Actor<B> {}
#[derive(Module, Debug)]
pub struct Critic<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
self.linear_2.forward(layer_1_output)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACCritic<B> for Critic<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
let config = SACTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
min_probability: conf.min_probability,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
let mut optimizer = SACOptimizer::new(
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.init(),
);
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
}
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,426 @@
use std::io::Write;
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -1.0012121;
const REWARD_VALID_MOVE: f32 = 1.0012121;
const REWARD_RATIO: f32 = 0.1;
const WIN_POINTS: f32 = 100.0;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
}
impl State for TrictracState {
type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default())
}
fn size() -> usize {
36
}
}
impl TrictracState {
/// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec();
let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36);
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
TrictracState { data }
}
}
/// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32,
}
impl Action for TrictracAction {
fn random() -> Self {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
TrictracAction {
index: rng.gen_range(0..Self::size() as u32),
}
}
fn enumerate() -> Vec<Self> {
(0..Self::size() as u32)
.map(|index| TrictracAction { index })
.collect()
}
fn size() -> usize {
514
}
}
impl From<u32> for TrictracAction {
fn from(index: u32) -> Self {
TrictracAction { index }
}
}
impl From<TrictracAction> for u32 {
fn from(action: TrictracAction) -> u32 {
action.index
}
}
/// Environnement Trictrac pour burn-rl
#[derive(Debug)]
pub struct TrictracEnvironment {
pub game: GameState,
active_player_id: PlayerId,
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub best_ratio: f32,
pub max_steps: usize,
pub pointrolls_count: usize,
pub goodmoves_count: usize,
pub goodmoves_ratio: f32,
pub visualized: bool,
}
impl Environment for TrictracEnvironment {
type StateType = TrictracState;
type ActionType = TrictracAction;
type RewardType = f32;
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
// Ajouter deux joueurs
game.init_player("DQN Agent");
game.init_player("Opponent");
let player1_id = 1;
let player2_id = 2;
// Commencer la partie
game.consume(&GameEvent::BeginGame { goes_first: 1 });
let current_state = TrictracState::from_game_state(&game);
TrictracEnvironment {
game,
active_player_id: player1_id,
opponent_id: player2_id,
current_state,
episode_reward: 0.0,
step_count: 0,
best_ratio: 0.0,
max_steps: 2000,
pointrolls_count: 0,
goodmoves_count: 0,
goodmoves_ratio: 0.0,
visualized,
}
}
fn state(&self) -> Self::StateType {
self.current_state
}
fn reset(&mut self) -> Snapshot<Self> {
// Réinitialiser le jeu
let history = self.game.history.clone();
self.game = GameState::new(false);
self.game.init_player("DQN Agent");
self.game.init_player("Opponent");
// Commencer la partie
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.goodmoves_ratio = if self.step_count == 0 {
0.0
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{history:?}").expect("could not write history log");
}
"!!!!"
} else {
""
};
// println!(
// "info: correct moves: {} ({}%) {}",
// self.goodmoves_count,
// (100.0 * self.goodmoves_ratio).round() as u32,
// warning
// );
self.step_count = 0;
self.pointrolls_count = 0;
self.goodmoves_count = 0;
Snapshot::new(self.current_state, 0.0, false)
}
fn step(&mut self, action: Self::ActionType) -> Snapshot<Self> {
self.step_count += 1;
// Convertir l'action burn-rl vers une action Trictrac
let trictrac_action = Self::convert_action(action);
let mut reward = 0.0;
let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
(reward, is_rollpoint) = self.execute_action(action);
if is_rollpoint {
self.pointrolls_count += 1;
}
if reward != ERROR_REWARD {
self.goodmoves_count += 1;
}
} else {
// Action non convertible, pénalité
panic!("action non convertible");
//reward = -0.5;
}
}
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
reward += self.play_opponent_if_needed();
}
// Vérifier si la partie est terminée
// let max_steps = self.max_steps;
// let max_steps = self.min_steps
// + (self.max_steps as f32 - self.min_steps)
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += WIN_POINTS; // Victoire
} else {
reward -= WIN_POINTS; // Défaite
}
}
}
let terminated = done || self.step_count >= self.max_steps;
// let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward += reward;
if self.visualized && terminated {
println!(
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
self.episode_reward, self.step_count
);
}
Snapshot::new(self.current_state, reward, terminated)
}
}
impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index(
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<training_common::TrictracAction> {
use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state);
if valid_actions.is_empty() {
return None;
}
// Mapper l'index d'action sur une action valide
let action_index = (action.index as usize) % valid_actions.len();
Some(valid_actions[action_index].clone())
}
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: training_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
use training_common::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;
// Appliquer l'événement si valide
if let Some(event) = action.to_event(&self.game) {
if self.game.validate(&event) {
self.game.consume(&event);
// reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.active_player_id,
dice: store::Dice {
values: dice_values,
},
};
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
}
}
} else {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = ERROR_REWARD;
self.game.mark_points_for_bot_training(self.opponent_id, 1);
}
} else {
reward = ERROR_REWARD;
self.game.mark_points_for_bot_training(self.opponent_id, 1);
}
(reward, is_rollpoint)
}
/// Fait jouer l'adversaire avec une stratégie simple
fn play_opponent_if_needed(&mut self) -> f32 {
let mut reward = 0.0;
// Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire
use crate::BotStrategy;
let mut strategy = crate::strategy::random::RandomStrategy::default();
strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
strategy.set_color(color);
}
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let mut calculate_points = false;
let opponent_color = store::Color::Black;
let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id,
},
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true;
GameEvent::RollResult {
player_id: self.opponent_id,
dice: store::Dice {
values: dice_values,
},
}
}
TurnStage::MarkPoints => {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
GameEvent::Mark {
player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).0,
}
}
TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
// pas de reward : déjà comptabilisé lors du tour de blanc
GameEvent::Mark {
player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).1,
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
GameEvent::Go {
player_id: self.opponent_id,
}
}
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: strategy.choose_move(),
},
};
if self.game.validate(&event) {
self.game.consume(&event);
if calculate_points {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points
reward -= REWARD_RATIO * (points as f32 - adv_points as f32);
}
}
}
reward
}
}
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

View file

@ -1,17 +1,20 @@
use crate::dqn::dqn_common; use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor}; use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State}; use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -1.0012121;
const REWARD_RATIO: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl /// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct TrictracState { pub struct TrictracState {
pub data: [f32; 36], // Représentation vectorielle de l'état du jeu pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
} }
impl State for TrictracState { impl State for TrictracState {
type Data = [f32; 36]; type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> { fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default()) Tensor::from_floats(self.data, &B::Device::default())
@ -25,8 +28,8 @@ impl State for TrictracState {
impl TrictracState { impl TrictracState {
/// Convertit un GameState en TrictracState /// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self { pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec_float(); let state_vec = game_state.to_vec();
let mut data = [0.0; 36]; let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille // Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36); let copy_len = state_vec.len().min(36);
@ -39,6 +42,7 @@ impl TrictracState {
/// Actions possibles dans Trictrac pour burn-rl /// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction { pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32, pub index: u32,
} }
@ -58,7 +62,9 @@ impl Action for TrictracAction {
} }
fn size() -> usize { fn size() -> usize {
1252 // état avec le plus de choix : mouvement
// choix premier dé : 16 (15 pions + aucun pion), choix deuxième dé 16, x2 ordre dé
64
} }
} }
@ -82,7 +88,9 @@ pub struct TrictracEnvironment {
opponent_id: PlayerId, opponent_id: PlayerId,
current_state: TrictracState, current_state: TrictracState,
episode_reward: f32, episode_reward: f32,
step_count: usize, pub step_count: usize,
pub max_steps: usize,
pub pointrolls_count: usize,
pub visualized: bool, pub visualized: bool,
} }
@ -91,8 +99,6 @@ impl Environment for TrictracEnvironment {
type ActionType = TrictracAction; type ActionType = TrictracAction;
type RewardType = f32; type RewardType = f32;
const MAX_STEPS: usize = 700; // Limite max pour éviter les parties infinies
fn new(visualized: bool) -> Self { fn new(visualized: bool) -> Self {
let mut game = GameState::new(false); let mut game = GameState::new(false);
@ -113,6 +119,8 @@ impl Environment for TrictracEnvironment {
current_state, current_state,
episode_reward: 0.0, episode_reward: 0.0,
step_count: 0, step_count: 0,
max_steps: 2000,
pointrolls_count: 0,
visualized, visualized,
} }
} }
@ -133,6 +141,7 @@ impl Environment for TrictracEnvironment {
self.current_state = TrictracState::from_game_state(&self.game); self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0; self.episode_reward = 0.0;
self.step_count = 0; self.step_count = 0;
self.pointrolls_count = 0;
Snapshot::new(self.current_state, 0.0, false) Snapshot::new(self.current_state, 0.0, false)
} }
@ -141,50 +150,52 @@ impl Environment for TrictracEnvironment {
self.step_count += 1; self.step_count += 1;
// Convertir l'action burn-rl vers une action Trictrac // Convertir l'action burn-rl vers une action Trictrac
let trictrac_action = self.convert_action(action, &self.game); // let trictrac_action = Self::convert_action(action);
let trictrac_action = self.convert_valid_action_index(action);
let mut reward = 0.0; let mut reward = 0.0;
let mut terminated = false; let is_rollpoint: bool;
// Exécuter l'action si c'est le tour de l'agent DQN // Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id { if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action { if let Some(action) = trictrac_action {
match self.execute_action(action) { (reward, is_rollpoint) = self.execute_action(action);
Ok(action_reward) => { // if reward != 0.0 {
reward = action_reward; // println!("info: self rew {reward}");
} // }
Err(_) => { if is_rollpoint {
// Action invalide, pénalité self.pointrolls_count += 1;
reward = -1.0;
}
} }
} else { } else {
// Action non convertible, pénalité // Action non convertible, pénalité
reward = -0.5; println!("info: action non convertible -> -1 {trictrac_action:?}");
reward = -1.0;
} }
} }
// Faire jouer l'adversaire (stratégie simple) // Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// let op_rew = self.play_opponent_if_needed();
// if op_rew != 0.0 {
// println!("info: op rew {op_rew}");
// }
// reward += op_rew;
reward += self.play_opponent_if_needed(); reward += self.play_opponent_if_needed();
} }
// Vérifier si la partie est terminée // Vérifier si la partie est terminée
let done = self.game.stage == Stage::Ended let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|| self.game.determine_winner().is_some()
|| self.step_count >= Self::MAX_STEPS;
if done { if done {
terminated = true;
// Récompense finale basée sur le résultat // Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() { if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id { if winner_id == self.active_player_id {
reward += 50.0; // Victoire reward += 100.0; // Victoire
} else { } else {
reward -= 25.0; // Défaite reward -= 100.0; // Défaite
} }
} }
} }
let terminated = done || self.step_count >= self.max_steps;
// Mettre à jour l'état // Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game); self.current_state = TrictracState::from_game_state(&self.game);
@ -202,25 +213,23 @@ impl Environment for TrictracEnvironment {
} }
impl TrictracEnvironment { impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac /// Convertit une action burn-rl vers une action Trictrac
fn convert_action( pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
&self, training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
action: TrictracAction,
game_state: &GameState,
) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
} }
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac /// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
fn convert_valid_action_index( fn convert_valid_action_index(
&self, &self,
action: TrictracAction, action: TrictracAction,
game_state: &GameState, ) -> Option<training_common::TrictracAction> {
) -> Option<dqn_common::TrictracAction> { use training_common::get_valid_actions;
use dqn_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel // Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state); let valid_actions = get_valid_actions(&self.game);
if valid_actions.is_empty() { if valid_actions.is_empty() {
return None; return None;
@ -232,75 +241,21 @@ impl TrictracEnvironment {
} }
/// Exécute une action Trictrac dans le jeu /// Exécute une action Trictrac dans le jeu
fn execute_action( // fn execute_action(
&mut self, // &mut self,
action: dqn_common::TrictracAction, // action: training_common::TrictracAction,
) -> Result<f32, Box<dyn std::error::Error>> { // ) -> Result<f32, Box<dyn std::error::Error>> {
use dqn_common::TrictracAction; fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
use training_common::TrictracAction;
let mut reward = 0.0; let mut reward = 0.0;
let mut is_rollpoint = false;
let event = match action {
TrictracAction::Roll => {
// Lancer les dés
reward += 0.1;
Some(GameEvent::Roll {
player_id: self.active_player_id,
})
}
// TrictracAction::Mark => {
// // Marquer des points
// let points = self.game.
// reward += 0.1 * points as f32;
// Some(GameEvent::Mark {
// player_id: self.active_player_id,
// points,
// })
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go {
player_id: self.active_player_id,
})
}
TrictracAction::Move {
dice_order,
from1,
from2,
} => {
// Effectuer un mouvement
let (dice1, dice2) = if dice_order {
(self.game.dice.values.0, self.game.dice.values.1)
} else {
(self.game.dice.values.1, self.game.dice.values.0)
};
let mut to1 = from1 + dice1 as usize;
let mut to2 = from2 + dice2 as usize;
// Gestion prise de coin par puissance
let opp_rest_field = 13;
if to1 == opp_rest_field && to2 == opp_rest_field {
to1 -= 1;
to2 -= 1;
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move {
player_id: self.active_player_id,
moves: (checker_move1, checker_move2),
})
}
};
// Appliquer l'événement si valide // Appliquer l'événement si valide
if let Some(event) = event { if let Some(event) = action.to_event(&self.game) {
if self.game.validate(&event) { if self.game.validate(&event) {
self.game.consume(&event); self.game.consume(&event);
// reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll // Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) { if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -314,16 +269,27 @@ impl TrictracEnvironment {
if self.game.validate(&dice_event) { if self.game.validate(&dice_event) {
self.game.consume(&dice_event); self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points; let (points, adv_points) = self.game.dice_points;
reward += 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points reward += REWARD_RATIO * (points as f32 - adv_points as f32);
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
} }
} }
} else { } else {
// Pénalité pour action invalide // Pénalité pour action invalide
reward -= 2.0; // on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = ERROR_REWARD;
self.game.mark_points_for_bot_training(self.opponent_id, 1);
} }
} else {
reward = ERROR_REWARD;
self.game.mark_points_for_bot_training(self.opponent_id, 1);
} }
Ok(reward) (reward, is_rollpoint)
} }
/// Fait jouer l'adversaire avec une stratégie simple /// Fait jouer l'adversaire avec une stratégie simple
@ -333,17 +299,18 @@ impl TrictracEnvironment {
// Si c'est le tour de l'adversaire, jouer automatiquement // Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended { if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire // Utiliser la stratégie default pour l'adversaire
use crate::strategy::default::DefaultStrategy;
use crate::BotStrategy; use crate::BotStrategy;
let mut default_strategy = DefaultStrategy::default(); let mut strategy = crate::strategy::random::RandomStrategy::default();
default_strategy.set_player_id(self.opponent_id); strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) { if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
default_strategy.set_color(color); strategy.set_color(color);
} }
*default_strategy.get_mut_game() = self.game.clone(); *strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage // Exécuter l'action selon le turn_stage
let mut calculate_points = false;
let opponent_color = store::Color::Black;
let event = match self.game.turn_stage { let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll { TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id, player_id: self.opponent_id,
@ -351,6 +318,7 @@ impl TrictracEnvironment {
TurnStage::RollWaiting => { TurnStage::RollWaiting => {
let mut rng = thread_rng(); let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6)); let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true;
GameEvent::RollResult { GameEvent::RollResult {
player_id: self.opponent_id, player_id: self.opponent_id,
dice: store::Dice { dice: store::Dice {
@ -359,7 +327,6 @@ impl TrictracEnvironment {
} }
} }
TurnStage::MarkPoints => { TurnStage::MarkPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self let dice_roll_count = self
.game .game
.players .players
@ -368,16 +335,12 @@ impl TrictracEnvironment {
.dice_roll_count; .dice_roll_count;
let points_rules = let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice); PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark { GameEvent::Mark {
player_id: self.opponent_id, player_id: self.opponent_id,
points, points: points_rules.get_points(dice_roll_count).0,
} }
} }
TurnStage::MarkAdvPoints => { TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self let dice_roll_count = self
.game .game
.players .players
@ -401,14 +364,33 @@ impl TrictracEnvironment {
} }
TurnStage::Move => GameEvent::Move { TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id, player_id: self.opponent_id,
moves: default_strategy.choose_move(), moves: strategy.choose_move(),
}, },
}; };
if self.game.validate(&event) { if self.game.validate(&event) {
self.game.consume(&event); self.game.consume(&event);
if calculate_points {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= Self::REWARD_RATIO * (points - adv_points) as f32;
// Récompense proportionnelle aux points
}
} }
} }
reward reward
} }
} }
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

90
bot/src/burnrl/main.rs Normal file
View file

@ -0,0 +1,90 @@
use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
use bot::burnrl::environment::TrictracEnvironment;
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
use bot::burnrl::utils::{demo_model, Config};
use burn::backend::{Autodiff, NdArray};
use burn_rl::base::ElemType;
use std::env;
type Backend = Autodiff<NdArray<ElemType>>;
fn main() {
let args: Vec<String> = env::args().collect();
let algo = &args[1];
// let dir_path = &args[2];
let path = format!("bot/models/burnrl_{algo}");
println!(
"info: loading configuration from file {:?}",
confy::get_configuration_file_path("trictrac_bot", None).unwrap()
);
let mut conf: Config = confy::load("trictrac_bot", None).expect("Could not load config");
conf.save_path = Some(path.clone());
println!("{conf}----------");
match algo.as_str() {
"dqn" => {
let _agent = dqn::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
burn_rl::agent::DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"dqn_valid" => {
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentValid, _, _> =
burn_rl::agent::DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"sac" => {
let _agent = sac::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"sac_valid" => {
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentValid, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo" => {
let _agent = ppo::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo_valid" => {
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentValid, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
&_ => {
println!("unknown algo {algo}");
}
}
}

4
bot/src/burnrl/mod.rs Normal file
View file

@ -0,0 +1,4 @@
pub mod algos;
pub mod environment;
pub mod environment_valid;
pub mod utils;

132
bot/src/burnrl/utils.rs Normal file
View file

@ -0,0 +1,132 @@
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Config {
pub save_path: Option<String>,
pub max_steps: usize, // max steps by episode
pub num_episodes: usize,
pub dense_size: usize, // neural network complexity
// discount factor. Plus élevé = encourage stratégies à long terme
pub gamma: f32,
// soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation plus lente moins sensible aux coups de chance
pub tau: f32,
// taille du pas. Bas : plus lent, haut : risque de ne jamais
pub learning_rate: f32,
// nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
pub batch_size: usize,
// limite max de correction à apporter au gradient (default 100)
pub clip_grad: f32,
// ---- for SAC
pub min_probability: f32,
// ---- for DQN
// epsilon initial value (0.9 => more exploration)
pub eps_start: f64,
pub eps_end: f64,
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
pub eps_decay: f64,
// ---- for PPO
pub lambda: f32,
pub epsilon_clip: f32,
pub critic_weight: f32,
pub entropy_weight: f32,
pub epochs: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
save_path: None,
max_steps: 2000,
num_episodes: 1000,
dense_size: 256,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
min_probability: 1e-9,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
lambda: 0.95,
epsilon_clip: 0.2,
critic_weight: 0.5,
entropy_weight: 0.01,
epochs: 8,
}
}
}
impl std::fmt::Display for Config {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut s = String::new();
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
s.push_str(&format!("min_probability={:?}\n", self.min_probability));
s.push_str(&format!("lambda={:?}\n", self.lambda));
s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip));
s.push_str(&format!("critic_weight={:?}\n", self.critic_weight));
s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight));
s.push_str(&format!("epochs={:?}\n", self.epochs));
write!(f, "{s}")
}
}
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut env = E::new(true);
let mut state = env.state();
let mut done = false;
while !done {
if let Some(action) = agent.react(&state) {
let snapshot = env.step(action);
state = *snapshot.state();
done = snapshot.done();
}
}
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}

View file

@ -1,68 +0,0 @@
use bot::dqn::burnrl::{dqn_model, environment, utils::demo_model};
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::Module;
use burn::record::{CompactRecorder, Recorder};
use burn_rl::agent::DQN;
use burn_rl::base::{Action, Agent, ElemType, Environment, State};
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
let conf = dqn_model::DqnConfig {
num_episodes: 40,
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
dense_size: 256, // neural network complexity
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
eps_end: 0.05,
eps_decay: 3000.0,
};
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_50".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
// println!("> Test avec le modèle entraîné");
// demo_model::<Env>(valid_agent);
println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model);
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{}_model.mpk", path);
println!("Modèle de validation sauvegardé : {}", model_path);
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
fn load_model(dense_size: usize, path: &String) -> dqn_model::Net<NdArray<ElemType>> {
let model_path = format!("{}_model.mpk", path);
println!("Chargement du modèle depuis : {}", model_path);
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let record = recorder
.load(model_path.into(), &device)
.expect("Impossible de charger le modèle");
dqn_model::Net::new(
<environment::TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<environment::TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
}

View file

@ -1,3 +0,0 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,83 +0,0 @@
use crate::dqn::burnrl::environment::{TrictracAction, TrictracEnvironment};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{ElemType, Environment, State};
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);
let mut done = false;
while !done {
// let action = match infer_action(&agent, &env, state) {
let action = match infer_action(&agent, &env) {
Some(value) => value,
None => break,
};
// Execute action
let snapshot = env.step(action);
done = snapshot.done();
}
}
fn infer_action<B: Backend, M: DQNModel<B>>(
agent: &DQN<TrictracEnvironment, B, M>,
env: &TrictracEnvironment,
) -> Option<TrictracAction> {
let state = env.state();
// Get q-values
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
// Get valid actions
let valid_actions_indices = get_valid_action_indices(&env.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = TrictracAction::from(action_index);
Some(action)
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}

View file

@ -1,3 +0,0 @@
pub mod dqn_common;
pub mod simple;
pub mod burnrl;

View file

@ -1,489 +0,0 @@
use crate::{CheckerMove, Color, GameState, PlayerId};
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use crate::dqn::dqn_common::{get_valid_actions, DqnConfig, SimpleNeuralNetwork, TrictracAction};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience {
pub state: Vec<f32>,
pub action: TrictracAction,
pub reward: f32,
pub next_state: Vec<f32>,
pub done: bool,
}
/// Buffer de replay pour stocker les expériences
#[derive(Debug)]
pub struct ReplayBuffer {
buffer: VecDeque<Experience>,
capacity: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, experience: Experience) {
if self.buffer.len() >= self.capacity {
self.buffer.pop_front();
}
self.buffer.push_back(experience);
}
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
let mut rng = thread_rng();
let len = self.buffer.len();
if len < batch_size {
return self.buffer.iter().cloned().collect();
}
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
let idx = rng.gen_range(0..len);
batch.push(self.buffer[idx].clone());
}
batch
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
/// Agent DQN pour l'apprentissage par renforcement
#[derive(Debug)]
pub struct DqnAgent {
config: DqnConfig,
model: SimpleNeuralNetwork,
target_model: SimpleNeuralNetwork,
replay_buffer: ReplayBuffer,
epsilon: f64,
step_count: usize,
}
impl DqnAgent {
pub fn new(config: DqnConfig) -> Self {
let model =
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
let target_model = model.clone();
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
let epsilon = config.epsilon;
Self {
config,
model,
target_model,
replay_buffer,
epsilon,
step_count: 0,
}
}
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
let valid_actions = get_valid_actions(game_state);
if valid_actions.is_empty() {
// Fallback si aucune action valide
return TrictracAction::Roll;
}
let mut rng = thread_rng();
if rng.gen::<f64>() < self.epsilon {
// Exploration : action valide aléatoire
valid_actions
.choose(&mut rng)
.cloned()
.unwrap_or(TrictracAction::Roll)
} else {
// Exploitation : meilleure action valide selon le modèle
let q_values = self.model.forward(state);
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for action in &valid_actions {
let action_index = action.to_action_index();
if action_index < q_values.len() {
let q_value = q_values[action_index];
if q_value > best_q_value {
best_q_value = q_value;
best_action = action;
}
}
}
best_action.clone()
}
}
pub fn store_experience(&mut self, experience: Experience) {
self.replay_buffer.push(experience);
}
pub fn train(&mut self) {
if self.replay_buffer.len() < self.config.batch_size {
return;
}
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
// Dans une implémentation complète, ici on ferait la backpropagation
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
self.step_count += 1;
// Mise à jour du target model tous les 100 steps
if self.step_count % 100 == 0 {
self.target_model = self.model.clone();
}
}
pub fn save_model<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), Box<dyn std::error::Error>> {
self.model.save(path)
}
pub fn get_epsilon(&self) -> f64 {
self.epsilon
}
pub fn get_step_count(&self) -> usize {
self.step_count
}
}
/// Environnement Trictrac pour l'entraînement
#[derive(Debug)]
pub struct TrictracEnv {
pub game_state: GameState,
pub agent_player_id: PlayerId,
pub opponent_player_id: PlayerId,
pub agent_color: Color,
pub max_steps: usize,
pub current_step: usize,
}
impl Default for TrictracEnv {
fn default() -> Self {
let mut game_state = GameState::new(false);
game_state.init_player("agent");
game_state.init_player("opponent");
Self {
game_state,
agent_player_id: 1,
opponent_player_id: 2,
agent_color: Color::White,
max_steps: 1000,
current_step: 0,
}
}
}
impl TrictracEnv {
pub fn reset(&mut self) -> Vec<f32> {
self.game_state = GameState::new(false);
self.game_state.init_player("agent");
self.game_state.init_player("opponent");
// Commencer la partie
self.game_state.consume(&GameEvent::BeginGame {
goes_first: self.agent_player_id,
});
self.current_step = 0;
self.game_state.to_vec_float()
}
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
let mut reward = 0.0;
// Appliquer l'action de l'agent
if self.game_state.active_player_id == self.agent_player_id {
reward += self.apply_agent_action(action);
}
// Faire jouer l'adversaire (stratégie simple)
while self.game_state.active_player_id == self.opponent_player_id
&& self.game_state.stage != Stage::Ended
{
reward += self.play_opponent_turn();
}
// Vérifier si la partie est terminée
let done = self.game_state.stage == Stage::Ended
|| self.game_state.determine_winner().is_some()
|| self.current_step >= self.max_steps;
// Récompense finale si la partie est terminée
if done {
if let Some(winner) = self.game_state.determine_winner() {
if winner == self.agent_player_id {
reward += 100.0; // Bonus pour gagner
} else {
reward -= 50.0; // Pénalité pour perdre
}
}
}
self.current_step += 1;
let next_state = self.game_state.to_vec_float();
(next_state, reward, done)
}
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
let mut reward = 0.0;
let event = match action {
TrictracAction::Roll => {
// Lancer les dés
reward += 0.1;
Some(GameEvent::Roll {
player_id: self.agent_player_id,
})
}
// TrictracAction::Mark => {
// // Marquer des points
// let points = self.game_state.
// reward += 0.1 * points as f32;
// Some(GameEvent::Mark {
// player_id: self.agent_player_id,
// points,
// })
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go {
player_id: self.agent_player_id,
})
}
TrictracAction::Move {
dice_order,
from1,
from2,
} => {
// Effectuer un mouvement
let (dice1, dice2) = if dice_order {
(self.game_state.dice.values.0, self.game_state.dice.values.1)
} else {
(self.game_state.dice.values.1, self.game_state.dice.values.0)
};
let mut to1 = from1 + dice1 as usize;
let mut to2 = from2 + dice2 as usize;
// Gestion prise de coin par puissance
let opp_rest_field = 13;
if to1 == opp_rest_field && to2 == opp_rest_field {
to1 -= 1;
to2 -= 1;
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move {
player_id: self.agent_player_id,
moves: (checker_move1, checker_move2),
})
}
};
// Appliquer l'événement si valide
if let Some(event) = event {
if self.game_state.validate(&event) {
self.game_state.consume(&event);
// Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.agent_player_id,
dice: store::Dice {
values: dice_values,
},
};
if self.game_state.validate(&dice_event) {
self.game_state.consume(&dice_event);
}
}
} else {
// Pénalité pour action invalide
reward -= 2.0;
}
}
reward
}
// TODO : use default bot strategy
fn play_opponent_turn(&mut self) -> f32 {
let mut reward = 0.0;
let event = match self.game_state.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_player_id,
},
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
GameEvent::RollResult {
player_id: self.opponent_player_id,
dice: store::Dice {
values: dice_values,
},
}
}
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
let opponent_color = self.agent_color.opponent_color();
let dice_roll_count = self
.game_state
.players
.get(&self.opponent_player_id)
.unwrap()
.dice_roll_count;
let points_rules = PointsRules::new(
&opponent_color,
&self.game_state.board,
self.game_state.dice,
);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_player_id,
points,
}
}
TurnStage::Move => {
let opponent_color = self.agent_color.opponent_color();
let rules = MoveRules::new(
&opponent_color,
&self.game_state.board,
self.game_state.dice,
);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// Stratégie simple : choix aléatoire
let mut rng = thread_rng();
let choosen_move = *possible_moves
.choose(&mut rng)
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
GameEvent::Move {
player_id: self.opponent_player_id,
moves: if opponent_color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
},
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
GameEvent::Go {
player_id: self.opponent_player_id,
}
}
};
if self.game_state.validate(&event) {
self.game_state.consume(&event);
}
reward
}
}
/// Entraîneur pour le modèle DQN
pub struct DqnTrainer {
agent: DqnAgent,
env: TrictracEnv,
}
impl DqnTrainer {
pub fn new(config: DqnConfig) -> Self {
Self {
agent: DqnAgent::new(config),
env: TrictracEnv::default(),
}
}
pub fn train_episode(&mut self) -> f32 {
let mut total_reward = 0.0;
let mut state = self.env.reset();
// let mut step_count = 0;
loop {
// step_count += 1;
let action = self.agent.select_action(&self.env.game_state, &state);
let (next_state, reward, done) = self.env.step(action.clone());
total_reward += reward;
let experience = Experience {
state: state.clone(),
action,
reward,
next_state: next_state.clone(),
done,
};
self.agent.store_experience(experience);
self.agent.train();
if done {
break;
}
// if step_count % 100 == 0 {
// println!("{:?}", next_state);
// }
state = next_state;
}
total_reward
}
pub fn train(
&mut self,
episodes: usize,
save_every: usize,
model_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
for episode in 1..=episodes {
let reward = self.train_episode();
if episode % 100 == 0 {
println!(
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
episode,
episodes,
reward,
self.agent.get_epsilon(),
self.agent.get_step_count()
);
}
if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode);
self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path);
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path);
self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path);
Ok(())
}
}

View file

@ -1 +0,0 @@
pub mod dqn_trainer;

View file

@ -1,10 +1,14 @@
pub mod dqn; pub mod burnrl;
pub mod strategy; pub mod strategy;
pub mod training_common;
pub mod trictrac_board;
use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage}; use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy; pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy; pub use strategy::dqnburn::DqnBurnStrategy;
pub use strategy::erroneous_moves::ErroneousStrategy; pub use strategy::erroneous_moves::ErroneousStrategy;
pub use strategy::random::RandomStrategy;
pub use strategy::stable_baselines3::StableBaselines3Strategy; pub use strategy::stable_baselines3::StableBaselines3Strategy;
pub trait BotStrategy: std::fmt::Debug { pub trait BotStrategy: std::fmt::Debug {
@ -26,7 +30,7 @@ pub trait BotStrategy: std::fmt::Debug {
pub struct Bot { pub struct Bot {
pub player_id: PlayerId, pub player_id: PlayerId,
strategy: Box<dyn BotStrategy>, strategy: Box<dyn BotStrategy>,
// color: Color, color: Color,
// schools_enabled: bool, // schools_enabled: bool,
} }
@ -34,9 +38,9 @@ impl Default for Bot {
fn default() -> Self { fn default() -> Self {
let strategy = DefaultStrategy::default(); let strategy = DefaultStrategy::default();
Self { Self {
player_id: 2, player_id: 1,
strategy: Box::new(strategy), strategy: Box::new(strategy),
// color: Color::Black, color: Color::White,
// schools_enabled: false, // schools_enabled: false,
} }
} }
@ -52,57 +56,86 @@ impl Bot {
Color::White => 1, Color::White => 1,
Color::Black => 2, Color::Black => 2,
}; };
strategy.set_player_id(player_id); // strategy.set_player_id(player_id);
strategy.set_color(color); // strategy.set_color(color);
Self { Self {
player_id, player_id,
strategy, strategy,
// color, color,
// schools_enabled: false, // schools_enabled: false,
} }
} }
pub fn handle_event(&mut self, event: &GameEvent) -> Option<GameEvent> { pub fn handle_event(&mut self, event: &GameEvent) -> Option<GameEvent> {
debug!(">>>> {:?} BOT handle", self.color);
let game = self.strategy.get_mut_game(); let game = self.strategy.get_mut_game();
game.consume(event); let internal_event = if self.color == Color::Black {
&event.get_mirror()
} else {
event
};
let init_player_points = game.who_plays().map(|p| (p.points, p.holes));
let turn_stage = game.turn_stage;
game.consume(internal_event);
if game.stage == Stage::Ended { if game.stage == Stage::Ended {
debug!("<<<< end {:?} BOT handle", self.color);
return None; return None;
} }
if game.active_player_id == self.player_id { let active_player_id = if self.color == Color::Black {
return match game.turn_stage { if game.active_player_id == 1 {
2
} else {
1
}
} else {
game.active_player_id
};
if active_player_id == self.player_id {
let player_points = game.who_plays().map(|p| (p.points, p.holes));
if self.color == Color::Black {
debug!( " input (internal) evt : {internal_event:?}, points : {init_player_points:?}, stage : {turn_stage:?}");
}
let internal_event = match game.turn_stage {
TurnStage::MarkAdvPoints => Some(GameEvent::Mark { TurnStage::MarkAdvPoints => Some(GameEvent::Mark {
player_id: self.player_id, player_id: 1,
points: self.strategy.calculate_adv_points(), points: self.strategy.calculate_adv_points(),
}), }),
TurnStage::RollDice => Some(GameEvent::Roll { TurnStage::RollDice => Some(GameEvent::Roll { player_id: 1 }),
player_id: self.player_id,
}),
TurnStage::MarkPoints => Some(GameEvent::Mark { TurnStage::MarkPoints => Some(GameEvent::Mark {
player_id: self.player_id, player_id: 1,
points: self.strategy.calculate_points(), points: self.strategy.calculate_points(),
}), }),
TurnStage::Move => Some(GameEvent::Move { TurnStage::Move => Some(GameEvent::Move {
player_id: self.player_id, player_id: 1,
moves: self.strategy.choose_move(), moves: self.strategy.choose_move(),
}), }),
TurnStage::HoldOrGoChoice => { TurnStage::HoldOrGoChoice => {
if self.strategy.choose_go() { if self.strategy.choose_go() {
Some(GameEvent::Go { Some(GameEvent::Go { player_id: 1 })
player_id: self.player_id,
})
} else { } else {
Some(GameEvent::Move { Some(GameEvent::Move {
player_id: self.player_id, player_id: 1,
moves: self.strategy.choose_move(), moves: self.strategy.choose_move(),
}) })
} }
} }
_ => None, _ => None,
}; };
return if self.color == Color::Black {
debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}");
debug!("<<<< end {:?} BOT handle", self.color);
internal_event.map(|evt| evt.get_mirror())
} else {
debug!("<<<< end {:?} BOT handle", self.color);
internal_event
};
} }
debug!("<<<< end {:?} BOT handle", self.color);
None None
} }
// Only used in tests below
pub fn get_state(&self) -> &GameState { pub fn get_state(&self) -> &GameState {
self.strategy.get_game() self.strategy.get_game()
} }
@ -121,17 +154,31 @@ mod tests {
} }
#[test] #[test]
fn test_consume() { fn test_handle_event() {
let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black); let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black);
// let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black, false); // let mut bot = Bot::new(Box::new(DefaultStrategy::default()), Color::Black, false);
let mut event = bot.handle_event(&GameEvent::BeginGame { goes_first: 2 }); let mut event = bot.handle_event(&GameEvent::BeginGame { goes_first: 2 });
assert_eq!(event, Some(GameEvent::Roll { player_id: 2 })); assert_eq!(event, Some(GameEvent::Roll { player_id: 2 }));
assert_eq!(bot.get_state().active_player_id, 2); assert_eq!(bot.get_state().active_player_id, 1); // bot internal active_player_id for black
event = bot.handle_event(&GameEvent::RollResult {
player_id: 2,
dice: Dice { values: (2, 3) },
});
assert_eq!(
event,
Some(GameEvent::Move {
player_id: 2,
moves: (
CheckerMove::new(24, 21).unwrap(),
CheckerMove::new(24, 22).unwrap()
)
})
);
event = bot.handle_event(&GameEvent::BeginGame { goes_first: 1 }); event = bot.handle_event(&GameEvent::BeginGame { goes_first: 1 });
assert_eq!(event, None); assert_eq!(event, None);
assert_eq!(bot.get_state().active_player_id, 1); assert_eq!(bot.get_state().active_player_id, 2); //internal active_player_id
bot.handle_event(&GameEvent::RollResult { bot.handle_event(&GameEvent::RollResult {
player_id: 1, player_id: 1,
dice: Dice { values: (2, 3) }, dice: Dice { values: (2, 3) },

View file

@ -13,8 +13,8 @@ impl Default for DefaultStrategy {
let game = GameState::default(); let game = GameState::default();
Self { Self {
game, game,
player_id: 2, player_id: 1,
color: Color::Black, color: Color::White,
} }
} }
} }

View file

@ -1,175 +0,0 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use std::path::Path;
use store::MoveRules;
use crate::dqn::dqn_common::{
get_valid_actions, sample_valid_action, SimpleNeuralNetwork, TrictracAction,
};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
pub struct DqnStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
pub model: Option<SimpleNeuralNetwork>,
}
impl Default for DqnStrategy {
fn default() -> Self {
Self {
game: GameState::default(),
player_id: 2,
color: Color::Black,
model: None,
}
}
}
impl DqnStrategy {
pub fn new() -> Self {
Self::default()
}
pub fn new_with_model<P: AsRef<Path>>(model_path: P) -> Self {
let mut strategy = Self::new();
if let Ok(model) = SimpleNeuralNetwork::load(model_path) {
strategy.model = Some(model);
}
strategy
}
/// Utilise le modèle DQN pour choisir une action valide
fn get_dqn_action(&self) -> Option<TrictracAction> {
if let Some(ref model) = self.model {
let state = self.game.to_vec_float();
let valid_actions = get_valid_actions(&self.game);
if valid_actions.is_empty() {
return None;
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.forward(&state);
// Trouver la meilleure action valide
let mut best_action = &valid_actions[0];
let mut best_q_value = f32::NEG_INFINITY;
for action in &valid_actions {
let action_index = action.to_action_index();
if action_index < q_values.len() {
let q_value = q_values[action_index];
if q_value > best_q_value {
best_q_value = q_value;
best_action = action;
}
}
}
Some(best_action.clone())
} else {
// Fallback : action aléatoire valide
sample_valid_action(&self.game)
}
}
}
impl BotStrategy for DqnStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
self.game.dice_points.0
}
fn calculate_adv_points(&self) -> u8 {
self.game.dice_points.1
}
fn choose_go(&self) -> bool {
// Utiliser le DQN pour décider si on continue
if let Some(action) = self.get_dqn_action() {
matches!(action, TrictracAction::Go)
} else {
// Fallback : toujours continuer
true
}
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement
if let Some(action) = self.get_dqn_action() {
if let TrictracAction::Move {
dice_order,
from1,
from2,
} = action
{
let dicevals = self.game.dice.values;
let (mut dice1, mut dice2) = if dice_order {
(dicevals.0, dicevals.1)
} else {
(dicevals.1, dicevals.0)
};
if from1 == 0 {
// empty move
dice1 = 0;
}
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
}
}
// Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let chosen_move = *possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
chosen_move
} else {
(chosen_move.0.mirror(), chosen_move.1.mirror())
}
}
}

220
bot/src/strategy/dqnburn.rs Normal file
View file

@ -0,0 +1,220 @@
use burn::backend::NdArray;
use burn::tensor::cast::ToElement;
use burn_rl::base::{ElemType, Model, State};
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info;
use store::MoveRules;
use crate::burnrl::algos::dqn;
use crate::burnrl::environment;
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>;
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
pub struct DqnBurnStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
pub model: Option<DqnBurnNetwork>,
}
impl Default for DqnBurnStrategy {
fn default() -> Self {
Self {
game: GameState::default(),
player_id: 1,
color: Color::White,
model: None,
}
}
}
impl DqnBurnStrategy {
pub fn new() -> Self {
Self::default()
}
pub fn new_with_model(model_path: &String) -> Self {
info!("Loading model {model_path:?}");
let mut strategy = Self::new();
strategy.model = dqn::load_model(256, model_path);
strategy
}
/// Utilise le modèle DQN pour choisir une action valide
fn get_dqn_action(&self) -> Option<TrictracAction> {
if let Some(ref model) = self.model {
let state = environment::TrictracState::from_game_state(&self.game);
let valid_actions_indices = get_valid_action_indices(&self.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Obtenir les Q-values pour toutes les actions
let q_values = model.infer(state.to_tensor().unsqueeze());
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
action_index,
))
} else {
// Fallback : action aléatoire valide
sample_valid_action(&self.game)
}
}
}
impl BotStrategy for DqnBurnStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
self.game.dice_points.0
}
fn calculate_adv_points(&self) -> u8 {
self.game.dice_points.1
}
fn choose_go(&self) -> bool {
// Utiliser le DQN pour décider si on continue
if let Some(action) = self.get_dqn_action() {
matches!(action, TrictracAction::Go)
} else {
// Fallback : toujours continuer
true
}
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
// Utiliser le DQN pour choisir le mouvement
if let Some(TrictracAction::Move {
dice_order,
checker1,
checker2,
}) = self.get_dqn_action()
{
let dicevals = self.game.dice.values;
let (mut dice1, mut dice2) = if dice_order {
(dicevals.0, dicevals.1)
} else {
(dicevals.1, dicevals.0)
};
assert_eq!(self.color, Color::White);
let from1 = self
.game
.board
.get_checker_field(&self.color, checker1 as u8)
.unwrap_or(0);
if from1 == 0 {
// empty move
dice1 = 0;
}
let mut to1 = from1;
if self.color == Color::White {
to1 += dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
} else {
let fto1 = to1 as i16 - dice1 as i16;
to1 = if fto1 < 0 { 0 } else { fto1 as usize };
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone();
let move_res = tmp_board.move_checker(&self.color, checker_move1);
if move_res.is_err() {
panic!("could not move {move_res:?}");
}
let from2 = tmp_board
.get_checker_field(&self.color, checker2 as u8)
.unwrap_or(0);
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2;
if self.color == Color::White {
to2 += dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
} else {
let fto2 = to2 as i16 - dice2 as i16;
to2 = if fto2 < 0 { 0 } else { fto2 as usize };
}
// Gestion prise de coin par puissance
let opp_rest_field = if self.color == Color::White { 13 } else { 12 };
if to1 == opp_rest_field && to2 == opp_rest_field {
if self.color == Color::White {
to1 -= 1;
to2 -= 1;
} else {
to1 += 1;
to2 += 1;
}
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
// XXX : really ?
(checker_move1.mirror(), checker_move2.mirror())
};
return chosen_move;
}
// Fallback : utiliser la stratégie par défaut
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
let chosen_move = *possible_moves
.first()
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
chosen_move
} else {
(chosen_move.0.mirror(), chosen_move.1.mirror())
}
}
}

View file

@ -1,5 +1,6 @@
pub mod client; pub mod client;
pub mod default; pub mod default;
pub mod dqn; pub mod dqnburn;
pub mod erroneous_moves; pub mod erroneous_moves;
pub mod random;
pub mod stable_baselines3; pub mod stable_baselines3;

View file

@ -0,0 +1,67 @@
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use store::MoveRules;
#[derive(Debug)]
pub struct RandomStrategy {
pub game: GameState,
pub player_id: PlayerId,
pub color: Color,
}
impl Default for RandomStrategy {
fn default() -> Self {
let game = GameState::default();
Self {
game,
player_id: 1,
color: Color::White,
}
}
}
impl BotStrategy for RandomStrategy {
fn get_game(&self) -> &GameState {
&self.game
}
fn get_mut_game(&mut self) -> &mut GameState {
&mut self.game
}
fn set_color(&mut self, color: Color) {
self.color = color;
}
fn set_player_id(&mut self, player_id: PlayerId) {
self.player_id = player_id;
}
fn calculate_points(&self) -> u8 {
self.game.dice_points.0
}
fn calculate_adv_points(&self) -> u8 {
self.game.dice_points.1
}
fn choose_go(&self) -> bool {
true
}
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
use rand::{seq::SliceRandom, thread_rng};
let mut rng = thread_rng();
let choosen_move = possible_moves
.choose(&mut rng)
.cloned()
.unwrap_or((CheckerMove::default(), CheckerMove::default()));
if self.color == Color::White {
choosen_move
} else {
(choosen_move.0.mirror(), choosen_move.1.mirror())
}
}
}

View file

@ -66,14 +66,14 @@ impl StableBaselines3Strategy {
// Remplir les positions des pièces blanches (valeurs positives) // Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) { for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 { if pos < 24 {
board[pos] = count as i8; board[pos] = count;
} }
} }
// Remplir les positions des pièces noires (valeurs négatives) // Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) { for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 { if pos < 24 {
board[pos] = -(count as i8); board[pos] = -count;
} }
} }
@ -270,4 +270,3 @@ impl BotStrategy for StableBaselines3Strategy {
} }
} }
} }

View file

@ -1,10 +1,17 @@
/// training_common.rs : environnement avec espace d'actions optimisé
/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' )
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::fmt::{Debug, Display, Formatter};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use store::{CheckerMove, Dice}; use store::{CheckerMove, GameEvent, GameState};
// 1 (Roll) + 1 (Go) + mouvements possibles
// Pour les mouvements : 2*16*16 = 514 (choix du dé + choix de la dame 0-15 pour chaque from)
pub const ACTION_SPACE_SIZE: usize = 514;
/// Types d'actions possibles dans le jeu /// Types d'actions possibles dans le jeu
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Copy, Clone, Eq, Serialize, Deserialize, PartialEq)]
pub enum TrictracAction { pub enum TrictracAction {
/// Lancer les dés /// Lancer les dés
Roll, Roll,
@ -13,13 +20,21 @@ pub enum TrictracAction {
/// Effectuer un mouvement de pions /// Effectuer un mouvement de pions
Move { Move {
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
from1: usize, // position de départ du premier pion (0-24) checker1: usize, // premier pion à déplacer en numérotant depuis la colonne de départ (0-15) 0 : aucun pion
from2: usize, // position de départ du deuxième pion (0-24) checker2: usize, // deuxième pion (0-15)
}, },
// Marquer les points : à activer si support des écoles // Marquer les points : à activer si support des écoles
// Mark, // Mark,
} }
impl Display for TrictracAction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s = format!("{self:?}");
writeln!(f, "{}", s.chars().rev().collect::<String>())?;
Ok(())
}
}
impl TrictracAction { impl TrictracAction {
/// Encode une action en index pour le réseau de neurones /// Encode une action en index pour le réseau de neurones
pub fn to_action_index(&self) -> usize { pub fn to_action_index(&self) -> usize {
@ -28,19 +43,91 @@ impl TrictracAction {
TrictracAction::Go => 1, TrictracAction::Go => 1,
TrictracAction::Move { TrictracAction::Move {
dice_order, dice_order,
from1, checker1,
from2, checker2,
} => { } => {
// Encoder les mouvements dans l'espace d'actions // Encoder les mouvements dans l'espace d'actions
// Indices 2+ pour les mouvements // Indices 2+ pour les mouvements
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier) // de 2 à 513 (2 à 257 pour dé 1 en premier, 258 à 513 pour dé 2 en premier)
let mut start = 2; let mut start = 2;
if !dice_order { if !dice_order {
// 25 * 25 = 625 // 16 * 16 = 256
start += 625; start += 256;
} }
start + from1 * 25 + from2 start + checker1 * 16 + checker2
} // TrictracAction::Mark => 1252, } // TrictracAction::Mark => 514,
}
}
pub fn to_event(&self, state: &GameState) -> Option<GameEvent> {
match self {
TrictracAction::Roll => {
// Lancer les dés
Some(GameEvent::Roll {
player_id: state.active_player_id,
})
}
// TrictracAction::Mark => {
// // Marquer des points
// let points = self.game.
// Some(GameEvent::Mark {
// player_id: self.active_player_id,
// points,
// })
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
Some(GameEvent::Go {
player_id: state.active_player_id,
})
}
TrictracAction::Move {
dice_order,
checker1,
checker2,
} => {
// Effectuer un mouvement
let (dice1, dice2) = if *dice_order {
(state.dice.values.0, state.dice.values.1)
} else {
(state.dice.values.1, state.dice.values.0)
};
let color = &store::Color::White;
let from1 = state
.board
.get_checker_field(color, *checker1 as u8)
.unwrap_or(0);
let mut to1 = from1 + dice1 as usize;
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = state.board.clone();
let move_result = tmp_board.move_checker(color, checker_move1);
if move_result.is_err() {
None
// panic!("Error while moving checker {move_result:?}")
} else {
let from2 = tmp_board
.get_checker_field(color, *checker2 as u8)
.unwrap_or(0);
let mut to2 = from2 + dice2 as usize;
// Gestion prise de coin par puissance
let opp_rest_field = 13;
if to1 == opp_rest_field && to2 == opp_rest_field {
to1 -= 1;
to2 -= 1;
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
Some(GameEvent::Move {
player_id: state.active_player_id,
moves: (checker_move1, checker_move2),
})
}
}
} }
} }
@ -48,15 +135,15 @@ impl TrictracAction {
pub fn from_action_index(index: usize) -> Option<TrictracAction> { pub fn from_action_index(index: usize) -> Option<TrictracAction> {
match index { match index {
0 => Some(TrictracAction::Roll), 0 => Some(TrictracAction::Roll),
// 1252 => Some(TrictracAction::Mark),
1 => Some(TrictracAction::Go), 1 => Some(TrictracAction::Go),
i if i >= 3 => { // 514 => Some(TrictracAction::Mark),
let move_code = i - 3; i if i >= 2 => {
let (dice_order, from1, from2) = Self::decode_move(move_code); let move_code = i - 2;
let (dice_order, checker1, checker2) = Self::decode_move(move_code);
Some(TrictracAction::Move { Some(TrictracAction::Move {
dice_order, dice_order,
from1, checker1,
from2, checker2,
}) })
} }
_ => None, _ => None,
@ -66,21 +153,18 @@ impl TrictracAction {
/// Décode un entier en paire de mouvements /// Décode un entier en paire de mouvements
fn decode_move(code: usize) -> (bool, usize, usize) { fn decode_move(code: usize) -> (bool, usize, usize) {
let mut encoded = code; let mut encoded = code;
let dice_order = code < 626; let dice_order = code < 256;
if !dice_order { if !dice_order {
encoded -= 625 encoded -= 256
} }
let from1 = encoded / 25; let checker1 = encoded / 16;
let from2 = 1 + encoded % 25; let checker2 = encoded % 16;
(dice_order, from1, from2) (dice_order, checker1, checker2)
} }
/// Retourne la taille de l'espace d'actions total /// Retourne la taille de l'espace d'actions total
pub fn action_space_size() -> usize { pub fn action_space_size() -> usize {
// 1 (Roll) + 1 (Go) + mouvements possibles ACTION_SPACE_SIZE
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
// Mais on peut optimiser en limitant aux positions valides (1-24)
2 + (2 * 25 * 25) // = 1252
} }
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent { // pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
@ -106,157 +190,6 @@ impl TrictracAction {
// } // }
} }
/// Configuration pour l'agent DQN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DqnConfig {
pub state_size: usize,
pub hidden_size: usize,
pub num_actions: usize,
pub learning_rate: f64,
pub gamma: f64,
pub epsilon: f64,
pub epsilon_decay: f64,
pub epsilon_min: f64,
pub replay_buffer_size: usize,
pub batch_size: usize,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
state_size: 36,
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
num_actions: TrictracAction::action_space_size(),
learning_rate: 0.001,
gamma: 0.99,
epsilon: 0.1,
epsilon_decay: 0.995,
epsilon_min: 0.01,
replay_buffer_size: 10000,
batch_size: 32,
}
}
}
/// Réseau de neurones DQN simplifié (matrice de poids basique)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimpleNeuralNetwork {
pub weights1: Vec<Vec<f32>>,
pub biases1: Vec<f32>,
pub weights2: Vec<Vec<f32>>,
pub biases2: Vec<f32>,
pub weights3: Vec<Vec<f32>>,
pub biases3: Vec<f32>,
}
impl SimpleNeuralNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
// Initialisation aléatoire des poids avec Xavier/Glorot
let scale1 = (2.0 / input_size as f32).sqrt();
let weights1 = (0..hidden_size)
.map(|_| {
(0..input_size)
.map(|_| rng.gen_range(-scale1..scale1))
.collect()
})
.collect();
let biases1 = vec![0.0; hidden_size];
let scale2 = (2.0 / hidden_size as f32).sqrt();
let weights2 = (0..hidden_size)
.map(|_| {
(0..hidden_size)
.map(|_| rng.gen_range(-scale2..scale2))
.collect()
})
.collect();
let biases2 = vec![0.0; hidden_size];
let scale3 = (2.0 / hidden_size as f32).sqrt();
let weights3 = (0..output_size)
.map(|_| {
(0..hidden_size)
.map(|_| rng.gen_range(-scale3..scale3))
.collect()
})
.collect();
let biases3 = vec![0.0; output_size];
Self {
weights1,
biases1,
weights2,
biases2,
weights3,
biases3,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
// Première couche
let mut layer1: Vec<f32> = self.biases1.clone();
for (i, neuron_weights) in self.weights1.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < input.len() {
layer1[i] += input[j] * weight;
}
}
layer1[i] = layer1[i].max(0.0); // ReLU
}
// Deuxième couche
let mut layer2: Vec<f32> = self.biases2.clone();
for (i, neuron_weights) in self.weights2.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer1.len() {
layer2[i] += layer1[j] * weight;
}
}
layer2[i] = layer2[i].max(0.0); // ReLU
}
// Couche de sortie
let mut output: Vec<f32> = self.biases3.clone();
for (i, neuron_weights) in self.weights3.iter().enumerate() {
for (j, &weight) in neuron_weights.iter().enumerate() {
if j < layer2.len() {
output[i] += layer2[j] * weight;
}
}
}
output
}
pub fn get_best_action(&self, input: &[f32]) -> usize {
let q_values = self.forward(input);
q_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(index, _)| index)
.unwrap_or(0)
}
pub fn save<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), Box<dyn std::error::Error>> {
let data = serde_json::to_string_pretty(self)?;
std::fs::write(path, data)?;
Ok(())
}
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
let data = std::fs::read_to_string(path)?;
let network = serde_json::from_str(&data)?;
Ok(network)
}
}
/// Obtient les actions valides pour l'état de jeu actuel /// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> { pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use store::TurnStage; use store::TurnStage;
@ -268,11 +201,15 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
if let Some(color) = player_color { if let Some(color) = player_color {
match game_state.turn_stage { match game_state.turn_stage {
TurnStage::RollDice | TurnStage::RollWaiting => { TurnStage::RollDice => {
valid_actions.push(TrictracAction::Roll); valid_actions.push(TrictracAction::Roll);
} }
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => { TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
// valid_actions.push(TrictracAction::Mark); // valid_actions.push(TrictracAction::Mark);
panic!(
"get_valid_actions not implemented for turn stage {:?}",
game_state.turn_stage
);
} }
TurnStage::HoldOrGoChoice => { TurnStage::HoldOrGoChoice => {
valid_actions.push(TrictracAction::Go); valid_actions.push(TrictracAction::Go);
@ -285,29 +222,32 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
assert_eq!(color, store::Color::White); assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move1, &move2, &color, game_state,
&move2,
&game_state.dice,
)); ));
} }
} }
TurnStage::Move => { TurnStage::Move => {
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice); let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]); let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
if possible_moves.is_empty() {
// Empty move
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
}
// Modififier checker_moves_to_trictrac_action si on doit gérer Black // Modififier checker_moves_to_trictrac_action si on doit gérer Black
assert_eq!(color, store::Color::White); assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves { for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action( valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move1, &move2, &color, game_state,
&move2,
&game_state.dice,
)); ));
} }
} }
} }
} }
if valid_actions.is_empty() {
panic!("empty valid_actions for state {game_state}");
}
valid_actions valid_actions
} }
@ -315,12 +255,14 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
fn checker_moves_to_trictrac_action( fn checker_moves_to_trictrac_action(
move1: &CheckerMove, move1: &CheckerMove,
move2: &CheckerMove, move2: &CheckerMove,
dice: &Dice, color: &store::Color,
state: &crate::GameState,
) -> TrictracAction { ) -> TrictracAction {
let to1 = move1.get_to(); let to1 = move1.get_to();
let to2 = move2.get_to(); let to2 = move2.get_to();
let from1 = move1.get_from(); let from1 = move1.get_from();
let from2 = move2.get_from(); let from2 = move2.get_from();
let dice = state.dice;
let mut diff_move1 = if to1 > 0 { let mut diff_move1 = if to1 > 0 {
// Mouvement sans sortie // Mouvement sans sortie
@ -354,10 +296,20 @@ fn checker_moves_to_trictrac_action(
// prise par puissance // prise par puissance
diff_move1 += 1; diff_move1 += 1;
} }
let dice_order = diff_move1 == dice.values.0 as usize;
let checker1 = state.board.get_field_checker(color, from1) as usize;
let mut tmp_board = state.board.clone();
// should not raise an error for a valid action
let move_res = tmp_board.move_checker(color, *move1);
if move_res.is_err() {
panic!("error while moving checker {move_res:?}");
}
let checker2 = tmp_board.get_field_checker(color, from2) as usize;
TrictracAction::Move { TrictracAction::Move {
dice_order: diff_move1 == dice.values.0 as usize, dice_order,
from1: move1.get_from(), checker1,
from2: move2.get_from(), checker2,
} }
} }
@ -386,21 +338,21 @@ mod tests {
fn to_action_index() { fn to_action_index() {
let action = TrictracAction::Move { let action = TrictracAction::Move {
dice_order: true, dice_order: true,
from1: 3, checker1: 3,
from2: 4, checker2: 4,
}; };
let index = action.to_action_index(); let index = action.to_action_index();
assert_eq!(Some(action), TrictracAction::from_action_index(index)); assert_eq!(Some(action), TrictracAction::from_action_index(index));
assert_eq!(81, index); assert_eq!(54, index);
} }
#[test] #[test]
fn from_action_index() { fn from_action_index() {
let action = TrictracAction::Move { let action = TrictracAction::Move {
dice_order: true, dice_order: true,
from1: 3, checker1: 3,
from2: 4, checker2: 4,
}; };
assert_eq!(Some(action), TrictracAction::from_action_index(81)); assert_eq!(Some(action), TrictracAction::from_action_index(54));
} }
} }

164
bot/src/trictrac_board.rs Normal file
View file

@ -0,0 +1,164 @@
// https://docs.rs/board-game/ implementation
use crate::training_common::{get_valid_actions, TrictracAction};
use board_game::board::{
Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer,
};
use board_game::impl_unit_symmetry_board;
use internal_iterator::InternalIterator;
use std::fmt;
use std::hash::Hash;
use std::ops::ControlFlow;
use store::Color;
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct TrictracBoard(crate::GameState);
impl Default for TrictracBoard {
fn default() -> Self {
TrictracBoard(crate::GameState::new_with_players("white", "black"))
}
}
impl fmt::Display for TrictracBoard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl_unit_symmetry_board!(TrictracBoard);
impl BoardGameBoard for TrictracBoard {
// impl TrictracBoard {
type Move = TrictracAction;
fn next_player(&self) -> BoardGamePlayer {
self.0
.who_plays()
.map(|p| {
if p.color == Color::Black {
BoardGamePlayer::B
} else {
BoardGamePlayer::A
}
})
.unwrap_or(BoardGamePlayer::A)
}
fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
self.check_done()?;
let is_valid = mv
.to_event(&self.0)
.map(|evt| self.0.validate(&evt))
.unwrap_or(false);
Ok(is_valid)
}
fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
self.check_can_play(mv)?;
self.0.consume(&mv.to_event(&self.0).unwrap());
Ok(())
}
fn outcome(&self) -> Option<Outcome> {
if self.0.stage == crate::Stage::Ended {
self.0.determine_winner().map(|player_id| {
Outcome::WonBy(if player_id == 1 {
BoardGamePlayer::A
} else {
BoardGamePlayer::B
})
})
} else {
None
}
}
fn can_lose_after_move() -> bool {
true
}
}
impl TrictracBoard {
pub fn inner(&self) -> &crate::GameState {
&self.0
}
pub fn to_fen(&self) -> String {
self.0.to_string_id()
}
pub fn from_fen(fen: &str) -> Result<TrictracBoard, String> {
crate::GameState::from_string_id(fen).map(TrictracBoard)
}
}
impl<'a> BoardMoves<'a, TrictracBoard> for TrictracBoard {
type AllMovesIterator = TrictracAllMovesIterator;
type AvailableMovesIterator = TrictracAvailableMovesIterator<'a>;
fn all_possible_moves() -> Self::AllMovesIterator {
TrictracAllMovesIterator::default()
}
fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
TrictracAvailableMovesIterator::new(self)
}
}
#[derive(Debug, Clone)]
pub struct TrictracAllMovesIterator;
impl Default for TrictracAllMovesIterator {
fn default() -> Self {
TrictracAllMovesIterator
}
}
impl InternalIterator for TrictracAllMovesIterator {
type Item = TrictracAction;
fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
f(TrictracAction::Roll)?;
f(TrictracAction::Go)?;
for dice_order in [false, true] {
for checker1 in 0..16 {
for checker2 in 0..16 {
f(TrictracAction::Move {
dice_order,
checker1,
checker2,
})?;
}
}
}
ControlFlow::Continue(())
}
}
#[derive(Debug, Clone)]
pub struct TrictracAvailableMovesIterator<'a> {
board: &'a TrictracBoard,
}
impl<'a> TrictracAvailableMovesIterator<'a> {
pub fn new(board: &'a TrictracBoard) -> Result<Self, BoardDone> {
board.check_done()?;
Ok(TrictracAvailableMovesIterator { board })
}
pub fn board(&self) -> &'a TrictracBoard {
self.board
}
}
impl InternalIterator for TrictracAvailableMovesIterator<'_> {
type Item = TrictracAction;
fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
where
F: FnMut(Self::Item) -> ControlFlow<R>,
{
get_valid_actions(&self.board.0).into_iter().try_for_each(f)
}
}

View file

@ -1,8 +0,0 @@
[target.x86_64-unknown-linux-gnu]
linker = "clang"
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only'
# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains.
#[profile.dev]
#debug = 1

View file

@ -1,14 +0,0 @@
[package]
name = "trictrac-client"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.75"
bevy = { version = "0.11.3" }
bevy_renet = "0.0.9"
bincode = "1.3.3"
renet = "0.0.13"
store = { path = "../store" }

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 MiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.4 KiB

View file

@ -1,334 +0,0 @@
use std::{net::UdpSocket, time::SystemTime};
use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES};
use store::{GameEvent, GameState, CheckerMove};
use bevy::prelude::*;
use bevy::window::PrimaryWindow;
use bevy_renet::{
renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient},
transport::{client_connected, NetcodeClientPlugin},
RenetClientPlugin,
};
#[derive(Debug, Resource)]
struct CurrentClientId(u64);
#[derive(Resource)]
struct BevyGameState(GameState);
impl Default for BevyGameState {
fn default() -> Self {
Self {
0: GameState::default(),
}
}
}
#[derive(Resource, Deref, DerefMut)]
struct GameUIState {
selected_tile: Option<usize>,
}
impl Default for GameUIState {
fn default() -> Self {
Self {
selected_tile: None,
}
}
}
#[derive(Event)]
struct BevyGameEvent(GameEvent);
// This id needs to be the same as the server is using
const PROTOCOL_ID: u64 = 2878;
fn main() {
// Get username from stdin args
let args = std::env::args().collect::<Vec<String>>();
let username = &args[1];
let (client, transport, client_id) = new_renet_client(&username).unwrap();
App::new()
// Lets add a nice dark grey background color
.insert_resource(ClearColor(Color::hex("282828").unwrap()))
.add_plugins(DefaultPlugins.set(WindowPlugin {
primary_window: Some(Window {
// Adding the username to the window title makes debugging a whole lot easier.
title: format!("TricTrac <{}>", username),
resolution: (1080.0, 1080.0).into(),
..default()
}),
..default()
}))
// Add our game state and register GameEvent as a bevy event
.insert_resource(BevyGameState::default())
.insert_resource(GameUIState::default())
.add_event::<BevyGameEvent>()
// Renet setup
.add_plugins(RenetClientPlugin)
.add_plugins(NetcodeClientPlugin)
.insert_resource(client)
.insert_resource(transport)
.insert_resource(CurrentClientId(client_id))
.add_systems(Startup, setup)
.add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system))
.add_systems(
PostUpdate,
receive_events_from_server.run_if(client_connected()),
)
.run();
}
////////// COMPONENTS //////////
#[derive(Component)]
struct UIRoot;
#[derive(Component)]
struct WaitingText;
#[derive(Component)]
struct Board {
squares: [Square; 26]
}
impl Default for Board {
fn default() -> Self {
Self {
squares: [Square { count: 0, color: None, position: 0}; 26]
}
}
}
impl Board {
fn square_at(&self, position: usize) -> Square {
self.squares[position]
}
}
#[derive(Component, Clone, Copy)]
struct Square {
count: usize,
color: Option<bool>,
position: usize,
}
////////// UPDATE SYSTEMS //////////
fn update_board(
mut commands: Commands,
game_state: Res<BevyGameState>,
mut game_events: EventReader<BevyGameEvent>,
asset_server: Res<AssetServer>,
) {
for event in game_events.iter() {
match event.0 {
GameEvent::Move { player_id, moves } => {
// trictrac positions, TODO : dépend de player_id
let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)};
let texture =
asset_server.load(match game_state.0.players[&player_id].color {
store::Color::Black => "tac.png",
store::Color::White => "tic.png",
});
info!("spawning tictac sprite");
commands.spawn(SpriteBundle {
transform: Transform::from_xyz(
83.0 * (x as f32 - 1.0),
-30.0 + 540.0 * (y as f32 - 1.0),
0.0,
),
sprite: Sprite {
custom_size: Some(Vec2::new(83.0, 83.0)),
..default()
},
texture: texture.into(),
..default()
});
}
_ => {}
}
}
}
fn update_waiting_text(mut text_query: Query<&mut Text, With<WaitingText>>, time: Res<Time>) {
if let Ok(mut text) = text_query.get_single_mut() {
let num_dots = (time.elapsed_seconds() as usize % 3) + 1;
text.sections[0].value = format!(
"Waiting for an opponent{}{}",
".".repeat(num_dots as usize),
// Pad with spaces to avoid text changing width and dancing all around the screen 🕺
" ".repeat(3 - num_dots as usize)
);
}
}
fn input(
primary_query: Query<&Window, With<PrimaryWindow>>,
// windows: Res<Windows>,
input: Res<Input<MouseButton>>,
game_state: Res<BevyGameState>,
mut game_ui_state: ResMut<GameUIState>,
mut client: ResMut<RenetClient>,
client_id: Res<CurrentClientId>,
) {
// We only want to handle inputs once we are ingame
if game_state.0.stage != store::Stage::InGame {
return;
}
let window = primary_query.get_single().unwrap();
if let Some(mouse_position) = window.cursor_position() {
// Determine the index of the tile that the mouse is currently over
// NOTE: This calculation assumes a fixed window size.
// That's fine for now, but consider using the windows size instead.
let mut tile_x: usize = (mouse_position.x / 83.0).floor() as usize;
let tile_y: usize = (mouse_position.y / 540.0).floor() as usize;
if tile_x > 5 {
// remove the middle bar offset
tile_x = tile_x - 1
}
// let tile = tile_x + tile_y * 12;
// traduction en position backgammon
let tile = if tile_y == 0 {
13 + tile_x
} else {
12 - tile_x
};
// If mouse is outside of board we do nothing
if 23 < tile {
return;
}
// If left mouse button is pressed, send a place tile event to the server
if input.just_pressed(MouseButton::Left) {
info!("select piece at tile {:?}", tile);
if game_ui_state.selected_tile.is_some() {
let from_tile = game_ui_state.selected_tile.unwrap();
info!("sending movement from: {:?} to: {:?} ", from_tile, tile);
let event = GameEvent::Move {
player_id: client_id.0,
moves: (
CheckerMove::new(from_tile, tile).unwrap(),
CheckerMove::new(from_tile, tile).unwrap()
)
};
client.send_message(0, bincode::serialize(&event).unwrap());
}
game_ui_state.selected_tile = if game_ui_state.selected_tile.is_some() {
None
} else {
Some(tile)
}
}
}
}
////////// SETUP //////////
fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
// Tric Trac is a 2D game
// To show 2D sprites we need a 2D camera
commands.spawn(Camera2dBundle::default());
// Spawn board background
commands.spawn(SpriteBundle {
transform: Transform::from_xyz(0.0, -30.0, 0.0),
sprite: Sprite {
custom_size: Some(Vec2::new(1080.0, 927.0)),
..default()
},
texture: asset_server.load("board.png").into(),
..default()
});
// Spawn pregame ui
commands
// A container that centers its children on the screen
.spawn(NodeBundle {
style: Style {
position_type: PositionType::Absolute,
left: Val::Px(0.0),
top: Val::Px(0.0),
width: Val::Percent(100.0),
height: Val::Percent(100.0),
align_items: AlignItems::Center,
justify_content: JustifyContent::Center,
..default()
},
..default()
})
.insert(UIRoot)
.with_children(|parent| {
// parent.spawn(Board::default()); // panic
parent
.spawn(TextBundle::from_section(
"Waiting for an opponent...",
TextStyle {
font: asset_server.load("Inconsolata.ttf"),
font_size: 24.0,
color: Color::hex("ebdbb2").unwrap(),
},
))
.insert(WaitingText);
});
}
////////// RENET NETWORKING //////////
// Creates a RenetClient thats already connected to a server.
// Returns an Err if connection fails
fn new_renet_client(
username: &String,
) -> anyhow::Result<(RenetClient, NetcodeClientTransport, u64)> {
let client = RenetClient::new(ConnectionConfig::default());
let server_addr = "127.0.0.1:5000".parse()?;
let socket = UdpSocket::bind("127.0.0.1:0")?;
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?;
let client_id = current_time.as_millis() as u64;
// Place username in user data
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
if username.len() > NETCODE_USER_DATA_BYTES - 8 {
panic!("Username is too big");
}
user_data[0..8].copy_from_slice(&(username.len() as u64).to_le_bytes());
user_data[8..username.len() + 8].copy_from_slice(username.as_bytes());
let authentication = ClientAuthentication::Unsecure {
server_addr,
client_id,
user_data: Some(user_data),
protocol_id: PROTOCOL_ID,
};
let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
Ok((client, transport, client_id))
}
fn receive_events_from_server(
mut client: ResMut<RenetClient>,
mut game_state: ResMut<BevyGameState>,
mut game_events: EventWriter<BevyGameEvent>,
) {
while let Some(message) = client.receive_message(0) {
// Whenever the server sends a message we know that it must be a game event
let event: GameEvent = bincode::deserialize(&message).unwrap();
trace!("{:#?}", event);
// We trust the server - It's always been good to us!
// No need to validate the events it is sending us
game_state.0.consume(&event);
// Send the event into the bevy event system so systems can react to it
game_events.send(BevyGameEvent(event));
}
}
// If any error is found we just panic
fn panic_on_error_system(mut renet_error: EventReader<NetcodeTransportError>) {
for e in renet_error.iter() {
panic!("{}", e);
}
}

View file

@ -15,3 +15,4 @@ store = { path = "../store" }
bot = { path = "../bot" } bot = { path = "../bot" }
itertools = "0.13.0" itertools = "0.13.0"
env_logger = "0.11.6" env_logger = "0.11.6"
log = "0.4.20"

View file

@ -1,4 +1,7 @@
use bot::{BotStrategy, DefaultStrategy, DqnStrategy, ErroneousStrategy, StableBaselines3Strategy}; use bot::{
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
StableBaselines3Strategy,
};
use itertools::Itertools; use itertools::Itertools;
use crate::game_runner::GameRunner; use crate::game_runner::GameRunner;
@ -32,21 +35,25 @@ impl App {
"dummy" => { "dummy" => {
Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(DefaultStrategy::default()) as Box<dyn BotStrategy>)
} }
"random" => {
Some(Box::new(RandomStrategy::default()) as Box<dyn BotStrategy>)
}
"erroneous" => { "erroneous" => {
Some(Box::new(ErroneousStrategy::default()) as Box<dyn BotStrategy>) Some(Box::new(ErroneousStrategy::default()) as Box<dyn BotStrategy>)
} }
"ai" => Some(Box::new(StableBaselines3Strategy::default()) "ai" => Some(Box::new(StableBaselines3Strategy::default())
as Box<dyn BotStrategy>), as Box<dyn BotStrategy>),
"dqn" => Some(Box::new(DqnStrategy::default()) "dqnburn" => {
as Box<dyn BotStrategy>), Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
}
s if s.starts_with("ai:") => { s if s.starts_with("ai:") => {
let path = s.trim_start_matches("ai:"); let path = s.trim_start_matches("ai:");
Some(Box::new(StableBaselines3Strategy::new(path)) Some(Box::new(StableBaselines3Strategy::new(path))
as Box<dyn BotStrategy>) as Box<dyn BotStrategy>)
} }
s if s.starts_with("dqn:") => { s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqn:"); let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnStrategy::new_with_model(path)) Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
as Box<dyn BotStrategy>) as Box<dyn BotStrategy>)
} }
_ => None, _ => None,
@ -101,7 +108,7 @@ impl App {
pub fn show_history(&self) { pub fn show_history(&self) {
for hist in self.game.state.history.iter() { for hist in self.game.state.history.iter() {
println!("{:?}\n", hist); println!("{hist:?}\n");
} }
} }
@ -126,6 +133,9 @@ impl App {
// &self.game.state.board, // &self.game.state.board,
// dice, // dice,
// ); // );
self.game.handle_event(&GameEvent::Roll {
player_id: self.game.player_id.unwrap(),
});
self.game.handle_event(&GameEvent::RollResult { self.game.handle_event(&GameEvent::RollResult {
player_id: self.game.player_id.unwrap(), player_id: self.game.player_id.unwrap(),
dice, dice,
@ -176,7 +186,7 @@ impl App {
return; return;
} }
} }
println!("invalid move : {}", input); println!("invalid move : {input}");
} }
pub fn display(&mut self) -> String { pub fn display(&mut self) -> String {
@ -316,6 +326,7 @@ Player :: holes :: points
seed: Some(1327), seed: Some(1327),
bot: Some("dummy".into()), bot: Some("dummy".into()),
}); });
println!("avant : {}", app.display());
app.input("roll"); app.input("roll");
app.input("1 3"); app.input("1 3");
app.input("1 4"); app.input("1 4");

View file

@ -1,4 +1,5 @@
use bot::{Bot, BotStrategy}; use bot::{Bot, BotStrategy};
use log::{debug, error};
use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage}; use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage};
// Application Game // Application Game
@ -62,11 +63,21 @@ impl GameRunner {
return None; return None;
} }
let valid_event = if self.state.validate(event) { let valid_event = if self.state.validate(event) {
debug!(
"--------------- new valid event {event:?} (stage {:?}) -----------",
self.state.turn_stage
);
self.state.consume(event); self.state.consume(event);
debug!(
" --> stage {:?} ; active player points {:?}",
self.state.turn_stage,
self.state.who_plays().map(|p| p.points)
);
event event
} else { } else {
println!("{}", self.state); debug!("{}", self.state);
println!("event not valid : {:?}", event); error!("event not valid : {event:?}");
// panic!("crash and burn {} \nevt not valid {event:?}", self.state);
&GameEvent::PlayError &GameEvent::PlayError
}; };

View file

@ -35,7 +35,7 @@ fn main() -> Result<()> {
let args = match parse_args() { let args = match parse_args() {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
eprintln!("Error: {}.", e); eprintln!("Error: {e}.");
std::process::exit(1); std::process::exit(1);
} }
}; };
@ -63,7 +63,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// Help has a higher priority and should be handled separately. // Help has a higher priority and should be handled separately.
if pargs.contains(["-h", "--help"]) { if pargs.contains(["-h", "--help"]) {
print!("{}", HELP); print!("{HELP}");
std::process::exit(0); std::process::exit(0);
} }
@ -78,7 +78,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// It's up to the caller what to do with the remaining arguments. // It's up to the caller what to do with the remaining arguments.
let remaining = pargs.finish(); let remaining = pargs.finish();
if !remaining.is_empty() { if !remaining.is_empty() {
eprintln!("Warning: unused arguments left: {:?}.", remaining); eprintln!("Warning: unused arguments left: {remaining:?}.");
} }
Ok(args) Ok(args)

View file

@ -1,14 +0,0 @@
[package]
name = "client_tui"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.89"
bincode = "1.3.3"
crossterm = "0.28.1"
ratatui = "0.28.1"
# renet = "0.0.13"
store = { path = "../store" }

View file

@ -1,53 +0,0 @@
// Application.
#[derive(Debug, Default)]
pub struct App {
// should the application exit?
pub should_quit: bool,
// counter
pub counter: u8,
}
impl App {
// Constructs a new instance of [`App`].
pub fn new() -> Self {
Self::default()
}
// Handles the tick event of the terminal.
pub fn tick(&self) {}
// Set running to false to quit the application.
pub fn quit(&mut self) {
self.should_quit = true;
}
pub fn increment_counter(&mut self) {
if let Some(res) = self.counter.checked_add(1) {
self.counter = res;
}
}
pub fn decrement_counter(&mut self) {
if let Some(res) = self.counter.checked_sub(1) {
self.counter = res;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_app_increment_counter() {
let mut app = App::default();
app.increment_counter();
assert_eq!(app.counter, 1);
}
#[test]
fn test_app_decrement_counter() {
let mut app = App::default();
app.decrement_counter();
assert_eq!(app.counter, 0);
}
}

View file

@ -1,87 +0,0 @@
use std::{
sync::mpsc,
thread,
time::{Duration, Instant},
};
use anyhow::Result;
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
// Terminal events.
#[derive(Clone, Copy, Debug)]
pub enum Event {
// Terminal tick.
Tick,
// Key press.
Key(KeyEvent),
// Mouse click/scroll.
Mouse(MouseEvent),
// Terminal resize.
Resize(u16, u16),
}
// Terminal event handler.
#[derive(Debug)]
pub struct EventHandler {
// Event sender channel.
#[allow(dead_code)]
sender: mpsc::Sender<Event>,
// Event receiver channel.
receiver: mpsc::Receiver<Event>,
// Event handler thread.
#[allow(dead_code)]
handler: thread::JoinHandle<()>,
}
impl EventHandler {
// Constructs a new instance of [`EventHandler`].
pub fn new(tick_rate: u64) -> Self {
let tick_rate = Duration::from_millis(tick_rate);
let (sender, receiver) = mpsc::channel();
let handler = {
let sender = sender.clone();
thread::spawn(move || {
let mut last_tick = Instant::now();
loop {
let timeout = tick_rate
.checked_sub(last_tick.elapsed())
.unwrap_or(tick_rate);
if event::poll(timeout).expect("no events available") {
match event::read().expect("unable to read event") {
CrosstermEvent::Key(e) => {
if e.kind == event::KeyEventKind::Press {
sender.send(Event::Key(e))
} else {
Ok(()) // ignore KeyEventKind::Release on windows
}
}
CrosstermEvent::Mouse(e) => sender.send(Event::Mouse(e)),
CrosstermEvent::Resize(w, h) => sender.send(Event::Resize(w, h)),
_ => unimplemented!(),
}
.expect("failed to send terminal event")
}
if last_tick.elapsed() >= tick_rate {
sender.send(Event::Tick).expect("failed to send tick event");
last_tick = Instant::now();
}
}
})
};
Self {
sender,
receiver,
handler,
}
}
// Receive the next event from the handler thread.
//
// This function will always block the current thread if
// there is no data available and it's possible for more data to be sent.
pub fn next(&self) -> Result<Event> {
Ok(self.receiver.recv()?)
}
}

View file

@ -1,50 +0,0 @@
// Application.
pub mod app;
// Terminal events handler.
pub mod event;
// Widget renderer.
pub mod ui;
// Terminal user interface.
pub mod tui;
// Application updater.
pub mod update;
use anyhow::Result;
use app::App;
use event::{Event, EventHandler};
use ratatui::{backend::CrosstermBackend, Terminal};
use tui::Tui;
use update::update;
fn main() -> Result<()> {
// Create an application.
let mut app = App::new();
// Initialize the terminal user interface.
let backend = CrosstermBackend::new(std::io::stderr());
let terminal = Terminal::new(backend)?;
let events = EventHandler::new(250);
let mut tui = Tui::new(terminal, events);
tui.enter()?;
// Start the main loop.
while !app.should_quit {
// Render the user interface.
tui.draw(&mut app)?;
// Handle events.
match tui.events.next()? {
Event::Tick => {}
Event::Key(key_event) => update(&mut app, key_event),
Event::Mouse(_) => {}
Event::Resize(_, _) => {}
};
}
// Exit the user interface.
tui.exit()?;
Ok(())
}

View file

@ -1,77 +0,0 @@
use std::{io, panic};
use anyhow::Result;
use crossterm::{
event::{DisableMouseCapture, EnableMouseCapture},
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
};
pub type CrosstermTerminal = ratatui::Terminal<ratatui::backend::CrosstermBackend<std::io::Stderr>>;
use crate::{app::App, event::EventHandler, ui};
// Representation of a terminal user interface.
//
// It is responsible for setting up the terminal,
// initializing the interface and handling the draw events.
pub struct Tui {
// Interface to the Terminal.
terminal: CrosstermTerminal,
// Terminal event handler.
pub events: EventHandler,
}
impl Tui {
// Constructs a new instance of [`Tui`].
pub fn new(terminal: CrosstermTerminal, events: EventHandler) -> Self {
Self { terminal, events }
}
// Initializes the terminal interface.
//
// It enables the raw mode and sets terminal properties.
pub fn enter(&mut self) -> Result<()> {
terminal::enable_raw_mode()?;
crossterm::execute!(io::stderr(), EnterAlternateScreen, EnableMouseCapture)?;
// Define a custom panic hook to reset the terminal properties.
// This way, you won't have your terminal messed up if an unexpected error happens.
let panic_hook = panic::take_hook();
panic::set_hook(Box::new(move |panic| {
Self::reset().expect("failed to reset the terminal");
panic_hook(panic);
}));
self.terminal.hide_cursor()?;
self.terminal.clear()?;
Ok(())
}
// [`Draw`] the terminal interface by [`rendering`] the widgets.
//
// [`Draw`]: tui::Terminal::draw
// [`rendering`]: crate::ui:render
pub fn draw(&mut self, app: &mut App) -> Result<()> {
self.terminal.draw(|frame| ui::render(app, frame))?;
Ok(())
}
// Resets the terminal interface.
//
// This function is also used for the panic hook to revert
// the terminal properties if unexpected errors occur.
fn reset() -> Result<()> {
terminal::disable_raw_mode()?;
crossterm::execute!(io::stderr(), LeaveAlternateScreen, DisableMouseCapture)?;
Ok(())
}
// Exits the terminal interface.
//
// It disables the raw mode and reverts back the terminal properties.
pub fn exit(&mut self) -> Result<()> {
Self::reset()?;
self.terminal.show_cursor()?;
Ok(())
}
}

View file

@ -1,30 +0,0 @@
use ratatui::{
prelude::{Alignment, Frame},
style::{Color, Style},
widgets::{Block, BorderType, Borders, Paragraph},
};
use crate::app::App;
pub fn render(app: &mut App, f: &mut Frame) {
f.render_widget(
Paragraph::new(format!(
"
Press `Esc`, `Ctrl-C` or `q` to stop running.\n\
Press `j` and `k` to increment and decrement the counter respectively.\n\
Counter: {}
",
app.counter
))
.block(
Block::default()
.title("Counter App")
.title_alignment(Alignment::Center)
.borders(Borders::ALL)
.border_type(BorderType::Rounded),
)
.style(Style::default().fg(Color::Yellow))
.alignment(Alignment::Center),
f.area(),
)
}

View file

@ -1,17 +0,0 @@
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use crate::app::App;
pub fn update(app: &mut App, key_event: KeyEvent) {
match key_event.code {
KeyCode::Esc | KeyCode::Char('q') => app.quit(),
KeyCode::Char('c') | KeyCode::Char('C') => {
if key_event.modifiers == KeyModifiers::CONTROL {
app.quit()
}
}
KeyCode::Right | KeyCode::Char('j') => app.increment_counter(),
KeyCode::Left | KeyCode::Char('k') => app.decrement_counter(),
_ => {}
};
}

View file

@ -10,8 +10,8 @@ MEMORY_SIZE
- À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au - À quoi ça sert : L'agent interagit avec l'environnement (le jeu de TricTrac) et stocke ses expériences (un état, l'action prise, la récompense obtenue, et l'état suivant) dans cette mémoire. Pour s'entraîner, au
lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire. lieu d'utiliser uniquement la dernière expérience, il pioche un lot (batch) d'expériences aléatoires dans cette mémoire.
- Pourquoi c'est important : - Pourquoi c'est important :
1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace. 1. Décorrélation : Ça casse la corrélation entre les expériences successives, ce qui rend l'entraînement plus stable et efficace.
2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données. 2. Réutilisation : Une même expérience peut être utilisée plusieurs fois pour l'entraînement, ce qui améliore l'efficacité des données.
- Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions. - Dans votre code : const MEMORY_SIZE: usize = 4096; signifie que l'agent gardera en mémoire les 4096 dernières transitions.
DENSE_SIZE DENSE_SIZE
@ -54,3 +54,53 @@ epsilon (ε) est la probabilité de faire un choix aléatoire (explorer).
En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de En résumé, ces constantes définissent l'architecture du "cerveau" de votre bot (DENSE*SIZE), sa mémoire à court terme (MEMORY_SIZE), et comment il apprend à équilibrer entre suivre sa stratégie et en découvrir de
nouvelles (EPS*\*). nouvelles (EPS*\*).
## Paramètres DQNTrainingConfig
1. `gamma` (Facteur d'actualisation / _Discount Factor_)
- À quoi ça sert ? Ça détermine l'importance des récompenses futures. Une valeur proche de 1 (ex: 0.99)
indique à l'agent qu'une récompense obtenue dans le futur est presque aussi importante qu'une
récompense immédiate. Il sera donc "patient" et capable de faire des sacrifices à court terme pour un
gain plus grand plus tard.
- Intuition : Un gamma de 0 rendrait l'agent "myope", ne se souciant que du prochain coup. Un gamma de
0.99 l'encourage à élaborer des stratégies à long terme.
2. `tau` (Taux de mise à jour douce / _Soft Update Rate_)
- À quoi ça sert ? Pour stabiliser l'apprentissage, les algorithmes DQN utilisent souvent deux réseaux
: un réseau principal qui apprend vite et un "réseau cible" (copie du premier) qui évolue lentement.
tau contrôle la vitesse à laquelle les connaissances du réseau principal sont transférées vers le
réseau cible.
- Intuition : Une petite valeur (ex: 0.005) signifie que le réseau cible, qui sert de référence stable,
ne se met à jour que très progressivement. C'est comme un "mentor" qui n'adopte pas immédiatement
toutes les nouvelles idées de son "élève", ce qui évite de déstabiliser tout l'apprentissage sur un
coup de chance (ou de malchance).
3. `learning_rate` (Taux d'apprentissage)
- À quoi ça sert ? C'est peut-être le plus classique des hyperparamètres. Il définit la "taille du
pas" lors de la correction des erreurs. Après chaque prédiction, l'agent compare le résultat à ce
qui s'est passé et ajuste ses poids. Le learning_rate détermine l'ampleur de cet ajustement.
- Intuition : Trop élevé, et l'agent risque de sur-corriger et de ne jamais converger (comme chercher
le fond d'une vallée en faisant des pas de géant). Trop bas, et l'apprentissage sera extrêmement
lent.
4. `batch_size` (Taille du lot)
- À quoi ça sert ? L'agent apprend de ses expériences passées, qu'il stocke dans une "mémoire". Pour
chaque session d'entraînement, au lieu d'apprendre d'une seule expérience, il en pioche un lot
(batch) au hasard (ex: 32 expériences). Il calcule l'erreur moyenne sur ce lot pour mettre à jour
ses poids.
- Intuition : Apprendre sur un lot plutôt que sur une seule expérience rend l'apprentissage plus
stable et plus général. L'agent se base sur une "moyenne" de situations plutôt que sur un cas
particulier qui pourrait être une anomalie.
5. `clip_grad` (Plafonnement du gradient / _Gradient Clipping_)
- À quoi ça sert ? C'est une sécurité pour éviter le problème des "gradients qui explosent". Parfois,
une expérience très inattendue peut produire une erreur de prédiction énorme, ce qui entraîne une
correction (un "gradient") démesurément grande. Une telle correction peut anéantir tout ce que le
réseau a appris.
- Intuition : clip_grad impose une limite. Si la correction à apporter dépasse un certain seuil, elle
est ramenée à cette valeur maximale. C'est un garde-fou qui dit : "OK, on a fait une grosse erreur,
mais on va corriger calmement, sans tout casser".

View file

@ -1,46 +0,0 @@
# Description du projet et question
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
Les règles du jeu et l'état d'une partie sont implémentées dans 'store', l'application ligne de commande est implémentée dans 'client_cli', elle permet déjà de jouer contre un bot, ou de faire jouer deux bots l'un contre l'autre.
Les stratégies de bots sont implémentées dans le dossier 'bot'.
Plus précisément, l'état du jeu est défini par le struct GameState dans store/src/game.rs, la méthode to_string_id() permet de coder cet état de manière compacte dans une chaîne de caractères, mais il n'y a pas l'historique des coups joués. Il y a aussi fmt::Display d'implémenté pour une representation textuelle plus lisible.
'client_cli/src/game_runner.rs' contient la logique permettant de faire jouer deux bots l'un contre l'autre.
'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning.
Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`)
Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`.
Une version, toujours DQN, mais en utilisant la bibliothèque burn (<https://burn.dev/>) est en cours de développement.
L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'.
Voici la sortie de l'entraînement lancé avec 'just trainbot' :
```
> Entraînement
> {"episode": 0, "reward": -1692.3148, "duration": 1000}
> {"episode": 1, "reward": -361.6962, "duration": 1000}
> {"episode": 2, "reward": -126.1013, "duration": 1000}
> {"episode": 3, "reward": -36.8000, "duration": 1000}
> {"episode": 4, "reward": -21.4997, "duration": 1000}
> {"episode": 5, "reward": -8.3000, "duration": 1000}
> {"episode": 6, "reward": 3.1000, "duration": 1000}
> {"episode": 7, "reward": -21.5998, "duration": 1000}
> {"episode": 8, "reward": -10.1999, "duration": 1000}
> {"episode": 9, "reward": 3.1000, "duration": 1000}
> {"episode": 10, "reward": 14.5002, "duration": 1000}
> {"episode": 11, "reward": 10.7000, "duration": 1000}
> {"episode": 12, "reward": -0.7000, "duration": 1000}
thread 'main' has overflowed its stack
fatal runtime error: stack overflow
error: Recipe `trainbot` was terminated on line 25 by signal 6
```
Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug <https://github.com/yunjhongwu/burn-rl-examples/issues/40> , donc peut-être que le problème vient du paquet 'burl-rl'.

View file

@ -1,46 +1,54 @@
# Inspirations # Inspirations
tools tools
- config clippy ?
- bacon : tests runner (ou loom ?) - config clippy ?
- bacon : tests runner (ou loom ?)
## Rust libs ## Rust libs
cf. https://blessed.rs/crates cf. <https://blessed.rs/crates>
nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-repeat-resume/ nombres aléatoires avec seed : <https://richard.dallaway.com/posts/2021-01-04-repeat-resume/>
- cli : https://lib.rs/crates/pico-args ( ou clap ) - cli : <https://lib.rs/crates/pico-args> ( ou clap )
- reseau async : tokio - reseau async : tokio
- web serveur : axum (uses tokio) - web serveur : axum (uses tokio)
- https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling - <https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling>
- db : sqlx - db : sqlx
- eyre, color-eyre (Results) - eyre, color-eyre (Results)
- tracing (logging) - tracing (logging)
- rayon ( sync <-> parallel ) - rayon ( sync <-> parallel )
- front : yew + tauri - front : yew + tauri
- egui - egui
- https://docs.rs/board-game/latest/board_game/ - <https://docs.rs/board-game/latest/board_game/>
## network games
- <https://www.mattkeeter.com/projects/pont/>
- <https://github.com/jackadamson/onitama> (wasm, rooms)
- <https://github.com/UkoeHB/renet2>
- <https://github.com/UkoeHB/bevy_simplenet>
## Others ## Others
- plugins avec https://github.com/extism/extism
- plugins avec <https://github.com/extism/extism>
## Backgammon existing projects ## Backgammon existing projects
* go : https://bgammon.org/blog/20240101-hello-world/ - go : <https://bgammon.org/blog/20240101-hello-world/>
- protocole de communication : https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md - protocole de communication : <https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md>
* ocaml : https://github.com/jacobhilton/backgammon?tab=readme-ov-file - ocaml : <https://github.com/jacobhilton/backgammon?tab=readme-ov-file>
cli example : https://www.jacobh.co.uk/backgammon/ cli example : <https://www.jacobh.co.uk/backgammon/>
* lib rust backgammon - lib rust backgammon
- https://github.com/carlostrub/backgammon - <https://github.com/carlostrub/backgammon>
- https://github.com/marktani/backgammon - <https://github.com/marktani/backgammon>
* network webtarot - network webtarot
* front ? - front ?
## cli examples ## cli examples
@ -48,7 +56,7 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
(No game) new game (No game) new game
gnubg rolls 3, anthon rolls 1. gnubg rolls 3, anthon rolls 1.
GNU Backgammon Positions ID: 4HPwATDgc/ABMA GNU Backgammon Positions ID: 4HPwATDgc/ABMA
Match ID : MIEFAAAAAAAA Match ID : MIEFAAAAAAAA
+12-11-10--9--8--7-------6--5--4--3--2--1-+ O: gnubg +12-11-10--9--8--7-------6--5--4--3--2--1-+ O: gnubg
@ -64,7 +72,7 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
| O X | | X O | | O X | | X O |
| O X | | X O | 0 points | O X | | X O | 0 points
+13-14-15-16-17-18------19-20-21-22-23-24-+ X: anthon +13-14-15-16-17-18------19-20-21-22-23-24-+ X: anthon
gnubg moves 8/5 6/5. gnubg moves 8/5 6/5.
### jacobh ### jacobh
@ -72,33 +80,37 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
Move 11: player O rolls a 6-2. Move 11: player O rolls a 6-2.
Player O estimates that they have a 90.6111% chance of winning. Player O estimates that they have a 90.6111% chance of winning.
Os borne off: none Os borne off: none
24 23 22 21 20 19 18 17 16 15 14 13 24 23 22 21 20 19 18 17 16 15 14 13
-------------------------------------------------------------------
| v v v v v v | | v v v v v v | ---
| | | |
| X O O O | | O O O | | v v v v v v | | v v v v v v |
| X O O O | | O O | | | | |
| O | | | | X O O O | | O O O |
| | X | | | X O O O | | O O |
| | | | | O | | |
| | | | | | X | |
| | | | | | | |
| | | | | | | |
|------------------------------| |------------------------------| | | | |
| | | | | | | |
| | | | |------------------------------| |------------------------------|
| | | | | | | |
| | | | | | | |
| X | | | | | | |
| X X | | X | | | | |
| X X X | | X O | | X | | |
| X X X | | X O O | | X X | | X |
| | | | | X X X | | X O |
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ | | X X X | | X O O |
------------------------------------------------------------------- | | | |
1 2 3 4 5 6 7 8 9 10 11 12 | ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
Xs borne off: none
---
1 2 3 4 5 6 7 8 9 10 11 12
Xs borne off: none
Move 12: player X rolls a 6-3. Move 12: player X rolls a 6-3.
Your move (? for help): bar/22 Your move (? for help): bar/22
@ -107,13 +119,12 @@ Your move (? for help): ?
Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move. Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move.
Each position should be number from 1 to 24, "bar" or "off". Each position should be number from 1 to 24, "bar" or "off".
Unlike in standard notation, you should enter each counter movement individually. For example: Unlike in standard notation, you should enter each counter movement individually. For example:
24/18 18/13 24/18 18/13
bar/3 13/10 13/10 8/5 bar/3 13/10 13/10 8/5
2/off 1/off 2/off 1/off
You can also enter these commands: You can also enter these commands:
p - show the previous move p - show the previous move
n - show the next move n - show the next move
<enter> - toggle between showing the current and last moves <enter> - toggle between showing the current and last moves
help - show this help text help - show this help text
quit - abandon game quit - abandon game

172
doc/specs/store.puml Normal file
View file

@ -0,0 +1,172 @@
@startuml
class "CheckerMove" {
- from: Field
- to: Field
+ to_display_string()
+ new(from: Field, to: Field)
+ mirror()
+ chain(cmove: Self)
+ get_from()
+ get_to()
+ is_exit()
+ doable_with_dice(dice: usize)
}
class "Board" {
- positions: [i8;24]
+ new()
+ mirror()
+ set_positions(positions: [ i8 ; 24 ])
+ count_checkers(color: Color, from: Field, to: Field)
+ to_vec()
+ to_gnupg_pos_id()
+ to_display_grid(col_size: usize)
+ set(color: & Color, field: Field, amount: i8)
+ blocked(color: & Color, field: Field)
+ passage_blocked(color: & Color, field: Field)
+ get_field_checkers(field: Field)
+ get_checkers_color(field: Field)
+ is_field_in_small_jan(field: Field)
+ get_color_fields(color: Color)
+ get_color_corner(color: & Color)
+ get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool)
+ passage_possible(color: & Color, cmove: & CheckerMove)
+ move_possible(color: & Color, cmove: & CheckerMove)
+ any_quarter_filled(color: Color)
+ is_quarter_filled(color: Color, field: Field)
+ get_quarter_filling_candidate(color: Color)
+ is_quarter_fillable(color: Color, field: Field)
- get_quarter_fields(field: Field)
+ move_checker(color: & Color, cmove: CheckerMove)
+ remove_checker(color: & Color, field: Field)
+ add_checker(color: & Color, field: Field)
}
class "MoveRules" {
+ board: Board
+ dice: Dice
+ new(color: & Color, board: & Board, dice: Dice)
+ set_board(color: & Color, board: & Board)
- get_board_from_color(color: & Color, board: & Board)
+ moves_follow_rules(moves: & ( CheckerMove , CheckerMove ))
- moves_possible(moves: & ( CheckerMove , CheckerMove ))
- moves_follows_dices(moves: & ( CheckerMove , CheckerMove ))
- get_move_compatible_dices(cmove: & CheckerMove)
+ moves_allowed(moves: & ( CheckerMove , CheckerMove ))
- check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
- check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
- check_corner_rules(moves: & ( CheckerMove , CheckerMove ))
- has_checkers_outside_last_quarter()
- check_exit_rules(moves: & ( CheckerMove , CheckerMove ))
+ get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >)
+ get_scoring_quarter_filling_moves_sequences()
- get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field)
+ get_quarter_filling_moves_sequences()
- get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >)
- _get_direct_exit_moves(state: & GameState)
- is_move_by_puissance(moves: & ( CheckerMove , CheckerMove ))
- can_take_corner_by_effect()
}
class "DiceRoller" {
- rng: StdRng
+ new(opt_seed: Option < u64 >)
+ roll()
}
class "Dice" {
+ values: (u8,u8)
+ to_bits_string()
+ to_display_string()
+ is_double()
}
class "GameState" {
+ stage: Stage
+ turn_stage: TurnStage
+ board: Board
+ active_player_id: PlayerId
+ players: HashMap<PlayerId,Player>
+ history: Vec<GameEvent>
+ dice: Dice
+ dice_points: (u8,u8)
+ dice_moves: (CheckerMove,CheckerMove)
+ dice_jans: PossibleJans
- roll_first: bool
+ schools_enabled: bool
+ new(schools_enabled: bool)
- set_schools_enabled(schools_enabled: bool)
- get_active_player()
- get_opponent_id()
+ to_vec_float()
+ to_vec()
+ to_string_id()
+ who_plays()
+ get_white_player()
+ get_black_player()
+ player_id_by_color(color: Color)
+ player_id(player: & Player)
+ player_color_by_id(player_id: & PlayerId)
+ validate(event: & GameEvent)
+ init_player(player_name: & str)
- add_player(player_id: PlayerId, player: Player)
+ switch_active_player()
+ consume(valid_event: & GameEvent)
- new_pick_up()
- get_rollresult_jans(dice: & Dice)
+ determine_winner()
- inc_roll_count(player_id: PlayerId)
- mark_points(player_id: PlayerId, points: u8)
}
class "Player" {
+ name: String
+ color: Color
+ points: u8
+ holes: u8
+ can_bredouille: bool
+ can_big_bredouille: bool
+ dice_roll_count: u8
+ new(name: String, color: Color)
+ to_bits_string()
+ to_vec()
}
class "PointsRules" {
+ board: Board
+ dice: Dice
+ move_rules: MoveRules
+ new(color: & Color, board: & Board, dice: Dice)
+ set_dice(dice: Dice)
+ update_positions(positions: [ i8 ; 24 ])
- get_jans(board_ini: & Board, dice_rolls_count: u8)
+ get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >)
+ get_points(dice_rolls_count: u8)
+ get_result_jans(dice_rolls_count: u8)
}
"MoveRules" <-- "Board"
"MoveRules" <-- "Dice"
"GameState" <-- "Board"
"HashMap<PlayerId,Player>" <-- "Player"
"GameState" <-- "HashMap<PlayerId,Player>"
"GameState" <-- "Dice"
"PointsRules" <-- "Board"
"PointsRules" <-- "Dice"
"PointsRules" <-- "MoveRules"
@enduml

View file

@ -9,8 +9,9 @@ shell:
runcli: runcli:
RUST_LOG=info cargo run --bin=client_cli RUST_LOG=info cargo run --bin=client_cli
runclibots: runclibots:
#RUST_LOG=info cargo run --bin=client_cli -- --bot dqn,dummy cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn #cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match: match:
cargo build --release --bin=client_cli cargo build --release --bin=client_cli
LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn LD_LIBRARY_PATH=./target/release ./target/release/client_cli -- --bot dummy,dqn
@ -21,15 +22,13 @@ profile:
pythonlib: pythonlib:
maturin build -m store/Cargo.toml --release maturin build -m store/Cargo.toml --release
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot: trainbot algo:
#python ./store/python/trainModel.py #python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok # cargo run --bin=train_dqn # ok
# cargo run --bin=train_dqn_burn # utilise debug (why ?) # ./bot/scripts/trainValid.sh
cargo build --release --bin=train_dqn_burn ./bot/scripts/train.sh {{algo}}
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_burn | tee /tmp/train.out plottrainbot algo:
plottrainbot: ./bot/scripts/train.sh plot {{algo}}
cat /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
#tail -f /tmp/train.out | awk -F '[ ,]' '{print $5}' | feedgnuplot --lines --points --unset grid
debugtrainbot: debugtrainbot:
cargo build --bin=train_dqn_burn cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn

View file

@ -1,14 +0,0 @@
[package]
name = "trictrac-server"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
store = { path = "../store" }
env_logger = "0.10.0"
log = "0.4.20"
pico-args = "0.5.0"
renet = "0.0.13"
bincode = "1.3.3"

View file

@ -1,147 +0,0 @@
use log::{info, trace, warn};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
use std::thread;
use std::time::{Duration, Instant, SystemTime};
use renet::{
transport::{
NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
},
ConnectionConfig, RenetServer, ServerEvent,
};
// Only clients that can provide the same PROTOCOL_ID that the server is using will be able to connect.
// This can be used to make sure players use the most recent version of the client for instance.
pub const PROTOCOL_ID: u64 = 2878;
/// Utility function for extracting a players name from renet user data
fn name_from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> String {
let mut buffer = [0u8; 8];
buffer.copy_from_slice(&user_data[0..8]);
let mut len = u64::from_le_bytes(buffer) as usize;
len = len.min(NETCODE_USER_DATA_BYTES - 8);
let data = user_data[8..len + 8].to_vec();
String::from_utf8(data).unwrap()
}
fn main() {
env_logger::init();
let mut server = RenetServer::new(ConnectionConfig::default());
// Setup transport layer
const SERVER_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5000);
let socket: UdpSocket = UdpSocket::bind(SERVER_ADDR).unwrap();
let server_config = ServerConfig {
max_clients: 2,
protocol_id: PROTOCOL_ID,
public_addr: SERVER_ADDR,
authentication: ServerAuthentication::Unsecure,
};
let current_time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
trace!("❂ TricTrac server listening on {}", SERVER_ADDR);
let mut game_state = store::GameState::default();
let mut last_updated = Instant::now();
loop {
// Update server time
let now = Instant::now();
let delta_time = now - last_updated;
server.update(delta_time);
transport.update(delta_time, &mut server).unwrap();
last_updated = now;
// Receive connection events from clients
while let Some(event) = server.get_event() {
match event {
ServerEvent::ClientConnected { client_id } => {
let user_data = transport.user_data(client_id).unwrap();
// Tell the recently joined player about the other player
for (player_id, player) in game_state.players.iter() {
let event = store::GameEvent::PlayerJoined {
player_id: *player_id,
name: player.name.clone(),
};
server.send_message(client_id, 0, bincode::serialize(&event).unwrap());
}
// Add the new player to the game
let event = store::GameEvent::PlayerJoined {
player_id: client_id,
name: name_from_user_data(&user_data),
};
game_state.consume(&event);
// Tell all players that a new player has joined
server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("🎉 Client {} connected.", client_id);
// In TicTacTussle the game can begin once two players has joined
if game_state.players.len() == 2 {
let event = store::GameEvent::BeginGame {
goes_first: client_id,
};
game_state.consume(&event);
server.broadcast_message(0, bincode::serialize(&event).unwrap());
trace!("The game gas begun");
}
}
ServerEvent::ClientDisconnected {
client_id,
reason: _,
} => {
// First consume a disconnect event
let event = store::GameEvent::PlayerDisconnected {
player_id: client_id,
};
game_state.consume(&event);
server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("Client {} disconnected", client_id);
// Then end the game, since tic tac toe can't go on with a single player
let event = store::GameEvent::EndGame {
reason: store::EndGameReason::PlayerLeft {
player_id: client_id,
},
};
game_state.consume(&event);
server.broadcast_message(0, bincode::serialize(&event).unwrap());
// NOTE: Since we don't authenticate users we can't do any reconnection attempts.
// We simply have no way to know if the next user is the same as the one that disconnected.
}
}
}
// Receive GameEvents from clients. Broadcast valid events.
for client_id in server.clients_id().into_iter() {
while let Some(message) = server.receive_message(client_id, 0) {
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
if game_state.validate(&event) {
game_state.consume(&event);
trace!("Player {} sent:\n\t{:#?}", client_id, event);
server.broadcast_message(0, bincode::serialize(&event).unwrap());
// Determine if a player has won the game
if let Some(winner) = game_state.determine_winner() {
let event = store::GameEvent::EndGame {
reason: store::EndGameReason::PlayerWon { winner },
};
server.broadcast_message(0, bincode::serialize(&event).unwrap());
}
} else {
warn!("Player {} sent invalid event:\n\t{:#?}", client_id, event);
}
}
}
}
transport.send_packets(&mut server);
thread::sleep(Duration::from_millis(50));
}
}

View file

@ -8,7 +8,7 @@ use std::fmt;
pub type Field = usize; pub type Field = usize;
pub type FieldWithCount = (Field, i8); pub type FieldWithCount = (Field, i8);
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Deserialize)] #[derive(Debug, Copy, Clone, Serialize, PartialEq, Eq, Deserialize)]
pub struct CheckerMove { pub struct CheckerMove {
from: Field, from: Field,
to: Field, to: Field,
@ -37,7 +37,7 @@ impl Default for CheckerMove {
impl CheckerMove { impl CheckerMove {
pub fn to_display_string(self) -> String { pub fn to_display_string(self) -> String {
format!("{:?} ", self) format!("{self:?} ")
} }
pub fn new(from: Field, to: Field) -> Result<Self, Error> { pub fn new(from: Field, to: Field) -> Result<Self, Error> {
@ -94,7 +94,7 @@ impl CheckerMove {
} }
/// Represents the Tric Trac board /// Represents the Tric Trac board
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Board { pub struct Board {
positions: [i8; 24], positions: [i8; 24],
} }
@ -114,7 +114,7 @@ impl fmt::Display for Board {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new(); let mut s = String::new();
s.push_str(&format!("{:?}", self.positions)); s.push_str(&format!("{:?}", self.positions));
write!(f, "{}", s) write!(f, "{s}")
} }
} }
@ -132,8 +132,13 @@ impl Board {
} }
/// Globally set pieces on board ( for tests ) /// Globally set pieces on board ( for tests )
pub fn set_positions(&mut self, positions: [i8; 24]) { pub fn set_positions(&mut self, color: &Color, positions: [i8; 24]) {
self.positions = positions; let mut new_positions = positions;
if color == &Color::Black {
new_positions = new_positions.map(|c| 0 - c);
new_positions.reverse();
}
self.positions = new_positions;
} }
pub fn count_checkers(&self, color: Color, from: Field, to: Field) -> u8 { pub fn count_checkers(&self, color: Color, from: Field, to: Field) -> u8 {
@ -153,6 +158,42 @@ impl Board {
.unsigned_abs() .unsigned_abs()
} }
// get the number of the last checker in a field
pub fn get_field_checker(&self, color: &Color, field: Field) -> u8 {
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
let mut total_count: u8 = 0;
for (i, checker_count) in self.positions.iter().enumerate() {
// count white checkers (checker_count > 0)
if *checker_count > 0 {
total_count += *checker_count as u8;
if field == i + 1 {
return total_count;
}
}
}
0
}
// get the field of the nth checker
pub fn get_checker_field(&self, color: &Color, checker_pos: u8) -> Option<Field> {
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
if checker_pos == 0 {
return None;
}
let mut total_count: u8 = 0;
for (i, checker_count) in self.positions.iter().enumerate() {
// count white checkers (checker_count > 0)
if *checker_count > 0 {
total_count += *checker_count as u8;
}
// return the current field if it contains the checker
if checker_pos <= total_count {
return Some(i + 1);
}
}
None
}
pub fn to_vec(&self) -> Vec<i8> { pub fn to_vec(&self) -> Vec<i8> {
self.positions.to_vec() self.positions.to_vec()
} }
@ -230,7 +271,7 @@ impl Board {
.map(|cells| { .map(|cells| {
cells cells
.into_iter() .into_iter()
.map(|cell| format!("{:>5}", cell)) .map(|cell| format!("{cell:>5}"))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("") .join("")
}) })
@ -241,7 +282,7 @@ impl Board {
.map(|cells| { .map(|cells| {
cells cells
.into_iter() .into_iter()
.map(|cell| format!("{:>5}", cell)) .map(|cell| format!("{cell:>5}"))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("") .join("")
}) })
@ -564,7 +605,7 @@ impl Board {
} }
let checker_color = self.get_checkers_color(field)?; let checker_color = self.get_checkers_color(field)?;
if Some(color) != checker_color { if Some(color) != checker_color {
println!("field invalid : {:?}, {:?}, {:?}", color, field, self); println!("field invalid : {color:?}, {field:?}, {self:?}");
return Err(Error::FieldInvalid); return Err(Error::FieldInvalid);
} }
let unit = match color { let unit = match color {
@ -598,6 +639,55 @@ impl Board {
self.positions[field - 1] += unit; self.positions[field - 1] += unit;
Ok(()) Ok(())
} }
pub fn from_gnupg_pos_id(bits: &str) -> Result<Board, String> {
let mut positions = [0i8; 24];
let mut bit_idx = 0;
let bit_chars: Vec<char> = bits.chars().collect();
// White checkers (points 1 to 24)
for i in 0..24 {
if bit_idx >= bit_chars.len() {
break;
}
let mut count = 0;
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
count += 1;
bit_idx += 1;
}
positions[i] = count;
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
bit_idx += 1; // Consume the '0' separator
}
}
// Black checkers (points 24 down to 1)
for i in (0..24).rev() {
if bit_idx >= bit_chars.len() {
break;
}
let mut count = 0;
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
count += 1;
bit_idx += 1;
}
if positions[i] == 0 {
positions[i] = -count;
} else if count > 0 {
return Err(format!(
"Invalid board: checkers of both colors on point {}",
i + 1
));
}
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
bit_idx += 1; // Consume the '0' separator
}
}
Ok(Board { positions })
}
} }
// Unit Tests // Unit Tests
@ -672,9 +762,12 @@ mod tests {
#[test] #[test]
fn is_quarter_fillable() { fn is_quarter_fillable() {
let mut board = Board::new(); let mut board = Board::new();
board.set_positions([ board.set_positions(
15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
assert!(board.is_quarter_fillable(Color::Black, 1)); assert!(board.is_quarter_fillable(Color::Black, 1));
assert!(!board.is_quarter_fillable(Color::Black, 12)); assert!(!board.is_quarter_fillable(Color::Black, 12));
assert!(board.is_quarter_fillable(Color::Black, 13)); assert!(board.is_quarter_fillable(Color::Black, 13));
@ -683,25 +776,62 @@ mod tests {
assert!(board.is_quarter_fillable(Color::White, 12)); assert!(board.is_quarter_fillable(Color::White, 12));
assert!(!board.is_quarter_fillable(Color::White, 13)); assert!(!board.is_quarter_fillable(Color::White, 13));
assert!(board.is_quarter_fillable(Color::White, 24)); assert!(board.is_quarter_fillable(Color::White, 24));
board.set_positions([ board.set_positions(
5, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -8, 0, 0, 0, 0, 0, -5, &Color::White,
]); [
5, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -8, 0, 0, 0, 0, 0, -5,
],
);
assert!(board.is_quarter_fillable(Color::Black, 13)); assert!(board.is_quarter_fillable(Color::Black, 13));
assert!(!board.is_quarter_fillable(Color::Black, 24)); assert!(!board.is_quarter_fillable(Color::Black, 24));
assert!(!board.is_quarter_fillable(Color::White, 1)); assert!(!board.is_quarter_fillable(Color::White, 1));
assert!(board.is_quarter_fillable(Color::White, 12)); assert!(board.is_quarter_fillable(Color::White, 12));
board.set_positions([ board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0,
],
);
assert!(board.is_quarter_fillable(Color::Black, 16)); assert!(board.is_quarter_fillable(Color::Black, 16));
} }
#[test] #[test]
fn get_quarter_filling_candidate() { fn get_quarter_filling_candidate() {
let mut board = Board::new(); let mut board = Board::new();
board.set_positions([ board.set_positions(
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White)); assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White));
} }
#[test]
fn get_checker_field() {
let mut board = Board::new();
board.set_positions(
&Color::White,
[
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
assert_eq!(None, board.get_checker_field(&Color::White, 0));
assert_eq!(Some(3), board.get_checker_field(&Color::White, 5));
assert_eq!(Some(3), board.get_checker_field(&Color::White, 6));
assert_eq!(None, board.get_checker_field(&Color::White, 14));
}
#[test]
fn get_field_checker() {
let mut board = Board::new();
board.set_positions(
&Color::White,
[
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
assert_eq!(4, board.get_field_checker(&Color::White, 2));
assert_eq!(6, board.get_field_checker(&Color::White, 3));
}
} }

View file

@ -44,7 +44,7 @@ impl DiceRoller {
/// Represents the two dice /// Represents the two dice
/// ///
/// Trictrac is always played with two dice. /// Trictrac is always played with two dice.
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize, Default)] #[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize, Default)]
pub struct Dice { pub struct Dice {
/// The two dice values /// The two dice values
pub values: (u8, u8), pub values: (u8, u8),
@ -55,6 +55,17 @@ impl Dice {
format!("{:0>3b}{:0>3b}", self.values.0, self.values.1) format!("{:0>3b}{:0>3b}", self.values.0, self.values.1)
} }
pub fn from_bits_string(bits: &str) -> Result<Self, String> {
if bits.len() != 6 {
return Err("Invalid bit string length for dice".to_string());
}
let d1_str = &bits[0..3];
let d2_str = &bits[3..6];
let d1 = u8::from_str_radix(d1_str, 2).map_err(|e| e.to_string())?;
let d2 = u8::from_str_radix(d2_str, 2).map_err(|e| e.to_string())?;
Ok(Dice { values: (d1, d2) })
}
pub fn to_display_string(self) -> String { pub fn to_display_string(self) -> String {
format!("{} & {}", self.values.0, self.values.1) format!("{} & {}", self.values.0, self.values.1)
} }

View file

@ -4,17 +4,18 @@ use crate::dice::Dice;
use crate::game_rules_moves::MoveRules; use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::{PointsRules, PossibleJans}; use crate::game_rules_points::{PointsRules, PossibleJans};
use crate::player::{Color, Player, PlayerId}; use crate::player::{Color, Player, PlayerId};
use log::{error, info}; use log::{debug, error};
// use itertools::Itertools; // use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::{fmt, str}; use std::{fmt, str};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
/// The different stages a game can be in. (not to be confused with the entire "GameState") /// The different stages a game can be in. (not to be confused with the entire "GameState")
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Stage { pub enum Stage {
PreGame, PreGame,
InGame, InGame,
@ -22,7 +23,7 @@ pub enum Stage {
} }
/// The different stages a game turn can be in. /// The different stages a game turn can be in.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TurnStage { pub enum TurnStage {
RollDice, RollDice,
RollWaiting, RollWaiting,
@ -60,7 +61,7 @@ impl From<TurnStage> for u8 {
} }
/// Represents a TricTrac game /// Represents a TricTrac game
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct GameState { pub struct GameState {
pub stage: Stage, pub stage: Stage,
pub turn_stage: TurnStage, pub turn_stage: TurnStage,
@ -91,7 +92,8 @@ impl fmt::Display for GameState {
s.push_str(&format!("Dice: {:?}\n", self.dice)); s.push_str(&format!("Dice: {:?}\n", self.dice));
// s.push_str(&format!("Who plays: {}\n", self.who_plays().map(|player| &player.name ).unwrap_or(""))); // s.push_str(&format!("Who plays: {}\n", self.who_plays().map(|player| &player.name ).unwrap_or("")));
s.push_str(&format!("Board: {:?}\n", self.board)); s.push_str(&format!("Board: {:?}\n", self.board));
write!(f, "{}", s) // s.push_str(&format!("History: {:?}\n", self.history));
write!(f, "{s}")
} }
} }
@ -113,6 +115,11 @@ impl Default for GameState {
} }
} }
} }
impl Hash for GameState {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_string_id().hash(state);
}
}
impl GameState { impl GameState {
/// Create a new default game /// Create a new default game
@ -122,6 +129,15 @@ impl GameState {
gs gs
} }
pub fn new_with_players(p1_name: &str, p2_name: &str) -> Self {
let mut game = Self::default();
if let Some(p1) = game.init_player(p1_name) {
game.init_player(p2_name);
game.consume(&GameEvent::BeginGame { goes_first: p1 });
}
game
}
fn set_schools_enabled(&mut self, schools_enabled: bool) { fn set_schools_enabled(&mut self, schools_enabled: bool) {
self.schools_enabled = schools_enabled; self.schools_enabled = schools_enabled;
} }
@ -150,6 +166,7 @@ impl GameState {
/// Get state as a vector (to be used for bot training input) : /// Get state as a vector (to be used for bot training input) :
/// length = 36 /// length = 36
/// i8 for board positions with negative values for blacks
pub fn to_vec(&self) -> Vec<i8> { pub fn to_vec(&self) -> Vec<i8> {
let state_len = 36; let state_len = 36;
let mut state = Vec::with_capacity(state_len); let mut state = Vec::with_capacity(state_len);
@ -242,7 +259,7 @@ impl GameState {
pos_bits.push_str(&white_bits); pos_bits.push_str(&white_bits);
pos_bits.push_str(&black_bits); pos_bits.push_str(&black_bits);
pos_bits = format!("{:0>108}", pos_bits); pos_bits = format!("{pos_bits:0<108}");
// println!("{}", pos_bits); // println!("{}", pos_bits);
let pos_u8 = pos_bits let pos_u8 = pos_bits
.as_bytes() .as_bytes()
@ -253,6 +270,81 @@ impl GameState {
general_purpose::STANDARD.encode(pos_u8) general_purpose::STANDARD.encode(pos_u8)
} }
pub fn from_string_id(id: &str) -> Result<Self, String> {
let bytes = general_purpose::STANDARD
.decode(id)
.map_err(|e| e.to_string())?;
let bits_str: String = bytes.iter().map(|byte| format!("{:06b}", byte)).collect();
// The original string was padded to 108 bits.
let bits = if bits_str.len() >= 108 {
&bits_str[..108]
} else {
return Err("Invalid decoded string length".to_string());
};
let board_bits = &bits[0..77];
let board = Board::from_gnupg_pos_id(board_bits)?;
let active_player_bit = bits.chars().nth(77).unwrap();
let active_player_color = if active_player_bit == '1' {
Color::Black
} else {
Color::White
};
let turn_stage_bits = &bits[78..81];
let turn_stage = match turn_stage_bits {
"000" => TurnStage::RollWaiting,
"001" => TurnStage::RollDice,
"010" => TurnStage::MarkPoints,
"011" => TurnStage::HoldOrGoChoice,
"100" => TurnStage::Move,
"101" => TurnStage::MarkAdvPoints,
_ => return Err(format!("Invalid bits for turn stage : {turn_stage_bits}")),
};
let dice_bits = &bits[81..87];
let dice = Dice::from_bits_string(dice_bits).map_err(|e| e.to_string())?;
let white_player_bits = &bits[87..97];
let black_player_bits = &bits[97..107];
let white_player =
Player::from_bits_string(white_player_bits, "Player 1".to_string(), Color::White)
.map_err(|e| e.to_string())?;
let black_player =
Player::from_bits_string(black_player_bits, "Player 2".to_string(), Color::Black)
.map_err(|e| e.to_string())?;
let mut players = HashMap::new();
players.insert(1, white_player);
players.insert(2, black_player);
let active_player_id = if active_player_color == Color::White {
1
} else {
2
};
// Some fields are not in the ID, so we use defaults.
Ok(GameState {
stage: Stage::InGame, // Assume InGame from ID
turn_stage,
board,
active_player_id,
players,
history: Vec::new(),
dice,
dice_points: (0, 0),
dice_moves: (CheckerMove::default(), CheckerMove::default()),
dice_jans: PossibleJans::default(),
roll_first: false, // Assume not first roll
schools_enabled: false, // Assume disabled
})
}
pub fn who_plays(&self) -> Option<&Player> { pub fn who_plays(&self) -> Option<&Player> {
self.get_active_player() self.get_active_player()
} }
@ -336,7 +428,7 @@ impl GameState {
return false; return false;
} }
} }
Roll { player_id } | RollResult { player_id, dice: _ } => { Roll { player_id } => {
// Check player exists // Check player exists
if !self.players.contains_key(player_id) { if !self.players.contains_key(player_id) {
return false; return false;
@ -345,6 +437,26 @@ impl GameState {
if self.active_player_id != *player_id { if self.active_player_id != *player_id {
return false; return false;
} }
// Check the turn stage
if self.turn_stage != TurnStage::RollDice {
error!("bad stage {:?}", self.turn_stage);
return false;
}
}
RollResult { player_id, dice: _ } => {
// Check player exists
if !self.players.contains_key(player_id) {
return false;
}
// Check player is currently the one making their move
if self.active_player_id != *player_id {
return false;
}
// Check the turn stage
if self.turn_stage != TurnStage::RollWaiting {
error!("bad stage {:?}", self.turn_stage);
return false;
}
} }
Mark { Mark {
player_id, player_id,
@ -372,22 +484,30 @@ impl GameState {
} }
Go { player_id } => { Go { player_id } => {
if !self.players.contains_key(player_id) { if !self.players.contains_key(player_id) {
error!("Player {} unknown", player_id); error!("Player {player_id} unknown");
return false; return false;
} }
// Check player is currently the one making their move // Check player is currently the one making their move
if self.active_player_id != *player_id { if self.active_player_id != *player_id {
error!("Player not active : {}", self.active_player_id);
return false; return false;
} }
// Check the player can leave (ie the game is in the KeepOrLeaveChoice stage) // Check the player can leave (ie the game is in the KeepOrLeaveChoice stage)
if self.turn_stage != TurnStage::HoldOrGoChoice { if self.turn_stage != TurnStage::HoldOrGoChoice {
error!("bad stage {:?}", self.turn_stage);
error!(
"black player points : {:?}",
self.get_black_player()
.map(|player| (player.points, player.holes))
);
// error!("history {:?}", self.history);
return false; return false;
} }
} }
Move { player_id, moves } => { Move { player_id, moves } => {
// Check player exists // Check player exists
if !self.players.contains_key(player_id) { if !self.players.contains_key(player_id) {
error!("Player {} unknown", player_id); error!("Player {player_id} unknown");
return false; return false;
} }
// Check player is currently the one making their move // Check player is currently the one making their move
@ -512,12 +632,15 @@ impl GameState {
self.inc_roll_count(self.active_player_id); self.inc_roll_count(self.active_player_id);
self.turn_stage = TurnStage::MarkPoints; self.turn_stage = TurnStage::MarkPoints;
(self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice); (self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice);
debug!("points from result : {:?}", self.dice_points);
if !self.schools_enabled { if !self.schools_enabled {
// Schools are not enabled. We mark points automatically // Schools are not enabled. We mark points automatically
// the points earned by the opponent will be marked on its turn // the points earned by the opponent will be marked on its turn
let new_hole = self.mark_points(self.active_player_id, self.dice_points.0); let new_hole = self.mark_points(self.active_player_id, self.dice_points.0);
if new_hole { if new_hole {
if self.get_active_player().unwrap().holes > 12 { let holes_count = self.get_active_player().unwrap().holes;
debug!("new hole -> {holes_count:?}");
if holes_count > 12 {
self.stage = Stage::Ended; self.stage = Stage::Ended;
} else { } else {
self.turn_stage = TurnStage::HoldOrGoChoice; self.turn_stage = TurnStage::HoldOrGoChoice;
@ -594,6 +717,10 @@ impl GameState {
fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) { fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) {
let player = &self.players.get(&self.active_player_id).unwrap(); let player = &self.players.get(&self.active_player_id).unwrap();
debug!(
"get rollresult for {:?} {:?} {:?} (roll count {:?})",
player.color, self.board, dice, player.dice_roll_count
);
let points_rules = PointsRules::new(&player.color, &self.board, *dice); let points_rules = PointsRules::new(&player.color, &self.board, *dice);
points_rules.get_result_jans(player.dice_roll_count) points_rules.get_result_jans(player.dice_roll_count)
} }
@ -610,13 +737,15 @@ impl GameState {
fn inc_roll_count(&mut self, player_id: PlayerId) { fn inc_roll_count(&mut self, player_id: PlayerId) {
self.players.get_mut(&player_id).map(|p| { self.players.get_mut(&player_id).map(|p| {
if p.dice_roll_count < u8::MAX { p.dice_roll_count = p.dice_roll_count.saturating_add(1);
p.dice_roll_count += 1;
}
p p
}); });
} }
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
self.mark_points(player_id, points)
}
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool { fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
// Update player points and holes // Update player points and holes
let mut new_hole = false; let mut new_hole = false;
@ -636,10 +765,11 @@ impl GameState {
p.points = sum_points % 12; p.points = sum_points % 12;
p.holes += holes; p.holes += holes;
if points > 0 && p.holes > 15 { // if points > 0 && p.holes > 15 {
info!( if points > 0 {
"player {:?} holes : {:?} added points : {:?}", debug!(
player_id, p.holes, points "player {player_id:?} holes : {:?} (+{holes:?}) points : {:?} (+{points:?} - {jeux:?})",
p.holes, p.points
) )
} }
p p
@ -671,14 +801,14 @@ impl GameState {
} }
/// The reasons why a game could end /// The reasons why a game could end
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)] #[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize)]
pub enum EndGameReason { pub enum EndGameReason {
PlayerLeft { player_id: PlayerId }, PlayerLeft { player_id: PlayerId },
PlayerWon { winner: PlayerId }, PlayerWon { winner: PlayerId },
} }
/// An event that progresses the GameState forward /// An event that progresses the GameState forward
#[derive(Debug, Clone, Serialize, PartialEq, Deserialize)] #[derive(Debug, Clone, Serialize, PartialEq, Eq, Deserialize)]
pub enum GameEvent { pub enum GameEvent {
BeginGame { BeginGame {
goes_first: PlayerId, goes_first: PlayerId,
@ -733,6 +863,58 @@ impl GameEvent {
_ => None, _ => None,
} }
} }
pub fn get_mirror(&self) -> Self {
// let mut mirror = self.clone();
let mirror_player_id = if let Some(player_id) = self.player_id() {
if player_id == 1 {
2
} else {
1
}
} else {
0
};
match self {
Self::PlayerJoined { player_id: _, name } => Self::PlayerJoined {
player_id: mirror_player_id,
name: name.clone(),
},
Self::PlayerDisconnected { player_id: _ } => GameEvent::PlayerDisconnected {
player_id: mirror_player_id,
},
Self::Roll { player_id: _ } => GameEvent::Roll {
player_id: mirror_player_id,
},
Self::RollResult { player_id: _, dice } => GameEvent::RollResult {
player_id: mirror_player_id,
dice: *dice,
},
Self::Mark {
player_id: _,
points,
} => GameEvent::Mark {
player_id: mirror_player_id,
points: *points,
},
Self::Go { player_id: _ } => GameEvent::Go {
player_id: mirror_player_id,
},
Self::Move {
player_id: _,
moves: (move1, move2),
} => Self::Move {
player_id: mirror_player_id,
moves: (move1.mirror(), move2.mirror()),
},
Self::BeginGame { goes_first } => GameEvent::BeginGame {
goes_first: (if *goes_first == 1 { 2 } else { 1 }),
},
Self::EndGame { reason } => GameEvent::EndGame { reason: *reason },
Self::PlayError => GameEvent::PlayError,
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -753,7 +935,16 @@ mod tests {
let state = init_test_gamestate(TurnStage::RollDice); let state = init_test_gamestate(TurnStage::RollDice);
let string_id = state.to_string_id(); let string_id = state.to_string_id();
// println!("string_id : {}", string_id); // println!("string_id : {}", string_id);
assert_eq!(string_id, "Hz88AAAAAz8/IAAAAAQAADAD"); assert_eq!(string_id, "Pz84AAAABz8/AAAAAAgAASAG");
let new_state = GameState::from_string_id(&string_id).unwrap();
assert_eq!(state.board, new_state.board);
assert_eq!(state.active_player_id, new_state.active_player_id);
assert_eq!(state.turn_stage, new_state.turn_stage);
assert_eq!(state.dice, new_state.dice);
assert_eq!(
state.get_white_player().unwrap().points,
new_state.get_white_player().unwrap().points
);
} }
#[test] #[test]

View file

@ -625,18 +625,24 @@ mod tests {
#[test] #[test]
fn can_take_corner_by_effect() { fn can_take_corner_by_effect() {
let mut rules = MoveRules::default(); let mut rules = MoveRules::default();
rules.board.set_positions([ rules.board.set_positions(
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
rules.dice.values = (4, 4); rules.dice.values = (4, 4);
assert!(rules.can_take_corner_by_effect()); assert!(rules.can_take_corner_by_effect());
rules.dice.values = (5, 5); rules.dice.values = (5, 5);
assert!(!rules.can_take_corner_by_effect()); assert!(!rules.can_take_corner_by_effect());
rules.board.set_positions([ rules.board.set_positions(
10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
10, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
rules.dice.values = (4, 4); rules.dice.values = (4, 4);
assert!(!rules.can_take_corner_by_effect()); assert!(!rules.can_take_corner_by_effect());
} }
@ -645,9 +651,12 @@ mod tests {
fn prise_en_puissance() { fn prise_en_puissance() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
// prise par puissance ok // prise par puissance ok
state.board.set_positions([ state.board.set_positions(
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(8, 12).unwrap(), CheckerMove::new(8, 12).unwrap(),
@ -658,25 +667,34 @@ mod tests {
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
// opponent corner must be empty // opponent corner must be empty
state.board.set_positions([ state.board.set_positions(
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13, &Color::White,
]); [
10, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13,
],
);
assert!(!state.is_move_by_puissance(&moves)); assert!(!state.is_move_by_puissance(&moves));
assert!(!state.moves_follows_dices(&moves)); assert!(!state.moves_follows_dices(&moves));
// Si on a la possibilité de prendre son coin à la fois par effet, c'est à dire naturellement, et aussi par puissance, on doit le prendre par effet // Si on a la possibilité de prendre son coin à la fois par effet, c'est à dire naturellement, et aussi par puissance, on doit le prendre par effet
state.board.set_positions([ state.board.set_positions(
5, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
5, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
assert_eq!( assert_eq!(
Err(MoveError::CornerByEffectPossible), Err(MoveError::CornerByEffectPossible),
state.moves_allowed(&moves) state.moves_allowed(&moves)
); );
// on a déjà pris son coin : on ne peux plus y deplacer des dames par puissance // on a déjà pris son coin : on ne peux plus y deplacer des dames par puissance
state.board.set_positions([ state.board.set_positions(
8, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15, &Color::White,
]); [
8, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -15,
],
);
assert!(!state.is_move_by_puissance(&moves)); assert!(!state.is_move_by_puissance(&moves));
assert!(!state.moves_follows_dices(&moves)); assert!(!state.moves_follows_dices(&moves));
} }
@ -685,9 +703,12 @@ mod tests {
fn exit() { fn exit() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
// exit ok // exit ok
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(20, 0).unwrap(), CheckerMove::new(20, 0).unwrap(),
@ -697,9 +718,12 @@ mod tests {
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
// toutes les dames doivent être dans le jan de retour // toutes les dames doivent être dans le jan de retour
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(20, 0).unwrap(), CheckerMove::new(20, 0).unwrap(),
@ -711,9 +735,12 @@ mod tests {
); );
// on ne peut pas sortir une dame avec un nombre excédant si on peut en jouer une avec un nombre défaillant // on ne peut pas sortir une dame avec un nombre excédant si on peut en jouer une avec un nombre défaillant
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(20, 0).unwrap(), CheckerMove::new(20, 0).unwrap(),
@ -725,9 +752,12 @@ mod tests {
); );
// on doit jouer le nombre excédant le plus éloigné // on doit jouer le nombre excédant le plus éloigné
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(20, 0).unwrap(), CheckerMove::new(20, 0).unwrap(),
@ -741,9 +771,12 @@ mod tests {
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
// Cas de la dernière dame // Cas de la dernière dame
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(23, 0).unwrap(), CheckerMove::new(23, 0).unwrap(),
@ -756,9 +789,12 @@ mod tests {
#[test] #[test]
fn move_check_opponent_fillable_quarter() { fn move_check_opponent_fillable_quarter() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(11, 16).unwrap(), CheckerMove::new(11, 16).unwrap(),
@ -766,9 +802,12 @@ mod tests {
); );
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(11, 16).unwrap(), CheckerMove::new(11, 16).unwrap(),
@ -779,9 +818,12 @@ mod tests {
state.moves_allowed(&moves) state.moves_allowed(&moves)
); );
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -12, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(11, 16).unwrap(), CheckerMove::new(11, 16).unwrap(),
@ -789,9 +831,12 @@ mod tests {
); );
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, -12, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, -12,
],
);
state.dice.values = (5, 5); state.dice.values = (5, 5);
let moves = ( let moves = (
CheckerMove::new(11, 16).unwrap(), CheckerMove::new(11, 16).unwrap(),
@ -806,9 +851,12 @@ mod tests {
#[test] #[test]
fn move_check_fillable_quarter() { fn move_check_fillable_quarter() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
3, 3, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
3, 3, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (5, 4); state.dice.values = (5, 4);
let moves = ( let moves = (
CheckerMove::new(1, 6).unwrap(), CheckerMove::new(1, 6).unwrap(),
@ -821,9 +869,12 @@ mod tests {
); );
assert_eq!(Err(MoveError::MustFillQuarter), state.moves_allowed(&moves)); assert_eq!(Err(MoveError::MustFillQuarter), state.moves_allowed(&moves));
state.board.set_positions([ state.board.set_positions(
2, 3, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 3, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 3); state.dice.values = (2, 3);
let moves = ( let moves = (
CheckerMove::new(6, 8).unwrap(), CheckerMove::new(6, 8).unwrap(),
@ -840,9 +891,12 @@ mod tests {
#[test] #[test]
fn move_play_all_dice() { fn move_play_all_dice() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
],
);
state.dice.values = (1, 3); state.dice.values = (1, 3);
let moves = ( let moves = (
CheckerMove::new(22, 0).unwrap(), CheckerMove::new(22, 0).unwrap(),
@ -861,9 +915,12 @@ mod tests {
fn move_opponent_rest_corner_rules() { fn move_opponent_rest_corner_rules() {
// fill with 2 checkers : forbidden // fill with 2 checkers : forbidden
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (1, 1); state.dice.values = (1, 1);
let moves = ( let moves = (
CheckerMove::new(12, 13).unwrap(), CheckerMove::new(12, 13).unwrap(),
@ -891,9 +948,12 @@ mod tests {
fn move_rest_corner_enter() { fn move_rest_corner_enter() {
// direct // direct
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let moves = ( let moves = (
CheckerMove::new(10, 12).unwrap(), CheckerMove::new(10, 12).unwrap(),
@ -915,9 +975,12 @@ mod tests {
#[test] #[test]
fn move_rest_corner_blocked() { fn move_rest_corner_blocked() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let moves = ( let moves = (
CheckerMove::new(0, 0).unwrap(), CheckerMove::new(0, 0).unwrap(),
@ -926,9 +989,12 @@ mod tests {
assert!(state.moves_follows_dices(&moves)); assert!(state.moves_follows_dices(&moves));
assert!(state.moves_allowed(&moves).is_ok()); assert!(state.moves_allowed(&moves).is_ok());
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let moves = ( let moves = (
CheckerMove::new(23, 24).unwrap(), CheckerMove::new(23, 24).unwrap(),
@ -949,9 +1015,12 @@ mod tests {
#[test] #[test]
fn move_rest_corner_exit() { fn move_rest_corner_exit() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 3); state.dice.values = (2, 3);
let moves = ( let moves = (
CheckerMove::new(12, 14).unwrap(), CheckerMove::new(12, 14).unwrap(),
@ -967,9 +1036,12 @@ mod tests {
fn move_rest_corner_toutdune() { fn move_rest_corner_toutdune() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
// We can't go to the occupied rest corner as an intermediary step // We can't go to the occupied rest corner as an intermediary step
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let moves = ( let moves = (
CheckerMove::new(11, 13).unwrap(), CheckerMove::new(11, 13).unwrap(),
@ -978,9 +1050,12 @@ mod tests {
assert!(!state.moves_possible(&moves)); assert!(!state.moves_possible(&moves));
// We can use the empty rest corner as an intermediary step // We can use the empty rest corner as an intermediary step
state.board.set_positions([ state.board.set_positions(
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, &Color::White,
]); [
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3,
],
);
state.dice.values = (6, 5); state.dice.values = (6, 5);
let moves = ( let moves = (
CheckerMove::new(8, 13).unwrap(), CheckerMove::new(8, 13).unwrap(),
@ -994,9 +1069,12 @@ mod tests {
#[test] #[test]
fn move_play_stronger_dice() { fn move_play_stronger_dice() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, -1, -1, -1, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, -1, -1, -1, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 3); state.dice.values = (2, 3);
let moves = ( let moves = (
CheckerMove::new(12, 14).unwrap(), CheckerMove::new(12, 14).unwrap(),
@ -1034,9 +1112,12 @@ mod tests {
assert!(!state.moves_possible(&moves)); assert!(!state.moves_possible(&moves));
// Can't move the same checker twice // Can't move the same checker twice
state.board.set_positions([ state.board.set_positions(
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let moves = ( let moves = (
CheckerMove::new(3, 5).unwrap(), CheckerMove::new(3, 5).unwrap(),
@ -1056,9 +1137,12 @@ mod tests {
#[test] #[test]
fn filling_moves_sequences() { fn filling_moves_sequences() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); let filling_moves_sequences = state.get_quarter_filling_moves_sequences();
// println!( // println!(
@ -1067,17 +1151,23 @@ mod tests {
// ); // );
assert_eq!(2, filling_moves_sequences.len()); assert_eq!(2, filling_moves_sequences.len());
state.board.set_positions([ state.board.set_positions(
3, 2, 3, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 2, 3, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 2); state.dice.values = (2, 2);
let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); let filling_moves_sequences = state.get_quarter_filling_moves_sequences();
// println!("{:?}", filling_moves_sequences); // println!("{:?}", filling_moves_sequences);
assert_eq!(2, filling_moves_sequences.len()); assert_eq!(2, filling_moves_sequences.len());
state.board.set_positions([ state.board.set_positions(
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); let filling_moves_sequences = state.get_quarter_filling_moves_sequences();
// println!( // println!(
@ -1087,9 +1177,12 @@ mod tests {
assert_eq!(2, filling_moves_sequences.len()); assert_eq!(2, filling_moves_sequences.len());
// positions // positions
state.board.set_positions([ state.board.set_positions(
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, &Color::White,
]); [
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3,
],
);
state.dice.values = (6, 5); state.dice.values = (6, 5);
let filling_moves_sequences = state.get_quarter_filling_moves_sequences(); let filling_moves_sequences = state.get_quarter_filling_moves_sequences();
assert_eq!(1, filling_moves_sequences.len()); assert_eq!(1, filling_moves_sequences.len());
@ -1099,19 +1192,46 @@ mod tests {
fn scoring_filling_moves_sequences() { fn scoring_filling_moves_sequences() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len()); assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len());
state.board.set_positions([ state.board.set_positions(
2, 3, 3, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 3, 3, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 1); state.dice.values = (2, 1);
let filling_moves_sequences = state.get_scoring_quarter_filling_moves_sequences(); let filling_moves_sequences = state.get_scoring_quarter_filling_moves_sequences();
// println!("{:?}", filling_moves_sequences); // println!("{:?}", filling_moves_sequences);
assert_eq!(3, filling_moves_sequences.len()); assert_eq!(3, filling_moves_sequences.len());
// preserve filling
state.board.set_positions(
&Color::White,
[
2, 2, 2, 2, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -1, -2, -3, -5, 0, -1,
],
);
state.dice.values = (3, 1);
assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len());
// preserve filling (black)
let mut state = MoveRules::new(&Color::Black, &Board::default(), Dice::default());
state.board.set_positions(
&Color::Black,
[
1, 0, 5, 3, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -4, -2, -2, -2, -2, -2,
],
);
state.dice.values = (3, 1);
assert_eq!(1, state.get_scoring_quarter_filling_moves_sequences().len());
} }
// prise de coin par puissance et conservation de jan #18 // prise de coin par puissance et conservation de jan #18
@ -1120,9 +1240,12 @@ mod tests {
fn corner_by_effect_and_filled_corner() { fn corner_by_effect_and_filled_corner() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3, &Color::White,
]); [
2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, -2, 0, 0, 0, -2, 0, -2, -2, -2, -2, -3,
],
);
state.dice.values = (6, 5); state.dice.values = (6, 5);
let moves = ( let moves = (
@ -1155,9 +1278,12 @@ mod tests {
fn get_possible_moves_sequences() { fn get_possible_moves_sequences() {
let mut state = MoveRules::default(); let mut state = MoveRules::default();
state.board.set_positions([ state.board.set_positions(
2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
state.dice.values = (2, 3); state.dice.values = (2, 3);
let moves = ( let moves = (
CheckerMove::new(9, 11).unwrap(), CheckerMove::new(9, 11).unwrap(),

View file

@ -5,6 +5,7 @@ use crate::player::Color;
use crate::CheckerMove; use crate::CheckerMove;
use crate::Error; use crate::Error;
use log::debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp; use std::cmp;
use std::collections::HashMap; use std::collections::HashMap;
@ -143,7 +144,9 @@ impl PointsRules {
} else { } else {
board.clone() board.clone()
}; };
let move_rules = MoveRules::new(color, &board, dice); // the board is already reverted for black, so we pretend color is white
let move_rules = MoveRules::new(&Color::White, &board, dice);
// let move_rules = MoveRules::new(color, &board, dice);
// let move_rules = MoveRules::new(color, &self.board, dice, moves); // let move_rules = MoveRules::new(color, &self.board, dice, moves);
Self { Self {
@ -158,9 +161,9 @@ impl PointsRules {
self.move_rules.dice = dice; self.move_rules.dice = dice;
} }
pub fn update_positions(&mut self, positions: [i8; 24]) { pub fn update_positions(&mut self, color: &Color, positions: [i8; 24]) {
self.board.set_positions(positions); self.board.set_positions(color, positions);
self.move_rules.board.set_positions(positions); self.move_rules.board.set_positions(color, positions);
} }
fn get_jans(&self, board_ini: &Board, dice_rolls_count: u8) -> PossibleJans { fn get_jans(&self, board_ini: &Board, dice_rolls_count: u8) -> PossibleJans {
@ -381,6 +384,7 @@ impl PointsRules {
pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) { pub fn get_result_jans(&self, dice_rolls_count: u8) -> (PossibleJans, (u8, u8)) {
let jans = self.get_jans(&self.board, dice_rolls_count); let jans = self.get_jans(&self.board, dice_rolls_count);
debug!("jans : {jans:?}");
let points_jans = jans.clone(); let points_jans = jans.clone();
(jans, self.get_jans_points(points_jans)) (jans, self.get_jans_points(points_jans))
} }
@ -481,9 +485,12 @@ mod tests {
#[test] #[test]
fn get_jans_by_dice_order() { fn get_jans_by_dice_order() {
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.board.set_positions([ rules.board.set_positions(
2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false);
assert_eq!(1, jans.len()); assert_eq!(1, jans.len());
@ -495,9 +502,12 @@ mod tests {
// On peut passer par une dame battue pour battre une autre dame // On peut passer par une dame battue pour battre une autre dame
// mais pas par une case remplie par l'adversaire // mais pas par une case remplie par l'adversaire
rules.board.set_positions([ rules.board.set_positions(
2, 0, -1, -2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -1, -2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let mut jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let mut jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false);
let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[3, 2], None, false); let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[3, 2], None, false);
@ -506,25 +516,34 @@ mod tests {
jans.merge(jans_revert_dices); jans.merge(jans_revert_dices);
assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len());
rules.board.set_positions([ rules.board.set_positions(
2, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false);
assert_eq!(1, jans.len()); assert_eq!(1, jans.len());
assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); assert_eq!(2, jans.get(&Jan::TrueHitSmallJan).unwrap().len());
rules.board.set_positions([ rules.board.set_positions(
2, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false);
assert_eq!(1, jans.len()); assert_eq!(1, jans.len());
assert_eq!(1, jans.get(&Jan::TrueHitSmallJan).unwrap().len()); assert_eq!(1, jans.get(&Jan::TrueHitSmallJan).unwrap().len());
rules.board.set_positions([ rules.board.set_positions(
2, 0, 1, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, 1, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 3], None, false);
assert_eq!(1, jans.len()); assert_eq!(1, jans.len());
@ -533,25 +552,34 @@ mod tests {
// corners handling // corners handling
// deux dés bloqués (coin de repos et coin de l'adversaire) // deux dés bloqués (coin de repos et coin de l'adversaire)
rules.board.set_positions([ rules.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
// le premier dé traité est le dernier du vecteur : 1 // le premier dé traité est le dernier du vecteur : 1
let jans = get_jans_by_ordered_dice(&rules.board, &[2, 1], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[2, 1], None, false);
// println!("jans (dés bloqués) : {:?}", jans.get(&Jan::TrueHit)); // println!("jans (dés bloqués) : {:?}", jans.get(&Jan::TrueHit));
assert_eq!(0, jans.len()); assert_eq!(0, jans.len());
// dé dans son coin de repos : peut tout de même battre à vrai // dé dans son coin de repos : peut tout de même battre à vrai
rules.board.set_positions([ rules.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let jans = get_jans_by_ordered_dice(&rules.board, &[3, 3], None, false); let jans = get_jans_by_ordered_dice(&rules.board, &[3, 3], None, false);
assert_eq!(1, jans.len()); assert_eq!(1, jans.len());
// premier dé bloqué, mais tout d'une possible en commençant par le second // premier dé bloqué, mais tout d'une possible en commençant par le second
rules.board.set_positions([ rules.board.set_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
let mut jans = get_jans_by_ordered_dice(&rules.board, &[3, 1], None, false); let mut jans = get_jans_by_ordered_dice(&rules.board, &[3, 1], None, false);
let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[1, 3], None, false); let jans_revert_dices = get_jans_by_ordered_dice(&rules.board, &[1, 3], None, false);
assert_eq!(1, jans_revert_dices.len()); assert_eq!(1, jans_revert_dices.len());
@ -564,174 +592,293 @@ mod tests {
// à vrai // à vrai
} }
#[test]
fn get_result_jans() {
let mut board = Board::new();
board.set_positions(
&Color::White,
[
0, 0, 5, 2, 4, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -2, -2, -2, -2, -2, -2,
],
);
let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) });
let jans = points_rules.get_result_jans(8);
assert!(!jans.0.is_empty());
}
#[test] #[test]
fn get_points() { fn get_points() {
// ----- Jan de récompense // ----- Jan de récompense
// Battre à vrai une dame située dans la table des petits jans : 4 + 4 + 4 = 12 // Battre à vrai une dame située dans la table des petits jans : 4 + 4 + 4 = 12
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!(12, rules.get_points(5).0); assert_eq!(12, rules.get_points(5).0);
// Calcul des points pour noir
let mut board = Board::new();
board.set_positions(
&Color::White,
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2,
],
);
let rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
assert_eq!(12, rules.get_points(5).0);
// Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4 // Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, 0, -1, 2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, 0, -1, 2, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 4) }); rules.set_dice(Dice { values: (2, 4) });
assert_eq!(4, rules.get_points(5).0); assert_eq!(4, rules.get_points(5).0);
// Battre à vrai une dame située dans la table des grands jans : 2 // Battre à vrai une dame située dans la table des grands jans : 2
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, -2, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -2, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 4) }); rules.set_dice(Dice { values: (2, 4) });
assert_eq!((2, 2), rules.get_points(5)); assert_eq!((2, 2), rules.get_points(5));
// Battre à vrai le coin adverse par doublet : 6 // Battre à vrai le coin adverse par doublet : 6
rules.update_positions([ rules.update_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 2) }); rules.set_dice(Dice { values: (2, 2) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
// Cas de battage du coin de repos adverse impossible // Cas de battage du coin de repos adverse impossible
rules.update_positions([ rules.update_positions(
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!(0, rules.get_points(5).0); assert_eq!(0, rules.get_points(5).0);
// ---- Jan de remplissage // ---- Jan de remplissage
// Faire un petit jan : 4 // Faire un petit jan : 4
rules.update_positions([ rules.update_positions(
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 1) }); rules.set_dice(Dice { values: (2, 1) });
assert_eq!(1, rules.get_jans(&rules.board, 5).len()); assert_eq!(1, rules.get_jans(&rules.board, 5).len());
assert_eq!(4, rules.get_points(5).0); assert_eq!(4, rules.get_points(5).0);
// Faire un petit jan avec un doublet : 6 // Faire un petit jan avec un doublet : 6
rules.update_positions([ rules.update_positions(
2, 3, 1, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 3, 1, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
// Faire un petit jan avec 2 moyens : 6 + 6 = 12 // Faire un petit jan avec 2 moyens : 6 + 6 = 12
rules.update_positions([ rules.update_positions(
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 3, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!(12, rules.get_points(5).0); assert_eq!(12, rules.get_points(5).0);
// Conserver un jan avec un doublet : 6 // Conserver un jan avec un doublet : 6
rules.update_positions([ rules.update_positions(
3, 3, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
3, 3, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
// Conserver un jan
rules.update_positions(
&Color::White,
[
2, 2, 2, 2, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -1, -2, -3, -5, 0, -1,
],
);
rules.set_dice(Dice { values: (3, 1) });
assert_eq!((4, 0), rules.get_points(8));
// Conserver un jan (black)
let mut board = Board::new();
board.set_positions(
&Color::White,
[
1, 0, 5, 3, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -4, -2, -2, -2, -2, -2,
],
);
let rules = PointsRules::new(&Color::Black, &board, Dice { values: (3, 1) });
assert_eq!((4, 0), rules.get_points(8));
// ---- Sorties // ---- Sorties
// Sortir toutes ses dames avant l'adversaire (simple) // Sortir toutes ses dames avant l'adversaire (simple)
rules.update_positions([ let mut rules = PointsRules::default();
0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, rules.update_positions(
]); &Color::White,
[
0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
],
);
rules.set_dice(Dice { values: (3, 1) }); rules.set_dice(Dice { values: (3, 1) });
assert_eq!(4, rules.get_points(5).0); assert_eq!(4, rules.get_points(5).0);
// Sortir toutes ses dames avant l'adversaire (doublet) // Sortir toutes ses dames avant l'adversaire (doublet)
rules.update_positions([ rules.update_positions(
0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, &Color::White,
]); [
0, 0, -2, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
],
);
rules.set_dice(Dice { values: (2, 2) }); rules.set_dice(Dice { values: (2, 2) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
// ---- JANS RARES // ---- JANS RARES
// Jan de six tables // Jan de six tables
rules.update_positions([ rules.update_positions(
10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!(0, rules.get_points(5).0); assert_eq!(0, rules.get_points(5).0);
rules.update_positions([ rules.update_positions(
10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
10, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!(4, rules.get_points(3).0); assert_eq!(4, rules.get_points(3).0);
rules.update_positions([ rules.update_positions(
10, 1, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
10, 1, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!(0, rules.get_points(3).0); assert_eq!(0, rules.get_points(3).0);
rules.update_positions([ rules.update_positions(
10, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
10, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!(0, rules.get_points(3).0); assert_eq!(0, rules.get_points(3).0);
// Jan de deux tables // Jan de deux tables
rules.update_positions([ rules.update_positions(
13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 2) }); rules.set_dice(Dice { values: (2, 2) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
rules.update_positions([ rules.update_positions(
12, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
12, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 2) }); rules.set_dice(Dice { values: (2, 2) });
assert_eq!(0, rules.get_points(5).0); assert_eq!(0, rules.get_points(5).0);
// Contre jan de deux tables // Contre jan de deux tables
rules.update_positions([ rules.update_positions(
13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
13, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (2, 2) }); rules.set_dice(Dice { values: (2, 2) });
assert_eq!((0, 6), rules.get_points(5)); assert_eq!((0, 6), rules.get_points(5));
// Jan de mézéas // Jan de mézéas
rules.update_positions([ rules.update_positions(
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!(6, rules.get_points(5).0); assert_eq!(6, rules.get_points(5).0);
rules.update_positions([ rules.update_positions(
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (1, 2) }); rules.set_dice(Dice { values: (1, 2) });
assert_eq!(4, rules.get_points(5).0); assert_eq!(4, rules.get_points(5).0);
// Contre jan de mézéas // Contre jan de mézéas
rules.update_positions([ rules.update_positions(
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, &Color::White,
]); [
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
],
);
rules.set_dice(Dice { values: (1, 1) }); rules.set_dice(Dice { values: (1, 1) });
assert_eq!((0, 6), rules.get_points(5)); assert_eq!((0, 6), rules.get_points(5));
// ---- JANS QUI NE PEUT // ---- JANS QUI NE PEUT
// Battre à faux une dame située dans la table des petits jans // Battre à faux une dame située dans la table des petits jans
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -2, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 3) }); rules.set_dice(Dice { values: (2, 3) });
assert_eq!((0, 4), rules.get_points(5)); assert_eq!((0, 4), rules.get_points(5));
// Battre à faux une dame située dans la table des grands jans // Battre à faux une dame située dans la table des grands jans
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, -2, -1, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -2, -1, -2, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 4) }); rules.set_dice(Dice { values: (2, 4) });
assert_eq!((0, 2), rules.get_points(5)); assert_eq!((0, 2), rules.get_points(5));
// Pour chaque dé non jouable (dame impuissante) // Pour chaque dé non jouable (dame impuissante)
let mut rules = PointsRules::default(); let mut rules = PointsRules::default();
rules.update_positions([ rules.update_positions(
2, 0, -2, -2, -2, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, &Color::White,
]); [
2, 0, -2, -2, -2, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
rules.set_dice(Dice { values: (2, 4) }); rules.set_dice(Dice { values: (2, 4) });
assert_eq!((0, 4), rules.get_points(5)); assert_eq!((0, 4), rules.get_points(5));
} }

View file

@ -4,7 +4,7 @@ use std::fmt;
// This just makes it easier to dissern between a player id and any ol' u64 // This just makes it easier to dissern between a player id and any ol' u64
pub type PlayerId = u64; pub type PlayerId = u64;
#[derive(Copy, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Color { pub enum Color {
White, White,
Black, Black,
@ -20,7 +20,7 @@ impl Color {
} }
/// Struct for storing player related data. /// Struct for storing player related data.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Player { pub struct Player {
pub name: String, pub name: String,
pub color: Color, pub color: Color,
@ -53,6 +53,26 @@ impl Player {
) )
} }
pub fn from_bits_string(bits: &str, name: String, color: Color) -> Result<Self, String> {
if bits.len() != 10 {
return Err("Invalid bit string length for player".to_string());
}
let points = u8::from_str_radix(&bits[0..4], 2).map_err(|e| e.to_string())?;
let holes = u8::from_str_radix(&bits[4..8], 2).map_err(|e| e.to_string())?;
let can_bredouille = bits.chars().nth(8).unwrap() == '1';
let can_big_bredouille = bits.chars().nth(9).unwrap() == '1';
Ok(Player {
name,
color,
points,
holes,
can_bredouille,
can_big_bredouille,
dice_roll_count: 0, // This info is not in the string id
})
}
pub fn to_vec(&self) -> Vec<u8> { pub fn to_vec(&self) -> Vec<u8> {
vec![ vec![
self.points, self.points,