Compare commits
No commits in common. "main" and "feature/botTrainValidMoves" have entirely different histories.
main
...
feature/bo
93 changed files with 5502 additions and 9102 deletions
4776
Cargo.lock
generated
4776
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = ["client_cli", "bot", "store"]
|
||||
members = ["client_tui", "client_cli", "bot", "server", "store"]
|
||||
|
|
|
|||
38
README.md
38
README.md
|
|
@ -1,41 +1,7 @@
|
|||
# Trictrac
|
||||
|
||||
This is a game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) rust implementation.
|
||||
Game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) in rust.
|
||||
|
||||
The project is on its early stages.
|
||||
Rules (without "schools") are implemented, as well as a rudimentary terminal interface which allow you to play against a bot which plays randomly.
|
||||
wip
|
||||
|
||||
Training of AI bots is the work in progress.
|
||||
|
||||
## Usage
|
||||
|
||||
`cargo run --bin=client_cli -- --bot random`
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] rules
|
||||
- [x] command line interface
|
||||
- [x] basic bot (random play)
|
||||
- [ ] AI bot
|
||||
- [ ] network game
|
||||
- [ ] web client
|
||||
|
||||
## Code structure
|
||||
|
||||
- game rules and game state are implemented in the _store/_ folder.
|
||||
- the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other
|
||||
- the bots algorithms and the training of their models are implemented in the _bot/_ folder
|
||||
|
||||
### _store_ package
|
||||
|
||||
The game state is defined by the `GameState` struct in _store/src/game.rs_. The `to_string_id()` method allows this state to be encoded compactly in a string (without the played moves history). For a more readable textual representation, the `fmt::Display` trait is implemented.
|
||||
|
||||
### _client_cli_ package
|
||||
|
||||
`client_cli/src/game_runner.rs` contains the logic to make two bots play against each other.
|
||||
|
||||
### _bot_ package
|
||||
|
||||
- `bot/src/strategy/default.rs` contains the code for a basic bot strategy: it determines the list of valid moves (using the `get_possible_moves_sequences` method of `store::MoveRules`) and simply executes the first move in the list.
|
||||
- `bot/src/strategy/dqnburn.rs` is another bot strategy that uses a reinforcement learning trained model with the DQN algorithm via the burn library (<https://burn.dev/>).
|
||||
- `bot/scripts/trains.sh` allows you to train agents using different algorithms (DQN, PPO, SAC).
|
||||
|
|
|
|||
|
|
@ -1,24 +1,29 @@
|
|||
[package]
|
||||
name = "trictrac-bot"
|
||||
name = "bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[[bin]]
|
||||
name = "burn_train"
|
||||
path = "src/burnrl/main.rs"
|
||||
name = "train_dqn_burn_valid"
|
||||
path = "src/dqn/burnrl_valid/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_burn"
|
||||
path = "src/dqn/burnrl/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "train_dqn_simple"
|
||||
path = "src/dqn/simple/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pretty_assertions = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
trictrac-store = { path = "../store" }
|
||||
rand = "0.9"
|
||||
store = { path = "../store" }
|
||||
rand = "0.8"
|
||||
env_logger = "0.10"
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
burn = { version = "0.17", features = ["ndarray", "autodiff"] }
|
||||
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
|
||||
log = "0.4.20"
|
||||
confy = "1.0.0"
|
||||
board-game = "0.8.2"
|
||||
internal-iterator = "0.2.3"
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
import trictrac_store
|
||||
|
||||
game = trictrac_store.TricTrac()
|
||||
print(game.current_player_idx())
|
||||
print(game.get_legal_actions(game.current_player_idx()))
|
||||
|
|
@ -1,50 +1,38 @@
|
|||
#!/usr/bin/env bash
|
||||
#!/usr/bin/env sh
|
||||
|
||||
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||
LOGS_DIR="$ROOT/bot/models/logs"
|
||||
|
||||
CFG_SIZE=17
|
||||
BINBOT=burn_train
|
||||
# BINBOT=train_ppo_burn
|
||||
# BINBOT=train_dqn_burn
|
||||
# BINBOT=train_dqn_burn_big
|
||||
# BINBOT=train_dqn_burn_before
|
||||
CFG_SIZE=12
|
||||
OPPONENT="random"
|
||||
|
||||
PLOT_EXT="png"
|
||||
|
||||
train() {
|
||||
ALGO=$1
|
||||
cargo build --release --bin=$BINBOT
|
||||
NAME="$(date +%Y-%m-%d_%H:%M:%S)"
|
||||
LOGS="$LOGS_DIR/$ALGO/$NAME.out"
|
||||
mkdir -p "$LOGS_DIR/$ALGO"
|
||||
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" $ALGO | tee "$LOGS"
|
||||
cargo build --release --bin=train_dqn_burn
|
||||
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
|
||||
LOGS="$LOGS_DIR/$NAME.out"
|
||||
mkdir -p "$LOGS_DIR"
|
||||
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
|
||||
}
|
||||
|
||||
plot() {
|
||||
ALGO=$1
|
||||
NAME=$(ls -rt "$LOGS_DIR/$ALGO" | grep -v png | tail -n 1)
|
||||
LOGS="$LOGS_DIR/$ALGO/$NAME"
|
||||
cfgs=$(grep -v "info:" "$LOGS" | head -n $CFG_SIZE)
|
||||
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||
LOGS="$LOGS_DIR/$NAME"
|
||||
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||
for cfg in $cfgs; do
|
||||
eval "$cfg"
|
||||
done
|
||||
|
||||
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
|
||||
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||
grep -v "info:" |
|
||||
awk -F '[ ,]' '{print $5}' |
|
||||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT"
|
||||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||
}
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage : train [plot] <algo>"
|
||||
elif [ "$1" = "plot" ]; then
|
||||
if [[ -z "$2" ]]; then
|
||||
echo "Usage : train [plot] <algo>"
|
||||
if [ "$1" = "plot" ]; then
|
||||
plot
|
||||
else
|
||||
plot $2
|
||||
fi
|
||||
else
|
||||
train $1
|
||||
train
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ train() {
|
|||
}
|
||||
|
||||
plot() {
|
||||
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
|
||||
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||
LOGS="$LOGS_DIR/$NAME"
|
||||
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||
for cfg in $cfgs; do
|
||||
|
|
@ -31,19 +31,8 @@ plot() {
|
|||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||
}
|
||||
|
||||
avg() {
|
||||
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
|
||||
LOGS="$LOGS_DIR/$NAME"
|
||||
echo $LOGS
|
||||
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||
grep -v "info:" |
|
||||
awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }'
|
||||
}
|
||||
|
||||
if [ "$1" = "plot" ]; then
|
||||
plot
|
||||
elif [ "$1" = "avg" ]; then
|
||||
avg
|
||||
else
|
||||
train
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
pub mod dqn;
|
||||
pub mod dqn_valid;
|
||||
pub mod ppo;
|
||||
pub mod ppo_valid;
|
||||
pub mod sac;
|
||||
pub mod sac_valid;
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use crate::burnrl::utils::Config;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear: Linear<B>,
|
||||
linear_actor: Linear<B>,
|
||||
linear_critic: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
||||
Self {
|
||||
linear: LinearConfig::new(input_size, dense_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_actor: LinearConfig::new(dense_size, output_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_critic: LinearConfig::new(dense_size, 1)
|
||||
.with_initializer(initializer)
|
||||
.init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
||||
let values = self.linear_critic.forward(layer_0_output);
|
||||
|
||||
PPOOutput::<B>::new(policies, values)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PPOModel<B> for Net<B> {}
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 512;
|
||||
|
||||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
// ) -> PPO<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let mut model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
let agent = MyAgent::default();
|
||||
let config = PPOTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
lambda: conf.lambda,
|
||||
epsilon_clip: conf.epsilon_clip,
|
||||
critic_weight: conf.critic_weight,
|
||||
entropy_weight: conf.entropy_weight,
|
||||
learning_rate: conf.learning_rate,
|
||||
epochs: conf.epochs,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(config.clip_grad.clone())
|
||||
.init();
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
env.reset();
|
||||
while !episode_done {
|
||||
let state = env.state();
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
||||
let snapshot = env.step(action);
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
episode_duration += 1;
|
||||
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
|
||||
now = SystemTime::now();
|
||||
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
||||
memory.clear();
|
||||
}
|
||||
|
||||
if let Some(path) = &conf.save_path {
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
||||
|
||||
// Save the trained model (backend B) to a temporary file
|
||||
recorder
|
||||
.record(model.clone().into_record(), tmp_path.clone())
|
||||
.expect("Failed to save temporary model");
|
||||
|
||||
// Create a new model instance with the target backend (NdArray)
|
||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
// Load the record from the temporary file into the new model
|
||||
let record = recorder
|
||||
.load(tmp_path.clone(), &device)
|
||||
.expect("Failed to load temporary model");
|
||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
||||
|
||||
// Clean up the temporary file
|
||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
||||
|
||||
save_model(&model_with_loaded_weights, path);
|
||||
}
|
||||
agent.valid(model)
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
||||
use crate::burnrl::utils::Config;
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
linear: Linear<B>,
|
||||
linear_actor: Linear<B>,
|
||||
linear_critic: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
#[allow(unused)]
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
||||
Self {
|
||||
linear: LinearConfig::new(input_size, dense_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_actor: LinearConfig::new(dense_size, output_size)
|
||||
.with_initializer(initializer.clone())
|
||||
.init(&Default::default()),
|
||||
linear_critic: LinearConfig::new(dense_size, 1)
|
||||
.with_initializer(initializer)
|
||||
.init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
||||
let values = self.linear_critic.forward(layer_0_output);
|
||||
|
||||
PPOOutput::<B>::new(policies, values)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear.forward(input));
|
||||
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PPOModel<B> for Net<B> {}
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 512;
|
||||
|
||||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
// ) -> PPO<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let mut model = Net::<B>::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
let agent = MyAgent::default();
|
||||
let config = PPOTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
lambda: conf.lambda,
|
||||
epsilon_clip: conf.epsilon_clip,
|
||||
critic_weight: conf.critic_weight,
|
||||
entropy_weight: conf.entropy_weight,
|
||||
learning_rate: conf.learning_rate,
|
||||
epochs: conf.epochs,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_grad_clipping(config.clip_grad.clone())
|
||||
.init();
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
env.reset();
|
||||
while !episode_done {
|
||||
let state = env.state();
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
||||
let snapshot = env.step(action);
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
episode_duration += 1;
|
||||
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
|
||||
now = SystemTime::now();
|
||||
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
||||
memory.clear();
|
||||
}
|
||||
|
||||
if let Some(path) = &conf.save_path {
|
||||
let device = NdArrayDevice::default();
|
||||
let recorder = CompactRecorder::new();
|
||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
||||
|
||||
// Save the trained model (backend B) to a temporary file
|
||||
recorder
|
||||
.record(model.clone().into_record(), tmp_path.clone())
|
||||
.expect("Failed to save temporary model");
|
||||
|
||||
// Create a new model instance with the target backend (NdArray)
|
||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
||||
<<E as Environment>::StateType as State>::size(),
|
||||
conf.dense_size,
|
||||
<<E as Environment>::ActionType as Action>::size(),
|
||||
);
|
||||
|
||||
// Load the record from the temporary file into the new model
|
||||
let record = recorder
|
||||
.load(tmp_path.clone(), &device)
|
||||
.expect("Failed to load temporary model");
|
||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
||||
|
||||
// Clean up the temporary file
|
||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
||||
|
||||
save_model(&model_with_loaded_weights, path);
|
||||
}
|
||||
agent.valid(model)
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Actor<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Actor<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
softmax(self.linear_2.forward(layer_1_output), 1)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACActor<B> for Actor<B> {}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Critic<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Critic<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||
(self.linear_0, self.linear_1, self.linear_2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
self.linear_2.forward(layer_1_output)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACCritic<B> for Critic<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 4096;
|
||||
|
||||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
let state_dim = <<E as Environment>::StateType as State>::size();
|
||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
||||
|
||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
||||
|
||||
let mut agent = MyAgent::default();
|
||||
|
||||
let config = SACTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
min_probability: conf.min_probability,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
||||
|
||||
let mut optimizer = SACOptimizer::new(
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.init(),
|
||||
);
|
||||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut state = env.state();
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
while !episode_done {
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
||||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs()
|
||||
);
|
||||
now = SystemTime::now();
|
||||
} else {
|
||||
state = *snapshot.state();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let valid_agent = agent.valid(nets.actor);
|
||||
if let Some(path) = &conf.save_path {
|
||||
if let Some(model) = valid_agent.model() {
|
||||
save_model(model, path);
|
||||
}
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Actor::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
|
@ -1,222 +0,0 @@
|
|||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::{relu, softmax};
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Actor<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Actor<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
softmax(self.linear_2.forward(layer_1_output), 1)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACActor<B> for Actor<B> {}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Critic<B: Backend> {
|
||||
linear_0: Linear<B>,
|
||||
linear_1: Linear<B>,
|
||||
linear_2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Critic<B> {
|
||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
||||
Self {
|
||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
||||
(self.linear_0, self.linear_1, self.linear_2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let layer_0_output = relu(self.linear_0.forward(input));
|
||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
||||
|
||||
self.linear_2.forward(layer_1_output)
|
||||
}
|
||||
|
||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.forward(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SACCritic<B> for Critic<B> {
|
||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
||||
let (linear_0, linear_1, linear_2) = this.consume();
|
||||
|
||||
Self {
|
||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 4096;
|
||||
|
||||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
visualized: bool,
|
||||
) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
let state_dim = <<E as Environment>::StateType as State>::size();
|
||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
||||
|
||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
||||
|
||||
let mut agent = MyAgent::default();
|
||||
|
||||
let config = SACTrainingConfig {
|
||||
gamma: conf.gamma,
|
||||
tau: conf.tau,
|
||||
learning_rate: conf.learning_rate,
|
||||
min_probability: conf.min_probability,
|
||||
batch_size: conf.batch_size,
|
||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
||||
conf.clip_grad,
|
||||
)),
|
||||
};
|
||||
|
||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
||||
|
||||
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
||||
|
||||
let mut optimizer = SACOptimizer::new(
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.clone().init(),
|
||||
optimizer_config.init(),
|
||||
);
|
||||
|
||||
let mut step = 0_usize;
|
||||
|
||||
for episode in 0..conf.num_episodes {
|
||||
let mut episode_done = false;
|
||||
let mut episode_reward = 0.0;
|
||||
let mut episode_duration = 0_usize;
|
||||
let mut state = env.state();
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
while !episode_done {
|
||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
||||
let snapshot = env.step(action);
|
||||
|
||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
||||
snapshot.reward().clone(),
|
||||
);
|
||||
|
||||
memory.push(
|
||||
state,
|
||||
*snapshot.state(),
|
||||
action,
|
||||
snapshot.reward().clone(),
|
||||
snapshot.done(),
|
||||
);
|
||||
|
||||
if config.batch_size < memory.len() {
|
||||
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
||||
}
|
||||
|
||||
step += 1;
|
||||
episode_duration += 1;
|
||||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
env.reset();
|
||||
episode_done = true;
|
||||
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
||||
now.elapsed().unwrap().as_secs()
|
||||
);
|
||||
now = SystemTime::now();
|
||||
} else {
|
||||
state = *snapshot.state();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let valid_agent = agent.valid(nets.actor);
|
||||
if let Some(path) = &conf.save_path {
|
||||
if let Some(model) = valid_agent.model() {
|
||||
save_model(model, path);
|
||||
}
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Actor::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
use trictrac_bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
|
||||
use trictrac_bot::burnrl::environment::TrictracEnvironment;
|
||||
use trictrac_bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
||||
use trictrac_bot::burnrl::utils::{demo_model, Config};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::base::ElemType;
|
||||
use std::env;
|
||||
|
||||
type Backend = Autodiff<NdArray<ElemType>>;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let algo = &args[1];
|
||||
// let dir_path = &args[2];
|
||||
|
||||
let path = format!("bot/models/burnrl_{algo}");
|
||||
println!(
|
||||
"info: loading configuration from file {:?}",
|
||||
confy::get_configuration_file_path("trictrac_bot", None).unwrap()
|
||||
);
|
||||
let mut conf: Config = confy::load("trictrac_bot", None).expect("Could not load config");
|
||||
conf.save_path = Some(path.clone());
|
||||
println!("{conf}----------");
|
||||
|
||||
match algo.as_str() {
|
||||
"dqn" => {
|
||||
let _agent = dqn::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"dqn_valid" => {
|
||||
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = dqn_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac" => {
|
||||
let _agent = sac::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"sac_valid" => {
|
||||
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = sac_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo" => {
|
||||
let _agent = ppo::run::<TrictracEnvironment, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
"ppo_valid" => {
|
||||
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = ppo_valid::load_model(conf.dense_size, &path);
|
||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentValid, _, _> =
|
||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
&_ => {
|
||||
println!("unknown algo {algo}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
pub mod algos;
|
||||
pub mod environment;
|
||||
pub mod environment_valid;
|
||||
pub mod utils;
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
use burn::module::{Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::base::{Agent, ElemType, Environment};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub save_path: Option<String>,
|
||||
pub max_steps: usize, // max steps by episode
|
||||
pub num_episodes: usize,
|
||||
pub dense_size: usize, // neural network complexity
|
||||
|
||||
// discount factor. Plus élevé = encourage stratégies à long terme
|
||||
pub gamma: f32,
|
||||
// soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation plus lente moins sensible aux coups de chance
|
||||
pub tau: f32,
|
||||
// taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
pub learning_rate: f32,
|
||||
// nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
pub batch_size: usize,
|
||||
// limite max de correction à apporter au gradient (default 100)
|
||||
pub clip_grad: f32,
|
||||
|
||||
// ---- for SAC
|
||||
pub min_probability: f32,
|
||||
|
||||
// ---- for DQN
|
||||
// epsilon initial value (0.9 => more exploration)
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||
// epsilon is updated at the start of each episode
|
||||
pub eps_decay: f64,
|
||||
|
||||
// ---- for PPO
|
||||
pub lambda: f32,
|
||||
pub epsilon_clip: f32,
|
||||
pub critic_weight: f32,
|
||||
pub entropy_weight: f32,
|
||||
pub epochs: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
save_path: None,
|
||||
max_steps: 2000,
|
||||
num_episodes: 1000,
|
||||
dense_size: 256,
|
||||
gamma: 0.999,
|
||||
tau: 0.005,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
clip_grad: 100.0,
|
||||
min_probability: 1e-9,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
lambda: 0.95,
|
||||
epsilon_clip: 0.2,
|
||||
critic_weight: 0.5,
|
||||
entropy_weight: 0.01,
|
||||
epochs: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Config {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let mut s = String::new();
|
||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||
s.push_str(&format!("min_probability={:?}\n", self.min_probability));
|
||||
s.push_str(&format!("lambda={:?}\n", self.lambda));
|
||||
s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip));
|
||||
s.push_str(&format!("critic_weight={:?}\n", self.critic_weight));
|
||||
s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight));
|
||||
s.push_str(&format!("epochs={:?}\n", self.epochs));
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
||||
let mut env = E::new(true);
|
||||
let mut state = env.state();
|
||||
let mut done = false;
|
||||
while !done {
|
||||
if let Some(action) = agent.react(&state) {
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
this: &Param<Tensor<B, N>>,
|
||||
that: &Param<Tensor<B, N>>,
|
||||
tau: ElemType,
|
||||
) -> Param<Tensor<B, N>> {
|
||||
let that_weight = that.val();
|
||||
let this_weight = this.val();
|
||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||
|
||||
Param::initialized(ParamId::new(), new_weight)
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(
|
||||
this: Linear<B>,
|
||||
that: &Linear<B>,
|
||||
tau: ElemType,
|
||||
) -> Linear<B> {
|
||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||
let bias = match (&this.bias, &that.bias) {
|
||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Linear::<B> { weight, bias }
|
||||
}
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||
use std::fmt;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
@ -63,19 +62,71 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
|||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub struct DqnConfig {
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub num_episodes: usize,
|
||||
pub dense_size: usize,
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
pub eps_decay: f64,
|
||||
|
||||
pub gamma: f32,
|
||||
pub tau: f32,
|
||||
pub learning_rate: f32,
|
||||
pub batch_size: usize,
|
||||
pub clip_grad: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for DqnConfig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut s = String::new();
|
||||
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
num_episodes: 1000,
|
||||
dense_size: 256,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
gamma: 0.999,
|
||||
tau: 0.005,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
clip_grad: 100.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
conf: &DqnConfig,
|
||||
visualized: bool,
|
||||
// ) -> DQN<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
env.as_mut().min_steps = conf.min_steps;
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
|
|
@ -143,7 +194,8 @@ pub fn run<
|
|||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
let envmut = env.as_mut();
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
envmut.goodmoves_count,
|
||||
envmut.pointrolls_count,
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
|
|
@ -155,35 +207,5 @@ pub fn run<
|
|||
}
|
||||
}
|
||||
}
|
||||
let valid_agent = agent.valid();
|
||||
if let Some(path) = &conf.save_path {
|
||||
save_model(valid_agent.model().as_ref().unwrap(), path);
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
agent
|
||||
}
|
||||
|
|
@ -1,15 +1,8 @@
|
|||
use std::io::Write;
|
||||
|
||||
use crate::dqn::dqn_common;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{rng, Rng};
|
||||
use trictrac_store::training_common;
|
||||
use trictrac_store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -1.0012121;
|
||||
const REWARD_VALID_MOVE: f32 = 1.0012121;
|
||||
const REWARD_RATIO: f32 = 0.1;
|
||||
const WIN_POINTS: f32 = 100.0;
|
||||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
@ -52,10 +45,10 @@ pub struct TrictracAction {
|
|||
|
||||
impl Action for TrictracAction {
|
||||
fn random() -> Self {
|
||||
use rand::{rng, Rng};
|
||||
let mut rng = rng();
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
TrictracAction {
|
||||
index: rng.random_range(0..Self::size() as u32),
|
||||
index: rng.gen_range(0..Self::size() as u32),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +59,7 @@ impl Action for TrictracAction {
|
|||
}
|
||||
|
||||
fn size() -> usize {
|
||||
514
|
||||
1252
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +84,7 @@ pub struct TrictracEnvironment {
|
|||
current_state: TrictracState,
|
||||
episode_reward: f32,
|
||||
pub step_count: usize,
|
||||
pub best_ratio: f32,
|
||||
pub min_steps: f32,
|
||||
pub max_steps: usize,
|
||||
pub pointrolls_count: usize,
|
||||
pub goodmoves_count: usize,
|
||||
|
|
@ -114,7 +107,7 @@ impl Environment for TrictracEnvironment {
|
|||
let player2_id = 2;
|
||||
|
||||
// Commencer la partie
|
||||
let _ = game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
let current_state = TrictracState::from_game_state(&game);
|
||||
TrictracEnvironment {
|
||||
|
|
@ -124,7 +117,7 @@ impl Environment for TrictracEnvironment {
|
|||
current_state,
|
||||
episode_reward: 0.0,
|
||||
step_count: 0,
|
||||
best_ratio: 0.0,
|
||||
min_steps: 250.0,
|
||||
max_steps: 2000,
|
||||
pointrolls_count: 0,
|
||||
goodmoves_count: 0,
|
||||
|
|
@ -139,13 +132,12 @@ impl Environment for TrictracEnvironment {
|
|||
|
||||
fn reset(&mut self) -> Snapshot<Self> {
|
||||
// Réinitialiser le jeu
|
||||
let history = self.game.history.clone();
|
||||
self.game = GameState::new(false);
|
||||
self.game.init_player("DQN Agent");
|
||||
self.game.init_player("Opponent");
|
||||
|
||||
// Commencer la partie
|
||||
let _ = self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward = 0.0;
|
||||
|
|
@ -154,22 +146,11 @@ impl Environment for TrictracEnvironment {
|
|||
} else {
|
||||
self.goodmoves_count as f32 / self.step_count as f32
|
||||
};
|
||||
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
|
||||
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
|
||||
let path = "bot/models/logs/debug.log";
|
||||
if let Ok(mut out) = std::fs::File::create(path) {
|
||||
write!(out, "{history:?}").expect("could not write history log");
|
||||
}
|
||||
"!!!!"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
// println!(
|
||||
// "info: correct moves: {} ({}%) {}",
|
||||
// self.goodmoves_count,
|
||||
// (100.0 * self.goodmoves_ratio).round() as u32,
|
||||
// warning
|
||||
// );
|
||||
println!(
|
||||
"info: correct moves: {} ({}%)",
|
||||
self.goodmoves_count,
|
||||
(100.0 * self.goodmoves_ratio).round() as u32
|
||||
);
|
||||
self.step_count = 0;
|
||||
self.pointrolls_count = 0;
|
||||
self.goodmoves_count = 0;
|
||||
|
|
@ -184,7 +165,8 @@ impl Environment for TrictracEnvironment {
|
|||
let trictrac_action = Self::convert_action(action);
|
||||
|
||||
let mut reward = 0.0;
|
||||
let is_rollpoint;
|
||||
let mut is_rollpoint = false;
|
||||
let mut terminated = false;
|
||||
|
||||
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||
if self.game.active_player_id == self.active_player_id {
|
||||
|
|
@ -193,13 +175,12 @@ impl Environment for TrictracEnvironment {
|
|||
if is_rollpoint {
|
||||
self.pointrolls_count += 1;
|
||||
}
|
||||
if reward != ERROR_REWARD {
|
||||
if reward != Self::ERROR_REWARD {
|
||||
self.goodmoves_count += 1;
|
||||
}
|
||||
} else {
|
||||
// Action non convertible, pénalité
|
||||
panic!("action non convertible");
|
||||
//reward = -0.5;
|
||||
reward = -0.5;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -209,24 +190,22 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
// let max_steps = self.max_steps;
|
||||
// let max_steps = self.min_steps
|
||||
// + (self.max_steps as f32 - self.min_steps)
|
||||
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let max_steps = self.min_steps
|
||||
+ (self.max_steps as f32 - self.min_steps)
|
||||
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||
|
||||
if done {
|
||||
// Récompense finale basée sur le résultat
|
||||
if let Some(winner_id) = self.game.determine_winner() {
|
||||
if winner_id == self.active_player_id {
|
||||
reward += WIN_POINTS; // Victoire
|
||||
reward += 50.0; // Victoire
|
||||
} else {
|
||||
reward -= WIN_POINTS; // Défaite
|
||||
reward -= 25.0; // Défaite
|
||||
}
|
||||
}
|
||||
}
|
||||
let terminated = done || self.step_count >= self.max_steps;
|
||||
// let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||
|
||||
// Mettre à jour l'état
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
|
|
@ -244,60 +223,120 @@ impl Environment for TrictracEnvironment {
|
|||
}
|
||||
|
||||
impl TrictracEnvironment {
|
||||
const ERROR_REWARD: f32 = -1.12121;
|
||||
const REWARD_RATIO: f32 = 1.0;
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
#[allow(dead_code)]
|
||||
fn convert_valid_action_index(
|
||||
&self,
|
||||
action: TrictracAction,
|
||||
game_state: &GameState,
|
||||
) -> Option<training_common::TrictracAction> {
|
||||
use training_common::get_valid_actions;
|
||||
) -> Option<dqn_common::TrictracAction> {
|
||||
use dqn_common::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
if let Ok(valid_actions) = get_valid_actions(game_state) {
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: training_common::TrictracAction,
|
||||
// action: dqn_common::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
||||
use training_common::TrictracAction;
|
||||
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
|
||||
use dqn_common::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
reward += 0.1;
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.active_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game.dice.values.0, self.game.dice.values.1)
|
||||
} else {
|
||||
(self.game.dice.values.1, self.game.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = action.to_event(&self.game) {
|
||||
if let Some(event) = event {
|
||||
if self.game.validate(&event) {
|
||||
let _ = self.game.consume(&event);
|
||||
// reward += REWARD_VALID_MOVE;
|
||||
self.game.consume(&event);
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = rng();
|
||||
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6));
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.active_player_id,
|
||||
dice: trictrac_store::Dice {
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game.validate(&dice_event) {
|
||||
let _ = self.game.consume(&dice_event);
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
is_rollpoint = true;
|
||||
// println!("info: rolled for {reward}");
|
||||
|
|
@ -309,12 +348,8 @@ impl TrictracEnvironment {
|
|||
// Pénalité pour action invalide
|
||||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
reward = Self::ERROR_REWARD;
|
||||
}
|
||||
} else {
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
}
|
||||
|
||||
(reward, is_rollpoint)
|
||||
|
|
@ -337,24 +372,22 @@ impl TrictracEnvironment {
|
|||
*strategy.get_mut_game() = self.game.clone();
|
||||
|
||||
// Exécuter l'action selon le turn_stage
|
||||
let mut calculate_points = false;
|
||||
let opponent_color = trictrac_store::Color::Black;
|
||||
let event = match self.game.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = rng();
|
||||
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6));
|
||||
calculate_points = true;
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_id,
|
||||
dice: trictrac_store::Dice {
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkPoints => {
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
|
|
@ -363,13 +396,16 @@ impl TrictracEnvironment {
|
|||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points: points_rules.get_points(dice_roll_count).0,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints => {
|
||||
let opponent_color = trictrac_store::Color::Black;
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
|
|
@ -378,10 +414,11 @@ impl TrictracEnvironment {
|
|||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let points = points_rules.get_points(dice_roll_count).1;
|
||||
// pas de reward : déjà comptabilisé lors du tour de blanc
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points: points_rules.get_points(dice_roll_count).1,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
|
|
@ -397,20 +434,7 @@ impl TrictracEnvironment {
|
|||
};
|
||||
|
||||
if self.game.validate(&event) {
|
||||
let _ = self.game.consume(&event);
|
||||
if calculate_points {
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
.get(&self.opponent_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
// Récompense proportionnelle aux points
|
||||
reward -= REWARD_RATIO * (points as f32 - adv_points as f32);
|
||||
}
|
||||
self.game.consume(&event);
|
||||
}
|
||||
}
|
||||
reward
|
||||
53
bot/src/dqn/burnrl/main.rs
Normal file
53
bot/src/dqn/burnrl/main.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use bot::dqn::burnrl::{
|
||||
dqn_model, environment,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::ElemType;
|
||||
|
||||
type Backend = Autodiff<NdArray<ElemType>>;
|
||||
type Env = environment::TrictracEnvironment;
|
||||
|
||||
fn main() {
|
||||
// println!("> Entraînement");
|
||||
|
||||
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||
let conf = dqn_model::DqnConfig {
|
||||
// defaults
|
||||
num_episodes: 40, // 40
|
||||
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction)
|
||||
max_steps: 3000, // 1000 max steps by episode
|
||||
dense_size: 256, // 128 neural network complexity (default 128)
|
||||
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05, // 0.05
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 2000.0, // 1000 ?
|
||||
|
||||
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
|
||||
let path = "models/burn_dqn_40".to_string();
|
||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model(conf.dense_size, &path);
|
||||
let loaded_agent = DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
3
bot/src/dqn/burnrl/mod.rs
Normal file
3
bot/src/dqn/burnrl/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod dqn_model;
|
||||
pub mod environment;
|
||||
pub mod utils;
|
||||
114
bot/src/dqn/burnrl/utils.rs
Normal file
114
bot/src/dqn/burnrl/utils.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
use crate::dqn::burnrl::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
|
||||
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
println!("Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
dqn_model::Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// let action = match infer_action(&agent, &env, state) {
|
||||
let action = match infer_action(&agent, &env) {
|
||||
Some(value) => value,
|
||||
None => break,
|
||||
};
|
||||
// Execute action
|
||||
let snapshot = env.step(action);
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||
agent: &DQN<TrictracEnvironment, B, M>,
|
||||
env: &TrictracEnvironment,
|
||||
) -> Option<TrictracAction> {
|
||||
let state = env.state();
|
||||
// Get q-values
|
||||
let q_values = agent
|
||||
.model()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.infer(state.to_tensor().unsqueeze());
|
||||
// Get valid actions
|
||||
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = TrictracAction::from(action_index);
|
||||
Some(action)
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
this: &Param<Tensor<B, N>>,
|
||||
that: &Param<Tensor<B, N>>,
|
||||
tau: ElemType,
|
||||
) -> Param<Tensor<B, N>> {
|
||||
let that_weight = that.val();
|
||||
let this_weight = this.val();
|
||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||
|
||||
Param::initialized(ParamId::new(), new_weight)
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(
|
||||
this: Linear<B>,
|
||||
that: &Linear<B>,
|
||||
tau: ElemType,
|
||||
) -> Linear<B> {
|
||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||
let bias = match (&this.bias, &that.bias) {
|
||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Linear::<B> { weight, bias }
|
||||
}
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
use crate::burnrl::environment::TrictracEnvironment;
|
||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
||||
use crate::dqn::burnrl_valid::environment::TrictracEnvironment;
|
||||
use crate::dqn::burnrl_valid::utils::soft_update_linear;
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::optim::AdamWConfig;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::activation::relu;
|
||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
||||
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||
use std::fmt;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
|
@ -63,20 +62,67 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
|||
#[allow(unused)]
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub struct DqnConfig {
|
||||
pub max_steps: usize,
|
||||
pub num_episodes: usize,
|
||||
pub dense_size: usize,
|
||||
pub eps_start: f64,
|
||||
pub eps_end: f64,
|
||||
pub eps_decay: f64,
|
||||
|
||||
pub gamma: f32,
|
||||
pub tau: f32,
|
||||
pub learning_rate: f32,
|
||||
pub batch_size: usize,
|
||||
pub clip_grad: f32,
|
||||
}
|
||||
|
||||
impl fmt::Display for DqnConfig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut s = String::new();
|
||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_steps: 2000,
|
||||
num_episodes: 1000,
|
||||
dense_size: 256,
|
||||
eps_start: 0.9,
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
gamma: 0.999,
|
||||
tau: 0.005,
|
||||
learning_rate: 0.001,
|
||||
batch_size: 32,
|
||||
clip_grad: 100.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||
|
||||
#[allow(unused)]
|
||||
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
pub fn run<
|
||||
E: Environment + AsMut<TrictracEnvironment>,
|
||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
||||
>(
|
||||
conf: &Config,
|
||||
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||
conf: &DqnConfig,
|
||||
visualized: bool,
|
||||
// ) -> DQN<E, B, Net<B>> {
|
||||
) -> impl Agent<E> {
|
||||
) -> DQN<E, B, Net<B>> {
|
||||
// ) -> impl Agent<E> {
|
||||
let mut env = E::new(visualized);
|
||||
// env.as_mut().min_steps = conf.min_steps;
|
||||
env.as_mut().max_steps = conf.max_steps;
|
||||
|
||||
let model = Net::<B>::new(
|
||||
|
|
@ -143,13 +189,8 @@ pub fn run<
|
|||
|
||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||
let envmut = env.as_mut();
|
||||
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
|
||||
* 100.0)
|
||||
.round() as u32;
|
||||
println!(
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
envmut.goodmoves_count,
|
||||
goodmoves_ratio,
|
||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||
envmut.pointrolls_count,
|
||||
now.elapsed().unwrap().as_secs(),
|
||||
);
|
||||
|
|
@ -161,35 +202,5 @@ pub fn run<
|
|||
}
|
||||
}
|
||||
}
|
||||
let valid_agent = agent.valid();
|
||||
if let Some(path) = &conf.save_path {
|
||||
save_model(valid_agent.model().as_ref().unwrap(), path);
|
||||
}
|
||||
valid_agent
|
||||
}
|
||||
|
||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}.mpk");
|
||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
agent
|
||||
}
|
||||
|
|
@ -1,11 +1,8 @@
|
|||
use crate::dqn::dqn_common;
|
||||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||
use rand::{rng, Rng};
|
||||
use trictrac_store::training_common;
|
||||
use trictrac_store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
const ERROR_REWARD: f32 = -1.0012121;
|
||||
const REWARD_RATIO: f32 = 0.1;
|
||||
use rand::{thread_rng, Rng};
|
||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
|
||||
/// État du jeu Trictrac pour burn-rl
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
@ -48,10 +45,10 @@ pub struct TrictracAction {
|
|||
|
||||
impl Action for TrictracAction {
|
||||
fn random() -> Self {
|
||||
use rand::{rng, Rng};
|
||||
let mut rng = rng();
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
TrictracAction {
|
||||
index: rng.random_range(0..Self::size() as u32),
|
||||
index: rng.gen_range(0..Self::size() as u32),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +106,7 @@ impl Environment for TrictracEnvironment {
|
|||
let player2_id = 2;
|
||||
|
||||
// Commencer la partie
|
||||
let _ = game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
let current_state = TrictracState::from_game_state(&game);
|
||||
TrictracEnvironment {
|
||||
|
|
@ -136,7 +133,7 @@ impl Environment for TrictracEnvironment {
|
|||
self.game.init_player("Opponent");
|
||||
|
||||
// Commencer la partie
|
||||
let _ = self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
self.current_state = TrictracState::from_game_state(&self.game);
|
||||
self.episode_reward = 0.0;
|
||||
|
|
@ -159,26 +156,17 @@ impl Environment for TrictracEnvironment {
|
|||
if self.game.active_player_id == self.active_player_id {
|
||||
if let Some(action) = trictrac_action {
|
||||
(reward, is_rollpoint) = self.execute_action(action);
|
||||
// if reward != 0.0 {
|
||||
// println!("info: self rew {reward}");
|
||||
// }
|
||||
if is_rollpoint {
|
||||
self.pointrolls_count += 1;
|
||||
}
|
||||
} else {
|
||||
// Action non convertible, pénalité
|
||||
println!("info: action non convertible -> -1 {trictrac_action:?}");
|
||||
reward = -1.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||
// let op_rew = self.play_opponent_if_needed();
|
||||
// if op_rew != 0.0 {
|
||||
// println!("info: op rew {op_rew}");
|
||||
// }
|
||||
// reward += op_rew;
|
||||
reward += self.play_opponent_if_needed();
|
||||
}
|
||||
|
||||
|
|
@ -217,57 +205,112 @@ impl TrictracEnvironment {
|
|||
const REWARD_RATIO: f32 = 1.0;
|
||||
|
||||
/// Convertit une action burn-rl vers une action Trictrac
|
||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||
fn convert_valid_action_index(
|
||||
&self,
|
||||
action: TrictracAction,
|
||||
) -> Option<training_common::TrictracAction> {
|
||||
use training_common::get_valid_actions;
|
||||
) -> Option<dqn_common::TrictracAction> {
|
||||
use dqn_common::get_valid_actions;
|
||||
|
||||
// Obtenir les actions valides dans le contexte actuel
|
||||
if let Ok(valid_actions) = get_valid_actions(&self.game) {
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Mapper l'index d'action sur une action valide
|
||||
let action_index = (action.index as usize) % valid_actions.len();
|
||||
Some(valid_actions[action_index])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Some(valid_actions[action_index].clone())
|
||||
}
|
||||
|
||||
/// Exécute une action Trictrac dans le jeu
|
||||
// fn execute_action(
|
||||
// &mut self,
|
||||
// action: training_common::TrictracAction,
|
||||
// action: dqn_common::TrictracAction,
|
||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
||||
use training_common::TrictracAction;
|
||||
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
|
||||
use dqn_common::TrictracAction;
|
||||
|
||||
let mut reward = 0.0;
|
||||
let mut is_rollpoint = false;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.active_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.active_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game.dice.values.0, self.game.dice.values.1)
|
||||
} else {
|
||||
(self.game.dice.values.1, self.game.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = action.to_event(&self.game) {
|
||||
if let Some(event) = event {
|
||||
if self.game.validate(&event) {
|
||||
let _ = self.game.consume(&event);
|
||||
// reward += REWARD_VALID_MOVE;
|
||||
self.game.consume(&event);
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = rng();
|
||||
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6));
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.active_player_id,
|
||||
dice: trictrac_store::Dice {
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game.validate(&dice_event) {
|
||||
let _ = self.game.consume(&dice_event);
|
||||
self.game.consume(&dice_event);
|
||||
let (points, adv_points) = self.game.dice_points;
|
||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
||||
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
if points > 0 {
|
||||
is_rollpoint = true;
|
||||
// println!("info: rolled for {reward}");
|
||||
|
|
@ -279,12 +322,8 @@ impl TrictracEnvironment {
|
|||
// Pénalité pour action invalide
|
||||
// on annule les précédents reward
|
||||
// et on indique une valeur reconnaissable pour statistiques
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
reward = Self::ERROR_REWARD;
|
||||
}
|
||||
} else {
|
||||
reward = ERROR_REWARD;
|
||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
||||
}
|
||||
|
||||
(reward, is_rollpoint)
|
||||
|
|
@ -307,24 +346,22 @@ impl TrictracEnvironment {
|
|||
*strategy.get_mut_game() = self.game.clone();
|
||||
|
||||
// Exécuter l'action selon le turn_stage
|
||||
let mut calculate_points = false;
|
||||
let opponent_color = trictrac_store::Color::Black;
|
||||
let event = match self.game.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = rng();
|
||||
let dice_values = (rng.random_range(1..=6), rng.random_range(1..=6));
|
||||
calculate_points = true;
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_id,
|
||||
dice: trictrac_store::Dice {
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkPoints => {
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
|
|
@ -333,12 +370,16 @@ impl TrictracEnvironment {
|
|||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_id,
|
||||
points: points_rules.get_points(dice_roll_count).0,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints => {
|
||||
let opponent_color = store::Color::Black;
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
|
|
@ -367,20 +408,7 @@ impl TrictracEnvironment {
|
|||
};
|
||||
|
||||
if self.game.validate(&event) {
|
||||
let _ = self.game.consume(&event);
|
||||
if calculate_points {
|
||||
let dice_roll_count = self
|
||||
.game
|
||||
.players
|
||||
.get(&self.opponent_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules =
|
||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||
// Récompense proportionnelle aux points
|
||||
}
|
||||
self.game.consume(&event);
|
||||
}
|
||||
}
|
||||
reward
|
||||
52
bot/src/dqn/burnrl_valid/main.rs
Normal file
52
bot/src/dqn/burnrl_valid/main.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
use bot::dqn::burnrl_valid::{
|
||||
dqn_model, environment,
|
||||
utils::{demo_model, load_model, save_model},
|
||||
};
|
||||
use burn::backend::{Autodiff, NdArray};
|
||||
use burn_rl::agent::DQN;
|
||||
use burn_rl::base::ElemType;
|
||||
|
||||
type Backend = Autodiff<NdArray<ElemType>>;
|
||||
type Env = environment::TrictracEnvironment;
|
||||
|
||||
fn main() {
|
||||
// println!("> Entraînement");
|
||||
|
||||
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||
let conf = dqn_model::DqnConfig {
|
||||
// defaults
|
||||
num_episodes: 100, // 40
|
||||
max_steps: 1000, // 1000 max steps by episode
|
||||
dense_size: 256, // 128 neural network complexity (default 128)
|
||||
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05, // 0.05
|
||||
// eps_decay higher = epsilon decrease slower
|
||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||
// epsilon is updated at the start of each episode
|
||||
eps_decay: 2000.0, // 1000 ?
|
||||
|
||||
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||
// plus lente moins sensible aux coups de chance
|
||||
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||
// converger
|
||||
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||
};
|
||||
println!("{conf}----------");
|
||||
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||
|
||||
let valid_agent = agent.valid();
|
||||
|
||||
println!("> Sauvegarde du modèle de validation");
|
||||
|
||||
let path = "bot/models/burn_dqn_valid_40".to_string();
|
||||
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||
|
||||
println!("> Chargement du modèle pour test");
|
||||
let loaded_model = load_model(conf.dense_size, &path);
|
||||
let loaded_agent = DQN::new(loaded_model.unwrap());
|
||||
|
||||
println!("> Test avec le modèle chargé");
|
||||
demo_model(loaded_agent);
|
||||
}
|
||||
3
bot/src/dqn/burnrl_valid/mod.rs
Normal file
3
bot/src/dqn/burnrl_valid/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod dqn_model;
|
||||
pub mod environment;
|
||||
pub mod utils;
|
||||
114
bot/src/dqn/burnrl_valid/utils.rs
Normal file
114
bot/src/dqn/burnrl_valid/utils.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
use crate::dqn::burnrl_valid::{
|
||||
dqn_model,
|
||||
environment::{TrictracAction, TrictracEnvironment},
|
||||
};
|
||||
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||
use burn::module::{Module, Param, ParamId};
|
||||
use burn::nn::Linear;
|
||||
use burn::record::{CompactRecorder, Recorder};
|
||||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::tensor::Tensor;
|
||||
use burn_rl::agent::{DQNModel, DQN};
|
||||
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||
|
||||
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||
let recorder = CompactRecorder::new();
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
println!("Modèle de validation sauvegardé : {model_path}");
|
||||
recorder
|
||||
.record(model.clone().into_record(), model_path.into())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||
let model_path = format!("{path}_model.mpk");
|
||||
// println!("Chargement du modèle depuis : {model_path}");
|
||||
|
||||
CompactRecorder::new()
|
||||
.load(model_path.into(), &NdArrayDevice::default())
|
||||
.map(|record| {
|
||||
dqn_model::Net::new(
|
||||
<TrictracEnvironment as Environment>::StateType::size(),
|
||||
dense_size,
|
||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||
)
|
||||
.load_record(record)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||
let mut env = TrictracEnvironment::new(true);
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// let action = match infer_action(&agent, &env, state) {
|
||||
let action = match infer_action(&agent, &env) {
|
||||
Some(value) => value,
|
||||
None => break,
|
||||
};
|
||||
// Execute action
|
||||
let snapshot = env.step(action);
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||
agent: &DQN<TrictracEnvironment, B, M>,
|
||||
env: &TrictracEnvironment,
|
||||
) -> Option<TrictracAction> {
|
||||
let state = env.state();
|
||||
// Get q-values
|
||||
let q_values = agent
|
||||
.model()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.infer(state.to_tensor().unsqueeze());
|
||||
// Get valid actions
|
||||
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
// Set non valid actions q-values to lowest
|
||||
let mut masked_q_values = q_values.clone();
|
||||
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
if !valid_actions_indices.contains(&index) {
|
||||
masked_q_values = masked_q_values.clone().mask_fill(
|
||||
masked_q_values.clone().equal_elem(*q_value),
|
||||
f32::NEG_INFINITY,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
let action = TrictracAction::from(action_index);
|
||||
Some(action)
|
||||
}
|
||||
|
||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||
this: &Param<Tensor<B, N>>,
|
||||
that: &Param<Tensor<B, N>>,
|
||||
tau: ElemType,
|
||||
) -> Param<Tensor<B, N>> {
|
||||
let that_weight = that.val();
|
||||
let this_weight = this.val();
|
||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||
|
||||
Param::initialized(ParamId::new(), new_weight)
|
||||
}
|
||||
|
||||
pub fn soft_update_linear<B: Backend>(
|
||||
this: Linear<B>,
|
||||
that: &Linear<B>,
|
||||
tau: ElemType,
|
||||
) -> Linear<B> {
|
||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||
let bias = match (&this.bias, &that.bias) {
|
||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Linear::<B> { weight, bias }
|
||||
}
|
||||
255
bot/src/dqn/dqn_common.rs
Normal file
255
bot/src/dqn/dqn_common.rs
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
use std::cmp::{max, min};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use store::{CheckerMove, Dice};
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum TrictracAction {
|
||||
/// Lancer les dés
|
||||
Roll,
|
||||
/// Continuer après avoir gagné un trou
|
||||
Go,
|
||||
/// Effectuer un mouvement de pions
|
||||
Move {
|
||||
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||
from1: usize, // position de départ du premier pion (0-24)
|
||||
from2: usize, // position de départ du deuxième pion (0-24)
|
||||
},
|
||||
// Marquer les points : à activer si support des écoles
|
||||
// Mark,
|
||||
}
|
||||
|
||||
impl TrictracAction {
|
||||
/// Encode une action en index pour le réseau de neurones
|
||||
pub fn to_action_index(&self) -> usize {
|
||||
match self {
|
||||
TrictracAction::Roll => 0,
|
||||
TrictracAction::Go => 1,
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Encoder les mouvements dans l'espace d'actions
|
||||
// Indices 2+ pour les mouvements
|
||||
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier)
|
||||
let mut start = 2;
|
||||
if !dice_order {
|
||||
// 25 * 25 = 625
|
||||
start += 625;
|
||||
}
|
||||
start + from1 * 25 + from2
|
||||
} // TrictracAction::Mark => 1252,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un index d'action en TrictracAction
|
||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||
match index {
|
||||
0 => Some(TrictracAction::Roll),
|
||||
// 1252 => Some(TrictracAction::Mark),
|
||||
1 => Some(TrictracAction::Go),
|
||||
i if i >= 3 => {
|
||||
let move_code = i - 3;
|
||||
let (dice_order, from1, from2) = Self::decode_move(move_code);
|
||||
Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un entier en paire de mouvements
|
||||
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||
let mut encoded = code;
|
||||
let dice_order = code < 626;
|
||||
if !dice_order {
|
||||
encoded -= 625
|
||||
}
|
||||
let from1 = encoded / 25;
|
||||
let from2 = 1 + encoded % 25;
|
||||
(dice_order, from1, from2)
|
||||
}
|
||||
|
||||
/// Retourne la taille de l'espace d'actions total
|
||||
pub fn action_space_size() -> usize {
|
||||
// 1 (Roll) + 1 (Go) + mouvements possibles
|
||||
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
|
||||
// Mais on peut optimiser en limitant aux positions valides (1-24)
|
||||
2 + (2 * 25 * 25) // = 1252
|
||||
}
|
||||
|
||||
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
||||
// match action {
|
||||
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
|
||||
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
|
||||
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
|
||||
// TrictracAction::Move {
|
||||
// dice_order,
|
||||
// from1,
|
||||
// from2,
|
||||
// } => {
|
||||
// // Effectuer un mouvement
|
||||
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
|
||||
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
|
||||
//
|
||||
// Some(GameEvent::Move {
|
||||
// player_id: self.agent_player_id,
|
||||
// moves: (checker_move1, checker_move2),
|
||||
// })
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
}
|
||||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||
use store::TurnStage;
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
||||
let active_player_id = game_state.active_player_id;
|
||||
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||
|
||||
if let Some(color) = player_color {
|
||||
match game_state.turn_stage {
|
||||
TurnStage::RollDice | TurnStage::RollWaiting => {
|
||||
valid_actions.push(TrictracAction::Roll);
|
||||
}
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||
// valid_actions.push(TrictracAction::Mark);
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
valid_actions.push(TrictracAction::Go);
|
||||
|
||||
// Ajoute aussi les mouvements possibles
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||
assert_eq!(color, store::Color::White);
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1,
|
||||
&move2,
|
||||
&game_state.dice,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
valid_actions
|
||||
}
|
||||
|
||||
// Valid only for White player
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
) -> TrictracAction {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
let from1 = move1.get_from();
|
||||
let from2 = move2.get_from();
|
||||
|
||||
let mut diff_move1 = if to1 > 0 {
|
||||
// Mouvement sans sortie
|
||||
to1 - from1
|
||||
} else {
|
||||
// sortie, on utilise la valeur du dé
|
||||
if to2 > 0 {
|
||||
// sortie pour le mouvement 1 uniquement
|
||||
let dice2 = to2 - from2;
|
||||
if dice2 == dice.values.0 as usize {
|
||||
dice.values.1 as usize
|
||||
} else {
|
||||
dice.values.0 as usize
|
||||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
|
||||
let rest_field = 12;
|
||||
if to1 == rest_field
|
||||
&& to2 == rest_field
|
||||
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
|
||||
{
|
||||
// prise par puissance
|
||||
diff_move1 += 1;
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order: diff_move1 == dice.values.0 as usize,
|
||||
from1: move1.get_from(),
|
||||
from2: move2.get_from(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retourne les indices des actions valides
|
||||
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
|
||||
get_valid_actions(game_state)
|
||||
.into_iter()
|
||||
.map(|action| action.to_action_index())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Sélectionne une action valide aléatoire
|
||||
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
let mut rng = thread_rng();
|
||||
valid_actions.choose(&mut rng).cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
let index = action.to_action_index();
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||
assert_eq!(81, index);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
from1: 3,
|
||||
from2: 4,
|
||||
};
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||
}
|
||||
}
|
||||
5
bot/src/dqn/mod.rs
Normal file
5
bot/src/dqn/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
pub mod burnrl;
|
||||
pub mod dqn_common;
|
||||
pub mod simple;
|
||||
|
||||
pub mod burnrl_valid;
|
||||
154
bot/src/dqn/simple/dqn_model.rs
Normal file
154
bot/src/dqn/simple/dqn_model.rs
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
use crate::dqn::dqn_common::TrictracAction;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration pour l'agent DQN
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DqnConfig {
|
||||
pub state_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_actions: usize,
|
||||
pub learning_rate: f64,
|
||||
pub gamma: f64,
|
||||
pub epsilon: f64,
|
||||
pub epsilon_decay: f64,
|
||||
pub epsilon_min: f64,
|
||||
pub replay_buffer_size: usize,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for DqnConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state_size: 36,
|
||||
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||
num_actions: TrictracAction::action_space_size(),
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 0.1,
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimpleNeuralNetwork {
|
||||
pub weights1: Vec<Vec<f32>>,
|
||||
pub biases1: Vec<f32>,
|
||||
pub weights2: Vec<Vec<f32>>,
|
||||
pub biases2: Vec<f32>,
|
||||
pub weights3: Vec<Vec<f32>>,
|
||||
pub biases3: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SimpleNeuralNetwork {
|
||||
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||||
use rand::{thread_rng, Rng};
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||
let weights1 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..input_size)
|
||||
.map(|_| rng.gen_range(-scale1..scale1))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases1 = vec![0.0; hidden_size];
|
||||
|
||||
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights2 = (0..hidden_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale2..scale2))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases2 = vec![0.0; hidden_size];
|
||||
|
||||
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||
let weights3 = (0..output_size)
|
||||
.map(|_| {
|
||||
(0..hidden_size)
|
||||
.map(|_| rng.gen_range(-scale3..scale3))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let biases3 = vec![0.0; output_size];
|
||||
|
||||
Self {
|
||||
weights1,
|
||||
biases1,
|
||||
weights2,
|
||||
biases2,
|
||||
weights3,
|
||||
biases3,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
// Première couche
|
||||
let mut layer1: Vec<f32> = self.biases1.clone();
|
||||
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < input.len() {
|
||||
layer1[i] += input[j] * weight;
|
||||
}
|
||||
}
|
||||
layer1[i] = layer1[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Deuxième couche
|
||||
let mut layer2: Vec<f32> = self.biases2.clone();
|
||||
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer1.len() {
|
||||
layer2[i] += layer1[j] * weight;
|
||||
}
|
||||
}
|
||||
layer2[i] = layer2[i].max(0.0); // ReLU
|
||||
}
|
||||
|
||||
// Couche de sortie
|
||||
let mut output: Vec<f32> = self.biases3.clone();
|
||||
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||||
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||
if j < layer2.len() {
|
||||
output[i] += layer2[j] * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||||
let q_values = self.forward(input);
|
||||
q_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(index, _)| index)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(path, data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let data = std::fs::read_to_string(path)?;
|
||||
let network = serde_json::from_str(&data)?;
|
||||
Ok(network)
|
||||
}
|
||||
}
|
||||
|
||||
490
bot/src/dqn/simple/dqn_trainer.rs
Normal file
490
bot/src/dqn/simple/dqn_trainer.rs
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
use crate::{CheckerMove, Color, GameState, PlayerId};
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||
|
||||
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
|
||||
use crate::dqn::dqn_common::{get_valid_actions, TrictracAction};
|
||||
|
||||
/// Expérience pour le buffer de replay
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Experience {
|
||||
pub state: Vec<f32>,
|
||||
pub action: TrictracAction,
|
||||
pub reward: f32,
|
||||
pub next_state: Vec<f32>,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
/// Buffer de replay pour stocker les expériences
|
||||
#[derive(Debug)]
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Experience>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, experience: Experience) {
|
||||
if self.buffer.len() >= self.capacity {
|
||||
self.buffer.pop_front();
|
||||
}
|
||||
self.buffer.push_back(experience);
|
||||
}
|
||||
|
||||
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||||
let mut rng = thread_rng();
|
||||
let len = self.buffer.len();
|
||||
if len < batch_size {
|
||||
return self.buffer.iter().cloned().collect();
|
||||
}
|
||||
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
for _ in 0..batch_size {
|
||||
let idx = rng.gen_range(0..len);
|
||||
batch.push(self.buffer[idx].clone());
|
||||
}
|
||||
batch
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent DQN pour l'apprentissage par renforcement
|
||||
#[derive(Debug)]
|
||||
pub struct DqnAgent {
|
||||
config: DqnConfig,
|
||||
model: SimpleNeuralNetwork,
|
||||
target_model: SimpleNeuralNetwork,
|
||||
replay_buffer: ReplayBuffer,
|
||||
epsilon: f64,
|
||||
step_count: usize,
|
||||
}
|
||||
|
||||
impl DqnAgent {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
let model =
|
||||
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||
let target_model = model.clone();
|
||||
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||
let epsilon = config.epsilon;
|
||||
|
||||
Self {
|
||||
config,
|
||||
model,
|
||||
target_model,
|
||||
replay_buffer,
|
||||
epsilon,
|
||||
step_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
// Fallback si aucune action valide
|
||||
return TrictracAction::Roll;
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
if rng.gen::<f64>() < self.epsilon {
|
||||
// Exploration : action valide aléatoire
|
||||
valid_actions
|
||||
.choose(&mut rng)
|
||||
.cloned()
|
||||
.unwrap_or(TrictracAction::Roll)
|
||||
} else {
|
||||
// Exploitation : meilleure action valide selon le modèle
|
||||
let q_values = self.model.forward(state);
|
||||
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for action in &valid_actions {
|
||||
let action_index = action.to_action_index();
|
||||
if action_index < q_values.len() {
|
||||
let q_value = q_values[action_index];
|
||||
if q_value > best_q_value {
|
||||
best_q_value = q_value;
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best_action.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_experience(&mut self, experience: Experience) {
|
||||
self.replay_buffer.push(experience);
|
||||
}
|
||||
|
||||
pub fn train(&mut self) {
|
||||
if self.replay_buffer.len() < self.config.batch_size {
|
||||
return;
|
||||
}
|
||||
|
||||
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||||
// Dans une implémentation complète, ici on ferait la backpropagation
|
||||
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||||
self.step_count += 1;
|
||||
|
||||
// Mise à jour du target model tous les 100 steps
|
||||
if self.step_count % 100 == 0 {
|
||||
self.target_model = self.model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
path: P,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.model.save(path)
|
||||
}
|
||||
|
||||
pub fn get_epsilon(&self) -> f64 {
|
||||
self.epsilon
|
||||
}
|
||||
|
||||
pub fn get_step_count(&self) -> usize {
|
||||
self.step_count
|
||||
}
|
||||
}
|
||||
|
||||
/// Environnement Trictrac pour l'entraînement
|
||||
#[derive(Debug)]
|
||||
pub struct TrictracEnv {
|
||||
pub game_state: GameState,
|
||||
pub agent_player_id: PlayerId,
|
||||
pub opponent_player_id: PlayerId,
|
||||
pub agent_color: Color,
|
||||
pub max_steps: usize,
|
||||
pub current_step: usize,
|
||||
}
|
||||
|
||||
impl Default for TrictracEnv {
|
||||
fn default() -> Self {
|
||||
let mut game_state = GameState::new(false);
|
||||
game_state.init_player("agent");
|
||||
game_state.init_player("opponent");
|
||||
|
||||
Self {
|
||||
game_state,
|
||||
agent_player_id: 1,
|
||||
opponent_player_id: 2,
|
||||
agent_color: Color::White,
|
||||
max_steps: 1000,
|
||||
current_step: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
pub fn reset(&mut self) -> Vec<f32> {
|
||||
self.game_state = GameState::new(false);
|
||||
self.game_state.init_player("agent");
|
||||
self.game_state.init_player("opponent");
|
||||
|
||||
// Commencer la partie
|
||||
self.game_state.consume(&GameEvent::BeginGame {
|
||||
goes_first: self.agent_player_id,
|
||||
});
|
||||
|
||||
self.current_step = 0;
|
||||
self.game_state.to_vec_float()
|
||||
}
|
||||
|
||||
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
|
||||
let mut reward = 0.0;
|
||||
|
||||
// Appliquer l'action de l'agent
|
||||
if self.game_state.active_player_id == self.agent_player_id {
|
||||
reward += self.apply_agent_action(action);
|
||||
}
|
||||
|
||||
// Faire jouer l'adversaire (stratégie simple)
|
||||
while self.game_state.active_player_id == self.opponent_player_id
|
||||
&& self.game_state.stage != Stage::Ended
|
||||
{
|
||||
reward += self.play_opponent_turn();
|
||||
}
|
||||
|
||||
// Vérifier si la partie est terminée
|
||||
let done = self.game_state.stage == Stage::Ended
|
||||
|| self.game_state.determine_winner().is_some()
|
||||
|| self.current_step >= self.max_steps;
|
||||
|
||||
// Récompense finale si la partie est terminée
|
||||
if done {
|
||||
if let Some(winner) = self.game_state.determine_winner() {
|
||||
if winner == self.agent_player_id {
|
||||
reward += 100.0; // Bonus pour gagner
|
||||
} else {
|
||||
reward -= 50.0; // Pénalité pour perdre
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.current_step += 1;
|
||||
let next_state = self.game_state.to_vec_float();
|
||||
(next_state, reward, done)
|
||||
}
|
||||
|
||||
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
|
||||
let event = match action {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
reward += 0.1;
|
||||
Some(GameEvent::Roll {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game_state.
|
||||
// reward += 0.1 * points as f32;
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.agent_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Go {
|
||||
player_id: self.agent_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if dice_order {
|
||||
(self.game_state.dice.values.0, self.game_state.dice.values.1)
|
||||
} else {
|
||||
(self.game_state.dice.values.1, self.game_state.dice.values.0)
|
||||
};
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
reward += 0.2;
|
||||
Some(GameEvent::Move {
|
||||
player_id: self.agent_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Appliquer l'événement si valide
|
||||
if let Some(event) = event {
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Simuler le résultat des dés après un Roll
|
||||
if matches!(action, TrictracAction::Roll) {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
let dice_event = GameEvent::RollResult {
|
||||
player_id: self.agent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
};
|
||||
if self.game_state.validate(&dice_event) {
|
||||
self.game_state.consume(&dice_event);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pénalité pour action invalide
|
||||
reward -= 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
reward
|
||||
}
|
||||
|
||||
// TODO : use default bot strategy
|
||||
fn play_opponent_turn(&mut self) -> f32 {
|
||||
let mut reward = 0.0;
|
||||
let event = match self.game_state.turn_stage {
|
||||
TurnStage::RollDice => GameEvent::Roll {
|
||||
player_id: self.opponent_player_id,
|
||||
},
|
||||
TurnStage::RollWaiting => {
|
||||
let mut rng = thread_rng();
|
||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||
GameEvent::RollResult {
|
||||
player_id: self.opponent_player_id,
|
||||
dice: store::Dice {
|
||||
values: dice_values,
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let dice_roll_count = self
|
||||
.game_state
|
||||
.players
|
||||
.get(&self.opponent_player_id)
|
||||
.unwrap()
|
||||
.dice_roll_count;
|
||||
let points_rules = PointsRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||
|
||||
GameEvent::Mark {
|
||||
player_id: self.opponent_player_id,
|
||||
points,
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let opponent_color = self.agent_color.opponent_color();
|
||||
let rules = MoveRules::new(
|
||||
&opponent_color,
|
||||
&self.game_state.board,
|
||||
self.game_state.dice,
|
||||
);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Stratégie simple : choix aléatoire
|
||||
let mut rng = thread_rng();
|
||||
let choosen_move = *possible_moves
|
||||
.choose(&mut rng)
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
GameEvent::Move {
|
||||
player_id: self.opponent_player_id,
|
||||
moves: if opponent_color == Color::White {
|
||||
choosen_move
|
||||
} else {
|
||||
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||
},
|
||||
}
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
// Stratégie simple : toujours continuer
|
||||
GameEvent::Go {
|
||||
player_id: self.opponent_player_id,
|
||||
}
|
||||
}
|
||||
};
|
||||
if self.game_state.validate(&event) {
|
||||
self.game_state.consume(&event);
|
||||
}
|
||||
reward
|
||||
}
|
||||
}
|
||||
|
||||
/// Entraîneur pour le modèle DQN
|
||||
pub struct DqnTrainer {
|
||||
agent: DqnAgent,
|
||||
env: TrictracEnv,
|
||||
}
|
||||
|
||||
impl DqnTrainer {
|
||||
pub fn new(config: DqnConfig) -> Self {
|
||||
Self {
|
||||
agent: DqnAgent::new(config),
|
||||
env: TrictracEnv::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_episode(&mut self) -> f32 {
|
||||
let mut total_reward = 0.0;
|
||||
let mut state = self.env.reset();
|
||||
// let mut step_count = 0;
|
||||
|
||||
loop {
|
||||
// step_count += 1;
|
||||
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||
let (next_state, reward, done) = self.env.step(action.clone());
|
||||
total_reward += reward;
|
||||
|
||||
let experience = Experience {
|
||||
state: state.clone(),
|
||||
action,
|
||||
reward,
|
||||
next_state: next_state.clone(),
|
||||
done,
|
||||
};
|
||||
self.agent.store_experience(experience);
|
||||
self.agent.train();
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
// if step_count % 100 == 0 {
|
||||
// println!("{:?}", next_state);
|
||||
// }
|
||||
state = next_state;
|
||||
}
|
||||
|
||||
total_reward
|
||||
}
|
||||
|
||||
pub fn train(
|
||||
&mut self,
|
||||
episodes: usize,
|
||||
save_every: usize,
|
||||
model_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
|
||||
|
||||
for episode in 1..=episodes {
|
||||
let reward = self.train_episode();
|
||||
|
||||
if episode % 100 == 0 {
|
||||
println!(
|
||||
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||
episode,
|
||||
episodes,
|
||||
reward,
|
||||
self.agent.get_epsilon(),
|
||||
self.agent.get_step_count()
|
||||
);
|
||||
}
|
||||
|
||||
if episode % save_every == 0 {
|
||||
let save_path = format!("{}_episode_{}.json", model_path, episode);
|
||||
self.agent.save_model(&save_path)?;
|
||||
println!("Modèle sauvegardé : {}", save_path);
|
||||
}
|
||||
}
|
||||
|
||||
// Sauvegarder le modèle final
|
||||
let final_path = format!("{}_final.json", model_path);
|
||||
self.agent.save_model(&final_path)?;
|
||||
println!("Modèle final sauvegardé : {}", final_path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
112
bot/src/dqn/simple/main.rs
Normal file
112
bot/src/dqn/simple/main.rs
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
use bot::dqn::dqn_common::TrictracAction;
|
||||
use bot::dqn::simple::dqn_model::DqnConfig;
|
||||
use bot::dqn::simple::dqn_trainer::DqnTrainer;
|
||||
use std::env;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::init();
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
// Paramètres par défaut
|
||||
let mut episodes = 1000;
|
||||
let mut model_path = "models/dqn_model".to_string();
|
||||
let mut save_every = 100;
|
||||
|
||||
// Parser les arguments de ligne de commande
|
||||
let mut i = 1;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--episodes" => {
|
||||
if i + 1 < args.len() {
|
||||
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--model-path" => {
|
||||
if i + 1 < args.len() {
|
||||
model_path = args[i + 1].clone();
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --model-path nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--save-every" => {
|
||||
if i + 1 < args.len() {
|
||||
save_every = args[i + 1].parse().unwrap_or(100);
|
||||
i += 2;
|
||||
} else {
|
||||
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
std::process::exit(0);
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Argument inconnu : {}", args[i]);
|
||||
print_help();
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Créer le dossier models s'il n'existe pas
|
||||
std::fs::create_dir_all("models")?;
|
||||
|
||||
println!("Configuration d'entraînement DQN :");
|
||||
println!(" Épisodes : {}", episodes);
|
||||
println!(" Chemin du modèle : {}", model_path);
|
||||
println!(" Sauvegarde tous les {} épisodes", save_every);
|
||||
println!();
|
||||
|
||||
// Configuration DQN
|
||||
let config = DqnConfig {
|
||||
state_size: 36, // state.to_vec size
|
||||
hidden_size: 256,
|
||||
num_actions: TrictracAction::action_space_size(),
|
||||
learning_rate: 0.001,
|
||||
gamma: 0.99,
|
||||
epsilon: 0.9, // Commencer avec plus d'exploration
|
||||
epsilon_decay: 0.995,
|
||||
epsilon_min: 0.01,
|
||||
replay_buffer_size: 10000,
|
||||
batch_size: 32,
|
||||
};
|
||||
|
||||
// Créer et lancer l'entraîneur
|
||||
let mut trainer = DqnTrainer::new(config);
|
||||
trainer.train(episodes, save_every, &model_path)?;
|
||||
|
||||
println!("Entraînement terminé avec succès !");
|
||||
println!("Pour utiliser le modèle entraîné :");
|
||||
println!(
|
||||
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
|
||||
model_path
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Entraîneur DQN pour Trictrac");
|
||||
println!();
|
||||
println!("USAGE:");
|
||||
println!(" cargo run --bin=train_dqn [OPTIONS]");
|
||||
println!();
|
||||
println!("OPTIONS:");
|
||||
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
|
||||
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
||||
println!(" -h, --help Afficher cette aide");
|
||||
println!();
|
||||
println!("EXEMPLES:");
|
||||
println!(" cargo run --bin=train_dqn");
|
||||
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
|
||||
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
|
||||
}
|
||||
2
bot/src/dqn/simple/mod.rs
Normal file
2
bot/src/dqn/simple/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod dqn_model;
|
||||
pub mod dqn_trainer;
|
||||
|
|
@ -1,16 +1,14 @@
|
|||
pub mod burnrl;
|
||||
pub mod dqn;
|
||||
pub mod strategy;
|
||||
pub mod trictrac_board;
|
||||
|
||||
use log::debug;
|
||||
use log::{debug, error};
|
||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||
pub use strategy::default::DefaultStrategy;
|
||||
pub use strategy::dqn::DqnStrategy;
|
||||
pub use strategy::dqnburn::DqnBurnStrategy;
|
||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||
pub use strategy::random::RandomStrategy;
|
||||
pub use strategy::stable_baselines3::StableBaselines3Strategy;
|
||||
use trictrac_store::{
|
||||
CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage,
|
||||
};
|
||||
|
||||
pub trait BotStrategy: std::fmt::Debug {
|
||||
fn get_game(&self) -> &GameState;
|
||||
|
|
@ -71,14 +69,14 @@ impl Bot {
|
|||
debug!(">>>> {:?} BOT handle", self.color);
|
||||
let game = self.strategy.get_mut_game();
|
||||
let internal_event = if self.color == Color::Black {
|
||||
&event.get_mirror(false)
|
||||
&event.get_mirror()
|
||||
} else {
|
||||
event
|
||||
};
|
||||
|
||||
let init_player_points = game.who_plays().map(|p| (p.points, p.holes));
|
||||
let turn_stage = game.turn_stage;
|
||||
let _ = game.consume(internal_event);
|
||||
game.consume(internal_event);
|
||||
if game.stage == Stage::Ended {
|
||||
debug!("<<<< end {:?} BOT handle", self.color);
|
||||
return None;
|
||||
|
|
@ -126,7 +124,7 @@ impl Bot {
|
|||
return if self.color == Color::Black {
|
||||
debug!(" bot (internal) evt : {internal_event:?} ; points : {player_points:?}");
|
||||
debug!("<<<< end {:?} BOT handle", self.color);
|
||||
internal_event.map(|evt| evt.get_mirror(false))
|
||||
internal_event.map(|evt| evt.get_mirror())
|
||||
} else {
|
||||
debug!("<<<< end {:?} BOT handle", self.color);
|
||||
internal_event
|
||||
|
|
@ -145,7 +143,7 @@ impl Bot {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use trictrac_store::{Dice, Stage};
|
||||
use store::{Dice, Stage};
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use trictrac_store::MoveRules;
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DefaultStrategy {
|
||||
|
|
|
|||
174
bot/src/strategy/dqn.rs
Normal file
174
bot/src/strategy/dqn.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use log::info;
|
||||
use std::path::Path;
|
||||
use store::MoveRules;
|
||||
|
||||
use crate::dqn::dqn_common::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork;
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
pub struct DqnStrategy {
|
||||
pub game: GameState,
|
||||
pub player_id: PlayerId,
|
||||
pub color: Color,
|
||||
pub model: Option<SimpleNeuralNetwork>,
|
||||
}
|
||||
|
||||
impl Default for DqnStrategy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
game: GameState::default(),
|
||||
player_id: 1,
|
||||
color: Color::White,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DqnStrategy {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn new_with_model<P: AsRef<Path> + std::fmt::Debug>(model_path: P) -> Self {
|
||||
let mut strategy = Self::new();
|
||||
if let Ok(model) = SimpleNeuralNetwork::load(&model_path) {
|
||||
info!("Loading model {model_path:?}");
|
||||
strategy.model = Some(model);
|
||||
}
|
||||
strategy
|
||||
}
|
||||
|
||||
/// Utilise le modèle DQN pour choisir une action valide
|
||||
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||
if let Some(ref model) = self.model {
|
||||
let state = self.game.to_vec_float();
|
||||
let valid_actions = get_valid_actions(&self.game);
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Obtenir les Q-values pour toutes les actions
|
||||
let q_values = model.forward(&state);
|
||||
|
||||
// Trouver la meilleure action valide
|
||||
let mut best_action = &valid_actions[0];
|
||||
let mut best_q_value = f32::NEG_INFINITY;
|
||||
|
||||
for action in &valid_actions {
|
||||
let action_index = action.to_action_index();
|
||||
if action_index < q_values.len() {
|
||||
let q_value = q_values[action_index];
|
||||
if q_value > best_q_value {
|
||||
best_q_value = q_value;
|
||||
best_action = action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(best_action.clone())
|
||||
} else {
|
||||
// Fallback : action aléatoire valide
|
||||
sample_valid_action(&self.game)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BotStrategy for DqnStrategy {
|
||||
fn get_game(&self) -> &GameState {
|
||||
&self.game
|
||||
}
|
||||
|
||||
fn get_mut_game(&mut self) -> &mut GameState {
|
||||
&mut self.game
|
||||
}
|
||||
|
||||
fn set_color(&mut self, color: Color) {
|
||||
self.color = color;
|
||||
}
|
||||
|
||||
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||
self.player_id = player_id;
|
||||
}
|
||||
|
||||
fn calculate_points(&self) -> u8 {
|
||||
self.game.dice_points.0
|
||||
}
|
||||
|
||||
fn calculate_adv_points(&self) -> u8 {
|
||||
self.game.dice_points.1
|
||||
}
|
||||
|
||||
fn choose_go(&self) -> bool {
|
||||
// Utiliser le DQN pour décider si on continue
|
||||
if let Some(action) = self.get_dqn_action() {
|
||||
matches!(action, TrictracAction::Go)
|
||||
} else {
|
||||
// Fallback : toujours continuer
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||
// Utiliser le DQN pour choisir le mouvement
|
||||
if let Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
from1,
|
||||
from2,
|
||||
}) = self.get_dqn_action()
|
||||
{
|
||||
let dicevals = self.game.dice.values;
|
||||
let (mut dice1, mut dice2) = if dice_order {
|
||||
(dicevals.0, dicevals.1)
|
||||
} else {
|
||||
(dicevals.1, dicevals.0)
|
||||
};
|
||||
|
||||
if from1 == 0 {
|
||||
// empty move
|
||||
dice1 = 0;
|
||||
}
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
if 24 < to1 {
|
||||
// sortie
|
||||
to1 = 0;
|
||||
}
|
||||
if from2 == 0 {
|
||||
// empty move
|
||||
dice2 = 0;
|
||||
}
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
if 24 < to2 {
|
||||
// sortie
|
||||
to2 = 0;
|
||||
}
|
||||
|
||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
let chosen_move = if self.color == Color::White {
|
||||
(checker_move1, checker_move2)
|
||||
} else {
|
||||
(checker_move1.mirror(), checker_move2.mirror())
|
||||
};
|
||||
|
||||
return chosen_move;
|
||||
}
|
||||
|
||||
// Fallback : utiliser la stratégie par défaut
|
||||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
let chosen_move = *possible_moves
|
||||
.first()
|
||||
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||
|
||||
if self.color == Color::White {
|
||||
chosen_move
|
||||
} else {
|
||||
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4,15 +4,12 @@ use burn_rl::base::{ElemType, Model, State};
|
|||
|
||||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use log::info;
|
||||
use trictrac_store::MoveRules;
|
||||
use store::MoveRules;
|
||||
|
||||
use crate::burnrl::algos::dqn;
|
||||
use crate::burnrl::environment;
|
||||
use trictrac_store::training_common::{
|
||||
get_valid_action_indices, sample_valid_action, TrictracAction,
|
||||
};
|
||||
use crate::dqn::burnrl::{dqn_model, environment, utils};
|
||||
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||
|
||||
type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>;
|
||||
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||||
|
||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||
#[derive(Debug)]
|
||||
|
|
@ -42,7 +39,7 @@ impl DqnBurnStrategy {
|
|||
pub fn new_with_model(model_path: &String) -> Self {
|
||||
info!("Loading model {model_path:?}");
|
||||
let mut strategy = Self::new();
|
||||
strategy.model = dqn::load_model(256, model_path);
|
||||
strategy.model = utils::load_model(256, model_path);
|
||||
strategy
|
||||
}
|
||||
|
||||
|
|
@ -50,7 +47,7 @@ impl DqnBurnStrategy {
|
|||
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||
if let Some(ref model) = self.model {
|
||||
let state = environment::TrictracState::from_game_state(&self.game);
|
||||
if let Ok(valid_actions_indices) = get_valid_action_indices(&self.game) {
|
||||
let valid_actions_indices = get_valid_action_indices(&self.game);
|
||||
if valid_actions_indices.is_empty() {
|
||||
return None; // No valid actions, end of episode
|
||||
}
|
||||
|
|
@ -71,11 +68,9 @@ impl DqnBurnStrategy {
|
|||
}
|
||||
// Get best action (highest q-value)
|
||||
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
return environment::TrictracEnvironment::convert_action(
|
||||
environment::TrictracAction::from(action_index),
|
||||
);
|
||||
}
|
||||
return None;
|
||||
environment::TrictracEnvironment::convert_action(environment::TrictracAction::from(
|
||||
action_index,
|
||||
))
|
||||
} else {
|
||||
// Fallback : action aléatoire valide
|
||||
sample_valid_action(&self.game)
|
||||
|
|
@ -122,8 +117,8 @@ impl BotStrategy for DqnBurnStrategy {
|
|||
// Utiliser le DQN pour choisir le mouvement
|
||||
if let Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
from1,
|
||||
from2,
|
||||
}) = self.get_dqn_action()
|
||||
{
|
||||
let dicevals = self.game.dice.values;
|
||||
|
|
@ -133,66 +128,24 @@ impl BotStrategy for DqnBurnStrategy {
|
|||
(dicevals.1, dicevals.0)
|
||||
};
|
||||
|
||||
assert_eq!(self.color, Color::White);
|
||||
let from1 = self
|
||||
.game
|
||||
.board
|
||||
.get_checker_field(&self.color, checker1 as u8)
|
||||
.unwrap_or(0);
|
||||
|
||||
if from1 == 0 {
|
||||
// empty move
|
||||
dice1 = 0;
|
||||
}
|
||||
let mut to1 = from1;
|
||||
if self.color == Color::White {
|
||||
to1 += dice1 as usize;
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
if 24 < to1 {
|
||||
// sortie
|
||||
to1 = 0;
|
||||
}
|
||||
} else {
|
||||
let fto1 = to1 as i16 - dice1 as i16;
|
||||
to1 = if fto1 < 0 { 0 } else { fto1 as usize };
|
||||
}
|
||||
|
||||
let checker_move1 = trictrac_store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
|
||||
let mut tmp_board = self.game.board.clone();
|
||||
let move_res = tmp_board.move_checker(&self.color, checker_move1);
|
||||
if move_res.is_err() {
|
||||
panic!("could not move {move_res:?}");
|
||||
}
|
||||
let from2 = tmp_board
|
||||
.get_checker_field(&self.color, checker2 as u8)
|
||||
.unwrap_or(0);
|
||||
if from2 == 0 {
|
||||
// empty move
|
||||
dice2 = 0;
|
||||
}
|
||||
let mut to2 = from2;
|
||||
if self.color == Color::White {
|
||||
to2 += dice2 as usize;
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
if 24 < to2 {
|
||||
// sortie
|
||||
to2 = 0;
|
||||
}
|
||||
} else {
|
||||
let fto2 = to2 as i16 - dice2 as i16;
|
||||
to2 = if fto2 < 0 { 0 } else { fto2 as usize };
|
||||
}
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = if self.color == Color::White { 13 } else { 12 };
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
if self.color == Color::White {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
} else {
|
||||
to1 += 1;
|
||||
to2 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
|
@ -200,7 +153,6 @@ impl BotStrategy for DqnBurnStrategy {
|
|||
let chosen_move = if self.color == Color::White {
|
||||
(checker_move1, checker_move2)
|
||||
} else {
|
||||
// XXX : really ?
|
||||
(checker_move1.mirror(), checker_move2.mirror())
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod client;
|
||||
pub mod default;
|
||||
pub mod dqn;
|
||||
pub mod dqnburn;
|
||||
pub mod erroneous_moves;
|
||||
pub mod random;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||
use rand::{prelude::IndexedRandom, rng};
|
||||
use trictrac_store::MoveRules;
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RandomStrategy {
|
||||
|
|
@ -52,7 +51,8 @@ impl BotStrategy for RandomStrategy {
|
|||
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
let mut rng = rng();
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
let mut rng = thread_rng();
|
||||
let choosen_move = possible_moves
|
||||
.choose(&mut rng)
|
||||
.cloned()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use std::io::Read;
|
|||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use trictrac_store::MoveRules;
|
||||
use store::MoveRules;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StableBaselines3Strategy {
|
||||
|
|
@ -66,25 +66,25 @@ impl StableBaselines3Strategy {
|
|||
// Remplir les positions des pièces blanches (valeurs positives)
|
||||
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
||||
if pos < 24 {
|
||||
board[pos] = count;
|
||||
board[pos] = count as i8;
|
||||
}
|
||||
}
|
||||
|
||||
// Remplir les positions des pièces noires (valeurs négatives)
|
||||
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
||||
if pos < 24 {
|
||||
board[pos] = -count;
|
||||
board[pos] = -(count as i8);
|
||||
}
|
||||
}
|
||||
|
||||
// Convertir l'étape du tour en entier
|
||||
let turn_stage = match self.game.turn_stage {
|
||||
trictrac_store::TurnStage::RollDice => 0,
|
||||
trictrac_store::TurnStage::RollWaiting => 1,
|
||||
trictrac_store::TurnStage::MarkPoints => 2,
|
||||
trictrac_store::TurnStage::HoldOrGoChoice => 3,
|
||||
trictrac_store::TurnStage::Move => 4,
|
||||
trictrac_store::TurnStage::MarkAdvPoints => 5,
|
||||
store::TurnStage::RollDice => 0,
|
||||
store::TurnStage::RollWaiting => 1,
|
||||
store::TurnStage::MarkPoints => 2,
|
||||
store::TurnStage::HoldOrGoChoice => 3,
|
||||
store::TurnStage::Move => 4,
|
||||
store::TurnStage::MarkAdvPoints => 5,
|
||||
};
|
||||
|
||||
// Récupérer les points et trous des joueurs
|
||||
|
|
@ -270,3 +270,4 @@ impl BotStrategy for StableBaselines3Strategy {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,171 +0,0 @@
|
|||
// https://docs.rs/board-game/ implementation
|
||||
use board_game::board::{
|
||||
Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer,
|
||||
};
|
||||
use board_game::impl_unit_symmetry_board;
|
||||
use internal_iterator::InternalIterator;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use std::ops::ControlFlow;
|
||||
use trictrac_store::training_common::{get_valid_actions, TrictracAction};
|
||||
use trictrac_store::Color;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct TrictracBoard(crate::GameState);
|
||||
|
||||
impl Default for TrictracBoard {
|
||||
fn default() -> Self {
|
||||
TrictracBoard(crate::GameState::new_with_players("white", "black"))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TrictracBoard {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl_unit_symmetry_board!(TrictracBoard);
|
||||
|
||||
impl BoardGameBoard for TrictracBoard {
|
||||
// impl TrictracBoard {
|
||||
type Move = TrictracAction;
|
||||
|
||||
fn next_player(&self) -> BoardGamePlayer {
|
||||
self.0
|
||||
.who_plays()
|
||||
.map(|p| {
|
||||
if p.color == Color::Black {
|
||||
BoardGamePlayer::B
|
||||
} else {
|
||||
BoardGamePlayer::A
|
||||
}
|
||||
})
|
||||
.unwrap_or(BoardGamePlayer::A)
|
||||
}
|
||||
|
||||
fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
|
||||
self.check_done()?;
|
||||
let is_valid = mv
|
||||
.to_event(&self.0)
|
||||
.map(|evt| self.0.validate(&evt))
|
||||
.unwrap_or(false);
|
||||
Ok(is_valid)
|
||||
}
|
||||
|
||||
fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
|
||||
self.check_can_play(mv)?;
|
||||
if let Some(evt) = mv.to_event(&self.0) {
|
||||
let _ = self.0.consume(&evt);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(PlayError::UnavailableMove)
|
||||
}
|
||||
}
|
||||
|
||||
fn outcome(&self) -> Option<Outcome> {
|
||||
if self.0.stage == crate::Stage::Ended {
|
||||
self.0.determine_winner().map(|player_id| {
|
||||
Outcome::WonBy(if player_id == 1 {
|
||||
BoardGamePlayer::A
|
||||
} else {
|
||||
BoardGamePlayer::B
|
||||
})
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn can_lose_after_move() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracBoard {
|
||||
pub fn inner(&self) -> &crate::GameState {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn to_fen(&self) -> String {
|
||||
self.0.to_string_id()
|
||||
}
|
||||
|
||||
pub fn from_fen(fen: &str) -> Result<TrictracBoard, String> {
|
||||
crate::GameState::from_string_id(fen).map(TrictracBoard)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> BoardMoves<'a, TrictracBoard> for TrictracBoard {
|
||||
type AllMovesIterator = TrictracAllMovesIterator;
|
||||
type AvailableMovesIterator = TrictracAvailableMovesIterator<'a>;
|
||||
|
||||
fn all_possible_moves() -> Self::AllMovesIterator {
|
||||
TrictracAllMovesIterator::default()
|
||||
}
|
||||
|
||||
fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
|
||||
TrictracAvailableMovesIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrictracAllMovesIterator;
|
||||
|
||||
impl Default for TrictracAllMovesIterator {
|
||||
fn default() -> Self {
|
||||
TrictracAllMovesIterator
|
||||
}
|
||||
}
|
||||
|
||||
impl InternalIterator for TrictracAllMovesIterator {
|
||||
type Item = TrictracAction;
|
||||
|
||||
fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
|
||||
f(TrictracAction::Roll)?;
|
||||
f(TrictracAction::Go)?;
|
||||
for dice_order in [false, true] {
|
||||
for checker1 in 0..16 {
|
||||
for checker2 in 0..16 {
|
||||
f(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrictracAvailableMovesIterator<'a> {
|
||||
board: &'a TrictracBoard,
|
||||
}
|
||||
|
||||
impl<'a> TrictracAvailableMovesIterator<'a> {
|
||||
pub fn new(board: &'a TrictracBoard) -> Result<Self, BoardDone> {
|
||||
board.check_done()?;
|
||||
Ok(TrictracAvailableMovesIterator { board })
|
||||
}
|
||||
|
||||
pub fn board(&self) -> &'a TrictracBoard {
|
||||
self.board
|
||||
}
|
||||
}
|
||||
|
||||
impl InternalIterator for TrictracAvailableMovesIterator<'_> {
|
||||
type Item = TrictracAction;
|
||||
|
||||
fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
|
||||
where
|
||||
F: FnMut(Self::Item) -> ControlFlow<R>,
|
||||
{
|
||||
match get_valid_actions(&self.board.0) {
|
||||
Ok(actions) => actions.into_iter().try_for_each(f),
|
||||
Err(_) => ControlFlow::Continue(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
8
client_bevy/.cargo/config.toml
Normal file
8
client_bevy/.cargo/config.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
[target.x86_64-unknown-linux-gnu]
|
||||
linker = "clang"
|
||||
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
|
||||
|
||||
# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only'
|
||||
# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains.
|
||||
#[profile.dev]
|
||||
#debug = 1
|
||||
14
client_bevy/Cargo.toml
Normal file
14
client_bevy/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "trictrac-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
bevy = { version = "0.11.3" }
|
||||
bevy_renet = "0.0.9"
|
||||
bincode = "1.3.3"
|
||||
renet = "0.0.13"
|
||||
store = { path = "../store" }
|
||||
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
Binary file not shown.
BIN
client_bevy/assets/board.png
Normal file
BIN
client_bevy/assets/board.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
BIN
client_bevy/assets/sound/click.wav
Normal file
BIN
client_bevy/assets/sound/click.wav
Normal file
Binary file not shown.
BIN
client_bevy/assets/sound/throw.wav
Executable file
BIN
client_bevy/assets/sound/throw.wav
Executable file
Binary file not shown.
BIN
client_bevy/assets/tac.png
Normal file
BIN
client_bevy/assets/tac.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
BIN
client_bevy/assets/tic.png
Normal file
BIN
client_bevy/assets/tic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 KiB |
334
client_bevy/src/main.rs
Normal file
334
client_bevy/src/main.rs
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
use std::{net::UdpSocket, time::SystemTime};
|
||||
|
||||
use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES};
|
||||
use store::{GameEvent, GameState, CheckerMove};
|
||||
|
||||
use bevy::prelude::*;
|
||||
use bevy::window::PrimaryWindow;
|
||||
use bevy_renet::{
|
||||
renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient},
|
||||
transport::{client_connected, NetcodeClientPlugin},
|
||||
RenetClientPlugin,
|
||||
};
|
||||
|
||||
#[derive(Debug, Resource)]
|
||||
struct CurrentClientId(u64);
|
||||
|
||||
#[derive(Resource)]
|
||||
struct BevyGameState(GameState);
|
||||
|
||||
impl Default for BevyGameState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
0: GameState::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Resource, Deref, DerefMut)]
|
||||
struct GameUIState {
|
||||
selected_tile: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for GameUIState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
selected_tile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Event)]
|
||||
struct BevyGameEvent(GameEvent);
|
||||
|
||||
// This id needs to be the same as the server is using
|
||||
const PROTOCOL_ID: u64 = 2878;
|
||||
|
||||
fn main() {
|
||||
// Get username from stdin args
|
||||
let args = std::env::args().collect::<Vec<String>>();
|
||||
let username = &args[1];
|
||||
|
||||
let (client, transport, client_id) = new_renet_client(&username).unwrap();
|
||||
App::new()
|
||||
// Lets add a nice dark grey background color
|
||||
.insert_resource(ClearColor(Color::hex("282828").unwrap()))
|
||||
.add_plugins(DefaultPlugins.set(WindowPlugin {
|
||||
primary_window: Some(Window {
|
||||
// Adding the username to the window title makes debugging a whole lot easier.
|
||||
title: format!("TricTrac <{}>", username),
|
||||
resolution: (1080.0, 1080.0).into(),
|
||||
..default()
|
||||
}),
|
||||
..default()
|
||||
}))
|
||||
// Add our game state and register GameEvent as a bevy event
|
||||
.insert_resource(BevyGameState::default())
|
||||
.insert_resource(GameUIState::default())
|
||||
.add_event::<BevyGameEvent>()
|
||||
// Renet setup
|
||||
.add_plugins(RenetClientPlugin)
|
||||
.add_plugins(NetcodeClientPlugin)
|
||||
.insert_resource(client)
|
||||
.insert_resource(transport)
|
||||
.insert_resource(CurrentClientId(client_id))
|
||||
.add_systems(Startup, setup)
|
||||
.add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system))
|
||||
.add_systems(
|
||||
PostUpdate,
|
||||
receive_events_from_server.run_if(client_connected()),
|
||||
)
|
||||
.run();
|
||||
}
|
||||
|
||||
////////// COMPONENTS //////////
|
||||
#[derive(Component)]
|
||||
struct UIRoot;
|
||||
|
||||
#[derive(Component)]
|
||||
struct WaitingText;
|
||||
|
||||
#[derive(Component)]
|
||||
struct Board {
|
||||
squares: [Square; 26]
|
||||
}
|
||||
|
||||
impl Default for Board {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
squares: [Square { count: 0, color: None, position: 0}; 26]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Board {
|
||||
fn square_at(&self, position: usize) -> Square {
|
||||
self.squares[position]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Component, Clone, Copy)]
|
||||
struct Square {
|
||||
count: usize,
|
||||
color: Option<bool>,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
////////// UPDATE SYSTEMS //////////
|
||||
fn update_board(
|
||||
mut commands: Commands,
|
||||
game_state: Res<BevyGameState>,
|
||||
mut game_events: EventReader<BevyGameEvent>,
|
||||
asset_server: Res<AssetServer>,
|
||||
) {
|
||||
for event in game_events.iter() {
|
||||
match event.0 {
|
||||
GameEvent::Move { player_id, moves } => {
|
||||
// trictrac positions, TODO : dépend de player_id
|
||||
let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)};
|
||||
let texture =
|
||||
asset_server.load(match game_state.0.players[&player_id].color {
|
||||
store::Color::Black => "tac.png",
|
||||
store::Color::White => "tic.png",
|
||||
});
|
||||
|
||||
info!("spawning tictac sprite");
|
||||
commands.spawn(SpriteBundle {
|
||||
transform: Transform::from_xyz(
|
||||
83.0 * (x as f32 - 1.0),
|
||||
-30.0 + 540.0 * (y as f32 - 1.0),
|
||||
0.0,
|
||||
),
|
||||
sprite: Sprite {
|
||||
custom_size: Some(Vec2::new(83.0, 83.0)),
|
||||
..default()
|
||||
},
|
||||
texture: texture.into(),
|
||||
..default()
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_waiting_text(mut text_query: Query<&mut Text, With<WaitingText>>, time: Res<Time>) {
|
||||
if let Ok(mut text) = text_query.get_single_mut() {
|
||||
let num_dots = (time.elapsed_seconds() as usize % 3) + 1;
|
||||
text.sections[0].value = format!(
|
||||
"Waiting for an opponent{}{}",
|
||||
".".repeat(num_dots as usize),
|
||||
// Pad with spaces to avoid text changing width and dancing all around the screen 🕺
|
||||
" ".repeat(3 - num_dots as usize)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn input(
|
||||
primary_query: Query<&Window, With<PrimaryWindow>>,
|
||||
// windows: Res<Windows>,
|
||||
input: Res<Input<MouseButton>>,
|
||||
game_state: Res<BevyGameState>,
|
||||
mut game_ui_state: ResMut<GameUIState>,
|
||||
mut client: ResMut<RenetClient>,
|
||||
client_id: Res<CurrentClientId>,
|
||||
) {
|
||||
// We only want to handle inputs once we are ingame
|
||||
if game_state.0.stage != store::Stage::InGame {
|
||||
return;
|
||||
}
|
||||
|
||||
let window = primary_query.get_single().unwrap();
|
||||
if let Some(mouse_position) = window.cursor_position() {
|
||||
// Determine the index of the tile that the mouse is currently over
|
||||
// NOTE: This calculation assumes a fixed window size.
|
||||
// That's fine for now, but consider using the windows size instead.
|
||||
let mut tile_x: usize = (mouse_position.x / 83.0).floor() as usize;
|
||||
let tile_y: usize = (mouse_position.y / 540.0).floor() as usize;
|
||||
if tile_x > 5 {
|
||||
// remove the middle bar offset
|
||||
tile_x = tile_x - 1
|
||||
}
|
||||
// let tile = tile_x + tile_y * 12;
|
||||
|
||||
// traduction en position backgammon
|
||||
let tile = if tile_y == 0 {
|
||||
13 + tile_x
|
||||
} else {
|
||||
12 - tile_x
|
||||
};
|
||||
|
||||
// If mouse is outside of board we do nothing
|
||||
if 23 < tile {
|
||||
return;
|
||||
}
|
||||
|
||||
// If left mouse button is pressed, send a place tile event to the server
|
||||
if input.just_pressed(MouseButton::Left) {
|
||||
info!("select piece at tile {:?}", tile);
|
||||
if game_ui_state.selected_tile.is_some() {
|
||||
let from_tile = game_ui_state.selected_tile.unwrap();
|
||||
info!("sending movement from: {:?} to: {:?} ", from_tile, tile);
|
||||
let event = GameEvent::Move {
|
||||
player_id: client_id.0,
|
||||
moves: (
|
||||
CheckerMove::new(from_tile, tile).unwrap(),
|
||||
CheckerMove::new(from_tile, tile).unwrap()
|
||||
)
|
||||
};
|
||||
client.send_message(0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
game_ui_state.selected_tile = if game_ui_state.selected_tile.is_some() {
|
||||
None
|
||||
} else {
|
||||
Some(tile)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////// SETUP //////////
|
||||
fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
|
||||
// Tric Trac is a 2D game
|
||||
// To show 2D sprites we need a 2D camera
|
||||
commands.spawn(Camera2dBundle::default());
|
||||
|
||||
// Spawn board background
|
||||
commands.spawn(SpriteBundle {
|
||||
transform: Transform::from_xyz(0.0, -30.0, 0.0),
|
||||
sprite: Sprite {
|
||||
custom_size: Some(Vec2::new(1080.0, 927.0)),
|
||||
..default()
|
||||
},
|
||||
texture: asset_server.load("board.png").into(),
|
||||
..default()
|
||||
});
|
||||
|
||||
// Spawn pregame ui
|
||||
commands
|
||||
// A container that centers its children on the screen
|
||||
.spawn(NodeBundle {
|
||||
style: Style {
|
||||
position_type: PositionType::Absolute,
|
||||
left: Val::Px(0.0),
|
||||
top: Val::Px(0.0),
|
||||
width: Val::Percent(100.0),
|
||||
height: Val::Percent(100.0),
|
||||
align_items: AlignItems::Center,
|
||||
justify_content: JustifyContent::Center,
|
||||
..default()
|
||||
},
|
||||
..default()
|
||||
})
|
||||
.insert(UIRoot)
|
||||
.with_children(|parent| {
|
||||
// parent.spawn(Board::default()); // panic
|
||||
parent
|
||||
.spawn(TextBundle::from_section(
|
||||
"Waiting for an opponent...",
|
||||
TextStyle {
|
||||
font: asset_server.load("Inconsolata.ttf"),
|
||||
font_size: 24.0,
|
||||
color: Color::hex("ebdbb2").unwrap(),
|
||||
},
|
||||
))
|
||||
.insert(WaitingText);
|
||||
});
|
||||
}
|
||||
|
||||
////////// RENET NETWORKING //////////
|
||||
// Creates a RenetClient thats already connected to a server.
|
||||
// Returns an Err if connection fails
|
||||
fn new_renet_client(
|
||||
username: &String,
|
||||
) -> anyhow::Result<(RenetClient, NetcodeClientTransport, u64)> {
|
||||
let client = RenetClient::new(ConnectionConfig::default());
|
||||
let server_addr = "127.0.0.1:5000".parse()?;
|
||||
let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?;
|
||||
let client_id = current_time.as_millis() as u64;
|
||||
|
||||
// Place username in user data
|
||||
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
|
||||
if username.len() > NETCODE_USER_DATA_BYTES - 8 {
|
||||
panic!("Username is too big");
|
||||
}
|
||||
user_data[0..8].copy_from_slice(&(username.len() as u64).to_le_bytes());
|
||||
user_data[8..username.len() + 8].copy_from_slice(username.as_bytes());
|
||||
|
||||
let authentication = ClientAuthentication::Unsecure {
|
||||
server_addr,
|
||||
client_id,
|
||||
user_data: Some(user_data),
|
||||
protocol_id: PROTOCOL_ID,
|
||||
};
|
||||
let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
|
||||
|
||||
Ok((client, transport, client_id))
|
||||
}
|
||||
|
||||
fn receive_events_from_server(
|
||||
mut client: ResMut<RenetClient>,
|
||||
mut game_state: ResMut<BevyGameState>,
|
||||
mut game_events: EventWriter<BevyGameEvent>,
|
||||
) {
|
||||
while let Some(message) = client.receive_message(0) {
|
||||
// Whenever the server sends a message we know that it must be a game event
|
||||
let event: GameEvent = bincode::deserialize(&message).unwrap();
|
||||
trace!("{:#?}", event);
|
||||
|
||||
// We trust the server - It's always been good to us!
|
||||
// No need to validate the events it is sending us
|
||||
game_state.0.consume(&event);
|
||||
|
||||
// Send the event into the bevy event system so systems can react to it
|
||||
game_events.send(BevyGameEvent(event));
|
||||
}
|
||||
}
|
||||
|
||||
// If any error is found we just panic
|
||||
fn panic_on_error_system(mut renet_error: EventReader<NetcodeTransportError>) {
|
||||
for e in renet_error.iter() {
|
||||
panic!("{}", e);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
[package]
|
||||
name = "trictrac-client_cli"
|
||||
name = "client_cli"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
|
@ -11,8 +11,8 @@ bincode = "1.3.3"
|
|||
pico-args = "0.5.0"
|
||||
pretty_assertions = "1.4.0"
|
||||
renet = "0.0.13"
|
||||
trictrac-store = { path = "../store" }
|
||||
trictrac-bot = { path = "../bot" }
|
||||
store = { path = "../store" }
|
||||
bot = { path = "../bot" }
|
||||
itertools = "0.13.0"
|
||||
env_logger = "0.11.6"
|
||||
log = "0.4.20"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use trictrac_bot::{
|
||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
||||
use bot::{
|
||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
|
||||
StableBaselines3Strategy,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::game_runner::GameRunner;
|
||||
use trictrac_store::{CheckerMove, GameEvent, GameState, Stage, TurnStage};
|
||||
use store::{CheckerMove, GameEvent, GameState, Stage, TurnStage};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct AppArgs {
|
||||
|
|
@ -25,8 +25,8 @@ pub struct App {
|
|||
impl App {
|
||||
// Constructs a new instance of [`App`].
|
||||
pub fn new(args: AppArgs) -> Self {
|
||||
let bot_strategies: Vec<Box<dyn BotStrategy>> =
|
||||
args.bot
|
||||
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
|
||||
.bot
|
||||
.as_deref()
|
||||
.map(|str_bots| {
|
||||
str_bots
|
||||
|
|
@ -43,6 +43,7 @@ impl App {
|
|||
}
|
||||
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
||||
as Box<dyn BotStrategy>),
|
||||
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
|
||||
"dqnburn" => {
|
||||
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
||||
}
|
||||
|
|
@ -51,9 +52,14 @@ impl App {
|
|||
Some(Box::new(StableBaselines3Strategy::new(path))
|
||||
as Box<dyn BotStrategy>)
|
||||
}
|
||||
s if s.starts_with("dqn:") => {
|
||||
let path = s.trim_start_matches("dqn:");
|
||||
Some(Box::new(DqnStrategy::new_with_model(path))
|
||||
as Box<dyn BotStrategy>)
|
||||
}
|
||||
s if s.starts_with("dqnburn:") => {
|
||||
let path = s.trim_start_matches("dqnburn:");
|
||||
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
||||
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}")))
|
||||
as Box<dyn BotStrategy>)
|
||||
}
|
||||
_ => None,
|
||||
|
|
@ -108,7 +114,7 @@ impl App {
|
|||
|
||||
pub fn show_history(&self) {
|
||||
for hist in self.game.state.history.iter() {
|
||||
println!("{hist:?}\n");
|
||||
println!("{:?}\n", hist);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -133,9 +139,6 @@ impl App {
|
|||
// &self.game.state.board,
|
||||
// dice,
|
||||
// );
|
||||
self.game.handle_event(&GameEvent::Roll {
|
||||
player_id: self.game.player_id.unwrap(),
|
||||
});
|
||||
self.game.handle_event(&GameEvent::RollResult {
|
||||
player_id: self.game.player_id.unwrap(),
|
||||
dice,
|
||||
|
|
@ -186,7 +189,7 @@ impl App {
|
|||
return;
|
||||
}
|
||||
}
|
||||
println!("invalid move : {input}");
|
||||
println!("invalid move : {}", input);
|
||||
}
|
||||
|
||||
pub fn display(&mut self) -> String {
|
||||
|
|
@ -326,7 +329,6 @@ Player :: holes :: points
|
|||
seed: Some(1327),
|
||||
bot: Some("dummy".into()),
|
||||
});
|
||||
println!("avant : {}", app.display());
|
||||
app.input("roll");
|
||||
app.input("1 3");
|
||||
app.input("1 4");
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use bot::{Bot, BotStrategy};
|
||||
use log::{debug, error};
|
||||
use trictrac_bot::{Bot, BotStrategy};
|
||||
use trictrac_store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage};
|
||||
use store::{CheckerMove, DiceRoller, GameEvent, GameState, PlayerId, TurnStage};
|
||||
|
||||
// Application Game
|
||||
#[derive(Debug, Default)]
|
||||
|
|
@ -67,7 +67,7 @@ impl GameRunner {
|
|||
"--------------- new valid event {event:?} (stage {:?}) -----------",
|
||||
self.state.turn_stage
|
||||
);
|
||||
let _ = self.state.consume(event).inspect_err(|e| error!("{}", e));
|
||||
self.state.consume(event);
|
||||
debug!(
|
||||
" --> stage {:?} ; active player points {:?}",
|
||||
self.state.turn_stage,
|
||||
|
|
@ -77,7 +77,7 @@ impl GameRunner {
|
|||
} else {
|
||||
debug!("{}", self.state);
|
||||
error!("event not valid : {event:?}");
|
||||
// panic!("crash and burn {} \nevt not valid {event:?}", self.state);
|
||||
panic!("crash and burn");
|
||||
&GameEvent::PlayError
|
||||
};
|
||||
|
||||
|
|
@ -117,8 +117,8 @@ impl GameRunner {
|
|||
}
|
||||
|
||||
if let Some(winner) = self.state.determine_winner() {
|
||||
next_event = Some(trictrac_store::GameEvent::EndGame {
|
||||
reason: trictrac_store::EndGameReason::PlayerWon { winner },
|
||||
next_event = Some(store::GameEvent::EndGame {
|
||||
reason: store::EndGameReason::PlayerWon { winner },
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ fn main() -> Result<()> {
|
|||
let args = match parse_args() {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
eprintln!("Error: {e}.");
|
||||
eprintln!("Error: {}.", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
|
@ -63,7 +63,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
|
|||
|
||||
// Help has a higher priority and should be handled separately.
|
||||
if pargs.contains(["-h", "--help"]) {
|
||||
print!("{HELP}");
|
||||
print!("{}", HELP);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
|
|
@ -78,7 +78,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
|
|||
// It's up to the caller what to do with the remaining arguments.
|
||||
let remaining = pargs.finish();
|
||||
if !remaining.is_empty() {
|
||||
eprintln!("Warning: unused arguments left: {remaining:?}.");
|
||||
eprintln!("Warning: unused arguments left: {:?}.", remaining);
|
||||
}
|
||||
|
||||
Ok(args)
|
||||
|
|
|
|||
14
client_tui/Cargo.toml
Normal file
14
client_tui/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "client_tui"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.89"
|
||||
bincode = "1.3.3"
|
||||
crossterm = "0.28.1"
|
||||
ratatui = "0.28.1"
|
||||
# renet = "0.0.13"
|
||||
store = { path = "../store" }
|
||||
53
client_tui/src/app.rs
Normal file
53
client_tui/src/app.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Application.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct App {
|
||||
// should the application exit?
|
||||
pub should_quit: bool,
|
||||
// counter
|
||||
pub counter: u8,
|
||||
}
|
||||
|
||||
impl App {
|
||||
// Constructs a new instance of [`App`].
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
// Handles the tick event of the terminal.
|
||||
pub fn tick(&self) {}
|
||||
|
||||
// Set running to false to quit the application.
|
||||
pub fn quit(&mut self) {
|
||||
self.should_quit = true;
|
||||
}
|
||||
|
||||
pub fn increment_counter(&mut self) {
|
||||
if let Some(res) = self.counter.checked_add(1) {
|
||||
self.counter = res;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrement_counter(&mut self) {
|
||||
if let Some(res) = self.counter.checked_sub(1) {
|
||||
self.counter = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_app_increment_counter() {
|
||||
let mut app = App::default();
|
||||
app.increment_counter();
|
||||
assert_eq!(app.counter, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_decrement_counter() {
|
||||
let mut app = App::default();
|
||||
app.decrement_counter();
|
||||
assert_eq!(app.counter, 0);
|
||||
}
|
||||
}
|
||||
87
client_tui/src/event.rs
Normal file
87
client_tui/src/event.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
use std::{
|
||||
sync::mpsc,
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
|
||||
|
||||
// Terminal events.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum Event {
|
||||
// Terminal tick.
|
||||
Tick,
|
||||
// Key press.
|
||||
Key(KeyEvent),
|
||||
// Mouse click/scroll.
|
||||
Mouse(MouseEvent),
|
||||
// Terminal resize.
|
||||
Resize(u16, u16),
|
||||
}
|
||||
|
||||
// Terminal event handler.
|
||||
#[derive(Debug)]
|
||||
pub struct EventHandler {
|
||||
// Event sender channel.
|
||||
#[allow(dead_code)]
|
||||
sender: mpsc::Sender<Event>,
|
||||
// Event receiver channel.
|
||||
receiver: mpsc::Receiver<Event>,
|
||||
// Event handler thread.
|
||||
#[allow(dead_code)]
|
||||
handler: thread::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl EventHandler {
|
||||
// Constructs a new instance of [`EventHandler`].
|
||||
pub fn new(tick_rate: u64) -> Self {
|
||||
let tick_rate = Duration::from_millis(tick_rate);
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let handler = {
|
||||
let sender = sender.clone();
|
||||
thread::spawn(move || {
|
||||
let mut last_tick = Instant::now();
|
||||
loop {
|
||||
let timeout = tick_rate
|
||||
.checked_sub(last_tick.elapsed())
|
||||
.unwrap_or(tick_rate);
|
||||
|
||||
if event::poll(timeout).expect("no events available") {
|
||||
match event::read().expect("unable to read event") {
|
||||
CrosstermEvent::Key(e) => {
|
||||
if e.kind == event::KeyEventKind::Press {
|
||||
sender.send(Event::Key(e))
|
||||
} else {
|
||||
Ok(()) // ignore KeyEventKind::Release on windows
|
||||
}
|
||||
}
|
||||
CrosstermEvent::Mouse(e) => sender.send(Event::Mouse(e)),
|
||||
CrosstermEvent::Resize(w, h) => sender.send(Event::Resize(w, h)),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
.expect("failed to send terminal event")
|
||||
}
|
||||
|
||||
if last_tick.elapsed() >= tick_rate {
|
||||
sender.send(Event::Tick).expect("failed to send tick event");
|
||||
last_tick = Instant::now();
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
Self {
|
||||
sender,
|
||||
receiver,
|
||||
handler,
|
||||
}
|
||||
}
|
||||
|
||||
// Receive the next event from the handler thread.
|
||||
//
|
||||
// This function will always block the current thread if
|
||||
// there is no data available and it's possible for more data to be sent.
|
||||
pub fn next(&self) -> Result<Event> {
|
||||
Ok(self.receiver.recv()?)
|
||||
}
|
||||
}
|
||||
50
client_tui/src/main.rs
Normal file
50
client_tui/src/main.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Application.
|
||||
pub mod app;
|
||||
|
||||
// Terminal events handler.
|
||||
pub mod event;
|
||||
|
||||
// Widget renderer.
|
||||
pub mod ui;
|
||||
|
||||
// Terminal user interface.
|
||||
pub mod tui;
|
||||
|
||||
// Application updater.
|
||||
pub mod update;
|
||||
|
||||
use anyhow::Result;
|
||||
use app::App;
|
||||
use event::{Event, EventHandler};
|
||||
use ratatui::{backend::CrosstermBackend, Terminal};
|
||||
use tui::Tui;
|
||||
use update::update;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Create an application.
|
||||
let mut app = App::new();
|
||||
|
||||
// Initialize the terminal user interface.
|
||||
let backend = CrosstermBackend::new(std::io::stderr());
|
||||
let terminal = Terminal::new(backend)?;
|
||||
let events = EventHandler::new(250);
|
||||
let mut tui = Tui::new(terminal, events);
|
||||
tui.enter()?;
|
||||
|
||||
// Start the main loop.
|
||||
while !app.should_quit {
|
||||
// Render the user interface.
|
||||
tui.draw(&mut app)?;
|
||||
// Handle events.
|
||||
match tui.events.next()? {
|
||||
Event::Tick => {}
|
||||
Event::Key(key_event) => update(&mut app, key_event),
|
||||
Event::Mouse(_) => {}
|
||||
Event::Resize(_, _) => {}
|
||||
};
|
||||
}
|
||||
|
||||
// Exit the user interface.
|
||||
tui.exit()?;
|
||||
Ok(())
|
||||
}
|
||||
77
client_tui/src/tui.rs
Normal file
77
client_tui/src/tui.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
use std::{io, panic};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossterm::{
|
||||
event::{DisableMouseCapture, EnableMouseCapture},
|
||||
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
};
|
||||
|
||||
pub type CrosstermTerminal = ratatui::Terminal<ratatui::backend::CrosstermBackend<std::io::Stderr>>;
|
||||
|
||||
use crate::{app::App, event::EventHandler, ui};
|
||||
|
||||
// Representation of a terminal user interface.
|
||||
//
|
||||
// It is responsible for setting up the terminal,
|
||||
// initializing the interface and handling the draw events.
|
||||
pub struct Tui {
|
||||
// Interface to the Terminal.
|
||||
terminal: CrosstermTerminal,
|
||||
// Terminal event handler.
|
||||
pub events: EventHandler,
|
||||
}
|
||||
|
||||
impl Tui {
|
||||
// Constructs a new instance of [`Tui`].
|
||||
pub fn new(terminal: CrosstermTerminal, events: EventHandler) -> Self {
|
||||
Self { terminal, events }
|
||||
}
|
||||
|
||||
// Initializes the terminal interface.
|
||||
//
|
||||
// It enables the raw mode and sets terminal properties.
|
||||
pub fn enter(&mut self) -> Result<()> {
|
||||
terminal::enable_raw_mode()?;
|
||||
crossterm::execute!(io::stderr(), EnterAlternateScreen, EnableMouseCapture)?;
|
||||
|
||||
// Define a custom panic hook to reset the terminal properties.
|
||||
// This way, you won't have your terminal messed up if an unexpected error happens.
|
||||
let panic_hook = panic::take_hook();
|
||||
panic::set_hook(Box::new(move |panic| {
|
||||
Self::reset().expect("failed to reset the terminal");
|
||||
panic_hook(panic);
|
||||
}));
|
||||
|
||||
self.terminal.hide_cursor()?;
|
||||
self.terminal.clear()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// [`Draw`] the terminal interface by [`rendering`] the widgets.
|
||||
//
|
||||
// [`Draw`]: tui::Terminal::draw
|
||||
// [`rendering`]: crate::ui:render
|
||||
pub fn draw(&mut self, app: &mut App) -> Result<()> {
|
||||
self.terminal.draw(|frame| ui::render(app, frame))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Resets the terminal interface.
|
||||
//
|
||||
// This function is also used for the panic hook to revert
|
||||
// the terminal properties if unexpected errors occur.
|
||||
fn reset() -> Result<()> {
|
||||
terminal::disable_raw_mode()?;
|
||||
crossterm::execute!(io::stderr(), LeaveAlternateScreen, DisableMouseCapture)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Exits the terminal interface.
|
||||
//
|
||||
// It disables the raw mode and reverts back the terminal properties.
|
||||
pub fn exit(&mut self) -> Result<()> {
|
||||
Self::reset()?;
|
||||
self.terminal.show_cursor()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
30
client_tui/src/ui.rs
Normal file
30
client_tui/src/ui.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use ratatui::{
|
||||
prelude::{Alignment, Frame},
|
||||
style::{Color, Style},
|
||||
widgets::{Block, BorderType, Borders, Paragraph},
|
||||
};
|
||||
|
||||
use crate::app::App;
|
||||
|
||||
pub fn render(app: &mut App, f: &mut Frame) {
|
||||
f.render_widget(
|
||||
Paragraph::new(format!(
|
||||
"
|
||||
Press `Esc`, `Ctrl-C` or `q` to stop running.\n\
|
||||
Press `j` and `k` to increment and decrement the counter respectively.\n\
|
||||
Counter: {}
|
||||
",
|
||||
app.counter
|
||||
))
|
||||
.block(
|
||||
Block::default()
|
||||
.title("Counter App")
|
||||
.title_alignment(Alignment::Center)
|
||||
.borders(Borders::ALL)
|
||||
.border_type(BorderType::Rounded),
|
||||
)
|
||||
.style(Style::default().fg(Color::Yellow))
|
||||
.alignment(Alignment::Center),
|
||||
f.area(),
|
||||
)
|
||||
}
|
||||
17
client_tui/src/update.rs
Normal file
17
client_tui/src/update.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
|
||||
use crate::app::App;
|
||||
|
||||
pub fn update(app: &mut App, key_event: KeyEvent) {
|
||||
match key_event.code {
|
||||
KeyCode::Esc | KeyCode::Char('q') => app.quit(),
|
||||
KeyCode::Char('c') | KeyCode::Char('C') => {
|
||||
if key_event.modifiers == KeyModifiers::CONTROL {
|
||||
app.quit()
|
||||
}
|
||||
}
|
||||
KeyCode::Right | KeyCode::Char('j') => app.increment_counter(),
|
||||
KeyCode::Left | KeyCode::Char('k') => app.decrement_counter(),
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
42
devenv.lock
42
devenv.lock
|
|
@ -3,10 +3,10 @@
|
|||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1770390537,
|
||||
"lastModified": 1753667201,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "d6f45cc00829254a9a6f8807c8fbfaf3efa7e629",
|
||||
"rev": "4d584d7686a50387f975879788043e55af9f0ad4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
@ -19,14 +19,14 @@
|
|||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1767039857,
|
||||
"owner": "NixOS",
|
||||
"lastModified": 1747046372,
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab",
|
||||
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
@ -40,10 +40,10 @@
|
|||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1769939035,
|
||||
"lastModified": 1750779888,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "a8ca480175326551d6c4121498316261cbb5b260",
|
||||
"rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
@ -60,10 +60,10 @@
|
|||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762808025,
|
||||
"lastModified": 1709087332,
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
@ -74,40 +74,24 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1770136044,
|
||||
"lastModified": 1753432016,
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e576e3c9cf9bad747afcddd9e34f51d18c855b4e",
|
||||
"rev": "6027c30c8e9810896b92429f0092f624f7b1aace",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-25.11",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-cmake3": {
|
||||
"locked": {
|
||||
"lastModified": 1758213207,
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "f4b140d5b253f5e2a1ff4e5506edbf8267724bde",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "f4b140d5b253f5e2a1ff4e5506edbf8267724bde",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"devenv": "devenv",
|
||||
"git-hooks": "git-hooks",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-cmake3": "nixpkgs-cmake3",
|
||||
"pre-commit-hooks": [
|
||||
"git-hooks"
|
||||
]
|
||||
|
|
|
|||
34
devenv.nix
34
devenv.nix
|
|
@ -1,16 +1,13 @@
|
|||
{ inputs, pkgs, ... }:
|
||||
{ pkgs, ... }:
|
||||
|
||||
let
|
||||
pkgs-cmake3 = import inputs.nixpkgs-cmake3 { system = pkgs.stdenv.system; };
|
||||
in
|
||||
{
|
||||
|
||||
packages = [
|
||||
|
||||
# pour burn-rs
|
||||
pkgs.SDL2_gfx
|
||||
# (compilation sdl2-sys)
|
||||
pkgs-cmake3.cmake
|
||||
pkgs.libxcb
|
||||
pkgs.cmake
|
||||
pkgs.libffi
|
||||
pkgs.wayland-scanner
|
||||
|
||||
|
|
@ -18,12 +15,6 @@ in
|
|||
pkgs.samply # code profiler
|
||||
pkgs.feedgnuplot # to visualize bots training results
|
||||
|
||||
# --- AI training with python ---
|
||||
# generate python classes from rust code
|
||||
pkgs.maturin
|
||||
# required by python numpy
|
||||
pkgs.libz
|
||||
|
||||
# for bevy
|
||||
pkgs.alsa-lib
|
||||
pkgs.udev
|
||||
|
|
@ -56,25 +47,6 @@ in
|
|||
# https://devenv.sh/languages/
|
||||
languages.rust.enable = true;
|
||||
|
||||
|
||||
# AI training with python
|
||||
enterShell = ''
|
||||
PYTHONPATH=$PYTHONPATH:$PWD/.devenv/state/venv/lib/python3/site-packages
|
||||
'';
|
||||
|
||||
languages.python = {
|
||||
enable = true;
|
||||
uv.enable = true;
|
||||
venv.enable = true;
|
||||
venv.requirements = "
|
||||
pip
|
||||
gymnasium
|
||||
numpy
|
||||
stable-baselines3
|
||||
shimmy
|
||||
";
|
||||
};
|
||||
|
||||
# https://devenv.sh/scripts/
|
||||
# scripts.hello.exec = "echo hello from $GREET";
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
inputs:
|
||||
nixpkgs:
|
||||
url: github:NixOS/nixpkgs/nixos-25.11
|
||||
nixpkgs-cmake3:
|
||||
url: github:NixOS/nixpkgs/f4b140d5b253f5e2a1ff4e5506edbf8267724bde
|
||||
url: github:NixOS/nixpkgs/nixpkgs-unstable
|
||||
|
|
@ -53,10 +53,6 @@ Client
|
|||
|
||||
### Epic : Bot
|
||||
|
||||
- PGX
|
||||
- https://joe-antognini.github.io/ml/jax-tic-tac-toe
|
||||
- https://www.sotets.uk/pgx/api_usage/
|
||||
|
||||
- OpenAi gym
|
||||
- doc gymnasium <https://gymnasium.farama.org/introduction/basic_usage/>
|
||||
- Rust implementation for OpenAi gym <https://github.com/MathisWellmann/gym-rs>
|
||||
|
|
|
|||
|
|
@ -1,992 +0,0 @@
|
|||
# Plan: C++ OpenSpiel Game via cxx.rs
|
||||
|
||||
> Implementation plan for a native C++ OpenSpiel game for Trictrac, powered by the existing Rust engine through [cxx.rs](https://cxx.rs/) bindings.
|
||||
>
|
||||
> Base on reading: `store/src/pyengine.rs`, `store/src/training_common.rs`, `store/src/game.rs`, `store/src/board.rs`, `store/src/player.rs`, `store/src/game_rules_points.rs`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.h`, `forks/open_spiel/open_spiel/games/backgammon/backgammon.cc`, `forks/open_spiel/open_spiel/spiel.h`, `forks/open_spiel/open_spiel/games/CMakeLists.txt`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Overview
|
||||
|
||||
The Python binding (`pyengine.rs` + `trictrac.py`) wraps the Rust engine via PyO3. The goal here is an analogous C++ binding:
|
||||
|
||||
- **`store/src/cxxengine.rs`** — defines a `#[cxx::bridge]` exposing an opaque `TricTracEngine` Rust type with the same logical API as `pyengine.rs`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.h`** — C++ header for a `TrictracGame : public Game` and `TrictracState : public State`.
|
||||
- **`forks/open_spiel/open_spiel/games/trictrac/trictrac.cc`** — C++ implementation that holds a `rust::Box<ffi::TricTracEngine>` and delegates all logic to Rust.
|
||||
- Build wired together via **corrosion** (CMake-native Rust integration) and `cxx-build`.
|
||||
|
||||
The resulting C++ game registers itself as `"trictrac"` via `REGISTER_SPIEL_GAME` and is consumable by any OpenSpiel algorithm (AlphaZero, MCTS, etc.) that works with C++ games.
|
||||
|
||||
---
|
||||
|
||||
## 2. Files to Create / Modify
|
||||
|
||||
```
|
||||
trictrac/
|
||||
store/
|
||||
Cargo.toml ← MODIFY: add cxx, cxx-build, staticlib crate-type
|
||||
build.rs ← CREATE: cxx-build bridge registration
|
||||
src/
|
||||
lib.rs ← MODIFY: add cxxengine module
|
||||
cxxengine.rs ← CREATE: #[cxx::bridge] definition + impl
|
||||
|
||||
forks/open_spiel/
|
||||
CMakeLists.txt ← MODIFY: add Corrosion FetchContent
|
||||
open_spiel/
|
||||
games/
|
||||
CMakeLists.txt ← MODIFY: add trictrac/ sources + test
|
||||
trictrac/ ← CREATE directory
|
||||
trictrac.h ← CREATE
|
||||
trictrac.cc ← CREATE
|
||||
trictrac_test.cc ← CREATE
|
||||
|
||||
justfile ← MODIFY: add buildtrictrac target
|
||||
trictrac/
|
||||
justfile ← MODIFY: add cxxlib target
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Step 1 — Rust: `store/Cargo.toml`
|
||||
|
||||
Add `cxx` as a runtime dependency and `cxx-build` as a build dependency. Add `staticlib` to `crate-type` so CMake can link against the Rust code as a static library.
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "trictrac-store"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "trictrac_store"
|
||||
# cdylib → Python .so (used by maturin / pyengine)
|
||||
# rlib → used by other Rust crates in the workspace
|
||||
# staticlib → used by C++ consumers (cxxengine)
|
||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.21.7"
|
||||
cxx = "1.0"
|
||||
log = "0.4.20"
|
||||
merge = "0.1.0"
|
||||
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
|
||||
rand = "0.9"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
```
|
||||
|
||||
> **Note on `staticlib` + `cdylib` coexistence.** Cargo will build all three types when asked. The static library is used by the C++ OpenSpiel build; the cdylib is used by maturin for the Python wheel. They do not interfere. The `rlib` is used internally by other workspace members (`bot`, `client_cli`).
|
||||
|
||||
---
|
||||
|
||||
## 4. Step 2 — Rust: `store/build.rs`
|
||||
|
||||
The `build.rs` script drives `cxx-build`, which compiles the C++ side of the bridge (the generated shim) and tells Cargo where to find the generated header.
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
cxx_build::bridge("src/cxxengine.rs")
|
||||
.std("c++17")
|
||||
.compile("trictrac-cxx");
|
||||
|
||||
// Re-run if the bridge source changes
|
||||
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
||||
}
|
||||
```
|
||||
|
||||
`cxx-build` will:
|
||||
|
||||
- Parse `src/cxxengine.rs` for the `#[cxx::bridge]` block.
|
||||
- Generate `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` — the C++ header.
|
||||
- Generate `$OUT_DIR/cxxbridge/sources/trictrac_store/src/cxxengine.rs.cc` — the C++ shim source.
|
||||
- Compile the shim into `libtrictrac-cxx.a` (alongside the Rust `libtrictrac_store.a`).
|
||||
|
||||
---
|
||||
|
||||
## 5. Step 3 — Rust: `store/src/cxxengine.rs`
|
||||
|
||||
This is the heart of the C++ integration. It mirrors `pyengine.rs` in structure but uses `#[cxx::bridge]` instead of PyO3.
|
||||
|
||||
### Design decisions vs. `pyengine.rs`
|
||||
|
||||
| pyengine | cxxengine | Reason |
|
||||
| ------------------------- | ---------------------------- | -------------------------------------------- |
|
||||
| `PyResult<()>` for errors | `Result<()>` | cxx.rs translates `Err` to a C++ exception |
|
||||
| `(u8, u8)` tuple for dice | `DicePair` shared struct | cxx cannot cross tuples |
|
||||
| `Vec<usize>` for actions | `Vec<u64>` | cxx does not support `usize` |
|
||||
| `[i32; 2]` for scores | `PlayerScores` shared struct | cxx cannot cross fixed arrays |
|
||||
| Clone via PyO3 pickling | `clone_engine()` method | OpenSpiel's `State::Clone()` needs deep copy |
|
||||
|
||||
### File content
|
||||
|
||||
```rust
|
||||
//! # C++ bindings for the TricTrac game engine via cxx.rs
|
||||
//!
|
||||
//! Exposes an opaque `TricTracEngine` type and associated functions
|
||||
//! to C++. The C++ side (trictrac.cc) uses `rust::Box<ffi::TricTracEngine>`.
|
||||
//!
|
||||
//! The Rust engine always works from the perspective of White (player 1).
|
||||
//! For Black (player 2), the board is mirrored before computing actions
|
||||
//! and events are mirrored back before applying — exactly as in pyengine.rs.
|
||||
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||
|
||||
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
||||
|
||||
#[cxx::bridge(namespace = "trictrac_engine")]
|
||||
pub mod ffi {
|
||||
// ── Shared types (visible to both Rust and C++) ───────────────────────────
|
||||
|
||||
/// Two dice values passed from C++ to Rust for a dice-roll event.
|
||||
struct DicePair {
|
||||
die1: u8,
|
||||
die2: u8,
|
||||
}
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
struct PlayerScores {
|
||||
score_p1: i32,
|
||||
score_p2: i32,
|
||||
}
|
||||
|
||||
// ── Opaque Rust type exposed to C++ ───────────────────────────────────────
|
||||
|
||||
extern "Rust" {
|
||||
/// Opaque handle to a TricTrac game state.
|
||||
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
||||
type TricTracEngine;
|
||||
|
||||
/// Create a new engine, initialise two players, begin with player 1.
|
||||
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
||||
|
||||
/// Return a deep copy of the engine (needed for State::Clone()).
|
||||
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
||||
|
||||
// ── Queries ───────────────────────────────────────────────────────────
|
||||
|
||||
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
||||
fn needs_roll(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// True when Stage::Ended.
|
||||
fn is_game_ended(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// Active player index: 0 (player 1 / White) or 1 (player 2 / Black).
|
||||
fn current_player_idx(self: &TricTracEngine) -> u64;
|
||||
|
||||
/// Legal action indices for `player_idx`. Returns empty vec if it is
|
||||
/// not that player's turn. Indices are in [0, 513].
|
||||
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Vec<u64>;
|
||||
|
||||
/// Human-readable action description, e.g. "0:Move { dice_order: true … }".
|
||||
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
||||
|
||||
/// Both players' scores: holes * 12 + points.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 36-element state observation vector (i8). Mirrored for player 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
||||
/// Full debug representation of the current state.
|
||||
fn to_debug_string(self: &TricTracEngine) -> String;
|
||||
|
||||
// ── Mutations ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a dice roll result. Returns Err if not in RollWaiting stage.
|
||||
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
||||
|
||||
/// Apply a player action (move, go, roll). Returns Err if invalid.
|
||||
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Opaque type implementation ────────────────────────────────────────────────
|
||||
|
||||
pub struct TricTracEngine {
|
||||
game_state: GameState,
|
||||
}
|
||||
|
||||
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
Box::new(TricTracEngine { game_state })
|
||||
}
|
||||
|
||||
impl TricTracEngine {
|
||||
fn clone_engine(&self) -> Box<TricTracEngine> {
|
||||
Box::new(TricTracEngine {
|
||||
game_state: self.game_state.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn needs_roll(&self) -> bool {
|
||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||
}
|
||||
|
||||
fn is_game_ended(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended
|
||||
}
|
||||
|
||||
/// Returns 0 for player 1 (White) and 1 for player 2 (Black).
|
||||
fn current_player_idx(&self) -> u64 {
|
||||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<u64> {
|
||||
if player_idx == self.current_player_idx() {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
.into_iter()
|
||||
.map(|i| i as u64)
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
||||
TrictracAction::from_action_index(action_idx as usize)
|
||||
.map(|a| format!("{}:{}", player_idx, a))
|
||||
.unwrap_or_else(|| "unknown action".into())
|
||||
}
|
||||
|
||||
fn get_players_scores(&self) -> ffi::PlayerScores {
|
||||
ffi::PlayerScores {
|
||||
score_p1: self.score_for(1),
|
||||
score_p2: self.score_for(2),
|
||||
}
|
||||
}
|
||||
|
||||
fn score_for(&self, player_id: u64) -> i32 {
|
||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
||||
player.holes as i32 * 12 + player.points as i32
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_debug_string(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
|
||||
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> Result<(), String> {
|
||||
let player_id = self.game_state.active_player_id;
|
||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
return Err("Not in RollWaiting stage".into());
|
||||
}
|
||||
let dice = Dice {
|
||||
values: (dice.die1, dice.die2),
|
||||
};
|
||||
self.game_state
|
||||
.consume(&GameEvent::RollResult { player_id, dice });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: u64) -> Result<(), String> {
|
||||
let action_idx = action_idx as usize;
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
|
||||
let event = TrictracAction::from_action_index(action_idx)
|
||||
.and_then(|a| {
|
||||
let game_state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(game_state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
});
|
||||
|
||||
match event {
|
||||
Some(evt) if self.game_state.validate(&evt) => {
|
||||
self.game_state.consume(&evt);
|
||||
Ok(())
|
||||
}
|
||||
Some(_) => Err("Action is invalid".into()),
|
||||
None => Err("Could not build event from action index".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Note on `Result<(), String>`**: cxx.rs requires the error type to implement `std::error::Error`. `String` does not implement it directly. Two options:
|
||||
>
|
||||
> - Use `anyhow::Error` (add `anyhow` dependency).
|
||||
> - Define a thin newtype `struct EngineError(String)` that implements `std::error::Error`.
|
||||
>
|
||||
> The recommended approach is `anyhow`:
|
||||
>
|
||||
> ```toml
|
||||
> [dependencies]
|
||||
> anyhow = "1.0"
|
||||
> ```
|
||||
>
|
||||
> Then `fn apply_action(...) -> Result<(), anyhow::Error>` — cxx.rs will convert this to a C++ exception of type `rust::Error` carrying the message.
|
||||
|
||||
---
|
||||
|
||||
## 6. Step 4 — Rust: `store/src/lib.rs`
|
||||
|
||||
Add the new module:
|
||||
|
||||
```rust
|
||||
// existing modules …
|
||||
mod pyengine;
|
||||
|
||||
// NEW: C++ bindings via cxx.rs
|
||||
pub mod cxxengine;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Step 5 — C++: `trictrac/trictrac.h`
|
||||
|
||||
Modelled closely after `backgammon/backgammon.h`. The state holds a `rust::Box<ffi::TricTracEngine>` and delegates everything to it.
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.h
|
||||
#ifndef OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
#define OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
// Generated by cxx-build from store/src/cxxengine.rs.
|
||||
// The include path is set by CMake (see CMakeLists.txt).
|
||||
#include "trictrac_store/src/cxxengine.rs.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
|
||||
inline constexpr int kNumPlayers = 2;
|
||||
inline constexpr int kNumChanceOutcomes = 36; // 6 × 6 dice outcomes
|
||||
inline constexpr int kNumDistinctActions = 514; // matches ACTION_SPACE_SIZE in Rust
|
||||
inline constexpr int kStateEncodingSize = 36; // matches to_vec() length in Rust
|
||||
inline constexpr int kDefaultMaxTurns = 1000;
|
||||
|
||||
class TrictracGame;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracState
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracState : public State {
|
||||
public:
|
||||
explicit TrictracState(std::shared_ptr<const Game> game);
|
||||
TrictracState(const TrictracState& other);
|
||||
|
||||
Player CurrentPlayer() const override;
|
||||
std::vector<Action> LegalActions() const override;
|
||||
std::string ActionToString(Player player, Action move_id) const override;
|
||||
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;
|
||||
std::string ToString() const override;
|
||||
bool IsTerminal() const override;
|
||||
std::vector<double> Returns() const override;
|
||||
std::string ObservationString(Player player) const override;
|
||||
void ObservationTensor(Player player, absl::Span<float> values) const override;
|
||||
std::unique_ptr<State> Clone() const override;
|
||||
|
||||
protected:
|
||||
void DoApplyAction(Action move_id) override;
|
||||
|
||||
private:
|
||||
// Decode a chance action index [0,35] to (die1, die2).
|
||||
// Matches Python: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
static trictrac_engine::DicePair DecodeChanceAction(Action action);
|
||||
|
||||
// The Rust engine handle. Deep-copied via clone_engine() when cloning state.
|
||||
rust::Box<trictrac_engine::TricTracEngine> engine_;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrictracGame
|
||||
// ---------------------------------------------------------------------------
|
||||
class TrictracGame : public Game {
|
||||
public:
|
||||
explicit TrictracGame(const GameParameters& params);
|
||||
|
||||
int NumDistinctActions() const override { return kNumDistinctActions; }
|
||||
std::unique_ptr<State> NewInitialState() const override;
|
||||
int MaxChanceOutcomes() const override { return kNumChanceOutcomes; }
|
||||
int NumPlayers() const override { return kNumPlayers; }
|
||||
double MinUtility() const override { return 0.0; }
|
||||
double MaxUtility() const override { return 200.0; }
|
||||
int MaxGameLength() const override { return 3 * max_turns_; }
|
||||
int MaxChanceNodesInHistory() const override { return MaxGameLength(); }
|
||||
std::vector<int> ObservationTensorShape() const override {
|
||||
return {kStateEncodingSize};
|
||||
}
|
||||
|
||||
private:
|
||||
int max_turns_;
|
||||
};
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
#endif // OPEN_SPIEL_GAMES_TRICTRAC_H_
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Step 6 — C++: `trictrac/trictrac.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "open_spiel/abseil-cpp/absl/types/span.h"
|
||||
#include "open_spiel/game_parameters.h"
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/spiel_globals.h"
|
||||
#include "open_spiel/spiel_utils.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
// ── Game registration ────────────────────────────────────────────────────────
|
||||
|
||||
const GameType kGameType{
|
||||
/*short_name=*/"trictrac",
|
||||
/*long_name=*/"Trictrac",
|
||||
GameType::Dynamics::kSequential,
|
||||
GameType::ChanceMode::kExplicitStochastic,
|
||||
GameType::Information::kPerfectInformation,
|
||||
GameType::Utility::kGeneralSum,
|
||||
GameType::RewardModel::kRewards,
|
||||
/*min_num_players=*/kNumPlayers,
|
||||
/*max_num_players=*/kNumPlayers,
|
||||
/*provides_information_state_string=*/false,
|
||||
/*provides_information_state_tensor=*/false,
|
||||
/*provides_observation_string=*/true,
|
||||
/*provides_observation_tensor=*/true,
|
||||
/*parameter_specification=*/{
|
||||
{"max_turns", GameParameter(kDefaultMaxTurns)},
|
||||
}};
|
||||
|
||||
static std::shared_ptr<const Game> Factory(const GameParameters& params) {
|
||||
return std::make_shared<const TrictracGame>(params);
|
||||
}
|
||||
|
||||
REGISTER_SPIEL_GAME(kGameType, Factory);
|
||||
|
||||
} // namespace
|
||||
|
||||
// ── TrictracGame ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracGame::TrictracGame(const GameParameters& params)
|
||||
: Game(kGameType, params),
|
||||
max_turns_(ParameterValue<int>("max_turns", kDefaultMaxTurns)) {}
|
||||
|
||||
std::unique_ptr<State> TrictracGame::NewInitialState() const {
|
||||
return std::make_unique<TrictracState>(shared_from_this());
|
||||
}
|
||||
|
||||
// ── TrictracState ─────────────────────────────────────────────────────────────
|
||||
|
||||
TrictracState::TrictracState(std::shared_ptr<const Game> game)
|
||||
: State(game),
|
||||
engine_(trictrac_engine::new_trictrac_engine()) {}
|
||||
|
||||
// Copy constructor: deep-copy the Rust engine via clone_engine().
|
||||
TrictracState::TrictracState(const TrictracState& other)
|
||||
: State(other),
|
||||
engine_(other.engine_->clone_engine()) {}
|
||||
|
||||
std::unique_ptr<State> TrictracState::Clone() const {
|
||||
return std::make_unique<TrictracState>(*this);
|
||||
}
|
||||
|
||||
// ── Current player ────────────────────────────────────────────────────────────
|
||||
|
||||
Player TrictracState::CurrentPlayer() const {
|
||||
if (engine_->is_game_ended()) return kTerminalPlayerId;
|
||||
if (engine_->needs_roll()) return kChancePlayerId;
|
||||
return static_cast<Player>(engine_->current_player_idx());
|
||||
}
|
||||
|
||||
// ── Legal actions ─────────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<Action> TrictracState::LegalActions() const {
|
||||
if (IsChanceNode()) {
|
||||
// All 36 dice outcomes are equally likely; return indices 0–35.
|
||||
std::vector<Action> actions(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) actions[i] = i;
|
||||
return actions;
|
||||
}
|
||||
Player player = CurrentPlayer();
|
||||
rust::Vec<uint64_t> rust_actions =
|
||||
engine_->get_legal_actions(static_cast<uint64_t>(player));
|
||||
std::vector<Action> actions;
|
||||
actions.reserve(rust_actions.size());
|
||||
for (uint64_t a : rust_actions) actions.push_back(static_cast<Action>(a));
|
||||
return actions;
|
||||
}
|
||||
|
||||
// ── Chance outcomes ───────────────────────────────────────────────────────────
|
||||
|
||||
std::vector<std::pair<Action, double>> TrictracState::ChanceOutcomes() const {
|
||||
SPIEL_CHECK_TRUE(IsChanceNode());
|
||||
const double p = 1.0 / kNumChanceOutcomes;
|
||||
std::vector<std::pair<Action, double>> outcomes;
|
||||
outcomes.reserve(kNumChanceOutcomes);
|
||||
for (int i = 0; i < kNumChanceOutcomes; ++i) outcomes.emplace_back(i, p);
|
||||
return outcomes;
|
||||
}
|
||||
|
||||
// ── Apply action ──────────────────────────────────────────────────────────────
|
||||
|
||||
/*static*/
|
||||
trictrac_engine::DicePair TrictracState::DecodeChanceAction(Action action) {
|
||||
// Matches: [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
return trictrac_engine::DicePair{
|
||||
/*die1=*/static_cast<uint8_t>(action / 6 + 1),
|
||||
/*die2=*/static_cast<uint8_t>(action % 6 + 1),
|
||||
};
|
||||
}
|
||||
|
||||
void TrictracState::DoApplyAction(Action action) {
|
||||
if (IsChanceNode()) {
|
||||
engine_->apply_dice_roll(DecodeChanceAction(action));
|
||||
} else {
|
||||
engine_->apply_action(static_cast<uint64_t>(action));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Terminal & returns ────────────────────────────────────────────────────────
|
||||
|
||||
bool TrictracState::IsTerminal() const {
|
||||
return engine_->is_game_ended();
|
||||
}
|
||||
|
||||
std::vector<double> TrictracState::Returns() const {
|
||||
trictrac_engine::PlayerScores scores = engine_->get_players_scores();
|
||||
return {static_cast<double>(scores.score_p1),
|
||||
static_cast<double>(scores.score_p2)};
|
||||
}
|
||||
|
||||
// ── Observation ───────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ObservationString(Player player) const {
|
||||
return std::string(engine_->get_observation_string(
|
||||
static_cast<uint64_t>(player)));
|
||||
}
|
||||
|
||||
void TrictracState::ObservationTensor(Player player,
|
||||
absl::Span<float> values) const {
|
||||
SPIEL_CHECK_EQ(values.size(), kStateEncodingSize);
|
||||
rust::Vec<int8_t> tensor =
|
||||
engine_->get_tensor(static_cast<uint64_t>(player));
|
||||
SPIEL_CHECK_EQ(tensor.size(), static_cast<size_t>(kStateEncodingSize));
|
||||
for (int i = 0; i < kStateEncodingSize; ++i) {
|
||||
values[i] = static_cast<float>(tensor[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Strings ───────────────────────────────────────────────────────────────────
|
||||
|
||||
std::string TrictracState::ToString() const {
|
||||
return std::string(engine_->to_debug_string());
|
||||
}
|
||||
|
||||
std::string TrictracState::ActionToString(Player player, Action action) const {
|
||||
if (IsChanceNode()) {
|
||||
trictrac_engine::DicePair d = DecodeChanceAction(action);
|
||||
return "(" + std::to_string(d.die1) + ", " + std::to_string(d.die2) + ")";
|
||||
}
|
||||
return std::string(engine_->action_to_string(
|
||||
static_cast<uint64_t>(player), static_cast<uint64_t>(action)));
|
||||
}
|
||||
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Step 7 — C++: `trictrac/trictrac_test.cc`
|
||||
|
||||
```cpp
|
||||
// open_spiel/games/trictrac/trictrac_test.cc
|
||||
#include "open_spiel/games/trictrac/trictrac.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "open_spiel/spiel.h"
|
||||
#include "open_spiel/tests/basic_tests.h"
|
||||
#include "open_spiel/utils/init.h"
|
||||
|
||||
namespace open_spiel {
|
||||
namespace trictrac {
|
||||
namespace {
|
||||
|
||||
void BasicTrictracTests() {
|
||||
testing::LoadGameTest("trictrac");
|
||||
testing::RandomSimTest(*LoadGame("trictrac"), /*num_sims=*/5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace trictrac
|
||||
} // namespace open_spiel
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
open_spiel::Init(&argc, &argv);
|
||||
open_spiel::trictrac::BasicTrictracTests();
|
||||
std::cout << "trictrac tests passed" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Step 8 — Build System: `forks/open_spiel/CMakeLists.txt`
|
||||
|
||||
The top-level `CMakeLists.txt` must be extended to bring in **Corrosion**, the standard CMake module for Rust. Add this block before the main `open_spiel` target is defined:
|
||||
|
||||
```cmake
|
||||
# ── Corrosion: CMake integration for Rust ────────────────────────────────────
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
Corrosion
|
||||
GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git
|
||||
GIT_TAG v0.5.1 # pin to a stable release
|
||||
)
|
||||
FetchContent_MakeAvailable(Corrosion)
|
||||
|
||||
# Import the trictrac-store Rust crate.
|
||||
# This creates a CMake target named 'trictrac-store'.
|
||||
corrosion_import_crate(
|
||||
MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml
|
||||
CRATES trictrac-store
|
||||
)
|
||||
|
||||
# Generate the cxx bridge from cxxengine.rs.
|
||||
# corrosion_add_cxxbridge:
|
||||
# - runs cxx-build as part of the Rust build
|
||||
# - creates a CMake target 'trictrac_cxx_bridge' that:
|
||||
# * compiles the generated C++ shim
|
||||
# * exposes INTERFACE include dirs for the generated .rs.h header
|
||||
corrosion_add_cxxbridge(trictrac_cxx_bridge
|
||||
CRATE trictrac-store
|
||||
FILES src/cxxengine.rs
|
||||
)
|
||||
```
|
||||
|
||||
> **Where to insert**: After the `cmake_minimum_required` / `project()` lines and before `add_subdirectory(open_spiel)` (or wherever games are pulled in). Check the actual file structure before editing.
|
||||
|
||||
---
|
||||
|
||||
## 11. Step 9 — Build System: `open_spiel/games/CMakeLists.txt`
|
||||
|
||||
Two changes: add the new source files to `GAME_SOURCES`, and add a test target.
|
||||
|
||||
### 11.1 Add to `GAME_SOURCES`
|
||||
|
||||
Find the alphabetically correct position (after `tic_tac_toe`, before `trade_comm`) and add:
|
||||
|
||||
```cmake
|
||||
set(GAME_SOURCES
|
||||
# ... existing games ...
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
# ... remaining games ...
|
||||
)
|
||||
```
|
||||
|
||||
### 11.2 Link cxx bridge into OpenSpiel objects
|
||||
|
||||
The `trictrac` sources need the Rust library and cxx bridge linked in. Since the existing build compiles all `GAME_SOURCES` into `${OPEN_SPIEL_OBJECTS}` as a single object library, you need to ensure the Rust library and cxx bridge are linked when that object library is consumed.
|
||||
|
||||
The cleanest approach is to add the link dependencies to the main `open_spiel` library target. Find where `open_spiel` is defined (likely in `open_spiel/CMakeLists.txt`) and add:
|
||||
|
||||
```cmake
|
||||
target_link_libraries(open_spiel
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge # C++ shim generated by cxx-build
|
||||
trictrac-store # Rust static library
|
||||
)
|
||||
```
|
||||
|
||||
If modifying the central `open_spiel` target is too disruptive, create an explicit object library for the trictrac game:
|
||||
|
||||
```cmake
|
||||
add_library(trictrac_game OBJECT
|
||||
trictrac/trictrac.cc
|
||||
trictrac/trictrac.h
|
||||
)
|
||||
target_include_directories(trictrac_game
|
||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
target_link_libraries(trictrac_game
|
||||
PUBLIC
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
open_spiel_core # or whatever the core target is called
|
||||
)
|
||||
```
|
||||
|
||||
Then reference `$<TARGET_OBJECTS:trictrac_game>` in relevant executables.
|
||||
|
||||
### 11.3 Add the test
|
||||
|
||||
```cmake
|
||||
add_executable(trictrac_test
|
||||
trictrac/trictrac_test.cc
|
||||
${OPEN_SPIEL_OBJECTS}
|
||||
$<TARGET_OBJECTS:tests>
|
||||
)
|
||||
target_link_libraries(trictrac_test
|
||||
PRIVATE
|
||||
trictrac_cxx_bridge
|
||||
trictrac-store
|
||||
)
|
||||
add_test(trictrac_test trictrac_test)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Step 10 — Justfile updates
|
||||
|
||||
### `trictrac/justfile` — add `cxxlib` target
|
||||
|
||||
Builds the Rust crate as a static library (for use by the C++ build) and confirms the generated header exists:
|
||||
|
||||
```just
|
||||
cxxlib:
|
||||
cargo build --release -p trictrac-store
|
||||
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
|
||||
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
|
||||
```
|
||||
|
||||
### `forks/open_spiel/justfile` — add `buildtrictrac` and `testtrictrac`
|
||||
|
||||
```just
|
||||
buildtrictrac:
|
||||
# Rebuild the Rust static lib first, then CMake
|
||||
cd ../../trictrac && cargo build --release -p trictrac-store
|
||||
mkdir -p build && cd build && \
|
||||
CXX=$(which clang++) cmake -DCMAKE_BUILD_TYPE=Release ../open_spiel && \
|
||||
make -j$(nproc) trictrac_test
|
||||
|
||||
testtrictrac: buildtrictrac
|
||||
./build/trictrac_test
|
||||
|
||||
playtrictrac_cpp:
|
||||
./build/examples/example --game=trictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Key Design Decisions
|
||||
|
||||
### 13.1 Opaque type with `clone_engine()`
|
||||
|
||||
OpenSpiel's `State::Clone()` must return a fully independent copy of the game state (used extensively by search algorithms). Since `TricTracEngine` is an opaque Rust type, C++ cannot copy it directly. The bridge exposes `clone_engine() -> Box<TricTracEngine>` which calls `.clone()` on the inner `GameState` (which derives `Clone`).
|
||||
|
||||
### 13.2 Action encoding: same 514-element space
|
||||
|
||||
The C++ game uses the same 514-action encoding as the Python version and the Rust training code. This means:
|
||||
|
||||
- The same `TrictracAction::to_action_index` / `from_action_index` mapping applies.
|
||||
- Action 0 = Roll (used as the bridge between Move and the next chance node).
|
||||
- Actions 2–513 = Move variants (checker ordinal pair + dice order).
|
||||
- A trained C++ model and Python model share the same action space.
|
||||
|
||||
### 13.3 Chance outcome ordering
|
||||
|
||||
The dice outcome ordering is identical to the Python version:
|
||||
|
||||
```
|
||||
action → (die1, die2)
|
||||
0 → (1,1) 6 → (2,1) ... 35 → (6,6)
|
||||
```
|
||||
|
||||
(`die1 = action/6 + 1`, `die2 = action%6 + 1`)
|
||||
|
||||
This matches `_roll_from_chance_idx` in `trictrac.py` exactly, ensuring the two implementations are interchangeable in training pipelines.
|
||||
|
||||
### 13.4 `GameType::Utility::kGeneralSum` + `kRewards`
|
||||
|
||||
Consistent with the Python version. Trictrac is not zero-sum (both players can score positive holes). Intermediate hole rewards are returned by `Returns()` at every state, not just the terminal.
|
||||
|
||||
### 13.5 Mirror pattern preserved
|
||||
|
||||
`get_legal_actions` and `apply_action` in `TricTracEngine` mirror the board for player 2 exactly as `pyengine.rs` does. C++ never needs to know about the mirroring — it simply passes `player_idx` and the Rust engine handles the rest.
|
||||
|
||||
### 13.6 `rust::Box` vs `rust::UniquePtr`
|
||||
|
||||
`rust::Box<T>` (where `T` is an `extern "Rust"` type) is the correct choice for ownership of a Rust type from C++. It owns the heap allocation and drops it when the C++ destructor runs. `rust::UniquePtr<T>` is for C++ types held in Rust.
|
||||
|
||||
### 13.7 Separate struct from `pyengine.rs`
|
||||
|
||||
`TricTracEngine` in `cxxengine.rs` is a separate struct from `TricTrac` in `pyengine.rs`. They both wrap `GameState` but are independent. This avoids:
|
||||
|
||||
- PyO3 and cxx attributes conflicting on the same type.
|
||||
- Changes to one binding breaking the other.
|
||||
- Feature-flag complexity.
|
||||
|
||||
---
|
||||
|
||||
## 14. Known Challenges
|
||||
|
||||
### 14.1 Corrosion path resolution
|
||||
|
||||
`corrosion_import_crate(MANIFEST_PATH ...)` takes a path relative to the CMake source directory. Since the Rust crate lives outside the `forks/open_spiel/` directory, the path will be something like `${CMAKE_CURRENT_SOURCE_DIR}/../../trictrac/store/Cargo.toml`. Verify this resolves correctly on all developer machines (absolute paths are safer but less portable).
|
||||
|
||||
### 14.2 `staticlib` + `cdylib` in one crate
|
||||
|
||||
Rust allows `["cdylib", "rlib", "staticlib"]` in one crate, but there are subtle interactions:
|
||||
|
||||
- The `cdylib` build (for maturin) does not need `staticlib`, and building both doubles the compile time.
|
||||
- Consider gating `staticlib` behind a Cargo feature: `crate-type` is not directly feature-gatable, but you can work around this with a separate `Cargo.toml` or a workspace profile.
|
||||
- Alternatively, accept the extra compile time during development.
|
||||
|
||||
### 14.3 Linker symbols from Rust std
|
||||
|
||||
When linking a Rust `staticlib`, the C++ linker must pull in Rust's runtime and standard library symbols. Corrosion handles this automatically by reading the output of `rustc --print native-static-libs` and adding them to the link command. If not using Corrosion, these must be added manually (typically `-ldl -lm -lpthread -lc`).
|
||||
|
||||
### 14.4 `anyhow` for error types
|
||||
|
||||
cxx.rs requires the `Err` type in `Result<T, E>` to implement `std::error::Error + Send + Sync`. `String` does not satisfy this. Use `anyhow::Error` or define a thin newtype wrapper:
|
||||
|
||||
```rust
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EngineError(String);
|
||||
impl fmt::Display for EngineError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) }
|
||||
}
|
||||
impl std::error::Error for EngineError {}
|
||||
```
|
||||
|
||||
On the C++ side, errors become `rust::Error` exceptions. Wrap `DoApplyAction` in a try-catch during development to surface Rust errors as `SpielFatalError`.
|
||||
|
||||
### 14.5 `UndoAction` not implemented
|
||||
|
||||
OpenSpiel algorithms that use tree search (e.g., MCTS) may call `UndoAction`. The Rust engine's `GameState` stores a full `history` vec of `GameEvent`s but does not implement undo — the history is append-only. To support undo, `Clone()` is the only reliable strategy (clone before applying, discard clone if undo needed). OpenSpiel's default `UndoAction` raises `SpielFatalError`, which is acceptable for RL training but blocks game-tree search. If search support is needed, the simplest approach is to store a stack of cloned states inside `TrictracState` and pop on undo.
|
||||
|
||||
### 14.6 Generated header path in `#include`
|
||||
|
||||
The `#include "trictrac_store/src/cxxengine.rs.h"` path used in `trictrac.h` must match the actual path that `cxx-build` (via corrosion) places the generated header. With `corrosion_add_cxxbridge`, this is typically handled by the `trictrac_cxx_bridge` target's `INTERFACE_INCLUDE_DIRECTORIES`, which CMake propagates automatically to any target that links against it. Verify by inspecting the generated build directory.
|
||||
|
||||
### 14.7 `rust::String` to `std::string` conversion
|
||||
|
||||
The bridge methods returning `String` (Rust) appear as `rust::String` in C++. The conversion `std::string(engine_->action_to_string(...))` is valid because `rust::String` is implicitly convertible to `std::string`. Verify this works with your cxx version; if not, use `engine_->action_to_string(...).c_str()` or `static_cast<std::string>(...)`.
|
||||
|
||||
---
|
||||
|
||||
## 15. Complete File Checklist
|
||||
|
||||
```
|
||||
[ ] trictrac/store/Cargo.toml — add cxx, cxx-build, staticlib
|
||||
[ ] trictrac/store/build.rs — new file: cxx_build::bridge(...)
|
||||
[ ] trictrac/store/src/lib.rs — add `pub mod cxxengine;`
|
||||
[ ] trictrac/store/src/cxxengine.rs — new file: full bridge implementation
|
||||
[ ] trictrac/justfile — add `cxxlib` target
|
||||
[ ] forks/open_spiel/CMakeLists.txt — add Corrosion, corrosion_import_crate, corrosion_add_cxxbridge
|
||||
[ ] forks/open_spiel/open_spiel/games/CMakeLists.txt — add trictrac sources + test
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.h — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — new file
|
||||
[ ] forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — new file
|
||||
[ ] forks/open_spiel/justfile — add buildtrictrac / testtrictrac
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 16. Implementation Order
|
||||
|
||||
Implement in this sequence to get early feedback at each step:
|
||||
|
||||
1. **Rust bridge first** (`Cargo.toml` → `build.rs` → `cxxengine.rs`). Run `cargo build -p trictrac-store` and confirm the static library and generated header are produced.
|
||||
2. **Verify generated header** by locating `$OUT_DIR/cxxbridge/include/trictrac_store/src/cxxengine.rs.h` and inspecting it. Confirm C++ signatures match expectations.
|
||||
3. **CMake Corrosion integration** in `forks/open_spiel/CMakeLists.txt`. Confirm `cmake ..` succeeds and the `trictrac_cxx_bridge` and `trictrac-store` targets exist.
|
||||
4. **Write `trictrac.h` and `trictrac.cc`**. Compile with `make trictrac_test` (or equivalent). Fix linker errors from Rust std symbols if Corrosion doesn't handle them.
|
||||
5. **Run `trictrac_test`**. Use `testing::RandomSimTest` to exercise the full game loop, catch panics from `get_valid_actions` (the `TurnStage::RollWaiting` panic path), and verify game termination.
|
||||
6. **Smoke-test with the example runner**: `./build/examples/example --game=trictrac`.
|
||||
|
||||
---
|
||||
|
||||
# Implementation complete
|
||||
|
||||
All files are in place and trictrac_test passes (168 84 final scores, all assertions pass).
|
||||
|
||||
What was done
|
||||
|
||||
Rust side (from previous session, already complete):
|
||||
|
||||
- trictrac/store/Cargo.toml — added staticlib, cxx, anyhow, cxx-build
|
||||
- trictrac/store/build.rs — drives cxx-build
|
||||
- trictrac/store/src/cxxengine.rs — full cxx bridge + TricTracEngine impl
|
||||
- trictrac/store/src/lib.rs — added pub mod cxxengine;
|
||||
|
||||
C++ side (this session):
|
||||
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.h — game header
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac.cc — game implementation
|
||||
- forks/open_spiel/open_spiel/games/trictrac/trictrac_test.cc — basic test
|
||||
|
||||
Build system:
|
||||
|
||||
- forks/open_spiel/open_spiel/CMakeLists.txt — Corrosion + corrosion_import_crate + corrosion_add_cxxbridge
|
||||
- forks/open_spiel/open_spiel/games/CMakeLists.txt — trictrac_game OBJECT target + trictrac_test executable
|
||||
|
||||
Justfiles:
|
||||
|
||||
- trictrac/justfile — added cxxlib target
|
||||
- forks/open_spiel/justfile — added buildtrictrac and testtrictrac
|
||||
|
||||
Fixes discovered during build
|
||||
|
||||
| Issue | Fix |
|
||||
| ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| Corrosion creates trictrac_store (underscore), not trictrac-store | Used trictrac_store in CRATE arg and target_link_libraries |
|
||||
| FILES src/cxxengine.rs doubled src/src/ | Changed to FILES cxxengine.rs (relative to crate's src/) |
|
||||
| Include path changed: not trictrac-store/src/cxxengine.rs.h but trictrac_cxx_bridge/cxxengine.h | Updated #include in trictrac.h |
|
||||
| rust::Error not in inline cxx types | Added #include "rust/cxx.h" to trictrac.cc |
|
||||
| Init() signature differs in this fork | Changed to Init(argv[0], &argc, &argv, true) |
|
||||
| libtrictrac_store.a contains PyO3 code → missing Python symbols | Added Python3::Python to target_link_libraries |
|
||||
| LegalActions() not sorted (OpenSpiel requires ascending) | Added std::sort |
|
||||
| Duplicate actions for doubles | Added std::unique after sort |
|
||||
| Returns() returned non-zero at intermediate states, violating invariant with default Rewards() | Returns() now returns {0, 0} at non-terminal states |
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
# Python bindings
|
||||
|
||||
## Génération bindings
|
||||
|
||||
```sh
|
||||
# Generate trictrac python lib as a wheel
|
||||
maturin build -m store/Cargo.toml --release
|
||||
# Install wheel in local python env
|
||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Pour vérifier l'accès à la lib : lancer le shell interactif `python`
|
||||
|
||||
```python
|
||||
Python 3.13.11 (main, Dec 5 2025, 16:06:33) [GCC 15.2.0] on linux
|
||||
Type "help", "copyright", "credits" or "license" for more information.
|
||||
>>> import trictrac_store
|
||||
>>> game = trictrac_store.TricTrac()
|
||||
>>> game.get_active_player_id()
|
||||
1
|
||||
```
|
||||
|
||||
### Appels depuis python
|
||||
|
||||
`python bot/python/test.py`
|
||||
|
||||
## Interfaces
|
||||
|
||||
## Entraînement
|
||||
|
|
@ -1,536 +0,0 @@
|
|||
# Trictrac — Research Notes: Engine & OpenSpiel Integration
|
||||
|
||||
> Generated from a deep read of `trictrac/store/src/` and `forks/open_spiel/open_spiel/python/games/trictrac.py`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Architecture Overview
|
||||
|
||||
The project connects two codebases through a compiled Python extension:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ trictrac/store/ (Rust crate) │
|
||||
│ - full game rules engine │
|
||||
│ - pyengine.rs → PyO3 bindings │
|
||||
│ compiled by maturin → .whl │
|
||||
└──────────────┬──────────────────────┘
|
||||
│ import trictrac_store
|
||||
┌──────────────▼──────────────────────┐
|
||||
│ forks/open_spiel/.../trictrac.py │
|
||||
│ - TrictracGame (pyspiel.Game) │
|
||||
│ - TrictracState (pyspiel.State) │
|
||||
│ registered as "python_trictrac" │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Build pipeline:
|
||||
- `just pythonlib` (in `trictrac/`) → `maturin build -m store/Cargo.toml --release` → `.whl` into `target/wheels/`
|
||||
- `just installtrictrac` (in `forks/open_spiel/`) → `pip install --force-reinstall` the wheel into the devenv venv
|
||||
|
||||
The Rust crate is named `trictrac-store` (package) but produces a lib named `trictrac_store` (the Python module name, set in `Cargo.toml` `[lib] name`).
|
||||
|
||||
---
|
||||
|
||||
## 2. Rust Engine: Module Map
|
||||
|
||||
| Module | Responsibility |
|
||||
|---|---|
|
||||
| `board.rs` | Board representation, checker manipulation, quarter analysis |
|
||||
| `dice.rs` | `Dice` struct, `DiceRoller`, bit encoding |
|
||||
| `player.rs` | `Player` struct (score, bredouille), `Color`, `PlayerId`, `CurrentPlayer` |
|
||||
| `game.rs` | `GameState` state machine, `GameEvent` enum, `Stage`/`TurnStage` |
|
||||
| `game_rules_moves.rs` | `MoveRules`: move validation and generation |
|
||||
| `game_rules_points.rs` | `PointsRules`: jan detection and scoring |
|
||||
| `training_common.rs` | `TrictracAction` enum, action-space encoding (size 514) |
|
||||
| `pyengine.rs` | PyO3 Python module exposing `TricTrac` class |
|
||||
| `lib.rs` | Crate root, re-exports |
|
||||
|
||||
---
|
||||
|
||||
## 3. Board Representation
|
||||
|
||||
```rust
|
||||
pub struct Board {
|
||||
positions: [i8; 24],
|
||||
}
|
||||
```
|
||||
|
||||
- 24 fields indexed 0–23 internally, 1–24 externally.
|
||||
- Positive values = White checkers on that field; negative = Black.
|
||||
- Initial state: `[15, 0, ..., 0, -15]` — all 15 white pieces on field 1, all 15 black pieces on field 24.
|
||||
- Field 0 is a sentinel for "exited the board" (never stored in the array).
|
||||
|
||||
**Mirroring** is the central symmetry operation used throughout:
|
||||
|
||||
```rust
|
||||
pub fn mirror(&self) -> Self {
|
||||
let mut positions = self.positions.map(|c| 0 - c);
|
||||
positions.reverse();
|
||||
Board { positions }
|
||||
}
|
||||
```
|
||||
|
||||
This negates all values (swapping who owns each checker) and reverses the array (swapping directions). The entire engine always reasons from White's perspective; Black's moves are handled by mirroring the board first.
|
||||
|
||||
**Quarter structure**: fields 1–6, 7–12, 13–18, 19–24. This maps to the four tables of Trictrac:
|
||||
- 1–6: White's "petit jan" (own table)
|
||||
- 7–12: White's "grand jan"
|
||||
- 13–18: Black's "grand jan" (= White's opponent territory)
|
||||
- 19–24: Black's "petit jan" / White's "jan de retour"
|
||||
|
||||
The "coin de repos" (rest corner) is field 12 for White, field 13 for Black.
|
||||
|
||||
---
|
||||
|
||||
## 4. Dice
|
||||
|
||||
```rust
|
||||
pub struct Dice {
|
||||
pub values: (u8, u8),
|
||||
}
|
||||
```
|
||||
|
||||
Dice are always a pair (never quadrupled for doubles, unlike Backgammon). The `DiceRoller` uses `StdRng` seeded from OS entropy (or an optional fixed seed for tests). Bit encoding: `"{d1:0>3b}{d2:0>3b}"` — 3 bits each, 6 bits total.
|
||||
|
||||
---
|
||||
|
||||
## 5. Player State
|
||||
|
||||
```rust
|
||||
pub struct Player {
|
||||
pub name: String,
|
||||
pub color: Color, // White or Black
|
||||
pub points: u8, // 0–11 (points within current hole)
|
||||
pub holes: u8, // holes won (game ends at >12)
|
||||
pub can_bredouille: bool,
|
||||
pub can_big_bredouille: bool,
|
||||
pub dice_roll_count: u8, // rolls since last new_pick_up()
|
||||
}
|
||||
```
|
||||
|
||||
`PlayerId` is a `u64` alias. Player 1 = White, Player 2 = Black (set at init time; this is fixed for the session in pyengine).
|
||||
|
||||
---
|
||||
|
||||
## 6. Game State Machine
|
||||
|
||||
### Stages
|
||||
|
||||
```rust
|
||||
pub enum Stage { PreGame, InGame, Ended }
|
||||
|
||||
pub enum TurnStage {
|
||||
RollDice, // 1 — player must request a roll
|
||||
RollWaiting, // 0 — waiting for dice result from outside
|
||||
MarkPoints, // 2 — points are being marked (schools mode only)
|
||||
HoldOrGoChoice, // 3 — player won a hole; choose to Go or Hold
|
||||
Move, // 4 — player must move checkers
|
||||
MarkAdvPoints, // 5 — mark opponent's points after the move (schools mode)
|
||||
}
|
||||
```
|
||||
|
||||
### Turn lifecycle (schools disabled — the default in pyengine)
|
||||
|
||||
```
|
||||
RollWaiting
|
||||
│ RollResult → auto-mark points
|
||||
├─[no hole]──→ Move
|
||||
│ │ Move → mark opponent's points → switch player
|
||||
│ └───────────────────────────────→ RollDice (next player)
|
||||
└─[hole won]─→ HoldOrGoChoice
|
||||
├─ Go ──→ new_pick_up() → RollDice (same player)
|
||||
└─ Move ──→ mark opponent's points → switch player → RollDice
|
||||
```
|
||||
|
||||
In schools mode (`schools_enabled = true`), the player explicitly marks their own points (`Mark` event) and then the opponent's points after moving (`MarkAdvPoints` stage).
|
||||
|
||||
### Key events
|
||||
|
||||
```rust
|
||||
pub enum GameEvent {
|
||||
BeginGame { goes_first: PlayerId },
|
||||
EndGame { reason: EndGameReason },
|
||||
PlayerJoined { player_id, name },
|
||||
PlayerDisconnected { player_id },
|
||||
Roll { player_id }, // triggers RollWaiting
|
||||
RollResult { player_id, dice }, // provides dice values
|
||||
Mark { player_id, points }, // explicit point marking (schools mode)
|
||||
Go { player_id }, // choose to restart position after hole
|
||||
Move { player_id, moves: (CheckerMove, CheckerMove) },
|
||||
PlayError,
|
||||
}
|
||||
```
|
||||
|
||||
### Initialization in pyengine
|
||||
|
||||
```rust
|
||||
fn new() -> Self {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
TricTrac { game_state }
|
||||
}
|
||||
```
|
||||
|
||||
Player 1 (White) always goes first. `active_player_id` uses 1-based indexing; pyengine converts to 0-based for the Python side with `active_player_id - 1`.
|
||||
|
||||
---
|
||||
|
||||
## 7. Scoring System (Jans)
|
||||
|
||||
Points are awarded after each dice roll based on "jans" (scoring events) detected by `PointsRules`. All computation assumes White's perspective (board is mirrored for Black before calling).
|
||||
|
||||
### Jan types
|
||||
|
||||
| Jan | Points (normal / doublet) | Direction |
|
||||
|---|---|---|
|
||||
| `TrueHitSmallJan` | 4 / 6 | → active player |
|
||||
| `TrueHitBigJan` | 2 / 4 | → active player |
|
||||
| `TrueHitOpponentCorner` | 4 / 6 | → active player |
|
||||
| `FilledQuarter` | 4 / 6 | → active player |
|
||||
| `FirstPlayerToExit` | 4 / 6 | → active player |
|
||||
| `SixTables` | 4 / 6 | → active player |
|
||||
| `TwoTables` | 4 / 6 | → active player |
|
||||
| `Mezeas` | 4 / 6 | → active player |
|
||||
| `FalseHitSmallJan` | −4 / −6 | → opponent |
|
||||
| `FalseHitBigJan` | −2 / −4 | → opponent |
|
||||
| `ContreTwoTables` | −4 / −6 | → opponent |
|
||||
| `ContreMezeas` | −4 / −6 | → opponent |
|
||||
| `HelplessMan` | −2 / −4 | → opponent |
|
||||
|
||||
A single roll can trigger multiple jans, each scored independently. The jan detection process:
|
||||
1. Try both dice orderings
|
||||
2. Detect "tout d'une" (combined dice move as a virtual single die)
|
||||
3. Prefer true hits over false hits for the same move
|
||||
4. Check quarter-filling opportunities
|
||||
5. Check rare jans (SixTables at roll 3, TwoTables, Mezeas) given specific board positions and talon counts
|
||||
|
||||
### Hole scoring
|
||||
|
||||
```rust
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
let sum_points = p.points + points;
|
||||
let jeux = sum_points / 12; // number of completed holes
|
||||
let holes = match (jeux, p.can_bredouille) {
|
||||
(0, _) => 0,
|
||||
(_, false) => 2 * jeux - 1, // no bredouille bonus
|
||||
(_, true) => 2 * jeux, // bredouille doubles the holes
|
||||
};
|
||||
p.points = sum_points % 12;
|
||||
p.holes += holes;
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
- 12 points = 1 "jeu", which yields 1 or 2 holes depending on bredouille status.
|
||||
- Scoring any points clears the opponent's `can_bredouille`.
|
||||
- Completing a hole resets `can_bredouille` for the scorer.
|
||||
- Game ends when `holes > 12`.
|
||||
- Score reported to OpenSpiel: `holes * 12 + points`.
|
||||
|
||||
### Points from both rolls
|
||||
|
||||
After a roll, the active player's points (`dice_points.0`) are auto-marked immediately. After the Move, the opponent's points (`dice_points.1`) are marked (they were computed at roll-time from the pre-move board).
|
||||
|
||||
---
|
||||
|
||||
## 8. Move Rules
|
||||
|
||||
`MoveRules` always works from White's perspective. Key constraints enforced by `moves_allowed()`:
|
||||
|
||||
1. **Opponent's corner forbidden**: Cannot land on field 13 (opponent's rest corner for White).
|
||||
2. **Corner needs two checkers**: The rest corner (field 12) must be taken or vacated with exactly 2 checkers simultaneously.
|
||||
3. **Corner by effect vs. by power**: If the corner can be taken directly ("par effet"), you cannot take it "par puissance" (using combined dice).
|
||||
4. **Exit preconditions**: All checkers must be in fields 19–24 before any exit is allowed.
|
||||
5. **Exit by effect priority**: If a normal exit is possible, exceedant moves (using overflow) are forbidden.
|
||||
6. **Farthest checker first**: When exiting with exceedant, must exit the checker at the highest field.
|
||||
7. **Must play all dice**: If both dice can be played, playing only one is invalid.
|
||||
8. **Must play strongest die**: If only one die can be played, it must be the higher value die.
|
||||
9. **Must fill quarter**: If a quarter can be completed, the move must complete it.
|
||||
10. **Cannot block opponent's fillable quarter**: Cannot move into a quarter the opponent can still fill.
|
||||
|
||||
The board state after each die application is simulated to check two-step sequences.
|
||||
|
||||
---
|
||||
|
||||
## 9. Action Space (training_common.rs)
|
||||
|
||||
Total size: **514 actions**.
|
||||
|
||||
| Index | Action | Description |
|
||||
|---|---|---|
|
||||
| 0 | `Roll` | Request dice roll (not used in OpenSpiel mode) |
|
||||
| 1 | `Go` | After winning hole: reset board and continue |
|
||||
| 2–257 | `Move { dice_order: true, checker1, checker2 }` | Move with die[0] first |
|
||||
| 258–513 | `Move { dice_order: false, checker1, checker2 }` | Move with die[1] first |
|
||||
|
||||
Move encoding: `index = 2 + (0 if dice_order else 256) + checker1 * 16 + checker2`
|
||||
|
||||
`checker1` and `checker2` are **ordinal positions** (1-based) of specific checkers counted left-to-right across all White-occupied fields, not field indices. Checker 0 = "no move" (empty move). Range: 0–15 (16 values each).
|
||||
|
||||
### Mirror pattern in get_legal_actions / apply_action
|
||||
|
||||
For player 2 (Black):
|
||||
```rust
|
||||
// get_legal_actions: mirror game state before computing
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
|
||||
// apply_action: convert action → event on mirrored state, then mirror the event back
|
||||
a.to_event(&self.game_state.mirror())
|
||||
.map(|e| e.get_mirror(false))
|
||||
```
|
||||
|
||||
This ensures Black's actions are computed as if Black were White on a mirrored board, then translated back to real-board coordinates.
|
||||
|
||||
---
|
||||
|
||||
## 10. Python Bindings (pyengine.rs)
|
||||
|
||||
The `TricTrac` PyO3 class exposes:
|
||||
|
||||
| Method | Signature | Description |
|
||||
|---|---|---|
|
||||
| `new()` | `→ TricTrac` | Create game, init 2 players, begin with player 1 |
|
||||
| `needs_roll()` | `→ bool` | True when in `RollWaiting` stage |
|
||||
| `is_game_ended()` | `→ bool` | True when `Stage::Ended` |
|
||||
| `current_player_idx()` | `→ u64` | 0 or 1 (active_player_id − 1) |
|
||||
| `get_legal_actions(player_idx)` | `→ Vec<usize>` | Action indices for player; empty if not their turn |
|
||||
| `action_to_string(player_idx, action_idx)` | `→ String` | Human-readable action description |
|
||||
| `apply_dice_roll(dices: (u8, u8))` | `→ PyResult<()>` | Inject dice result; errors if not in RollWaiting |
|
||||
| `apply_action(action_idx)` | `→ PyResult<()>` | Apply a game action; validates before applying |
|
||||
| `get_score(player_id)` | `→ i32` | `holes * 12 + points` for player (1-indexed!) |
|
||||
| `get_players_scores()` | `→ [i32; 2]` | `[score_p1, score_p2]` |
|
||||
| `get_tensor(player_idx)` | `→ Vec<i8>` | 36-element state vector (mirrored for player 1) |
|
||||
| `get_observation_string(player_idx)` | `→ String` | Human-readable state (mirrored for player 1) |
|
||||
| `__str__()` | `→ String` | Debug representation of game state |
|
||||
|
||||
Note: `get_score(player_id)` takes a 1-based player ID (1 or 2), unlike `current_player_idx()` which returns 0-based.
|
||||
|
||||
---
|
||||
|
||||
## 11. State Tensor Encoding (36 bytes)
|
||||
|
||||
```
|
||||
[0..23] Board positions (i8): +N white / −N black checkers per field
|
||||
[24] Active player: 0=White, 1=Black
|
||||
[25] TurnStage: 0=RollWaiting, 1=RollDice, 2=MarkPoints, 3=HoldOrGoChoice,
|
||||
4=Move, 5=MarkAdvPoints
|
||||
[26] Dice value 1 (i8)
|
||||
[27] Dice value 2 (i8)
|
||||
[28] White: points (0–11)
|
||||
[29] White: holes (0–12)
|
||||
[30] White: can_bredouille (0 or 1)
|
||||
[31] White: can_big_bredouille (0 or 1)
|
||||
[32] Black: points
|
||||
[33] Black: holes
|
||||
[34] Black: can_bredouille
|
||||
[35] Black: can_big_bredouille
|
||||
```
|
||||
|
||||
When called for player 1 (Black), the entire state is mirrored first (`game_state.mirror().to_vec()`).
|
||||
|
||||
### State ID (base64 string for hashing)
|
||||
|
||||
108 bits packed as 18 base64 characters:
|
||||
- 77 bits: GNUbg-inspired board position encoding (run-length with separators)
|
||||
- 1 bit: active player color
|
||||
- 3 bits: turn stage
|
||||
- 6 bits: dice (3 bits per die)
|
||||
- 10 bits: white player (4 pts + 4 holes + 2 flags)
|
||||
- 10 bits: black player
|
||||
- Padded to 108 bits, grouped as 18 × 6-bit base64 chunks
|
||||
|
||||
---
|
||||
|
||||
## 12. OpenSpiel Integration (trictrac.py)
|
||||
|
||||
### Game registration
|
||||
|
||||
```python
|
||||
pyspiel.register_game(_GAME_TYPE, TrictracGame)
|
||||
```
|
||||
|
||||
Key parameters:
|
||||
- `short_name = "python_trictrac"`
|
||||
- `dynamics = SEQUENTIAL`
|
||||
- `chance_mode = EXPLICIT_STOCHASTIC`
|
||||
- `information = PERFECT_INFORMATION`
|
||||
- `utility = GENERAL_SUM` (both players can score positive; no zero-sum constraint)
|
||||
- `reward_model = REWARDS` (intermediate rewards, not just terminal)
|
||||
- `num_distinct_actions = 514`
|
||||
- `max_chance_outcomes = 36`
|
||||
- `min_utility = 0.0`, `max_utility = 200.0`
|
||||
- `max_game_length = 3000` (rough estimate)
|
||||
|
||||
### Chance node handling
|
||||
|
||||
When `needs_roll()` is true, the state is a chance node. OpenSpiel samples one of 36 outcomes (uniform):
|
||||
|
||||
```python
|
||||
def _roll_from_chance_idx(self, action):
|
||||
return [(i,j) for i in range(1,7) for j in range(1,7)][action]
|
||||
|
||||
def chance_outcomes(self):
|
||||
p = 1.0 / 36
|
||||
return [(i, p) for i in range(0, 36)]
|
||||
```
|
||||
|
||||
Action 0 → (1,1), action 1 → (1,2), …, action 35 → (6,6). The chance action is then passed to `apply_dice_roll((d1, d2))` on the Rust side.
|
||||
|
||||
### Player action handling
|
||||
|
||||
When not a chance node:
|
||||
```python
|
||||
def _legal_actions(self, player):
|
||||
return self._store.get_legal_actions(player)
|
||||
|
||||
def _apply_action(self, action):
|
||||
self._store.apply_action(action)
|
||||
```
|
||||
|
||||
The `Roll` action (index 0) is never returned by `get_legal_actions` in this mode because the Rust side only returns Roll actions from `TurnStage::RollDice`, which is bypassed in the pyengine flow (the RollWaiting→chance node path takes over).
|
||||
|
||||
### Returns
|
||||
|
||||
```python
|
||||
def returns(self):
|
||||
return self._store.get_players_scores()
|
||||
# → [holes_p1 * 12 + points_p1, holes_p2 * 12 + points_p2]
|
||||
```
|
||||
|
||||
These are cumulative scores available at any point during the game (not just terminal), consistent with `reward_model = REWARDS`.
|
||||
|
||||
---
|
||||
|
||||
## 13. Known Issues and Inconsistencies
|
||||
|
||||
### 13.1 `observation_string` missing return (trictrac.py:156)
|
||||
|
||||
```python
|
||||
def observation_string(self, player):
|
||||
self._store.get_observation_string(player) # result discarded, returns None
|
||||
```
|
||||
|
||||
Should be `return self._store.get_observation_string(player)`.
|
||||
|
||||
### 13.2 `observation_tensor` not populating buffer (trictrac.py:159)
|
||||
|
||||
```python
|
||||
def observation_tensor(self, player, values):
|
||||
self._store.get_tensor(player) # result discarded, values not filled
|
||||
```
|
||||
|
||||
OpenSpiel's API expects `values` (a mutable buffer, typically a flat numpy array) to be filled in-place. The returned `Vec<i8>` from `get_tensor()` is discarded. Should copy data into `values`.
|
||||
|
||||
### 13.3 Debug print statement active (trictrac.py:140)
|
||||
|
||||
```python
|
||||
print("in apply action", self.is_chance_node(), action)
|
||||
```
|
||||
|
||||
This fires on every action application. Should be removed or guarded.
|
||||
|
||||
### 13.4 Color swap on new_pick_up disabled
|
||||
|
||||
In `game.rs:new_pick_up()`:
|
||||
|
||||
```rust
|
||||
// XXX : switch colors
|
||||
// désactivé pour le moment car la vérification des mouvements échoue,
|
||||
// cf. https://code.rhumbs.fr/henri/trictrac/issues/31
|
||||
// p.color = p.color.opponent_color();
|
||||
```
|
||||
|
||||
In authentic Trictrac, players swap colors between "relevés" (pick-ups after a hole is won with Go). This is commented out, so the same player always plays White and the same always plays Black throughout the entire game.
|
||||
|
||||
### 13.5 `can_big_bredouille` tracked but not implemented
|
||||
|
||||
The `can_big_bredouille` flag is stored in `Player` and serialized in state encoding, but the scoring logic never reads it. Grande bredouille (a rare extra bonus) is not implemented.
|
||||
|
||||
### 13.6 `Roll` action in action space but unused in OpenSpiel mode
|
||||
|
||||
`TrictracAction::Roll` (index 0) exists in the 514-action space and in `get_valid_actions()` (for `TurnStage::RollDice`). However, in pyengine, the game starts at `RollWaiting` (dice have been requested but not yet rolled), so `TurnStage::RollDice` is never reached from OpenSpiel's perspective. The chance node mechanism replaces the Roll action entirely. The action space slot 0 is permanently wasted from OpenSpiel's point of view.
|
||||
|
||||
### 13.7 `get_valid_actions` panics on `RollWaiting`
|
||||
|
||||
```rust
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||
panic!("get_valid_actions not implemented for turn stage {:?}", ...)
|
||||
}
|
||||
```
|
||||
|
||||
If `get_legal_actions` were ever called while `needs_roll()` is true, this would panic. OpenSpiel's turn logic avoids this because chance nodes are handled separately, but it is a latent danger.
|
||||
|
||||
### 13.8 PPO training script uses wrong model name
|
||||
|
||||
`trictrac_ppo.py` saves to `ppo_backgammon_model.ckpt` — clearly copied from a backgammon example without renaming. Also uses `tensorflow.compat.v1` despite the PyTorch PPO import.
|
||||
|
||||
### 13.9 Opponent points marked at pre-move board state
|
||||
|
||||
The opponent's `dice_points.1` is computed at roll time (before the active player moves), but applied to the opponent after the move. This means the opponent's scoring is evaluated on the board position that existed before the active player moved — which is per the rules of Trictrac (points are based on where pieces could be hit at the moment of the roll), but it's worth noting this subtlety.
|
||||
|
||||
---
|
||||
|
||||
## 14. Data Flow: A Complete Turn
|
||||
|
||||
```
|
||||
Python (OpenSpiel) → Rust (trictrac_store)
|
||||
─────────────────────────────────────────────────────
|
||||
is_chance_node() ← needs_roll() [TurnStage == RollWaiting]
|
||||
(true at game start)
|
||||
|
||||
chance_outcomes() → [(0,p)..(35,p)]
|
||||
|
||||
_apply_action(chance_idx)
|
||||
_roll_from_chance_idx(idx) → (d1, d2)
|
||||
apply_dice_roll((d1, d2)) → consume(RollResult{dice})
|
||||
→ auto-mark active player's points
|
||||
→ if hole: TurnStage=HoldOrGoChoice
|
||||
→ else: TurnStage=Move
|
||||
|
||||
current_player() → 0 or 1
|
||||
|
||||
_legal_actions(player) ← get_legal_actions(player_idx)
|
||||
→ get_valid_actions on (possibly mirrored) state
|
||||
→ Vec<usize> of valid action indices
|
||||
|
||||
_apply_action(action_idx) → apply_action(action_idx)
|
||||
→ TrictracAction::from_action_index
|
||||
→ to_event on (mirrored) state
|
||||
→ mirror event back if player==2
|
||||
→ validate → consume
|
||||
→ mark opponent points
|
||||
→ switch active player
|
||||
→ TurnStage=RollDice (→ pyengine starts next turn)
|
||||
|
||||
Wait — pyengine starts at RollWaiting, not RollDice!
|
||||
The next is_chance_node() call will be true again.
|
||||
```
|
||||
|
||||
Note on turn transition: After a `Move` event in `game.rs`, turn stage becomes `RollDice` (not `RollWaiting`). The pyengine `needs_roll()` checks for `RollWaiting`. So after a move, `is_chance_node()` returns false — OpenSpiel will ask for a regular player action. But `get_valid_actions` at `TurnStage::RollDice` returns only `Roll` (index 0), which is **not** the chance path.
|
||||
|
||||
This reveals a subtlety: after the Move event, the active player has already been switched, so `current_player()` returns the new active player, and `get_legal_actions` returns `[0]` (Roll). OpenSpiel then applies action 0, which calls `apply_action(0)` → `TrictracAction::Roll` → `GameEvent::Roll` → TurnStage becomes `RollWaiting`. Then the next call to `is_chance_node()` returns true, and the chance mechanism kicks in again.
|
||||
|
||||
So the full sequence in OpenSpiel terms is:
|
||||
```
|
||||
[Chance] dice roll → [Player] move → [Player] Roll action → [Chance] dice roll → ...
|
||||
```
|
||||
|
||||
The `Roll` action IS used — it is the bridge between Move completion and the next chance node.
|
||||
|
||||
---
|
||||
|
||||
## 15. Summary of Design Choices
|
||||
|
||||
| Choice | Rationale |
|
||||
|---|---|
|
||||
| All rules engine in Rust | Performance, correctness, can be used in other contexts (CLI, native bots) |
|
||||
| Mirror pattern for Black | Avoids duplicating all rule logic for both colors |
|
||||
| Schools disabled by default | Simpler turn structure for RL training; full protocol for human play |
|
||||
| GENERAL_SUM + REWARDS | Trictrac is not strictly zero-sum; intermediate hole rewards are informative for training |
|
||||
| Action index for checkers (not fields) | Reduces action space; ordinal checker numbering is compact |
|
||||
| 514 action slots | 1 Roll + 1 Go + 256 × 2 move combinations (ordered by die priority × 16 × 16 checker pairs) |
|
||||
| Chance node = dice roll | Standard OpenSpiel pattern for stochastic games |
|
||||
775
doc/refs/claudeAIquestion.md
Normal file
775
doc/refs/claudeAIquestion.md
Normal file
|
|
@ -0,0 +1,775 @@
|
|||
# Description
|
||||
|
||||
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
|
||||
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
|
||||
|
||||
Les règles du jeu et l'état d'une partie sont implémentées dans 'store', l'application ligne de commande est implémentée dans 'client_cli', elle permet déjà de jouer contre un bot, ou de faire jouer deux bots l'un contre l'autre.
|
||||
Les stratégies de bots sont implémentées dans le dossier 'bot'.
|
||||
|
||||
Plus précisément, l'état du jeu est défini par le struct GameState dans store/src/game.rs, la méthode to_string_id() permet de coder cet état de manière compacte dans une chaîne de caractères, mais il n'y a pas l'historique des coups joués. Il y a aussi fmt::Display d'implémenté pour une representation textuelle plus lisible.
|
||||
|
||||
'client_cli/src/game_runner.rs' contient la logique permettant de faire jouer deux bots l'un contre l'autre.
|
||||
'bot/src/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
|
||||
|
||||
J'aimerais maintenant ajouter une stratégie de bot plus forte en entrainant un agent/bot avec OpenAI Gym.
|
||||
|
||||
Pour bénéficier des bibliothèques disponibles en python, comme 'Stable-Baselines3', j'ai utilisé PyO3 pour appeler mon code rust depuis python. Le code rust exposé se trouve dans 'store/src/engine.rs', il ne contient pour le moment que du code d'exemple pour m'assurer que la passerelle entre python et rust fonctionne, mais tout est à implémenter : l'état du jeu, les mouvements disponibles, etc.
|
||||
|
||||
La classe python 'store/python/trictracEnv.py' est elle aussi presque vide, elle ne contient pour le moment qu'un squelette avec des fonctions à implémenter.
|
||||
|
||||
Peux-tu m'aider à modifier engine.rs et trictracenv.py pour permettre d'entrainer un agent (bot) ?
|
||||
|
||||
J'aimerais utiliser ce que j'ai déjà implémenté en rust : l'état d'une partie et les règles, et la fonction renvoyant la liste des coups disponibles pour un joueur à partir d'un état du jeu (get_possible_moves_sequences), mais je ne sais pas trop comment m'y prendre.
|
||||
|
||||
## Réponse de Claude
|
||||
|
||||
Je vois que je n'ai pas l'autorisation d'utiliser l'outil Edit. Je vais vous montrer les modifications nécessaires pour les deux fichiers principaux.
|
||||
|
||||
Voici le contenu à utiliser pour `store/src/engine.rs` :
|
||||
|
||||
```rust
|
||||
//! # Expose trictrac game state and rules in a python module
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyDict, PyTuple};
|
||||
|
||||
use crate::board::CheckerMove;
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::game_rules_moves::MoveRules;
|
||||
use crate::game_rules_points::{PointsRules, PossibleJans};
|
||||
use crate::player::{Color, Player, PlayerId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[pyclass]
|
||||
struct TricTrac {
|
||||
game_state: GameState,
|
||||
dice_roll_sequence: Vec<(u8, u8)>,
|
||||
current_dice_index: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl TricTrac {
|
||||
#[new]
|
||||
fn new() -> Self {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
|
||||
// Initialiser 2 joueurs
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("bot");
|
||||
|
||||
// Commencer la partie avec le joueur 1
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
TricTrac {
|
||||
game_state,
|
||||
dice_roll_sequence: Vec::new(),
|
||||
current_dice_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtenir l'état du jeu sous forme de chaîne de caractères compacte
|
||||
fn get_state_id(&self) -> String {
|
||||
self.game_state.to_string_id()
|
||||
}
|
||||
|
||||
/// Obtenir l'état du jeu sous forme de dictionnaire pour faciliter l'entrainement
|
||||
fn get_state_dict(&self) -> PyResult<Py<PyDict>> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let state_dict = PyDict::new(py);
|
||||
|
||||
// Informations essentielles sur l'état du jeu
|
||||
state_dict.set_item("active_player", self.game_state.active_player_id)?;
|
||||
state_dict.set_item("stage", format!("{:?}", self.game_state.stage))?;
|
||||
state_dict.set_item("turn_stage", format!("{:?}", self.game_state.turn_stage))?;
|
||||
|
||||
// Dés
|
||||
let (dice1, dice2) = self.game_state.dice.values;
|
||||
state_dict.set_item("dice", (dice1, dice2))?;
|
||||
|
||||
// Points des joueurs
|
||||
if let Some(white_player) = self.game_state.get_white_player() {
|
||||
state_dict.set_item("white_points", white_player.points)?;
|
||||
state_dict.set_item("white_holes", white_player.holes)?;
|
||||
}
|
||||
|
||||
if let Some(black_player) = self.game_state.get_black_player() {
|
||||
state_dict.set_item("black_points", black_player.points)?;
|
||||
state_dict.set_item("black_holes", black_player.holes)?;
|
||||
}
|
||||
|
||||
// Positions des pièces
|
||||
let white_positions = self.get_checker_positions(Color::White);
|
||||
let black_positions = self.get_checker_positions(Color::Black);
|
||||
|
||||
state_dict.set_item("white_positions", white_positions)?;
|
||||
state_dict.set_item("black_positions", black_positions)?;
|
||||
|
||||
// État compact pour la comparaison d'états
|
||||
state_dict.set_item("state_id", self.game_state.to_string_id())?;
|
||||
|
||||
Ok(state_dict.into())
|
||||
}
|
||||
|
||||
/// Renvoie les positions des pièces pour un joueur spécifique
|
||||
fn get_checker_positions(&self, color: Color) -> Vec<(usize, i8)> {
|
||||
self.game_state.board.get_color_fields(color)
|
||||
}
|
||||
|
||||
/// Obtenir la liste des mouvements légaux sous forme de paires (from, to)
|
||||
fn get_available_moves(&self) -> Vec<((usize, usize), (usize, usize))> {
|
||||
// L'agent joue toujours le joueur actif
|
||||
let color = self.game_state.player_color_by_id(&self.game_state.active_player_id).unwrap_or(Color::White);
|
||||
|
||||
// Si ce n'est pas le moment de déplacer les pièces, retourner une liste vide
|
||||
if self.game_state.turn_stage != TurnStage::Move && self.game_state.turn_stage != TurnStage::HoldOrGoChoice {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let rules = MoveRules::new(&color, &self.game_state.board, self.game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
// Convertir les mouvements CheckerMove en tuples (from, to) pour Python
|
||||
possible_moves.into_iter()
|
||||
.map(|(move1, move2)| (
|
||||
(move1.get_from(), move1.get_to()),
|
||||
(move2.get_from(), move2.get_to())
|
||||
)).collect()
|
||||
}
|
||||
|
||||
/// Jouer un coup ((from1, to1), (from2, to2))
|
||||
fn play_move(&mut self, moves: ((usize, usize), (usize, usize))) -> bool {
|
||||
let ((from1, to1), (from2, to2)) = moves;
|
||||
|
||||
// Vérifier que c'est au tour du joueur de jouer
|
||||
if self.game_state.turn_stage != TurnStage::Move && self.game_state.turn_stage != TurnStage::HoldOrGoChoice {
|
||||
return false;
|
||||
}
|
||||
|
||||
let move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
let event = GameEvent::Move {
|
||||
player_id: self.game_state.active_player_id,
|
||||
moves: (move1, move2),
|
||||
};
|
||||
|
||||
// Vérifier si le mouvement est valide
|
||||
if !self.game_state.validate(&event) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exécuter le mouvement
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Si l'autre joueur doit lancer les dés maintenant, simuler ce lancement
|
||||
if self.game_state.turn_stage == TurnStage::RollDice {
|
||||
self.roll_dice();
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Lancer les dés (soit aléatoirement, soit en utilisant une séquence prédéfinie)
|
||||
fn roll_dice(&mut self) -> (u8, u8) {
|
||||
// Vérifier que c'est au bon moment pour lancer les dés
|
||||
if self.game_state.turn_stage != TurnStage::RollDice && self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
return self.game_state.dice.values;
|
||||
}
|
||||
|
||||
// Simuler un lancer de dés
|
||||
let dice_values = if !self.dice_roll_sequence.is_empty() && self.current_dice_index < self.dice_roll_sequence.len() {
|
||||
// Utiliser la séquence prédéfinie
|
||||
let dice = self.dice_roll_sequence[self.current_dice_index];
|
||||
self.current_dice_index += 1;
|
||||
dice
|
||||
} else {
|
||||
// Générer aléatoirement
|
||||
(
|
||||
(1 + (rand::random::<u8>() % 6)),
|
||||
(1 + (rand::random::<u8>() % 6))
|
||||
)
|
||||
};
|
||||
|
||||
// Envoyer les événements appropriés
|
||||
let roll_event = GameEvent::Roll {
|
||||
player_id: self.game_state.active_player_id,
|
||||
};
|
||||
|
||||
if self.game_state.validate(&roll_event) {
|
||||
self.game_state.consume(&roll_event);
|
||||
}
|
||||
|
||||
let roll_result_event = GameEvent::RollResult {
|
||||
player_id: self.game_state.active_player_id,
|
||||
dice: Dice { values: dice_values },
|
||||
};
|
||||
|
||||
if self.game_state.validate(&roll_result_event) {
|
||||
self.game_state.consume(&roll_result_event);
|
||||
}
|
||||
|
||||
dice_values
|
||||
}
|
||||
|
||||
/// Marquer des points
|
||||
fn mark_points(&mut self, points: u8) -> bool {
|
||||
// Vérifier que c'est au bon moment pour marquer des points
|
||||
if self.game_state.turn_stage != TurnStage::MarkPoints && self.game_state.turn_stage != TurnStage::MarkAdvPoints {
|
||||
return false;
|
||||
}
|
||||
|
||||
let event = GameEvent::Mark {
|
||||
player_id: self.game_state.active_player_id,
|
||||
points,
|
||||
};
|
||||
|
||||
// Vérifier si l'événement est valide
|
||||
if !self.game_state.validate(&event) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exécuter l'événement
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Si l'autre joueur doit lancer les dés maintenant, simuler ce lancement
|
||||
if self.game_state.turn_stage == TurnStage::RollDice {
|
||||
self.roll_dice();
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Choisir de "continuer" (Go) après avoir gagné un trou
|
||||
fn choose_go(&mut self) -> bool {
|
||||
// Vérifier que c'est au bon moment pour choisir de continuer
|
||||
if self.game_state.turn_stage != TurnStage::HoldOrGoChoice {
|
||||
return false;
|
||||
}
|
||||
|
||||
let event = GameEvent::Go {
|
||||
player_id: self.game_state.active_player_id,
|
||||
};
|
||||
|
||||
// Vérifier si l'événement est valide
|
||||
if !self.game_state.validate(&event) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exécuter l'événement
|
||||
self.game_state.consume(&event);
|
||||
|
||||
// Simuler le lancer de dés pour le prochain tour
|
||||
self.roll_dice();
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Calcule les points maximaux que le joueur actif peut obtenir avec les dés actuels
|
||||
fn calculate_points(&self) -> u8 {
|
||||
let active_player = self.game_state.players.get(&self.game_state.active_player_id);
|
||||
|
||||
if let Some(player) = active_player {
|
||||
let dice_roll_count = player.dice_roll_count;
|
||||
let color = player.color;
|
||||
|
||||
let points_rules = PointsRules::new(&color, &self.game_state.board, self.game_state.dice);
|
||||
let (points, _) = points_rules.get_points(dice_roll_count);
|
||||
|
||||
points
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Réinitialise la partie
|
||||
fn reset(&mut self) {
|
||||
self.game_state = GameState::new(false);
|
||||
|
||||
// Initialiser 2 joueurs
|
||||
self.game_state.init_player("player1");
|
||||
self.game_state.init_player("bot");
|
||||
|
||||
// Commencer la partie avec le joueur 1
|
||||
self.game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
// Réinitialiser l'index de la séquence de dés
|
||||
self.current_dice_index = 0;
|
||||
}
|
||||
|
||||
/// Vérifie si la partie est terminée
|
||||
fn is_done(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended || self.game_state.determine_winner().is_some()
|
||||
}
|
||||
|
||||
/// Obtenir le gagnant de la partie
|
||||
fn get_winner(&self) -> Option<PlayerId> {
|
||||
self.game_state.determine_winner()
|
||||
}
|
||||
|
||||
/// Obtenir le score du joueur actif (nombre de trous)
|
||||
fn get_score(&self, player_id: PlayerId) -> i32 {
|
||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
||||
player.holes as i32
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtenir l'ID du joueur actif
|
||||
fn get_active_player_id(&self) -> PlayerId {
|
||||
self.game_state.active_player_id
|
||||
}
|
||||
|
||||
/// Définir une séquence de dés à utiliser (pour la reproductibilité)
|
||||
fn set_dice_sequence(&mut self, sequence: Vec<(u8, u8)>) {
|
||||
self.dice_roll_sequence = sequence;
|
||||
self.current_dice_index = 0;
|
||||
}
|
||||
|
||||
/// Afficher l'état du jeu (pour le débogage)
|
||||
fn __str__(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule]
|
||||
fn trictrac(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<TricTrac>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Et pour `store/python/trictracEnv.py` :
|
||||
|
||||
```python
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
import trictrac # module Rust exposé via PyO3
|
||||
from typing import Dict, List, Tuple, Optional, Any, Union
|
||||
|
||||
class TricTracEnv(gym.Env):
|
||||
"""Environnement OpenAI Gym pour le jeu de Trictrac"""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self, opponent_strategy="random"):
|
||||
super(TricTracEnv, self).__init__()
|
||||
|
||||
# Instancier le jeu
|
||||
self.game = trictrac.TricTrac()
|
||||
|
||||
# Stratégie de l'adversaire
|
||||
self.opponent_strategy = opponent_strategy
|
||||
|
||||
# Constantes
|
||||
self.MAX_FIELD = 24 # Nombre de cases sur le plateau
|
||||
self.MAX_CHECKERS = 15 # Nombre maximum de pièces par joueur
|
||||
|
||||
# Définition de l'espace d'observation
|
||||
# Format:
|
||||
# - Position des pièces blanches (24)
|
||||
# - Position des pièces noires (24)
|
||||
# - Joueur actif (1: blanc, 2: noir) (1)
|
||||
# - Valeurs des dés (2)
|
||||
# - Points de chaque joueur (2)
|
||||
# - Trous de chaque joueur (2)
|
||||
# - Phase du jeu (1)
|
||||
self.observation_space = spaces.Dict({
|
||||
'board': spaces.Box(low=-self.MAX_CHECKERS, high=self.MAX_CHECKERS, shape=(self.MAX_FIELD,), dtype=np.int8),
|
||||
'active_player': spaces.Discrete(3), # 0: pas de joueur, 1: blanc, 2: noir
|
||||
'dice': spaces.MultiDiscrete([7, 7]), # Valeurs des dés (1-6)
|
||||
'white_points': spaces.Discrete(13), # Points du joueur blanc (0-12)
|
||||
'white_holes': spaces.Discrete(13), # Trous du joueur blanc (0-12)
|
||||
'black_points': spaces.Discrete(13), # Points du joueur noir (0-12)
|
||||
'black_holes': spaces.Discrete(13), # Trous du joueur noir (0-12)
|
||||
'turn_stage': spaces.Discrete(6), # Étape du tour
|
||||
})
|
||||
|
||||
# Définition de l'espace d'action
|
||||
# Format:
|
||||
# - Action type: 0=move, 1=mark, 2=go
|
||||
# - Move: (from1, to1, from2, to2) ou zeros
|
||||
self.action_space = spaces.Dict({
|
||||
'action_type': spaces.Discrete(3),
|
||||
'move': spaces.MultiDiscrete([self.MAX_FIELD + 1, self.MAX_FIELD + 1,
|
||||
self.MAX_FIELD + 1, self.MAX_FIELD + 1])
|
||||
})
|
||||
|
||||
# État courant
|
||||
self.state = self._get_observation()
|
||||
|
||||
# Historique des états pour éviter les situations sans issue
|
||||
self.state_history = []
|
||||
|
||||
# Pour le débogage et l'entraînement
|
||||
self.steps_taken = 0
|
||||
self.max_steps = 1000 # Limite pour éviter les parties infinies
|
||||
|
||||
def reset(self):
|
||||
"""Réinitialise l'environnement et renvoie l'état initial"""
|
||||
self.game.reset()
|
||||
self.state = self._get_observation()
|
||||
self.state_history = []
|
||||
self.steps_taken = 0
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Exécute une action et retourne (state, reward, done, info)
|
||||
|
||||
Action format:
|
||||
{
|
||||
'action_type': 0/1/2, # 0=move, 1=mark, 2=go
|
||||
'move': [from1, to1, from2, to2] # Utilisé seulement si action_type=0
|
||||
}
|
||||
"""
|
||||
action_type = action['action_type']
|
||||
reward = 0
|
||||
done = False
|
||||
info = {}
|
||||
|
||||
# Vérifie que l'action est valide pour le joueur humain (id=1)
|
||||
player_id = self.game.get_active_player_id()
|
||||
is_agent_turn = player_id == 1 # L'agent joue toujours le joueur 1
|
||||
|
||||
if is_agent_turn:
|
||||
# Exécute l'action selon son type
|
||||
if action_type == 0: # Move
|
||||
from1, to1, from2, to2 = action['move']
|
||||
move_made = self.game.play_move(((from1, to1), (from2, to2)))
|
||||
if not move_made:
|
||||
# Pénaliser les mouvements invalides
|
||||
reward -= 2.0
|
||||
info['invalid_move'] = True
|
||||
else:
|
||||
# Petit bonus pour un mouvement valide
|
||||
reward += 0.1
|
||||
elif action_type == 1: # Mark
|
||||
points = self.game.calculate_points()
|
||||
marked = self.game.mark_points(points)
|
||||
if not marked:
|
||||
# Pénaliser les actions invalides
|
||||
reward -= 2.0
|
||||
info['invalid_mark'] = True
|
||||
else:
|
||||
# Bonus pour avoir marqué des points
|
||||
reward += 0.1 * points
|
||||
elif action_type == 2: # Go
|
||||
go_made = self.game.choose_go()
|
||||
if not go_made:
|
||||
# Pénaliser les actions invalides
|
||||
reward -= 2.0
|
||||
info['invalid_go'] = True
|
||||
else:
|
||||
# Petit bonus pour l'action valide
|
||||
reward += 0.1
|
||||
else:
|
||||
# Tour de l'adversaire
|
||||
self._play_opponent_turn()
|
||||
|
||||
# Vérifier si la partie est terminée
|
||||
if self.game.is_done():
|
||||
done = True
|
||||
winner = self.game.get_winner()
|
||||
if winner == 1:
|
||||
# Bonus si l'agent gagne
|
||||
reward += 10.0
|
||||
info['winner'] = 'agent'
|
||||
else:
|
||||
# Pénalité si l'adversaire gagne
|
||||
reward -= 5.0
|
||||
info['winner'] = 'opponent'
|
||||
|
||||
# Récompense basée sur la progression des trous
|
||||
agent_holes = self.game.get_score(1)
|
||||
opponent_holes = self.game.get_score(2)
|
||||
reward += 0.5 * (agent_holes - opponent_holes)
|
||||
|
||||
# Mettre à jour l'état
|
||||
new_state = self._get_observation()
|
||||
|
||||
# Vérifier les états répétés
|
||||
if self._is_state_repeating(new_state):
|
||||
reward -= 0.2 # Pénalité légère pour éviter les boucles
|
||||
info['repeating_state'] = True
|
||||
|
||||
# Ajouter l'état à l'historique
|
||||
self.state_history.append(self._get_state_id())
|
||||
|
||||
# Limiter la durée des parties
|
||||
self.steps_taken += 1
|
||||
if self.steps_taken >= self.max_steps:
|
||||
done = True
|
||||
info['timeout'] = True
|
||||
|
||||
# Comparer les scores en cas de timeout
|
||||
if agent_holes > opponent_holes:
|
||||
reward += 5.0
|
||||
info['winner'] = 'agent'
|
||||
elif opponent_holes > agent_holes:
|
||||
reward -= 2.0
|
||||
info['winner'] = 'opponent'
|
||||
|
||||
self.state = new_state
|
||||
return self.state, reward, done, info
|
||||
|
||||
def _play_opponent_turn(self):
|
||||
"""Simule le tour de l'adversaire avec la stratégie choisie"""
|
||||
player_id = self.game.get_active_player_id()
|
||||
|
||||
# Boucle tant qu'il est au tour de l'adversaire
|
||||
while player_id == 2 and not self.game.is_done():
|
||||
# Action selon l'étape du tour
|
||||
state_dict = self._get_state_dict()
|
||||
turn_stage = state_dict.get('turn_stage')
|
||||
|
||||
if turn_stage == 'RollDice' or turn_stage == 'RollWaiting':
|
||||
self.game.roll_dice()
|
||||
elif turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints':
|
||||
points = self.game.calculate_points()
|
||||
self.game.mark_points(points)
|
||||
elif turn_stage == 'HoldOrGoChoice':
|
||||
# Stratégie simple: toujours continuer (Go)
|
||||
self.game.choose_go()
|
||||
elif turn_stage == 'Move':
|
||||
available_moves = self.game.get_available_moves()
|
||||
if available_moves:
|
||||
if self.opponent_strategy == "random":
|
||||
# Choisir un mouvement au hasard
|
||||
move = available_moves[np.random.randint(0, len(available_moves))]
|
||||
else:
|
||||
# Par défaut, prendre le premier mouvement valide
|
||||
move = available_moves[0]
|
||||
self.game.play_move(move)
|
||||
|
||||
# Mise à jour de l'ID du joueur actif
|
||||
player_id = self.game.get_active_player_id()
|
||||
|
||||
def _get_observation(self):
|
||||
"""Convertit l'état du jeu en un format utilisable par l'apprentissage par renforcement"""
|
||||
state_dict = self._get_state_dict()
|
||||
|
||||
# Créer un tableau représentant le plateau
|
||||
board = np.zeros(self.MAX_FIELD, dtype=np.int8)
|
||||
|
||||
# Remplir les positions des pièces blanches (valeurs positives)
|
||||
white_positions = state_dict.get('white_positions', [])
|
||||
for pos, count in white_positions:
|
||||
if 1 <= pos <= self.MAX_FIELD:
|
||||
board[pos-1] = count
|
||||
|
||||
# Remplir les positions des pièces noires (valeurs négatives)
|
||||
black_positions = state_dict.get('black_positions', [])
|
||||
for pos, count in black_positions:
|
||||
if 1 <= pos <= self.MAX_FIELD:
|
||||
board[pos-1] = -count
|
||||
|
||||
# Créer l'observation complète
|
||||
observation = {
|
||||
'board': board,
|
||||
'active_player': state_dict.get('active_player', 0),
|
||||
'dice': np.array([
|
||||
state_dict.get('dice', (1, 1))[0],
|
||||
state_dict.get('dice', (1, 1))[1]
|
||||
]),
|
||||
'white_points': state_dict.get('white_points', 0),
|
||||
'white_holes': state_dict.get('white_holes', 0),
|
||||
'black_points': state_dict.get('black_points', 0),
|
||||
'black_holes': state_dict.get('black_holes', 0),
|
||||
'turn_stage': self._turn_stage_to_int(state_dict.get('turn_stage', 'RollDice')),
|
||||
}
|
||||
|
||||
return observation
|
||||
|
||||
def _get_state_dict(self) -> Dict:
|
||||
"""Récupère l'état du jeu sous forme de dictionnaire depuis le module Rust"""
|
||||
return self.game.get_state_dict()
|
||||
|
||||
def _get_state_id(self) -> str:
|
||||
"""Récupère l'identifiant unique de l'état actuel"""
|
||||
return self.game.get_state_id()
|
||||
|
||||
def _is_state_repeating(self, new_state) -> bool:
|
||||
"""Vérifie si l'état se répète trop souvent"""
|
||||
state_id = self.game.get_state_id()
|
||||
# Compter les occurrences de l'état dans l'historique récent
|
||||
count = sum(1 for s in self.state_history[-10:] if s == state_id)
|
||||
return count >= 3 # Considéré comme répétitif si l'état apparaît 3 fois ou plus
|
||||
|
||||
def _turn_stage_to_int(self, turn_stage: str) -> int:
|
||||
"""Convertit l'étape du tour en entier pour l'observation"""
|
||||
stages = {
|
||||
'RollDice': 0,
|
||||
'RollWaiting': 1,
|
||||
'MarkPoints': 2,
|
||||
'HoldOrGoChoice': 3,
|
||||
'Move': 4,
|
||||
'MarkAdvPoints': 5
|
||||
}
|
||||
return stages.get(turn_stage, 0)
|
||||
|
||||
def render(self, mode="human"):
|
||||
"""Affiche l'état actuel du jeu"""
|
||||
if mode == "human":
|
||||
print(str(self.game))
|
||||
print(f"État actuel: {self._get_state_id()}")
|
||||
|
||||
# Afficher les actions possibles
|
||||
if self.game.get_active_player_id() == 1:
|
||||
turn_stage = self._get_state_dict().get('turn_stage')
|
||||
print(f"Étape: {turn_stage}")
|
||||
|
||||
if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice':
|
||||
print("Mouvements possibles:")
|
||||
moves = self.game.get_available_moves()
|
||||
for i, move in enumerate(moves):
|
||||
print(f" {i}: {move}")
|
||||
|
||||
if turn_stage == 'HoldOrGoChoice':
|
||||
print("Option: Go (continuer)")
|
||||
|
||||
def get_action_mask(self):
|
||||
"""Retourne un masque des actions valides dans l'état actuel"""
|
||||
state_dict = self._get_state_dict()
|
||||
turn_stage = state_dict.get('turn_stage')
|
||||
|
||||
# Masque par défaut (toutes les actions sont invalides)
|
||||
mask = {
|
||||
'action_type': np.zeros(3, dtype=bool),
|
||||
'move': np.zeros((self.MAX_FIELD + 1, self.MAX_FIELD + 1,
|
||||
self.MAX_FIELD + 1, self.MAX_FIELD + 1), dtype=bool)
|
||||
}
|
||||
|
||||
if self.game.get_active_player_id() != 1:
|
||||
return mask # Pas au tour de l'agent
|
||||
|
||||
# Activer les types d'actions valides selon l'étape du tour
|
||||
if turn_stage == 'Move' or turn_stage == 'HoldOrGoChoice':
|
||||
mask['action_type'][0] = True # Activer l'action de mouvement
|
||||
|
||||
# Activer les mouvements valides
|
||||
valid_moves = self.game.get_available_moves()
|
||||
for ((from1, to1), (from2, to2)) in valid_moves:
|
||||
mask['move'][from1, to1, from2, to2] = True
|
||||
|
||||
if turn_stage == 'MarkPoints' or turn_stage == 'MarkAdvPoints':
|
||||
mask['action_type'][1] = True # Activer l'action de marquer des points
|
||||
|
||||
if turn_stage == 'HoldOrGoChoice':
|
||||
mask['action_type'][2] = True # Activer l'action de continuer (Go)
|
||||
|
||||
return mask
|
||||
|
||||
def sample_valid_action(self):
|
||||
"""Échantillonne une action valide selon le masque d'actions"""
|
||||
mask = self.get_action_mask()
|
||||
|
||||
# Trouver les types d'actions valides
|
||||
valid_action_types = np.where(mask['action_type'])[0]
|
||||
|
||||
if len(valid_action_types) == 0:
|
||||
# Aucune action valide (pas le tour de l'agent)
|
||||
return {
|
||||
'action_type': 0,
|
||||
'move': np.zeros(4, dtype=np.int32)
|
||||
}
|
||||
|
||||
# Choisir un type d'action
|
||||
action_type = np.random.choice(valid_action_types)
|
||||
|
||||
action = {
|
||||
'action_type': action_type,
|
||||
'move': np.zeros(4, dtype=np.int32)
|
||||
}
|
||||
|
||||
# Si c'est un mouvement, sélectionner un mouvement valide
|
||||
if action_type == 0:
|
||||
valid_moves = np.where(mask['move'])
|
||||
if len(valid_moves[0]) > 0:
|
||||
# Sélectionner un mouvement valide aléatoirement
|
||||
idx = np.random.randint(0, len(valid_moves[0]))
|
||||
from1 = valid_moves[0][idx]
|
||||
to1 = valid_moves[1][idx]
|
||||
from2 = valid_moves[2][idx]
|
||||
to2 = valid_moves[3][idx]
|
||||
action['move'] = np.array([from1, to1, from2, to2], dtype=np.int32)
|
||||
|
||||
return action
|
||||
|
||||
def close(self):
|
||||
"""Nettoie les ressources à la fermeture de l'environnement"""
|
||||
pass
|
||||
|
||||
# Exemple d'utilisation avec Stable-Baselines3
|
||||
def example_usage():
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
# Fonction d'enveloppement pour créer l'environnement
|
||||
def make_env():
|
||||
return TricTracEnv()
|
||||
|
||||
# Créer un environnement vectorisé (peut être parallélisé)
|
||||
env = DummyVecEnv([make_env])
|
||||
|
||||
# Créer le modèle
|
||||
model = PPO("MultiInputPolicy", env, verbose=1)
|
||||
|
||||
# Entraîner le modèle
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
# Sauvegarder le modèle
|
||||
model.save("trictrac_ppo")
|
||||
|
||||
print("Entraînement terminé et modèle sauvegardé")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tester l'environnement
|
||||
env = TricTracEnv()
|
||||
obs = env.reset()
|
||||
|
||||
print("Environnement initialisé")
|
||||
env.render()
|
||||
|
||||
# Jouer quelques coups aléatoires
|
||||
for _ in range(10):
|
||||
action = env.sample_valid_action()
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
print(f"\nAction: {action}")
|
||||
print(f"Reward: {reward}")
|
||||
print(f"Info: {info}")
|
||||
env.render()
|
||||
|
||||
if done:
|
||||
print("Game over!")
|
||||
break
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
Ces implémentations vous permettront de créer un environnement Gym pour votre jeu Trictrac qui:
|
||||
|
||||
1. Expose l'état du jeu dans un format approprié pour l'apprentissage par renforcement
|
||||
2. Gère le déroulement des parties et le tour de l'adversaire
|
||||
3. Définit un système de récompenses pour l'entraînement de l'agent
|
||||
4. Fournit des masques d'actions pour garantir que seules les actions valides sont sélectionnées
|
||||
5. Inclut des exemples pour l'intégration avec Stable-Baselines3
|
||||
|
||||
Ce code nécessitera probablement quelques ajustements en fonction de la façon dont votre code Rust est organisé, mais il devrait vous donner une base solide pour commencer.
|
||||
|
||||
Pour utiliser cette implémentation :
|
||||
|
||||
1. Mettez à jour `engine.rs` pour exposer l'API nécessaire à Python
|
||||
2. Compilez la bibliothèque Rust avec maturin : `maturin build -m store/Cargo.toml --release`
|
||||
3. Implémentez `trictracEnv.py` avec l'environnement Gym
|
||||
4. Créez un script d'entraînement qui utilise Stable-Baselines3 ou une autre bibliothèque d'apprentissage par renforcement
|
||||
|
||||
Voulez-vous que j'explique plus en détail certaines parties de l'implémentation ou que je vous aide à résoudre des problèmes spécifiques ?
|
||||
268
doc/refs/claudeAIquestionOnlyRust.md
Normal file
268
doc/refs/claudeAIquestionOnlyRust.md
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
# Description
|
||||
|
||||
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
|
||||
Pour le moment ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
|
||||
|
||||
Les règles du jeu et l'état d'une partie sont implémentées dans 'store', l'application ligne de commande est implémentée dans 'client_cli', elle permet de jouer contre un bot, ou de faire jouer deux bots l'un contre l'autre.
|
||||
Les stratégies de bots sont implémentées dans le dossier 'bot'.
|
||||
|
||||
L'état du jeu est défini par le struct GameState dans store/src/game.rs, la méthode to_string_id() permet de coder cet état de manière compacte dans une chaîne de caractères, mais il n'y a pas l'historique des coups joués. Il y a aussi fmt::Display d'implémenté pour une representation textuelle plus lisible.
|
||||
|
||||
'client_cli/src/game_runner.rs' contient la logique permettant de faire jouer deux bots l'un contre l'autre.
|
||||
'bot/src/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
|
||||
|
||||
Dans une précédente session tu m'avais aidé à ajouter une stratégie de bot plus forte en entrainant un agent/bot avec OpenAI Gym en utilisant PyO3 et Stable-Baselines3. Tu m'avais fourni le code pour `store/src/engine.rs` et `store/python/trictracEnv.py`.
|
||||
|
||||
Ça fonctionne, mais je voudrais maintenant passer à une solution entièrement rust.
|
||||
Je pensais utiliser burn et l'algorithme DQN implémenté dans s3-burn ( <https://github.com/will-maclean/sb3-burn> ). Peux-tu m'aider à faire ça ?
|
||||
|
||||
----- fin d'implémentation :
|
||||
|
||||
Parfait ! Le jeu fonctionne maintenant avec la stratégie DQN ! Le bot DQN a même gagné la partie contre le bot
|
||||
dummy.
|
||||
|
||||
Résumé
|
||||
|
||||
J'ai réussi à migrer l'implémentation d'une solution Python/PyO3/Stable-Baselines3 vers une solution entièrement
|
||||
Rust avec DQN. Voici ce qui a été accompli :
|
||||
|
||||
✅ Ce qui fonctionne maintenant :
|
||||
|
||||
1. Suppression complète de PyO3 du module store pour éviter les conflits
|
||||
2. Implémentation DQN native en Rust avec :
|
||||
|
||||
- Réseau de neurones simplifié (3 couches, ReLU)
|
||||
- Initialisation Xavier/Glorot des poids
|
||||
- Buffer de replay pour l'expérience
|
||||
- Stratégie epsilon-greedy
|
||||
- Sauvegarde/chargement de modèles en JSON
|
||||
|
||||
3. Intégration complète dans le CLI avec la commande :
|
||||
|
||||
- --bot dqn : DQN avec configuration par défaut
|
||||
- --bot dqn:/path/to/model : DQN avec modèle personnalisé
|
||||
|
||||
4. Compilation réussie sans dépendances Burn problématiques
|
||||
|
||||
🔄 Points à améliorer plus tard :
|
||||
|
||||
1. Entraînement complet : Actuellement la méthode train() simule l'entraînement. Il faudrait implémenter la
|
||||
backpropagation complète.
|
||||
2. Espace d'actions plus sophistiqué : L'espace d'actions est simplifié (3 actions). On pourrait l'étendre pour
|
||||
couvrir tous les mouvements possibles.
|
||||
3. Système de récompenses : Le système de récompenses est basique et pourrait être enrichi.
|
||||
4. Migration vers Burn : Quand les problèmes de compatibilité seront résolus, on pourra migrer vers Burn pour de
|
||||
meilleures performances.
|
||||
|
||||
La base DQN est maintenant fonctionnelle et entièrement en Rust ! 🎉
|
||||
|
||||
---
|
||||
|
||||
Ton implémentation démarre avec un bot non entraîné et il s'entraîne lorsqu'il , le modèle n'est pas entraîné de manière séparée ? On démarre avec un modèle non entraîné et il s'entraîne en même temps que le bot joue ?
|
||||
|
||||
Je voudrais que le fichier dqn.rs actuel ne contienne que le code nécessaire pour de charger un modèle entraîné et jouer contre les adversaires (c'est à dire l'implémentation de BotStrategy). Et donc que tu extraies le code qui permet d'entraîner le modèle dans un autre fichier, et que tu ajoutes une commande pour lancer l'entraînement.
|
||||
|
||||
---
|
||||
|
||||
● Parfait ! Testons maintenant un entraînement court :
|
||||
|
||||
● Bash(cd /home/henri/travaux/programmes/trictrac && cargo run --bin=train_dqn -- --episodes 10 --save-every 5)…
|
||||
⎿ Error: Command timed out after 2m 0.0s warning: unused import: `DqnConfig`
|
||||
--> bot/src/strategy/dqn.rs:5:25
|
||||
|
|
||||
5 | use super::dqn_common::{DqnConfig, SimpleNeuralNetwork, game_state_to_vector};
|
||||
| ^^^^^^^^^
|
||||
|
|
||||
= note: `#[warn(unused_imports)]` on by default
|
||||
|
||||
warning: unused import: `store::MoveRules`
|
||||
--> bot/src/strategy/erroneous_moves.rs:2:5
|
||||
… +175 lines (ctrl+r to see all)
|
||||
|
||||
● Il y a un problème dans la logique d'entraînement - il semble y avoir une boucle infinie avec des erreurs "rules not followed". Il faut corriger la logique du TrictracEnv. Le problème est probablement que l'environnement ne gère pas
|
||||
correctement les étapes du jeu.
|
||||
|
||||
● Update Todos
|
||||
⎿ ☒ Analyser le code existant avec PyO3/Stable-Baselines3
|
||||
☒ Examiner l'environnement Gym Python actuel
|
||||
☒ Créer un environnement DQN natif en Rust
|
||||
☒ Implémenter la stratégie bot DQN en Rust
|
||||
☒ Ajouter les dépendances Burn et sb3-burn
|
||||
☒ Restructurer le code DQN
|
||||
☒ Créer un module d'entraînement séparé
|
||||
☒ Ajouter une commande d'entraînement
|
||||
|
||||
● Parfait ! J'ai réussi à restructurer complètement le code DQN selon tes demandes :
|
||||
|
||||
✅ Restructuration terminée avec succès
|
||||
|
||||
📁 Structure du code :
|
||||
|
||||
1. dqn.rs - Contient uniquement la stratégie DqnStrategy pour charger et utiliser un modèle pré-entraîné
|
||||
2. dqn_common.rs - Code partagé (réseau de neurones, configuration, utilitaires)
|
||||
3. dqn_trainer.rs - Logique d'entraînement complète avec environnement et agent
|
||||
4. train_dqn - Binaire dédié pour lancer l'entraînement
|
||||
|
||||
🚀 Commandes disponibles :
|
||||
|
||||
Pour jouer avec un modèle DQN :
|
||||
cargo run --bin=client_cli -- --bot dqn,dummy # Sans modèle (fallback)
|
||||
cargo run --bin=client_cli -- --bot dqn:/path/to/model.json,dummy # Avec modèle
|
||||
|
||||
Pour entraîner un modèle DQN :
|
||||
cargo run --bin=train_dqn # Entraînement par défaut (1000 épisodes)
|
||||
cargo run --bin=train_dqn -- --help # Voir les options
|
||||
cargo run --bin=train_dqn -- --episodes 5000 --save-every 500
|
||||
|
||||
⚠️ Note sur l'environnement d'entraînement :
|
||||
|
||||
L'entraînement a révélé un problème dans la logique de l'environnement (boucle infinie avec "rules not followed"). C'est normal pour une première implémentation - l'environnement d'entraînement nécessite des ajustements pour bien
|
||||
gérer les règles du jeu Trictrac.
|
||||
|
||||
Le code est maintenant parfaitement séparé entre :
|
||||
|
||||
- Utilisation (stratégie simple qui charge un modèle)
|
||||
- Entraînement (environnement complet avec logique d'apprentissage)
|
||||
|
||||
---
|
||||
|
||||
● ✅ Espace d'actions étendu implémenté avec succès !
|
||||
|
||||
J'ai complètement étendu l'espace d'actions du DQN pour couvrir tous les choix possibles dans le jeu Trictrac :
|
||||
|
||||
🎯 Nouvel espace d'actions :
|
||||
|
||||
TrictracAction enum comprend :
|
||||
|
||||
- Roll - Lancer les dés (index 0)
|
||||
- Mark { points } - Marquer des points (indices 1-13 pour 0-12 points)
|
||||
- Go - Continuer après avoir gagné un trou (index 14)
|
||||
- Move { move1, move2 } - Mouvement de pions (indices 15+)
|
||||
- Encode (from, to) pour les deux pions
|
||||
- Espace total : 331,791 actions possibles
|
||||
|
||||
🧠 Améliorations intelligentes :
|
||||
|
||||
1. Masquage d'actions valides - Le DQN ne considère que les actions légales dans l'état actuel
|
||||
2. Encodage compact - Mouvements encodés en un seul entier pour efficacité
|
||||
3. Décodage automatique - Conversion transparente entre indices et actions
|
||||
4. Stratégie adaptative - La stratégie DQN utilise maintenant le bon type d'action selon le contexte
|
||||
|
||||
🔧 Fonctionnalités ajoutées :
|
||||
|
||||
- get_valid_actions() - Obtient les actions valides pour l'état actuel
|
||||
- TrictracAction::to_action_index() / from_action_index() - Conversion bidirectionnelle
|
||||
- sample_valid_action() - Sélection aléatoire d'une action valide
|
||||
- Réseau de neurones élargi (512 neurones cachés) pour gérer l'espace étendu
|
||||
|
||||
📊 Utilisation dans le jeu :
|
||||
|
||||
Le bot DQN peut maintenant :
|
||||
|
||||
- Choisir le nombre exact de points à marquer (pas seulement le maximum)
|
||||
- Sélectionner des mouvements spécifiques parmi toutes les combinaisons possibles
|
||||
- Décider intelligemment entre "Go" et les mouvements alternatifs
|
||||
|
||||
L'espace d'actions est maintenant parfaitement aligné avec la complexité réelle du jeu Trictrac ! 🎲
|
||||
|
||||
---
|
||||
|
||||
Pour réduire la taille de l'espace des actions, est-ce qu'on peut tenir compte de l'état du jeu pour le choix des actions ?
|
||||
Ça voudrait dire que pour le choix du mouvement par exemple, comme on connait la valeur des dés on a juste à choisir quel dé on joue en premier (2 choix), puis la position de départ du premier pion (25 choix), et la position de départ du deuxième pion (25 choix), la position de départ étant déduite de la valeur des dés.
|
||||
|
||||
....
|
||||
|
||||
● Parfait ! J'ai implémenté l'espace d'actions compactes basé sur le contexte du jeu. Voici les amélirations majeures apportées :
|
||||
|
||||
✅ Réduction drastique de l'espace d'actions :
|
||||
|
||||
- Avant : 331,791 actions possibles à chaque étape
|
||||
- Maintenant : 1-200 actions selon le contexte du jeu
|
||||
|
||||
✅ Actions contextuelles :
|
||||
|
||||
- Roll : Seulement quand c'est possible (1 action)
|
||||
- Mark : Seulement les points valides 0-12 (≤13 actions)
|
||||
- Go : Toujours 1 action
|
||||
- MoveChoice : 2 ordres de dés × ~10 positions de départ × ~10 positions = ~200 actions au lieu de 331,791
|
||||
|
||||
✅ Architecture modulaire :
|
||||
|
||||
- CompactAction : actions simplifiées basées sur le contexte
|
||||
- to_trictrac_action() : conversion automatique vers les actions complètes
|
||||
- get_valid_compact_actions() : génération contextuelle des actions valides
|
||||
- context_action_space_size() : taille dynamique de l'espace d'actions
|
||||
|
||||
L'approche utilise exactement votre suggestion : pour les mouvements, au lieu d'encoder toutes les combinaisons possibles, on choisit :
|
||||
|
||||
1. L'ordre des dés (2 choix max si dés différents)
|
||||
2. Les positions de départ (seulement celles avec des pions)
|
||||
3. Les destinations sont automatiquement calculées selon les règles
|
||||
|
||||
Cela réduit l'espace d'actions de 99.94% tout en gardant toute la richesse du jeu !
|
||||
|
||||
---
|
||||
|
||||
● Parfait ! J'ai créé le fichier train_burn_rl.rs qui utilise votre environnement Burn-RL corrigé.
|
||||
|
||||
Pour lancer l'entraînement avec Burn-RL :
|
||||
|
||||
cargo run --bin=train_burn_rl
|
||||
|
||||
Ou avec des options personnalisées :
|
||||
|
||||
cargo run --bin=train_burn_rl -- --episodes 2000 --save-every 200 --max-steps 1000
|
||||
|
||||
Caractéristiques de cet entraîneur :
|
||||
|
||||
✅ Utilise l'environnement Burn-RL que vous avez corrigé
|
||||
✅ Actions contextuelles via get_valid_actions()
|
||||
✅ Politique epsilon-greedy simple pour commencer
|
||||
✅ Statistiques détaillées avec moyennes mobiles
|
||||
✅ Configuration flexible via arguments CLI
|
||||
✅ Logging progressif pour suivre l'entraînement
|
||||
|
||||
Options disponibles :
|
||||
|
||||
- --episodes : nombre d'épisodes (défaut: 1000)
|
||||
- --save-every : fréquence d'affichage des stats (défaut: 100)
|
||||
- --max-steps : nombre max de steps par épisode (défaut: 500)
|
||||
- --help : aide complète
|
||||
|
||||
Cet entraîneur sert de base pour tester l'environnement Burn-RL. Une fois que tout fonctionne bien, on pourra y intégrer un vrai agent DQN avec réseaux de neurones !
|
||||
|
||||
---
|
||||
|
||||
Session où je n'ai cessé de recevoir ce message :
|
||||
|
||||
⎿ API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute. For details, refer to: <https://docs.anthropic.com/en/api/rate-limits>. You can see the response headers for current usage. Please reduce the prompt length or the maximum tokens requested, or try again later. You may also contact sales at <https://www.anthropic.com/contact-sales> to discuss your options for a rate limit increase."}}) · Retrying in 391 seconds… (attempt 1/10)
|
||||
|
||||
✶ Coaching… (403s · ↑ 382 tokens · esc to interrupt)
|
||||
|
||||
Pour à la fin de la session avoir ce résumé :
|
||||
|
||||
Total cost: $0.95
|
||||
Total duration (API): 1h 24m 22.8s
|
||||
Total duration (wall): 1h 43m 3.5s
|
||||
Total code changes: 746 lines added, 0 lines removed
|
||||
Token usage by model:
|
||||
claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
|
||||
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
||||
|
||||
Mais pourtant 2 millions indiqués dans la page usage : <https://console.anthropic.com/usage>, et 7.88 dollars de consommés sur <https://console.anthropic.com/cost>.
|
||||
|
||||
I just had a claude code session in which I kept having this error, even if the agent didn't seem to read a lot of files : API Error (429 {"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed the rate limit for your organization (813e6b21-ec6f-44c3-a7f0-408244105e5c) of 20,000 input tokens per minute.
|
||||
|
||||
at the end of the session the token usage and cost indicated was this :
|
||||
|
||||
Total cost: $0.95
|
||||
Total duration (API): 1h 24m 22.8s
|
||||
Total duration (wall): 1h 43m 3.5s
|
||||
Total code changes: 746 lines added, 0 lines removed
|
||||
Token usage by model:
|
||||
claude-3-5-haiku: 18.8k input, 443 output, 0 cache read, 0 cache write
|
||||
claude-sonnet: 10 input, 666 output, 0 cache read, 245.6k cache write
|
||||
|
||||
but the usage on the /usage page was 2,073,698 token in, and the cost on the /cost page was $7.90.
|
||||
|
||||
When looking at the costs csv file, it seems that it is the "input cache write 5m" that consumed nearly all the tokens ( $7,71 ). Is it a bug ?
|
||||
46
doc/refs/geminiQuestions.md
Normal file
46
doc/refs/geminiQuestions.md
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Description du projet et question
|
||||
|
||||
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
|
||||
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
|
||||
|
||||
Les règles du jeu et l'état d'une partie sont implémentées dans 'store', l'application ligne de commande est implémentée dans 'client_cli', elle permet déjà de jouer contre un bot, ou de faire jouer deux bots l'un contre l'autre.
|
||||
Les stratégies de bots sont implémentées dans le dossier 'bot'.
|
||||
|
||||
Plus précisément, l'état du jeu est défini par le struct GameState dans store/src/game.rs, la méthode to_string_id() permet de coder cet état de manière compacte dans une chaîne de caractères, mais il n'y a pas l'historique des coups joués. Il y a aussi fmt::Display d'implémenté pour une representation textuelle plus lisible.
|
||||
|
||||
'client_cli/src/game_runner.rs' contient la logique permettant de faire jouer deux bots l'un contre l'autre.
|
||||
'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
|
||||
|
||||
Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning.
|
||||
|
||||
Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`)
|
||||
Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`.
|
||||
|
||||
Une version, toujours DQN, mais en utilisant la bibliothèque burn (<https://burn.dev/>) est en cours de développement.
|
||||
|
||||
L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'.
|
||||
|
||||
Voici la sortie de l'entraînement lancé avec 'just trainbot' :
|
||||
|
||||
```
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1692.3148, "duration": 1000}
|
||||
> {"episode": 1, "reward": -361.6962, "duration": 1000}
|
||||
> {"episode": 2, "reward": -126.1013, "duration": 1000}
|
||||
> {"episode": 3, "reward": -36.8000, "duration": 1000}
|
||||
> {"episode": 4, "reward": -21.4997, "duration": 1000}
|
||||
> {"episode": 5, "reward": -8.3000, "duration": 1000}
|
||||
> {"episode": 6, "reward": 3.1000, "duration": 1000}
|
||||
> {"episode": 7, "reward": -21.5998, "duration": 1000}
|
||||
> {"episode": 8, "reward": -10.1999, "duration": 1000}
|
||||
> {"episode": 9, "reward": 3.1000, "duration": 1000}
|
||||
> {"episode": 10, "reward": 14.5002, "duration": 1000}
|
||||
> {"episode": 11, "reward": 10.7000, "duration": 1000}
|
||||
> {"episode": 12, "reward": -0.7000, "duration": 1000}
|
||||
|
||||
thread 'main' has overflowed its stack
|
||||
fatal runtime error: stack overflow
|
||||
error: Recipe `trainbot` was terminated on line 25 by signal 6
|
||||
```
|
||||
|
||||
Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug <https://github.com/yunjhongwu/burn-rl-examples/issues/40> , donc peut-être que le problème vient du paquet 'burl-rl'.
|
||||
|
|
@ -1,54 +1,46 @@
|
|||
# Inspirations
|
||||
|
||||
tools
|
||||
|
||||
- config clippy ?
|
||||
- bacon : tests runner (ou loom ?)
|
||||
|
||||
## Rust libs
|
||||
|
||||
cf. <https://blessed.rs/crates>
|
||||
cf. https://blessed.rs/crates
|
||||
|
||||
nombres aléatoires avec seed : <https://richard.dallaway.com/posts/2021-01-04-repeat-resume/>
|
||||
nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-repeat-resume/
|
||||
|
||||
- cli : <https://lib.rs/crates/pico-args> ( ou clap )
|
||||
- cli : https://lib.rs/crates/pico-args ( ou clap )
|
||||
- reseau async : tokio
|
||||
- web serveur : axum (uses tokio)
|
||||
- <https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling>
|
||||
- https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling
|
||||
- db : sqlx
|
||||
|
||||
|
||||
- eyre, color-eyre (Results)
|
||||
- tracing (logging)
|
||||
- rayon ( sync <-> parallel )
|
||||
|
||||
- front : yew + tauri
|
||||
|
||||
- egui
|
||||
|
||||
- <https://docs.rs/board-game/latest/board_game/>
|
||||
|
||||
## network games
|
||||
|
||||
- <https://www.mattkeeter.com/projects/pont/>
|
||||
- <https://github.com/jackadamson/onitama> (wasm, rooms)
|
||||
- <https://github.com/UkoeHB/renet2>
|
||||
- <https://github.com/UkoeHB/bevy_simplenet>
|
||||
- https://docs.rs/board-game/latest/board_game/
|
||||
|
||||
## Others
|
||||
|
||||
- plugins avec <https://github.com/extism/extism>
|
||||
- plugins avec https://github.com/extism/extism
|
||||
|
||||
## Backgammon existing projects
|
||||
|
||||
- go : <https://bgammon.org/blog/20240101-hello-world/>
|
||||
- protocole de communication : <https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md>
|
||||
- ocaml : <https://github.com/jacobhilton/backgammon?tab=readme-ov-file>
|
||||
cli example : <https://www.jacobh.co.uk/backgammon/>
|
||||
- lib rust backgammon
|
||||
- <https://github.com/carlostrub/backgammon>
|
||||
- <https://github.com/marktani/backgammon>
|
||||
- network webtarot
|
||||
- front ?
|
||||
* go : https://bgammon.org/blog/20240101-hello-world/
|
||||
- protocole de communication : https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md
|
||||
* ocaml : https://github.com/jacobhilton/backgammon?tab=readme-ov-file
|
||||
cli example : https://www.jacobh.co.uk/backgammon/
|
||||
* lib rust backgammon
|
||||
- https://github.com/carlostrub/backgammon
|
||||
- https://github.com/marktani/backgammon
|
||||
* network webtarot
|
||||
* front ?
|
||||
|
||||
|
||||
## cli examples
|
||||
|
||||
|
|
@ -82,9 +74,7 @@ Player O estimates that they have a 90.6111% chance of winning.
|
|||
|
||||
Os borne off: none
|
||||
24 23 22 21 20 19 18 17 16 15 14 13
|
||||
|
||||
---
|
||||
|
||||
-------------------------------------------------------------------
|
||||
| v v v v v v | | v v v v v v |
|
||||
| | | |
|
||||
| X O O O | | O O O |
|
||||
|
|
@ -106,9 +96,7 @@ Os borne off: none
|
|||
| X X X | | X O O |
|
||||
| | | |
|
||||
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
|
||||
|
||||
---
|
||||
|
||||
-------------------------------------------------------------------
|
||||
1 2 3 4 5 6 7 8 9 10 11 12
|
||||
Xs borne off: none
|
||||
|
||||
|
|
@ -128,3 +116,4 @@ n - show the next move
|
|||
<enter> - toggle between showing the current and last moves
|
||||
help - show this help text
|
||||
quit - abandon game
|
||||
|
||||
|
|
|
|||
292
doc/research.md
292
doc/research.md
|
|
@ -1,292 +0,0 @@
|
|||
# Trictrac — Research Notes
|
||||
|
||||
## 1. Rust Engine: Module Map
|
||||
|
||||
| Module | Responsibility |
|
||||
| ---------------------- | ------------------------------------------------------------------------- |
|
||||
| `board.rs` | Board representation, checker manipulation, quarter analysis |
|
||||
| `dice.rs` | `Dice` struct, `DiceRoller`, bit encoding |
|
||||
| `player.rs` | `Player` struct (score, bredouille), `Color`, `PlayerId`, `CurrentPlayer` |
|
||||
| `game.rs` | `GameState` state machine, `GameEvent` enum, `Stage`/`TurnStage` |
|
||||
| `game_rules_moves.rs` | `MoveRules`: move validation and generation |
|
||||
| `game_rules_points.rs` | `PointsRules`: jan detection and scoring |
|
||||
| `training_common.rs` | `TrictracAction` enum, action-space encoding (size 514) |
|
||||
| `pyengine.rs` | PyO3 Python module exposing `TricTrac` class |
|
||||
| `lib.rs` | Crate root, re-exports |
|
||||
|
||||
---
|
||||
|
||||
## 2. Board Representation
|
||||
|
||||
```rust
|
||||
pub struct Board {
|
||||
positions: [i8; 24],
|
||||
}
|
||||
```
|
||||
|
||||
- 24 fields indexed 0–23 internally, 1–24 externally.
|
||||
- Positive values = White checkers on that field; negative = Black.
|
||||
- Initial state: `[15, 0, ..., 0, -15]` — all 15 white pieces on field 1, all 15 black pieces on field 24.
|
||||
- Field 0 is a sentinel for "exited the board" (never stored in the array).
|
||||
|
||||
**Mirroring** is the central symmetry operation used throughout:
|
||||
|
||||
```rust
|
||||
pub fn mirror(&self) -> Self {
|
||||
let mut positions = self.positions.map(|c| 0 - c);
|
||||
positions.reverse();
|
||||
Board { positions }
|
||||
}
|
||||
```
|
||||
|
||||
This negates all values (swapping who owns each checker) and reverses the array (swapping directions). The entire engine always reasons from White's perspective; Black's moves are handled by mirroring the board first.
|
||||
|
||||
**Quarter structure**: fields 1–6, 7–12, 13–18, 19–24. This maps to the four tables of Trictrac:
|
||||
|
||||
- 1–6: White's "petit jan" (own table)
|
||||
- 7–12: White's "grand jan"
|
||||
- 13–18: Black's "grand jan" (= White's opponent territory)
|
||||
- 19–24: Black's "petit jan" / White's "jan de retour"
|
||||
|
||||
The "coin de repos" (rest corner) is field 12 for White, field 13 for Black.
|
||||
|
||||
---
|
||||
|
||||
## 3. Dice
|
||||
|
||||
```rust
|
||||
pub struct Dice {
|
||||
pub values: (u8, u8),
|
||||
}
|
||||
```
|
||||
|
||||
Dice are always a pair (never quadrupled for doubles, unlike Backgammon). The `DiceRoller` uses `StdRng` seeded from OS entropy (or an optional fixed seed for tests). Bit encoding: `"{d1:0>3b}{d2:0>3b}"` — 3 bits each, 6 bits total.
|
||||
|
||||
---
|
||||
|
||||
## 4. Player State
|
||||
|
||||
```rust
|
||||
pub struct Player {
|
||||
pub name: String,
|
||||
pub color: Color, // White or Black
|
||||
pub points: u8, // 0–11 (points within current hole)
|
||||
pub holes: u8, // holes won (game ends at >12)
|
||||
pub can_bredouille: bool,
|
||||
pub can_big_bredouille: bool,
|
||||
pub dice_roll_count: u8, // rolls since last new_pick_up()
|
||||
}
|
||||
```
|
||||
|
||||
`PlayerId` is a `u64` alias. Player 1 = White, Player 2 = Black (set at init time; this is fixed for the session in pyengine).
|
||||
|
||||
---
|
||||
|
||||
## 5. Game State Machine
|
||||
|
||||
### Stages
|
||||
|
||||
```rust
|
||||
pub enum Stage { PreGame, InGame, Ended }
|
||||
|
||||
pub enum TurnStage {
|
||||
RollDice, // 1 — player must request a roll
|
||||
RollWaiting, // 0 — waiting for dice result from outside
|
||||
MarkPoints, // 2 — points are being marked (schools mode only)
|
||||
HoldOrGoChoice, // 3 — player won a hole; choose to Go or Hold
|
||||
Move, // 4 — player must move checkers
|
||||
MarkAdvPoints, // 5 — mark opponent's points after the move (schools mode)
|
||||
}
|
||||
```
|
||||
|
||||
### Turn lifecycle (schools disabled — the default)
|
||||
|
||||
```
|
||||
RollWaiting
|
||||
│ RollResult → auto-mark points
|
||||
├─[no hole]──→ Move
|
||||
│ │ Move → mark opponent's points → switch player
|
||||
│ └───────────────────────────────→ RollDice (next player)
|
||||
└─[hole won]─→ HoldOrGoChoice
|
||||
├─ Go ──→ new_pick_up() → RollDice (same player)
|
||||
└─ Move ──→ mark opponent's points → switch player → RollDice
|
||||
```
|
||||
|
||||
In schools mode (`schools_enabled = true`), the player explicitly marks their own points (`Mark` event) and then the opponent's points after moving (`MarkAdvPoints` stage).
|
||||
|
||||
### Key events
|
||||
|
||||
```rust
|
||||
pub enum GameEvent {
|
||||
BeginGame { goes_first: PlayerId },
|
||||
EndGame { reason: EndGameReason },
|
||||
PlayerJoined { player_id, name },
|
||||
PlayerDisconnected { player_id },
|
||||
Roll { player_id }, // triggers RollWaiting
|
||||
RollResult { player_id, dice }, // provides dice values
|
||||
Mark { player_id, points }, // explicit point marking (schools mode)
|
||||
Go { player_id }, // choose to restart position after hole
|
||||
Move { player_id, moves: (CheckerMove, CheckerMove) },
|
||||
PlayError,
|
||||
}
|
||||
```
|
||||
|
||||
### Initialization in pyengine
|
||||
|
||||
```rust
|
||||
fn new() -> Self {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
TricTrac { game_state }
|
||||
}
|
||||
```
|
||||
|
||||
Player 1 (White) always goes first. `active_player_id` uses 1-based indexing; pyengine converts to 0-based for the Python side with `active_player_id - 1`.
|
||||
|
||||
---
|
||||
|
||||
## 6. Scoring System (Jans)
|
||||
|
||||
Points are awarded after each dice roll based on "jans" (scoring events) detected by `PointsRules`. All computation assumes White's perspective (board is mirrored for Black before calling).
|
||||
|
||||
### Jan types
|
||||
|
||||
| Jan | Points (normal / doublet) | Direction |
|
||||
| ----------------------- | ------------------------- | --------------- |
|
||||
| `TrueHitSmallJan` | 4 / 6 | → active player |
|
||||
| `TrueHitBigJan` | 2 / 4 | → active player |
|
||||
| `TrueHitOpponentCorner` | 4 / 6 | → active player |
|
||||
| `FilledQuarter` | 4 / 6 | → active player |
|
||||
| `FirstPlayerToExit` | 4 / 6 | → active player |
|
||||
| `SixTables` | 4 / 6 | → active player |
|
||||
| `TwoTables` | 4 / 6 | → active player |
|
||||
| `Mezeas` | 4 / 6 | → active player |
|
||||
| `FalseHitSmallJan` | −4 / −6 | → opponent |
|
||||
| `FalseHitBigJan` | −2 / −4 | → opponent |
|
||||
| `ContreTwoTables` | −4 / −6 | → opponent |
|
||||
| `ContreMezeas` | −4 / −6 | → opponent |
|
||||
| `HelplessMan` | −2 / −4 | → opponent |
|
||||
|
||||
A single roll can trigger multiple jans, each scored independently. The jan detection process:
|
||||
|
||||
1. Try both dice orderings
|
||||
2. Detect "tout d'une" (combined dice move as a virtual single die)
|
||||
3. Prefer true hits over false hits for the same move
|
||||
4. Check quarter-filling opportunities
|
||||
5. Check rare jans (SixTables at roll 3, TwoTables, Mezeas) given specific board positions and talon counts
|
||||
|
||||
### Hole scoring
|
||||
|
||||
```rust
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
let sum_points = p.points + points;
|
||||
let jeux = sum_points / 12; // number of completed holes
|
||||
let holes = match (jeux, p.can_bredouille) {
|
||||
(0, _) => 0,
|
||||
(_, false) => 2 * jeux - 1, // no bredouille bonus
|
||||
(_, true) => 2 * jeux, // bredouille doubles the holes
|
||||
};
|
||||
p.points = sum_points % 12;
|
||||
p.holes += holes;
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
- 12 points = 1 "jeu", which yields 1 or 2 holes depending on bredouille status.
|
||||
- Scoring any points clears the opponent's `can_bredouille`.
|
||||
- Completing a hole resets `can_bredouille` for the scorer.
|
||||
- Game ends when `holes > 12`.
|
||||
- Score reported to OpenSpiel: `holes * 12 + points`.
|
||||
|
||||
### Points from both rolls
|
||||
|
||||
After a roll, the active player's points (`dice_points.0`) are auto-marked immediately. After the Move, the opponent's points (`dice_points.1`) are marked (they were computed at roll-time from the pre-move board).
|
||||
|
||||
---
|
||||
|
||||
## 7. Move Rules
|
||||
|
||||
`MoveRules` always works from White's perspective. Key constraints enforced by `moves_allowed()`:
|
||||
|
||||
1. **Opponent's corner forbidden**: Cannot land on field 13 (opponent's rest corner for White).
|
||||
2. **Corner needs two checkers**: The rest corner (field 12) must be taken or vacated with exactly 2 checkers simultaneously.
|
||||
3. **Corner by effect vs. by power**: If the corner can be taken directly ("par effet"), you cannot take it "par puissance" (using combined dice).
|
||||
4. **Exit preconditions**: All checkers must be in fields 19–24 before any exit is allowed.
|
||||
5. **Exit by effect priority**: If a normal exit is possible, exceedant moves (using overflow) are forbidden.
|
||||
6. **Farthest checker first**: When exiting with exceedant, must exit the checker at the highest field.
|
||||
7. **Must play all dice**: If both dice can be played, playing only one is invalid.
|
||||
8. **Must play strongest die**: If only one die can be played, it must be the higher value die.
|
||||
9. **Must fill quarter**: If a quarter can be completed, the move must complete it.
|
||||
10. **Cannot block opponent's fillable quarter**: Cannot move into a quarter the opponent can still fill.
|
||||
|
||||
The board state after each die application is simulated to check two-step sequences.
|
||||
|
||||
---
|
||||
|
||||
## 8. Action Space (training_common.rs)
|
||||
|
||||
Total size: **514 actions**.
|
||||
|
||||
| Index | Action | Description |
|
||||
| ------- | ------------------------------------------------ | ---------------------------------------------- |
|
||||
| 0 | `Roll` | Request dice roll (not used in OpenSpiel mode) |
|
||||
| 1 | `Go` | After winning hole: reset board and continue |
|
||||
| 2–257 | `Move { dice_order: true, checker1, checker2 }` | Move with die[0] first |
|
||||
| 258–513 | `Move { dice_order: false, checker1, checker2 }` | Move with die[1] first |
|
||||
|
||||
Move encoding: `index = 2 + (0 if dice_order else 256) + checker1 * 16 + checker2`
|
||||
|
||||
`checker1` and `checker2` are **ordinal positions** (1-based) of specific checkers counted left-to-right across all White-occupied fields, not field indices. Checker 0 = "no move" (empty move). Range: 0–15 (16 values each).
|
||||
|
||||
### Mirror pattern in get_legal_actions / apply_action
|
||||
|
||||
For player 2 (Black):
|
||||
|
||||
```rust
|
||||
// get_legal_actions: mirror game state before computing
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
|
||||
// apply_action: convert action → event on mirrored state, then mirror the event back
|
||||
a.to_event(&self.game_state.mirror())
|
||||
.map(|e| e.get_mirror(false))
|
||||
```
|
||||
|
||||
This ensures Black's actions are computed as if Black were White on a mirrored board, then translated back to real-board coordinates.
|
||||
|
||||
---
|
||||
|
||||
## 9. Known Issues and Inconsistencies
|
||||
|
||||
### 9.1 Color swap on new_pick_up disabled
|
||||
|
||||
In `game.rs:new_pick_up()`:
|
||||
|
||||
```rust
|
||||
// XXX : switch colors
|
||||
// désactivé pour le moment car la vérification des mouvements échoue,
|
||||
// cf. https://code.rhumbs.fr/henri/trictrac/issues/31
|
||||
// p.color = p.color.opponent_color();
|
||||
```
|
||||
|
||||
In authentic Trictrac, players swap colors between "relevés" (pick-ups after a hole is won with Go). This is commented out, so the same player always plays White and the same always plays Black throughout the entire game.
|
||||
|
||||
### 9.2 `can_big_bredouille` tracked but not implemented
|
||||
|
||||
The `can_big_bredouille` flag is stored in `Player` and serialized in state encoding, but the scoring logic never reads it. Grande bredouille (a rare extra bonus) is not implemented.
|
||||
|
||||
### 9.3 `get_valid_actions` panics on `RollWaiting`
|
||||
|
||||
```rust
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||
panic!("get_valid_actions not implemented for turn stage {:?}", ...)
|
||||
}
|
||||
```
|
||||
|
||||
If `get_legal_actions` were ever called while `needs_roll()` is true, this would panic.
|
||||
|
||||
### 9.4 Opponent points marked at pre-move board state
|
||||
|
||||
The opponent's `dice_points.1` is computed at roll time (before the active player moves), but applied to the opponent after the move. This means the opponent's scoring is evaluated on the board position that existed before the active player moved — which is per the rules of Trictrac (points are based on where pieces could be hit at the moment of the roll), but it's worth noting this subtlety.
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
@startuml
|
||||
|
||||
class "CheckerMove" {
|
||||
- from: Field
|
||||
- to: Field
|
||||
+ to_display_string()
|
||||
+ new(from: Field, to: Field)
|
||||
+ mirror()
|
||||
+ chain(cmove: Self)
|
||||
+ get_from()
|
||||
+ get_to()
|
||||
+ is_exit()
|
||||
+ doable_with_dice(dice: usize)
|
||||
}
|
||||
|
||||
class "Board" {
|
||||
- positions: [i8;24]
|
||||
+ new()
|
||||
+ mirror()
|
||||
+ set_positions(positions: [ i8 ; 24 ])
|
||||
+ count_checkers(color: Color, from: Field, to: Field)
|
||||
+ to_vec()
|
||||
+ to_gnupg_pos_id()
|
||||
+ to_display_grid(col_size: usize)
|
||||
+ set(color: & Color, field: Field, amount: i8)
|
||||
+ blocked(color: & Color, field: Field)
|
||||
+ passage_blocked(color: & Color, field: Field)
|
||||
+ get_field_checkers(field: Field)
|
||||
+ get_checkers_color(field: Field)
|
||||
+ is_field_in_small_jan(field: Field)
|
||||
+ get_color_fields(color: Color)
|
||||
+ get_color_corner(color: & Color)
|
||||
+ get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool)
|
||||
+ passage_possible(color: & Color, cmove: & CheckerMove)
|
||||
+ move_possible(color: & Color, cmove: & CheckerMove)
|
||||
+ any_quarter_filled(color: Color)
|
||||
+ is_quarter_filled(color: Color, field: Field)
|
||||
+ get_quarter_filling_candidate(color: Color)
|
||||
+ is_quarter_fillable(color: Color, field: Field)
|
||||
- get_quarter_fields(field: Field)
|
||||
+ move_checker(color: & Color, cmove: CheckerMove)
|
||||
+ remove_checker(color: & Color, field: Field)
|
||||
+ add_checker(color: & Color, field: Field)
|
||||
}
|
||||
|
||||
class "MoveRules" {
|
||||
+ board: Board
|
||||
+ dice: Dice
|
||||
+ new(color: & Color, board: & Board, dice: Dice)
|
||||
+ set_board(color: & Color, board: & Board)
|
||||
- get_board_from_color(color: & Color, board: & Board)
|
||||
+ moves_follow_rules(moves: & ( CheckerMove , CheckerMove ))
|
||||
- moves_possible(moves: & ( CheckerMove , CheckerMove ))
|
||||
- moves_follows_dices(moves: & ( CheckerMove , CheckerMove ))
|
||||
- get_move_compatible_dices(cmove: & CheckerMove)
|
||||
+ moves_allowed(moves: & ( CheckerMove , CheckerMove ))
|
||||
- check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
|
||||
- check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
|
||||
- check_corner_rules(moves: & ( CheckerMove , CheckerMove ))
|
||||
- has_checkers_outside_last_quarter()
|
||||
- check_exit_rules(moves: & ( CheckerMove , CheckerMove ))
|
||||
+ get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >)
|
||||
+ get_scoring_quarter_filling_moves_sequences()
|
||||
- get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field)
|
||||
+ get_quarter_filling_moves_sequences()
|
||||
- get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >)
|
||||
- _get_direct_exit_moves(state: & GameState)
|
||||
- is_move_by_puissance(moves: & ( CheckerMove , CheckerMove ))
|
||||
- can_take_corner_by_effect()
|
||||
}
|
||||
|
||||
class "DiceRoller" {
|
||||
- rng: StdRng
|
||||
+ new(opt_seed: Option < u64 >)
|
||||
+ roll()
|
||||
}
|
||||
|
||||
class "Dice" {
|
||||
+ values: (u8,u8)
|
||||
+ to_bits_string()
|
||||
+ to_display_string()
|
||||
+ is_double()
|
||||
}
|
||||
|
||||
class "GameState" {
|
||||
+ stage: Stage
|
||||
+ turn_stage: TurnStage
|
||||
+ board: Board
|
||||
+ active_player_id: PlayerId
|
||||
+ players: HashMap<PlayerId,Player>
|
||||
+ history: Vec<GameEvent>
|
||||
+ dice: Dice
|
||||
+ dice_points: (u8,u8)
|
||||
+ dice_moves: (CheckerMove,CheckerMove)
|
||||
+ dice_jans: PossibleJans
|
||||
- roll_first: bool
|
||||
+ schools_enabled: bool
|
||||
+ new(schools_enabled: bool)
|
||||
- set_schools_enabled(schools_enabled: bool)
|
||||
- get_active_player()
|
||||
- get_opponent_id()
|
||||
+ to_vec_float()
|
||||
+ to_vec()
|
||||
+ to_string_id()
|
||||
+ who_plays()
|
||||
+ get_white_player()
|
||||
+ get_black_player()
|
||||
+ player_id_by_color(color: Color)
|
||||
+ player_id(player: & Player)
|
||||
+ player_color_by_id(player_id: & PlayerId)
|
||||
+ validate(event: & GameEvent)
|
||||
+ init_player(player_name: & str)
|
||||
- add_player(player_id: PlayerId, player: Player)
|
||||
+ switch_active_player()
|
||||
+ consume(valid_event: & GameEvent)
|
||||
- new_pick_up()
|
||||
- get_rollresult_jans(dice: & Dice)
|
||||
+ determine_winner()
|
||||
- inc_roll_count(player_id: PlayerId)
|
||||
- mark_points(player_id: PlayerId, points: u8)
|
||||
}
|
||||
|
||||
class "Player" {
|
||||
+ name: String
|
||||
+ color: Color
|
||||
+ points: u8
|
||||
+ holes: u8
|
||||
+ can_bredouille: bool
|
||||
+ can_big_bredouille: bool
|
||||
+ dice_roll_count: u8
|
||||
+ new(name: String, color: Color)
|
||||
+ to_bits_string()
|
||||
+ to_vec()
|
||||
}
|
||||
|
||||
class "PointsRules" {
|
||||
+ board: Board
|
||||
+ dice: Dice
|
||||
+ move_rules: MoveRules
|
||||
+ new(color: & Color, board: & Board, dice: Dice)
|
||||
+ set_dice(dice: Dice)
|
||||
+ update_positions(positions: [ i8 ; 24 ])
|
||||
- get_jans(board_ini: & Board, dice_rolls_count: u8)
|
||||
+ get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >)
|
||||
+ get_points(dice_rolls_count: u8)
|
||||
+ get_result_jans(dice_rolls_count: u8)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
"MoveRules" <-- "Board"
|
||||
"MoveRules" <-- "Dice"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
"GameState" <-- "Board"
|
||||
"HashMap<PlayerId,Player>" <-- "Player"
|
||||
"GameState" <-- "HashMap<PlayerId,Player>"
|
||||
"GameState" <-- "Dice"
|
||||
|
||||
|
||||
|
||||
|
||||
"PointsRules" <-- "Board"
|
||||
"PointsRules" <-- "Dice"
|
||||
"PointsRules" <-- "MoveRules"
|
||||
|
||||
@enduml
|
||||
19
justfile
19
justfile
|
|
@ -9,7 +9,7 @@ shell:
|
|||
runcli:
|
||||
RUST_LOG=info cargo run --bin=client_cli
|
||||
runclibots:
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
|
||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||
match:
|
||||
|
|
@ -20,20 +20,17 @@ profile:
|
|||
cargo build --profile profiling
|
||||
samply record ./target/profiling/client_cli --bot dummy,dummy
|
||||
pythonlib:
|
||||
rm -rf target/wheels
|
||||
maturin build -m store/Cargo.toml --release
|
||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||
cxxlib:
|
||||
cargo build --release -p trictrac-store
|
||||
@echo "Static lib: $(ls target/release/libtrictrac_store.a)"
|
||||
@echo "CXX header: $(find target -name 'cxxengine.rs.h' | head -1)"
|
||||
trainbot algo:
|
||||
trainsimple:
|
||||
cargo build --release --bin=train_dqn_simple
|
||||
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out
|
||||
trainbot:
|
||||
#python ./store/python/trainModel.py
|
||||
# cargo run --bin=train_dqn # ok
|
||||
# ./bot/scripts/trainValid.sh
|
||||
./bot/scripts/train.sh {{algo}}
|
||||
plottrainbot algo:
|
||||
./bot/scripts/train.sh plot {{algo}}
|
||||
./bot/scripts/trainValid.sh
|
||||
plottrainbot:
|
||||
./bot/scripts/trainValid.sh plot
|
||||
debugtrainbot:
|
||||
cargo build --bin=train_dqn_burn
|
||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||
|
|
|
|||
14
server/Cargo.toml
Normal file
14
server/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "trictrac-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
store = { path = "../store" }
|
||||
env_logger = "0.10.0"
|
||||
log = "0.4.20"
|
||||
pico-args = "0.5.0"
|
||||
renet = "0.0.13"
|
||||
bincode = "1.3.3"
|
||||
147
server/src/main.rs
Normal file
147
server/src/main.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
use log::{info, trace, warn};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
|
||||
use renet::{
|
||||
transport::{
|
||||
NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
|
||||
},
|
||||
ConnectionConfig, RenetServer, ServerEvent,
|
||||
};
|
||||
|
||||
// Only clients that can provide the same PROTOCOL_ID that the server is using will be able to connect.
|
||||
// This can be used to make sure players use the most recent version of the client for instance.
|
||||
pub const PROTOCOL_ID: u64 = 2878;
|
||||
|
||||
/// Utility function for extracting a players name from renet user data
|
||||
fn name_from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> String {
|
||||
let mut buffer = [0u8; 8];
|
||||
buffer.copy_from_slice(&user_data[0..8]);
|
||||
let mut len = u64::from_le_bytes(buffer) as usize;
|
||||
len = len.min(NETCODE_USER_DATA_BYTES - 8);
|
||||
let data = user_data[8..len + 8].to_vec();
|
||||
String::from_utf8(data).unwrap()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
let mut server = RenetServer::new(ConnectionConfig::default());
|
||||
|
||||
// Setup transport layer
|
||||
const SERVER_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5000);
|
||||
let socket: UdpSocket = UdpSocket::bind(SERVER_ADDR).unwrap();
|
||||
let server_config = ServerConfig {
|
||||
max_clients: 2,
|
||||
protocol_id: PROTOCOL_ID,
|
||||
public_addr: SERVER_ADDR,
|
||||
authentication: ServerAuthentication::Unsecure,
|
||||
};
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
|
||||
|
||||
trace!("❂ TricTrac server listening on {}", SERVER_ADDR);
|
||||
|
||||
let mut game_state = store::GameState::default();
|
||||
let mut last_updated = Instant::now();
|
||||
loop {
|
||||
// Update server time
|
||||
let now = Instant::now();
|
||||
let delta_time = now - last_updated;
|
||||
server.update(delta_time);
|
||||
transport.update(delta_time, &mut server).unwrap();
|
||||
last_updated = now;
|
||||
|
||||
// Receive connection events from clients
|
||||
while let Some(event) = server.get_event() {
|
||||
match event {
|
||||
ServerEvent::ClientConnected { client_id } => {
|
||||
let user_data = transport.user_data(client_id).unwrap();
|
||||
|
||||
// Tell the recently joined player about the other player
|
||||
for (player_id, player) in game_state.players.iter() {
|
||||
let event = store::GameEvent::PlayerJoined {
|
||||
player_id: *player_id,
|
||||
name: player.name.clone(),
|
||||
};
|
||||
server.send_message(client_id, 0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
|
||||
// Add the new player to the game
|
||||
let event = store::GameEvent::PlayerJoined {
|
||||
player_id: client_id,
|
||||
name: name_from_user_data(&user_data),
|
||||
};
|
||||
game_state.consume(&event);
|
||||
|
||||
// Tell all players that a new player has joined
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
info!("🎉 Client {} connected.", client_id);
|
||||
// In TicTacTussle the game can begin once two players has joined
|
||||
if game_state.players.len() == 2 {
|
||||
let event = store::GameEvent::BeginGame {
|
||||
goes_first: client_id,
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
trace!("The game gas begun");
|
||||
}
|
||||
}
|
||||
ServerEvent::ClientDisconnected {
|
||||
client_id,
|
||||
reason: _,
|
||||
} => {
|
||||
// First consume a disconnect event
|
||||
let event = store::GameEvent::PlayerDisconnected {
|
||||
player_id: client_id,
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
info!("Client {} disconnected", client_id);
|
||||
|
||||
// Then end the game, since tic tac toe can't go on with a single player
|
||||
let event = store::GameEvent::EndGame {
|
||||
reason: store::EndGameReason::PlayerLeft {
|
||||
player_id: client_id,
|
||||
},
|
||||
};
|
||||
game_state.consume(&event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
// NOTE: Since we don't authenticate users we can't do any reconnection attempts.
|
||||
// We simply have no way to know if the next user is the same as the one that disconnected.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Receive GameEvents from clients. Broadcast valid events.
|
||||
for client_id in server.clients_id().into_iter() {
|
||||
while let Some(message) = server.receive_message(client_id, 0) {
|
||||
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
|
||||
if game_state.validate(&event) {
|
||||
game_state.consume(&event);
|
||||
trace!("Player {} sent:\n\t{:#?}", client_id, event);
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
|
||||
// Determine if a player has won the game
|
||||
if let Some(winner) = game_state.determine_winner() {
|
||||
let event = store::GameEvent::EndGame {
|
||||
reason: store::EndGameReason::PlayerWon { winner },
|
||||
};
|
||||
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||
}
|
||||
} else {
|
||||
warn!("Player {} sent invalid event:\n\t{:#?}", client_id, event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transport.send_packets(&mut server);
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,29 +1,20 @@
|
|||
[package]
|
||||
name = "trictrac-store"
|
||||
name = "store"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
name = "trictrac_store"
|
||||
# "cdylib" → Python .so built by maturin (pyengine)
|
||||
# "rlib" → used by other workspace crates (bot, client_cli)
|
||||
# "staticlib" → used by the C++ OpenSpiel game (cxxengine)
|
||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
name = "store"
|
||||
# Only "rlib" is needed for other Rust crates to use this library
|
||||
crate-type = ["rlib"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
base64 = "0.21.7"
|
||||
cxx = "1.0"
|
||||
# provides macros for creating log messages to be used by a logger (for example env_logger)
|
||||
log = "0.4.20"
|
||||
merge = "0.1.0"
|
||||
# generate python lib (with maturin) to be used in AI training
|
||||
pyo3 = { version = "0.23", features = ["extension-module", "abi3-py38"] }
|
||||
rand = "0.9"
|
||||
rand = "0.8.5"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
transpose = "0.2.2"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
fn main() {
|
||||
cxx_build::bridge("src/cxxengine.rs")
|
||||
.std("c++17")
|
||||
.compile("trictrac-cxx");
|
||||
|
||||
println!("cargo:rerun-if-changed=src/cxxengine.rs");
|
||||
}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.maturin]
|
||||
# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so)
|
||||
features = ["pyo3/extension-module"]
|
||||
# python-source = "python"
|
||||
|
|
@ -8,7 +8,7 @@ use std::fmt;
|
|||
pub type Field = usize;
|
||||
pub type FieldWithCount = (Field, i8);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Eq, Deserialize, Hash)]
|
||||
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Deserialize)]
|
||||
pub struct CheckerMove {
|
||||
from: Field,
|
||||
to: Field,
|
||||
|
|
@ -94,7 +94,7 @@ impl CheckerMove {
|
|||
}
|
||||
|
||||
/// Represents the Tric Trac board
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Board {
|
||||
positions: [i8; 24],
|
||||
}
|
||||
|
|
@ -158,42 +158,6 @@ impl Board {
|
|||
.unsigned_abs()
|
||||
}
|
||||
|
||||
// get the number of the last checker in a field
|
||||
pub fn get_field_checker(&self, color: &Color, field: Field) -> u8 {
|
||||
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
|
||||
let mut total_count: u8 = 0;
|
||||
for (i, checker_count) in self.positions.iter().enumerate() {
|
||||
// count white checkers (checker_count > 0)
|
||||
if *checker_count > 0 {
|
||||
total_count += *checker_count as u8;
|
||||
if field == i + 1 {
|
||||
return total_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
// get the field of the nth checker
|
||||
pub fn get_checker_field(&self, color: &Color, checker_pos: u8) -> Option<Field> {
|
||||
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
|
||||
if checker_pos == 0 {
|
||||
return None;
|
||||
}
|
||||
let mut total_count: u8 = 0;
|
||||
for (i, checker_count) in self.positions.iter().enumerate() {
|
||||
// count white checkers (checker_count > 0)
|
||||
if *checker_count > 0 {
|
||||
total_count += *checker_count as u8;
|
||||
}
|
||||
// return the current field if it contains the checker
|
||||
if checker_pos <= total_count {
|
||||
return Some(i + 1);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<i8> {
|
||||
self.positions.to_vec()
|
||||
}
|
||||
|
|
@ -271,7 +235,7 @@ impl Board {
|
|||
.map(|cells| {
|
||||
cells
|
||||
.into_iter()
|
||||
.map(|cell| format!("{cell:>5}"))
|
||||
.map(|cell| format!("{:>5}", cell))
|
||||
.collect::<Vec<String>>()
|
||||
.join("")
|
||||
})
|
||||
|
|
@ -282,7 +246,7 @@ impl Board {
|
|||
.map(|cells| {
|
||||
cells
|
||||
.into_iter()
|
||||
.map(|cell| format!("{cell:>5}"))
|
||||
.map(|cell| format!("{:>5}", cell))
|
||||
.collect::<Vec<String>>()
|
||||
.join("")
|
||||
})
|
||||
|
|
@ -439,7 +403,6 @@ impl Board {
|
|||
check_rest_corner_exit: bool,
|
||||
forbid_exits: bool,
|
||||
) -> Vec<CheckerMove> {
|
||||
// println!("------- board.get_possible_moves...");
|
||||
let mut moves = Vec::new();
|
||||
|
||||
let get_dest = |from| {
|
||||
|
|
@ -454,7 +417,6 @@ impl Board {
|
|||
}
|
||||
};
|
||||
|
||||
// let mut farthest_exit_move = 25;
|
||||
for (field, count) in self.get_color_fields(color) {
|
||||
// check rest corner exit
|
||||
if field == self.get_color_corner(&color) && count == 2 && check_rest_corner_exit {
|
||||
|
|
@ -465,11 +427,8 @@ impl Board {
|
|||
continue;
|
||||
}
|
||||
if !(0..25).contains(&dest) {
|
||||
// if with_excedants && !forbid_exits && field < farthest_exit_move && 2 < count {
|
||||
if with_excedants && !forbid_exits {
|
||||
dest = 0;
|
||||
// farthest_exit_move = field;
|
||||
// println!("farthest is now {farthest_exit_move}");
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -644,55 +603,6 @@ impl Board {
|
|||
self.positions[field - 1] += unit;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn from_gnupg_pos_id(bits: &str) -> Result<Board, String> {
|
||||
let mut positions = [0i8; 24];
|
||||
let mut bit_idx = 0;
|
||||
let bit_chars: Vec<char> = bits.chars().collect();
|
||||
|
||||
// White checkers (points 1 to 24)
|
||||
for i in 0..24 {
|
||||
if bit_idx >= bit_chars.len() {
|
||||
break;
|
||||
}
|
||||
let mut count = 0;
|
||||
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
|
||||
count += 1;
|
||||
bit_idx += 1;
|
||||
}
|
||||
positions[i] = count;
|
||||
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
|
||||
bit_idx += 1; // Consume the '0' separator
|
||||
}
|
||||
}
|
||||
|
||||
// Black checkers (points 24 down to 1)
|
||||
for i in (0..24).rev() {
|
||||
if bit_idx >= bit_chars.len() {
|
||||
break;
|
||||
}
|
||||
let mut count = 0;
|
||||
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
|
||||
count += 1;
|
||||
bit_idx += 1;
|
||||
}
|
||||
|
||||
if positions[i] == 0 {
|
||||
positions[i] = -count;
|
||||
} else if count > 0 {
|
||||
return Err(format!(
|
||||
"Invalid board: checkers of both colors on point {}",
|
||||
i + 1
|
||||
));
|
||||
}
|
||||
|
||||
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
|
||||
bit_idx += 1; // Consume the '0' separator
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Board { positions })
|
||||
}
|
||||
}
|
||||
|
||||
// Unit Tests
|
||||
|
|
@ -811,55 +721,4 @@ mod tests {
|
|||
);
|
||||
assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_checker_field() {
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&Color::White,
|
||||
[
|
||||
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
);
|
||||
assert_eq!(None, board.get_checker_field(&Color::White, 0));
|
||||
assert_eq!(Some(3), board.get_checker_field(&Color::White, 5));
|
||||
assert_eq!(Some(3), board.get_checker_field(&Color::White, 6));
|
||||
assert_eq!(None, board.get_checker_field(&Color::White, 14));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_field_checker() {
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&Color::White,
|
||||
[
|
||||
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
);
|
||||
assert_eq!(4, board.get_field_checker(&Color::White, 2));
|
||||
assert_eq!(6, board.get_field_checker(&Color::White, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_possible_moves() {
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&Color::White,
|
||||
[
|
||||
-8, -3, -1, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 3, 3,
|
||||
],
|
||||
);
|
||||
let moves = vec![
|
||||
CheckerMove::new(19, 22).unwrap(),
|
||||
CheckerMove::new(20, 23).unwrap(),
|
||||
CheckerMove::new(21, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
];
|
||||
assert_eq!(
|
||||
moves,
|
||||
board.get_possible_moves(Color::White, 3, true, true, false,)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,252 +0,0 @@
|
|||
//! C++ bindings for the TricTrac game engine via cxx.rs.
|
||||
//!
|
||||
//! Exposes an opaque `TricTracEngine` type to C++. The C++ side
|
||||
//! (open_spiel/games/trictrac/trictrac.cc) holds it via
|
||||
//! `rust::Box<trictrac_engine::TricTracEngine>`.
|
||||
//!
|
||||
//! The Rust engine always reasons from White's (player 1's) perspective.
|
||||
//! For Black (player 2), the board is mirrored before computing actions
|
||||
//! and events are mirrored back before being applied — exactly as in
|
||||
//! pyengine.rs.
|
||||
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||
|
||||
/// Catch any Rust panic and convert it to anyhow::Error so it never
|
||||
/// crosses the C FFI boundary as undefined behaviour.
|
||||
fn catch_panics<F, T>(f: F) -> anyhow::Result<T>
|
||||
where
|
||||
F: FnOnce() -> anyhow::Result<T> + panic::UnwindSafe,
|
||||
{
|
||||
panic::catch_unwind(f).unwrap_or_else(|e| {
|
||||
let msg = e
|
||||
.downcast_ref::<String>()
|
||||
.map(|s| s.as_str())
|
||||
.or_else(|| e.downcast_ref::<&str>().copied())
|
||||
.unwrap_or("unknown panic payload");
|
||||
Err(anyhow::anyhow!("Rust panic in FFI: {}", msg))
|
||||
})
|
||||
}
|
||||
|
||||
// ── cxx bridge declaration ────────────────────────────────────────────────────
|
||||
|
||||
#[cxx::bridge(namespace = "trictrac_engine")]
|
||||
pub mod ffi {
|
||||
// ── Shared types (transparent to both Rust and C++) ───────────────────────
|
||||
|
||||
/// Two dice values passed from C++ when applying a chance outcome.
|
||||
struct DicePair {
|
||||
die1: u8,
|
||||
die2: u8,
|
||||
}
|
||||
|
||||
/// Both players' cumulative scores: holes * 12 + points.
|
||||
struct PlayerScores {
|
||||
score_p1: i32,
|
||||
score_p2: i32,
|
||||
}
|
||||
|
||||
// ── Opaque Rust type and its free-function constructor ────────────────────
|
||||
|
||||
extern "Rust" {
|
||||
/// Opaque handle to a running TricTrac game.
|
||||
/// C++ accesses this only through `rust::Box<TricTracEngine>`.
|
||||
type TricTracEngine;
|
||||
|
||||
/// Construct a fresh engine with two players; player 1 (White) goes first.
|
||||
fn new_trictrac_engine() -> Box<TricTracEngine>;
|
||||
|
||||
/// Deep-copy the engine — required by OpenSpiel's State::Clone().
|
||||
fn clone_engine(self: &TricTracEngine) -> Box<TricTracEngine>;
|
||||
|
||||
// ── Queries ───────────────────────────────────────────────────────────
|
||||
|
||||
/// True when the game is in TurnStage::RollWaiting (OpenSpiel chance node).
|
||||
fn needs_roll(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// True when Stage::Ended.
|
||||
fn is_game_ended(self: &TricTracEngine) -> bool;
|
||||
|
||||
/// Active player index: 0 = player 1 (White), 1 = player 2 (Black).
|
||||
fn current_player_idx(self: &TricTracEngine) -> u64;
|
||||
|
||||
/// Legal action indices for `player_idx` in [0, 513].
|
||||
/// Returns an empty vector when it is not that player's turn.
|
||||
fn get_legal_actions(self: &TricTracEngine, player_idx: u64) -> Result<Vec<u64>>;
|
||||
|
||||
/// Human-readable description of an action index.
|
||||
fn action_to_string(self: &TricTracEngine, player_idx: u64, action_idx: u64) -> String;
|
||||
|
||||
/// Both players' scores.
|
||||
fn get_players_scores(self: &TricTracEngine) -> PlayerScores;
|
||||
|
||||
/// 36-element state vector (i8). Mirrored for player_idx == 1.
|
||||
fn get_tensor(self: &TricTracEngine, player_idx: u64) -> Vec<i8>;
|
||||
|
||||
/// Human-readable state description for `player_idx`.
|
||||
fn get_observation_string(self: &TricTracEngine, player_idx: u64) -> String;
|
||||
|
||||
/// Full debug representation of the current state.
|
||||
fn to_debug_string(self: &TricTracEngine) -> String;
|
||||
|
||||
// ── Mutations ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Apply a dice-roll result. Returns Err (C++ exception) if not in
|
||||
/// the RollWaiting stage.
|
||||
fn apply_dice_roll(self: &mut TricTracEngine, dice: DicePair) -> Result<()>;
|
||||
|
||||
/// Apply a player action. Returns Err (C++ exception) if the action
|
||||
/// is not legal in the current state.
|
||||
fn apply_action(self: &mut TricTracEngine, action_idx: u64) -> Result<()>;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Opaque type ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub struct TricTracEngine {
|
||||
game_state: GameState,
|
||||
}
|
||||
|
||||
// ── Free-function constructor (declared in the bridge as a plain function) ────
|
||||
|
||||
pub fn new_trictrac_engine() -> Box<TricTracEngine> {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
game_state
|
||||
.consume(&GameEvent::BeginGame { goes_first: 1 })
|
||||
.expect("BeginGame failed during engine initialization");
|
||||
Box::new(TricTracEngine { game_state })
|
||||
}
|
||||
|
||||
// ── Method implementations ────────────────────────────────────────────────────
|
||||
|
||||
impl TricTracEngine {
|
||||
fn clone_engine(&self) -> Box<TricTracEngine> {
|
||||
Box::new(TricTracEngine {
|
||||
game_state: self.game_state.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn needs_roll(&self) -> bool {
|
||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||
}
|
||||
|
||||
fn is_game_ended(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended
|
||||
}
|
||||
|
||||
fn current_player_idx(&self) -> u64 {
|
||||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_idx: u64) -> anyhow::Result<Vec<u64>> {
|
||||
if player_idx != self.current_player_idx() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
catch_panics(AssertUnwindSafe(|| {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state)
|
||||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror)
|
||||
.map(|v| v.into_iter().map(|i| i as u64).collect())
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn action_to_string(&self, player_idx: u64, action_idx: u64) -> String {
|
||||
TrictracAction::from_action_index(action_idx as usize)
|
||||
.map(|a| format!("{}:{}", player_idx, a))
|
||||
.unwrap_or_else(|| "unknown action".into())
|
||||
}
|
||||
|
||||
fn get_players_scores(&self) -> ffi::PlayerScores {
|
||||
ffi::PlayerScores {
|
||||
score_p1: self.score_for(1),
|
||||
score_p2: self.score_for(2),
|
||||
}
|
||||
}
|
||||
|
||||
fn score_for(&self, player_id: u64) -> i32 {
|
||||
self.game_state
|
||||
.players
|
||||
.get(&player_id)
|
||||
.map(|p| p.holes as i32 * 12 + p.points as i32)
|
||||
.unwrap_or(-1)
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_debug_string(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
|
||||
fn apply_dice_roll(&mut self, dice: ffi::DicePair) -> anyhow::Result<()> {
|
||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
anyhow::bail!(
|
||||
"apply_dice_roll: not in RollWaiting stage (currently {:?})",
|
||||
self.game_state.turn_stage
|
||||
);
|
||||
}
|
||||
let player_id = self.game_state.active_player_id;
|
||||
let dice = Dice {
|
||||
values: (dice.die1, dice.die2),
|
||||
};
|
||||
catch_panics(AssertUnwindSafe(|| {
|
||||
self.game_state
|
||||
.consume(&GameEvent::RollResult { player_id, dice })
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}))
|
||||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: u64) -> anyhow::Result<()> {
|
||||
catch_panics(AssertUnwindSafe(|| {
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
|
||||
let event = TrictracAction::from_action_index(action_idx as usize).and_then(|a| {
|
||||
let state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
});
|
||||
|
||||
match event {
|
||||
Some(evt) if self.game_state.validate(&evt) => self
|
||||
.game_state
|
||||
.consume(&evt)
|
||||
.map_err(|e| anyhow::anyhow!(e)),
|
||||
Some(evt) => anyhow::bail!(
|
||||
"apply_action: event {:?} is not valid in current state {}",
|
||||
evt,
|
||||
self.game_state
|
||||
),
|
||||
None => anyhow::bail!(
|
||||
"apply_action: could not build event from action index {}",
|
||||
action_idx
|
||||
),
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use rand::distr::{Distribution, Uniform};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ impl DiceRoller {
|
|||
pub fn new(opt_seed: Option<u64>) -> Self {
|
||||
Self {
|
||||
rng: match opt_seed {
|
||||
None => StdRng::from_rng(&mut rand::rng()),
|
||||
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
|
||||
Some(seed) => SeedableRng::seed_from_u64(seed),
|
||||
},
|
||||
}
|
||||
|
|
@ -26,7 +26,7 @@ impl DiceRoller {
|
|||
/// Roll the dices which generates two random numbers between 1 and 6, replicating a perfect
|
||||
/// dice. We use the operating system's random number generator.
|
||||
pub fn roll(&mut self) -> Dice {
|
||||
let between = Uniform::new_inclusive(1, 6).expect("1 > 6 !?");
|
||||
let between = Uniform::new_inclusive(1, 6);
|
||||
|
||||
let v = (between.sample(&mut self.rng), between.sample(&mut self.rng));
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ impl DiceRoller {
|
|||
/// Represents the two dice
|
||||
///
|
||||
/// Trictrac is always played with two dice.
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize, Default)]
|
||||
pub struct Dice {
|
||||
/// The two dice values
|
||||
pub values: (u8, u8),
|
||||
|
|
@ -55,17 +55,6 @@ impl Dice {
|
|||
format!("{:0>3b}{:0>3b}", self.values.0, self.values.1)
|
||||
}
|
||||
|
||||
pub fn from_bits_string(bits: &str) -> Result<Self, String> {
|
||||
if bits.len() != 6 {
|
||||
return Err("Invalid bit string length for dice".to_string());
|
||||
}
|
||||
let d1_str = &bits[0..3];
|
||||
let d2_str = &bits[3..6];
|
||||
let d1 = u8::from_str_radix(d1_str, 2).map_err(|e| e.to_string())?;
|
||||
let d2 = u8::from_str_radix(d2_str, 2).map_err(|e| e.to_string())?;
|
||||
Ok(Dice { values: (d1, d2) })
|
||||
}
|
||||
|
||||
pub fn to_display_string(self) -> String {
|
||||
format!("{} & {}", self.values.0, self.values.1)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,21 +2,19 @@
|
|||
use crate::board::{Board, CheckerMove};
|
||||
use crate::dice::Dice;
|
||||
use crate::game_rules_moves::MoveRules;
|
||||
use crate::game_rules_points::{PointsRules, PossibleJans, PossibleJansMethods};
|
||||
use crate::game_rules_points::{PointsRules, PossibleJans};
|
||||
use crate::player::{Color, Player, PlayerId};
|
||||
// use anyhow::{Context, Result};
|
||||
use log::{debug, error};
|
||||
use log::{debug, error, info};
|
||||
|
||||
// use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{fmt, str};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
|
||||
/// The different stages a game can be in. (not to be confused with the entire "GameState")
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum Stage {
|
||||
PreGame,
|
||||
InGame,
|
||||
|
|
@ -24,7 +22,7 @@ pub enum Stage {
|
|||
}
|
||||
|
||||
/// The different stages a game turn can be in.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum TurnStage {
|
||||
RollDice,
|
||||
RollWaiting,
|
||||
|
|
@ -62,7 +60,7 @@ impl From<TurnStage> for u8 {
|
|||
}
|
||||
|
||||
/// Represents a TricTrac game
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GameState {
|
||||
pub stage: Stage,
|
||||
pub turn_stage: TurnStage,
|
||||
|
|
@ -91,13 +89,7 @@ impl fmt::Display for GameState {
|
|||
self.stage, self.turn_stage
|
||||
));
|
||||
s.push_str(&format!("Dice: {:?}\n", self.dice));
|
||||
let empty_string = String::from("");
|
||||
s.push_str(&format!(
|
||||
"Who plays: {}\n",
|
||||
self.who_plays()
|
||||
.map(|player| &player.name)
|
||||
.unwrap_or_else(|| &empty_string)
|
||||
));
|
||||
// s.push_str(&format!("Who plays: {}\n", self.who_plays().map(|player| &player.name ).unwrap_or("")));
|
||||
s.push_str(&format!("Board: {:?}\n", self.board));
|
||||
// s.push_str(&format!("History: {:?}\n", self.history));
|
||||
write!(f, "{s}")
|
||||
|
|
@ -122,11 +114,6 @@ impl Default for GameState {
|
|||
}
|
||||
}
|
||||
}
|
||||
impl Hash for GameState {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_string_id().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl GameState {
|
||||
/// Create a new default game
|
||||
|
|
@ -136,51 +123,6 @@ impl GameState {
|
|||
gs
|
||||
}
|
||||
|
||||
pub fn new_with_players(p1_name: &str, p2_name: &str) -> Self {
|
||||
let mut game = Self::default();
|
||||
if let Some(p1) = game.init_player(p1_name) {
|
||||
game.init_player(p2_name);
|
||||
let _ = game
|
||||
.consume(&GameEvent::BeginGame { goes_first: p1 })
|
||||
.inspect_err(|e| error!("{}", e));
|
||||
}
|
||||
game
|
||||
}
|
||||
|
||||
pub fn mirror(&self) -> GameState {
|
||||
let mirrored_active_player = if self.active_player_id == 1 { 2 } else { 1 };
|
||||
let mut mirrored_players = HashMap::new();
|
||||
if let Some(p2) = self.players.get(&2) {
|
||||
mirrored_players.insert(1, p2.mirror());
|
||||
}
|
||||
if let Some(p1) = self.players.get(&1) {
|
||||
mirrored_players.insert(2, p1.mirror());
|
||||
}
|
||||
let mirrored_history = self
|
||||
.history
|
||||
.clone()
|
||||
.iter()
|
||||
.map(|evt| evt.get_mirror(false))
|
||||
.collect();
|
||||
|
||||
let (move1, move2) = self.dice_moves;
|
||||
GameState {
|
||||
stage: self.stage,
|
||||
turn_stage: self.turn_stage,
|
||||
board: self.board.mirror(),
|
||||
active_player_id: mirrored_active_player,
|
||||
// active_player_id: self.active_player_id,
|
||||
players: mirrored_players,
|
||||
history: mirrored_history,
|
||||
dice: self.dice,
|
||||
dice_points: self.dice_points,
|
||||
dice_moves: (move1.mirror(), move2.mirror()),
|
||||
dice_jans: self.dice_jans.mirror(),
|
||||
roll_first: self.roll_first,
|
||||
schools_enabled: self.schools_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_schools_enabled(&mut self, schools_enabled: bool) {
|
||||
self.schools_enabled = schools_enabled;
|
||||
}
|
||||
|
|
@ -237,13 +179,20 @@ impl GameState {
|
|||
// points, trous, bredouille, grande bredouille length=4 x2 joueurs = 8
|
||||
let white_player: Vec<i8> = self
|
||||
.get_white_player()
|
||||
.map(|p| p.to_vec().iter().map(|&x| x as i8).collect())
|
||||
.unwrap_or(vec![0; 10]);
|
||||
.unwrap()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|&x| x as i8)
|
||||
.collect();
|
||||
state.extend(white_player);
|
||||
let black_player: Vec<i8> = self
|
||||
.get_black_player()
|
||||
.map(|p| p.to_vec().iter().map(|&x| x as i8).collect())
|
||||
.unwrap_or(vec![0; 10]);
|
||||
.unwrap()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|&x| x as i8)
|
||||
.collect();
|
||||
// .iter().map(|&x| x as i8) .collect()
|
||||
state.extend(black_player);
|
||||
|
||||
// ensure state has length state_len
|
||||
|
|
@ -255,7 +204,7 @@ impl GameState {
|
|||
}
|
||||
|
||||
/// Calculate game state id :
|
||||
pub fn to_string_id_slow(&self) -> String {
|
||||
pub fn to_string_id(&self) -> String {
|
||||
// Pieces placement -> 77 bits (24 + 23 + 30 max)
|
||||
let mut pos_bits = self.board.to_gnupg_pos_id();
|
||||
|
||||
|
|
@ -290,218 +239,22 @@ impl GameState {
|
|||
pos_bits.push_str(&dice_bits);
|
||||
|
||||
// points 10bits x2 joueurs = 20bits
|
||||
let white_bits = self
|
||||
.get_white_player()
|
||||
.map(|p| p.to_bits_string())
|
||||
.unwrap_or("0000000000".into());
|
||||
let black_bits = self
|
||||
.get_black_player()
|
||||
.map(|p| p.to_bits_string())
|
||||
.unwrap_or("0000000000".into());
|
||||
let white_bits = self.get_white_player().unwrap().to_bits_string();
|
||||
let black_bits = self.get_black_player().unwrap().to_bits_string();
|
||||
pos_bits.push_str(&white_bits);
|
||||
pos_bits.push_str(&black_bits);
|
||||
|
||||
pos_bits = format!("{pos_bits:0<108}");
|
||||
pos_bits = format!("{:0>108}", pos_bits);
|
||||
// println!("{}", pos_bits);
|
||||
// let pos_u8 = pos_bits
|
||||
// .as_bytes()
|
||||
// .chunks(6)
|
||||
// .map(|chunk| str::from_utf8(chunk).unwrap())
|
||||
// .map(|chunk| u8::from_str_radix(chunk, 2).unwrap())
|
||||
// .collect::<Vec<u8>>();
|
||||
|
||||
let pos_u8 = pos_bits
|
||||
.as_bytes()
|
||||
.chunks(6)
|
||||
.map(|chunk| chunk.iter().fold(0u8, |acc, &b| (acc << 1) | (b - b'0')))
|
||||
.map(|chunk| str::from_utf8(chunk).unwrap())
|
||||
.map(|chunk| u8::from_str_radix(chunk, 2).unwrap())
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
general_purpose::STANDARD.encode(pos_u8)
|
||||
}
|
||||
|
||||
pub fn to_string_id(&self) -> String {
|
||||
const TOTAL_BITS: usize = 108;
|
||||
const TOTAL_BYTES: usize = TOTAL_BITS / 6; // 18 bytes
|
||||
|
||||
let mut output = Vec::with_capacity(TOTAL_BYTES);
|
||||
|
||||
let mut current: u8 = 0;
|
||||
let mut bit_count: u8 = 0;
|
||||
|
||||
// helper to push a single bit
|
||||
let push_bit = |bit: u8, output: &mut Vec<u8>, current: &mut u8, bit_count: &mut u8| {
|
||||
*current = (*current << 1) | (bit & 1);
|
||||
*bit_count += 1;
|
||||
|
||||
if *bit_count == 6 {
|
||||
output.push(*current);
|
||||
*current = 0;
|
||||
*bit_count = 0;
|
||||
}
|
||||
};
|
||||
|
||||
// helper to push a string of '0'/'1'
|
||||
let push_bits_str =
|
||||
|bits: &str, output: &mut Vec<u8>, current: &mut u8, bit_count: &mut u8| {
|
||||
for b in bits.bytes() {
|
||||
push_bit(b - b'0', output, current, bit_count);
|
||||
}
|
||||
};
|
||||
|
||||
// --------------------------------------------------
|
||||
// 1️⃣ Board position bits
|
||||
// --------------------------------------------------
|
||||
push_bits_str(
|
||||
&self.board.to_gnupg_pos_id(),
|
||||
&mut output,
|
||||
&mut current,
|
||||
&mut bit_count,
|
||||
);
|
||||
|
||||
// --------------------------------------------------
|
||||
// 2️⃣ Active player (1 bit)
|
||||
// --------------------------------------------------
|
||||
let active_bit = self
|
||||
.who_plays()
|
||||
.map(|player| (player.color == Color::Black) as u8)
|
||||
.unwrap_or(0);
|
||||
|
||||
push_bit(active_bit, &mut output, &mut current, &mut bit_count);
|
||||
|
||||
// --------------------------------------------------
|
||||
// 3️⃣ Turn stage (3 bits)
|
||||
// --------------------------------------------------
|
||||
let stage_bits: u8 = match self.turn_stage {
|
||||
TurnStage::RollWaiting => 0b000,
|
||||
TurnStage::RollDice => 0b001,
|
||||
TurnStage::MarkPoints => 0b010,
|
||||
TurnStage::HoldOrGoChoice => 0b011,
|
||||
TurnStage::Move => 0b100,
|
||||
TurnStage::MarkAdvPoints => 0b101,
|
||||
};
|
||||
|
||||
for i in (0..3).rev() {
|
||||
push_bit(
|
||||
(stage_bits >> i) & 1,
|
||||
&mut output,
|
||||
&mut current,
|
||||
&mut bit_count,
|
||||
);
|
||||
}
|
||||
|
||||
// --------------------------------------------------
|
||||
// 4️⃣ Dice (6 bits)
|
||||
// --------------------------------------------------
|
||||
push_bits_str(
|
||||
&self.dice.to_bits_string(),
|
||||
&mut output,
|
||||
&mut current,
|
||||
&mut bit_count,
|
||||
);
|
||||
|
||||
// --------------------------------------------------
|
||||
// 5️⃣ Players points (10 bits each)
|
||||
// --------------------------------------------------
|
||||
let white_bits = self
|
||||
.get_white_player()
|
||||
.map(|p| p.to_bits_string())
|
||||
.unwrap_or_else(|| "0000000000".to_string());
|
||||
|
||||
let black_bits = self
|
||||
.get_black_player()
|
||||
.map(|p| p.to_bits_string())
|
||||
.unwrap_or_else(|| "0000000000".to_string());
|
||||
|
||||
push_bits_str(&white_bits, &mut output, &mut current, &mut bit_count);
|
||||
push_bits_str(&black_bits, &mut output, &mut current, &mut bit_count);
|
||||
|
||||
// --------------------------------------------------
|
||||
// 6️⃣ Pad remaining bits (if needed)
|
||||
// --------------------------------------------------
|
||||
while output.len() < TOTAL_BYTES {
|
||||
push_bit(0, &mut output, &mut current, &mut bit_count);
|
||||
}
|
||||
|
||||
base64::engine::general_purpose::STANDARD.encode(output)
|
||||
}
|
||||
|
||||
pub fn from_string_id(id: &str) -> Result<Self, String> {
|
||||
let bytes = general_purpose::STANDARD
|
||||
.decode(id)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let bits_str: String = bytes.iter().map(|byte| format!("{:06b}", byte)).collect();
|
||||
|
||||
// The original string was padded to 108 bits.
|
||||
let bits = if bits_str.len() >= 108 {
|
||||
&bits_str[..108]
|
||||
} else {
|
||||
return Err("Invalid decoded string length".to_string());
|
||||
};
|
||||
|
||||
let board_bits = &bits[0..77];
|
||||
let board = Board::from_gnupg_pos_id(board_bits)?;
|
||||
|
||||
let Some(active_player_bit) = bits.chars().nth(77) else {
|
||||
return Err("No bit at 77th position".to_string());
|
||||
};
|
||||
let active_player_color = if active_player_bit == '1' {
|
||||
Color::Black
|
||||
} else {
|
||||
Color::White
|
||||
};
|
||||
|
||||
let turn_stage_bits = &bits[78..81];
|
||||
let turn_stage = match turn_stage_bits {
|
||||
"000" => TurnStage::RollWaiting,
|
||||
"001" => TurnStage::RollDice,
|
||||
"010" => TurnStage::MarkPoints,
|
||||
"011" => TurnStage::HoldOrGoChoice,
|
||||
"100" => TurnStage::Move,
|
||||
"101" => TurnStage::MarkAdvPoints,
|
||||
_ => return Err(format!("Invalid bits for turn stage : {turn_stage_bits}")),
|
||||
};
|
||||
|
||||
let dice_bits = &bits[81..87];
|
||||
let dice = Dice::from_bits_string(dice_bits).map_err(|e| e.to_string())?;
|
||||
|
||||
let white_player_bits = &bits[87..97];
|
||||
let black_player_bits = &bits[97..107];
|
||||
|
||||
let white_player =
|
||||
Player::from_bits_string(white_player_bits, "Player 1".to_string(), Color::White)
|
||||
.map_err(|e| e.to_string())?;
|
||||
let black_player =
|
||||
Player::from_bits_string(black_player_bits, "Player 2".to_string(), Color::Black)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let mut players = HashMap::new();
|
||||
players.insert(1, white_player);
|
||||
players.insert(2, black_player);
|
||||
|
||||
let active_player_id = if active_player_color == Color::White {
|
||||
1
|
||||
} else {
|
||||
2
|
||||
};
|
||||
|
||||
// Some fields are not in the ID, so we use defaults.
|
||||
Ok(GameState {
|
||||
stage: Stage::InGame, // Assume InGame from ID
|
||||
turn_stage,
|
||||
board,
|
||||
active_player_id,
|
||||
players,
|
||||
history: Vec::new(),
|
||||
dice,
|
||||
dice_points: (0, 0),
|
||||
dice_moves: (CheckerMove::default(), CheckerMove::default()),
|
||||
dice_jans: PossibleJans::default(),
|
||||
roll_first: false, // Assume not first roll
|
||||
schools_enabled: false, // Assume disabled
|
||||
})
|
||||
}
|
||||
|
||||
pub fn who_plays(&self) -> Option<&Player> {
|
||||
self.get_active_player()
|
||||
}
|
||||
|
|
@ -585,24 +338,7 @@ impl GameState {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
Roll { player_id } => {
|
||||
// Check player exists
|
||||
if !self.players.contains_key(player_id) {
|
||||
error!("unknown player_id");
|
||||
return false;
|
||||
}
|
||||
// Check player is currently the one making their move
|
||||
if self.active_player_id != *player_id {
|
||||
error!("not active player_id");
|
||||
return false;
|
||||
}
|
||||
// Check the turn stage
|
||||
if self.turn_stage != TurnStage::RollDice {
|
||||
error!("bad stage {:?}", self.turn_stage);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
RollResult { player_id, dice: _ } => {
|
||||
Roll { player_id } | RollResult { player_id, dice: _ } => {
|
||||
// Check player exists
|
||||
if !self.players.contains_key(player_id) {
|
||||
return false;
|
||||
|
|
@ -611,11 +347,6 @@ impl GameState {
|
|||
if self.active_player_id != *player_id {
|
||||
return false;
|
||||
}
|
||||
// Check the turn stage
|
||||
if self.turn_stage != TurnStage::RollWaiting {
|
||||
error!("bad stage {:?}", self.turn_stage);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Mark {
|
||||
player_id,
|
||||
|
|
@ -690,7 +421,6 @@ impl GameState {
|
|||
*moves
|
||||
};
|
||||
if !rules.moves_follow_rules(&moves) {
|
||||
// println!(">>> rules not followed ");
|
||||
error!("rules not followed ");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -710,7 +440,7 @@ impl GameState {
|
|||
|
||||
pub fn init_player(&mut self, player_name: &str) -> Option<PlayerId> {
|
||||
if self.players.len() > 2 {
|
||||
// println!("more than two players");
|
||||
println!("more than two players");
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
@ -739,14 +469,23 @@ impl GameState {
|
|||
.next();
|
||||
self.active_player_id = other_player_id.unwrap_or(0);
|
||||
}
|
||||
|
||||
/// Consumes an event, modifying the GameState and adding the event to its history
|
||||
/// NOTE: consume assumes the event to have already been validated and will accept *any* event passed to it
|
||||
pub fn consume(&mut self, valid_event: &GameEvent) -> Result<(), String> {
|
||||
pub fn consume(&mut self, valid_event: &GameEvent) {
|
||||
use GameEvent::*;
|
||||
match valid_event {
|
||||
BeginGame { goes_first } => {
|
||||
self.active_player_id = *goes_first;
|
||||
// if self.who_plays().is_none() {
|
||||
// let active_color = match self.dice.coin() {
|
||||
// false => Color::Black,
|
||||
// true => Color::White,
|
||||
// };
|
||||
// let color_player_id = self.player_id_by_color(active_color);
|
||||
// if color_player_id.is_some() {
|
||||
// self.active_player_id = *color_player_id.unwrap();
|
||||
// }
|
||||
// }
|
||||
self.stage = Stage::InGame;
|
||||
self.turn_stage = TurnStage::RollDice;
|
||||
}
|
||||
|
|
@ -782,16 +521,14 @@ impl GameState {
|
|||
self.dice = *dice;
|
||||
self.inc_roll_count(self.active_player_id);
|
||||
self.turn_stage = TurnStage::MarkPoints;
|
||||
(self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice)?;
|
||||
(self.dice_jans, self.dice_points) = self.get_rollresult_jans(dice);
|
||||
debug!("points from result : {:?}", self.dice_points);
|
||||
if !self.schools_enabled {
|
||||
// Schools are not enabled. We mark points automatically
|
||||
// the points earned by the opponent will be marked on its turn
|
||||
let new_hole = self.mark_points(self.active_player_id, self.dice_points.0);
|
||||
if new_hole {
|
||||
let Some(holes_count) = self.get_active_player().map(|p| p.holes) else {
|
||||
return Err("No active player".into());
|
||||
};
|
||||
let holes_count = self.get_active_player().unwrap().holes;
|
||||
debug!("new hole -> {holes_count:?}");
|
||||
if holes_count > 12 {
|
||||
self.stage = Stage::Ended;
|
||||
|
|
@ -807,10 +544,7 @@ impl GameState {
|
|||
if self.schools_enabled {
|
||||
let new_hole = self.mark_points(*player_id, *points);
|
||||
if new_hole {
|
||||
let Some(holes) = self.get_active_player().map(|p| p.holes) else {
|
||||
return Err("No active player".into());
|
||||
};
|
||||
if holes > 12 {
|
||||
if self.get_active_player().unwrap().holes > 12 {
|
||||
self.stage = Stage::Ended;
|
||||
} else {
|
||||
self.turn_stage = if self.turn_stage == TurnStage::MarkAdvPoints {
|
||||
|
|
@ -830,26 +564,17 @@ impl GameState {
|
|||
}
|
||||
Go { player_id: _ } => self.new_pick_up(),
|
||||
Move { player_id, moves } => {
|
||||
let Some(player) = self.players.get(player_id) else {
|
||||
return Err("unknown player {player_id}".into());
|
||||
};
|
||||
self.board
|
||||
.move_checker(&player.color, moves.0)
|
||||
.map_err(|e| e.to_string())?;
|
||||
self.board
|
||||
.move_checker(&player.color, moves.1)
|
||||
.map_err(|e| e.to_string())?;
|
||||
let player = self.players.get(player_id).unwrap();
|
||||
self.board.move_checker(&player.color, moves.0).unwrap();
|
||||
self.board.move_checker(&player.color, moves.1).unwrap();
|
||||
self.dice_moves = *moves;
|
||||
let Some(active_player_id) = self.players.keys().find(|id| *id != player_id) else {
|
||||
return Err("Can't find player id {id}".into());
|
||||
};
|
||||
self.active_player_id = *active_player_id;
|
||||
self.active_player_id = *self.players.keys().find(|id| *id != player_id).unwrap();
|
||||
self.turn_stage = if self.schools_enabled {
|
||||
TurnStage::MarkAdvPoints
|
||||
} else {
|
||||
// The player has moved, we can mark its opponent's points (which is now the current player)
|
||||
let new_hole = self.mark_points(self.active_player_id, self.dice_points.1);
|
||||
if new_hole && self.get_active_player().map(|p| p.holes).unwrap_or(0) > 12 {
|
||||
if new_hole && self.get_active_player().unwrap().holes > 12 {
|
||||
self.stage = Stage::Ended;
|
||||
}
|
||||
TurnStage::RollDice
|
||||
|
|
@ -858,7 +583,6 @@ impl GameState {
|
|||
PlayError => {}
|
||||
}
|
||||
self.history.push(valid_event.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a new pick up ('relevé') after a player won a hole and choose to 'go',
|
||||
|
|
@ -881,16 +605,14 @@ impl GameState {
|
|||
self.board = Board::new();
|
||||
}
|
||||
|
||||
fn get_rollresult_jans(&self, dice: &Dice) -> Result<(PossibleJans, (u8, u8)), String> {
|
||||
let Some(player) = &self.players.get(&self.active_player_id) else {
|
||||
return Err("No active player".into());
|
||||
};
|
||||
fn get_rollresult_jans(&self, dice: &Dice) -> (PossibleJans, (u8, u8)) {
|
||||
let player = &self.players.get(&self.active_player_id).unwrap();
|
||||
debug!(
|
||||
"get rollresult for {:?} {:?} {:?} (roll count {:?})",
|
||||
player.color, self.board, dice, player.dice_roll_count
|
||||
);
|
||||
let points_rules = PointsRules::new(&player.color, &self.board, *dice);
|
||||
Ok(points_rules.get_result_jans(player.dice_roll_count))
|
||||
points_rules.get_result_jans(player.dice_roll_count)
|
||||
}
|
||||
|
||||
/// Determines if someone has won the game
|
||||
|
|
@ -905,15 +627,13 @@ impl GameState {
|
|||
|
||||
fn inc_roll_count(&mut self, player_id: PlayerId) {
|
||||
self.players.get_mut(&player_id).map(|p| {
|
||||
p.dice_roll_count = p.dice_roll_count.saturating_add(1);
|
||||
if p.dice_roll_count < u8::MAX {
|
||||
p.dice_roll_count += 1;
|
||||
}
|
||||
p
|
||||
});
|
||||
}
|
||||
|
||||
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
self.mark_points(player_id, points)
|
||||
}
|
||||
|
||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||
// Update player points and holes
|
||||
let mut new_hole = false;
|
||||
|
|
@ -969,14 +689,14 @@ impl GameState {
|
|||
}
|
||||
|
||||
/// The reasons why a game could end
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
|
||||
pub enum EndGameReason {
|
||||
PlayerLeft { player_id: PlayerId },
|
||||
PlayerWon { winner: PlayerId },
|
||||
}
|
||||
|
||||
/// An event that progresses the GameState forward
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Deserialize)]
|
||||
pub enum GameEvent {
|
||||
BeginGame {
|
||||
goes_first: PlayerId,
|
||||
|
|
@ -1032,12 +752,10 @@ impl GameEvent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_mirror(&self, preserve_player: bool) -> Self {
|
||||
pub fn get_mirror(&self) -> Self {
|
||||
// let mut mirror = self.clone();
|
||||
let mirror_player_id = if let Some(player_id) = self.player_id() {
|
||||
if preserve_player {
|
||||
player_id
|
||||
} else if player_id == 1 {
|
||||
if player_id == 1 {
|
||||
2
|
||||
} else {
|
||||
1
|
||||
|
|
@ -1105,16 +823,7 @@ mod tests {
|
|||
let state = init_test_gamestate(TurnStage::RollDice);
|
||||
let string_id = state.to_string_id();
|
||||
// println!("string_id : {}", string_id);
|
||||
assert_eq!(string_id, "Pz84AAAABz8/AAAAAAgAASAG");
|
||||
let new_state = GameState::from_string_id(&string_id).unwrap();
|
||||
assert_eq!(state.board, new_state.board);
|
||||
assert_eq!(state.active_player_id, new_state.active_player_id);
|
||||
assert_eq!(state.turn_stage, new_state.turn_stage);
|
||||
assert_eq!(state.dice, new_state.dice);
|
||||
assert_eq!(
|
||||
state.get_white_player().unwrap().points,
|
||||
new_state.get_white_player().unwrap().points
|
||||
);
|
||||
assert_eq!(string_id, "Hz88AAAAAz8/IAAAAAQAADAD");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1122,7 +831,7 @@ mod tests {
|
|||
let mut game_state = init_test_gamestate(TurnStage::MarkPoints);
|
||||
game_state.schools_enabled = true;
|
||||
let pid = game_state.active_player_id;
|
||||
let _ = game_state.consume(
|
||||
game_state.consume(
|
||||
&(GameEvent::Mark {
|
||||
player_id: pid,
|
||||
points: 13,
|
||||
|
|
@ -1134,7 +843,7 @@ mod tests {
|
|||
assert_eq!(game_state.turn_stage, TurnStage::HoldOrGoChoice);
|
||||
|
||||
// Go
|
||||
let _ = game_state.consume(
|
||||
game_state.consume(
|
||||
&(GameEvent::Go {
|
||||
player_id: game_state.active_player_id,
|
||||
}),
|
||||
|
|
@ -1148,7 +857,7 @@ mod tests {
|
|||
let mut game_state = init_test_gamestate(TurnStage::MarkPoints);
|
||||
game_state.schools_enabled = true;
|
||||
let pid = game_state.active_player_id;
|
||||
let _ = game_state.consume(
|
||||
game_state.consume(
|
||||
&(GameEvent::Mark {
|
||||
player_id: pid,
|
||||
points: 13,
|
||||
|
|
@ -1158,7 +867,7 @@ mod tests {
|
|||
CheckerMove::new(1, 3).unwrap(),
|
||||
CheckerMove::new(1, 3).unwrap(),
|
||||
);
|
||||
let _ = game_state.consume(
|
||||
game_state.consume(
|
||||
&(GameEvent::Move {
|
||||
player_id: game_state.active_player_id,
|
||||
moves,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ use crate::game::GameState;
|
|||
use crate::player::Color;
|
||||
use log::info;
|
||||
use std::cmp;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(std::cmp::PartialEq, Debug)]
|
||||
pub enum MoveError {
|
||||
|
|
@ -23,7 +22,7 @@ pub enum MoveError {
|
|||
// sans nombre en excédant est possible
|
||||
ExitByEffectPossible,
|
||||
// Sortie avec nombre en excédant d'une dame qui n'est pas la plus éloignée
|
||||
ExitNotFarthest,
|
||||
ExitNotFasthest,
|
||||
// Jeu dans un cadran que l'adversaire peut encore remplir
|
||||
OpponentCanFillQuarter,
|
||||
// remplir cadran si possible & conserver cadran rempli si possible ----
|
||||
|
|
@ -82,8 +81,7 @@ impl MoveRules {
|
|||
let is_allowed = self.moves_allowed(moves);
|
||||
// let is_allowed = self.moves_allowed(moves, ignored_rules);
|
||||
if is_allowed.is_err() {
|
||||
println!("Move not allowed : {:?}", is_allowed.unwrap_err());
|
||||
// info!("Move not allowed : {:?}", is_allowed.unwrap_err());
|
||||
info!("Move not allowed : {:?}", is_allowed.unwrap_err());
|
||||
false
|
||||
} else {
|
||||
true
|
||||
|
|
@ -101,10 +99,6 @@ impl MoveRules {
|
|||
if let Ok((field_count, Some(field_color))) = self.board.get_field_checkers(move0_from)
|
||||
{
|
||||
if color != field_color || field_count < 2 {
|
||||
println!(
|
||||
"Move not physically possible 1. field_color {:?}, count {}",
|
||||
field_color, field_count
|
||||
);
|
||||
info!("Move not physically possible");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -116,7 +110,6 @@ impl MoveRules {
|
|||
if !self.board.passage_possible(color, &moves.0)
|
||||
|| !self.board.move_possible(color, &chained_move)
|
||||
{
|
||||
println!("Tout d'une : Move not physically possible");
|
||||
info!("Tout d'une : Move not physically possible");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -124,11 +117,6 @@ impl MoveRules {
|
|||
|| !self.board.move_possible(color, &moves.1)
|
||||
{
|
||||
// Move is not physically possible
|
||||
println!("Move not physically possible 2");
|
||||
println!(
|
||||
"board: {}, color: {:?} move: {:?}",
|
||||
self.board, color, moves
|
||||
);
|
||||
info!("Move not physically possible");
|
||||
return false;
|
||||
}
|
||||
|
|
@ -271,9 +259,6 @@ impl MoveRules {
|
|||
) -> Result<(), MoveError> {
|
||||
let filling_moves_sequences = self.get_quarter_filling_moves_sequences();
|
||||
if !filling_moves_sequences.contains(moves) && !filling_moves_sequences.is_empty() {
|
||||
if *moves == (CheckerMove::default(), CheckerMove::default()) {
|
||||
println!("filling moves sequences: {:?}", filling_moves_sequences);
|
||||
}
|
||||
return Err(MoveError::MustFillQuarter);
|
||||
}
|
||||
Ok(())
|
||||
|
|
@ -344,87 +329,43 @@ impl MoveRules {
|
|||
return Err(MoveError::ExitByEffectPossible);
|
||||
}
|
||||
|
||||
// The chosen checker must be the farthest from exit
|
||||
// For chained moves (tout d'une), we need to check the board state AFTER the first move
|
||||
// let board_to_check = if moves.0.get_to() == moves.1.get_from() {
|
||||
// // Chained move: apply first move to get the board state
|
||||
// let mut board_copy = self.board.clone();
|
||||
// let _ = board_copy.move_checker(&Color::White, moves.0);
|
||||
// board_copy
|
||||
// } else {
|
||||
// self.board.clone()
|
||||
// };
|
||||
|
||||
let mut board_to_check = self.board.clone();
|
||||
let farthest_on_move1 = Self::get_board_exit_farthest(&board_to_check);
|
||||
|
||||
let _ = board_to_check.move_checker(&Color::White, moves.0);
|
||||
let farthest_on_move2 = Self::get_board_exit_farthest(&board_to_check);
|
||||
|
||||
let (is_move1_exedant, is_move2_exedant) = self.move_excedants(moves);
|
||||
if (is_move1_exedant && moves.0.get_from() != farthest_on_move1)
|
||||
|| (is_move2_exedant && moves.1.get_from() != farthest_on_move2)
|
||||
{
|
||||
return Err(MoveError::ExitNotFarthest);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn move_excedants(&self, moves: &(CheckerMove, CheckerMove)) -> (bool, bool) {
|
||||
let move1to = if moves.0.get_to() == 0 {
|
||||
25
|
||||
} else {
|
||||
moves.0.get_to()
|
||||
};
|
||||
let dist1 = move1to - moves.0.get_from();
|
||||
|
||||
let move2to = if moves.1.get_to() == 0 {
|
||||
25
|
||||
} else {
|
||||
moves.1.get_to()
|
||||
};
|
||||
let dist2 = move2to - moves.1.get_from();
|
||||
|
||||
let dist_min = cmp::min(dist1, dist2);
|
||||
let dist_max = cmp::max(dist1, dist2);
|
||||
|
||||
let dice_min = cmp::min(self.dice.values.0, self.dice.values.1) as usize;
|
||||
let dice_max = cmp::max(self.dice.values.0, self.dice.values.1) as usize;
|
||||
|
||||
let min_excedant = dist_min != 0 && dist_min < dice_min;
|
||||
let max_excedant = dist_max != 0 && dist_max < dice_max;
|
||||
|
||||
if dist_min == dist1 {
|
||||
(min_excedant, max_excedant)
|
||||
} else {
|
||||
(max_excedant, min_excedant)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_board_exit_farthest(board: &Board) -> Field {
|
||||
let mut checkers = board.get_color_fields(Color::White);
|
||||
checkers.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut farthest = 25;
|
||||
// Check if we have a filled quarter that must be preserved
|
||||
let has_filled_quarter = board.any_quarter_filled(Color::White);
|
||||
if has_filled_quarter {
|
||||
// When a quarter is filled, we can only exit from fields with >2 checkers
|
||||
// Find the farthest field with >2 checkers (removing one won't break the quarter)
|
||||
let available_checkers: Vec<_> =
|
||||
checkers.iter().filter(|(_, count)| *count > 2).collect();
|
||||
if !available_checkers.is_empty() {
|
||||
// Use the farthest available checker (that won't break the quarter)
|
||||
farthest = available_checkers[0].0;
|
||||
}
|
||||
}
|
||||
if farthest == 25 {
|
||||
if let Some((field, _)) = checkers.first() {
|
||||
// - la dame choisie doit être la plus éloignée de la sortie
|
||||
let mut checkers = self.board.get_color_fields(Color::White);
|
||||
checkers.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
let mut farthest = 24;
|
||||
let mut next_farthest = 24;
|
||||
let mut has_two_checkers = false;
|
||||
if let Some((field, count)) = checkers.first() {
|
||||
farthest = *field;
|
||||
if *count > 1 {
|
||||
next_farthest = *field;
|
||||
has_two_checkers = true;
|
||||
} else if let Some((field, _count)) = checkers.get(1) {
|
||||
next_farthest = *field;
|
||||
has_two_checkers = true;
|
||||
}
|
||||
}
|
||||
farthest
|
||||
|
||||
// s'il reste au moins deux dames, on vérifie que les plus éloignées soint choisies
|
||||
if has_two_checkers {
|
||||
if moves.0.get_to() == 0 && moves.1.get_to() == 0 {
|
||||
// Deux coups sortants en excédant
|
||||
if cmp::max(moves.0.get_from(), moves.1.get_from()) > next_farthest {
|
||||
return Err(MoveError::ExitNotFasthest);
|
||||
}
|
||||
} else {
|
||||
// Un seul coup sortant en excédant le coup sortant doit concerner la plus éloignée du bord
|
||||
let exit_move_field = if moves.0.get_to() == 0 {
|
||||
moves.0.get_from()
|
||||
} else {
|
||||
moves.1.get_from()
|
||||
};
|
||||
if exit_move_field != farthest {
|
||||
return Err(MoveError::ExitNotFasthest);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_possible_moves_sequences(
|
||||
|
|
@ -462,11 +403,6 @@ impl MoveRules {
|
|||
if empty_removed.count() > 0 {
|
||||
moves_seqs.retain(|(c1, c2)| *c1 != EMPTY_MOVE && *c2 != EMPTY_MOVE);
|
||||
}
|
||||
|
||||
// deduplicate
|
||||
let mut set = HashSet::new();
|
||||
moves_seqs.retain(|x| set.insert(*x));
|
||||
|
||||
moves_seqs
|
||||
}
|
||||
|
||||
|
|
@ -547,7 +483,6 @@ impl MoveRules {
|
|||
let mut moves_seqs = Vec::new();
|
||||
let color = &Color::White;
|
||||
let forbid_exits = self.has_checkers_outside_last_quarter();
|
||||
// println!("==== First");
|
||||
for first_move in
|
||||
self.board
|
||||
.get_possible_moves(*color, dice1, with_excedents, false, forbid_exits)
|
||||
|
|
@ -561,59 +496,23 @@ impl MoveRules {
|
|||
// XXX : the goal here is to replicate moves_allowed() checks without using get_possible_moves_sequences to
|
||||
// avoid an infinite loop...
|
||||
let mut has_second_dice_move = false;
|
||||
// println!(" ==== Second");
|
||||
for second_move in
|
||||
board2.get_possible_moves(*color, dice2, with_excedents, true, forbid_exits)
|
||||
{
|
||||
if self
|
||||
.check_corner_rules(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd (corner rule): {:?} - {:?}, {:?}",
|
||||
// e, first_move, second_move
|
||||
// )
|
||||
// })
|
||||
.is_ok()
|
||||
if self.check_corner_rules(&(first_move, second_move)).is_ok()
|
||||
&& self
|
||||
.check_opponent_can_fill_quarter_rule(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd (op fill quarter): {:?} - {:?}, {:?}",
|
||||
// e, first_move, second_move
|
||||
// )
|
||||
// })
|
||||
.is_ok()
|
||||
&& !(self.is_move_by_puissance(&(first_move, second_move))
|
||||
&& self.can_take_corner_by_effect())
|
||||
&& (ignored_rules.contains(&TricTracRule::Exit)
|
||||
|| self
|
||||
.check_exit_rules(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd (exit rule): {:?} - {:?}, {:?}",
|
||||
// e, first_move, second_move
|
||||
// )
|
||||
// })
|
||||
.is_ok())
|
||||
|| self.check_exit_rules(&(first_move, second_move)).is_ok())
|
||||
&& (ignored_rules.contains(&TricTracRule::MustFillQuarter)
|
||||
|| self
|
||||
.check_must_fill_quarter_rule(&(first_move, second_move))
|
||||
// .inspect_err(|e| {
|
||||
// println!(
|
||||
// " 2nd: {:?} - {:?}, {:?} for {:?}",
|
||||
// e, first_move, second_move, self.board
|
||||
// )
|
||||
// })
|
||||
.is_ok())
|
||||
{
|
||||
if second_move.get_to() == 0
|
||||
&& first_move.get_to() == 0
|
||||
&& second_move.get_from() < first_move.get_from()
|
||||
{
|
||||
moves_seqs.push((second_move, first_move));
|
||||
} else {
|
||||
moves_seqs.push((first_move, second_move));
|
||||
}
|
||||
has_second_dice_move = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -864,7 +763,7 @@ mod tests {
|
|||
CheckerMove::new(20, 0).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(Err(MoveError::ExitNotFarthest), state.moves_allowed(&moves));
|
||||
assert_eq!(Err(MoveError::ExitNotFasthest), state.moves_allowed(&moves));
|
||||
let moves = (
|
||||
CheckerMove::new(20, 0).unwrap(),
|
||||
CheckerMove::new(21, 0).unwrap(),
|
||||
|
|
@ -885,19 +784,6 @@ mod tests {
|
|||
);
|
||||
assert!(state.moves_follows_dices(&moves));
|
||||
assert!(state.moves_allowed(&moves).is_ok());
|
||||
|
||||
state.board.set_positions(
|
||||
&Color::White,
|
||||
[
|
||||
-5, -2, -2, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 2, 3, 2, 3,
|
||||
],
|
||||
);
|
||||
state.dice.values = (4, 5);
|
||||
let moves = (
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.moves_allowed(&moves).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1407,202 +1293,5 @@ mod tests {
|
|||
vec![moves],
|
||||
state.get_possible_moves_sequences(true, vec![])
|
||||
);
|
||||
|
||||
state.board.set_positions(
|
||||
&Color::White,
|
||||
[
|
||||
-8, -4, -1, 0, 0, 0, 0, -1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 2, 2, 2,
|
||||
],
|
||||
);
|
||||
state.dice.values = (1, 4);
|
||||
let moves = (
|
||||
CheckerMove::new(21, 22).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences(true, vec![])
|
||||
);
|
||||
|
||||
state.dice.values = (5, 3);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-8, -3, -1, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 3, 3,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences(true, vec![])
|
||||
);
|
||||
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-3, -3, -2, -2, -2, -2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 3, 8,
|
||||
],
|
||||
);
|
||||
let state = MoveRules::new(&Color::Black, &board, Dice { values: (5, 3) });
|
||||
let moves = (
|
||||
// CheckerMove::new(2, 0).unwrap(),
|
||||
// CheckerMove::new(1, 0).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences(true, vec![])
|
||||
);
|
||||
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-7, -4, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 3, 2, 1, 2, 3, 4,
|
||||
],
|
||||
);
|
||||
let state = MoveRules::new(&Color::White, &board, Dice { values: (2, 6) });
|
||||
let moves = vec![
|
||||
(
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(19, 21).unwrap(),
|
||||
),
|
||||
(
|
||||
CheckerMove::new(19, 21).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
),
|
||||
];
|
||||
assert_eq!(moves, state.get_possible_moves_sequences(true, vec![]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_possible_moves_sequences_by_dices() {
|
||||
let mut state = MoveRules::default();
|
||||
|
||||
state.dice.values = (5, 3);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-8, -3, -1, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 3, 3,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
CheckerMove::new(24, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
state.dice.values.0,
|
||||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
|
||||
state.dice.values = (4, 5);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-5, -2, -2, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 2, 3, 2, 3,
|
||||
],
|
||||
);
|
||||
let moves = vec![(
|
||||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
)];
|
||||
assert_eq!(
|
||||
moves,
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
state.dice.values.0,
|
||||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
let moves = vec![(
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
)];
|
||||
assert_eq!(
|
||||
moves,
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
state.dice.values.1,
|
||||
state.dice.values.0,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
|
||||
let mut board = Board::new();
|
||||
board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-7, -4, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 3, 2, 1, 2, 3, 4,
|
||||
],
|
||||
);
|
||||
let state = MoveRules::new(&Color::White, &board, Dice { values: (2, 6) });
|
||||
let moves = (
|
||||
CheckerMove::new(19, 21).unwrap(),
|
||||
CheckerMove::new(23, 0).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
vec![moves],
|
||||
state.get_possible_moves_sequences_by_dices(
|
||||
state.dice.values.0,
|
||||
state.dice.values.1,
|
||||
true,
|
||||
false,
|
||||
vec![]
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_exit_rules() {
|
||||
let mut state = MoveRules::default();
|
||||
state.dice.values = (4, 5);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-5, -2, -2, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 2, 3, 2, 3,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(19, 23).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
|
||||
let moves = (
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_exit_rules(&moves).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_must_fill_quarter_rule() {
|
||||
let mut state = MoveRules::default();
|
||||
state.dice.values = (4, 5);
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-5, -2, -2, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 2, 3, 2, 3,
|
||||
],
|
||||
);
|
||||
let moves = (
|
||||
CheckerMove::new(19, 24).unwrap(),
|
||||
CheckerMove::new(22, 0).unwrap(),
|
||||
);
|
||||
assert!(state.check_must_fill_quarter_rule(&moves).is_ok());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,26 +69,10 @@ pub type PossibleJans = HashMap<Jan, Vec<(CheckerMove, CheckerMove)>>;
|
|||
pub trait PossibleJansMethods {
|
||||
fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove));
|
||||
fn merge(&mut self, other: Self);
|
||||
fn mirror(&self) -> Self;
|
||||
// fn get_points(&self) -> u8;
|
||||
}
|
||||
|
||||
impl PossibleJansMethods for PossibleJans {
|
||||
fn mirror(&self) -> Self {
|
||||
self.clone()
|
||||
.into_iter()
|
||||
.map(|(jan, moves)| {
|
||||
(
|
||||
jan,
|
||||
moves
|
||||
.into_iter()
|
||||
.map(|(m1, m2)| (m1.mirror(), m2.mirror()))
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn push(&mut self, jan: Jan, cmoves: (CheckerMove, CheckerMove)) {
|
||||
if let Some(ways) = self.get_mut(&jan) {
|
||||
if !ways.contains(&cmoves) {
|
||||
|
|
@ -619,7 +603,7 @@ mod tests {
|
|||
);
|
||||
let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) });
|
||||
let jans = points_rules.get_result_jans(8);
|
||||
assert!(!jans.0.is_empty());
|
||||
assert!(jans.0.len() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -644,7 +628,7 @@ mod tests {
|
|||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2,
|
||||
],
|
||||
);
|
||||
let rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
|
||||
let mut rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
|
||||
assert_eq!(12, rules.get_points(5).0);
|
||||
|
||||
// Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4
|
||||
|
|
|
|||
|
|
@ -16,11 +16,3 @@ pub use board::CheckerMove;
|
|||
|
||||
mod dice;
|
||||
pub use dice::{Dice, DiceRoller};
|
||||
|
||||
pub mod training_common;
|
||||
|
||||
// python interface "trictrac_engine" (for AI training..)
|
||||
mod pyengine;
|
||||
|
||||
// C++ interface via cxx.rs (for OpenSpiel C++ integration)
|
||||
pub mod cxxengine;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
use pyo3::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
// This just makes it easier to dissern between a player id and any ol' u64
|
||||
pub type PlayerId = u64;
|
||||
|
||||
#[pyclass(eq, eq_int)]
|
||||
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Copy, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Color {
|
||||
White,
|
||||
Black,
|
||||
|
|
@ -22,7 +20,7 @@ impl Color {
|
|||
}
|
||||
|
||||
/// Struct for storing player related data.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Player {
|
||||
pub name: String,
|
||||
pub color: Color,
|
||||
|
|
@ -48,16 +46,6 @@ impl Player {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn mirror(&self) -> Self {
|
||||
let mut player = self.clone();
|
||||
player.color = if self.color == Color::White {
|
||||
Color::Black
|
||||
} else {
|
||||
Color::White
|
||||
};
|
||||
player
|
||||
}
|
||||
|
||||
pub fn to_bits_string(&self) -> String {
|
||||
format!(
|
||||
"{:0>4b}{:0>4b}{:b}{:b}",
|
||||
|
|
@ -65,26 +53,6 @@ impl Player {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn from_bits_string(bits: &str, name: String, color: Color) -> Result<Self, String> {
|
||||
if bits.len() != 10 {
|
||||
return Err("Invalid bit string length for player".to_string());
|
||||
}
|
||||
let points = u8::from_str_radix(&bits[0..4], 2).map_err(|e| e.to_string())?;
|
||||
let holes = u8::from_str_radix(&bits[4..8], 2).map_err(|e| e.to_string())?;
|
||||
let can_bredouille = bits.chars().nth(8).ok_or_else(|| "8th bit unreadable")? == '1';
|
||||
let can_big_bredouille = bits.chars().nth(9).ok_or_else(|| "9th bit unreadable")? == '1';
|
||||
|
||||
Ok(Player {
|
||||
name,
|
||||
color,
|
||||
points,
|
||||
holes,
|
||||
can_bredouille,
|
||||
can_big_bredouille,
|
||||
dice_roll_count: 0, // This info is not in the string id
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<u8> {
|
||||
vec![
|
||||
self.points,
|
||||
|
|
|
|||
|
|
@ -1,146 +0,0 @@
|
|||
//! # Expose trictrac game state and rules in a python module
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::dice::Dice;
|
||||
use crate::game::{GameEvent, GameState, Stage, TurnStage};
|
||||
use crate::player::PlayerId;
|
||||
use crate::training_common::{get_valid_action_indices, TrictracAction};
|
||||
|
||||
#[pyclass]
|
||||
struct TricTrac {
|
||||
game_state: GameState,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl TricTrac {
|
||||
#[new]
|
||||
fn new() -> Self {
|
||||
let mut game_state = GameState::new(false); // schools_enabled = false
|
||||
|
||||
// Initialiser 2 joueurs
|
||||
game_state.init_player("player1");
|
||||
game_state.init_player("player2");
|
||||
|
||||
// Commencer la partie avec le joueur 1
|
||||
let _ = game_state.consume(&GameEvent::BeginGame { goes_first: 1 });
|
||||
|
||||
TricTrac { game_state }
|
||||
}
|
||||
|
||||
fn needs_roll(&self) -> bool {
|
||||
self.game_state.turn_stage == TurnStage::RollWaiting
|
||||
}
|
||||
|
||||
fn is_game_ended(&self) -> bool {
|
||||
self.game_state.stage == Stage::Ended
|
||||
}
|
||||
|
||||
// 0 or 1
|
||||
fn current_player_idx(&self) -> u64 {
|
||||
self.game_state.active_player_id - 1
|
||||
}
|
||||
|
||||
fn get_legal_actions(&self, player_idx: u64) -> Vec<usize> {
|
||||
if player_idx == self.current_player_idx() {
|
||||
if player_idx == 0 {
|
||||
get_valid_action_indices(&self.game_state).unwrap()
|
||||
} else {
|
||||
let mirror = self.game_state.mirror();
|
||||
get_valid_action_indices(&mirror).unwrap()
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn action_to_string(&self, player_idx: u64, action_idx: usize) -> String {
|
||||
TrictracAction::from_action_index(action_idx)
|
||||
.map(|a| format!("{}:{}", player_idx, a))
|
||||
.unwrap_or("unknown action".into())
|
||||
}
|
||||
|
||||
fn apply_dice_roll(&mut self, dices: (u8, u8)) -> PyResult<()> {
|
||||
let player_id = self.game_state.active_player_id;
|
||||
|
||||
if self.game_state.turn_stage != TurnStage::RollWaiting {
|
||||
return Err(pyo3::exceptions::PyRuntimeError::new_err(
|
||||
"Not in RollWaiting stage",
|
||||
));
|
||||
}
|
||||
|
||||
let dice = Dice { values: dices };
|
||||
let _ = self
|
||||
.game_state
|
||||
.consume(&GameEvent::RollResult { player_id, dice });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_action(&mut self, action_idx: usize) -> PyResult<()> {
|
||||
if let Some(event) = TrictracAction::from_action_index(action_idx).and_then(|a| {
|
||||
let needs_mirror = self.game_state.active_player_id == 2;
|
||||
let game_state = if needs_mirror {
|
||||
&self.game_state.mirror()
|
||||
} else {
|
||||
&self.game_state
|
||||
};
|
||||
a.to_event(game_state)
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
}) {
|
||||
if self.game_state.validate(&event) {
|
||||
let _ = self.game_state.consume(&event);
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(pyo3::exceptions::PyRuntimeError::new_err(
|
||||
"Action is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(pyo3::exceptions::PyRuntimeError::new_err(
|
||||
"Could not apply action",
|
||||
))
|
||||
}
|
||||
|
||||
/// Get a player total score (holes & points)
|
||||
fn get_score(&self, player_id: PlayerId) -> i32 {
|
||||
if let Some(player) = self.game_state.players.get(&player_id) {
|
||||
player.holes as i32 * 12 + player.points as i32
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
fn get_players_scores(&self) -> [i32; 2] {
|
||||
[self.get_score(1), self.get_score(2)]
|
||||
}
|
||||
|
||||
fn get_tensor(&self, player_idx: u64) -> Vec<i8> {
|
||||
if player_idx == 0 {
|
||||
self.game_state.to_vec()
|
||||
} else {
|
||||
self.game_state.mirror().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_observation_string(&self, player_idx: u64) -> String {
|
||||
if player_idx == 0 {
|
||||
format!("{}", self.game_state)
|
||||
} else {
|
||||
format!("{}", self.game_state.mirror())
|
||||
}
|
||||
}
|
||||
|
||||
/// Afficher l'état du jeu (pour le débogage)
|
||||
fn __str__(&self) -> String {
|
||||
format!("{}", self.game_state)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule]
|
||||
fn trictrac_store(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<TricTrac>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,460 +0,0 @@
|
|||
/// training_common.rs : environnement avec espace d'actions optimisé
|
||||
/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' )
|
||||
use std::cmp::{max, min};
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
|
||||
use crate::board::Board;
|
||||
use crate::{CheckerMove, Dice, GameEvent, GameState};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// 1 (Roll) + 1 (Go) + 512 (mouvements possibles)
|
||||
// avec 512 = 2 (choix du dé) * 16 * 16 (choix de la dame 0-15 pour chaque from)
|
||||
pub const ACTION_SPACE_SIZE: usize = 514;
|
||||
|
||||
/// Types d'actions possibles dans le jeu
|
||||
#[derive(Debug, Copy, Clone, Eq, Serialize, Deserialize, PartialEq)]
|
||||
pub enum TrictracAction {
|
||||
/// Lancer les dés
|
||||
Roll,
|
||||
/// Faire un nouveau 'relevé' (repositionnement des dames à l'état de départ) après avoir gagné un trou,
|
||||
/// au lieu de continuer dans la position courante
|
||||
Go,
|
||||
/// Effectuer un mouvement de pions
|
||||
Move {
|
||||
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||
checker1: usize, // premier pion à déplacer en numérotant depuis la colonne de départ (0-15) 0 : aucun pion
|
||||
checker2: usize, // deuxième pion (0-15)
|
||||
},
|
||||
// Marquer les points : à activer si support des écoles
|
||||
// Mark,
|
||||
}
|
||||
|
||||
impl Display for TrictracAction {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let s = format!("{self:?}");
|
||||
writeln!(f, "{}", s.chars().rev().collect::<String>())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracAction {
|
||||
/// Encode une action en index pour le réseau de neurones
|
||||
pub fn to_action_index(&self) -> usize {
|
||||
match self {
|
||||
TrictracAction::Roll => 0,
|
||||
TrictracAction::Go => 1,
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
} => {
|
||||
// Encoder les mouvements dans l'espace d'actions
|
||||
// Indices 2+ pour les mouvements
|
||||
// de 2 à 513 (2 à 257 pour dé 1 en premier, 258 à 513 pour dé 2 en premier)
|
||||
let mut start = 2;
|
||||
if !dice_order {
|
||||
// 16 * 16 = 256
|
||||
start += 256;
|
||||
}
|
||||
start + checker1 * 16 + checker2
|
||||
} // TrictracAction::Mark => 514,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mirror(&self) -> TrictracAction {
|
||||
match self {
|
||||
TrictracAction::Roll => TrictracAction::Roll,
|
||||
TrictracAction::Go => TrictracAction::Go,
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
} => TrictracAction::Move {
|
||||
dice_order: *dice_order,
|
||||
checker1: if *checker1 == 0 { 0 } else { 25 - checker1 },
|
||||
checker2: if *checker2 == 0 { 0 } else { 25 - checker2 },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_event(&self, state: &GameState) -> Option<GameEvent> {
|
||||
match self {
|
||||
TrictracAction::Roll => {
|
||||
// Lancer les dés
|
||||
Some(GameEvent::Roll {
|
||||
player_id: state.active_player_id,
|
||||
})
|
||||
}
|
||||
// TrictracAction::Mark => {
|
||||
// // Marquer des points
|
||||
// let points = self.game.
|
||||
// Some(GameEvent::Mark {
|
||||
// player_id: self.active_player_id,
|
||||
// points,
|
||||
// })
|
||||
// }
|
||||
TrictracAction::Go => {
|
||||
// Continuer après avoir gagné un trou
|
||||
Some(GameEvent::Go {
|
||||
player_id: state.active_player_id,
|
||||
})
|
||||
}
|
||||
TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
} => {
|
||||
// Effectuer un mouvement
|
||||
let (dice1, dice2) = if *dice_order {
|
||||
(state.dice.values.0, state.dice.values.1)
|
||||
} else {
|
||||
(state.dice.values.1, state.dice.values.0)
|
||||
};
|
||||
|
||||
let color = &crate::Color::White;
|
||||
let from1 = state
|
||||
.board
|
||||
.get_checker_field(color, *checker1 as u8)
|
||||
.unwrap_or(0);
|
||||
let mut to1 = from1 + dice1 as usize;
|
||||
if 24 < to1 {
|
||||
// exit board
|
||||
to1 = 0;
|
||||
}
|
||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
|
||||
let mut tmp_board = state.board.clone();
|
||||
let move_result = tmp_board.move_checker(color, checker_move1);
|
||||
if move_result.is_err() {
|
||||
None
|
||||
// panic!("Error while moving checker {move_result:?}")
|
||||
} else {
|
||||
let from2 = tmp_board
|
||||
.get_checker_field(color, *checker2 as u8)
|
||||
.unwrap_or(0);
|
||||
let mut to2 = from2 + dice2 as usize;
|
||||
if 24 < to2 {
|
||||
// exit board
|
||||
to2 = 0;
|
||||
}
|
||||
|
||||
// Gestion prise de coin par puissance
|
||||
let opp_rest_field = 13;
|
||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||
to1 -= 1;
|
||||
to2 -= 1;
|
||||
}
|
||||
|
||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||
|
||||
Some(GameEvent::Move {
|
||||
player_id: state.active_player_id,
|
||||
moves: (checker_move1, checker_move2),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un index d'action en TrictracAction
|
||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||
match index {
|
||||
0 => Some(TrictracAction::Roll),
|
||||
1 => Some(TrictracAction::Go),
|
||||
// 514 => Some(TrictracAction::Mark),
|
||||
i if i >= 2 => {
|
||||
let move_code = i - 2;
|
||||
let (dice_order, checker1, checker2) = Self::decode_move(move_code);
|
||||
Some(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Décode un entier en paire de mouvements
|
||||
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||
let mut encoded = code;
|
||||
let dice_order = code < 256;
|
||||
if !dice_order {
|
||||
encoded -= 256
|
||||
}
|
||||
let checker1 = encoded / 16;
|
||||
let checker2 = encoded % 16;
|
||||
(dice_order, checker1, checker2)
|
||||
}
|
||||
|
||||
/// Retourne la taille de l'espace d'actions total
|
||||
pub fn action_space_size() -> usize {
|
||||
ACTION_SPACE_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtient les actions valides pour l'état de jeu actuel
|
||||
pub fn get_valid_actions(game_state: &GameState) -> anyhow::Result<Vec<TrictracAction>> {
|
||||
use crate::TurnStage;
|
||||
|
||||
let mut valid_actions = Vec::new();
|
||||
|
||||
let active_player_id = game_state.active_player_id;
|
||||
let player_color = game_state.player_color_by_id(&active_player_id);
|
||||
|
||||
if let Some(color) = player_color {
|
||||
match game_state.turn_stage {
|
||||
TurnStage::RollDice => {
|
||||
valid_actions.push(TrictracAction::Roll);
|
||||
}
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
||||
// valid_actions.push(TrictracAction::Mark);
|
||||
anyhow::bail!(
|
||||
"get_valid_actions not implemented for turn stage {:?}",
|
||||
game_state.turn_stage
|
||||
);
|
||||
}
|
||||
TurnStage::HoldOrGoChoice => {
|
||||
valid_actions.push(TrictracAction::Go);
|
||||
|
||||
// Ajoute aussi les mouvements possibles
|
||||
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
TurnStage::Move => {
|
||||
let rules = crate::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||
let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||
if possible_moves.is_empty() {
|
||||
// Empty move
|
||||
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
||||
}
|
||||
|
||||
for (move1, move2) in possible_moves {
|
||||
valid_actions.push(checker_moves_to_trictrac_action(
|
||||
&move1, &move2, &color, game_state,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid_actions.is_empty() {
|
||||
anyhow::bail!("empty valid_actions for state {game_state}");
|
||||
}
|
||||
Ok(valid_actions)
|
||||
}
|
||||
|
||||
fn checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
color: &crate::Color,
|
||||
state: &GameState,
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
let dice = &state.dice;
|
||||
let board = &state.board;
|
||||
|
||||
if color == &crate::Color::Black {
|
||||
// Moves are already 'white', so we don't mirror them
|
||||
white_checker_moves_to_trictrac_action(
|
||||
move1,
|
||||
move2,
|
||||
// &move1.clone().mirror(),
|
||||
// &move2.clone().mirror(),
|
||||
dice,
|
||||
&board.clone().mirror(),
|
||||
)
|
||||
// .map(|a| a.mirror())
|
||||
} else {
|
||||
white_checker_moves_to_trictrac_action(move1, move2, dice, board)
|
||||
}
|
||||
}
|
||||
|
||||
fn white_checker_moves_to_trictrac_action(
|
||||
move1: &CheckerMove,
|
||||
move2: &CheckerMove,
|
||||
dice: &Dice,
|
||||
board: &Board,
|
||||
) -> anyhow::Result<TrictracAction> {
|
||||
let to1 = move1.get_to();
|
||||
let to2 = move2.get_to();
|
||||
let from1 = move1.get_from();
|
||||
let from2 = move2.get_from();
|
||||
|
||||
let mut diff_move1 = if to1 > 0 {
|
||||
// Mouvement sans sortie
|
||||
to1 - from1
|
||||
} else {
|
||||
// sortie, on utilise la valeur du dé
|
||||
if to2 > 0 {
|
||||
// sortie pour le mouvement 1 uniquement
|
||||
let dice2 = to2 - from2;
|
||||
if dice2 == dice.values.0 as usize {
|
||||
dice.values.1 as usize
|
||||
} else {
|
||||
dice.values.0 as usize
|
||||
}
|
||||
} else {
|
||||
// double sortie
|
||||
if from1 < from2 {
|
||||
max(dice.values.0, dice.values.1) as usize
|
||||
} else {
|
||||
min(dice.values.0, dice.values.1) as usize
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
|
||||
let rest_field = 12;
|
||||
if to1 == rest_field
|
||||
&& to2 == rest_field
|
||||
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
|
||||
{
|
||||
// prise par puissance
|
||||
diff_move1 += 1;
|
||||
}
|
||||
let dice_order = diff_move1 == dice.values.0 as usize;
|
||||
|
||||
let checker1 = board.get_field_checker(&crate::Color::White, from1) as usize;
|
||||
let mut tmp_board = board.clone();
|
||||
// should not raise an error for a valid action
|
||||
tmp_board.move_checker(&crate::Color::White, *move1)?;
|
||||
let checker2 = tmp_board.get_field_checker(&crate::Color::White, from2) as usize;
|
||||
Ok(TrictracAction::Move {
|
||||
dice_order,
|
||||
checker1,
|
||||
checker2,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retourne les indices des actions valides
|
||||
pub fn get_valid_action_indices(game_state: &GameState) -> anyhow::Result<Vec<usize>> {
|
||||
let actions = get_valid_actions(game_state)?;
|
||||
Ok(actions
|
||||
.into_iter()
|
||||
.map(|action| action.to_action_index())
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Sélectionne une action valide aléatoire
|
||||
pub fn sample_valid_action(game_state: &GameState) -> Option<TrictracAction> {
|
||||
use rand::{prelude::IndexedRandom, rng};
|
||||
|
||||
let valid_actions = get_valid_actions(game_state);
|
||||
let mut rng = rng();
|
||||
valid_actions
|
||||
.map(|va| va.choose(&mut rng).cloned())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 3,
|
||||
checker2: 4,
|
||||
};
|
||||
let index = action.to_action_index();
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||
assert_eq!(54, index);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_action_index() {
|
||||
let action = TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 3,
|
||||
checker2: 4,
|
||||
};
|
||||
assert_eq!(Some(action), TrictracAction::from_action_index(54));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_valid_actions() {
|
||||
let mut state = GameState::new_with_players("white", "black");
|
||||
state.active_player_id = 2;
|
||||
state.dice = Dice { values: (5, 3) };
|
||||
state.turn_stage = crate::TurnStage::Move;
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-3, -3, -2, -2, -2, -2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 3, 8,
|
||||
],
|
||||
);
|
||||
|
||||
let actions = vec![TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 11,
|
||||
checker2: 13,
|
||||
}];
|
||||
assert_eq!(Some(actions), super::get_valid_actions(&state).ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checker_moves_to_trictrac_action() {
|
||||
let mut state = GameState::new_with_players("white", "black");
|
||||
state.turn_stage = crate::TurnStage::Move;
|
||||
state.dice = Dice { values: (5, 3) };
|
||||
|
||||
// White player
|
||||
state.active_player_id = 1;
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-8, -3, -1, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 3, 3,
|
||||
],
|
||||
);
|
||||
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
&CheckerMove::new(23, 0).unwrap(),
|
||||
&CheckerMove::new(24, 0).unwrap(),
|
||||
&crate::Color::White,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 11,
|
||||
checker2: 13, // because the 11th has left
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
|
||||
// Black player
|
||||
state.active_player_id = 2;
|
||||
state.board.set_positions(
|
||||
&crate::Color::White,
|
||||
[
|
||||
-3, -3, -2, -2, -2, -2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 3, 8,
|
||||
],
|
||||
);
|
||||
let ttaction = super::checker_moves_to_trictrac_action(
|
||||
// &CheckerMove::new(2, 0).unwrap(),
|
||||
// &CheckerMove::new(1, 0).unwrap(),
|
||||
&CheckerMove::new(23, 0).unwrap(),
|
||||
&CheckerMove::new(24, 0).unwrap(),
|
||||
&crate::Color::Black,
|
||||
&state,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(TrictracAction::Move {
|
||||
dice_order: true,
|
||||
checker1: 11,
|
||||
checker2: 13, // because the 11th has left
|
||||
}),
|
||||
ttaction.ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue