Compare commits
No commits in common. "main" and "feature/botTrainValidMoves" have entirely different histories.
main
...
feature/bo
2660
Cargo.lock
generated
2660
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
members = ["client_cli", "bot", "store"]
|
members = ["client_tui", "client_cli", "bot", "server", "store"]
|
||||||
|
|
|
||||||
38
README.md
38
README.md
|
|
@ -1,41 +1,7 @@
|
||||||
# Trictrac
|
# Trictrac
|
||||||
|
|
||||||
This is a game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) rust implementation.
|
Game of [Trictrac](https://en.wikipedia.org/wiki/Trictrac) in rust.
|
||||||
|
|
||||||
The project is on its early stages.
|
wip
|
||||||
Rules (without "schools") are implemented, as well as a rudimentary terminal interface which allow you to play against a bot which plays randomly.
|
|
||||||
|
|
||||||
Training of AI bots is the work in progress.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
`cargo run --bin=client_cli -- --bot random`
|
|
||||||
|
|
||||||
## Roadmap
|
|
||||||
|
|
||||||
- [x] rules
|
|
||||||
- [x] command line interface
|
|
||||||
- [x] basic bot (random play)
|
|
||||||
- [ ] AI bot
|
|
||||||
- [ ] network game
|
|
||||||
- [ ] web client
|
|
||||||
|
|
||||||
## Code structure
|
|
||||||
|
|
||||||
- game rules and game state are implemented in the _store/_ folder.
|
|
||||||
- the command-line application is implemented in _client_cli/_; it allows you to play against a bot, or to have two bots play against each other
|
|
||||||
- the bots algorithms and the training of their models are implemented in the _bot/_ folder
|
|
||||||
|
|
||||||
### _store_ package
|
|
||||||
|
|
||||||
The game state is defined by the `GameState` struct in _store/src/game.rs_. The `to_string_id()` method allows this state to be encoded compactly in a string (without the played moves history). For a more readable textual representation, the `fmt::Display` trait is implemented.
|
|
||||||
|
|
||||||
### _client_cli_ package
|
|
||||||
|
|
||||||
`client_cli/src/game_runner.rs` contains the logic to make two bots play against each other.
|
|
||||||
|
|
||||||
### _bot_ package
|
|
||||||
|
|
||||||
- `bot/src/strategy/default.rs` contains the code for a basic bot strategy: it determines the list of valid moves (using the `get_possible_moves_sequences` method of `store::MoveRules`) and simply executes the first move in the list.
|
|
||||||
- `bot/src/strategy/dqnburn.rs` is another bot strategy that uses a reinforcement learning trained model with the DQN algorithm via the burn library (<https://burn.dev/>).
|
|
||||||
- `bot/scripts/trains.sh` allows you to train agents using different algorithms (DQN, PPO, SAC).
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,16 @@ edition = "2021"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "burn_train"
|
name = "train_dqn_burn_valid"
|
||||||
path = "src/burnrl/main.rs"
|
path = "src/dqn/burnrl_valid/main.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "train_dqn_burn"
|
||||||
|
path = "src/dqn/burnrl/main.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "train_dqn_simple"
|
||||||
|
path = "src/dqn/simple/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pretty_assertions = "1.4.0"
|
pretty_assertions = "1.4.0"
|
||||||
|
|
@ -16,9 +24,6 @@ serde_json = "1.0"
|
||||||
store = { path = "../store" }
|
store = { path = "../store" }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
burn = { version = "0.18", features = ["ndarray", "autodiff"] }
|
burn = { version = "0.17", features = ["ndarray", "autodiff"] }
|
||||||
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
|
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
|
||||||
log = "0.4.20"
|
log = "0.4.20"
|
||||||
confy = "1.0.0"
|
|
||||||
board-game = "0.8.2"
|
|
||||||
internal-iterator = "0.2.3"
|
|
||||||
|
|
|
||||||
|
|
@ -1,50 +1,38 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
|
||||||
LOGS_DIR="$ROOT/bot/models/logs"
|
LOGS_DIR="$ROOT/bot/models/logs"
|
||||||
|
|
||||||
CFG_SIZE=17
|
CFG_SIZE=12
|
||||||
BINBOT=burn_train
|
|
||||||
# BINBOT=train_ppo_burn
|
|
||||||
# BINBOT=train_dqn_burn
|
|
||||||
# BINBOT=train_dqn_burn_big
|
|
||||||
# BINBOT=train_dqn_burn_before
|
|
||||||
OPPONENT="random"
|
OPPONENT="random"
|
||||||
|
|
||||||
PLOT_EXT="png"
|
PLOT_EXT="png"
|
||||||
|
|
||||||
train() {
|
train() {
|
||||||
ALGO=$1
|
cargo build --release --bin=train_dqn_burn
|
||||||
cargo build --release --bin=$BINBOT
|
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
|
||||||
NAME="$(date +%Y-%m-%d_%H:%M:%S)"
|
LOGS="$LOGS_DIR/$NAME.out"
|
||||||
LOGS="$LOGS_DIR/$ALGO/$NAME.out"
|
mkdir -p "$LOGS_DIR"
|
||||||
mkdir -p "$LOGS_DIR/$ALGO"
|
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
|
||||||
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" $ALGO | tee "$LOGS"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
plot() {
|
plot() {
|
||||||
ALGO=$1
|
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||||
NAME=$(ls -rt "$LOGS_DIR/$ALGO" | grep -v png | tail -n 1)
|
LOGS="$LOGS_DIR/$NAME"
|
||||||
LOGS="$LOGS_DIR/$ALGO/$NAME"
|
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||||
cfgs=$(grep -v "info:" "$LOGS" | head -n $CFG_SIZE)
|
|
||||||
for cfg in $cfgs; do
|
for cfg in $cfgs; do
|
||||||
eval "$cfg"
|
eval "$cfg"
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
|
||||||
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
||||||
grep -v "info:" |
|
grep -v "info:" |
|
||||||
awk -F '[ ,]' '{print $5}' |
|
awk -F '[ ,]' '{print $5}' |
|
||||||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT"
|
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||||
}
|
}
|
||||||
|
|
||||||
if [[ -z "$1" ]]; then
|
if [ "$1" = "plot" ]; then
|
||||||
echo "Usage : train [plot] <algo>"
|
plot
|
||||||
elif [ "$1" = "plot" ]; then
|
|
||||||
if [[ -z "$2" ]]; then
|
|
||||||
echo "Usage : train [plot] <algo>"
|
|
||||||
else
|
|
||||||
plot $2
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
train $1
|
train
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ train() {
|
||||||
}
|
}
|
||||||
|
|
||||||
plot() {
|
plot() {
|
||||||
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
|
NAME=$(ls "$LOGS_DIR" | tail -n 1)
|
||||||
LOGS="$LOGS_DIR/$NAME"
|
LOGS="$LOGS_DIR/$NAME"
|
||||||
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
cfgs=$(head -n $CFG_SIZE "$LOGS")
|
||||||
for cfg in $cfgs; do
|
for cfg in $cfgs; do
|
||||||
|
|
@ -31,19 +31,8 @@ plot() {
|
||||||
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
|
||||||
}
|
}
|
||||||
|
|
||||||
avg() {
|
|
||||||
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
|
|
||||||
LOGS="$LOGS_DIR/$NAME"
|
|
||||||
echo $LOGS
|
|
||||||
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
|
|
||||||
grep -v "info:" |
|
|
||||||
awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }'
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ "$1" = "plot" ]; then
|
if [ "$1" = "plot" ]; then
|
||||||
plot
|
plot
|
||||||
elif [ "$1" = "avg" ]; then
|
|
||||||
avg
|
|
||||||
else
|
else
|
||||||
train
|
train
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
pub mod dqn;
|
|
||||||
pub mod dqn_valid;
|
|
||||||
pub mod ppo;
|
|
||||||
pub mod ppo_valid;
|
|
||||||
pub mod sac;
|
|
||||||
pub mod sac_valid;
|
|
||||||
|
|
@ -1,191 +0,0 @@
|
||||||
use crate::burnrl::environment::TrictracEnvironment;
|
|
||||||
use crate::burnrl::utils::Config;
|
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
|
||||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
|
||||||
use burn::optim::AdamWConfig;
|
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::{relu, softmax};
|
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
|
||||||
use burn::tensor::Tensor;
|
|
||||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
|
||||||
use std::env;
|
|
||||||
use std::fs;
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Net<B: Backend> {
|
|
||||||
linear: Linear<B>,
|
|
||||||
linear_actor: Linear<B>,
|
|
||||||
linear_critic: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Net<B> {
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
|
||||||
Self {
|
|
||||||
linear: LinearConfig::new(input_size, dense_size)
|
|
||||||
.with_initializer(initializer.clone())
|
|
||||||
.init(&Default::default()),
|
|
||||||
linear_actor: LinearConfig::new(dense_size, output_size)
|
|
||||||
.with_initializer(initializer.clone())
|
|
||||||
.init(&Default::default()),
|
|
||||||
linear_critic: LinearConfig::new(dense_size, 1)
|
|
||||||
.with_initializer(initializer)
|
|
||||||
.init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
|
||||||
let layer_0_output = relu(self.linear.forward(input));
|
|
||||||
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
|
||||||
let values = self.linear_critic.forward(layer_0_output);
|
|
||||||
|
|
||||||
PPOOutput::<B>::new(policies, values)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear.forward(input));
|
|
||||||
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> PPOModel<B> for Net<B> {}
|
|
||||||
#[allow(unused)]
|
|
||||||
const MEMORY_SIZE: usize = 512;
|
|
||||||
|
|
||||||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn run<
|
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
|
||||||
// ) -> PPO<E, B, Net<B>> {
|
|
||||||
) -> impl Agent<E> {
|
|
||||||
let mut env = E::new(visualized);
|
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
|
||||||
|
|
||||||
let mut model = Net::<B>::new(
|
|
||||||
<<E as Environment>::StateType as State>::size(),
|
|
||||||
conf.dense_size,
|
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
|
||||||
);
|
|
||||||
let agent = MyAgent::default();
|
|
||||||
let config = PPOTrainingConfig {
|
|
||||||
gamma: conf.gamma,
|
|
||||||
lambda: conf.lambda,
|
|
||||||
epsilon_clip: conf.epsilon_clip,
|
|
||||||
critic_weight: conf.critic_weight,
|
|
||||||
entropy_weight: conf.entropy_weight,
|
|
||||||
learning_rate: conf.learning_rate,
|
|
||||||
epochs: conf.epochs,
|
|
||||||
batch_size: conf.batch_size,
|
|
||||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
|
||||||
conf.clip_grad,
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut optimizer = AdamWConfig::new()
|
|
||||||
.with_grad_clipping(config.clip_grad.clone())
|
|
||||||
.init();
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
|
||||||
for episode in 0..conf.num_episodes {
|
|
||||||
let mut episode_done = false;
|
|
||||||
let mut episode_reward = 0.0;
|
|
||||||
let mut episode_duration = 0_usize;
|
|
||||||
let mut now = SystemTime::now();
|
|
||||||
|
|
||||||
env.reset();
|
|
||||||
while !episode_done {
|
|
||||||
let state = env.state();
|
|
||||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
|
||||||
let snapshot = env.step(action);
|
|
||||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
memory.push(
|
|
||||||
state,
|
|
||||||
*snapshot.state(),
|
|
||||||
action,
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
snapshot.done(),
|
|
||||||
);
|
|
||||||
|
|
||||||
episode_duration += 1;
|
|
||||||
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!(
|
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
|
||||||
now.elapsed().unwrap().as_secs(),
|
|
||||||
);
|
|
||||||
|
|
||||||
now = SystemTime::now();
|
|
||||||
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
|
||||||
memory.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
let device = NdArrayDevice::default();
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
|
||||||
|
|
||||||
// Save the trained model (backend B) to a temporary file
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), tmp_path.clone())
|
|
||||||
.expect("Failed to save temporary model");
|
|
||||||
|
|
||||||
// Create a new model instance with the target backend (NdArray)
|
|
||||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
|
||||||
<<E as Environment>::StateType as State>::size(),
|
|
||||||
conf.dense_size,
|
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Load the record from the temporary file into the new model
|
|
||||||
let record = recorder
|
|
||||||
.load(tmp_path.clone(), &device)
|
|
||||||
.expect("Failed to load temporary model");
|
|
||||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
|
||||||
|
|
||||||
// Clean up the temporary file
|
|
||||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
|
||||||
|
|
||||||
save_model(&model_with_loaded_weights, path);
|
|
||||||
}
|
|
||||||
agent.valid(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Net::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
|
||||||
|
|
@ -1,191 +0,0 @@
|
||||||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
|
||||||
use crate::burnrl::utils::Config;
|
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
|
||||||
use burn::nn::{Initializer, Linear, LinearConfig};
|
|
||||||
use burn::optim::AdamWConfig;
|
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::{relu, softmax};
|
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
|
||||||
use burn::tensor::Tensor;
|
|
||||||
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
|
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
|
||||||
use std::env;
|
|
||||||
use std::fs;
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Net<B: Backend> {
|
|
||||||
linear: Linear<B>,
|
|
||||||
linear_actor: Linear<B>,
|
|
||||||
linear_critic: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Net<B> {
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
let initializer = Initializer::XavierUniform { gain: 1.0 };
|
|
||||||
Self {
|
|
||||||
linear: LinearConfig::new(input_size, dense_size)
|
|
||||||
.with_initializer(initializer.clone())
|
|
||||||
.init(&Default::default()),
|
|
||||||
linear_actor: LinearConfig::new(dense_size, output_size)
|
|
||||||
.with_initializer(initializer.clone())
|
|
||||||
.init(&Default::default()),
|
|
||||||
linear_critic: LinearConfig::new(dense_size, 1)
|
|
||||||
.with_initializer(initializer)
|
|
||||||
.init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
|
|
||||||
let layer_0_output = relu(self.linear.forward(input));
|
|
||||||
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
|
|
||||||
let values = self.linear_critic.forward(layer_0_output);
|
|
||||||
|
|
||||||
PPOOutput::<B>::new(policies, values)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear.forward(input));
|
|
||||||
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> PPOModel<B> for Net<B> {}
|
|
||||||
#[allow(unused)]
|
|
||||||
const MEMORY_SIZE: usize = 512;
|
|
||||||
|
|
||||||
type MyAgent<E, B> = PPO<E, B, Net<B>>;
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn run<
|
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
|
||||||
// ) -> PPO<E, B, Net<B>> {
|
|
||||||
) -> impl Agent<E> {
|
|
||||||
let mut env = E::new(visualized);
|
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
|
||||||
|
|
||||||
let mut model = Net::<B>::new(
|
|
||||||
<<E as Environment>::StateType as State>::size(),
|
|
||||||
conf.dense_size,
|
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
|
||||||
);
|
|
||||||
let agent = MyAgent::default();
|
|
||||||
let config = PPOTrainingConfig {
|
|
||||||
gamma: conf.gamma,
|
|
||||||
lambda: conf.lambda,
|
|
||||||
epsilon_clip: conf.epsilon_clip,
|
|
||||||
critic_weight: conf.critic_weight,
|
|
||||||
entropy_weight: conf.entropy_weight,
|
|
||||||
learning_rate: conf.learning_rate,
|
|
||||||
epochs: conf.epochs,
|
|
||||||
batch_size: conf.batch_size,
|
|
||||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
|
||||||
conf.clip_grad,
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut optimizer = AdamWConfig::new()
|
|
||||||
.with_grad_clipping(config.clip_grad.clone())
|
|
||||||
.init();
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
|
||||||
for episode in 0..conf.num_episodes {
|
|
||||||
let mut episode_done = false;
|
|
||||||
let mut episode_reward = 0.0;
|
|
||||||
let mut episode_duration = 0_usize;
|
|
||||||
let mut now = SystemTime::now();
|
|
||||||
|
|
||||||
env.reset();
|
|
||||||
while !episode_done {
|
|
||||||
let state = env.state();
|
|
||||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
|
|
||||||
let snapshot = env.step(action);
|
|
||||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
memory.push(
|
|
||||||
state,
|
|
||||||
*snapshot.state(),
|
|
||||||
action,
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
snapshot.done(),
|
|
||||||
);
|
|
||||||
|
|
||||||
episode_duration += 1;
|
|
||||||
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!(
|
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
|
||||||
now.elapsed().unwrap().as_secs(),
|
|
||||||
);
|
|
||||||
|
|
||||||
now = SystemTime::now();
|
|
||||||
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
|
|
||||||
memory.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
let device = NdArrayDevice::default();
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let tmp_path = env::temp_dir().join("tmp_model.mpk");
|
|
||||||
|
|
||||||
// Save the trained model (backend B) to a temporary file
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), tmp_path.clone())
|
|
||||||
.expect("Failed to save temporary model");
|
|
||||||
|
|
||||||
// Create a new model instance with the target backend (NdArray)
|
|
||||||
let model_to_save: Net<NdArray<ElemType>> = Net::new(
|
|
||||||
<<E as Environment>::StateType as State>::size(),
|
|
||||||
conf.dense_size,
|
|
||||||
<<E as Environment>::ActionType as Action>::size(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Load the record from the temporary file into the new model
|
|
||||||
let record = recorder
|
|
||||||
.load(tmp_path.clone(), &device)
|
|
||||||
.expect("Failed to load temporary model");
|
|
||||||
let model_with_loaded_weights = model_to_save.load_record(record);
|
|
||||||
|
|
||||||
// Clean up the temporary file
|
|
||||||
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
|
|
||||||
|
|
||||||
save_model(&model_with_loaded_weights, path);
|
|
||||||
}
|
|
||||||
agent.valid(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Net::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
|
||||||
|
|
@ -1,221 +0,0 @@
|
||||||
use crate::burnrl::environment::TrictracEnvironment;
|
|
||||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
|
||||||
use burn::nn::{Linear, LinearConfig};
|
|
||||||
use burn::optim::AdamWConfig;
|
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::{relu, softmax};
|
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
|
||||||
use burn::tensor::Tensor;
|
|
||||||
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Actor<B: Backend> {
|
|
||||||
linear_0: Linear<B>,
|
|
||||||
linear_1: Linear<B>,
|
|
||||||
linear_2: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Actor<B> {
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
|
||||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear_0.forward(input));
|
|
||||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
|
||||||
|
|
||||||
softmax(self.linear_2.forward(layer_1_output), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
self.forward(input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> SACActor<B> for Actor<B> {}
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Critic<B: Backend> {
|
|
||||||
linear_0: Linear<B>,
|
|
||||||
linear_1: Linear<B>,
|
|
||||||
linear_2: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Critic<B> {
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
|
||||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
|
||||||
(self.linear_0, self.linear_1, self.linear_2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear_0.forward(input));
|
|
||||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
|
||||||
|
|
||||||
self.linear_2.forward(layer_1_output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
self.forward(input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> SACCritic<B> for Critic<B> {
|
|
||||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
|
||||||
let (linear_0, linear_1, linear_2) = this.consume();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
|
||||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
|
||||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
const MEMORY_SIZE: usize = 4096;
|
|
||||||
|
|
||||||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn run<
|
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
|
||||||
) -> impl Agent<E> {
|
|
||||||
let mut env = E::new(visualized);
|
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
|
||||||
let state_dim = <<E as Environment>::StateType as State>::size();
|
|
||||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
|
||||||
|
|
||||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
|
||||||
|
|
||||||
let mut agent = MyAgent::default();
|
|
||||||
|
|
||||||
let config = SACTrainingConfig {
|
|
||||||
gamma: conf.gamma,
|
|
||||||
tau: conf.tau,
|
|
||||||
learning_rate: conf.learning_rate,
|
|
||||||
min_probability: conf.min_probability,
|
|
||||||
batch_size: conf.batch_size,
|
|
||||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
|
||||||
conf.clip_grad,
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
|
||||||
|
|
||||||
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
|
||||||
|
|
||||||
let mut optimizer = SACOptimizer::new(
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.init(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut step = 0_usize;
|
|
||||||
|
|
||||||
for episode in 0..conf.num_episodes {
|
|
||||||
let mut episode_done = false;
|
|
||||||
let mut episode_reward = 0.0;
|
|
||||||
let mut episode_duration = 0_usize;
|
|
||||||
let mut state = env.state();
|
|
||||||
let mut now = SystemTime::now();
|
|
||||||
|
|
||||||
while !episode_done {
|
|
||||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
|
||||||
let snapshot = env.step(action);
|
|
||||||
|
|
||||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
memory.push(
|
|
||||||
state,
|
|
||||||
*snapshot.state(),
|
|
||||||
action,
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
snapshot.done(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if config.batch_size < memory.len() {
|
|
||||||
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
|
||||||
}
|
|
||||||
|
|
||||||
step += 1;
|
|
||||||
episode_duration += 1;
|
|
||||||
|
|
||||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
|
||||||
env.reset();
|
|
||||||
episode_done = true;
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
|
||||||
now.elapsed().unwrap().as_secs()
|
|
||||||
);
|
|
||||||
now = SystemTime::now();
|
|
||||||
} else {
|
|
||||||
state = *snapshot.state();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let valid_agent = agent.valid(nets.actor);
|
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
if let Some(model) = valid_agent.model() {
|
|
||||||
save_model(model, path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
valid_agent
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Actor::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
|
||||||
|
|
@ -1,222 +0,0 @@
|
||||||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
|
||||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
|
||||||
use burn::nn::{Linear, LinearConfig};
|
|
||||||
use burn::optim::AdamWConfig;
|
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::{relu, softmax};
|
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
|
||||||
use burn::tensor::Tensor;
|
|
||||||
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
|
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Actor<B: Backend> {
|
|
||||||
linear_0: Linear<B>,
|
|
||||||
linear_1: Linear<B>,
|
|
||||||
linear_2: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Actor<B> {
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
|
||||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear_0.forward(input));
|
|
||||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
|
||||||
|
|
||||||
softmax(self.linear_2.forward(layer_1_output), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
self.forward(input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> SACActor<B> for Actor<B> {}
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
|
||||||
pub struct Critic<B: Backend> {
|
|
||||||
linear_0: Linear<B>,
|
|
||||||
linear_1: Linear<B>,
|
|
||||||
linear_2: Linear<B>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Critic<B> {
|
|
||||||
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
|
|
||||||
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
|
|
||||||
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
|
|
||||||
(self.linear_0, self.linear_1, self.linear_2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
|
|
||||||
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
let layer_0_output = relu(self.linear_0.forward(input));
|
|
||||||
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
|
|
||||||
|
|
||||||
self.linear_2.forward(layer_1_output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
|
||||||
self.forward(input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> SACCritic<B> for Critic<B> {
|
|
||||||
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
|
|
||||||
let (linear_0, linear_1, linear_2) = this.consume();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
|
|
||||||
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
|
|
||||||
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
const MEMORY_SIZE: usize = 4096;
|
|
||||||
|
|
||||||
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn run<
|
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
|
||||||
) -> impl Agent<E> {
|
|
||||||
let mut env = E::new(visualized);
|
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
|
||||||
let state_dim = <<E as Environment>::StateType as State>::size();
|
|
||||||
let action_dim = <<E as Environment>::ActionType as Action>::size();
|
|
||||||
|
|
||||||
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
|
|
||||||
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
|
|
||||||
|
|
||||||
let mut agent = MyAgent::default();
|
|
||||||
|
|
||||||
let config = SACTrainingConfig {
|
|
||||||
gamma: conf.gamma,
|
|
||||||
tau: conf.tau,
|
|
||||||
learning_rate: conf.learning_rate,
|
|
||||||
min_probability: conf.min_probability,
|
|
||||||
batch_size: conf.batch_size,
|
|
||||||
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
|
|
||||||
conf.clip_grad,
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
|
|
||||||
|
|
||||||
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
|
|
||||||
|
|
||||||
let mut optimizer = SACOptimizer::new(
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.clone().init(),
|
|
||||||
optimizer_config.init(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut step = 0_usize;
|
|
||||||
|
|
||||||
for episode in 0..conf.num_episodes {
|
|
||||||
let mut episode_done = false;
|
|
||||||
let mut episode_reward = 0.0;
|
|
||||||
let mut episode_duration = 0_usize;
|
|
||||||
let mut state = env.state();
|
|
||||||
let mut now = SystemTime::now();
|
|
||||||
|
|
||||||
while !episode_done {
|
|
||||||
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
|
|
||||||
let snapshot = env.step(action);
|
|
||||||
|
|
||||||
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
memory.push(
|
|
||||||
state,
|
|
||||||
*snapshot.state(),
|
|
||||||
action,
|
|
||||||
snapshot.reward().clone(),
|
|
||||||
snapshot.done(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if config.batch_size < memory.len() {
|
|
||||||
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
|
|
||||||
}
|
|
||||||
|
|
||||||
step += 1;
|
|
||||||
episode_duration += 1;
|
|
||||||
|
|
||||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
|
||||||
env.reset();
|
|
||||||
episode_done = true;
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
|
|
||||||
now.elapsed().unwrap().as_secs()
|
|
||||||
);
|
|
||||||
now = SystemTime::now();
|
|
||||||
} else {
|
|
||||||
state = *snapshot.state();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let valid_agent = agent.valid(nets.actor);
|
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
if let Some(model) = valid_agent.model() {
|
|
||||||
save_model(model, path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
valid_agent
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Actor::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
use bot::burnrl::algos::{dqn, dqn_valid, ppo, ppo_valid, sac, sac_valid};
|
|
||||||
use bot::burnrl::environment::TrictracEnvironment;
|
|
||||||
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
|
|
||||||
use bot::burnrl::utils::{demo_model, Config};
|
|
||||||
use burn::backend::{Autodiff, NdArray};
|
|
||||||
use burn_rl::base::ElemType;
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
type Backend = Autodiff<NdArray<ElemType>>;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let args: Vec<String> = env::args().collect();
|
|
||||||
let algo = &args[1];
|
|
||||||
// let dir_path = &args[2];
|
|
||||||
|
|
||||||
let path = format!("bot/models/burnrl_{algo}");
|
|
||||||
println!(
|
|
||||||
"info: loading configuration from file {:?}",
|
|
||||||
confy::get_configuration_file_path("trictrac_bot", None).unwrap()
|
|
||||||
);
|
|
||||||
let mut conf: Config = confy::load("trictrac_bot", None).expect("Could not load config");
|
|
||||||
conf.save_path = Some(path.clone());
|
|
||||||
println!("{conf}----------");
|
|
||||||
|
|
||||||
match algo.as_str() {
|
|
||||||
"dqn" => {
|
|
||||||
let _agent = dqn::run::<TrictracEnvironment, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = dqn::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
|
|
||||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
"dqn_valid" => {
|
|
||||||
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = dqn_valid::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentValid, _, _> =
|
|
||||||
burn_rl::agent::DQN::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
"sac" => {
|
|
||||||
let _agent = sac::run::<TrictracEnvironment, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = sac::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
|
|
||||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
"sac_valid" => {
|
|
||||||
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = sac_valid::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentValid, _, _> =
|
|
||||||
burn_rl::agent::SAC::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
"ppo" => {
|
|
||||||
let _agent = ppo::run::<TrictracEnvironment, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = ppo::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
|
|
||||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
"ppo_valid" => {
|
|
||||||
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
|
|
||||||
println!("> Chargement du modèle pour test");
|
|
||||||
let loaded_model = ppo_valid::load_model(conf.dense_size, &path);
|
|
||||||
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentValid, _, _> =
|
|
||||||
burn_rl::agent::PPO::new(loaded_model.unwrap());
|
|
||||||
|
|
||||||
println!("> Test avec le modèle chargé");
|
|
||||||
demo_model(loaded_agent);
|
|
||||||
}
|
|
||||||
&_ => {
|
|
||||||
println!("unknown algo {algo}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
pub mod algos;
|
|
||||||
pub mod environment;
|
|
||||||
pub mod environment_valid;
|
|
||||||
pub mod utils;
|
|
||||||
|
|
@ -1,132 +0,0 @@
|
||||||
use burn::module::{Param, ParamId};
|
|
||||||
use burn::nn::Linear;
|
|
||||||
use burn::tensor::backend::Backend;
|
|
||||||
use burn::tensor::Tensor;
|
|
||||||
use burn_rl::base::{Agent, ElemType, Environment};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub save_path: Option<String>,
|
|
||||||
pub max_steps: usize, // max steps by episode
|
|
||||||
pub num_episodes: usize,
|
|
||||||
pub dense_size: usize, // neural network complexity
|
|
||||||
|
|
||||||
// discount factor. Plus élevé = encourage stratégies à long terme
|
|
||||||
pub gamma: f32,
|
|
||||||
// soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation plus lente moins sensible aux coups de chance
|
|
||||||
pub tau: f32,
|
|
||||||
// taille du pas. Bas : plus lent, haut : risque de ne jamais
|
|
||||||
pub learning_rate: f32,
|
|
||||||
// nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
|
||||||
pub batch_size: usize,
|
|
||||||
// limite max de correction à apporter au gradient (default 100)
|
|
||||||
pub clip_grad: f32,
|
|
||||||
|
|
||||||
// ---- for SAC
|
|
||||||
pub min_probability: f32,
|
|
||||||
|
|
||||||
// ---- for DQN
|
|
||||||
// epsilon initial value (0.9 => more exploration)
|
|
||||||
pub eps_start: f64,
|
|
||||||
pub eps_end: f64,
|
|
||||||
// eps_decay higher = epsilon decrease slower
|
|
||||||
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
|
||||||
// epsilon is updated at the start of each episode
|
|
||||||
pub eps_decay: f64,
|
|
||||||
|
|
||||||
// ---- for PPO
|
|
||||||
pub lambda: f32,
|
|
||||||
pub epsilon_clip: f32,
|
|
||||||
pub critic_weight: f32,
|
|
||||||
pub entropy_weight: f32,
|
|
||||||
pub epochs: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
save_path: None,
|
|
||||||
max_steps: 2000,
|
|
||||||
num_episodes: 1000,
|
|
||||||
dense_size: 256,
|
|
||||||
gamma: 0.999,
|
|
||||||
tau: 0.005,
|
|
||||||
learning_rate: 0.001,
|
|
||||||
batch_size: 32,
|
|
||||||
clip_grad: 100.0,
|
|
||||||
min_probability: 1e-9,
|
|
||||||
eps_start: 0.9,
|
|
||||||
eps_end: 0.05,
|
|
||||||
eps_decay: 1000.0,
|
|
||||||
lambda: 0.95,
|
|
||||||
epsilon_clip: 0.2,
|
|
||||||
critic_weight: 0.5,
|
|
||||||
entropy_weight: 0.01,
|
|
||||||
epochs: 8,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for Config {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
let mut s = String::new();
|
|
||||||
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
|
||||||
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
|
||||||
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
|
||||||
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
|
||||||
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
|
||||||
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
|
||||||
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
|
||||||
s.push_str(&format!("tau={:?}\n", self.tau));
|
|
||||||
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
|
||||||
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
|
||||||
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
|
||||||
s.push_str(&format!("min_probability={:?}\n", self.min_probability));
|
|
||||||
s.push_str(&format!("lambda={:?}\n", self.lambda));
|
|
||||||
s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip));
|
|
||||||
s.push_str(&format!("critic_weight={:?}\n", self.critic_weight));
|
|
||||||
s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight));
|
|
||||||
s.push_str(&format!("epochs={:?}\n", self.epochs));
|
|
||||||
write!(f, "{s}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
|
||||||
let mut env = E::new(true);
|
|
||||||
let mut state = env.state();
|
|
||||||
let mut done = false;
|
|
||||||
while !done {
|
|
||||||
if let Some(action) = agent.react(&state) {
|
|
||||||
let snapshot = env.step(action);
|
|
||||||
state = *snapshot.state();
|
|
||||||
done = snapshot.done();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn soft_update_tensor<const N: usize, B: Backend>(
|
|
||||||
this: &Param<Tensor<B, N>>,
|
|
||||||
that: &Param<Tensor<B, N>>,
|
|
||||||
tau: ElemType,
|
|
||||||
) -> Param<Tensor<B, N>> {
|
|
||||||
let that_weight = that.val();
|
|
||||||
let this_weight = this.val();
|
|
||||||
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
|
||||||
|
|
||||||
Param::initialized(ParamId::new(), new_weight)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn soft_update_linear<B: Backend>(
|
|
||||||
this: Linear<B>,
|
|
||||||
that: &Linear<B>,
|
|
||||||
tau: ElemType,
|
|
||||||
) -> Linear<B> {
|
|
||||||
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
|
||||||
let bias = match (&this.bias, &that.bias) {
|
|
||||||
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Linear::<B> { weight, bias }
|
|
||||||
}
|
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
use crate::burnrl::environment_valid::TrictracEnvironment;
|
use crate::dqn::burnrl::environment::TrictracEnvironment;
|
||||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
use crate::dqn::burnrl::utils::soft_update_linear;
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::optim::AdamWConfig;
|
use burn::optim::AdamWConfig;
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::relu;
|
use burn::tensor::activation::relu;
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::fmt;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
|
|
@ -63,19 +62,71 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
const MEMORY_SIZE: usize = 8192;
|
const MEMORY_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
pub struct DqnConfig {
|
||||||
|
pub min_steps: f32,
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub num_episodes: usize,
|
||||||
|
pub dense_size: usize,
|
||||||
|
pub eps_start: f64,
|
||||||
|
pub eps_end: f64,
|
||||||
|
pub eps_decay: f64,
|
||||||
|
|
||||||
|
pub gamma: f32,
|
||||||
|
pub tau: f32,
|
||||||
|
pub learning_rate: f32,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub clip_grad: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for DqnConfig {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
let mut s = String::new();
|
||||||
|
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
|
||||||
|
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||||
|
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||||
|
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||||
|
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||||
|
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||||
|
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||||
|
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||||
|
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||||
|
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||||
|
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||||
|
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||||
|
write!(f, "{s}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
min_steps: 250.0,
|
||||||
|
max_steps: 2000,
|
||||||
|
num_episodes: 1000,
|
||||||
|
dense_size: 256,
|
||||||
|
eps_start: 0.9,
|
||||||
|
eps_end: 0.05,
|
||||||
|
eps_decay: 1000.0,
|
||||||
|
|
||||||
|
gamma: 0.999,
|
||||||
|
tau: 0.005,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 32,
|
||||||
|
clip_grad: 100.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||||
pub fn run<
|
conf: &DqnConfig,
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
// ) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
|
env.as_mut().min_steps = conf.min_steps;
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
|
|
@ -143,7 +194,8 @@ pub fn run<
|
||||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
let envmut = env.as_mut();
|
let envmut = env.as_mut();
|
||||||
println!(
|
println!(
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||||
|
envmut.goodmoves_count,
|
||||||
envmut.pointrolls_count,
|
envmut.pointrolls_count,
|
||||||
now.elapsed().unwrap().as_secs(),
|
now.elapsed().unwrap().as_secs(),
|
||||||
);
|
);
|
||||||
|
|
@ -155,35 +207,5 @@ pub fn run<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let valid_agent = agent.valid();
|
agent
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
save_model(valid_agent.model().as_ref().unwrap(), path);
|
|
||||||
}
|
|
||||||
valid_agent
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Net::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,16 +1,9 @@
|
||||||
use std::io::Write;
|
use crate::dqn::dqn_common;
|
||||||
|
|
||||||
use crate::training_common;
|
|
||||||
use burn::{prelude::Backend, tensor::Tensor};
|
use burn::{prelude::Backend, tensor::Tensor};
|
||||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -1.0012121;
|
|
||||||
const REWARD_VALID_MOVE: f32 = 1.0012121;
|
|
||||||
const REWARD_RATIO: f32 = 0.1;
|
|
||||||
const WIN_POINTS: f32 = 100.0;
|
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct TrictracState {
|
pub struct TrictracState {
|
||||||
|
|
@ -66,7 +59,7 @@ impl Action for TrictracAction {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size() -> usize {
|
fn size() -> usize {
|
||||||
514
|
1252
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -91,7 +84,7 @@ pub struct TrictracEnvironment {
|
||||||
current_state: TrictracState,
|
current_state: TrictracState,
|
||||||
episode_reward: f32,
|
episode_reward: f32,
|
||||||
pub step_count: usize,
|
pub step_count: usize,
|
||||||
pub best_ratio: f32,
|
pub min_steps: f32,
|
||||||
pub max_steps: usize,
|
pub max_steps: usize,
|
||||||
pub pointrolls_count: usize,
|
pub pointrolls_count: usize,
|
||||||
pub goodmoves_count: usize,
|
pub goodmoves_count: usize,
|
||||||
|
|
@ -124,7 +117,7 @@ impl Environment for TrictracEnvironment {
|
||||||
current_state,
|
current_state,
|
||||||
episode_reward: 0.0,
|
episode_reward: 0.0,
|
||||||
step_count: 0,
|
step_count: 0,
|
||||||
best_ratio: 0.0,
|
min_steps: 250.0,
|
||||||
max_steps: 2000,
|
max_steps: 2000,
|
||||||
pointrolls_count: 0,
|
pointrolls_count: 0,
|
||||||
goodmoves_count: 0,
|
goodmoves_count: 0,
|
||||||
|
|
@ -139,7 +132,6 @@ impl Environment for TrictracEnvironment {
|
||||||
|
|
||||||
fn reset(&mut self) -> Snapshot<Self> {
|
fn reset(&mut self) -> Snapshot<Self> {
|
||||||
// Réinitialiser le jeu
|
// Réinitialiser le jeu
|
||||||
let history = self.game.history.clone();
|
|
||||||
self.game = GameState::new(false);
|
self.game = GameState::new(false);
|
||||||
self.game.init_player("DQN Agent");
|
self.game.init_player("DQN Agent");
|
||||||
self.game.init_player("Opponent");
|
self.game.init_player("Opponent");
|
||||||
|
|
@ -154,22 +146,11 @@ impl Environment for TrictracEnvironment {
|
||||||
} else {
|
} else {
|
||||||
self.goodmoves_count as f32 / self.step_count as f32
|
self.goodmoves_count as f32 / self.step_count as f32
|
||||||
};
|
};
|
||||||
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
|
println!(
|
||||||
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
|
"info: correct moves: {} ({}%)",
|
||||||
let path = "bot/models/logs/debug.log";
|
self.goodmoves_count,
|
||||||
if let Ok(mut out) = std::fs::File::create(path) {
|
(100.0 * self.goodmoves_ratio).round() as u32
|
||||||
write!(out, "{history:?}").expect("could not write history log");
|
);
|
||||||
}
|
|
||||||
"!!!!"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
};
|
|
||||||
// println!(
|
|
||||||
// "info: correct moves: {} ({}%) {}",
|
|
||||||
// self.goodmoves_count,
|
|
||||||
// (100.0 * self.goodmoves_ratio).round() as u32,
|
|
||||||
// warning
|
|
||||||
// );
|
|
||||||
self.step_count = 0;
|
self.step_count = 0;
|
||||||
self.pointrolls_count = 0;
|
self.pointrolls_count = 0;
|
||||||
self.goodmoves_count = 0;
|
self.goodmoves_count = 0;
|
||||||
|
|
@ -184,7 +165,8 @@ impl Environment for TrictracEnvironment {
|
||||||
let trictrac_action = Self::convert_action(action);
|
let trictrac_action = Self::convert_action(action);
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let is_rollpoint;
|
let mut is_rollpoint = false;
|
||||||
|
let mut terminated = false;
|
||||||
|
|
||||||
// Exécuter l'action si c'est le tour de l'agent DQN
|
// Exécuter l'action si c'est le tour de l'agent DQN
|
||||||
if self.game.active_player_id == self.active_player_id {
|
if self.game.active_player_id == self.active_player_id {
|
||||||
|
|
@ -193,13 +175,12 @@ impl Environment for TrictracEnvironment {
|
||||||
if is_rollpoint {
|
if is_rollpoint {
|
||||||
self.pointrolls_count += 1;
|
self.pointrolls_count += 1;
|
||||||
}
|
}
|
||||||
if reward != ERROR_REWARD {
|
if reward != Self::ERROR_REWARD {
|
||||||
self.goodmoves_count += 1;
|
self.goodmoves_count += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Action non convertible, pénalité
|
// Action non convertible, pénalité
|
||||||
panic!("action non convertible");
|
reward = -0.5;
|
||||||
//reward = -0.5;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -209,24 +190,22 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vérifier si la partie est terminée
|
// Vérifier si la partie est terminée
|
||||||
// let max_steps = self.max_steps;
|
let max_steps = self.min_steps
|
||||||
// let max_steps = self.min_steps
|
+ (self.max_steps as f32 - self.min_steps)
|
||||||
// + (self.max_steps as f32 - self.min_steps)
|
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
||||||
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
|
|
||||||
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
|
||||||
|
|
||||||
if done {
|
if done {
|
||||||
// Récompense finale basée sur le résultat
|
// Récompense finale basée sur le résultat
|
||||||
if let Some(winner_id) = self.game.determine_winner() {
|
if let Some(winner_id) = self.game.determine_winner() {
|
||||||
if winner_id == self.active_player_id {
|
if winner_id == self.active_player_id {
|
||||||
reward += WIN_POINTS; // Victoire
|
reward += 50.0; // Victoire
|
||||||
} else {
|
} else {
|
||||||
reward -= WIN_POINTS; // Défaite
|
reward -= 25.0; // Défaite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let terminated = done || self.step_count >= self.max_steps;
|
let terminated = done || self.step_count >= max_steps.round() as usize;
|
||||||
// let terminated = done || self.step_count >= max_steps.round() as usize;
|
|
||||||
|
|
||||||
// Mettre à jour l'état
|
// Mettre à jour l'état
|
||||||
self.current_state = TrictracState::from_game_state(&self.game);
|
self.current_state = TrictracState::from_game_state(&self.game);
|
||||||
|
|
@ -244,19 +223,21 @@ impl Environment for TrictracEnvironment {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrictracEnvironment {
|
impl TrictracEnvironment {
|
||||||
|
const ERROR_REWARD: f32 = -1.12121;
|
||||||
|
const REWARD_RATIO: f32 = 1.0;
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||||
#[allow(dead_code)]
|
|
||||||
fn convert_valid_action_index(
|
fn convert_valid_action_index(
|
||||||
&self,
|
&self,
|
||||||
action: TrictracAction,
|
action: TrictracAction,
|
||||||
game_state: &GameState,
|
game_state: &GameState,
|
||||||
) -> Option<training_common::TrictracAction> {
|
) -> Option<dqn_common::TrictracAction> {
|
||||||
use training_common::get_valid_actions;
|
use dqn_common::get_valid_actions;
|
||||||
|
|
||||||
// Obtenir les actions valides dans le contexte actuel
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
let valid_actions = get_valid_actions(game_state);
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
|
@ -273,19 +254,75 @@ impl TrictracEnvironment {
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
// fn execute_action(
|
// fn execute_action(
|
||||||
// &mut self,
|
// &mut self,
|
||||||
// action: training_common::TrictracAction,
|
// action: dqn_common::TrictracAction,
|
||||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
|
||||||
use training_common::TrictracAction;
|
use dqn_common::TrictracAction;
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let mut is_rollpoint = false;
|
let mut is_rollpoint = false;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
reward += 0.1;
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.active_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game.dice.values.0, self.game.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game.dice.values.1, self.game.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Appliquer l'événement si valide
|
// Appliquer l'événement si valide
|
||||||
if let Some(event) = action.to_event(&self.game) {
|
if let Some(event) = event {
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
// reward += REWARD_VALID_MOVE;
|
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -299,7 +336,7 @@ impl TrictracEnvironment {
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
if points > 0 {
|
if points > 0 {
|
||||||
is_rollpoint = true;
|
is_rollpoint = true;
|
||||||
// println!("info: rolled for {reward}");
|
// println!("info: rolled for {reward}");
|
||||||
|
|
@ -311,12 +348,8 @@ impl TrictracEnvironment {
|
||||||
// Pénalité pour action invalide
|
// Pénalité pour action invalide
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
reward = Self::ERROR_REWARD;
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
reward = ERROR_REWARD;
|
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(reward, is_rollpoint)
|
(reward, is_rollpoint)
|
||||||
|
|
@ -339,8 +372,6 @@ impl TrictracEnvironment {
|
||||||
*strategy.get_mut_game() = self.game.clone();
|
*strategy.get_mut_game() = self.game.clone();
|
||||||
|
|
||||||
// Exécuter l'action selon le turn_stage
|
// Exécuter l'action selon le turn_stage
|
||||||
let mut calculate_points = false;
|
|
||||||
let opponent_color = store::Color::Black;
|
|
||||||
let event = match self.game.turn_stage {
|
let event = match self.game.turn_stage {
|
||||||
TurnStage::RollDice => GameEvent::Roll {
|
TurnStage::RollDice => GameEvent::Roll {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
|
|
@ -348,7 +379,6 @@ impl TrictracEnvironment {
|
||||||
TurnStage::RollWaiting => {
|
TurnStage::RollWaiting => {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
calculate_points = true;
|
|
||||||
GameEvent::RollResult {
|
GameEvent::RollResult {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
dice: store::Dice {
|
dice: store::Dice {
|
||||||
|
|
@ -357,6 +387,7 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkPoints => {
|
TurnStage::MarkPoints => {
|
||||||
|
let opponent_color = store::Color::Black;
|
||||||
let dice_roll_count = self
|
let dice_roll_count = self
|
||||||
.game
|
.game
|
||||||
.players
|
.players
|
||||||
|
|
@ -365,9 +396,12 @@ impl TrictracEnvironment {
|
||||||
.dice_roll_count;
|
.dice_roll_count;
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
|
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
points: points_rules.get_points(dice_roll_count).0,
|
points,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkAdvPoints => {
|
TurnStage::MarkAdvPoints => {
|
||||||
|
|
@ -380,10 +414,11 @@ impl TrictracEnvironment {
|
||||||
.dice_roll_count;
|
.dice_roll_count;
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
let points = points_rules.get_points(dice_roll_count).1;
|
||||||
// pas de reward : déjà comptabilisé lors du tour de blanc
|
// pas de reward : déjà comptabilisé lors du tour de blanc
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
points: points_rules.get_points(dice_roll_count).1,
|
points,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::HoldOrGoChoice => {
|
TurnStage::HoldOrGoChoice => {
|
||||||
|
|
@ -400,19 +435,6 @@ impl TrictracEnvironment {
|
||||||
|
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
if calculate_points {
|
|
||||||
let dice_roll_count = self
|
|
||||||
.game
|
|
||||||
.players
|
|
||||||
.get(&self.opponent_id)
|
|
||||||
.unwrap()
|
|
||||||
.dice_roll_count;
|
|
||||||
let points_rules =
|
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
|
||||||
// Récompense proportionnelle aux points
|
|
||||||
reward -= REWARD_RATIO * (points as f32 - adv_points as f32);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
reward
|
reward
|
||||||
53
bot/src/dqn/burnrl/main.rs
Normal file
53
bot/src/dqn/burnrl/main.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
use bot::dqn::burnrl::{
|
||||||
|
dqn_model, environment,
|
||||||
|
utils::{demo_model, load_model, save_model},
|
||||||
|
};
|
||||||
|
use burn::backend::{Autodiff, NdArray};
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
|
use burn_rl::base::ElemType;
|
||||||
|
|
||||||
|
type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// println!("> Entraînement");
|
||||||
|
|
||||||
|
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||||
|
let conf = dqn_model::DqnConfig {
|
||||||
|
// defaults
|
||||||
|
num_episodes: 40, // 40
|
||||||
|
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction)
|
||||||
|
max_steps: 3000, // 1000 max steps by episode
|
||||||
|
dense_size: 256, // 128 neural network complexity (default 128)
|
||||||
|
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||||
|
eps_end: 0.05, // 0.05
|
||||||
|
// eps_decay higher = epsilon decrease slower
|
||||||
|
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||||
|
// epsilon is updated at the start of each episode
|
||||||
|
eps_decay: 2000.0, // 1000 ?
|
||||||
|
|
||||||
|
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||||
|
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||||
|
// plus lente moins sensible aux coups de chance
|
||||||
|
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||||
|
// converger
|
||||||
|
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||||
|
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||||
|
};
|
||||||
|
println!("{conf}----------");
|
||||||
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
|
let valid_agent = agent.valid();
|
||||||
|
|
||||||
|
println!("> Sauvegarde du modèle de validation");
|
||||||
|
|
||||||
|
let path = "models/burn_dqn_40".to_string();
|
||||||
|
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||||
|
|
||||||
|
println!("> Chargement du modèle pour test");
|
||||||
|
let loaded_model = load_model(conf.dense_size, &path);
|
||||||
|
let loaded_agent = DQN::new(loaded_model.unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle chargé");
|
||||||
|
demo_model(loaded_agent);
|
||||||
|
}
|
||||||
3
bot/src/dqn/burnrl/mod.rs
Normal file
3
bot/src/dqn/burnrl/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod dqn_model;
|
||||||
|
pub mod environment;
|
||||||
|
pub mod utils;
|
||||||
114
bot/src/dqn/burnrl/utils.rs
Normal file
114
bot/src/dqn/burnrl/utils.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
use crate::dqn::burnrl::{
|
||||||
|
dqn_model,
|
||||||
|
environment::{TrictracAction, TrictracEnvironment},
|
||||||
|
};
|
||||||
|
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||||
|
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||||
|
use burn::module::{Module, Param, ParamId};
|
||||||
|
use burn::nn::Linear;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn::tensor::backend::Backend;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::{DQNModel, DQN};
|
||||||
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
|
||||||
|
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
println!("Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
CompactRecorder::new()
|
||||||
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
|
dqn_model::Net::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||||
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
let mut done = false;
|
||||||
|
while !done {
|
||||||
|
// let action = match infer_action(&agent, &env, state) {
|
||||||
|
let action = match infer_action(&agent, &env) {
|
||||||
|
Some(value) => value,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
// Execute action
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
done = snapshot.done();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||||
|
agent: &DQN<TrictracEnvironment, B, M>,
|
||||||
|
env: &TrictracEnvironment,
|
||||||
|
) -> Option<TrictracAction> {
|
||||||
|
let state = env.state();
|
||||||
|
// Get q-values
|
||||||
|
let q_values = agent
|
||||||
|
.model()
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.infer(state.to_tensor().unsqueeze());
|
||||||
|
// Get valid actions
|
||||||
|
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||||
|
if valid_actions_indices.is_empty() {
|
||||||
|
return None; // No valid actions, end of episode
|
||||||
|
}
|
||||||
|
// Set non valid actions q-values to lowest
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions_indices.contains(&index) {
|
||||||
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get best action (highest q-value)
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
let action = TrictracAction::from(action_index);
|
||||||
|
Some(action)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||||
|
this: &Param<Tensor<B, N>>,
|
||||||
|
that: &Param<Tensor<B, N>>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Param<Tensor<B, N>> {
|
||||||
|
let that_weight = that.val();
|
||||||
|
let this_weight = this.val();
|
||||||
|
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||||
|
|
||||||
|
Param::initialized(ParamId::new(), new_weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn soft_update_linear<B: Backend>(
|
||||||
|
this: Linear<B>,
|
||||||
|
that: &Linear<B>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Linear<B> {
|
||||||
|
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||||
|
let bias = match (&this.bias, &that.bias) {
|
||||||
|
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Linear::<B> { weight, bias }
|
||||||
|
}
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
use crate::burnrl::environment::TrictracEnvironment;
|
use crate::dqn::burnrl_valid::environment::TrictracEnvironment;
|
||||||
use crate::burnrl::utils::{soft_update_linear, Config};
|
use crate::dqn::burnrl_valid::utils::soft_update_linear;
|
||||||
use burn::backend::{ndarray::NdArrayDevice, NdArray};
|
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::optim::AdamWConfig;
|
use burn::optim::AdamWConfig;
|
||||||
use burn::record::{CompactRecorder, Recorder};
|
|
||||||
use burn::tensor::activation::relu;
|
use burn::tensor::activation::relu;
|
||||||
use burn::tensor::backend::{AutodiffBackend, Backend};
|
use burn::tensor::backend::{AutodiffBackend, Backend};
|
||||||
use burn::tensor::Tensor;
|
use burn::tensor::Tensor;
|
||||||
use burn_rl::agent::DQN;
|
use burn_rl::agent::DQN;
|
||||||
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
|
||||||
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
|
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
|
||||||
|
use std::fmt;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
|
|
@ -63,20 +62,67 @@ impl<B: Backend> DQNModel<B> for Net<B> {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
const MEMORY_SIZE: usize = 8192;
|
const MEMORY_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
pub struct DqnConfig {
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub num_episodes: usize,
|
||||||
|
pub dense_size: usize,
|
||||||
|
pub eps_start: f64,
|
||||||
|
pub eps_end: f64,
|
||||||
|
pub eps_decay: f64,
|
||||||
|
|
||||||
|
pub gamma: f32,
|
||||||
|
pub tau: f32,
|
||||||
|
pub learning_rate: f32,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub clip_grad: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for DqnConfig {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
let mut s = String::new();
|
||||||
|
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
|
||||||
|
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
|
||||||
|
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
|
||||||
|
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
|
||||||
|
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
|
||||||
|
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
|
||||||
|
s.push_str(&format!("gamma={:?}\n", self.gamma));
|
||||||
|
s.push_str(&format!("tau={:?}\n", self.tau));
|
||||||
|
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
|
||||||
|
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
|
||||||
|
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
|
||||||
|
write!(f, "{s}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_steps: 2000,
|
||||||
|
num_episodes: 1000,
|
||||||
|
dense_size: 256,
|
||||||
|
eps_start: 0.9,
|
||||||
|
eps_end: 0.05,
|
||||||
|
eps_decay: 1000.0,
|
||||||
|
|
||||||
|
gamma: 0.999,
|
||||||
|
tau: 0.005,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 32,
|
||||||
|
clip_grad: 100.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
type MyAgent<E, B> = DQN<E, B, Net<B>>;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
|
||||||
pub fn run<
|
conf: &DqnConfig,
|
||||||
E: Environment + AsMut<TrictracEnvironment>,
|
|
||||||
B: AutodiffBackend<InnerBackend = NdArray>,
|
|
||||||
>(
|
|
||||||
conf: &Config,
|
|
||||||
visualized: bool,
|
visualized: bool,
|
||||||
// ) -> DQN<E, B, Net<B>> {
|
) -> DQN<E, B, Net<B>> {
|
||||||
) -> impl Agent<E> {
|
// ) -> impl Agent<E> {
|
||||||
let mut env = E::new(visualized);
|
let mut env = E::new(visualized);
|
||||||
// env.as_mut().min_steps = conf.min_steps;
|
|
||||||
env.as_mut().max_steps = conf.max_steps;
|
env.as_mut().max_steps = conf.max_steps;
|
||||||
|
|
||||||
let model = Net::<B>::new(
|
let model = Net::<B>::new(
|
||||||
|
|
@ -143,13 +189,8 @@ pub fn run<
|
||||||
|
|
||||||
if snapshot.done() || episode_duration >= conf.max_steps {
|
if snapshot.done() || episode_duration >= conf.max_steps {
|
||||||
let envmut = env.as_mut();
|
let envmut = env.as_mut();
|
||||||
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
|
|
||||||
* 100.0)
|
|
||||||
.round() as u32;
|
|
||||||
println!(
|
println!(
|
||||||
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
|
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
|
||||||
envmut.goodmoves_count,
|
|
||||||
goodmoves_ratio,
|
|
||||||
envmut.pointrolls_count,
|
envmut.pointrolls_count,
|
||||||
now.elapsed().unwrap().as_secs(),
|
now.elapsed().unwrap().as_secs(),
|
||||||
);
|
);
|
||||||
|
|
@ -161,35 +202,5 @@ pub fn run<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let valid_agent = agent.valid();
|
agent
|
||||||
if let Some(path) = &conf.save_path {
|
|
||||||
save_model(valid_agent.model().as_ref().unwrap(), path);
|
|
||||||
}
|
|
||||||
valid_agent
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
|
|
||||||
let recorder = CompactRecorder::new();
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
println!("info: Modèle de validation sauvegardé : {model_path}");
|
|
||||||
recorder
|
|
||||||
.record(model.clone().into_record(), model_path.into())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
|
|
||||||
let model_path = format!("{path}.mpk");
|
|
||||||
// println!("Chargement du modèle depuis : {model_path}");
|
|
||||||
|
|
||||||
CompactRecorder::new()
|
|
||||||
.load(model_path.into(), &NdArrayDevice::default())
|
|
||||||
.map(|record| {
|
|
||||||
Net::new(
|
|
||||||
<TrictracEnvironment as Environment>::StateType::size(),
|
|
||||||
dense_size,
|
|
||||||
<TrictracEnvironment as Environment>::ActionType::size(),
|
|
||||||
)
|
|
||||||
.load_record(record)
|
|
||||||
})
|
|
||||||
.ok()
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
use crate::training_common;
|
use crate::dqn::dqn_common;
|
||||||
use burn::{prelude::Backend, tensor::Tensor};
|
use burn::{prelude::Backend, tensor::Tensor};
|
||||||
use burn_rl::base::{Action, Environment, Snapshot, State};
|
use burn_rl::base::{Action, Environment, Snapshot, State};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
const ERROR_REWARD: f32 = -1.0012121;
|
|
||||||
const REWARD_RATIO: f32 = 0.1;
|
|
||||||
|
|
||||||
/// État du jeu Trictrac pour burn-rl
|
/// État du jeu Trictrac pour burn-rl
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct TrictracState {
|
pub struct TrictracState {
|
||||||
|
|
@ -159,26 +156,17 @@ impl Environment for TrictracEnvironment {
|
||||||
if self.game.active_player_id == self.active_player_id {
|
if self.game.active_player_id == self.active_player_id {
|
||||||
if let Some(action) = trictrac_action {
|
if let Some(action) = trictrac_action {
|
||||||
(reward, is_rollpoint) = self.execute_action(action);
|
(reward, is_rollpoint) = self.execute_action(action);
|
||||||
// if reward != 0.0 {
|
|
||||||
// println!("info: self rew {reward}");
|
|
||||||
// }
|
|
||||||
if is_rollpoint {
|
if is_rollpoint {
|
||||||
self.pointrolls_count += 1;
|
self.pointrolls_count += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Action non convertible, pénalité
|
// Action non convertible, pénalité
|
||||||
println!("info: action non convertible -> -1 {trictrac_action:?}");
|
|
||||||
reward = -1.0;
|
reward = -1.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Faire jouer l'adversaire (stratégie simple)
|
// Faire jouer l'adversaire (stratégie simple)
|
||||||
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
|
||||||
// let op_rew = self.play_opponent_if_needed();
|
|
||||||
// if op_rew != 0.0 {
|
|
||||||
// println!("info: op rew {op_rew}");
|
|
||||||
// }
|
|
||||||
// reward += op_rew;
|
|
||||||
reward += self.play_opponent_if_needed();
|
reward += self.play_opponent_if_needed();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -217,16 +205,16 @@ impl TrictracEnvironment {
|
||||||
const REWARD_RATIO: f32 = 1.0;
|
const REWARD_RATIO: f32 = 1.0;
|
||||||
|
|
||||||
/// Convertit une action burn-rl vers une action Trictrac
|
/// Convertit une action burn-rl vers une action Trictrac
|
||||||
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
|
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
|
||||||
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
|
||||||
fn convert_valid_action_index(
|
fn convert_valid_action_index(
|
||||||
&self,
|
&self,
|
||||||
action: TrictracAction,
|
action: TrictracAction,
|
||||||
) -> Option<training_common::TrictracAction> {
|
) -> Option<dqn_common::TrictracAction> {
|
||||||
use training_common::get_valid_actions;
|
use dqn_common::get_valid_actions;
|
||||||
|
|
||||||
// Obtenir les actions valides dans le contexte actuel
|
// Obtenir les actions valides dans le contexte actuel
|
||||||
let valid_actions = get_valid_actions(&self.game);
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
@ -243,19 +231,72 @@ impl TrictracEnvironment {
|
||||||
/// Exécute une action Trictrac dans le jeu
|
/// Exécute une action Trictrac dans le jeu
|
||||||
// fn execute_action(
|
// fn execute_action(
|
||||||
// &mut self,
|
// &mut self,
|
||||||
// action: training_common::TrictracAction,
|
// action: dqn_common::TrictracAction,
|
||||||
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
// ) -> Result<f32, Box<dyn std::error::Error>> {
|
||||||
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
|
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
|
||||||
use training_common::TrictracAction;
|
use dqn_common::TrictracAction;
|
||||||
|
|
||||||
let mut reward = 0.0;
|
let mut reward = 0.0;
|
||||||
let mut is_rollpoint = false;
|
let mut is_rollpoint = false;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.active_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game.dice.values.0, self.game.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game.dice.values.1, self.game.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.active_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Appliquer l'événement si valide
|
// Appliquer l'événement si valide
|
||||||
if let Some(event) = action.to_event(&self.game) {
|
if let Some(event) = event {
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
// reward += REWARD_VALID_MOVE;
|
|
||||||
// Simuler le résultat des dés après un Roll
|
// Simuler le résultat des dés après un Roll
|
||||||
if matches!(action, TrictracAction::Roll) {
|
if matches!(action, TrictracAction::Roll) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -269,7 +310,7 @@ impl TrictracEnvironment {
|
||||||
if self.game.validate(&dice_event) {
|
if self.game.validate(&dice_event) {
|
||||||
self.game.consume(&dice_event);
|
self.game.consume(&dice_event);
|
||||||
let (points, adv_points) = self.game.dice_points;
|
let (points, adv_points) = self.game.dice_points;
|
||||||
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
|
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
|
||||||
if points > 0 {
|
if points > 0 {
|
||||||
is_rollpoint = true;
|
is_rollpoint = true;
|
||||||
// println!("info: rolled for {reward}");
|
// println!("info: rolled for {reward}");
|
||||||
|
|
@ -281,12 +322,8 @@ impl TrictracEnvironment {
|
||||||
// Pénalité pour action invalide
|
// Pénalité pour action invalide
|
||||||
// on annule les précédents reward
|
// on annule les précédents reward
|
||||||
// et on indique une valeur reconnaissable pour statistiques
|
// et on indique une valeur reconnaissable pour statistiques
|
||||||
reward = ERROR_REWARD;
|
reward = Self::ERROR_REWARD;
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
reward = ERROR_REWARD;
|
|
||||||
self.game.mark_points_for_bot_training(self.opponent_id, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(reward, is_rollpoint)
|
(reward, is_rollpoint)
|
||||||
|
|
@ -309,8 +346,6 @@ impl TrictracEnvironment {
|
||||||
*strategy.get_mut_game() = self.game.clone();
|
*strategy.get_mut_game() = self.game.clone();
|
||||||
|
|
||||||
// Exécuter l'action selon le turn_stage
|
// Exécuter l'action selon le turn_stage
|
||||||
let mut calculate_points = false;
|
|
||||||
let opponent_color = store::Color::Black;
|
|
||||||
let event = match self.game.turn_stage {
|
let event = match self.game.turn_stage {
|
||||||
TurnStage::RollDice => GameEvent::Roll {
|
TurnStage::RollDice => GameEvent::Roll {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
|
|
@ -318,7 +353,6 @@ impl TrictracEnvironment {
|
||||||
TurnStage::RollWaiting => {
|
TurnStage::RollWaiting => {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
calculate_points = true;
|
|
||||||
GameEvent::RollResult {
|
GameEvent::RollResult {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
dice: store::Dice {
|
dice: store::Dice {
|
||||||
|
|
@ -327,6 +361,7 @@ impl TrictracEnvironment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkPoints => {
|
TurnStage::MarkPoints => {
|
||||||
|
let opponent_color = store::Color::Black;
|
||||||
let dice_roll_count = self
|
let dice_roll_count = self
|
||||||
.game
|
.game
|
||||||
.players
|
.players
|
||||||
|
|
@ -335,12 +370,16 @@ impl TrictracEnvironment {
|
||||||
.dice_roll_count;
|
.dice_roll_count;
|
||||||
let points_rules =
|
let points_rules =
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
||||||
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
|
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
GameEvent::Mark {
|
GameEvent::Mark {
|
||||||
player_id: self.opponent_id,
|
player_id: self.opponent_id,
|
||||||
points: points_rules.get_points(dice_roll_count).0,
|
points,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::MarkAdvPoints => {
|
TurnStage::MarkAdvPoints => {
|
||||||
|
let opponent_color = store::Color::Black;
|
||||||
let dice_roll_count = self
|
let dice_roll_count = self
|
||||||
.game
|
.game
|
||||||
.players
|
.players
|
||||||
|
|
@ -370,19 +409,6 @@ impl TrictracEnvironment {
|
||||||
|
|
||||||
if self.game.validate(&event) {
|
if self.game.validate(&event) {
|
||||||
self.game.consume(&event);
|
self.game.consume(&event);
|
||||||
if calculate_points {
|
|
||||||
let dice_roll_count = self
|
|
||||||
.game
|
|
||||||
.players
|
|
||||||
.get(&self.opponent_id)
|
|
||||||
.unwrap()
|
|
||||||
.dice_roll_count;
|
|
||||||
let points_rules =
|
|
||||||
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
|
|
||||||
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
|
||||||
reward -= Self::REWARD_RATIO * (points - adv_points) as f32;
|
|
||||||
// Récompense proportionnelle aux points
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
reward
|
reward
|
||||||
52
bot/src/dqn/burnrl_valid/main.rs
Normal file
52
bot/src/dqn/burnrl_valid/main.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
use bot::dqn::burnrl_valid::{
|
||||||
|
dqn_model, environment,
|
||||||
|
utils::{demo_model, load_model, save_model},
|
||||||
|
};
|
||||||
|
use burn::backend::{Autodiff, NdArray};
|
||||||
|
use burn_rl::agent::DQN;
|
||||||
|
use burn_rl::base::ElemType;
|
||||||
|
|
||||||
|
type Backend = Autodiff<NdArray<ElemType>>;
|
||||||
|
type Env = environment::TrictracEnvironment;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// println!("> Entraînement");
|
||||||
|
|
||||||
|
// See also MEMORY_SIZE in dqn_model.rs : 8192
|
||||||
|
let conf = dqn_model::DqnConfig {
|
||||||
|
// defaults
|
||||||
|
num_episodes: 100, // 40
|
||||||
|
max_steps: 1000, // 1000 max steps by episode
|
||||||
|
dense_size: 256, // 128 neural network complexity (default 128)
|
||||||
|
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
|
||||||
|
eps_end: 0.05, // 0.05
|
||||||
|
// eps_decay higher = epsilon decrease slower
|
||||||
|
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
|
||||||
|
// epsilon is updated at the start of each episode
|
||||||
|
eps_decay: 2000.0, // 1000 ?
|
||||||
|
|
||||||
|
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
|
||||||
|
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
|
||||||
|
// plus lente moins sensible aux coups de chance
|
||||||
|
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
|
||||||
|
// converger
|
||||||
|
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
|
||||||
|
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
|
||||||
|
};
|
||||||
|
println!("{conf}----------");
|
||||||
|
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
|
||||||
|
|
||||||
|
let valid_agent = agent.valid();
|
||||||
|
|
||||||
|
println!("> Sauvegarde du modèle de validation");
|
||||||
|
|
||||||
|
let path = "bot/models/burn_dqn_valid_40".to_string();
|
||||||
|
save_model(valid_agent.model().as_ref().unwrap(), &path);
|
||||||
|
|
||||||
|
println!("> Chargement du modèle pour test");
|
||||||
|
let loaded_model = load_model(conf.dense_size, &path);
|
||||||
|
let loaded_agent = DQN::new(loaded_model.unwrap());
|
||||||
|
|
||||||
|
println!("> Test avec le modèle chargé");
|
||||||
|
demo_model(loaded_agent);
|
||||||
|
}
|
||||||
3
bot/src/dqn/burnrl_valid/mod.rs
Normal file
3
bot/src/dqn/burnrl_valid/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod dqn_model;
|
||||||
|
pub mod environment;
|
||||||
|
pub mod utils;
|
||||||
114
bot/src/dqn/burnrl_valid/utils.rs
Normal file
114
bot/src/dqn/burnrl_valid/utils.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
use crate::dqn::burnrl_valid::{
|
||||||
|
dqn_model,
|
||||||
|
environment::{TrictracAction, TrictracEnvironment},
|
||||||
|
};
|
||||||
|
use crate::dqn::dqn_common::get_valid_action_indices;
|
||||||
|
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
|
||||||
|
use burn::module::{Module, Param, ParamId};
|
||||||
|
use burn::nn::Linear;
|
||||||
|
use burn::record::{CompactRecorder, Recorder};
|
||||||
|
use burn::tensor::backend::Backend;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
|
use burn::tensor::Tensor;
|
||||||
|
use burn_rl::agent::{DQNModel, DQN};
|
||||||
|
use burn_rl::base::{Action, ElemType, Environment, State};
|
||||||
|
|
||||||
|
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
|
||||||
|
let recorder = CompactRecorder::new();
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
println!("Modèle de validation sauvegardé : {model_path}");
|
||||||
|
recorder
|
||||||
|
.record(model.clone().into_record(), model_path.into())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
|
||||||
|
let model_path = format!("{path}_model.mpk");
|
||||||
|
// println!("Chargement du modèle depuis : {model_path}");
|
||||||
|
|
||||||
|
CompactRecorder::new()
|
||||||
|
.load(model_path.into(), &NdArrayDevice::default())
|
||||||
|
.map(|record| {
|
||||||
|
dqn_model::Net::new(
|
||||||
|
<TrictracEnvironment as Environment>::StateType::size(),
|
||||||
|
dense_size,
|
||||||
|
<TrictracEnvironment as Environment>::ActionType::size(),
|
||||||
|
)
|
||||||
|
.load_record(record)
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
|
||||||
|
let mut env = TrictracEnvironment::new(true);
|
||||||
|
let mut done = false;
|
||||||
|
while !done {
|
||||||
|
// let action = match infer_action(&agent, &env, state) {
|
||||||
|
let action = match infer_action(&agent, &env) {
|
||||||
|
Some(value) => value,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
// Execute action
|
||||||
|
let snapshot = env.step(action);
|
||||||
|
done = snapshot.done();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_action<B: Backend, M: DQNModel<B>>(
|
||||||
|
agent: &DQN<TrictracEnvironment, B, M>,
|
||||||
|
env: &TrictracEnvironment,
|
||||||
|
) -> Option<TrictracAction> {
|
||||||
|
let state = env.state();
|
||||||
|
// Get q-values
|
||||||
|
let q_values = agent
|
||||||
|
.model()
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.infer(state.to_tensor().unsqueeze());
|
||||||
|
// Get valid actions
|
||||||
|
let valid_actions_indices = get_valid_action_indices(&env.game);
|
||||||
|
if valid_actions_indices.is_empty() {
|
||||||
|
return None; // No valid actions, end of episode
|
||||||
|
}
|
||||||
|
// Set non valid actions q-values to lowest
|
||||||
|
let mut masked_q_values = q_values.clone();
|
||||||
|
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||||
|
for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||||
|
if !valid_actions_indices.contains(&index) {
|
||||||
|
masked_q_values = masked_q_values.clone().mask_fill(
|
||||||
|
masked_q_values.clone().equal_elem(*q_value),
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get best action (highest q-value)
|
||||||
|
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||||
|
let action = TrictracAction::from(action_index);
|
||||||
|
Some(action)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn soft_update_tensor<const N: usize, B: Backend>(
|
||||||
|
this: &Param<Tensor<B, N>>,
|
||||||
|
that: &Param<Tensor<B, N>>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Param<Tensor<B, N>> {
|
||||||
|
let that_weight = that.val();
|
||||||
|
let this_weight = this.val();
|
||||||
|
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
|
||||||
|
|
||||||
|
Param::initialized(ParamId::new(), new_weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn soft_update_linear<B: Backend>(
|
||||||
|
this: Linear<B>,
|
||||||
|
that: &Linear<B>,
|
||||||
|
tau: ElemType,
|
||||||
|
) -> Linear<B> {
|
||||||
|
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
|
||||||
|
let bias = match (&this.bias, &that.bias) {
|
||||||
|
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Linear::<B> { weight, bias }
|
||||||
|
}
|
||||||
|
|
@ -1,17 +1,10 @@
|
||||||
/// training_common.rs : environnement avec espace d'actions optimisé
|
|
||||||
/// (514 au lieu de 1252 pour training_common_big.rs de la branche 'big_and_full' )
|
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
use std::fmt::{Debug, Display, Formatter};
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use store::{CheckerMove, GameEvent, GameState};
|
use store::{CheckerMove, Dice};
|
||||||
|
|
||||||
// 1 (Roll) + 1 (Go) + mouvements possibles
|
|
||||||
// Pour les mouvements : 2*16*16 = 514 (choix du dé + choix de la dame 0-15 pour chaque from)
|
|
||||||
pub const ACTION_SPACE_SIZE: usize = 514;
|
|
||||||
|
|
||||||
/// Types d'actions possibles dans le jeu
|
/// Types d'actions possibles dans le jeu
|
||||||
#[derive(Debug, Copy, Clone, Eq, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub enum TrictracAction {
|
pub enum TrictracAction {
|
||||||
/// Lancer les dés
|
/// Lancer les dés
|
||||||
Roll,
|
Roll,
|
||||||
|
|
@ -20,21 +13,13 @@ pub enum TrictracAction {
|
||||||
/// Effectuer un mouvement de pions
|
/// Effectuer un mouvement de pions
|
||||||
Move {
|
Move {
|
||||||
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
|
||||||
checker1: usize, // premier pion à déplacer en numérotant depuis la colonne de départ (0-15) 0 : aucun pion
|
from1: usize, // position de départ du premier pion (0-24)
|
||||||
checker2: usize, // deuxième pion (0-15)
|
from2: usize, // position de départ du deuxième pion (0-24)
|
||||||
},
|
},
|
||||||
// Marquer les points : à activer si support des écoles
|
// Marquer les points : à activer si support des écoles
|
||||||
// Mark,
|
// Mark,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for TrictracAction {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
let s = format!("{self:?}");
|
|
||||||
writeln!(f, "{}", s.chars().rev().collect::<String>())?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrictracAction {
|
impl TrictracAction {
|
||||||
/// Encode une action en index pour le réseau de neurones
|
/// Encode une action en index pour le réseau de neurones
|
||||||
pub fn to_action_index(&self) -> usize {
|
pub fn to_action_index(&self) -> usize {
|
||||||
|
|
@ -43,91 +28,19 @@ impl TrictracAction {
|
||||||
TrictracAction::Go => 1,
|
TrictracAction::Go => 1,
|
||||||
TrictracAction::Move {
|
TrictracAction::Move {
|
||||||
dice_order,
|
dice_order,
|
||||||
checker1,
|
from1,
|
||||||
checker2,
|
from2,
|
||||||
} => {
|
} => {
|
||||||
// Encoder les mouvements dans l'espace d'actions
|
// Encoder les mouvements dans l'espace d'actions
|
||||||
// Indices 2+ pour les mouvements
|
// Indices 2+ pour les mouvements
|
||||||
// de 2 à 513 (2 à 257 pour dé 1 en premier, 258 à 513 pour dé 2 en premier)
|
// de 2 à 1251 (2 à 626 pour dé 1 en premier, 627 à 1251 pour dé 2 en premier)
|
||||||
let mut start = 2;
|
let mut start = 2;
|
||||||
if !dice_order {
|
if !dice_order {
|
||||||
// 16 * 16 = 256
|
// 25 * 25 = 625
|
||||||
start += 256;
|
start += 625;
|
||||||
}
|
}
|
||||||
start + checker1 * 16 + checker2
|
start + from1 * 25 + from2
|
||||||
} // TrictracAction::Mark => 514,
|
} // TrictracAction::Mark => 1252,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_event(&self, state: &GameState) -> Option<GameEvent> {
|
|
||||||
match self {
|
|
||||||
TrictracAction::Roll => {
|
|
||||||
// Lancer les dés
|
|
||||||
Some(GameEvent::Roll {
|
|
||||||
player_id: state.active_player_id,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// TrictracAction::Mark => {
|
|
||||||
// // Marquer des points
|
|
||||||
// let points = self.game.
|
|
||||||
// Some(GameEvent::Mark {
|
|
||||||
// player_id: self.active_player_id,
|
|
||||||
// points,
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
TrictracAction::Go => {
|
|
||||||
// Continuer après avoir gagné un trou
|
|
||||||
Some(GameEvent::Go {
|
|
||||||
player_id: state.active_player_id,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
TrictracAction::Move {
|
|
||||||
dice_order,
|
|
||||||
checker1,
|
|
||||||
checker2,
|
|
||||||
} => {
|
|
||||||
// Effectuer un mouvement
|
|
||||||
let (dice1, dice2) = if *dice_order {
|
|
||||||
(state.dice.values.0, state.dice.values.1)
|
|
||||||
} else {
|
|
||||||
(state.dice.values.1, state.dice.values.0)
|
|
||||||
};
|
|
||||||
|
|
||||||
let color = &store::Color::White;
|
|
||||||
let from1 = state
|
|
||||||
.board
|
|
||||||
.get_checker_field(color, *checker1 as u8)
|
|
||||||
.unwrap_or(0);
|
|
||||||
let mut to1 = from1 + dice1 as usize;
|
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
|
||||||
|
|
||||||
let mut tmp_board = state.board.clone();
|
|
||||||
let move_result = tmp_board.move_checker(color, checker_move1);
|
|
||||||
if move_result.is_err() {
|
|
||||||
None
|
|
||||||
// panic!("Error while moving checker {move_result:?}")
|
|
||||||
} else {
|
|
||||||
let from2 = tmp_board
|
|
||||||
.get_checker_field(color, *checker2 as u8)
|
|
||||||
.unwrap_or(0);
|
|
||||||
let mut to2 = from2 + dice2 as usize;
|
|
||||||
|
|
||||||
// Gestion prise de coin par puissance
|
|
||||||
let opp_rest_field = 13;
|
|
||||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
|
||||||
to1 -= 1;
|
|
||||||
to2 -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
|
||||||
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
|
||||||
|
|
||||||
Some(GameEvent::Move {
|
|
||||||
player_id: state.active_player_id,
|
|
||||||
moves: (checker_move1, checker_move2),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -135,15 +48,15 @@ impl TrictracAction {
|
||||||
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
|
||||||
match index {
|
match index {
|
||||||
0 => Some(TrictracAction::Roll),
|
0 => Some(TrictracAction::Roll),
|
||||||
|
// 1252 => Some(TrictracAction::Mark),
|
||||||
1 => Some(TrictracAction::Go),
|
1 => Some(TrictracAction::Go),
|
||||||
// 514 => Some(TrictracAction::Mark),
|
i if i >= 3 => {
|
||||||
i if i >= 2 => {
|
let move_code = i - 3;
|
||||||
let move_code = i - 2;
|
let (dice_order, from1, from2) = Self::decode_move(move_code);
|
||||||
let (dice_order, checker1, checker2) = Self::decode_move(move_code);
|
|
||||||
Some(TrictracAction::Move {
|
Some(TrictracAction::Move {
|
||||||
dice_order,
|
dice_order,
|
||||||
checker1,
|
from1,
|
||||||
checker2,
|
from2,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
|
|
@ -153,18 +66,21 @@ impl TrictracAction {
|
||||||
/// Décode un entier en paire de mouvements
|
/// Décode un entier en paire de mouvements
|
||||||
fn decode_move(code: usize) -> (bool, usize, usize) {
|
fn decode_move(code: usize) -> (bool, usize, usize) {
|
||||||
let mut encoded = code;
|
let mut encoded = code;
|
||||||
let dice_order = code < 256;
|
let dice_order = code < 626;
|
||||||
if !dice_order {
|
if !dice_order {
|
||||||
encoded -= 256
|
encoded -= 625
|
||||||
}
|
}
|
||||||
let checker1 = encoded / 16;
|
let from1 = encoded / 25;
|
||||||
let checker2 = encoded % 16;
|
let from2 = 1 + encoded % 25;
|
||||||
(dice_order, checker1, checker2)
|
(dice_order, from1, from2)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retourne la taille de l'espace d'actions total
|
/// Retourne la taille de l'espace d'actions total
|
||||||
pub fn action_space_size() -> usize {
|
pub fn action_space_size() -> usize {
|
||||||
ACTION_SPACE_SIZE
|
// 1 (Roll) + 1 (Go) + mouvements possibles
|
||||||
|
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
|
||||||
|
// Mais on peut optimiser en limitant aux positions valides (1-24)
|
||||||
|
2 + (2 * 25 * 25) // = 1252
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
|
||||||
|
|
@ -201,15 +117,11 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
|
|
||||||
if let Some(color) = player_color {
|
if let Some(color) = player_color {
|
||||||
match game_state.turn_stage {
|
match game_state.turn_stage {
|
||||||
TurnStage::RollDice => {
|
TurnStage::RollDice | TurnStage::RollWaiting => {
|
||||||
valid_actions.push(TrictracAction::Roll);
|
valid_actions.push(TrictracAction::Roll);
|
||||||
}
|
}
|
||||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
|
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||||
// valid_actions.push(TrictracAction::Mark);
|
// valid_actions.push(TrictracAction::Mark);
|
||||||
panic!(
|
|
||||||
"get_valid_actions not implemented for turn stage {:?}",
|
|
||||||
game_state.turn_stage
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
TurnStage::HoldOrGoChoice => {
|
TurnStage::HoldOrGoChoice => {
|
||||||
valid_actions.push(TrictracAction::Go);
|
valid_actions.push(TrictracAction::Go);
|
||||||
|
|
@ -222,32 +134,29 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
assert_eq!(color, store::Color::White);
|
assert_eq!(color, store::Color::White);
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(checker_moves_to_trictrac_action(
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
&move1, &move2, &color, game_state,
|
&move1,
|
||||||
|
&move2,
|
||||||
|
&game_state.dice,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TurnStage::Move => {
|
TurnStage::Move => {
|
||||||
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
|
||||||
let mut possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
if possible_moves.is_empty() {
|
|
||||||
// Empty move
|
|
||||||
possible_moves.push((CheckerMove::default(), CheckerMove::default()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
|
||||||
assert_eq!(color, store::Color::White);
|
assert_eq!(color, store::Color::White);
|
||||||
for (move1, move2) in possible_moves {
|
for (move1, move2) in possible_moves {
|
||||||
valid_actions.push(checker_moves_to_trictrac_action(
|
valid_actions.push(checker_moves_to_trictrac_action(
|
||||||
&move1, &move2, &color, game_state,
|
&move1,
|
||||||
|
&move2,
|
||||||
|
&game_state.dice,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if valid_actions.is_empty() {
|
|
||||||
panic!("empty valid_actions for state {game_state}");
|
|
||||||
}
|
|
||||||
valid_actions
|
valid_actions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -255,14 +164,12 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
|
||||||
fn checker_moves_to_trictrac_action(
|
fn checker_moves_to_trictrac_action(
|
||||||
move1: &CheckerMove,
|
move1: &CheckerMove,
|
||||||
move2: &CheckerMove,
|
move2: &CheckerMove,
|
||||||
color: &store::Color,
|
dice: &Dice,
|
||||||
state: &crate::GameState,
|
|
||||||
) -> TrictracAction {
|
) -> TrictracAction {
|
||||||
let to1 = move1.get_to();
|
let to1 = move1.get_to();
|
||||||
let to2 = move2.get_to();
|
let to2 = move2.get_to();
|
||||||
let from1 = move1.get_from();
|
let from1 = move1.get_from();
|
||||||
let from2 = move2.get_from();
|
let from2 = move2.get_from();
|
||||||
let dice = state.dice;
|
|
||||||
|
|
||||||
let mut diff_move1 = if to1 > 0 {
|
let mut diff_move1 = if to1 > 0 {
|
||||||
// Mouvement sans sortie
|
// Mouvement sans sortie
|
||||||
|
|
@ -296,20 +203,10 @@ fn checker_moves_to_trictrac_action(
|
||||||
// prise par puissance
|
// prise par puissance
|
||||||
diff_move1 += 1;
|
diff_move1 += 1;
|
||||||
}
|
}
|
||||||
let dice_order = diff_move1 == dice.values.0 as usize;
|
|
||||||
|
|
||||||
let checker1 = state.board.get_field_checker(color, from1) as usize;
|
|
||||||
let mut tmp_board = state.board.clone();
|
|
||||||
// should not raise an error for a valid action
|
|
||||||
let move_res = tmp_board.move_checker(color, *move1);
|
|
||||||
if move_res.is_err() {
|
|
||||||
panic!("error while moving checker {move_res:?}");
|
|
||||||
}
|
|
||||||
let checker2 = tmp_board.get_field_checker(color, from2) as usize;
|
|
||||||
TrictracAction::Move {
|
TrictracAction::Move {
|
||||||
dice_order,
|
dice_order: diff_move1 == dice.values.0 as usize,
|
||||||
checker1,
|
from1: move1.get_from(),
|
||||||
checker2,
|
from2: move2.get_from(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -338,21 +235,21 @@ mod tests {
|
||||||
fn to_action_index() {
|
fn to_action_index() {
|
||||||
let action = TrictracAction::Move {
|
let action = TrictracAction::Move {
|
||||||
dice_order: true,
|
dice_order: true,
|
||||||
checker1: 3,
|
from1: 3,
|
||||||
checker2: 4,
|
from2: 4,
|
||||||
};
|
};
|
||||||
let index = action.to_action_index();
|
let index = action.to_action_index();
|
||||||
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
assert_eq!(Some(action), TrictracAction::from_action_index(index));
|
||||||
assert_eq!(54, index);
|
assert_eq!(81, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn from_action_index() {
|
fn from_action_index() {
|
||||||
let action = TrictracAction::Move {
|
let action = TrictracAction::Move {
|
||||||
dice_order: true,
|
dice_order: true,
|
||||||
checker1: 3,
|
from1: 3,
|
||||||
checker2: 4,
|
from2: 4,
|
||||||
};
|
};
|
||||||
assert_eq!(Some(action), TrictracAction::from_action_index(54));
|
assert_eq!(Some(action), TrictracAction::from_action_index(81));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
5
bot/src/dqn/mod.rs
Normal file
5
bot/src/dqn/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
pub mod burnrl;
|
||||||
|
pub mod dqn_common;
|
||||||
|
pub mod simple;
|
||||||
|
|
||||||
|
pub mod burnrl_valid;
|
||||||
154
bot/src/dqn/simple/dqn_model.rs
Normal file
154
bot/src/dqn/simple/dqn_model.rs
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
use crate::dqn::dqn_common::TrictracAction;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Configuration pour l'agent DQN
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DqnConfig {
|
||||||
|
pub state_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub num_actions: usize,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub gamma: f64,
|
||||||
|
pub epsilon: f64,
|
||||||
|
pub epsilon_decay: f64,
|
||||||
|
pub epsilon_min: f64,
|
||||||
|
pub replay_buffer_size: usize,
|
||||||
|
pub batch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state_size: 36,
|
||||||
|
hidden_size: 512, // Augmenter la taille pour gérer l'espace d'actions élargi
|
||||||
|
num_actions: TrictracAction::action_space_size(),
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 0.1,
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Réseau de neurones DQN simplifié (matrice de poids basique)
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SimpleNeuralNetwork {
|
||||||
|
pub weights1: Vec<Vec<f32>>,
|
||||||
|
pub biases1: Vec<f32>,
|
||||||
|
pub weights2: Vec<Vec<f32>>,
|
||||||
|
pub biases2: Vec<f32>,
|
||||||
|
pub weights3: Vec<Vec<f32>>,
|
||||||
|
pub biases3: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleNeuralNetwork {
|
||||||
|
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
|
// Initialisation aléatoire des poids avec Xavier/Glorot
|
||||||
|
let scale1 = (2.0 / input_size as f32).sqrt();
|
||||||
|
let weights1 = (0..hidden_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..input_size)
|
||||||
|
.map(|_| rng.gen_range(-scale1..scale1))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases1 = vec![0.0; hidden_size];
|
||||||
|
|
||||||
|
let scale2 = (2.0 / hidden_size as f32).sqrt();
|
||||||
|
let weights2 = (0..hidden_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..hidden_size)
|
||||||
|
.map(|_| rng.gen_range(-scale2..scale2))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases2 = vec![0.0; hidden_size];
|
||||||
|
|
||||||
|
let scale3 = (2.0 / hidden_size as f32).sqrt();
|
||||||
|
let weights3 = (0..output_size)
|
||||||
|
.map(|_| {
|
||||||
|
(0..hidden_size)
|
||||||
|
.map(|_| rng.gen_range(-scale3..scale3))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let biases3 = vec![0.0; output_size];
|
||||||
|
|
||||||
|
Self {
|
||||||
|
weights1,
|
||||||
|
biases1,
|
||||||
|
weights2,
|
||||||
|
biases2,
|
||||||
|
weights3,
|
||||||
|
biases3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||||
|
// Première couche
|
||||||
|
let mut layer1: Vec<f32> = self.biases1.clone();
|
||||||
|
for (i, neuron_weights) in self.weights1.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < input.len() {
|
||||||
|
layer1[i] += input[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
layer1[i] = layer1[i].max(0.0); // ReLU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deuxième couche
|
||||||
|
let mut layer2: Vec<f32> = self.biases2.clone();
|
||||||
|
for (i, neuron_weights) in self.weights2.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < layer1.len() {
|
||||||
|
layer2[i] += layer1[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
layer2[i] = layer2[i].max(0.0); // ReLU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Couche de sortie
|
||||||
|
let mut output: Vec<f32> = self.biases3.clone();
|
||||||
|
for (i, neuron_weights) in self.weights3.iter().enumerate() {
|
||||||
|
for (j, &weight) in neuron_weights.iter().enumerate() {
|
||||||
|
if j < layer2.len() {
|
||||||
|
output[i] += layer2[j] * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_best_action(&self, input: &[f32]) -> usize {
|
||||||
|
let q_values = self.forward(input);
|
||||||
|
q_values
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||||
|
.map(|(index, _)| index)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
path: P,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let data = serde_json::to_string_pretty(self)?;
|
||||||
|
std::fs::write(path, data)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let data = std::fs::read_to_string(path)?;
|
||||||
|
let network = serde_json::from_str(&data)?;
|
||||||
|
Ok(network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
490
bot/src/dqn/simple/dqn_trainer.rs
Normal file
490
bot/src/dqn/simple/dqn_trainer.rs
Normal file
|
|
@ -0,0 +1,490 @@
|
||||||
|
use crate::{CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use rand::prelude::SliceRandom;
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
|
||||||
|
|
||||||
|
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
|
||||||
|
use crate::dqn::dqn_common::{get_valid_actions, TrictracAction};
|
||||||
|
|
||||||
|
/// Expérience pour le buffer de replay
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Experience {
|
||||||
|
pub state: Vec<f32>,
|
||||||
|
pub action: TrictracAction,
|
||||||
|
pub reward: f32,
|
||||||
|
pub next_state: Vec<f32>,
|
||||||
|
pub done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer de replay pour stocker les expériences
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ReplayBuffer {
|
||||||
|
buffer: VecDeque<Experience>,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: VecDeque::with_capacity(capacity),
|
||||||
|
capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, experience: Experience) {
|
||||||
|
if self.buffer.len() >= self.capacity {
|
||||||
|
self.buffer.pop_front();
|
||||||
|
}
|
||||||
|
self.buffer.push_back(experience);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let len = self.buffer.len();
|
||||||
|
if len < batch_size {
|
||||||
|
return self.buffer.iter().cloned().collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batch = Vec::with_capacity(batch_size);
|
||||||
|
for _ in 0..batch_size {
|
||||||
|
let idx = rng.gen_range(0..len);
|
||||||
|
batch.push(self.buffer[idx].clone());
|
||||||
|
}
|
||||||
|
batch
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent DQN pour l'apprentissage par renforcement
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnAgent {
|
||||||
|
config: DqnConfig,
|
||||||
|
model: SimpleNeuralNetwork,
|
||||||
|
target_model: SimpleNeuralNetwork,
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
epsilon: f64,
|
||||||
|
step_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnAgent {
|
||||||
|
pub fn new(config: DqnConfig) -> Self {
|
||||||
|
let model =
|
||||||
|
SimpleNeuralNetwork::new(config.state_size, config.hidden_size, config.num_actions);
|
||||||
|
let target_model = model.clone();
|
||||||
|
let replay_buffer = ReplayBuffer::new(config.replay_buffer_size);
|
||||||
|
let epsilon = config.epsilon;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
model,
|
||||||
|
target_model,
|
||||||
|
replay_buffer,
|
||||||
|
epsilon,
|
||||||
|
step_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select_action(&mut self, game_state: &GameState, state: &[f32]) -> TrictracAction {
|
||||||
|
let valid_actions = get_valid_actions(game_state);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
// Fallback si aucune action valide
|
||||||
|
return TrictracAction::Roll;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
if rng.gen::<f64>() < self.epsilon {
|
||||||
|
// Exploration : action valide aléatoire
|
||||||
|
valid_actions
|
||||||
|
.choose(&mut rng)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(TrictracAction::Roll)
|
||||||
|
} else {
|
||||||
|
// Exploitation : meilleure action valide selon le modèle
|
||||||
|
let q_values = self.model.forward(state);
|
||||||
|
|
||||||
|
let mut best_action = &valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for action in &valid_actions {
|
||||||
|
let action_index = action.to_action_index();
|
||||||
|
if action_index < q_values.len() {
|
||||||
|
let q_value = q_values[action_index];
|
||||||
|
if q_value > best_q_value {
|
||||||
|
best_q_value = q_value;
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
best_action.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_experience(&mut self, experience: Experience) {
|
||||||
|
self.replay_buffer.push(experience);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&mut self) {
|
||||||
|
if self.replay_buffer.len() < self.config.batch_size {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pour l'instant, on simule l'entraînement en mettant à jour epsilon
|
||||||
|
// Dans une implémentation complète, ici on ferait la backpropagation
|
||||||
|
self.epsilon = (self.epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
|
||||||
|
self.step_count += 1;
|
||||||
|
|
||||||
|
// Mise à jour du target model tous les 100 steps
|
||||||
|
if self.step_count % 100 == 0 {
|
||||||
|
self.target_model = self.model.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_model<P: AsRef<std::path::Path>>(
|
||||||
|
&self,
|
||||||
|
path: P,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
self.model.save(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_epsilon(&self) -> f64 {
|
||||||
|
self.epsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_step_count(&self) -> usize {
|
||||||
|
self.step_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Environnement Trictrac pour l'entraînement
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TrictracEnv {
|
||||||
|
pub game_state: GameState,
|
||||||
|
pub agent_player_id: PlayerId,
|
||||||
|
pub opponent_player_id: PlayerId,
|
||||||
|
pub agent_color: Color,
|
||||||
|
pub max_steps: usize,
|
||||||
|
pub current_step: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrictracEnv {
|
||||||
|
fn default() -> Self {
|
||||||
|
let mut game_state = GameState::new(false);
|
||||||
|
game_state.init_player("agent");
|
||||||
|
game_state.init_player("opponent");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
game_state,
|
||||||
|
agent_player_id: 1,
|
||||||
|
opponent_player_id: 2,
|
||||||
|
agent_color: Color::White,
|
||||||
|
max_steps: 1000,
|
||||||
|
current_step: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrictracEnv {
|
||||||
|
pub fn reset(&mut self) -> Vec<f32> {
|
||||||
|
self.game_state = GameState::new(false);
|
||||||
|
self.game_state.init_player("agent");
|
||||||
|
self.game_state.init_player("opponent");
|
||||||
|
|
||||||
|
// Commencer la partie
|
||||||
|
self.game_state.consume(&GameEvent::BeginGame {
|
||||||
|
goes_first: self.agent_player_id,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.current_step = 0;
|
||||||
|
self.game_state.to_vec_float()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, action: TrictracAction) -> (Vec<f32>, f32, bool) {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
|
||||||
|
// Appliquer l'action de l'agent
|
||||||
|
if self.game_state.active_player_id == self.agent_player_id {
|
||||||
|
reward += self.apply_agent_action(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Faire jouer l'adversaire (stratégie simple)
|
||||||
|
while self.game_state.active_player_id == self.opponent_player_id
|
||||||
|
&& self.game_state.stage != Stage::Ended
|
||||||
|
{
|
||||||
|
reward += self.play_opponent_turn();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vérifier si la partie est terminée
|
||||||
|
let done = self.game_state.stage == Stage::Ended
|
||||||
|
|| self.game_state.determine_winner().is_some()
|
||||||
|
|| self.current_step >= self.max_steps;
|
||||||
|
|
||||||
|
// Récompense finale si la partie est terminée
|
||||||
|
if done {
|
||||||
|
if let Some(winner) = self.game_state.determine_winner() {
|
||||||
|
if winner == self.agent_player_id {
|
||||||
|
reward += 100.0; // Bonus pour gagner
|
||||||
|
} else {
|
||||||
|
reward -= 50.0; // Pénalité pour perdre
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_step += 1;
|
||||||
|
let next_state = self.game_state.to_vec_float();
|
||||||
|
(next_state, reward, done)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_agent_action(&mut self, action: TrictracAction) -> f32 {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
|
||||||
|
let event = match action {
|
||||||
|
TrictracAction::Roll => {
|
||||||
|
// Lancer les dés
|
||||||
|
reward += 0.1;
|
||||||
|
Some(GameEvent::Roll {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// TrictracAction::Mark => {
|
||||||
|
// // Marquer des points
|
||||||
|
// let points = self.game_state.
|
||||||
|
// reward += 0.1 * points as f32;
|
||||||
|
// Some(GameEvent::Mark {
|
||||||
|
// player_id: self.agent_player_id,
|
||||||
|
// points,
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
TrictracAction::Go => {
|
||||||
|
// Continuer après avoir gagné un trou
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Go {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
} => {
|
||||||
|
// Effectuer un mouvement
|
||||||
|
let (dice1, dice2) = if dice_order {
|
||||||
|
(self.game_state.dice.values.0, self.game_state.dice.values.1)
|
||||||
|
} else {
|
||||||
|
(self.game_state.dice.values.1, self.game_state.dice.values.0)
|
||||||
|
};
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
|
||||||
|
// Gestion prise de coin par puissance
|
||||||
|
let opp_rest_field = 13;
|
||||||
|
if to1 == opp_rest_field && to2 == opp_rest_field {
|
||||||
|
to1 -= 1;
|
||||||
|
to2 -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
reward += 0.2;
|
||||||
|
Some(GameEvent::Move {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
moves: (checker_move1, checker_move2),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appliquer l'événement si valide
|
||||||
|
if let Some(event) = event {
|
||||||
|
if self.game_state.validate(&event) {
|
||||||
|
self.game_state.consume(&event);
|
||||||
|
|
||||||
|
// Simuler le résultat des dés après un Roll
|
||||||
|
if matches!(action, TrictracAction::Roll) {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
let dice_event = GameEvent::RollResult {
|
||||||
|
player_id: self.agent_player_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if self.game_state.validate(&dice_event) {
|
||||||
|
self.game_state.consume(&dice_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Pénalité pour action invalide
|
||||||
|
reward -= 2.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reward
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO : use default bot strategy
|
||||||
|
fn play_opponent_turn(&mut self) -> f32 {
|
||||||
|
let mut reward = 0.0;
|
||||||
|
let event = match self.game_state.turn_stage {
|
||||||
|
TurnStage::RollDice => GameEvent::Roll {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
},
|
||||||
|
TurnStage::RollWaiting => {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
|
||||||
|
GameEvent::RollResult {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
dice: store::Dice {
|
||||||
|
values: dice_values,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::MarkAdvPoints | TurnStage::MarkPoints => {
|
||||||
|
let opponent_color = self.agent_color.opponent_color();
|
||||||
|
let dice_roll_count = self
|
||||||
|
.game_state
|
||||||
|
.players
|
||||||
|
.get(&self.opponent_player_id)
|
||||||
|
.unwrap()
|
||||||
|
.dice_roll_count;
|
||||||
|
let points_rules = PointsRules::new(
|
||||||
|
&opponent_color,
|
||||||
|
&self.game_state.board,
|
||||||
|
self.game_state.dice,
|
||||||
|
);
|
||||||
|
let (points, adv_points) = points_rules.get_points(dice_roll_count);
|
||||||
|
reward -= 0.3 * (points - adv_points) as f32; // Récompense proportionnelle aux points
|
||||||
|
|
||||||
|
GameEvent::Mark {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
points,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::Move => {
|
||||||
|
let opponent_color = self.agent_color.opponent_color();
|
||||||
|
let rules = MoveRules::new(
|
||||||
|
&opponent_color,
|
||||||
|
&self.game_state.board,
|
||||||
|
self.game_state.dice,
|
||||||
|
);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
// Stratégie simple : choix aléatoire
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let choosen_move = *possible_moves
|
||||||
|
.choose(&mut rng)
|
||||||
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
|
GameEvent::Move {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
moves: if opponent_color == Color::White {
|
||||||
|
choosen_move
|
||||||
|
} else {
|
||||||
|
(choosen_move.0.mirror(), choosen_move.1.mirror())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TurnStage::HoldOrGoChoice => {
|
||||||
|
// Stratégie simple : toujours continuer
|
||||||
|
GameEvent::Go {
|
||||||
|
player_id: self.opponent_player_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if self.game_state.validate(&event) {
|
||||||
|
self.game_state.consume(&event);
|
||||||
|
}
|
||||||
|
reward
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Entraîneur pour le modèle DQN
|
||||||
|
pub struct DqnTrainer {
|
||||||
|
agent: DqnAgent,
|
||||||
|
env: TrictracEnv,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnTrainer {
|
||||||
|
pub fn new(config: DqnConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
agent: DqnAgent::new(config),
|
||||||
|
env: TrictracEnv::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_episode(&mut self) -> f32 {
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
let mut state = self.env.reset();
|
||||||
|
// let mut step_count = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// step_count += 1;
|
||||||
|
let action = self.agent.select_action(&self.env.game_state, &state);
|
||||||
|
let (next_state, reward, done) = self.env.step(action.clone());
|
||||||
|
total_reward += reward;
|
||||||
|
|
||||||
|
let experience = Experience {
|
||||||
|
state: state.clone(),
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
next_state: next_state.clone(),
|
||||||
|
done,
|
||||||
|
};
|
||||||
|
self.agent.store_experience(experience);
|
||||||
|
self.agent.train();
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// if step_count % 100 == 0 {
|
||||||
|
// println!("{:?}", next_state);
|
||||||
|
// }
|
||||||
|
state = next_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
total_reward
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(
|
||||||
|
&mut self,
|
||||||
|
episodes: usize,
|
||||||
|
save_every: usize,
|
||||||
|
model_path: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
|
||||||
|
|
||||||
|
for episode in 1..=episodes {
|
||||||
|
let reward = self.train_episode();
|
||||||
|
|
||||||
|
if episode % 100 == 0 {
|
||||||
|
println!(
|
||||||
|
"Épisode {}/{}: Récompense = {:.2}, Epsilon = {:.3}, Steps = {}",
|
||||||
|
episode,
|
||||||
|
episodes,
|
||||||
|
reward,
|
||||||
|
self.agent.get_epsilon(),
|
||||||
|
self.agent.get_step_count()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if episode % save_every == 0 {
|
||||||
|
let save_path = format!("{}_episode_{}.json", model_path, episode);
|
||||||
|
self.agent.save_model(&save_path)?;
|
||||||
|
println!("Modèle sauvegardé : {}", save_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sauvegarder le modèle final
|
||||||
|
let final_path = format!("{}_final.json", model_path);
|
||||||
|
self.agent.save_model(&final_path)?;
|
||||||
|
println!("Modèle final sauvegardé : {}", final_path);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
112
bot/src/dqn/simple/main.rs
Normal file
112
bot/src/dqn/simple/main.rs
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
use bot::dqn::dqn_common::TrictracAction;
|
||||||
|
use bot::dqn::simple::dqn_model::DqnConfig;
|
||||||
|
use bot::dqn::simple::dqn_trainer::DqnTrainer;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
|
// Paramètres par défaut
|
||||||
|
let mut episodes = 1000;
|
||||||
|
let mut model_path = "models/dqn_model".to_string();
|
||||||
|
let mut save_every = 100;
|
||||||
|
|
||||||
|
// Parser les arguments de ligne de commande
|
||||||
|
let mut i = 1;
|
||||||
|
while i < args.len() {
|
||||||
|
match args[i].as_str() {
|
||||||
|
"--episodes" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
episodes = args[i + 1].parse().unwrap_or(1000);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --episodes nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--model-path" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
model_path = args[i + 1].clone();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --model-path nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--save-every" => {
|
||||||
|
if i + 1 < args.len() {
|
||||||
|
save_every = args[i + 1].parse().unwrap_or(100);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
eprintln!("Erreur : --save-every nécessite une valeur");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"--help" | "-h" => {
|
||||||
|
print_help();
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("Argument inconnu : {}", args[i]);
|
||||||
|
print_help();
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Créer le dossier models s'il n'existe pas
|
||||||
|
std::fs::create_dir_all("models")?;
|
||||||
|
|
||||||
|
println!("Configuration d'entraînement DQN :");
|
||||||
|
println!(" Épisodes : {}", episodes);
|
||||||
|
println!(" Chemin du modèle : {}", model_path);
|
||||||
|
println!(" Sauvegarde tous les {} épisodes", save_every);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Configuration DQN
|
||||||
|
let config = DqnConfig {
|
||||||
|
state_size: 36, // state.to_vec size
|
||||||
|
hidden_size: 256,
|
||||||
|
num_actions: TrictracAction::action_space_size(),
|
||||||
|
learning_rate: 0.001,
|
||||||
|
gamma: 0.99,
|
||||||
|
epsilon: 0.9, // Commencer avec plus d'exploration
|
||||||
|
epsilon_decay: 0.995,
|
||||||
|
epsilon_min: 0.01,
|
||||||
|
replay_buffer_size: 10000,
|
||||||
|
batch_size: 32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Créer et lancer l'entraîneur
|
||||||
|
let mut trainer = DqnTrainer::new(config);
|
||||||
|
trainer.train(episodes, save_every, &model_path)?;
|
||||||
|
|
||||||
|
println!("Entraînement terminé avec succès !");
|
||||||
|
println!("Pour utiliser le modèle entraîné :");
|
||||||
|
println!(
|
||||||
|
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
|
||||||
|
model_path
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_help() {
|
||||||
|
println!("Entraîneur DQN pour Trictrac");
|
||||||
|
println!();
|
||||||
|
println!("USAGE:");
|
||||||
|
println!(" cargo run --bin=train_dqn [OPTIONS]");
|
||||||
|
println!();
|
||||||
|
println!("OPTIONS:");
|
||||||
|
println!(" --episodes <NUM> Nombre d'épisodes d'entraînement (défaut: 1000)");
|
||||||
|
println!(" --model-path <PATH> Chemin de base pour sauvegarder les modèles (défaut: models/dqn_model)");
|
||||||
|
println!(" --save-every <NUM> Sauvegarder le modèle tous les N épisodes (défaut: 100)");
|
||||||
|
println!(" -h, --help Afficher cette aide");
|
||||||
|
println!();
|
||||||
|
println!("EXEMPLES:");
|
||||||
|
println!(" cargo run --bin=train_dqn");
|
||||||
|
println!(" cargo run --bin=train_dqn -- --episodes 5000 --save-every 500");
|
||||||
|
println!(" cargo run --bin=train_dqn -- --model-path models/my_model --episodes 2000");
|
||||||
|
}
|
||||||
2
bot/src/dqn/simple/mod.rs
Normal file
2
bot/src/dqn/simple/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod dqn_model;
|
||||||
|
pub mod dqn_trainer;
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
pub mod burnrl;
|
pub mod dqn;
|
||||||
pub mod strategy;
|
pub mod strategy;
|
||||||
pub mod training_common;
|
|
||||||
pub mod trictrac_board;
|
|
||||||
|
|
||||||
use log::debug;
|
use log::{debug, error};
|
||||||
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
|
||||||
pub use strategy::default::DefaultStrategy;
|
pub use strategy::default::DefaultStrategy;
|
||||||
|
pub use strategy::dqn::DqnStrategy;
|
||||||
pub use strategy::dqnburn::DqnBurnStrategy;
|
pub use strategy::dqnburn::DqnBurnStrategy;
|
||||||
pub use strategy::erroneous_moves::ErroneousStrategy;
|
pub use strategy::erroneous_moves::ErroneousStrategy;
|
||||||
pub use strategy::random::RandomStrategy;
|
pub use strategy::random::RandomStrategy;
|
||||||
|
|
|
||||||
174
bot/src/strategy/dqn.rs
Normal file
174
bot/src/strategy/dqn.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
|
use log::info;
|
||||||
|
use std::path::Path;
|
||||||
|
use store::MoveRules;
|
||||||
|
|
||||||
|
use crate::dqn::dqn_common::{get_valid_actions, sample_valid_action, TrictracAction};
|
||||||
|
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork;
|
||||||
|
|
||||||
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DqnStrategy {
|
||||||
|
pub game: GameState,
|
||||||
|
pub player_id: PlayerId,
|
||||||
|
pub color: Color,
|
||||||
|
pub model: Option<SimpleNeuralNetwork>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DqnStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
game: GameState::default(),
|
||||||
|
player_id: 1,
|
||||||
|
color: Color::White,
|
||||||
|
model: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DqnStrategy {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_model<P: AsRef<Path> + std::fmt::Debug>(model_path: P) -> Self {
|
||||||
|
let mut strategy = Self::new();
|
||||||
|
if let Ok(model) = SimpleNeuralNetwork::load(&model_path) {
|
||||||
|
info!("Loading model {model_path:?}");
|
||||||
|
strategy.model = Some(model);
|
||||||
|
}
|
||||||
|
strategy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilise le modèle DQN pour choisir une action valide
|
||||||
|
fn get_dqn_action(&self) -> Option<TrictracAction> {
|
||||||
|
if let Some(ref model) = self.model {
|
||||||
|
let state = self.game.to_vec_float();
|
||||||
|
let valid_actions = get_valid_actions(&self.game);
|
||||||
|
|
||||||
|
if valid_actions.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtenir les Q-values pour toutes les actions
|
||||||
|
let q_values = model.forward(&state);
|
||||||
|
|
||||||
|
// Trouver la meilleure action valide
|
||||||
|
let mut best_action = &valid_actions[0];
|
||||||
|
let mut best_q_value = f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
for action in &valid_actions {
|
||||||
|
let action_index = action.to_action_index();
|
||||||
|
if action_index < q_values.len() {
|
||||||
|
let q_value = q_values[action_index];
|
||||||
|
if q_value > best_q_value {
|
||||||
|
best_q_value = q_value;
|
||||||
|
best_action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(best_action.clone())
|
||||||
|
} else {
|
||||||
|
// Fallback : action aléatoire valide
|
||||||
|
sample_valid_action(&self.game)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotStrategy for DqnStrategy {
|
||||||
|
fn get_game(&self) -> &GameState {
|
||||||
|
&self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_game(&mut self) -> &mut GameState {
|
||||||
|
&mut self.game
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_color(&mut self, color: Color) {
|
||||||
|
self.color = color;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_player_id(&mut self, player_id: PlayerId) {
|
||||||
|
self.player_id = player_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_adv_points(&self) -> u8 {
|
||||||
|
self.game.dice_points.1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_go(&self) -> bool {
|
||||||
|
// Utiliser le DQN pour décider si on continue
|
||||||
|
if let Some(action) = self.get_dqn_action() {
|
||||||
|
matches!(action, TrictracAction::Go)
|
||||||
|
} else {
|
||||||
|
// Fallback : toujours continuer
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn choose_move(&self) -> (CheckerMove, CheckerMove) {
|
||||||
|
// Utiliser le DQN pour choisir le mouvement
|
||||||
|
if let Some(TrictracAction::Move {
|
||||||
|
dice_order,
|
||||||
|
from1,
|
||||||
|
from2,
|
||||||
|
}) = self.get_dqn_action()
|
||||||
|
{
|
||||||
|
let dicevals = self.game.dice.values;
|
||||||
|
let (mut dice1, mut dice2) = if dice_order {
|
||||||
|
(dicevals.0, dicevals.1)
|
||||||
|
} else {
|
||||||
|
(dicevals.1, dicevals.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
if from1 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice1 = 0;
|
||||||
|
}
|
||||||
|
let mut to1 = from1 + dice1 as usize;
|
||||||
|
if 24 < to1 {
|
||||||
|
// sortie
|
||||||
|
to1 = 0;
|
||||||
|
}
|
||||||
|
if from2 == 0 {
|
||||||
|
// empty move
|
||||||
|
dice2 = 0;
|
||||||
|
}
|
||||||
|
let mut to2 = from2 + dice2 as usize;
|
||||||
|
if 24 < to2 {
|
||||||
|
// sortie
|
||||||
|
to2 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
let checker_move2 = CheckerMove::new(from2, to2).unwrap_or_default();
|
||||||
|
|
||||||
|
let chosen_move = if self.color == Color::White {
|
||||||
|
(checker_move1, checker_move2)
|
||||||
|
} else {
|
||||||
|
(checker_move1.mirror(), checker_move2.mirror())
|
||||||
|
};
|
||||||
|
|
||||||
|
return chosen_move;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback : utiliser la stratégie par défaut
|
||||||
|
let rules = MoveRules::new(&self.color, &self.game.board, self.game.dice);
|
||||||
|
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
|
||||||
|
|
||||||
|
let chosen_move = *possible_moves
|
||||||
|
.first()
|
||||||
|
.unwrap_or(&(CheckerMove::default(), CheckerMove::default()));
|
||||||
|
|
||||||
|
if self.color == Color::White {
|
||||||
|
chosen_move
|
||||||
|
} else {
|
||||||
|
(chosen_move.0.mirror(), chosen_move.1.mirror())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -6,11 +6,10 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
|
||||||
use log::info;
|
use log::info;
|
||||||
use store::MoveRules;
|
use store::MoveRules;
|
||||||
|
|
||||||
use crate::burnrl::algos::dqn;
|
use crate::dqn::burnrl::{dqn_model, environment, utils};
|
||||||
use crate::burnrl::environment;
|
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
||||||
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
|
|
||||||
|
|
||||||
type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>;
|
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
|
||||||
|
|
||||||
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -40,7 +39,7 @@ impl DqnBurnStrategy {
|
||||||
pub fn new_with_model(model_path: &String) -> Self {
|
pub fn new_with_model(model_path: &String) -> Self {
|
||||||
info!("Loading model {model_path:?}");
|
info!("Loading model {model_path:?}");
|
||||||
let mut strategy = Self::new();
|
let mut strategy = Self::new();
|
||||||
strategy.model = dqn::load_model(256, model_path);
|
strategy.model = utils::load_model(256, model_path);
|
||||||
strategy
|
strategy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,8 +117,8 @@ impl BotStrategy for DqnBurnStrategy {
|
||||||
// Utiliser le DQN pour choisir le mouvement
|
// Utiliser le DQN pour choisir le mouvement
|
||||||
if let Some(TrictracAction::Move {
|
if let Some(TrictracAction::Move {
|
||||||
dice_order,
|
dice_order,
|
||||||
checker1,
|
from1,
|
||||||
checker2,
|
from2,
|
||||||
}) = self.get_dqn_action()
|
}) = self.get_dqn_action()
|
||||||
{
|
{
|
||||||
let dicevals = self.game.dice.values;
|
let dicevals = self.game.dice.values;
|
||||||
|
|
@ -129,65 +128,23 @@ impl BotStrategy for DqnBurnStrategy {
|
||||||
(dicevals.1, dicevals.0)
|
(dicevals.1, dicevals.0)
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(self.color, Color::White);
|
|
||||||
let from1 = self
|
|
||||||
.game
|
|
||||||
.board
|
|
||||||
.get_checker_field(&self.color, checker1 as u8)
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
if from1 == 0 {
|
if from1 == 0 {
|
||||||
// empty move
|
// empty move
|
||||||
dice1 = 0;
|
dice1 = 0;
|
||||||
}
|
}
|
||||||
let mut to1 = from1;
|
let mut to1 = from1 + dice1 as usize;
|
||||||
if self.color == Color::White {
|
if 24 < to1 {
|
||||||
to1 += dice1 as usize;
|
// sortie
|
||||||
if 24 < to1 {
|
to1 = 0;
|
||||||
// sortie
|
|
||||||
to1 = 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let fto1 = to1 as i16 - dice1 as i16;
|
|
||||||
to1 = if fto1 < 0 { 0 } else { fto1 as usize };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
|
|
||||||
|
|
||||||
let mut tmp_board = self.game.board.clone();
|
|
||||||
let move_res = tmp_board.move_checker(&self.color, checker_move1);
|
|
||||||
if move_res.is_err() {
|
|
||||||
panic!("could not move {move_res:?}");
|
|
||||||
}
|
|
||||||
let from2 = tmp_board
|
|
||||||
.get_checker_field(&self.color, checker2 as u8)
|
|
||||||
.unwrap_or(0);
|
|
||||||
if from2 == 0 {
|
if from2 == 0 {
|
||||||
// empty move
|
// empty move
|
||||||
dice2 = 0;
|
dice2 = 0;
|
||||||
}
|
}
|
||||||
let mut to2 = from2;
|
let mut to2 = from2 + dice2 as usize;
|
||||||
if self.color == Color::White {
|
if 24 < to2 {
|
||||||
to2 += dice2 as usize;
|
// sortie
|
||||||
if 24 < to2 {
|
to2 = 0;
|
||||||
// sortie
|
|
||||||
to2 = 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let fto2 = to2 as i16 - dice2 as i16;
|
|
||||||
to2 = if fto2 < 0 { 0 } else { fto2 as usize };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gestion prise de coin par puissance
|
|
||||||
let opp_rest_field = if self.color == Color::White { 13 } else { 12 };
|
|
||||||
if to1 == opp_rest_field && to2 == opp_rest_field {
|
|
||||||
if self.color == Color::White {
|
|
||||||
to1 -= 1;
|
|
||||||
to2 -= 1;
|
|
||||||
} else {
|
|
||||||
to1 += 1;
|
|
||||||
to2 += 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
|
||||||
|
|
@ -196,7 +153,6 @@ impl BotStrategy for DqnBurnStrategy {
|
||||||
let chosen_move = if self.color == Color::White {
|
let chosen_move = if self.color == Color::White {
|
||||||
(checker_move1, checker_move2)
|
(checker_move1, checker_move2)
|
||||||
} else {
|
} else {
|
||||||
// XXX : really ?
|
|
||||||
(checker_move1.mirror(), checker_move2.mirror())
|
(checker_move1.mirror(), checker_move2.mirror())
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod default;
|
pub mod default;
|
||||||
|
pub mod dqn;
|
||||||
pub mod dqnburn;
|
pub mod dqnburn;
|
||||||
pub mod erroneous_moves;
|
pub mod erroneous_moves;
|
||||||
pub mod random;
|
pub mod random;
|
||||||
|
|
|
||||||
|
|
@ -66,14 +66,14 @@ impl StableBaselines3Strategy {
|
||||||
// Remplir les positions des pièces blanches (valeurs positives)
|
// Remplir les positions des pièces blanches (valeurs positives)
|
||||||
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
for (pos, count) in self.game.board.get_color_fields(Color::White) {
|
||||||
if pos < 24 {
|
if pos < 24 {
|
||||||
board[pos] = count;
|
board[pos] = count as i8;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remplir les positions des pièces noires (valeurs négatives)
|
// Remplir les positions des pièces noires (valeurs négatives)
|
||||||
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
|
||||||
if pos < 24 {
|
if pos < 24 {
|
||||||
board[pos] = -count;
|
board[pos] = -(count as i8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -270,3 +270,4 @@ impl BotStrategy for StableBaselines3Strategy {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,164 +0,0 @@
|
||||||
// https://docs.rs/board-game/ implementation
|
|
||||||
use crate::training_common::{get_valid_actions, TrictracAction};
|
|
||||||
use board_game::board::{
|
|
||||||
Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer,
|
|
||||||
};
|
|
||||||
use board_game::impl_unit_symmetry_board;
|
|
||||||
use internal_iterator::InternalIterator;
|
|
||||||
use std::fmt;
|
|
||||||
use std::hash::Hash;
|
|
||||||
use std::ops::ControlFlow;
|
|
||||||
use store::Color;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
|
||||||
pub struct TrictracBoard(crate::GameState);
|
|
||||||
|
|
||||||
impl Default for TrictracBoard {
|
|
||||||
fn default() -> Self {
|
|
||||||
TrictracBoard(crate::GameState::new_with_players("white", "black"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for TrictracBoard {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
self.0.fmt(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_unit_symmetry_board!(TrictracBoard);
|
|
||||||
|
|
||||||
impl BoardGameBoard for TrictracBoard {
|
|
||||||
// impl TrictracBoard {
|
|
||||||
type Move = TrictracAction;
|
|
||||||
|
|
||||||
fn next_player(&self) -> BoardGamePlayer {
|
|
||||||
self.0
|
|
||||||
.who_plays()
|
|
||||||
.map(|p| {
|
|
||||||
if p.color == Color::Black {
|
|
||||||
BoardGamePlayer::B
|
|
||||||
} else {
|
|
||||||
BoardGamePlayer::A
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or(BoardGamePlayer::A)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
|
|
||||||
self.check_done()?;
|
|
||||||
let is_valid = mv
|
|
||||||
.to_event(&self.0)
|
|
||||||
.map(|evt| self.0.validate(&evt))
|
|
||||||
.unwrap_or(false);
|
|
||||||
Ok(is_valid)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
|
|
||||||
self.check_can_play(mv)?;
|
|
||||||
self.0.consume(&mv.to_event(&self.0).unwrap());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn outcome(&self) -> Option<Outcome> {
|
|
||||||
if self.0.stage == crate::Stage::Ended {
|
|
||||||
self.0.determine_winner().map(|player_id| {
|
|
||||||
Outcome::WonBy(if player_id == 1 {
|
|
||||||
BoardGamePlayer::A
|
|
||||||
} else {
|
|
||||||
BoardGamePlayer::B
|
|
||||||
})
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn can_lose_after_move() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrictracBoard {
|
|
||||||
pub fn inner(&self) -> &crate::GameState {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_fen(&self) -> String {
|
|
||||||
self.0.to_string_id()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_fen(fen: &str) -> Result<TrictracBoard, String> {
|
|
||||||
crate::GameState::from_string_id(fen).map(TrictracBoard)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> BoardMoves<'a, TrictracBoard> for TrictracBoard {
|
|
||||||
type AllMovesIterator = TrictracAllMovesIterator;
|
|
||||||
type AvailableMovesIterator = TrictracAvailableMovesIterator<'a>;
|
|
||||||
|
|
||||||
fn all_possible_moves() -> Self::AllMovesIterator {
|
|
||||||
TrictracAllMovesIterator::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
|
|
||||||
TrictracAvailableMovesIterator::new(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrictracAllMovesIterator;
|
|
||||||
|
|
||||||
impl Default for TrictracAllMovesIterator {
|
|
||||||
fn default() -> Self {
|
|
||||||
TrictracAllMovesIterator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InternalIterator for TrictracAllMovesIterator {
|
|
||||||
type Item = TrictracAction;
|
|
||||||
|
|
||||||
fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
|
|
||||||
f(TrictracAction::Roll)?;
|
|
||||||
f(TrictracAction::Go)?;
|
|
||||||
for dice_order in [false, true] {
|
|
||||||
for checker1 in 0..16 {
|
|
||||||
for checker2 in 0..16 {
|
|
||||||
f(TrictracAction::Move {
|
|
||||||
dice_order,
|
|
||||||
checker1,
|
|
||||||
checker2,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ControlFlow::Continue(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrictracAvailableMovesIterator<'a> {
|
|
||||||
board: &'a TrictracBoard,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> TrictracAvailableMovesIterator<'a> {
|
|
||||||
pub fn new(board: &'a TrictracBoard) -> Result<Self, BoardDone> {
|
|
||||||
board.check_done()?;
|
|
||||||
Ok(TrictracAvailableMovesIterator { board })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn board(&self) -> &'a TrictracBoard {
|
|
||||||
self.board
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InternalIterator for TrictracAvailableMovesIterator<'_> {
|
|
||||||
type Item = TrictracAction;
|
|
||||||
|
|
||||||
fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
|
|
||||||
where
|
|
||||||
F: FnMut(Self::Item) -> ControlFlow<R>,
|
|
||||||
{
|
|
||||||
get_valid_actions(&self.board.0).into_iter().try_for_each(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
8
client_bevy/.cargo/config.toml
Normal file
8
client_bevy/.cargo/config.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
linker = "clang"
|
||||||
|
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
|
||||||
|
|
||||||
|
# Optional: Uncommenting the following improves compile times, but reduces the amount of debug info to 'line number tables only'
|
||||||
|
# In most cases the gains are negligible, but if you are on macos and have slow compile times you should see significant gains.
|
||||||
|
#[profile.dev]
|
||||||
|
#debug = 1
|
||||||
14
client_bevy/Cargo.toml
Normal file
14
client_bevy/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "trictrac-client"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.75"
|
||||||
|
bevy = { version = "0.11.3" }
|
||||||
|
bevy_renet = "0.0.9"
|
||||||
|
bincode = "1.3.3"
|
||||||
|
renet = "0.0.13"
|
||||||
|
store = { path = "../store" }
|
||||||
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
BIN
client_bevy/assets/Inconsolata.ttf
Normal file
Binary file not shown.
BIN
client_bevy/assets/board.png
Normal file
BIN
client_bevy/assets/board.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
BIN
client_bevy/assets/sound/click.wav
Normal file
BIN
client_bevy/assets/sound/click.wav
Normal file
Binary file not shown.
BIN
client_bevy/assets/sound/throw.wav
Executable file
BIN
client_bevy/assets/sound/throw.wav
Executable file
Binary file not shown.
BIN
client_bevy/assets/tac.png
Normal file
BIN
client_bevy/assets/tac.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
BIN
client_bevy/assets/tic.png
Normal file
BIN
client_bevy/assets/tic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 KiB |
334
client_bevy/src/main.rs
Normal file
334
client_bevy/src/main.rs
Normal file
|
|
@ -0,0 +1,334 @@
|
||||||
|
use std::{net::UdpSocket, time::SystemTime};
|
||||||
|
|
||||||
|
use renet::transport::{NetcodeClientTransport, NetcodeTransportError, NETCODE_USER_DATA_BYTES};
|
||||||
|
use store::{GameEvent, GameState, CheckerMove};
|
||||||
|
|
||||||
|
use bevy::prelude::*;
|
||||||
|
use bevy::window::PrimaryWindow;
|
||||||
|
use bevy_renet::{
|
||||||
|
renet::{transport::ClientAuthentication, ConnectionConfig, RenetClient},
|
||||||
|
transport::{client_connected, NetcodeClientPlugin},
|
||||||
|
RenetClientPlugin,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Resource)]
|
||||||
|
struct CurrentClientId(u64);
|
||||||
|
|
||||||
|
#[derive(Resource)]
|
||||||
|
struct BevyGameState(GameState);
|
||||||
|
|
||||||
|
impl Default for BevyGameState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
0: GameState::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Resource, Deref, DerefMut)]
|
||||||
|
struct GameUIState {
|
||||||
|
selected_tile: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GameUIState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
selected_tile: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Event)]
|
||||||
|
struct BevyGameEvent(GameEvent);
|
||||||
|
|
||||||
|
// This id needs to be the same as the server is using
|
||||||
|
const PROTOCOL_ID: u64 = 2878;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Get username from stdin args
|
||||||
|
let args = std::env::args().collect::<Vec<String>>();
|
||||||
|
let username = &args[1];
|
||||||
|
|
||||||
|
let (client, transport, client_id) = new_renet_client(&username).unwrap();
|
||||||
|
App::new()
|
||||||
|
// Lets add a nice dark grey background color
|
||||||
|
.insert_resource(ClearColor(Color::hex("282828").unwrap()))
|
||||||
|
.add_plugins(DefaultPlugins.set(WindowPlugin {
|
||||||
|
primary_window: Some(Window {
|
||||||
|
// Adding the username to the window title makes debugging a whole lot easier.
|
||||||
|
title: format!("TricTrac <{}>", username),
|
||||||
|
resolution: (1080.0, 1080.0).into(),
|
||||||
|
..default()
|
||||||
|
}),
|
||||||
|
..default()
|
||||||
|
}))
|
||||||
|
// Add our game state and register GameEvent as a bevy event
|
||||||
|
.insert_resource(BevyGameState::default())
|
||||||
|
.insert_resource(GameUIState::default())
|
||||||
|
.add_event::<BevyGameEvent>()
|
||||||
|
// Renet setup
|
||||||
|
.add_plugins(RenetClientPlugin)
|
||||||
|
.add_plugins(NetcodeClientPlugin)
|
||||||
|
.insert_resource(client)
|
||||||
|
.insert_resource(transport)
|
||||||
|
.insert_resource(CurrentClientId(client_id))
|
||||||
|
.add_systems(Startup, setup)
|
||||||
|
.add_systems(Update, (update_waiting_text, input, update_board, panic_on_error_system))
|
||||||
|
.add_systems(
|
||||||
|
PostUpdate,
|
||||||
|
receive_events_from_server.run_if(client_connected()),
|
||||||
|
)
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// COMPONENTS //////////
|
||||||
|
#[derive(Component)]
|
||||||
|
struct UIRoot;
|
||||||
|
|
||||||
|
#[derive(Component)]
|
||||||
|
struct WaitingText;
|
||||||
|
|
||||||
|
#[derive(Component)]
|
||||||
|
struct Board {
|
||||||
|
squares: [Square; 26]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Board {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
squares: [Square { count: 0, color: None, position: 0}; 26]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Board {
|
||||||
|
fn square_at(&self, position: usize) -> Square {
|
||||||
|
self.squares[position]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Component, Clone, Copy)]
|
||||||
|
struct Square {
|
||||||
|
count: usize,
|
||||||
|
color: Option<bool>,
|
||||||
|
position: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// UPDATE SYSTEMS //////////
|
||||||
|
fn update_board(
|
||||||
|
mut commands: Commands,
|
||||||
|
game_state: Res<BevyGameState>,
|
||||||
|
mut game_events: EventReader<BevyGameEvent>,
|
||||||
|
asset_server: Res<AssetServer>,
|
||||||
|
) {
|
||||||
|
for event in game_events.iter() {
|
||||||
|
match event.0 {
|
||||||
|
GameEvent::Move { player_id, moves } => {
|
||||||
|
// trictrac positions, TODO : dépend de player_id
|
||||||
|
let (x, y) = if moves.0.get_to() < 13 { (13 - moves.0.get_to(), 1) } else { (moves.0.get_to() - 13, 0)};
|
||||||
|
let texture =
|
||||||
|
asset_server.load(match game_state.0.players[&player_id].color {
|
||||||
|
store::Color::Black => "tac.png",
|
||||||
|
store::Color::White => "tic.png",
|
||||||
|
});
|
||||||
|
|
||||||
|
info!("spawning tictac sprite");
|
||||||
|
commands.spawn(SpriteBundle {
|
||||||
|
transform: Transform::from_xyz(
|
||||||
|
83.0 * (x as f32 - 1.0),
|
||||||
|
-30.0 + 540.0 * (y as f32 - 1.0),
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
sprite: Sprite {
|
||||||
|
custom_size: Some(Vec2::new(83.0, 83.0)),
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
texture: texture.into(),
|
||||||
|
..default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_waiting_text(mut text_query: Query<&mut Text, With<WaitingText>>, time: Res<Time>) {
|
||||||
|
if let Ok(mut text) = text_query.get_single_mut() {
|
||||||
|
let num_dots = (time.elapsed_seconds() as usize % 3) + 1;
|
||||||
|
text.sections[0].value = format!(
|
||||||
|
"Waiting for an opponent{}{}",
|
||||||
|
".".repeat(num_dots as usize),
|
||||||
|
// Pad with spaces to avoid text changing width and dancing all around the screen 🕺
|
||||||
|
" ".repeat(3 - num_dots as usize)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input(
|
||||||
|
primary_query: Query<&Window, With<PrimaryWindow>>,
|
||||||
|
// windows: Res<Windows>,
|
||||||
|
input: Res<Input<MouseButton>>,
|
||||||
|
game_state: Res<BevyGameState>,
|
||||||
|
mut game_ui_state: ResMut<GameUIState>,
|
||||||
|
mut client: ResMut<RenetClient>,
|
||||||
|
client_id: Res<CurrentClientId>,
|
||||||
|
) {
|
||||||
|
// We only want to handle inputs once we are ingame
|
||||||
|
if game_state.0.stage != store::Stage::InGame {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let window = primary_query.get_single().unwrap();
|
||||||
|
if let Some(mouse_position) = window.cursor_position() {
|
||||||
|
// Determine the index of the tile that the mouse is currently over
|
||||||
|
// NOTE: This calculation assumes a fixed window size.
|
||||||
|
// That's fine for now, but consider using the windows size instead.
|
||||||
|
let mut tile_x: usize = (mouse_position.x / 83.0).floor() as usize;
|
||||||
|
let tile_y: usize = (mouse_position.y / 540.0).floor() as usize;
|
||||||
|
if tile_x > 5 {
|
||||||
|
// remove the middle bar offset
|
||||||
|
tile_x = tile_x - 1
|
||||||
|
}
|
||||||
|
// let tile = tile_x + tile_y * 12;
|
||||||
|
|
||||||
|
// traduction en position backgammon
|
||||||
|
let tile = if tile_y == 0 {
|
||||||
|
13 + tile_x
|
||||||
|
} else {
|
||||||
|
12 - tile_x
|
||||||
|
};
|
||||||
|
|
||||||
|
// If mouse is outside of board we do nothing
|
||||||
|
if 23 < tile {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If left mouse button is pressed, send a place tile event to the server
|
||||||
|
if input.just_pressed(MouseButton::Left) {
|
||||||
|
info!("select piece at tile {:?}", tile);
|
||||||
|
if game_ui_state.selected_tile.is_some() {
|
||||||
|
let from_tile = game_ui_state.selected_tile.unwrap();
|
||||||
|
info!("sending movement from: {:?} to: {:?} ", from_tile, tile);
|
||||||
|
let event = GameEvent::Move {
|
||||||
|
player_id: client_id.0,
|
||||||
|
moves: (
|
||||||
|
CheckerMove::new(from_tile, tile).unwrap(),
|
||||||
|
CheckerMove::new(from_tile, tile).unwrap()
|
||||||
|
)
|
||||||
|
};
|
||||||
|
client.send_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
game_ui_state.selected_tile = if game_ui_state.selected_tile.is_some() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// SETUP //////////
|
||||||
|
fn setup(mut commands: Commands, asset_server: Res<AssetServer>) {
|
||||||
|
// Tric Trac is a 2D game
|
||||||
|
// To show 2D sprites we need a 2D camera
|
||||||
|
commands.spawn(Camera2dBundle::default());
|
||||||
|
|
||||||
|
// Spawn board background
|
||||||
|
commands.spawn(SpriteBundle {
|
||||||
|
transform: Transform::from_xyz(0.0, -30.0, 0.0),
|
||||||
|
sprite: Sprite {
|
||||||
|
custom_size: Some(Vec2::new(1080.0, 927.0)),
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
texture: asset_server.load("board.png").into(),
|
||||||
|
..default()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn pregame ui
|
||||||
|
commands
|
||||||
|
// A container that centers its children on the screen
|
||||||
|
.spawn(NodeBundle {
|
||||||
|
style: Style {
|
||||||
|
position_type: PositionType::Absolute,
|
||||||
|
left: Val::Px(0.0),
|
||||||
|
top: Val::Px(0.0),
|
||||||
|
width: Val::Percent(100.0),
|
||||||
|
height: Val::Percent(100.0),
|
||||||
|
align_items: AlignItems::Center,
|
||||||
|
justify_content: JustifyContent::Center,
|
||||||
|
..default()
|
||||||
|
},
|
||||||
|
..default()
|
||||||
|
})
|
||||||
|
.insert(UIRoot)
|
||||||
|
.with_children(|parent| {
|
||||||
|
// parent.spawn(Board::default()); // panic
|
||||||
|
parent
|
||||||
|
.spawn(TextBundle::from_section(
|
||||||
|
"Waiting for an opponent...",
|
||||||
|
TextStyle {
|
||||||
|
font: asset_server.load("Inconsolata.ttf"),
|
||||||
|
font_size: 24.0,
|
||||||
|
color: Color::hex("ebdbb2").unwrap(),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.insert(WaitingText);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// RENET NETWORKING //////////
|
||||||
|
// Creates a RenetClient thats already connected to a server.
|
||||||
|
// Returns an Err if connection fails
|
||||||
|
fn new_renet_client(
|
||||||
|
username: &String,
|
||||||
|
) -> anyhow::Result<(RenetClient, NetcodeClientTransport, u64)> {
|
||||||
|
let client = RenetClient::new(ConnectionConfig::default());
|
||||||
|
let server_addr = "127.0.0.1:5000".parse()?;
|
||||||
|
let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||||
|
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?;
|
||||||
|
let client_id = current_time.as_millis() as u64;
|
||||||
|
|
||||||
|
// Place username in user data
|
||||||
|
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
|
||||||
|
if username.len() > NETCODE_USER_DATA_BYTES - 8 {
|
||||||
|
panic!("Username is too big");
|
||||||
|
}
|
||||||
|
user_data[0..8].copy_from_slice(&(username.len() as u64).to_le_bytes());
|
||||||
|
user_data[8..username.len() + 8].copy_from_slice(username.as_bytes());
|
||||||
|
|
||||||
|
let authentication = ClientAuthentication::Unsecure {
|
||||||
|
server_addr,
|
||||||
|
client_id,
|
||||||
|
user_data: Some(user_data),
|
||||||
|
protocol_id: PROTOCOL_ID,
|
||||||
|
};
|
||||||
|
let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
|
||||||
|
|
||||||
|
Ok((client, transport, client_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn receive_events_from_server(
|
||||||
|
mut client: ResMut<RenetClient>,
|
||||||
|
mut game_state: ResMut<BevyGameState>,
|
||||||
|
mut game_events: EventWriter<BevyGameEvent>,
|
||||||
|
) {
|
||||||
|
while let Some(message) = client.receive_message(0) {
|
||||||
|
// Whenever the server sends a message we know that it must be a game event
|
||||||
|
let event: GameEvent = bincode::deserialize(&message).unwrap();
|
||||||
|
trace!("{:#?}", event);
|
||||||
|
|
||||||
|
// We trust the server - It's always been good to us!
|
||||||
|
// No need to validate the events it is sending us
|
||||||
|
game_state.0.consume(&event);
|
||||||
|
|
||||||
|
// Send the event into the bevy event system so systems can react to it
|
||||||
|
game_events.send(BevyGameEvent(event));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any error is found we just panic
|
||||||
|
fn panic_on_error_system(mut renet_error: EventReader<NetcodeTransportError>) {
|
||||||
|
for e in renet_error.iter() {
|
||||||
|
panic!("{}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bot::{
|
use bot::{
|
||||||
BotStrategy, DefaultStrategy, DqnBurnStrategy, ErroneousStrategy, RandomStrategy,
|
BotStrategy, DefaultStrategy, DqnBurnStrategy, DqnStrategy, ErroneousStrategy, RandomStrategy,
|
||||||
StableBaselines3Strategy,
|
StableBaselines3Strategy,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
@ -25,11 +25,11 @@ pub struct App {
|
||||||
impl App {
|
impl App {
|
||||||
// Constructs a new instance of [`App`].
|
// Constructs a new instance of [`App`].
|
||||||
pub fn new(args: AppArgs) -> Self {
|
pub fn new(args: AppArgs) -> Self {
|
||||||
let bot_strategies: Vec<Box<dyn BotStrategy>> =
|
let bot_strategies: Vec<Box<dyn BotStrategy>> = args
|
||||||
args.bot
|
.bot
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.map(|str_bots| {
|
.map(|str_bots| {
|
||||||
str_bots
|
str_bots
|
||||||
.split(",")
|
.split(",")
|
||||||
.filter_map(|s| match s.trim() {
|
.filter_map(|s| match s.trim() {
|
||||||
"dummy" => {
|
"dummy" => {
|
||||||
|
|
@ -43,6 +43,7 @@ impl App {
|
||||||
}
|
}
|
||||||
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
"ai" => Some(Box::new(StableBaselines3Strategy::default())
|
||||||
as Box<dyn BotStrategy>),
|
as Box<dyn BotStrategy>),
|
||||||
|
"dqn" => Some(Box::new(DqnStrategy::default()) as Box<dyn BotStrategy>),
|
||||||
"dqnburn" => {
|
"dqnburn" => {
|
||||||
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
Some(Box::new(DqnBurnStrategy::default()) as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
|
|
@ -51,16 +52,21 @@ impl App {
|
||||||
Some(Box::new(StableBaselines3Strategy::new(path))
|
Some(Box::new(StableBaselines3Strategy::new(path))
|
||||||
as Box<dyn BotStrategy>)
|
as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
|
s if s.starts_with("dqn:") => {
|
||||||
|
let path = s.trim_start_matches("dqn:");
|
||||||
|
Some(Box::new(DqnStrategy::new_with_model(path))
|
||||||
|
as Box<dyn BotStrategy>)
|
||||||
|
}
|
||||||
s if s.starts_with("dqnburn:") => {
|
s if s.starts_with("dqnburn:") => {
|
||||||
let path = s.trim_start_matches("dqnburn:");
|
let path = s.trim_start_matches("dqnburn:");
|
||||||
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
|
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}")))
|
||||||
as Box<dyn BotStrategy>)
|
as Box<dyn BotStrategy>)
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let schools_enabled = false;
|
let schools_enabled = false;
|
||||||
let should_quit = bot_strategies.len() > 1;
|
let should_quit = bot_strategies.len() > 1;
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -108,7 +114,7 @@ impl App {
|
||||||
|
|
||||||
pub fn show_history(&self) {
|
pub fn show_history(&self) {
|
||||||
for hist in self.game.state.history.iter() {
|
for hist in self.game.state.history.iter() {
|
||||||
println!("{hist:?}\n");
|
println!("{:?}\n", hist);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -133,9 +139,6 @@ impl App {
|
||||||
// &self.game.state.board,
|
// &self.game.state.board,
|
||||||
// dice,
|
// dice,
|
||||||
// );
|
// );
|
||||||
self.game.handle_event(&GameEvent::Roll {
|
|
||||||
player_id: self.game.player_id.unwrap(),
|
|
||||||
});
|
|
||||||
self.game.handle_event(&GameEvent::RollResult {
|
self.game.handle_event(&GameEvent::RollResult {
|
||||||
player_id: self.game.player_id.unwrap(),
|
player_id: self.game.player_id.unwrap(),
|
||||||
dice,
|
dice,
|
||||||
|
|
@ -186,7 +189,7 @@ impl App {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("invalid move : {input}");
|
println!("invalid move : {}", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn display(&mut self) -> String {
|
pub fn display(&mut self) -> String {
|
||||||
|
|
@ -326,7 +329,6 @@ Player :: holes :: points
|
||||||
seed: Some(1327),
|
seed: Some(1327),
|
||||||
bot: Some("dummy".into()),
|
bot: Some("dummy".into()),
|
||||||
});
|
});
|
||||||
println!("avant : {}", app.display());
|
|
||||||
app.input("roll");
|
app.input("roll");
|
||||||
app.input("1 3");
|
app.input("1 3");
|
||||||
app.input("1 4");
|
app.input("1 4");
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ impl GameRunner {
|
||||||
} else {
|
} else {
|
||||||
debug!("{}", self.state);
|
debug!("{}", self.state);
|
||||||
error!("event not valid : {event:?}");
|
error!("event not valid : {event:?}");
|
||||||
// panic!("crash and burn {} \nevt not valid {event:?}", self.state);
|
panic!("crash and burn");
|
||||||
&GameEvent::PlayError
|
&GameEvent::PlayError
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ fn main() -> Result<()> {
|
||||||
let args = match parse_args() {
|
let args = match parse_args() {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error: {e}.");
|
eprintln!("Error: {}.", e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -63,7 +63,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
|
||||||
|
|
||||||
// Help has a higher priority and should be handled separately.
|
// Help has a higher priority and should be handled separately.
|
||||||
if pargs.contains(["-h", "--help"]) {
|
if pargs.contains(["-h", "--help"]) {
|
||||||
print!("{HELP}");
|
print!("{}", HELP);
|
||||||
std::process::exit(0);
|
std::process::exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -78,7 +78,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
|
||||||
// It's up to the caller what to do with the remaining arguments.
|
// It's up to the caller what to do with the remaining arguments.
|
||||||
let remaining = pargs.finish();
|
let remaining = pargs.finish();
|
||||||
if !remaining.is_empty() {
|
if !remaining.is_empty() {
|
||||||
eprintln!("Warning: unused arguments left: {remaining:?}.");
|
eprintln!("Warning: unused arguments left: {:?}.", remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(args)
|
Ok(args)
|
||||||
|
|
|
||||||
14
client_tui/Cargo.toml
Normal file
14
client_tui/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "client_tui"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.89"
|
||||||
|
bincode = "1.3.3"
|
||||||
|
crossterm = "0.28.1"
|
||||||
|
ratatui = "0.28.1"
|
||||||
|
# renet = "0.0.13"
|
||||||
|
store = { path = "../store" }
|
||||||
53
client_tui/src/app.rs
Normal file
53
client_tui/src/app.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
// Application.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct App {
|
||||||
|
// should the application exit?
|
||||||
|
pub should_quit: bool,
|
||||||
|
// counter
|
||||||
|
pub counter: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl App {
|
||||||
|
// Constructs a new instance of [`App`].
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles the tick event of the terminal.
|
||||||
|
pub fn tick(&self) {}
|
||||||
|
|
||||||
|
// Set running to false to quit the application.
|
||||||
|
pub fn quit(&mut self) {
|
||||||
|
self.should_quit = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_counter(&mut self) {
|
||||||
|
if let Some(res) = self.counter.checked_add(1) {
|
||||||
|
self.counter = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrement_counter(&mut self) {
|
||||||
|
if let Some(res) = self.counter.checked_sub(1) {
|
||||||
|
self.counter = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
#[test]
|
||||||
|
fn test_app_increment_counter() {
|
||||||
|
let mut app = App::default();
|
||||||
|
app.increment_counter();
|
||||||
|
assert_eq!(app.counter, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_app_decrement_counter() {
|
||||||
|
let mut app = App::default();
|
||||||
|
app.decrement_counter();
|
||||||
|
assert_eq!(app.counter, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
87
client_tui/src/event.rs
Normal file
87
client_tui/src/event.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
use std::{
|
||||||
|
sync::mpsc,
|
||||||
|
thread,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use crossterm::event::{self, Event as CrosstermEvent, KeyEvent, MouseEvent};
|
||||||
|
|
||||||
|
// Terminal events.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub enum Event {
|
||||||
|
// Terminal tick.
|
||||||
|
Tick,
|
||||||
|
// Key press.
|
||||||
|
Key(KeyEvent),
|
||||||
|
// Mouse click/scroll.
|
||||||
|
Mouse(MouseEvent),
|
||||||
|
// Terminal resize.
|
||||||
|
Resize(u16, u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminal event handler.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EventHandler {
|
||||||
|
// Event sender channel.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
sender: mpsc::Sender<Event>,
|
||||||
|
// Event receiver channel.
|
||||||
|
receiver: mpsc::Receiver<Event>,
|
||||||
|
// Event handler thread.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
handler: thread::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventHandler {
|
||||||
|
// Constructs a new instance of [`EventHandler`].
|
||||||
|
pub fn new(tick_rate: u64) -> Self {
|
||||||
|
let tick_rate = Duration::from_millis(tick_rate);
|
||||||
|
let (sender, receiver) = mpsc::channel();
|
||||||
|
let handler = {
|
||||||
|
let sender = sender.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut last_tick = Instant::now();
|
||||||
|
loop {
|
||||||
|
let timeout = tick_rate
|
||||||
|
.checked_sub(last_tick.elapsed())
|
||||||
|
.unwrap_or(tick_rate);
|
||||||
|
|
||||||
|
if event::poll(timeout).expect("no events available") {
|
||||||
|
match event::read().expect("unable to read event") {
|
||||||
|
CrosstermEvent::Key(e) => {
|
||||||
|
if e.kind == event::KeyEventKind::Press {
|
||||||
|
sender.send(Event::Key(e))
|
||||||
|
} else {
|
||||||
|
Ok(()) // ignore KeyEventKind::Release on windows
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CrosstermEvent::Mouse(e) => sender.send(Event::Mouse(e)),
|
||||||
|
CrosstermEvent::Resize(w, h) => sender.send(Event::Resize(w, h)),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
.expect("failed to send terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
|
if last_tick.elapsed() >= tick_rate {
|
||||||
|
sender.send(Event::Tick).expect("failed to send tick event");
|
||||||
|
last_tick = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
sender,
|
||||||
|
receiver,
|
||||||
|
handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive the next event from the handler thread.
|
||||||
|
//
|
||||||
|
// This function will always block the current thread if
|
||||||
|
// there is no data available and it's possible for more data to be sent.
|
||||||
|
pub fn next(&self) -> Result<Event> {
|
||||||
|
Ok(self.receiver.recv()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
50
client_tui/src/main.rs
Normal file
50
client_tui/src/main.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Application.
|
||||||
|
pub mod app;
|
||||||
|
|
||||||
|
// Terminal events handler.
|
||||||
|
pub mod event;
|
||||||
|
|
||||||
|
// Widget renderer.
|
||||||
|
pub mod ui;
|
||||||
|
|
||||||
|
// Terminal user interface.
|
||||||
|
pub mod tui;
|
||||||
|
|
||||||
|
// Application updater.
|
||||||
|
pub mod update;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use app::App;
|
||||||
|
use event::{Event, EventHandler};
|
||||||
|
use ratatui::{backend::CrosstermBackend, Terminal};
|
||||||
|
use tui::Tui;
|
||||||
|
use update::update;
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// Create an application.
|
||||||
|
let mut app = App::new();
|
||||||
|
|
||||||
|
// Initialize the terminal user interface.
|
||||||
|
let backend = CrosstermBackend::new(std::io::stderr());
|
||||||
|
let terminal = Terminal::new(backend)?;
|
||||||
|
let events = EventHandler::new(250);
|
||||||
|
let mut tui = Tui::new(terminal, events);
|
||||||
|
tui.enter()?;
|
||||||
|
|
||||||
|
// Start the main loop.
|
||||||
|
while !app.should_quit {
|
||||||
|
// Render the user interface.
|
||||||
|
tui.draw(&mut app)?;
|
||||||
|
// Handle events.
|
||||||
|
match tui.events.next()? {
|
||||||
|
Event::Tick => {}
|
||||||
|
Event::Key(key_event) => update(&mut app, key_event),
|
||||||
|
Event::Mouse(_) => {}
|
||||||
|
Event::Resize(_, _) => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit the user interface.
|
||||||
|
tui.exit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
77
client_tui/src/tui.rs
Normal file
77
client_tui/src/tui.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
use std::{io, panic};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use crossterm::{
|
||||||
|
event::{DisableMouseCapture, EnableMouseCapture},
|
||||||
|
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type CrosstermTerminal = ratatui::Terminal<ratatui::backend::CrosstermBackend<std::io::Stderr>>;
|
||||||
|
|
||||||
|
use crate::{app::App, event::EventHandler, ui};
|
||||||
|
|
||||||
|
// Representation of a terminal user interface.
|
||||||
|
//
|
||||||
|
// It is responsible for setting up the terminal,
|
||||||
|
// initializing the interface and handling the draw events.
|
||||||
|
pub struct Tui {
|
||||||
|
// Interface to the Terminal.
|
||||||
|
terminal: CrosstermTerminal,
|
||||||
|
// Terminal event handler.
|
||||||
|
pub events: EventHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tui {
|
||||||
|
// Constructs a new instance of [`Tui`].
|
||||||
|
pub fn new(terminal: CrosstermTerminal, events: EventHandler) -> Self {
|
||||||
|
Self { terminal, events }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initializes the terminal interface.
|
||||||
|
//
|
||||||
|
// It enables the raw mode and sets terminal properties.
|
||||||
|
pub fn enter(&mut self) -> Result<()> {
|
||||||
|
terminal::enable_raw_mode()?;
|
||||||
|
crossterm::execute!(io::stderr(), EnterAlternateScreen, EnableMouseCapture)?;
|
||||||
|
|
||||||
|
// Define a custom panic hook to reset the terminal properties.
|
||||||
|
// This way, you won't have your terminal messed up if an unexpected error happens.
|
||||||
|
let panic_hook = panic::take_hook();
|
||||||
|
panic::set_hook(Box::new(move |panic| {
|
||||||
|
Self::reset().expect("failed to reset the terminal");
|
||||||
|
panic_hook(panic);
|
||||||
|
}));
|
||||||
|
|
||||||
|
self.terminal.hide_cursor()?;
|
||||||
|
self.terminal.clear()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// [`Draw`] the terminal interface by [`rendering`] the widgets.
|
||||||
|
//
|
||||||
|
// [`Draw`]: tui::Terminal::draw
|
||||||
|
// [`rendering`]: crate::ui:render
|
||||||
|
pub fn draw(&mut self, app: &mut App) -> Result<()> {
|
||||||
|
self.terminal.draw(|frame| ui::render(app, frame))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the terminal interface.
|
||||||
|
//
|
||||||
|
// This function is also used for the panic hook to revert
|
||||||
|
// the terminal properties if unexpected errors occur.
|
||||||
|
fn reset() -> Result<()> {
|
||||||
|
terminal::disable_raw_mode()?;
|
||||||
|
crossterm::execute!(io::stderr(), LeaveAlternateScreen, DisableMouseCapture)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exits the terminal interface.
|
||||||
|
//
|
||||||
|
// It disables the raw mode and reverts back the terminal properties.
|
||||||
|
pub fn exit(&mut self) -> Result<()> {
|
||||||
|
Self::reset()?;
|
||||||
|
self.terminal.show_cursor()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
30
client_tui/src/ui.rs
Normal file
30
client_tui/src/ui.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
use ratatui::{
|
||||||
|
prelude::{Alignment, Frame},
|
||||||
|
style::{Color, Style},
|
||||||
|
widgets::{Block, BorderType, Borders, Paragraph},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
|
|
||||||
|
pub fn render(app: &mut App, f: &mut Frame) {
|
||||||
|
f.render_widget(
|
||||||
|
Paragraph::new(format!(
|
||||||
|
"
|
||||||
|
Press `Esc`, `Ctrl-C` or `q` to stop running.\n\
|
||||||
|
Press `j` and `k` to increment and decrement the counter respectively.\n\
|
||||||
|
Counter: {}
|
||||||
|
",
|
||||||
|
app.counter
|
||||||
|
))
|
||||||
|
.block(
|
||||||
|
Block::default()
|
||||||
|
.title("Counter App")
|
||||||
|
.title_alignment(Alignment::Center)
|
||||||
|
.borders(Borders::ALL)
|
||||||
|
.border_type(BorderType::Rounded),
|
||||||
|
)
|
||||||
|
.style(Style::default().fg(Color::Yellow))
|
||||||
|
.alignment(Alignment::Center),
|
||||||
|
f.area(),
|
||||||
|
)
|
||||||
|
}
|
||||||
17
client_tui/src/update.rs
Normal file
17
client_tui/src/update.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
|
|
||||||
|
pub fn update(app: &mut App, key_event: KeyEvent) {
|
||||||
|
match key_event.code {
|
||||||
|
KeyCode::Esc | KeyCode::Char('q') => app.quit(),
|
||||||
|
KeyCode::Char('c') | KeyCode::Char('C') => {
|
||||||
|
if key_event.modifiers == KeyModifiers::CONTROL {
|
||||||
|
app.quit()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KeyCode::Right | KeyCode::Char('j') => app.increment_counter(),
|
||||||
|
KeyCode::Left | KeyCode::Char('k') => app.decrement_counter(),
|
||||||
|
_ => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
46
doc/refs/geminiQuestions.md
Normal file
46
doc/refs/geminiQuestions.md
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Description du projet et question
|
||||||
|
|
||||||
|
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
|
||||||
|
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
|
||||||
|
|
||||||
|
Les règles du jeu et l'état d'une partie sont implémentées dans 'store', l'application ligne de commande est implémentée dans 'client_cli', elle permet déjà de jouer contre un bot, ou de faire jouer deux bots l'un contre l'autre.
|
||||||
|
Les stratégies de bots sont implémentées dans le dossier 'bot'.
|
||||||
|
|
||||||
|
Plus précisément, l'état du jeu est défini par le struct GameState dans store/src/game.rs, la méthode to_string_id() permet de coder cet état de manière compacte dans une chaîne de caractères, mais il n'y a pas l'historique des coups joués. Il y a aussi fmt::Display d'implémenté pour une representation textuelle plus lisible.
|
||||||
|
|
||||||
|
'client_cli/src/game_runner.rs' contient la logique permettant de faire jouer deux bots l'un contre l'autre.
|
||||||
|
'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
|
||||||
|
|
||||||
|
Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning.
|
||||||
|
|
||||||
|
Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`)
|
||||||
|
Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`.
|
||||||
|
|
||||||
|
Une version, toujours DQN, mais en utilisant la bibliothèque burn (<https://burn.dev/>) est en cours de développement.
|
||||||
|
|
||||||
|
L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'.
|
||||||
|
|
||||||
|
Voici la sortie de l'entraînement lancé avec 'just trainbot' :
|
||||||
|
|
||||||
|
```
|
||||||
|
> Entraînement
|
||||||
|
> {"episode": 0, "reward": -1692.3148, "duration": 1000}
|
||||||
|
> {"episode": 1, "reward": -361.6962, "duration": 1000}
|
||||||
|
> {"episode": 2, "reward": -126.1013, "duration": 1000}
|
||||||
|
> {"episode": 3, "reward": -36.8000, "duration": 1000}
|
||||||
|
> {"episode": 4, "reward": -21.4997, "duration": 1000}
|
||||||
|
> {"episode": 5, "reward": -8.3000, "duration": 1000}
|
||||||
|
> {"episode": 6, "reward": 3.1000, "duration": 1000}
|
||||||
|
> {"episode": 7, "reward": -21.5998, "duration": 1000}
|
||||||
|
> {"episode": 8, "reward": -10.1999, "duration": 1000}
|
||||||
|
> {"episode": 9, "reward": 3.1000, "duration": 1000}
|
||||||
|
> {"episode": 10, "reward": 14.5002, "duration": 1000}
|
||||||
|
> {"episode": 11, "reward": 10.7000, "duration": 1000}
|
||||||
|
> {"episode": 12, "reward": -0.7000, "duration": 1000}
|
||||||
|
|
||||||
|
thread 'main' has overflowed its stack
|
||||||
|
fatal runtime error: stack overflow
|
||||||
|
error: Recipe `trainbot` was terminated on line 25 by signal 6
|
||||||
|
```
|
||||||
|
|
||||||
|
Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug <https://github.com/yunjhongwu/burn-rl-examples/issues/40> , donc peut-être que le problème vient du paquet 'burl-rl'.
|
||||||
|
|
@ -1,54 +1,46 @@
|
||||||
# Inspirations
|
# Inspirations
|
||||||
|
|
||||||
tools
|
tools
|
||||||
|
- config clippy ?
|
||||||
- config clippy ?
|
- bacon : tests runner (ou loom ?)
|
||||||
- bacon : tests runner (ou loom ?)
|
|
||||||
|
|
||||||
## Rust libs
|
## Rust libs
|
||||||
|
|
||||||
cf. <https://blessed.rs/crates>
|
cf. https://blessed.rs/crates
|
||||||
|
|
||||||
nombres aléatoires avec seed : <https://richard.dallaway.com/posts/2021-01-04-repeat-resume/>
|
nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-repeat-resume/
|
||||||
|
|
||||||
- cli : <https://lib.rs/crates/pico-args> ( ou clap )
|
- cli : https://lib.rs/crates/pico-args ( ou clap )
|
||||||
- reseau async : tokio
|
- reseau async : tokio
|
||||||
- web serveur : axum (uses tokio)
|
- web serveur : axum (uses tokio)
|
||||||
- <https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling>
|
- https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling
|
||||||
- db : sqlx
|
- db : sqlx
|
||||||
|
|
||||||
|
|
||||||
- eyre, color-eyre (Results)
|
- eyre, color-eyre (Results)
|
||||||
- tracing (logging)
|
- tracing (logging)
|
||||||
- rayon ( sync <-> parallel )
|
- rayon ( sync <-> parallel )
|
||||||
|
|
||||||
- front : yew + tauri
|
- front : yew + tauri
|
||||||
|
|
||||||
- egui
|
- egui
|
||||||
|
|
||||||
- <https://docs.rs/board-game/latest/board_game/>
|
- https://docs.rs/board-game/latest/board_game/
|
||||||
|
|
||||||
## network games
|
|
||||||
|
|
||||||
- <https://www.mattkeeter.com/projects/pont/>
|
|
||||||
- <https://github.com/jackadamson/onitama> (wasm, rooms)
|
|
||||||
- <https://github.com/UkoeHB/renet2>
|
|
||||||
- <https://github.com/UkoeHB/bevy_simplenet>
|
|
||||||
|
|
||||||
## Others
|
## Others
|
||||||
|
- plugins avec https://github.com/extism/extism
|
||||||
- plugins avec <https://github.com/extism/extism>
|
|
||||||
|
|
||||||
## Backgammon existing projects
|
## Backgammon existing projects
|
||||||
|
|
||||||
- go : <https://bgammon.org/blog/20240101-hello-world/>
|
* go : https://bgammon.org/blog/20240101-hello-world/
|
||||||
- protocole de communication : <https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md>
|
- protocole de communication : https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md
|
||||||
- ocaml : <https://github.com/jacobhilton/backgammon?tab=readme-ov-file>
|
* ocaml : https://github.com/jacobhilton/backgammon?tab=readme-ov-file
|
||||||
cli example : <https://www.jacobh.co.uk/backgammon/>
|
cli example : https://www.jacobh.co.uk/backgammon/
|
||||||
- lib rust backgammon
|
* lib rust backgammon
|
||||||
- <https://github.com/carlostrub/backgammon>
|
- https://github.com/carlostrub/backgammon
|
||||||
- <https://github.com/marktani/backgammon>
|
- https://github.com/marktani/backgammon
|
||||||
- network webtarot
|
* network webtarot
|
||||||
- front ?
|
* front ?
|
||||||
|
|
||||||
|
|
||||||
## cli examples
|
## cli examples
|
||||||
|
|
||||||
|
|
@ -81,35 +73,31 @@ Move 11: player O rolls a 6-2.
|
||||||
Player O estimates that they have a 90.6111% chance of winning.
|
Player O estimates that they have a 90.6111% chance of winning.
|
||||||
|
|
||||||
Os borne off: none
|
Os borne off: none
|
||||||
24 23 22 21 20 19 18 17 16 15 14 13
|
24 23 22 21 20 19 18 17 16 15 14 13
|
||||||
|
-------------------------------------------------------------------
|
||||||
---
|
| v v v v v v | | v v v v v v |
|
||||||
|
| | | |
|
||||||
| v v v v v v | | v v v v v v |
|
| X O O O | | O O O |
|
||||||
| | | |
|
| X O O O | | O O |
|
||||||
| X O O O | | O O O |
|
| O | | |
|
||||||
| X O O O | | O O |
|
| | X | |
|
||||||
| O | | |
|
| | | |
|
||||||
| | X | |
|
| | | |
|
||||||
| | | |
|
| | | |
|
||||||
| | | |
|
| | | |
|
||||||
| | | |
|
|------------------------------| |------------------------------|
|
||||||
| | | |
|
| | | |
|
||||||
|------------------------------| |------------------------------|
|
| | | |
|
||||||
| | | |
|
| | | |
|
||||||
| | | |
|
| | | |
|
||||||
| | | |
|
| X | | |
|
||||||
| | | |
|
| X X | | X |
|
||||||
| X | | |
|
| X X X | | X O |
|
||||||
| X X | | X |
|
| X X X | | X O O |
|
||||||
| X X X | | X O |
|
| | | |
|
||||||
| X X X | | X O O |
|
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
|
||||||
| | | |
|
-------------------------------------------------------------------
|
||||||
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
|
1 2 3 4 5 6 7 8 9 10 11 12
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
1 2 3 4 5 6 7 8 9 10 11 12
|
|
||||||
Xs borne off: none
|
Xs borne off: none
|
||||||
|
|
||||||
Move 12: player X rolls a 6-3.
|
Move 12: player X rolls a 6-3.
|
||||||
|
|
@ -119,12 +107,13 @@ Your move (? for help): ?
|
||||||
Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move.
|
Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move.
|
||||||
Each position should be number from 1 to 24, "bar" or "off".
|
Each position should be number from 1 to 24, "bar" or "off".
|
||||||
Unlike in standard notation, you should enter each counter movement individually. For example:
|
Unlike in standard notation, you should enter each counter movement individually. For example:
|
||||||
24/18 18/13
|
24/18 18/13
|
||||||
bar/3 13/10 13/10 8/5
|
bar/3 13/10 13/10 8/5
|
||||||
2/off 1/off
|
2/off 1/off
|
||||||
You can also enter these commands:
|
You can also enter these commands:
|
||||||
p - show the previous move
|
p - show the previous move
|
||||||
n - show the next move
|
n - show the next move
|
||||||
<enter> - toggle between showing the current and last moves
|
<enter> - toggle between showing the current and last moves
|
||||||
help - show this help text
|
help - show this help text
|
||||||
quit - abandon game
|
quit - abandon game
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,172 +0,0 @@
|
||||||
@startuml
|
|
||||||
|
|
||||||
class "CheckerMove" {
|
|
||||||
- from: Field
|
|
||||||
- to: Field
|
|
||||||
+ to_display_string()
|
|
||||||
+ new(from: Field, to: Field)
|
|
||||||
+ mirror()
|
|
||||||
+ chain(cmove: Self)
|
|
||||||
+ get_from()
|
|
||||||
+ get_to()
|
|
||||||
+ is_exit()
|
|
||||||
+ doable_with_dice(dice: usize)
|
|
||||||
}
|
|
||||||
|
|
||||||
class "Board" {
|
|
||||||
- positions: [i8;24]
|
|
||||||
+ new()
|
|
||||||
+ mirror()
|
|
||||||
+ set_positions(positions: [ i8 ; 24 ])
|
|
||||||
+ count_checkers(color: Color, from: Field, to: Field)
|
|
||||||
+ to_vec()
|
|
||||||
+ to_gnupg_pos_id()
|
|
||||||
+ to_display_grid(col_size: usize)
|
|
||||||
+ set(color: & Color, field: Field, amount: i8)
|
|
||||||
+ blocked(color: & Color, field: Field)
|
|
||||||
+ passage_blocked(color: & Color, field: Field)
|
|
||||||
+ get_field_checkers(field: Field)
|
|
||||||
+ get_checkers_color(field: Field)
|
|
||||||
+ is_field_in_small_jan(field: Field)
|
|
||||||
+ get_color_fields(color: Color)
|
|
||||||
+ get_color_corner(color: & Color)
|
|
||||||
+ get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool)
|
|
||||||
+ passage_possible(color: & Color, cmove: & CheckerMove)
|
|
||||||
+ move_possible(color: & Color, cmove: & CheckerMove)
|
|
||||||
+ any_quarter_filled(color: Color)
|
|
||||||
+ is_quarter_filled(color: Color, field: Field)
|
|
||||||
+ get_quarter_filling_candidate(color: Color)
|
|
||||||
+ is_quarter_fillable(color: Color, field: Field)
|
|
||||||
- get_quarter_fields(field: Field)
|
|
||||||
+ move_checker(color: & Color, cmove: CheckerMove)
|
|
||||||
+ remove_checker(color: & Color, field: Field)
|
|
||||||
+ add_checker(color: & Color, field: Field)
|
|
||||||
}
|
|
||||||
|
|
||||||
class "MoveRules" {
|
|
||||||
+ board: Board
|
|
||||||
+ dice: Dice
|
|
||||||
+ new(color: & Color, board: & Board, dice: Dice)
|
|
||||||
+ set_board(color: & Color, board: & Board)
|
|
||||||
- get_board_from_color(color: & Color, board: & Board)
|
|
||||||
+ moves_follow_rules(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- moves_possible(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- moves_follows_dices(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- get_move_compatible_dices(cmove: & CheckerMove)
|
|
||||||
+ moves_allowed(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- check_corner_rules(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- has_checkers_outside_last_quarter()
|
|
||||||
- check_exit_rules(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
+ get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >)
|
|
||||||
+ get_scoring_quarter_filling_moves_sequences()
|
|
||||||
- get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field)
|
|
||||||
+ get_quarter_filling_moves_sequences()
|
|
||||||
- get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >)
|
|
||||||
- _get_direct_exit_moves(state: & GameState)
|
|
||||||
- is_move_by_puissance(moves: & ( CheckerMove , CheckerMove ))
|
|
||||||
- can_take_corner_by_effect()
|
|
||||||
}
|
|
||||||
|
|
||||||
class "DiceRoller" {
|
|
||||||
- rng: StdRng
|
|
||||||
+ new(opt_seed: Option < u64 >)
|
|
||||||
+ roll()
|
|
||||||
}
|
|
||||||
|
|
||||||
class "Dice" {
|
|
||||||
+ values: (u8,u8)
|
|
||||||
+ to_bits_string()
|
|
||||||
+ to_display_string()
|
|
||||||
+ is_double()
|
|
||||||
}
|
|
||||||
|
|
||||||
class "GameState" {
|
|
||||||
+ stage: Stage
|
|
||||||
+ turn_stage: TurnStage
|
|
||||||
+ board: Board
|
|
||||||
+ active_player_id: PlayerId
|
|
||||||
+ players: HashMap<PlayerId,Player>
|
|
||||||
+ history: Vec<GameEvent>
|
|
||||||
+ dice: Dice
|
|
||||||
+ dice_points: (u8,u8)
|
|
||||||
+ dice_moves: (CheckerMove,CheckerMove)
|
|
||||||
+ dice_jans: PossibleJans
|
|
||||||
- roll_first: bool
|
|
||||||
+ schools_enabled: bool
|
|
||||||
+ new(schools_enabled: bool)
|
|
||||||
- set_schools_enabled(schools_enabled: bool)
|
|
||||||
- get_active_player()
|
|
||||||
- get_opponent_id()
|
|
||||||
+ to_vec_float()
|
|
||||||
+ to_vec()
|
|
||||||
+ to_string_id()
|
|
||||||
+ who_plays()
|
|
||||||
+ get_white_player()
|
|
||||||
+ get_black_player()
|
|
||||||
+ player_id_by_color(color: Color)
|
|
||||||
+ player_id(player: & Player)
|
|
||||||
+ player_color_by_id(player_id: & PlayerId)
|
|
||||||
+ validate(event: & GameEvent)
|
|
||||||
+ init_player(player_name: & str)
|
|
||||||
- add_player(player_id: PlayerId, player: Player)
|
|
||||||
+ switch_active_player()
|
|
||||||
+ consume(valid_event: & GameEvent)
|
|
||||||
- new_pick_up()
|
|
||||||
- get_rollresult_jans(dice: & Dice)
|
|
||||||
+ determine_winner()
|
|
||||||
- inc_roll_count(player_id: PlayerId)
|
|
||||||
- mark_points(player_id: PlayerId, points: u8)
|
|
||||||
}
|
|
||||||
|
|
||||||
class "Player" {
|
|
||||||
+ name: String
|
|
||||||
+ color: Color
|
|
||||||
+ points: u8
|
|
||||||
+ holes: u8
|
|
||||||
+ can_bredouille: bool
|
|
||||||
+ can_big_bredouille: bool
|
|
||||||
+ dice_roll_count: u8
|
|
||||||
+ new(name: String, color: Color)
|
|
||||||
+ to_bits_string()
|
|
||||||
+ to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
class "PointsRules" {
|
|
||||||
+ board: Board
|
|
||||||
+ dice: Dice
|
|
||||||
+ move_rules: MoveRules
|
|
||||||
+ new(color: & Color, board: & Board, dice: Dice)
|
|
||||||
+ set_dice(dice: Dice)
|
|
||||||
+ update_positions(positions: [ i8 ; 24 ])
|
|
||||||
- get_jans(board_ini: & Board, dice_rolls_count: u8)
|
|
||||||
+ get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >)
|
|
||||||
+ get_points(dice_rolls_count: u8)
|
|
||||||
+ get_result_jans(dice_rolls_count: u8)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"MoveRules" <-- "Board"
|
|
||||||
"MoveRules" <-- "Dice"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"GameState" <-- "Board"
|
|
||||||
"HashMap<PlayerId,Player>" <-- "Player"
|
|
||||||
"GameState" <-- "HashMap<PlayerId,Player>"
|
|
||||||
"GameState" <-- "Dice"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"PointsRules" <-- "Board"
|
|
||||||
"PointsRules" <-- "Dice"
|
|
||||||
"PointsRules" <-- "MoveRules"
|
|
||||||
|
|
||||||
@enduml
|
|
||||||
14
justfile
14
justfile
|
|
@ -9,7 +9,7 @@ shell:
|
||||||
runcli:
|
runcli:
|
||||||
RUST_LOG=info cargo run --bin=client_cli
|
RUST_LOG=info cargo run --bin=client_cli
|
||||||
runclibots:
|
runclibots:
|
||||||
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
|
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
|
||||||
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
|
||||||
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
|
||||||
match:
|
match:
|
||||||
|
|
@ -22,13 +22,15 @@ profile:
|
||||||
pythonlib:
|
pythonlib:
|
||||||
maturin build -m store/Cargo.toml --release
|
maturin build -m store/Cargo.toml --release
|
||||||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||||
trainbot algo:
|
trainsimple:
|
||||||
|
cargo build --release --bin=train_dqn_simple
|
||||||
|
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out
|
||||||
|
trainbot:
|
||||||
#python ./store/python/trainModel.py
|
#python ./store/python/trainModel.py
|
||||||
# cargo run --bin=train_dqn # ok
|
# cargo run --bin=train_dqn # ok
|
||||||
# ./bot/scripts/trainValid.sh
|
./bot/scripts/trainValid.sh
|
||||||
./bot/scripts/train.sh {{algo}}
|
plottrainbot:
|
||||||
plottrainbot algo:
|
./bot/scripts/trainValid.sh plot
|
||||||
./bot/scripts/train.sh plot {{algo}}
|
|
||||||
debugtrainbot:
|
debugtrainbot:
|
||||||
cargo build --bin=train_dqn_burn
|
cargo build --bin=train_dqn_burn
|
||||||
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn
|
||||||
|
|
|
||||||
14
server/Cargo.toml
Normal file
14
server/Cargo.toml
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "trictrac-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
store = { path = "../store" }
|
||||||
|
env_logger = "0.10.0"
|
||||||
|
log = "0.4.20"
|
||||||
|
pico-args = "0.5.0"
|
||||||
|
renet = "0.0.13"
|
||||||
|
bincode = "1.3.3"
|
||||||
147
server/src/main.rs
Normal file
147
server/src/main.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
use log::{info, trace, warn};
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, Instant, SystemTime};
|
||||||
|
|
||||||
|
use renet::{
|
||||||
|
transport::{
|
||||||
|
NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
|
||||||
|
},
|
||||||
|
ConnectionConfig, RenetServer, ServerEvent,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only clients that can provide the same PROTOCOL_ID that the server is using will be able to connect.
|
||||||
|
// This can be used to make sure players use the most recent version of the client for instance.
|
||||||
|
pub const PROTOCOL_ID: u64 = 2878;
|
||||||
|
|
||||||
|
/// Utility function for extracting a players name from renet user data
|
||||||
|
fn name_from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> String {
|
||||||
|
let mut buffer = [0u8; 8];
|
||||||
|
buffer.copy_from_slice(&user_data[0..8]);
|
||||||
|
let mut len = u64::from_le_bytes(buffer) as usize;
|
||||||
|
len = len.min(NETCODE_USER_DATA_BYTES - 8);
|
||||||
|
let data = user_data[8..len + 8].to_vec();
|
||||||
|
String::from_utf8(data).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
let mut server = RenetServer::new(ConnectionConfig::default());
|
||||||
|
|
||||||
|
// Setup transport layer
|
||||||
|
const SERVER_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 5000);
|
||||||
|
let socket: UdpSocket = UdpSocket::bind(SERVER_ADDR).unwrap();
|
||||||
|
let server_config = ServerConfig {
|
||||||
|
max_clients: 2,
|
||||||
|
protocol_id: PROTOCOL_ID,
|
||||||
|
public_addr: SERVER_ADDR,
|
||||||
|
authentication: ServerAuthentication::Unsecure,
|
||||||
|
};
|
||||||
|
let current_time = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap();
|
||||||
|
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
|
||||||
|
|
||||||
|
trace!("❂ TricTrac server listening on {}", SERVER_ADDR);
|
||||||
|
|
||||||
|
let mut game_state = store::GameState::default();
|
||||||
|
let mut last_updated = Instant::now();
|
||||||
|
loop {
|
||||||
|
// Update server time
|
||||||
|
let now = Instant::now();
|
||||||
|
let delta_time = now - last_updated;
|
||||||
|
server.update(delta_time);
|
||||||
|
transport.update(delta_time, &mut server).unwrap();
|
||||||
|
last_updated = now;
|
||||||
|
|
||||||
|
// Receive connection events from clients
|
||||||
|
while let Some(event) = server.get_event() {
|
||||||
|
match event {
|
||||||
|
ServerEvent::ClientConnected { client_id } => {
|
||||||
|
let user_data = transport.user_data(client_id).unwrap();
|
||||||
|
|
||||||
|
// Tell the recently joined player about the other player
|
||||||
|
for (player_id, player) in game_state.players.iter() {
|
||||||
|
let event = store::GameEvent::PlayerJoined {
|
||||||
|
player_id: *player_id,
|
||||||
|
name: player.name.clone(),
|
||||||
|
};
|
||||||
|
server.send_message(client_id, 0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new player to the game
|
||||||
|
let event = store::GameEvent::PlayerJoined {
|
||||||
|
player_id: client_id,
|
||||||
|
name: name_from_user_data(&user_data),
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
|
||||||
|
// Tell all players that a new player has joined
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
info!("🎉 Client {} connected.", client_id);
|
||||||
|
// In TicTacTussle the game can begin once two players has joined
|
||||||
|
if game_state.players.len() == 2 {
|
||||||
|
let event = store::GameEvent::BeginGame {
|
||||||
|
goes_first: client_id,
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
trace!("The game gas begun");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ServerEvent::ClientDisconnected {
|
||||||
|
client_id,
|
||||||
|
reason: _,
|
||||||
|
} => {
|
||||||
|
// First consume a disconnect event
|
||||||
|
let event = store::GameEvent::PlayerDisconnected {
|
||||||
|
player_id: client_id,
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
info!("Client {} disconnected", client_id);
|
||||||
|
|
||||||
|
// Then end the game, since tic tac toe can't go on with a single player
|
||||||
|
let event = store::GameEvent::EndGame {
|
||||||
|
reason: store::EndGameReason::PlayerLeft {
|
||||||
|
player_id: client_id,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
game_state.consume(&event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
// NOTE: Since we don't authenticate users we can't do any reconnection attempts.
|
||||||
|
// We simply have no way to know if the next user is the same as the one that disconnected.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive GameEvents from clients. Broadcast valid events.
|
||||||
|
for client_id in server.clients_id().into_iter() {
|
||||||
|
while let Some(message) = server.receive_message(client_id, 0) {
|
||||||
|
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
|
||||||
|
if game_state.validate(&event) {
|
||||||
|
game_state.consume(&event);
|
||||||
|
trace!("Player {} sent:\n\t{:#?}", client_id, event);
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
|
||||||
|
// Determine if a player has won the game
|
||||||
|
if let Some(winner) = game_state.determine_winner() {
|
||||||
|
let event = store::GameEvent::EndGame {
|
||||||
|
reason: store::EndGameReason::PlayerWon { winner },
|
||||||
|
};
|
||||||
|
server.broadcast_message(0, bincode::serialize(&event).unwrap());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Player {} sent invalid event:\n\t{:#?}", client_id, event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.send_packets(&mut server);
|
||||||
|
thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,7 +8,7 @@ use std::fmt;
|
||||||
pub type Field = usize;
|
pub type Field = usize;
|
||||||
pub type FieldWithCount = (Field, i8);
|
pub type FieldWithCount = (Field, i8);
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Deserialize)]
|
||||||
pub struct CheckerMove {
|
pub struct CheckerMove {
|
||||||
from: Field,
|
from: Field,
|
||||||
to: Field,
|
to: Field,
|
||||||
|
|
@ -94,7 +94,7 @@ impl CheckerMove {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents the Tric Trac board
|
/// Represents the Tric Trac board
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct Board {
|
pub struct Board {
|
||||||
positions: [i8; 24],
|
positions: [i8; 24],
|
||||||
}
|
}
|
||||||
|
|
@ -158,42 +158,6 @@ impl Board {
|
||||||
.unsigned_abs()
|
.unsigned_abs()
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the number of the last checker in a field
|
|
||||||
pub fn get_field_checker(&self, color: &Color, field: Field) -> u8 {
|
|
||||||
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
|
|
||||||
let mut total_count: u8 = 0;
|
|
||||||
for (i, checker_count) in self.positions.iter().enumerate() {
|
|
||||||
// count white checkers (checker_count > 0)
|
|
||||||
if *checker_count > 0 {
|
|
||||||
total_count += *checker_count as u8;
|
|
||||||
if field == i + 1 {
|
|
||||||
return total_count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the field of the nth checker
|
|
||||||
pub fn get_checker_field(&self, color: &Color, checker_pos: u8) -> Option<Field> {
|
|
||||||
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
|
|
||||||
if checker_pos == 0 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let mut total_count: u8 = 0;
|
|
||||||
for (i, checker_count) in self.positions.iter().enumerate() {
|
|
||||||
// count white checkers (checker_count > 0)
|
|
||||||
if *checker_count > 0 {
|
|
||||||
total_count += *checker_count as u8;
|
|
||||||
}
|
|
||||||
// return the current field if it contains the checker
|
|
||||||
if checker_pos <= total_count {
|
|
||||||
return Some(i + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_vec(&self) -> Vec<i8> {
|
pub fn to_vec(&self) -> Vec<i8> {
|
||||||
self.positions.to_vec()
|
self.positions.to_vec()
|
||||||
}
|
}
|
||||||
|
|
@ -271,7 +235,7 @@ impl Board {
|
||||||
.map(|cells| {
|
.map(|cells| {
|
||||||
cells
|
cells
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|cell| format!("{cell:>5}"))
|
.map(|cell| format!("{:>5}", cell))
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<String>>()
|
||||||
.join("")
|
.join("")
|
||||||
})
|
})
|
||||||
|
|
@ -282,7 +246,7 @@ impl Board {
|
||||||
.map(|cells| {
|
.map(|cells| {
|
||||||
cells
|
cells
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|cell| format!("{cell:>5}"))
|
.map(|cell| format!("{:>5}", cell))
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<String>>()
|
||||||
.join("")
|
.join("")
|
||||||
})
|
})
|
||||||
|
|
@ -639,55 +603,6 @@ impl Board {
|
||||||
self.positions[field - 1] += unit;
|
self.positions[field - 1] += unit;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_gnupg_pos_id(bits: &str) -> Result<Board, String> {
|
|
||||||
let mut positions = [0i8; 24];
|
|
||||||
let mut bit_idx = 0;
|
|
||||||
let bit_chars: Vec<char> = bits.chars().collect();
|
|
||||||
|
|
||||||
// White checkers (points 1 to 24)
|
|
||||||
for i in 0..24 {
|
|
||||||
if bit_idx >= bit_chars.len() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let mut count = 0;
|
|
||||||
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
|
|
||||||
count += 1;
|
|
||||||
bit_idx += 1;
|
|
||||||
}
|
|
||||||
positions[i] = count;
|
|
||||||
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
|
|
||||||
bit_idx += 1; // Consume the '0' separator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Black checkers (points 24 down to 1)
|
|
||||||
for i in (0..24).rev() {
|
|
||||||
if bit_idx >= bit_chars.len() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let mut count = 0;
|
|
||||||
while bit_idx < bit_chars.len() && bit_chars[bit_idx] == '1' {
|
|
||||||
count += 1;
|
|
||||||
bit_idx += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if positions[i] == 0 {
|
|
||||||
positions[i] = -count;
|
|
||||||
} else if count > 0 {
|
|
||||||
return Err(format!(
|
|
||||||
"Invalid board: checkers of both colors on point {}",
|
|
||||||
i + 1
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if bit_idx < bit_chars.len() && bit_chars[bit_idx] == '0' {
|
|
||||||
bit_idx += 1; // Consume the '0' separator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Board { positions })
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unit Tests
|
// Unit Tests
|
||||||
|
|
@ -806,32 +721,4 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White));
|
assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_checker_field() {
|
|
||||||
let mut board = Board::new();
|
|
||||||
board.set_positions(
|
|
||||||
&Color::White,
|
|
||||||
[
|
|
||||||
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
assert_eq!(None, board.get_checker_field(&Color::White, 0));
|
|
||||||
assert_eq!(Some(3), board.get_checker_field(&Color::White, 5));
|
|
||||||
assert_eq!(Some(3), board.get_checker_field(&Color::White, 6));
|
|
||||||
assert_eq!(None, board.get_checker_field(&Color::White, 14));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_field_checker() {
|
|
||||||
let mut board = Board::new();
|
|
||||||
board.set_positions(
|
|
||||||
&Color::White,
|
|
||||||
[
|
|
||||||
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
assert_eq!(4, board.get_field_checker(&Color::White, 2));
|
|
||||||
assert_eq!(6, board.get_field_checker(&Color::White, 3));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ impl DiceRoller {
|
||||||
/// Represents the two dice
|
/// Represents the two dice
|
||||||
///
|
///
|
||||||
/// Trictrac is always played with two dice.
|
/// Trictrac is always played with two dice.
|
||||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize, Default)]
|
||||||
pub struct Dice {
|
pub struct Dice {
|
||||||
/// The two dice values
|
/// The two dice values
|
||||||
pub values: (u8, u8),
|
pub values: (u8, u8),
|
||||||
|
|
@ -55,17 +55,6 @@ impl Dice {
|
||||||
format!("{:0>3b}{:0>3b}", self.values.0, self.values.1)
|
format!("{:0>3b}{:0>3b}", self.values.0, self.values.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_bits_string(bits: &str) -> Result<Self, String> {
|
|
||||||
if bits.len() != 6 {
|
|
||||||
return Err("Invalid bit string length for dice".to_string());
|
|
||||||
}
|
|
||||||
let d1_str = &bits[0..3];
|
|
||||||
let d2_str = &bits[3..6];
|
|
||||||
let d1 = u8::from_str_radix(d1_str, 2).map_err(|e| e.to_string())?;
|
|
||||||
let d2 = u8::from_str_radix(d2_str, 2).map_err(|e| e.to_string())?;
|
|
||||||
Ok(Dice { values: (d1, d2) })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_display_string(self) -> String {
|
pub fn to_display_string(self) -> String {
|
||||||
format!("{} & {}", self.values.0, self.values.1)
|
format!("{} & {}", self.values.0, self.values.1)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,17 @@ use crate::dice::Dice;
|
||||||
use crate::game_rules_moves::MoveRules;
|
use crate::game_rules_moves::MoveRules;
|
||||||
use crate::game_rules_points::{PointsRules, PossibleJans};
|
use crate::game_rules_points::{PointsRules, PossibleJans};
|
||||||
use crate::player::{Color, Player, PlayerId};
|
use crate::player::{Color, Player, PlayerId};
|
||||||
use log::{debug, error};
|
use log::{debug, error, info};
|
||||||
|
|
||||||
// use itertools::Itertools;
|
// use itertools::Itertools;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::{Hash, Hasher};
|
|
||||||
use std::{fmt, str};
|
use std::{fmt, str};
|
||||||
|
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
|
|
||||||
/// The different stages a game can be in. (not to be confused with the entire "GameState")
|
/// The different stages a game can be in. (not to be confused with the entire "GameState")
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum Stage {
|
pub enum Stage {
|
||||||
PreGame,
|
PreGame,
|
||||||
InGame,
|
InGame,
|
||||||
|
|
@ -23,7 +22,7 @@ pub enum Stage {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The different stages a game turn can be in.
|
/// The different stages a game turn can be in.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum TurnStage {
|
pub enum TurnStage {
|
||||||
RollDice,
|
RollDice,
|
||||||
RollWaiting,
|
RollWaiting,
|
||||||
|
|
@ -61,7 +60,7 @@ impl From<TurnStage> for u8 {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a TricTrac game
|
/// Represents a TricTrac game
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct GameState {
|
pub struct GameState {
|
||||||
pub stage: Stage,
|
pub stage: Stage,
|
||||||
pub turn_stage: TurnStage,
|
pub turn_stage: TurnStage,
|
||||||
|
|
@ -115,11 +114,6 @@ impl Default for GameState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl Hash for GameState {
|
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
||||||
self.to_string_id().hash(state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GameState {
|
impl GameState {
|
||||||
/// Create a new default game
|
/// Create a new default game
|
||||||
|
|
@ -129,15 +123,6 @@ impl GameState {
|
||||||
gs
|
gs
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_players(p1_name: &str, p2_name: &str) -> Self {
|
|
||||||
let mut game = Self::default();
|
|
||||||
if let Some(p1) = game.init_player(p1_name) {
|
|
||||||
game.init_player(p2_name);
|
|
||||||
game.consume(&GameEvent::BeginGame { goes_first: p1 });
|
|
||||||
}
|
|
||||||
game
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_schools_enabled(&mut self, schools_enabled: bool) {
|
fn set_schools_enabled(&mut self, schools_enabled: bool) {
|
||||||
self.schools_enabled = schools_enabled;
|
self.schools_enabled = schools_enabled;
|
||||||
}
|
}
|
||||||
|
|
@ -259,7 +244,7 @@ impl GameState {
|
||||||
pos_bits.push_str(&white_bits);
|
pos_bits.push_str(&white_bits);
|
||||||
pos_bits.push_str(&black_bits);
|
pos_bits.push_str(&black_bits);
|
||||||
|
|
||||||
pos_bits = format!("{pos_bits:0<108}");
|
pos_bits = format!("{:0>108}", pos_bits);
|
||||||
// println!("{}", pos_bits);
|
// println!("{}", pos_bits);
|
||||||
let pos_u8 = pos_bits
|
let pos_u8 = pos_bits
|
||||||
.as_bytes()
|
.as_bytes()
|
||||||
|
|
@ -270,81 +255,6 @@ impl GameState {
|
||||||
general_purpose::STANDARD.encode(pos_u8)
|
general_purpose::STANDARD.encode(pos_u8)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_string_id(id: &str) -> Result<Self, String> {
|
|
||||||
let bytes = general_purpose::STANDARD
|
|
||||||
.decode(id)
|
|
||||||
.map_err(|e| e.to_string())?;
|
|
||||||
|
|
||||||
let bits_str: String = bytes.iter().map(|byte| format!("{:06b}", byte)).collect();
|
|
||||||
|
|
||||||
// The original string was padded to 108 bits.
|
|
||||||
let bits = if bits_str.len() >= 108 {
|
|
||||||
&bits_str[..108]
|
|
||||||
} else {
|
|
||||||
return Err("Invalid decoded string length".to_string());
|
|
||||||
};
|
|
||||||
|
|
||||||
let board_bits = &bits[0..77];
|
|
||||||
let board = Board::from_gnupg_pos_id(board_bits)?;
|
|
||||||
|
|
||||||
let active_player_bit = bits.chars().nth(77).unwrap();
|
|
||||||
let active_player_color = if active_player_bit == '1' {
|
|
||||||
Color::Black
|
|
||||||
} else {
|
|
||||||
Color::White
|
|
||||||
};
|
|
||||||
|
|
||||||
let turn_stage_bits = &bits[78..81];
|
|
||||||
let turn_stage = match turn_stage_bits {
|
|
||||||
"000" => TurnStage::RollWaiting,
|
|
||||||
"001" => TurnStage::RollDice,
|
|
||||||
"010" => TurnStage::MarkPoints,
|
|
||||||
"011" => TurnStage::HoldOrGoChoice,
|
|
||||||
"100" => TurnStage::Move,
|
|
||||||
"101" => TurnStage::MarkAdvPoints,
|
|
||||||
_ => return Err(format!("Invalid bits for turn stage : {turn_stage_bits}")),
|
|
||||||
};
|
|
||||||
|
|
||||||
let dice_bits = &bits[81..87];
|
|
||||||
let dice = Dice::from_bits_string(dice_bits).map_err(|e| e.to_string())?;
|
|
||||||
|
|
||||||
let white_player_bits = &bits[87..97];
|
|
||||||
let black_player_bits = &bits[97..107];
|
|
||||||
|
|
||||||
let white_player =
|
|
||||||
Player::from_bits_string(white_player_bits, "Player 1".to_string(), Color::White)
|
|
||||||
.map_err(|e| e.to_string())?;
|
|
||||||
let black_player =
|
|
||||||
Player::from_bits_string(black_player_bits, "Player 2".to_string(), Color::Black)
|
|
||||||
.map_err(|e| e.to_string())?;
|
|
||||||
|
|
||||||
let mut players = HashMap::new();
|
|
||||||
players.insert(1, white_player);
|
|
||||||
players.insert(2, black_player);
|
|
||||||
|
|
||||||
let active_player_id = if active_player_color == Color::White {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
2
|
|
||||||
};
|
|
||||||
|
|
||||||
// Some fields are not in the ID, so we use defaults.
|
|
||||||
Ok(GameState {
|
|
||||||
stage: Stage::InGame, // Assume InGame from ID
|
|
||||||
turn_stage,
|
|
||||||
board,
|
|
||||||
active_player_id,
|
|
||||||
players,
|
|
||||||
history: Vec::new(),
|
|
||||||
dice,
|
|
||||||
dice_points: (0, 0),
|
|
||||||
dice_moves: (CheckerMove::default(), CheckerMove::default()),
|
|
||||||
dice_jans: PossibleJans::default(),
|
|
||||||
roll_first: false, // Assume not first roll
|
|
||||||
schools_enabled: false, // Assume disabled
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn who_plays(&self) -> Option<&Player> {
|
pub fn who_plays(&self) -> Option<&Player> {
|
||||||
self.get_active_player()
|
self.get_active_player()
|
||||||
}
|
}
|
||||||
|
|
@ -428,7 +338,7 @@ impl GameState {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Roll { player_id } => {
|
Roll { player_id } | RollResult { player_id, dice: _ } => {
|
||||||
// Check player exists
|
// Check player exists
|
||||||
if !self.players.contains_key(player_id) {
|
if !self.players.contains_key(player_id) {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -437,26 +347,6 @@ impl GameState {
|
||||||
if self.active_player_id != *player_id {
|
if self.active_player_id != *player_id {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check the turn stage
|
|
||||||
if self.turn_stage != TurnStage::RollDice {
|
|
||||||
error!("bad stage {:?}", self.turn_stage);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RollResult { player_id, dice: _ } => {
|
|
||||||
// Check player exists
|
|
||||||
if !self.players.contains_key(player_id) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Check player is currently the one making their move
|
|
||||||
if self.active_player_id != *player_id {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Check the turn stage
|
|
||||||
if self.turn_stage != TurnStage::RollWaiting {
|
|
||||||
error!("bad stage {:?}", self.turn_stage);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Mark {
|
Mark {
|
||||||
player_id,
|
player_id,
|
||||||
|
|
@ -737,15 +627,13 @@ impl GameState {
|
||||||
|
|
||||||
fn inc_roll_count(&mut self, player_id: PlayerId) {
|
fn inc_roll_count(&mut self, player_id: PlayerId) {
|
||||||
self.players.get_mut(&player_id).map(|p| {
|
self.players.get_mut(&player_id).map(|p| {
|
||||||
p.dice_roll_count = p.dice_roll_count.saturating_add(1);
|
if p.dice_roll_count < u8::MAX {
|
||||||
|
p.dice_roll_count += 1;
|
||||||
|
}
|
||||||
p
|
p
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mark_points_for_bot_training(&mut self, player_id: PlayerId, points: u8) -> bool {
|
|
||||||
self.mark_points(player_id, points)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
fn mark_points(&mut self, player_id: PlayerId, points: u8) -> bool {
|
||||||
// Update player points and holes
|
// Update player points and holes
|
||||||
let mut new_hole = false;
|
let mut new_hole = false;
|
||||||
|
|
@ -801,14 +689,14 @@ impl GameState {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The reasons why a game could end
|
/// The reasons why a game could end
|
||||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
|
||||||
pub enum EndGameReason {
|
pub enum EndGameReason {
|
||||||
PlayerLeft { player_id: PlayerId },
|
PlayerLeft { player_id: PlayerId },
|
||||||
PlayerWon { winner: PlayerId },
|
PlayerWon { winner: PlayerId },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An event that progresses the GameState forward
|
/// An event that progresses the GameState forward
|
||||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Serialize, PartialEq, Deserialize)]
|
||||||
pub enum GameEvent {
|
pub enum GameEvent {
|
||||||
BeginGame {
|
BeginGame {
|
||||||
goes_first: PlayerId,
|
goes_first: PlayerId,
|
||||||
|
|
@ -935,16 +823,7 @@ mod tests {
|
||||||
let state = init_test_gamestate(TurnStage::RollDice);
|
let state = init_test_gamestate(TurnStage::RollDice);
|
||||||
let string_id = state.to_string_id();
|
let string_id = state.to_string_id();
|
||||||
// println!("string_id : {}", string_id);
|
// println!("string_id : {}", string_id);
|
||||||
assert_eq!(string_id, "Pz84AAAABz8/AAAAAAgAASAG");
|
assert_eq!(string_id, "Hz88AAAAAz8/IAAAAAQAADAD");
|
||||||
let new_state = GameState::from_string_id(&string_id).unwrap();
|
|
||||||
assert_eq!(state.board, new_state.board);
|
|
||||||
assert_eq!(state.active_player_id, new_state.active_player_id);
|
|
||||||
assert_eq!(state.turn_stage, new_state.turn_stage);
|
|
||||||
assert_eq!(state.dice, new_state.dice);
|
|
||||||
assert_eq!(
|
|
||||||
state.get_white_player().unwrap().points,
|
|
||||||
new_state.get_white_player().unwrap().points
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -603,7 +603,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) });
|
let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) });
|
||||||
let jans = points_rules.get_result_jans(8);
|
let jans = points_rules.get_result_jans(8);
|
||||||
assert!(!jans.0.is_empty());
|
assert!(jans.0.len() > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -628,7 +628,7 @@ mod tests {
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
let rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
|
let mut rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
|
||||||
assert_eq!(12, rules.get_points(5).0);
|
assert_eq!(12, rules.get_points(5).0);
|
||||||
|
|
||||||
// Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4
|
// Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use std::fmt;
|
||||||
// This just makes it easier to dissern between a player id and any ol' u64
|
// This just makes it easier to dissern between a player id and any ol' u64
|
||||||
pub type PlayerId = u64;
|
pub type PlayerId = u64;
|
||||||
|
|
||||||
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Copy, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum Color {
|
pub enum Color {
|
||||||
White,
|
White,
|
||||||
Black,
|
Black,
|
||||||
|
|
@ -20,7 +20,7 @@ impl Color {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Struct for storing player related data.
|
/// Struct for storing player related data.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct Player {
|
pub struct Player {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub color: Color,
|
pub color: Color,
|
||||||
|
|
@ -53,26 +53,6 @@ impl Player {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_bits_string(bits: &str, name: String, color: Color) -> Result<Self, String> {
|
|
||||||
if bits.len() != 10 {
|
|
||||||
return Err("Invalid bit string length for player".to_string());
|
|
||||||
}
|
|
||||||
let points = u8::from_str_radix(&bits[0..4], 2).map_err(|e| e.to_string())?;
|
|
||||||
let holes = u8::from_str_radix(&bits[4..8], 2).map_err(|e| e.to_string())?;
|
|
||||||
let can_bredouille = bits.chars().nth(8).unwrap() == '1';
|
|
||||||
let can_big_bredouille = bits.chars().nth(9).unwrap() == '1';
|
|
||||||
|
|
||||||
Ok(Player {
|
|
||||||
name,
|
|
||||||
color,
|
|
||||||
points,
|
|
||||||
holes,
|
|
||||||
can_bredouille,
|
|
||||||
can_big_bredouille,
|
|
||||||
dice_roll_count: 0, // This info is not in the string id
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_vec(&self) -> Vec<u8> {
|
pub fn to_vec(&self) -> Vec<u8> {
|
||||||
vec![
|
vec![
|
||||||
self.points,
|
self.points,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue