Compare commits

...

28 commits

Author SHA1 Message Date
Henri Bourcereau 2ef1f7ee50 Merge branch 'release/v0.1.0' 2025-08-30 13:29:07 +02:00
Henri Bourcereau 73cc6ee67e doc 2025-08-30 13:28:00 +02:00
Henri Bourcereau f2a89f60bc feat: Karel Peeters board game implementation 2025-08-28 19:20:06 +02:00
Henri Bourcereau 866ba611a6 fix: train.sh parsing 2025-08-26 17:12:19 +02:00
Henri Bourcereau e1b8d7e679 feat: bot training configuration file 2025-08-22 09:24:01 +02:00
Henri Bourcereau 8f41cc1412 feat: bot all algos 2025-08-21 17:39:45 +02:00
Henri Bourcereau 0c58490f87 feat: bot sac & ppo save & load 2025-08-21 14:35:25 +02:00
Henri Bourcereau afeb3561e0 refacto: bot one exec 2025-08-21 11:30:25 +02:00
Henri Bourcereau 18e85744d6 refacto: burnrl 2025-08-20 14:08:04 +02:00
Henri Bourcereau 97167ff389 feat: wip bot burn sac 2025-08-19 21:40:02 +02:00
Henri Bourcereau 088124fad1 feat: wip bot burn ppo 2025-08-19 17:46:22 +02:00
Henri Bourcereau fcd50bc0f2 refacto: bot directories 2025-08-19 16:27:37 +02:00
Henri Bourcereau e66921fcce refact models paths 2025-08-18 17:44:01 +02:00
Henri Bourcereau 2499c3377f refact script train bot 2025-08-17 17:42:59 +02:00
Henri Bourcereau a7aa087b18 fix: train bad move 2025-08-17 16:14:06 +02:00
Henri Bourcereau 1dc29d0ff0 chore:refacto clippy 2025-08-17 15:59:53 +02:00
Henri Bourcereau db9560dfac fix dqn burn small 2025-08-16 21:47:12 +02:00
Henri Bourcereau 47a8502b63 fix validations & client_cli 2025-08-16 17:59:00 +02:00
Henri Bourcereau c1e99a5f35 wip (tests fails) 2025-08-16 16:39:25 +02:00
Henri Bourcereau 56d155b911 wip debug 2025-08-16 11:13:31 +02:00
Henri Bourcereau d313cb6151 burnrl_big like before 2025-08-15 21:08:23 +02:00
Henri Bourcereau 93624c425d wip burnrl_big 2025-08-15 18:39:09 +02:00
Henri Bourcereau 86a67ae66a fix: train bot opponent rewards 2025-08-13 18:08:35 +02:00
Henri Bourcereau ac14341cf9 doc: schema store 2025-08-13 15:29:04 +02:00
Henri Bourcereau cfc19e6064 compile ok but diverge 2025-08-12 21:56:52 +02:00
Henri Bourcereau ec6ae26d38 wip reduction TrictracAction 2025-08-12 17:56:41 +02:00
Henri Bourcereau 5370eb4307 Merge branch 'feature/botTrainValidMoves' into develop 2025-08-11 18:56:17 +02:00
Henri Bourcereau 1fb04209f5 doc params train bot 2025-08-10 17:46:09 +02:00
51 changed files with 3610 additions and 734 deletions

282
Cargo.lock generated
View file

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 4
[[package]]
name = "addr2line"
version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1"
dependencies = [
"gimli",
]
[[package]]
name = "adler2"
version = "2.0.1"
@ -158,6 +167,24 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "arimaa_engine_step"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1c6726d7896a539a62e157b05fa4b7308ffb7872f2b4a2a592d5adb19837861"
dependencies = [
"anyhow",
"itertools 0.10.5",
"log",
"regex",
]
[[package]]
name = "arrayvec"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
[[package]]
name = "arrayvec"
version = "0.7.6"
@ -204,7 +231,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f3efb2ca85bc610acfa917b5aaa36f3fcbebed5b3182d7f877b02531c4b80c8"
dependencies = [
"anyhow",
"arrayvec",
"arrayvec 0.7.6",
"log",
"nom",
"num-rational",
@ -217,7 +244,22 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98922d6a4cfbcb08820c69d8eeccc05bb1f29bfa06b4f5b1dbfe9a868bd7608e"
dependencies = [
"arrayvec",
"arrayvec 0.7.6",
]
[[package]]
name = "backtrace"
version = "0.3.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002"
dependencies = [
"addr2line",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
"windows-targets 0.52.6",
]
[[package]]
@ -314,13 +356,39 @@ dependencies = [
"generic-array",
]
[[package]]
name = "board-game"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "647fc8459363368aae04df3d21da37094430c57dd993d09be2792133d5365e3e"
dependencies = [
"arimaa_engine_step",
"cast_trait",
"chess",
"decorum",
"internal-iterator",
"itertools 0.10.5",
"lazy_static",
"nohash-hasher",
"nom",
"num-traits",
"once_cell",
"rand 0.8.5",
"rand_xoshiro",
"rayon",
"static_assertions",
]
[[package]]
name = "bot"
version = "0.1.0"
dependencies = [
"board-game",
"burn",
"burn-rl",
"confy",
"env_logger 0.10.0",
"internal-iterator",
"log",
"pretty_assertions",
"rand 0.8.5",
@ -796,6 +864,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast_trait"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4f8d981c476baadf74cd52897866a1d279d3e14e2d5e2d9af045210e0ae6128"
[[package]]
name = "castaway"
version = "0.2.3"
@ -862,6 +936,18 @@ dependencies = [
"zeroize",
]
[[package]]
name = "chess"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ed299b171ec34f372945ad6726f7bc1d2afd5f59fb8380f64f48e2bab2f0ec8"
dependencies = [
"arrayvec 0.5.2",
"failure",
"nodrop",
"rand 0.7.3",
]
[[package]]
name = "cipher"
version = "0.4.4"
@ -917,7 +1003,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.1.14",
"unicode-width 0.2.0",
]
[[package]]
@ -964,6 +1050,18 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "confy"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f29222b549d4e3ded127989d523da9e928918d0d0d7f7c1690b439d0d538bae9"
dependencies = [
"directories",
"serde",
"thiserror 2.0.12",
"toml",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
@ -1433,6 +1531,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "decorum"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "281759d3c8a14f5c3f0c49363be56810fcd7f910422f97f2db850c2920fde5cf"
dependencies = [
"num-traits",
]
[[package]]
name = "deranged"
version = "0.4.0"
@ -1524,6 +1631,15 @@ dependencies = [
"subtle",
]
[[package]]
name = "directories"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f5094c54661b38d03bd7e50df373292118db60b585c08a411c6d840017fe7d"
dependencies = [
"dirs-sys 0.5.0",
]
[[package]]
name = "dirs"
version = "5.0.1"
@ -1737,6 +1853,28 @@ dependencies = [
"zune-inflate",
]
[[package]]
name = "failure"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d32e9bd16cc02eae7db7ef620b392808b89f6a5e16bb3497d159c6b92a0f4f86"
dependencies = [
"backtrace",
"failure_derive",
]
[[package]]
name = "failure_derive"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
"synstructure 0.12.6",
]
[[package]]
name = "fallible-iterator"
version = "0.3.0"
@ -2170,6 +2308,12 @@ dependencies = [
"weezl",
]
[[package]]
name = "gimli"
version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "gix-features"
version = "0.42.1"
@ -2352,7 +2496,7 @@ dependencies = [
"num-traits",
"ordered-float 5.0.0",
"rand 0.8.5",
"rand_pcg",
"rand_pcg 0.3.1",
"sdl2",
"serde",
]
@ -2551,6 +2695,12 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "internal-iterator"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "969ee3fc68ec2e88eb21434ce4d9b7e1600d1ce92ff974560a6c4a304f5124b9"
[[package]]
name = "interpolate_name"
version = "0.2.4"
@ -2579,6 +2729,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
@ -2937,7 +3096,7 @@ version = "25.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b977c445f26e49757f9aca3631c3b8b836942cb278d69a92e7b80d3b24da632"
dependencies = [
"arrayvec",
"arrayvec 0.7.6",
"bit-set",
"bitflags 2.9.1",
"cfg_aliases",
@ -3014,6 +3173,18 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "nohash-hasher"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
[[package]]
name = "nom"
version = "7.1.3"
@ -3213,6 +3384,15 @@ dependencies = [
"malloc_buf",
]
[[package]]
name = "object"
version = "0.36.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"memchr",
]
[[package]]
name = "octets"
version = "0.2.0"
@ -3570,6 +3750,18 @@ dependencies = [
"uuid",
]
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
"rand_pcg 0.2.1",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -3592,6 +3784,16 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@ -3612,6 +3814,12 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
[[package]]
name = "rand_core"
version = "0.6.4"
@ -3641,6 +3849,24 @@ dependencies = [
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "rand_pcg"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "rand_pcg"
version = "0.3.1"
@ -3650,6 +3876,15 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_xoshiro"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "range-alloc"
version = "0.1.4"
@ -3707,7 +3942,7 @@ checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9"
dependencies = [
"arbitrary",
"arg_enum_proc_macro",
"arrayvec",
"arrayvec 0.7.6",
"av1-grain",
"bitstream-io",
"built",
@ -3991,6 +4226,12 @@ dependencies = [
"smallvec",
]
[[package]]
name = "rustc-demangle"
version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace"
[[package]]
name = "rustc-hash"
version = "1.1.0"
@ -4500,6 +4741,18 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
"unicode-xid",
]
[[package]]
name = "synstructure"
version = "0.13.2"
@ -4784,9 +5037,16 @@ dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_write",
"winnow",
]
[[package]]
name = "toml_write"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
[[package]]
name = "torch-sys"
version = "0.19.0"
@ -5244,7 +5504,7 @@ version = "25.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec8fb398f119472be4d80bc3647339f56eb63b2a331f6a3d16e25d8144197dd9"
dependencies = [
"arrayvec",
"arrayvec 0.7.6",
"bitflags 2.9.1",
"cfg_aliases",
"document-features",
@ -5272,7 +5532,7 @@ version = "25.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7b882196f8368511d613c6aeec80655160db6646aebddf8328879a88d54e500"
dependencies = [
"arrayvec",
"arrayvec 0.7.6",
"bit-set",
"bit-vec",
"bitflags 2.9.1",
@ -5331,7 +5591,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f968767fe4d3d33747bbd1473ccd55bf0f6451f55d733b5597e67b5deab4ad17"
dependencies = [
"android_system_properties",
"arrayvec",
"arrayvec 0.7.6",
"ash",
"bit-set",
"bitflags 2.9.1",
@ -5754,7 +6014,7 @@ dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
"synstructure",
"synstructure 0.13.2",
]
[[package]]
@ -5795,7 +6055,7 @@ dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
"synstructure",
"synstructure 0.13.2",
]
[[package]]

View file

@ -6,16 +6,12 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[[bin]]
name = "train_dqn_burn_valid"
path = "src/dqn/burnrl_valid/main.rs"
[[bin]]
name = "train_dqn_burn"
path = "src/dqn/burnrl/main.rs"
name = "burn_train"
path = "src/burnrl/main.rs"
[[bin]]
name = "train_dqn_simple"
path = "src/dqn/simple/main.rs"
path = "src/dqn_simple/main.rs"
[dependencies]
pretty_assertions = "1.4.0"
@ -27,3 +23,6 @@ env_logger = "0.10"
burn = { version = "0.17", features = ["ndarray", "autodiff"] }
burn-rl = { git = "https://github.com/yunjhongwu/burn-rl-examples.git", package = "burn-rl" }
log = "0.4.20"
confy = "1.0.0"
board-game = "0.8.2"
internal-iterator = "0.2.3"

View file

@ -1,38 +1,50 @@
#!/usr/bin/env sh
#!/usr/bin/env bash
ROOT="$(cd "$(dirname "$0")" && pwd)/../.."
LOGS_DIR="$ROOT/bot/models/logs"
CFG_SIZE=12
CFG_SIZE=17
BINBOT=burn_train
# BINBOT=train_ppo_burn
# BINBOT=train_dqn_burn
# BINBOT=train_dqn_burn_big
# BINBOT=train_dqn_burn_before
OPPONENT="random"
PLOT_EXT="png"
train() {
cargo build --release --bin=train_dqn_burn
NAME="train_$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$NAME.out"
mkdir -p "$LOGS_DIR"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/train_dqn_burn" | tee "$LOGS"
ALGO=$1
cargo build --release --bin=$BINBOT
NAME="$(date +%Y-%m-%d_%H:%M:%S)"
LOGS="$LOGS_DIR/$ALGO/$NAME.out"
mkdir -p "$LOGS_DIR/$ALGO"
LD_LIBRARY_PATH="$ROOT/target/release" "$ROOT/target/release/$BINBOT" $ALGO | tee "$LOGS"
}
plot() {
NAME=$(ls "$LOGS_DIR" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
ALGO=$1
NAME=$(ls -rt "$LOGS_DIR/$ALGO" | grep -v png | tail -n 1)
LOGS="$LOGS_DIR/$ALGO/$NAME"
cfgs=$(grep -v "info:" "$LOGS" | head -n $CFG_SIZE)
for cfg in $cfgs; do
eval "$cfg"
done
# tail -n +$((CFG_SIZE + 2)) "$LOGS"
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' |
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$ALGO/$NAME.$PLOT_EXT"
}
if [ "$1" = "plot" ]; then
plot
if [[ -z "$1" ]]; then
echo "Usage : train [plot] <algo>"
elif [ "$1" = "plot" ]; then
if [[ -z "$2" ]]; then
echo "Usage : train [plot] <algo>"
else
plot $2
fi
else
train
train $1
fi

View file

@ -17,7 +17,7 @@ train() {
}
plot() {
NAME=$(ls "$LOGS_DIR" | tail -n 1)
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
cfgs=$(head -n $CFG_SIZE "$LOGS")
for cfg in $cfgs; do
@ -31,8 +31,19 @@ plot() {
feedgnuplot --lines --points --unset grid --title "adv = $OPPONENT ; density = $dense_size ; decay = $eps_decay ; max steps = $max_steps" --terminal $PLOT_EXT >"$LOGS_DIR/$OPPONENT-$dense_size-$eps_decay-$max_steps-$NAME.$PLOT_EXT"
}
avg() {
NAME=$(ls -rt "$LOGS_DIR" | grep -v "png" | tail -n 1)
LOGS="$LOGS_DIR/$NAME"
echo $LOGS
tail -n +$((CFG_SIZE + 2)) "$LOGS" |
grep -v "info:" |
awk -F '[ ,]' '{print $5}' | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }'
}
if [ "$1" = "plot" ]; then
plot
elif [ "$1" = "avg" ]; then
avg
else
train
fi

View file

@ -1,15 +1,16 @@
use crate::dqn::burnrl_valid::environment::TrictracEnvironment;
use crate::dqn::burnrl_valid::utils::soft_update_linear;
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::fmt;
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -62,67 +63,20 @@ impl<B: Backend> DQNModel<B> for Net<B> {
#[allow(unused)]
const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub max_steps: usize,
pub num_episodes: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
pub gamma: f32,
pub tau: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub clip_grad: f32,
}
impl fmt::Display for DqnConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
write!(f, "{s}")
}
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
max_steps: 2000,
num_episodes: 1000,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
}
}
}
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig,
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
// ) -> DQN<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
// env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
@ -189,8 +143,13 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
* 100.0)
.round() as u32;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
envmut.goodmoves_count,
goodmoves_ratio,
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
@ -202,5 +161,35 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
}
}
}
agent
let valid_agent = agent.valid();
if let Some(path) = &conf.save_path {
save_model(valid_agent.model().as_ref().unwrap(), path);
}
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -1,15 +1,16 @@
use crate::dqn::burnrl::environment::TrictracEnvironment;
use crate::dqn::burnrl::utils::soft_update_linear;
use crate::burnrl::environment_big::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, ElemType, Environment, Memory, Model, State};
use std::fmt;
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
@ -62,71 +63,19 @@ impl<B: Backend> DQNModel<B> for Net<B> {
#[allow(unused)]
const MEMORY_SIZE: usize = 8192;
pub struct DqnConfig {
pub min_steps: f32,
pub max_steps: usize,
pub num_episodes: usize,
pub dense_size: usize,
pub eps_start: f64,
pub eps_end: f64,
pub eps_decay: f64,
pub gamma: f32,
pub tau: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub clip_grad: f32,
}
impl fmt::Display for DqnConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = String::new();
s.push_str(&format!("min_steps={:?}\n", self.min_steps));
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
write!(f, "{s}")
}
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
min_steps: 250.0,
max_steps: 2000,
num_episodes: 1000,
dense_size: 256,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
}
}
}
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
conf: &DqnConfig,
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> DQN<E, B, Net<B>> {
// ) -> impl Agent<E> {
// ) -> DQN<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().min_steps = conf.min_steps;
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
@ -193,9 +142,13 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
let goodmoves_ratio = ((envmut.goodmoves_count as f32 / episode_duration as f32)
* 100.0)
.round() as u32;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"rollpoints\":{}, \"duration\": {}}}",
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"goodmoves\": {}, \"ratio\": {}%, \"rollpoints\":{}, \"duration\": {}}}",
envmut.goodmoves_count,
goodmoves_ratio,
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
@ -207,5 +160,35 @@ pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
}
}
}
agent
let valid_agent = agent.valid();
if let Some(path) = &conf.save_path {
save_model(valid_agent.model().as_ref().unwrap(), path);
}
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,189 @@
use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::relu;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::DQN;
use burn_rl::agent::{DQNModel, DQNTrainingConfig};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
relu(self.linear_2.forward(layer_1_output))
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> DQNModel<B> for Net<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 8192;
type MyAgent<E, B> = DQN<E, B, Net<B>>;
#[allow(unused)]
// pub fn run<E: Environment + AsMut<TrictracEnvironment>, B: AutodiffBackend>(
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> DQN<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let mut agent = MyAgent::new(model);
// let config = DQNTrainingConfig::default();
let config = DQNTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut policy_net = agent.model().as_ref().unwrap().clone();
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward: ElemType = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
let eps_threshold = conf.eps_end
+ (conf.eps_start - conf.eps_end) * f64::exp(-(step as f64) / conf.eps_decay);
let action =
DQN::<E, B, Net<B>>::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action);
episode_reward +=
<<E as Environment>::RewardType as Into<ElemType>>::into(snapshot.reward().clone());
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
policy_net =
agent.train::<MEMORY_SIZE>(policy_net, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
let envmut = env.as_mut();
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"epsilon\": {eps_threshold:.3}, \"rollpoints\":{}, \"duration\": {}}}",
envmut.pointrolls_count,
now.elapsed().unwrap().as_secs(),
);
env.reset();
episode_done = true;
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
let valid_agent = agent.valid();
if let Some(path) = &conf.save_path {
save_model(valid_agent.model().as_ref().unwrap(), path);
}
valid_agent
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,9 @@
pub mod dqn;
pub mod dqn_big;
pub mod dqn_valid;
pub mod ppo;
pub mod ppo_big;
pub mod ppo_valid;
pub mod sac;
pub mod sac_big;
pub mod sac_valid;

191
bot/src/burnrl/algos/ppo.rs Normal file
View file

@ -0,0 +1,191 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear: Linear<B>,
linear_actor: Linear<B>,
linear_critic: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
let initializer = Initializer::XavierUniform { gain: 1.0 };
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
let layer_0_output = relu(self.linear.forward(input));
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
let values = self.linear_critic.forward(layer_0_output);
PPOOutput::<B>::new(policies, values)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}
impl<B: Backend> PPOModel<B> for Net<B> {}
#[allow(unused)]
const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let mut model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let agent = MyAgent::default();
let config = PPOTrainingConfig {
gamma: conf.gamma,
lambda: conf.lambda,
epsilon_clip: conf.epsilon_clip,
critic_weight: conf.critic_weight,
entropy_weight: conf.entropy_weight,
learning_rate: conf.learning_rate,
epochs: conf.epochs,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut now = SystemTime::now();
env.reset();
while !episode_done {
let state = env.state();
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
episode_duration += 1;
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
}
}
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
memory.clear();
}
if let Some(path) = &conf.save_path {
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
agent.valid(model)
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,191 @@
use crate::burnrl::environment_big::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear: Linear<B>,
linear_actor: Linear<B>,
linear_critic: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
let initializer = Initializer::XavierUniform { gain: 1.0 };
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
let layer_0_output = relu(self.linear.forward(input));
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
let values = self.linear_critic.forward(layer_0_output);
PPOOutput::<B>::new(policies, values)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}
impl<B: Backend> PPOModel<B> for Net<B> {}
#[allow(unused)]
const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let mut model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let agent = MyAgent::default();
let config = PPOTrainingConfig {
gamma: conf.gamma,
lambda: conf.lambda,
epsilon_clip: conf.epsilon_clip,
critic_weight: conf.critic_weight,
entropy_weight: conf.entropy_weight,
learning_rate: conf.learning_rate,
epochs: conf.epochs,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut now = SystemTime::now();
env.reset();
while !episode_done {
let state = env.state();
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
episode_duration += 1;
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
}
}
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
memory.clear();
}
if let Some(path) = &conf.save_path {
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
agent.valid(model)
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,191 @@
use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::Config;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Initializer, Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{PPOModel, PPOOutput, PPOTrainingConfig, PPO};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::env;
use std::fs;
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
linear: Linear<B>,
linear_actor: Linear<B>,
linear_critic: Linear<B>,
}
impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
let initializer = Initializer::XavierUniform { gain: 1.0 };
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>, Tensor<B, 2>> for Net<B> {
fn forward(&self, input: Tensor<B, 2>) -> PPOOutput<B> {
let layer_0_output = relu(self.linear.forward(input));
let policies = softmax(self.linear_actor.forward(layer_0_output.clone()), 1);
let values = self.linear_critic.forward(layer_0_output);
PPOOutput::<B>::new(policies, values)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}
impl<B: Backend> PPOModel<B> for Net<B> {}
#[allow(unused)]
const MEMORY_SIZE: usize = 512;
type MyAgent<E, B> = PPO<E, B, Net<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
// ) -> PPO<E, B, Net<B>> {
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let mut model = Net::<B>::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
let agent = MyAgent::default();
let config = PPOTrainingConfig {
gamma: conf.gamma,
lambda: conf.lambda,
epsilon_clip: conf.epsilon_clip,
critic_weight: conf.critic_weight,
entropy_weight: conf.entropy_weight,
learning_rate: conf.learning_rate,
epochs: conf.epochs,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut optimizer = AdamWConfig::new()
.with_grad_clipping(config.clip_grad.clone())
.init();
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut now = SystemTime::now();
env.reset();
while !episode_done {
let state = env.state();
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &model) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
episode_duration += 1;
episode_done = snapshot.done() || episode_duration >= conf.max_steps;
}
}
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs(),
);
now = SystemTime::now();
model = MyAgent::train::<MEMORY_SIZE>(model, &memory, &mut optimizer, &config);
memory.clear();
}
if let Some(path) = &conf.save_path {
let device = NdArrayDevice::default();
let recorder = CompactRecorder::new();
let tmp_path = env::temp_dir().join("tmp_model.mpk");
// Save the trained model (backend B) to a temporary file
recorder
.record(model.clone().into_record(), tmp_path.clone())
.expect("Failed to save temporary model");
// Create a new model instance with the target backend (NdArray)
let model_to_save: Net<NdArray<ElemType>> = Net::new(
<<E as Environment>::StateType as State>::size(),
conf.dense_size,
<<E as Environment>::ActionType as Action>::size(),
);
// Load the record from the temporary file into the new model
let record = recorder
.load(tmp_path.clone(), &device)
.expect("Failed to load temporary model");
let model_with_loaded_weights = model_to_save.load_record(record);
// Clean up the temporary file
fs::remove_file(tmp_path).expect("Failed to remove temporary model file");
save_model(&model_with_loaded_weights, path);
}
agent.valid(model)
}
pub fn save_model(model: &Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Net<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

221
bot/src/burnrl/algos/sac.rs Normal file
View file

@ -0,0 +1,221 @@
use crate::burnrl::environment::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Actor<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
softmax(self.linear_2.forward(layer_1_output), 1)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACActor<B> for Actor<B> {}
#[derive(Module, Debug)]
pub struct Critic<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
self.linear_2.forward(layer_1_output)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACCritic<B> for Critic<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
let config = SACTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
min_probability: conf.min_probability,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
let mut optimizer = SACOptimizer::new(
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.init(),
);
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
}
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,222 @@
use crate::burnrl::environment_big::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Actor<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
softmax(self.linear_2.forward(layer_1_output), 1)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACActor<B> for Actor<B> {}
#[derive(Module, Debug)]
pub struct Critic<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
self.linear_2.forward(layer_1_output)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACCritic<B> for Critic<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
let config = SACTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
min_probability: conf.min_probability,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
let mut optimizer = SACOptimizer::new(
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.init(),
);
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
}
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,222 @@
use crate::burnrl::environment_valid::TrictracEnvironment;
use crate::burnrl::utils::{soft_update_linear, Config};
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamWConfig;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};
use std::time::SystemTime;
#[derive(Module, Debug)]
pub struct Actor<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Actor<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
softmax(self.linear_2.forward(layer_1_output), 1)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACActor<B> for Actor<B> {}
#[derive(Module, Debug)]
pub struct Critic<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
fn consume(self) -> (Linear<B>, Linear<B>, Linear<B>) {
(self.linear_0, self.linear_1, self.linear_2)
}
}
impl<B: Backend> Model<B, Tensor<B, 2>, Tensor<B, 2>> for Critic<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));
self.linear_2.forward(layer_1_output)
}
fn infer(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.forward(input)
}
}
impl<B: Backend> SACCritic<B> for Critic<B> {
fn soft_update(this: Self, that: &Self, tau: ElemType) -> Self {
let (linear_0, linear_1, linear_2) = this.consume();
Self {
linear_0: soft_update_linear(linear_0, &that.linear_0, tau),
linear_1: soft_update_linear(linear_1, &that.linear_1, tau),
linear_2: soft_update_linear(linear_2, &that.linear_2, tau),
}
}
}
#[allow(unused)]
const MEMORY_SIZE: usize = 4096;
type MyAgent<E, B> = SAC<E, B, Actor<B>>;
#[allow(unused)]
pub fn run<
E: Environment + AsMut<TrictracEnvironment>,
B: AutodiffBackend<InnerBackend = NdArray>,
>(
conf: &Config,
visualized: bool,
) -> impl Agent<E> {
let mut env = E::new(visualized);
env.as_mut().max_steps = conf.max_steps;
let state_dim = <<E as Environment>::StateType as State>::size();
let action_dim = <<E as Environment>::ActionType as Action>::size();
let actor = Actor::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_1 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let critic_2 = Critic::<B>::new(state_dim, conf.dense_size, action_dim);
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);
let mut agent = MyAgent::default();
let config = SACTrainingConfig {
gamma: conf.gamma,
tau: conf.tau,
learning_rate: conf.learning_rate,
min_probability: conf.min_probability,
batch_size: conf.batch_size,
clip_grad: Some(burn::grad_clipping::GradientClippingConfig::Value(
conf.clip_grad,
)),
};
let mut memory = Memory::<E, B, MEMORY_SIZE>::default();
let optimizer_config = AdamWConfig::new().with_grad_clipping(config.clip_grad.clone());
let mut optimizer = SACOptimizer::new(
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.clone().init(),
optimizer_config.init(),
);
let mut step = 0_usize;
for episode in 0..conf.num_episodes {
let mut episode_done = false;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();
let mut now = SystemTime::now();
while !episode_done {
if let Some(action) = MyAgent::<E, _>::react_with_model(&state, &nets.actor) {
let snapshot = env.step(action);
episode_reward += <<E as Environment>::RewardType as Into<ElemType>>::into(
snapshot.reward().clone(),
);
memory.push(
state,
*snapshot.state(),
action,
snapshot.reward().clone(),
snapshot.done(),
);
if config.batch_size < memory.len() {
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}
step += 1;
episode_duration += 1;
if snapshot.done() || episode_duration >= conf.max_steps {
env.reset();
episode_done = true;
println!(
"{{\"episode\": {episode}, \"reward\": {episode_reward:.4}, \"steps count\": {episode_duration}, \"duration\": {}}}",
now.elapsed().unwrap().as_secs()
);
now = SystemTime::now();
} else {
state = *snapshot.state();
}
}
}
}
let valid_agent = agent.valid(nets.actor);
if let Some(path) = &conf.save_path {
if let Some(model) = valid_agent.model() {
save_model(model, path);
}
}
valid_agent
}
pub fn save_model(model: &Actor<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}.mpk");
println!("info: Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<Actor<NdArray<ElemType>>> {
let model_path = format!("{path}.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
Actor::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}

View file

@ -0,0 +1,424 @@
use std::io::Write;
use crate::training_common;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -1.12121;
const REWARD_VALID_MOVE: f32 = 1.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 1.0;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
pub data: [i8; 36], // Représentation vectorielle de l'état du jeu
}
impl State for TrictracState {
type Data = [i8; 36];
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::from_floats(self.data, &B::Device::default())
}
fn size() -> usize {
36
}
}
impl TrictracState {
/// Convertit un GameState en TrictracState
pub fn from_game_state(game_state: &GameState) -> Self {
let state_vec = game_state.to_vec();
let mut data = [0; 36];
// Copier les données en s'assurant qu'on ne dépasse pas la taille
let copy_len = state_vec.len().min(36);
data[..copy_len].copy_from_slice(&state_vec[..copy_len]);
TrictracState { data }
}
}
/// Actions possibles dans Trictrac pour burn-rl
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrictracAction {
// u32 as required by burn_rl::base::Action type
pub index: u32,
}
impl Action for TrictracAction {
fn random() -> Self {
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
TrictracAction {
index: rng.gen_range(0..Self::size() as u32),
}
}
fn enumerate() -> Vec<Self> {
(0..Self::size() as u32)
.map(|index| TrictracAction { index })
.collect()
}
fn size() -> usize {
514
}
}
impl From<u32> for TrictracAction {
fn from(index: u32) -> Self {
TrictracAction { index }
}
}
impl From<TrictracAction> for u32 {
fn from(action: TrictracAction) -> u32 {
action.index
}
}
/// Environnement Trictrac pour burn-rl
#[derive(Debug)]
pub struct TrictracEnvironment {
pub game: GameState,
active_player_id: PlayerId,
opponent_id: PlayerId,
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub best_ratio: f32,
pub max_steps: usize,
pub pointrolls_count: usize,
pub goodmoves_count: usize,
pub goodmoves_ratio: f32,
pub visualized: bool,
}
impl Environment for TrictracEnvironment {
type StateType = TrictracState;
type ActionType = TrictracAction;
type RewardType = f32;
fn new(visualized: bool) -> Self {
let mut game = GameState::new(false);
// Ajouter deux joueurs
game.init_player("DQN Agent");
game.init_player("Opponent");
let player1_id = 1;
let player2_id = 2;
// Commencer la partie
game.consume(&GameEvent::BeginGame { goes_first: 1 });
let current_state = TrictracState::from_game_state(&game);
TrictracEnvironment {
game,
active_player_id: player1_id,
opponent_id: player2_id,
current_state,
episode_reward: 0.0,
step_count: 0,
best_ratio: 0.0,
max_steps: 2000,
pointrolls_count: 0,
goodmoves_count: 0,
goodmoves_ratio: 0.0,
visualized,
}
}
fn state(&self) -> Self::StateType {
self.current_state
}
fn reset(&mut self) -> Snapshot<Self> {
// Réinitialiser le jeu
let history = self.game.history.clone();
self.game = GameState::new(false);
self.game.init_player("DQN Agent");
self.game.init_player("Opponent");
// Commencer la partie
self.game.consume(&GameEvent::BeginGame { goes_first: 1 });
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward = 0.0;
self.goodmoves_ratio = if self.step_count == 0 {
0.0
} else {
self.goodmoves_count as f32 / self.step_count as f32
};
self.best_ratio = self.best_ratio.max(self.goodmoves_ratio);
let _warning = if self.best_ratio > 0.7 && self.goodmoves_ratio < 0.1 {
let path = "bot/models/logs/debug.log";
if let Ok(mut out) = std::fs::File::create(path) {
write!(out, "{history:?}").expect("could not write history log");
}
"!!!!"
} else {
""
};
// println!(
// "info: correct moves: {} ({}%) {}",
// self.goodmoves_count,
// (100.0 * self.goodmoves_ratio).round() as u32,
// warning
// );
self.step_count = 0;
self.pointrolls_count = 0;
self.goodmoves_count = 0;
Snapshot::new(self.current_state, 0.0, false)
}
fn step(&mut self, action: Self::ActionType) -> Snapshot<Self> {
self.step_count += 1;
// Convertir l'action burn-rl vers une action Trictrac
let trictrac_action = Self::convert_action(action);
let mut reward = 0.0;
let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
(reward, is_rollpoint) = self.execute_action(action);
if is_rollpoint {
self.pointrolls_count += 1;
}
if reward != ERROR_REWARD {
self.goodmoves_count += 1;
}
} else {
// Action non convertible, pénalité
panic!("action non convertible");
//reward = -0.5;
}
}
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
reward += self.play_opponent_if_needed();
}
// Vérifier si la partie est terminée
// let max_steps = self.max_steps;
// let max_steps = self.min_steps
// + (self.max_steps as f32 - self.min_steps)
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += WIN_POINTS; // Victoire
} else {
reward -= WIN_POINTS; // Défaite
}
}
}
let terminated = done || self.step_count >= self.max_steps;
// let terminated = done || self.step_count >= max_steps.round() as usize;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward += reward;
if self.visualized && terminated {
println!(
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
self.episode_reward, self.step_count
);
}
Snapshot::new(self.current_state, reward, terminated)
}
}
impl TrictracEnvironment {
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<training_common::TrictracAction> {
training_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index(
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<training_common::TrictracAction> {
use training_common::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state);
if valid_actions.is_empty() {
return None;
}
// Mapper l'index d'action sur une action valide
let action_index = (action.index as usize) % valid_actions.len();
Some(valid_actions[action_index].clone())
}
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: training_common::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: training_common::TrictracAction) -> (f32, bool) {
use training_common::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;
// Appliquer l'événement si valide
if let Some(event) = action.to_event(&self.game) {
if self.game.validate(&event) {
self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
player_id: self.active_player_id,
dice: store::Dice {
values: dice_values,
},
};
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += REWARD_RATIO * (points as f32 - adv_points as f32);
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
}
// Récompense proportionnelle aux points
}
}
} else {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = ERROR_REWARD;
}
} else {
reward = ERROR_REWARD;
}
(reward, is_rollpoint)
}
/// Fait jouer l'adversaire avec une stratégie simple
fn play_opponent_if_needed(&mut self) -> f32 {
let mut reward = 0.0;
// Si c'est le tour de l'adversaire, jouer automatiquement
if self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// Utiliser la stratégie default pour l'adversaire
use crate::BotStrategy;
let mut strategy = crate::strategy::random::RandomStrategy::default();
strategy.set_player_id(self.opponent_id);
if let Some(color) = self.game.player_color_by_id(&self.opponent_id) {
strategy.set_color(color);
}
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let mut calculate_points = false;
let opponent_color = store::Color::Black;
let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id,
},
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true;
GameEvent::RollResult {
player_id: self.opponent_id,
dice: store::Dice {
values: dice_values,
},
}
}
TurnStage::MarkPoints => {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
GameEvent::Mark {
player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).0,
}
}
TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
// pas de reward : déjà comptabilisé lors du tour de blanc
GameEvent::Mark {
player_id: self.opponent_id,
points: points_rules.get_points(dice_roll_count).1,
}
}
TurnStage::HoldOrGoChoice => {
// Stratégie simple : toujours continuer
GameEvent::Go {
player_id: self.opponent_id,
}
}
TurnStage::Move => GameEvent::Move {
player_id: self.opponent_id,
moves: strategy.choose_move(),
},
};
if self.game.validate(&event) {
self.game.consume(&event);
if calculate_points {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points
reward -= REWARD_RATIO * (points as f32 - adv_points as f32);
}
}
}
reward
}
}
impl AsMut<TrictracEnvironment> for TrictracEnvironment {
fn as_mut(&mut self) -> &mut Self {
self
}
}

View file

@ -1,9 +1,14 @@
use crate::dqn::dqn_common;
use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
use store::{GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
const ERROR_REWARD: f32 = -2.12121;
const REWARD_VALID_MOVE: f32 = 2.12121;
const REWARD_RATIO: f32 = 0.01;
const WIN_POINTS: f32 = 0.1;
/// État du jeu Trictrac pour burn-rl
#[derive(Debug, Clone, Copy)]
pub struct TrictracState {
@ -84,7 +89,6 @@ pub struct TrictracEnvironment {
current_state: TrictracState,
episode_reward: f32,
pub step_count: usize,
pub min_steps: f32,
pub max_steps: usize,
pub pointrolls_count: usize,
pub goodmoves_count: usize,
@ -117,7 +121,6 @@ impl Environment for TrictracEnvironment {
current_state,
episode_reward: 0.0,
step_count: 0,
min_steps: 250.0,
max_steps: 2000,
pointrolls_count: 0,
goodmoves_count: 0,
@ -165,8 +168,7 @@ impl Environment for TrictracEnvironment {
let trictrac_action = Self::convert_action(action);
let mut reward = 0.0;
let mut is_rollpoint = false;
let mut terminated = false;
let is_rollpoint;
// Exécuter l'action si c'est le tour de l'agent DQN
if self.game.active_player_id == self.active_player_id {
@ -175,8 +177,9 @@ impl Environment for TrictracEnvironment {
if is_rollpoint {
self.pointrolls_count += 1;
}
if reward != Self::ERROR_REWARD {
if reward != ERROR_REWARD {
self.goodmoves_count += 1;
// println!("{str_action}");
}
} else {
// Action non convertible, pénalité
@ -186,31 +189,32 @@ impl Environment for TrictracEnvironment {
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// print!(":");
reward += self.play_opponent_if_needed();
}
// Vérifier si la partie est terminée
let max_steps = self.min_steps
+ (self.max_steps as f32 - self.min_steps)
* f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
// let max_steps = self.max_steps
// let max_steps = self.min_steps
// + (self.max_steps as f32 - self.min_steps)
// * f32::exp((self.goodmoves_ratio - 1.0) / 0.25);
let done = self.game.stage == Stage::Ended || self.game.determine_winner().is_some();
if done {
// Récompense finale basée sur le résultat
if let Some(winner_id) = self.game.determine_winner() {
if winner_id == self.active_player_id {
reward += 50.0; // Victoire
reward += WIN_POINTS; // Victoire
} else {
reward -= 25.0; // Défaite
reward -= WIN_POINTS; // Défaite
}
}
}
let terminated = done || self.step_count >= max_steps.round() as usize;
let terminated = done || self.step_count >= self.max_steps;
// Mettre à jour l'état
self.current_state = TrictracState::from_game_state(&self.game);
self.episode_reward += reward;
if self.visualized && terminated {
println!(
"Episode terminé. Récompense totale: {:.2}, Étapes: {}",
@ -223,21 +227,19 @@ impl Environment for TrictracEnvironment {
}
impl TrictracEnvironment {
const ERROR_REWARD: f32 = -1.12121;
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
#[allow(dead_code)]
fn convert_valid_action_index(
&self,
action: TrictracAction,
game_state: &GameState,
) -> Option<dqn_common::TrictracAction> {
use dqn_common::get_valid_actions;
) -> Option<training_common_big::TrictracAction> {
use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(game_state);
@ -254,18 +256,19 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: dqn_common::TrictracAction,
// action:training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
use dqn_common::TrictracAction;
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use training_common_big::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;
let mut need_roll = false;
let event = match action {
TrictracAction::Roll => {
// Lancer les dés
reward += 0.1;
need_roll = true;
Some(GameEvent::Roll {
player_id: self.active_player_id,
})
@ -281,7 +284,6 @@ impl TrictracEnvironment {
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
reward += 0.2;
Some(GameEvent::Go {
player_id: self.active_player_id,
})
@ -310,7 +312,6 @@ impl TrictracEnvironment {
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
reward += 0.2;
Some(GameEvent::Move {
player_id: self.active_player_id,
moves: (checker_move1, checker_move2),
@ -322,9 +323,10 @@ impl TrictracEnvironment {
if let Some(event) = event {
if self.game.validate(&event) {
self.game.consume(&event);
reward += REWARD_VALID_MOVE;
// Simuler le résultat des dés après un Roll
if matches!(action, TrictracAction::Roll) {
// if matches!(action, TrictracAction::Roll) {
if need_roll {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
let dice_event = GameEvent::RollResult {
@ -333,10 +335,11 @@ impl TrictracEnvironment {
values: dice_values,
},
};
// print!("o");
if self.game.validate(&dice_event) {
self.game.consume(&dice_event);
let (points, adv_points) = self.game.dice_points;
reward += Self::REWARD_RATIO * (points - adv_points) as f32;
reward += REWARD_RATIO * (points - adv_points) as f32;
if points > 0 {
is_rollpoint = true;
// println!("info: rolled for {reward}");
@ -348,7 +351,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
reward = Self::ERROR_REWARD;
reward = ERROR_REWARD;
}
}
@ -357,6 +360,7 @@ impl TrictracEnvironment {
/// Fait jouer l'adversaire avec une stratégie simple
fn play_opponent_if_needed(&mut self) -> f32 {
// print!("z?");
let mut reward = 0.0;
// Si c'est le tour de l'adversaire, jouer automatiquement
@ -372,6 +376,8 @@ impl TrictracEnvironment {
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let mut calculate_points = false;
let opponent_color = store::Color::Black;
let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id,
@ -379,6 +385,7 @@ impl TrictracEnvironment {
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true; // comment to replicate burnrl_before
GameEvent::RollResult {
player_id: self.opponent_id,
dice: store::Dice {
@ -387,25 +394,21 @@ impl TrictracEnvironment {
}
}
TurnStage::MarkPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_id,
points,
}
panic!("in play_opponent_if_needed > TurnStage::MarkPoints");
// let dice_roll_count = self
// .game
// .players
// .get(&self.opponent_id)
// .unwrap()
// .dice_roll_count;
// let points_rules =
// PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
// GameEvent::Mark {
// player_id: self.opponent_id,
// points: points_rules.get_points(dice_roll_count).0,
// }
}
TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
@ -414,11 +417,10 @@ impl TrictracEnvironment {
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let points = points_rules.get_points(dice_roll_count).1;
// pas de reward : déjà comptabilisé lors du tour de blanc
GameEvent::Mark {
player_id: self.opponent_id,
points,
points: points_rules.get_points(dice_roll_count).1,
}
}
TurnStage::HoldOrGoChoice => {
@ -435,6 +437,25 @@ impl TrictracEnvironment {
if self.game.validate(&event) {
self.game.consume(&event);
// print!(".");
if calculate_points {
// print!("x");
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
// Récompense proportionnelle aux points
let adv_reward = REWARD_RATIO * (points - adv_points) as f32;
reward -= adv_reward;
// if adv_reward != 0.0 {
// println!("info: opponent : {adv_reward} -> {reward}");
// }
}
}
}
reward

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common;
use crate::training_common_big;
use burn::{prelude::Backend, tensor::Tensor};
use burn_rl::base::{Action, Environment, Snapshot, State};
use rand::{thread_rng, Rng};
@ -156,17 +156,26 @@ impl Environment for TrictracEnvironment {
if self.game.active_player_id == self.active_player_id {
if let Some(action) = trictrac_action {
(reward, is_rollpoint) = self.execute_action(action);
// if reward != 0.0 {
// println!("info: self rew {reward}");
// }
if is_rollpoint {
self.pointrolls_count += 1;
}
} else {
// Action non convertible, pénalité
println!("info: action non convertible -> -1 {trictrac_action:?}");
reward = -1.0;
}
}
// Faire jouer l'adversaire (stratégie simple)
while self.game.active_player_id == self.opponent_id && self.game.stage != Stage::Ended {
// let op_rew = self.play_opponent_if_needed();
// if op_rew != 0.0 {
// println!("info: op rew {op_rew}");
// }
// reward += op_rew;
reward += self.play_opponent_if_needed();
}
@ -205,16 +214,16 @@ impl TrictracEnvironment {
const REWARD_RATIO: f32 = 1.0;
/// Convertit une action burn-rl vers une action Trictrac
pub fn convert_action(action: TrictracAction) -> Option<dqn_common::TrictracAction> {
dqn_common::TrictracAction::from_action_index(action.index.try_into().unwrap())
pub fn convert_action(action: TrictracAction) -> Option<training_common_big::TrictracAction> {
training_common_big::TrictracAction::from_action_index(action.index.try_into().unwrap())
}
/// Convertit l'index d'une action au sein des actions valides vers une action Trictrac
fn convert_valid_action_index(
&self,
action: TrictracAction,
) -> Option<dqn_common::TrictracAction> {
use dqn_common::get_valid_actions;
) -> Option<training_common_big::TrictracAction> {
use training_common_big::get_valid_actions;
// Obtenir les actions valides dans le contexte actuel
let valid_actions = get_valid_actions(&self.game);
@ -231,10 +240,10 @@ impl TrictracEnvironment {
/// Exécute une action Trictrac dans le jeu
// fn execute_action(
// &mut self,
// action: dqn_common::TrictracAction,
// action: training_common_big::TrictracAction,
// ) -> Result<f32, Box<dyn std::error::Error>> {
fn execute_action(&mut self, action: dqn_common::TrictracAction) -> (f32, bool) {
use dqn_common::TrictracAction;
fn execute_action(&mut self, action: training_common_big::TrictracAction) -> (f32, bool) {
use training_common_big::TrictracAction;
let mut reward = 0.0;
let mut is_rollpoint = false;
@ -322,6 +331,7 @@ impl TrictracEnvironment {
// Pénalité pour action invalide
// on annule les précédents reward
// et on indique une valeur reconnaissable pour statistiques
println!("info: action invalide -> err_reward");
reward = Self::ERROR_REWARD;
}
}
@ -346,6 +356,8 @@ impl TrictracEnvironment {
*strategy.get_mut_game() = self.game.clone();
// Exécuter l'action selon le turn_stage
let mut calculate_points = false;
let opponent_color = store::Color::Black;
let event = match self.game.turn_stage {
TurnStage::RollDice => GameEvent::Roll {
player_id: self.opponent_id,
@ -353,6 +365,7 @@ impl TrictracEnvironment {
TurnStage::RollWaiting => {
let mut rng = thread_rng();
let dice_values = (rng.gen_range(1..=6), rng.gen_range(1..=6));
calculate_points = true;
GameEvent::RollResult {
player_id: self.opponent_id,
dice: store::Dice {
@ -361,7 +374,6 @@ impl TrictracEnvironment {
}
}
TurnStage::MarkPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
@ -370,16 +382,12 @@ impl TrictracEnvironment {
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= Self::REWARD_RATIO * (points - adv_points) as f32; // Récompense proportionnelle aux points
GameEvent::Mark {
player_id: self.opponent_id,
points,
points: points_rules.get_points(dice_roll_count).0,
}
}
TurnStage::MarkAdvPoints => {
let opponent_color = store::Color::Black;
let dice_roll_count = self
.game
.players
@ -409,6 +417,19 @@ impl TrictracEnvironment {
if self.game.validate(&event) {
self.game.consume(&event);
if calculate_points {
let dice_roll_count = self
.game
.players
.get(&self.opponent_id)
.unwrap()
.dice_roll_count;
let points_rules =
PointsRules::new(&opponent_color, &self.game.board, self.game.dice);
let (points, adv_points) = points_rules.get_points(dice_roll_count);
reward -= Self::REWARD_RATIO * (points - adv_points) as f32;
// Récompense proportionnelle aux points
}
}
}
reward

123
bot/src/burnrl/main.rs Normal file
View file

@ -0,0 +1,123 @@
use bot::burnrl::algos::{
dqn, dqn_big, dqn_valid, ppo, ppo_big, ppo_valid, sac, sac_big, sac_valid,
};
use bot::burnrl::environment::TrictracEnvironment;
use bot::burnrl::environment_big::TrictracEnvironment as TrictracEnvironmentBig;
use bot::burnrl::environment_valid::TrictracEnvironment as TrictracEnvironmentValid;
use bot::burnrl::utils::{demo_model, Config};
use burn::backend::{Autodiff, NdArray};
use burn_rl::base::ElemType;
use std::env;
type Backend = Autodiff<NdArray<ElemType>>;
fn main() {
let args: Vec<String> = env::args().collect();
let algo = &args[1];
// let dir_path = &args[2];
let path = format!("bot/models/burnrl_{algo}");
println!(
"info: loading configuration from file {:?}",
confy::get_configuration_file_path("trictrac_bot", None).unwrap()
);
let mut conf: Config = confy::load("trictrac_bot", None).expect("Could not load config");
conf.save_path = Some(path.clone());
println!("{conf}----------");
match algo.as_str() {
"dqn" => {
let _agent = dqn::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironment, _, _> =
burn_rl::agent::DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"dqn_big" => {
let _agent = dqn_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn_big::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentBig, _, _> =
burn_rl::agent::DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"dqn_valid" => {
let _agent = dqn_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = dqn_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::DQN<TrictracEnvironmentValid, _, _> =
burn_rl::agent::DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"sac" => {
let _agent = sac::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironment, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"sac_big" => {
let _agent = sac_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac_big::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentBig, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"sac_valid" => {
let _agent = sac_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = sac_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::SAC<TrictracEnvironmentValid, _, _> =
burn_rl::agent::SAC::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo" => {
let _agent = ppo::run::<TrictracEnvironment, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironment, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo_big" => {
let _agent = ppo_big::run::<TrictracEnvironmentBig, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo_big::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentBig, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
"ppo_valid" => {
let _agent = ppo_valid::run::<TrictracEnvironmentValid, Backend>(&conf, false);
println!("> Chargement du modèle pour test");
let loaded_model = ppo_valid::load_model(conf.dense_size, &path);
let loaded_agent: burn_rl::agent::PPO<TrictracEnvironmentValid, _, _> =
burn_rl::agent::PPO::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}
&_ => {
println!("unknown algo {algo}");
}
}
}

5
bot/src/burnrl/mod.rs Normal file
View file

@ -0,0 +1,5 @@
pub mod algos;
pub mod environment;
pub mod environment_big;
pub mod environment_valid;
pub mod utils;

132
bot/src/burnrl/utils.rs Normal file
View file

@ -0,0 +1,132 @@
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use burn_rl::base::{Agent, ElemType, Environment};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Config {
pub save_path: Option<String>,
pub max_steps: usize, // max steps by episode
pub num_episodes: usize,
pub dense_size: usize, // neural network complexity
// discount factor. Plus élevé = encourage stratégies à long terme
pub gamma: f32,
// soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation plus lente moins sensible aux coups de chance
pub tau: f32,
// taille du pas. Bas : plus lent, haut : risque de ne jamais
pub learning_rate: f32,
// nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
pub batch_size: usize,
// limite max de correction à apporter au gradient (default 100)
pub clip_grad: f32,
// ---- for SAC
pub min_probability: f32,
// ---- for DQN
// epsilon initial value (0.9 => more exploration)
pub eps_start: f64,
pub eps_end: f64,
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
pub eps_decay: f64,
// ---- for PPO
pub lambda: f32,
pub epsilon_clip: f32,
pub critic_weight: f32,
pub entropy_weight: f32,
pub epochs: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
save_path: None,
max_steps: 2000,
num_episodes: 1000,
dense_size: 256,
gamma: 0.999,
tau: 0.005,
learning_rate: 0.001,
batch_size: 32,
clip_grad: 100.0,
min_probability: 1e-9,
eps_start: 0.9,
eps_end: 0.05,
eps_decay: 1000.0,
lambda: 0.95,
epsilon_clip: 0.2,
critic_weight: 0.5,
entropy_weight: 0.01,
epochs: 8,
}
}
}
impl std::fmt::Display for Config {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut s = String::new();
s.push_str(&format!("max_steps={:?}\n", self.max_steps));
s.push_str(&format!("num_episodes={:?}\n", self.num_episodes));
s.push_str(&format!("dense_size={:?}\n", self.dense_size));
s.push_str(&format!("eps_start={:?}\n", self.eps_start));
s.push_str(&format!("eps_end={:?}\n", self.eps_end));
s.push_str(&format!("eps_decay={:?}\n", self.eps_decay));
s.push_str(&format!("gamma={:?}\n", self.gamma));
s.push_str(&format!("tau={:?}\n", self.tau));
s.push_str(&format!("learning_rate={:?}\n", self.learning_rate));
s.push_str(&format!("batch_size={:?}\n", self.batch_size));
s.push_str(&format!("clip_grad={:?}\n", self.clip_grad));
s.push_str(&format!("min_probability={:?}\n", self.min_probability));
s.push_str(&format!("lambda={:?}\n", self.lambda));
s.push_str(&format!("epsilon_clip={:?}\n", self.epsilon_clip));
s.push_str(&format!("critic_weight={:?}\n", self.critic_weight));
s.push_str(&format!("entropy_weight={:?}\n", self.entropy_weight));
s.push_str(&format!("epochs={:?}\n", self.epochs));
write!(f, "{s}")
}
}
pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut env = E::new(true);
let mut state = env.state();
let mut done = false;
while !done {
if let Some(action) = agent.react(&state) {
let snapshot = env.step(action);
state = *snapshot.state();
done = snapshot.done();
}
}
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}

View file

@ -1,53 +0,0 @@
use bot::dqn::burnrl::{
dqn_model, environment,
utils::{demo_model, load_model, save_model},
};
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
// defaults
num_episodes: 40, // 40
min_steps: 500.0, // 1000 min of max steps by episode (mise à jour par la fonction)
max_steps: 3000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ?
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
let path = "models/burn_dqn_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}

View file

@ -1,3 +0,0 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,114 +0,0 @@
use crate::dqn::burnrl::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
let model_path = format!("{path}_model.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);
let mut done = false;
while !done {
// let action = match infer_action(&agent, &env, state) {
let action = match infer_action(&agent, &env) {
Some(value) => value,
None => break,
};
// Execute action
let snapshot = env.step(action);
done = snapshot.done();
}
}
fn infer_action<B: Backend, M: DQNModel<B>>(
agent: &DQN<TrictracEnvironment, B, M>,
env: &TrictracEnvironment,
) -> Option<TrictracAction> {
let state = env.state();
// Get q-values
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
// Get valid actions
let valid_actions_indices = get_valid_action_indices(&env.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = TrictracAction::from(action_index);
Some(action)
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}

View file

@ -1,52 +0,0 @@
use bot::dqn::burnrl_valid::{
dqn_model, environment,
utils::{demo_model, load_model, save_model},
};
use burn::backend::{Autodiff, NdArray};
use burn_rl::agent::DQN;
use burn_rl::base::ElemType;
type Backend = Autodiff<NdArray<ElemType>>;
type Env = environment::TrictracEnvironment;
fn main() {
// println!("> Entraînement");
// See also MEMORY_SIZE in dqn_model.rs : 8192
let conf = dqn_model::DqnConfig {
// defaults
num_episodes: 100, // 40
max_steps: 1000, // 1000 max steps by episode
dense_size: 256, // 128 neural network complexity (default 128)
eps_start: 0.9, // 0.9 epsilon initial value (0.9 => more exploration)
eps_end: 0.05, // 0.05
// eps_decay higher = epsilon decrease slower
// used in : epsilon = eps_end + (eps_start - eps_end) * e^(-step / eps_decay);
// epsilon is updated at the start of each episode
eps_decay: 2000.0, // 1000 ?
gamma: 0.999, // 0.999 discount factor. Plus élevé = encourage stratégies à long terme
tau: 0.005, // 0.005 soft update rate. Taux de mise à jour du réseau cible. Plus bas = adaptation
// plus lente moins sensible aux coups de chance
learning_rate: 0.001, // 0.001 taille du pas. Bas : plus lent, haut : risque de ne jamais
// converger
batch_size: 32, // 32 nombre d'expériences passées sur lesquelles pour calcul de l'erreur moy.
clip_grad: 100.0, // 100 limite max de correction à apporter au gradient (default 100)
};
println!("{conf}----------");
let agent = dqn_model::run::<Env, Backend>(&conf, false); //true);
let valid_agent = agent.valid();
println!("> Sauvegarde du modèle de validation");
let path = "bot/models/burn_dqn_valid_40".to_string();
save_model(valid_agent.model().as_ref().unwrap(), &path);
println!("> Chargement du modèle pour test");
let loaded_model = load_model(conf.dense_size, &path);
let loaded_agent = DQN::new(loaded_model.unwrap());
println!("> Test avec le modèle chargé");
demo_model(loaded_agent);
}

View file

@ -1,3 +0,0 @@
pub mod dqn_model;
pub mod environment;
pub mod utils;

View file

@ -1,114 +0,0 @@
use crate::dqn::burnrl_valid::{
dqn_model,
environment::{TrictracAction, TrictracEnvironment},
};
use crate::dqn::dqn_common::get_valid_action_indices;
use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
use burn::module::{Module, Param, ParamId};
use burn::nn::Linear;
use burn::record::{CompactRecorder, Recorder};
use burn::tensor::backend::Backend;
use burn::tensor::cast::ToElement;
use burn::tensor::Tensor;
use burn_rl::agent::{DQNModel, DQN};
use burn_rl::base::{Action, ElemType, Environment, State};
pub fn save_model(model: &dqn_model::Net<NdArray<ElemType>>, path: &String) {
let recorder = CompactRecorder::new();
let model_path = format!("{path}_model.mpk");
println!("Modèle de validation sauvegardé : {model_path}");
recorder
.record(model.clone().into_record(), model_path.into())
.unwrap();
}
pub fn load_model(dense_size: usize, path: &String) -> Option<dqn_model::Net<NdArray<ElemType>>> {
let model_path = format!("{path}_model.mpk");
// println!("Chargement du modèle depuis : {model_path}");
CompactRecorder::new()
.load(model_path.into(), &NdArrayDevice::default())
.map(|record| {
dqn_model::Net::new(
<TrictracEnvironment as Environment>::StateType::size(),
dense_size,
<TrictracEnvironment as Environment>::ActionType::size(),
)
.load_record(record)
})
.ok()
}
pub fn demo_model<B: Backend, M: DQNModel<B>>(agent: DQN<TrictracEnvironment, B, M>) {
let mut env = TrictracEnvironment::new(true);
let mut done = false;
while !done {
// let action = match infer_action(&agent, &env, state) {
let action = match infer_action(&agent, &env) {
Some(value) => value,
None => break,
};
// Execute action
let snapshot = env.step(action);
done = snapshot.done();
}
}
fn infer_action<B: Backend, M: DQNModel<B>>(
agent: &DQN<TrictracEnvironment, B, M>,
env: &TrictracEnvironment,
) -> Option<TrictracAction> {
let state = env.state();
// Get q-values
let q_values = agent
.model()
.as_ref()
.unwrap()
.infer(state.to_tensor().unsqueeze());
// Get valid actions
let valid_actions_indices = get_valid_action_indices(&env.game);
if valid_actions_indices.is_empty() {
return None; // No valid actions, end of episode
}
// Set non valid actions q-values to lowest
let mut masked_q_values = q_values.clone();
let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
for (index, q_value) in q_values_vec.iter().enumerate() {
if !valid_actions_indices.contains(&index) {
masked_q_values = masked_q_values.clone().mask_fill(
masked_q_values.clone().equal_elem(*q_value),
f32::NEG_INFINITY,
);
}
}
// Get best action (highest q-value)
let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
let action = TrictracAction::from(action_index);
Some(action)
}
fn soft_update_tensor<const N: usize, B: Backend>(
this: &Param<Tensor<B, N>>,
that: &Param<Tensor<B, N>>,
tau: ElemType,
) -> Param<Tensor<B, N>> {
let that_weight = that.val();
let this_weight = this.val();
let new_weight = this_weight * (1.0 - tau) + that_weight * tau;
Param::initialized(ParamId::new(), new_weight)
}
pub fn soft_update_linear<B: Backend>(
this: Linear<B>,
that: &Linear<B>,
tau: ElemType,
) -> Linear<B> {
let weight = soft_update_tensor(&this.weight, &that.weight, tau);
let bias = match (&this.bias, &that.bias) {
(Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),
_ => None,
};
Linear::<B> { weight, bias }
}

View file

@ -1,5 +0,0 @@
pub mod burnrl;
pub mod dqn_common;
pub mod simple;
pub mod burnrl_valid;

View file

@ -1,4 +1,4 @@
use crate::dqn::dqn_common::TrictracAction;
use crate::training_common_big::TrictracAction;
use serde::{Deserialize, Serialize};
/// Configuration pour l'agent DQN
@ -151,4 +151,3 @@ impl SimpleNeuralNetwork {
Ok(network)
}
}

View file

@ -6,7 +6,7 @@ use std::collections::VecDeque;
use store::{GameEvent, MoveRules, PointsRules, Stage, TurnStage};
use super::dqn_model::{DqnConfig, SimpleNeuralNetwork};
use crate::dqn::dqn_common::{get_valid_actions, TrictracAction};
use crate::training_common_big::{get_valid_actions, TrictracAction};
/// Expérience pour le buffer de replay
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -55,6 +55,10 @@ impl ReplayBuffer {
batch
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
@ -457,7 +461,7 @@ impl DqnTrainer {
save_every: usize,
model_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Démarrage de l'entraînement DQN pour {} épisodes", episodes);
println!("Démarrage de l'entraînement DQN pour {episodes} épisodes");
for episode in 1..=episodes {
let reward = self.train_episode();
@ -474,16 +478,16 @@ impl DqnTrainer {
}
if episode % save_every == 0 {
let save_path = format!("{}_episode_{}.json", model_path, episode);
let save_path = format!("{model_path}_episode_{episode}.json");
self.agent.save_model(&save_path)?;
println!("Modèle sauvegardé : {}", save_path);
println!("Modèle sauvegardé : {save_path}");
}
}
// Sauvegarder le modèle final
let final_path = format!("{}_final.json", model_path);
let final_path = format!("{model_path}_final.json");
self.agent.save_model(&final_path)?;
println!("Modèle final sauvegardé : {}", final_path);
println!("Modèle final sauvegardé : {final_path}");
Ok(())
}

View file

@ -1,6 +1,6 @@
use bot::dqn::dqn_common::TrictracAction;
use bot::dqn::simple::dqn_model::DqnConfig;
use bot::dqn::simple::dqn_trainer::DqnTrainer;
use bot::dqn_simple::dqn_model::DqnConfig;
use bot::dqn_simple::dqn_trainer::DqnTrainer;
use bot::training_common::TrictracAction;
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -60,9 +60,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
std::fs::create_dir_all("models")?;
println!("Configuration d'entraînement DQN :");
println!(" Épisodes : {}", episodes);
println!(" Chemin du modèle : {}", model_path);
println!(" Sauvegarde tous les {} épisodes", save_every);
println!(" Épisodes : {episodes}");
println!(" Chemin du modèle : {model_path}");
println!(" Sauvegarde tous les {save_every} épisodes");
println!();
// Configuration DQN
@ -85,10 +85,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Entraînement terminé avec succès !");
println!("Pour utiliser le modèle entraîné :");
println!(
" cargo run --bin=client_cli -- --bot dqn:{}_final.json,dummy",
model_path
);
println!(" cargo run --bin=client_cli -- --bot dqn:{model_path}_final.json,dummy");
Ok(())
}

View file

@ -1,7 +1,11 @@
pub mod dqn;
pub mod burnrl;
pub mod dqn_simple;
pub mod strategy;
pub mod training_common;
pub mod training_common_big;
pub mod trictrac_board;
use log::{debug, error};
use log::debug;
use store::{CheckerMove, Color, GameEvent, GameState, PlayerId, PointsRules, Stage, TurnStage};
pub use strategy::default::DefaultStrategy;
pub use strategy::dqn::DqnStrategy;

View file

@ -3,8 +3,8 @@ use log::info;
use std::path::Path;
use store::MoveRules;
use crate::dqn::dqn_common::{get_valid_actions, sample_valid_action, TrictracAction};
use crate::dqn::simple::dqn_model::SimpleNeuralNetwork;
use crate::dqn_simple::dqn_model::SimpleNeuralNetwork;
use crate::training_common_big::{get_valid_actions, sample_valid_action, TrictracAction};
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]

View file

@ -6,10 +6,11 @@ use crate::{BotStrategy, CheckerMove, Color, GameState, PlayerId};
use log::info;
use store::MoveRules;
use crate::dqn::burnrl::{dqn_model, environment, utils};
use crate::dqn::dqn_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
use crate::burnrl::algos::dqn;
use crate::burnrl::environment;
use crate::training_common::{get_valid_action_indices, sample_valid_action, TrictracAction};
type DqnBurnNetwork = dqn_model::Net<NdArray<ElemType>>;
type DqnBurnNetwork = dqn::Net<NdArray<ElemType>>;
/// Stratégie DQN pour le bot - ne fait que charger et utiliser un modèle pré-entraîné
#[derive(Debug)]
@ -39,7 +40,7 @@ impl DqnBurnStrategy {
pub fn new_with_model(model_path: &String) -> Self {
info!("Loading model {model_path:?}");
let mut strategy = Self::new();
strategy.model = utils::load_model(256, model_path);
strategy.model = dqn::load_model(256, model_path);
strategy
}
@ -117,8 +118,8 @@ impl BotStrategy for DqnBurnStrategy {
// Utiliser le DQN pour choisir le mouvement
if let Some(TrictracAction::Move {
dice_order,
from1,
from2,
checker1,
checker2,
}) = self.get_dqn_action()
{
let dicevals = self.game.dice.values;
@ -128,23 +129,65 @@ impl BotStrategy for DqnBurnStrategy {
(dicevals.1, dicevals.0)
};
assert_eq!(self.color, Color::White);
let from1 = self
.game
.board
.get_checker_field(&self.color, checker1 as u8)
.unwrap_or(0);
if from1 == 0 {
// empty move
dice1 = 0;
}
let mut to1 = from1 + dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
let mut to1 = from1;
if self.color == Color::White {
to1 += dice1 as usize;
if 24 < to1 {
// sortie
to1 = 0;
}
} else {
let fto1 = to1 as i16 - dice1 as i16;
to1 = if fto1 < 0 { 0 } else { fto1 as usize };
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = self.game.board.clone();
let move_res = tmp_board.move_checker(&self.color, checker_move1);
if move_res.is_err() {
panic!("could not move {move_res:?}");
}
let from2 = tmp_board
.get_checker_field(&self.color, checker2 as u8)
.unwrap_or(0);
if from2 == 0 {
// empty move
dice2 = 0;
}
let mut to2 = from2 + dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
let mut to2 = from2;
if self.color == Color::White {
to2 += dice2 as usize;
if 24 < to2 {
// sortie
to2 = 0;
}
} else {
let fto2 = to2 as i16 - dice2 as i16;
to2 = if fto2 < 0 { 0 } else { fto2 as usize };
}
// Gestion prise de coin par puissance
let opp_rest_field = if self.color == Color::White { 13 } else { 12 };
if to1 == opp_rest_field && to2 == opp_rest_field {
if self.color == Color::White {
to1 -= 1;
to2 -= 1;
} else {
to1 += 1;
to2 += 1;
}
}
let checker_move1 = CheckerMove::new(from1, to1).unwrap_or_default();
@ -153,6 +196,7 @@ impl BotStrategy for DqnBurnStrategy {
let chosen_move = if self.color == Color::White {
(checker_move1, checker_move2)
} else {
// XXX : really ?
(checker_move1.mirror(), checker_move2.mirror())
};

View file

@ -66,14 +66,14 @@ impl StableBaselines3Strategy {
// Remplir les positions des pièces blanches (valeurs positives)
for (pos, count) in self.game.board.get_color_fields(Color::White) {
if pos < 24 {
board[pos] = count as i8;
board[pos] = count;
}
}
// Remplir les positions des pièces noires (valeurs négatives)
for (pos, count) in self.game.board.get_color_fields(Color::Black) {
if pos < 24 {
board[pos] = -(count as i8);
board[pos] = -count;
}
}
@ -270,4 +270,3 @@ impl BotStrategy for StableBaselines3Strategy {
}
}
}

351
bot/src/training_common.rs Normal file
View file

@ -0,0 +1,351 @@
use std::cmp::{max, min};
use std::fmt::{Debug, Display, Formatter};
use serde::{Deserialize, Serialize};
use store::{CheckerMove, GameEvent, GameState};
/// Types d'actions possibles dans le jeu
#[derive(Debug, Copy, Clone, Eq, Serialize, Deserialize, PartialEq)]
pub enum TrictracAction {
/// Lancer les dés
Roll,
/// Continuer après avoir gagné un trou
Go,
/// Effectuer un mouvement de pions
Move {
dice_order: bool, // true = utiliser dice[0] en premier, false = dice[1] en premier
checker1: usize, // premier pion à déplacer en numérotant depuis la colonne de départ (0-15) 0 : aucun pion
checker2: usize, // deuxième pion (0-15)
},
// Marquer les points : à activer si support des écoles
// Mark,
}
impl Display for TrictracAction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s = format!("{self:?}");
writeln!(f, "{}", s.chars().rev().collect::<String>())?;
Ok(())
}
}
impl TrictracAction {
/// Encode une action en index pour le réseau de neurones
pub fn to_action_index(&self) -> usize {
match self {
TrictracAction::Roll => 0,
TrictracAction::Go => 1,
TrictracAction::Move {
dice_order,
checker1,
checker2,
} => {
// Encoder les mouvements dans l'espace d'actions
// Indices 2+ pour les mouvements
// de 2 à 513 (2 à 257 pour dé 1 en premier, 258 à 513 pour dé 2 en premier)
let mut start = 2;
if !dice_order {
// 16 * 16 = 256
start += 256;
}
start + checker1 * 16 + checker2
} // TrictracAction::Mark => 514,
}
}
pub fn to_event(&self, state: &GameState) -> Option<GameEvent> {
match self {
TrictracAction::Roll => {
// Lancer les dés
Some(GameEvent::Roll {
player_id: state.active_player_id,
})
}
// TrictracAction::Mark => {
// // Marquer des points
// let points = self.game.
// Some(GameEvent::Mark {
// player_id: self.active_player_id,
// points,
// })
// }
TrictracAction::Go => {
// Continuer après avoir gagné un trou
Some(GameEvent::Go {
player_id: state.active_player_id,
})
}
TrictracAction::Move {
dice_order,
checker1,
checker2,
} => {
// Effectuer un mouvement
let (dice1, dice2) = if *dice_order {
(state.dice.values.0, state.dice.values.1)
} else {
(state.dice.values.1, state.dice.values.0)
};
let color = &store::Color::White;
let from1 = state
.board
.get_checker_field(color, *checker1 as u8)
.unwrap_or(0);
let mut to1 = from1 + dice1 as usize;
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let mut tmp_board = state.board.clone();
let move_result = tmp_board.move_checker(color, checker_move1);
if move_result.is_err() {
None
// panic!("Error while moving checker {move_result:?}")
} else {
let from2 = tmp_board
.get_checker_field(color, *checker2 as u8)
.unwrap_or(0);
let mut to2 = from2 + dice2 as usize;
// Gestion prise de coin par puissance
let opp_rest_field = 13;
if to1 == opp_rest_field && to2 == opp_rest_field {
to1 -= 1;
to2 -= 1;
}
let checker_move1 = store::CheckerMove::new(from1, to1).unwrap_or_default();
let checker_move2 = store::CheckerMove::new(from2, to2).unwrap_or_default();
Some(GameEvent::Move {
player_id: state.active_player_id,
moves: (checker_move1, checker_move2),
})
}
}
}
}
/// Décode un index d'action en TrictracAction
pub fn from_action_index(index: usize) -> Option<TrictracAction> {
match index {
0 => Some(TrictracAction::Roll),
1 => Some(TrictracAction::Go),
// 514 => Some(TrictracAction::Mark),
i if i >= 2 => {
let move_code = i - 2;
let (dice_order, checker1, checker2) = Self::decode_move(move_code);
Some(TrictracAction::Move {
dice_order,
checker1,
checker2,
})
}
_ => None,
}
}
/// Décode un entier en paire de mouvements
fn decode_move(code: usize) -> (bool, usize, usize) {
let mut encoded = code;
let dice_order = code < 256;
if !dice_order {
encoded -= 256
}
let checker1 = encoded / 16;
let checker2 = encoded % 16;
(dice_order, checker1, checker2)
}
/// Retourne la taille de l'espace d'actions total
pub fn action_space_size() -> usize {
// 1 (Roll) + 1 (Go) + mouvements possibles
// Pour les mouvements : 2*25*25 = 1250 (choix du dé + position 0-24 pour chaque from)
// Mais on peut optimiser en limitant aux positions valides (1-24)
2 + (2 * 16 * 16) // = 514
}
// pub fn to_game_event(&self, player_id: PlayerId, dice: Dice) -> GameEvent {
// match action {
// TrictracAction::Roll => Some(GameEvent::Roll { player_id }),
// TrictracAction::Mark => Some(GameEvent::Mark { player_id, points }),
// TrictracAction::Go => Some(GameEvent::Go { player_id }),
// TrictracAction::Move {
// dice_order,
// from1,
// from2,
// } => {
// // Effectuer un mouvement
// let checker_move1 = store::CheckerMove::new(move1.0, move1.1).unwrap_or_default();
// let checker_move2 = store::CheckerMove::new(move2.0, move2.1).unwrap_or_default();
//
// Some(GameEvent::Move {
// player_id: self.agent_player_id,
// moves: (checker_move1, checker_move2),
// })
// }
// };
// }
}
/// Obtient les actions valides pour l'état de jeu actuel
pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
use store::TurnStage;
let mut valid_actions = Vec::new();
let active_player_id = game_state.active_player_id;
let player_color = game_state.player_color_by_id(&active_player_id);
if let Some(color) = player_color {
match game_state.turn_stage {
TurnStage::RollDice => {
valid_actions.push(TrictracAction::Roll);
}
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
// valid_actions.push(TrictracAction::Mark);
panic!(
"get_valid_actions not implemented for turn stage {:?}",
game_state.turn_stage
);
}
TurnStage::HoldOrGoChoice => {
valid_actions.push(TrictracAction::Go);
// Ajoute aussi les mouvements possibles
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state,
));
}
}
TurnStage::Move => {
let rules = store::MoveRules::new(&color, &game_state.board, game_state.dice);
let possible_moves = rules.get_possible_moves_sequences(true, vec![]);
// Modififier checker_moves_to_trictrac_action si on doit gérer Black
assert_eq!(color, store::Color::White);
for (move1, move2) in possible_moves {
valid_actions.push(checker_moves_to_trictrac_action(
&move1, &move2, &color, game_state,
));
}
}
}
}
if valid_actions.is_empty() {
panic!("empty valid_actions for state {game_state}");
}
valid_actions
}
// Valid only for White player
fn checker_moves_to_trictrac_action(
move1: &CheckerMove,
move2: &CheckerMove,
color: &store::Color,
state: &crate::GameState,
) -> TrictracAction {
let to1 = move1.get_to();
let to2 = move2.get_to();
let from1 = move1.get_from();
let from2 = move2.get_from();
let dice = state.dice;
let mut diff_move1 = if to1 > 0 {
// Mouvement sans sortie
to1 - from1
} else {
// sortie, on utilise la valeur du dé
if to2 > 0 {
// sortie pour le mouvement 1 uniquement
let dice2 = to2 - from2;
if dice2 == dice.values.0 as usize {
dice.values.1 as usize
} else {
dice.values.0 as usize
}
} else {
// double sortie
if from1 < from2 {
max(dice.values.0, dice.values.1) as usize
} else {
min(dice.values.0, dice.values.1) as usize
}
}
};
// modification de diff_move1 si on est dans le cas d'un mouvement par puissance
let rest_field = 12;
if to1 == rest_field
&& to2 == rest_field
&& max(dice.values.0 as usize, dice.values.1 as usize) + min(from1, from2) != rest_field
{
// prise par puissance
diff_move1 += 1;
}
let dice_order = diff_move1 == dice.values.0 as usize;
let checker1 = state.board.get_field_checker(color, from1) as usize;
let mut tmp_board = state.board.clone();
// should not raise an error for a valid action
let move_res = tmp_board.move_checker(color, *move1);
if move_res.is_err() {
panic!("error while moving checker {move_res:?}");
}
let checker2 = tmp_board.get_field_checker(color, from2) as usize;
TrictracAction::Move {
dice_order,
checker1,
checker2,
}
}
/// Retourne les indices des actions valides
pub fn get_valid_action_indices(game_state: &crate::GameState) -> Vec<usize> {
get_valid_actions(game_state)
.into_iter()
.map(|action| action.to_action_index())
.collect()
}
/// Sélectionne une action valide aléatoire
pub fn sample_valid_action(game_state: &crate::GameState) -> Option<TrictracAction> {
use rand::{seq::SliceRandom, thread_rng};
let valid_actions = get_valid_actions(game_state);
let mut rng = thread_rng();
valid_actions.choose(&mut rng).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn to_action_index() {
let action = TrictracAction::Move {
dice_order: true,
checker1: 3,
checker2: 4,
};
let index = action.to_action_index();
assert_eq!(Some(action), TrictracAction::from_action_index(index));
assert_eq!(54, index);
}
#[test]
fn from_action_index() {
let action = TrictracAction::Move {
dice_order: true,
checker1: 3,
checker2: 4,
};
assert_eq!(Some(action), TrictracAction::from_action_index(54));
}
}

View file

@ -117,10 +117,14 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
if let Some(color) = player_color {
match game_state.turn_stage {
TurnStage::RollDice | TurnStage::RollWaiting => {
TurnStage::RollDice => {
valid_actions.push(TrictracAction::Roll);
}
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
TurnStage::MarkPoints | TurnStage::MarkAdvPoints | TurnStage::RollWaiting => {
panic!(
"get_valid_actions not implemented for turn stage {:?}",
game_state.turn_stage
);
// valid_actions.push(TrictracAction::Mark);
}
TurnStage::HoldOrGoChoice => {
@ -157,6 +161,9 @@ pub fn get_valid_actions(game_state: &crate::GameState) -> Vec<TrictracAction> {
}
}
if valid_actions.is_empty() {
panic!("empty valid_actions for state {game_state}");
}
valid_actions
}

149
bot/src/trictrac_board.rs Normal file
View file

@ -0,0 +1,149 @@
// https://docs.rs/board-game/ implementation
use crate::training_common::{get_valid_actions, TrictracAction};
use board_game::board::{
Board as BoardGameBoard, BoardDone, BoardMoves, Outcome, PlayError, Player as BoardGamePlayer,
};
use board_game::impl_unit_symmetry_board;
use internal_iterator::InternalIterator;
use std::fmt;
use std::ops::ControlFlow;
use store::Color;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TrictracBoard(crate::GameState);
impl Default for TrictracBoard {
fn default() -> Self {
TrictracBoard(crate::GameState::new_with_players("white", "black"))
}
}
impl fmt::Display for TrictracBoard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl_unit_symmetry_board!(TrictracBoard);
impl BoardGameBoard for TrictracBoard {
// impl TrictracBoard {
type Move = TrictracAction;
fn next_player(&self) -> BoardGamePlayer {
self.0
.who_plays()
.map(|p| {
if p.color == Color::Black {
BoardGamePlayer::B
} else {
BoardGamePlayer::A
}
})
.unwrap_or(BoardGamePlayer::A)
}
fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
self.check_done()?;
let is_valid = mv
.to_event(&self.0)
.map(|evt| self.0.validate(&evt))
.unwrap_or(false);
Ok(is_valid)
}
fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
self.check_can_play(mv)?;
self.0.consume(&mv.to_event(&self.0).unwrap());
Ok(())
}
fn outcome(&self) -> Option<Outcome> {
if self.0.stage == crate::Stage::Ended {
self.0.determine_winner().map(|player_id| {
Outcome::WonBy(if player_id == 1 {
BoardGamePlayer::A
} else {
BoardGamePlayer::B
})
})
} else {
None
}
}
fn can_lose_after_move() -> bool {
true
}
}
impl<'a> BoardMoves<'a, TrictracBoard> for TrictracBoard {
type AllMovesIterator = TrictracAllMovesIterator;
type AvailableMovesIterator = TrictracAvailableMovesIterator<'a>;
fn all_possible_moves() -> Self::AllMovesIterator {
TrictracAllMovesIterator::default()
}
fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
TrictracAvailableMovesIterator::new(self)
}
}
#[derive(Debug, Clone)]
pub struct TrictracAllMovesIterator;
impl Default for TrictracAllMovesIterator {
fn default() -> Self {
TrictracAllMovesIterator
}
}
impl InternalIterator for TrictracAllMovesIterator {
type Item = TrictracAction;
fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
f(TrictracAction::Roll)?;
f(TrictracAction::Go)?;
for dice_order in [false, true] {
for checker1 in 0..16 {
for checker2 in 0..16 {
f(TrictracAction::Move {
dice_order,
checker1,
checker2,
})?;
}
}
}
ControlFlow::Continue(())
}
}
#[derive(Debug, Clone)]
pub struct TrictracAvailableMovesIterator<'a> {
board: &'a TrictracBoard,
}
impl<'a> TrictracAvailableMovesIterator<'a> {
pub fn new(board: &'a TrictracBoard) -> Result<Self, BoardDone> {
board.check_done()?;
Ok(TrictracAvailableMovesIterator { board })
}
pub fn board(&self) -> &'a TrictracBoard {
self.board
}
}
impl InternalIterator for TrictracAvailableMovesIterator<'_> {
type Item = TrictracAction;
fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
where
F: FnMut(Self::Item) -> ControlFlow<R>,
{
get_valid_actions(&self.board.0).into_iter().try_for_each(f)
}
}

View file

@ -59,7 +59,7 @@ impl App {
}
s if s.starts_with("dqnburn:") => {
let path = s.trim_start_matches("dqnburn:");
Some(Box::new(DqnBurnStrategy::new_with_model(&format!("{path}")))
Some(Box::new(DqnBurnStrategy::new_with_model(&path.to_string()))
as Box<dyn BotStrategy>)
}
_ => None,
@ -114,7 +114,7 @@ impl App {
pub fn show_history(&self) {
for hist in self.game.state.history.iter() {
println!("{:?}\n", hist);
println!("{hist:?}\n");
}
}
@ -139,6 +139,9 @@ impl App {
// &self.game.state.board,
// dice,
// );
self.game.handle_event(&GameEvent::Roll {
player_id: self.game.player_id.unwrap(),
});
self.game.handle_event(&GameEvent::RollResult {
player_id: self.game.player_id.unwrap(),
dice,
@ -189,7 +192,7 @@ impl App {
return;
}
}
println!("invalid move : {}", input);
println!("invalid move : {input}");
}
pub fn display(&mut self) -> String {
@ -329,6 +332,7 @@ Player :: holes :: points
seed: Some(1327),
bot: Some("dummy".into()),
});
println!("avant : {}", app.display());
app.input("roll");
app.input("1 3");
app.input("1 4");

View file

@ -77,7 +77,7 @@ impl GameRunner {
} else {
debug!("{}", self.state);
error!("event not valid : {event:?}");
panic!("crash and burn");
// panic!("crash and burn {} \nevt not valid {event:?}", self.state);
&GameEvent::PlayError
};

View file

@ -35,7 +35,7 @@ fn main() -> Result<()> {
let args = match parse_args() {
Ok(v) => v,
Err(e) => {
eprintln!("Error: {}.", e);
eprintln!("Error: {e}.");
std::process::exit(1);
}
};
@ -63,7 +63,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// Help has a higher priority and should be handled separately.
if pargs.contains(["-h", "--help"]) {
print!("{}", HELP);
print!("{HELP}");
std::process::exit(0);
}
@ -78,7 +78,7 @@ fn parse_args() -> Result<AppArgs, pico_args::Error> {
// It's up to the caller what to do with the remaining arguments.
let remaining = pargs.finish();
if !remaining.is_empty() {
eprintln!("Warning: unused arguments left: {:?}.", remaining);
eprintln!("Warning: unused arguments left: {remaining:?}.");
}
Ok(args)

View file

@ -1,4 +1,4 @@
# Description du projet et question
# Description du projet
Je développe un jeu de TricTrac (<https://fr.wikipedia.org/wiki/Trictrac>) dans le langage rust.
Pour le moment je me concentre sur l'application en ligne de commande simple, donc ne t'occupe pas des dossiers 'client_bevy', 'client_tui', et 'server' qui ne seront utilisés que pour de prochaines évolutions.
@ -12,35 +12,8 @@ Plus précisément, l'état du jeu est défini par le struct GameState dans stor
'bot/src/strategy/default.rs' contient le code d'une stratégie de bot basique : il détermine la liste des mouvements valides (avec la méthode get_possible_moves_sequences de store::MoveRules) et joue simplement le premier de la liste.
Je cherche maintenant à ajouter des stratégies de bot plus fortes en entrainant un agent/bot par reinforcement learning.
J'utilise la bibliothèque burn (<https://burn.dev/>).
Une première version avec DQN fonctionne (entraînement avec `cargo run -bin=train_dqn`)
Il gagne systématiquement contre le bot par défaut 'dummy' : `cargo run --bin=client_cli -- --bot dqn:./models/dqn_model_final.json,dummy`.
Une version utilisant l'algorithme DQN peut être lancée avec `cargo run --bin=burn_train -- dqn`). Elle effectue un entraînement, sauvegarde les données du modèle obtenu puis recharge le modèle depuis le disque pour tester l'agent. L'entraînement est fait dans la fonction 'run' du fichier bot/src/burnrl/dqn_model.rs, la sauvegarde du modèle dans la fonction 'save_model' et le chargement dans la fonction 'load_model'.
Une version, toujours DQN, mais en utilisant la bibliothèque burn (<https://burn.dev/>) est en cours de développement.
L'entraînement du modèle se passe dans la fonction "main" du fichier bot/src/burnrl/main.rs. On peut lancer l'exécution avec 'just trainbot'.
Voici la sortie de l'entraînement lancé avec 'just trainbot' :
```
> Entraînement
> {"episode": 0, "reward": -1692.3148, "duration": 1000}
> {"episode": 1, "reward": -361.6962, "duration": 1000}
> {"episode": 2, "reward": -126.1013, "duration": 1000}
> {"episode": 3, "reward": -36.8000, "duration": 1000}
> {"episode": 4, "reward": -21.4997, "duration": 1000}
> {"episode": 5, "reward": -8.3000, "duration": 1000}
> {"episode": 6, "reward": 3.1000, "duration": 1000}
> {"episode": 7, "reward": -21.5998, "duration": 1000}
> {"episode": 8, "reward": -10.1999, "duration": 1000}
> {"episode": 9, "reward": 3.1000, "duration": 1000}
> {"episode": 10, "reward": 14.5002, "duration": 1000}
> {"episode": 11, "reward": 10.7000, "duration": 1000}
> {"episode": 12, "reward": -0.7000, "duration": 1000}
thread 'main' has overflowed its stack
fatal runtime error: stack overflow
error: Recipe `trainbot` was terminated on line 25 by signal 6
```
Au bout du 12ème épisode (plus de 6 heures sur ma machine), l'entraînement s'arrête avec une erreur stack overlow. Peux-tu m'aider à diagnostiquer d'où peut provenir le problème ? Y a-t-il des outils qui permettent de détecter les zones de code qui utilisent le plus la stack ? Pour information j'ai vu ce rapport de bug <https://github.com/yunjhongwu/burn-rl-examples/issues/40> , donc peut-être que le problème vient du paquet 'burl-rl'.
J'essaie de faire l'équivalent avec les algorithmes PPO (fichier bot/src/burnrl/ppo_model.rs) et SAC (fichier bot/src/burnrl/sac_model.rs) : les fonctions 'run' sont implémentées mais pas les fonctions 'save_model' et 'load_model'. Peux-tu les implémenter ?

View file

@ -1,46 +1,52 @@
# Inspirations
tools
- config clippy ?
- bacon : tests runner (ou loom ?)
- config clippy ?
- bacon : tests runner (ou loom ?)
## Rust libs
cf. https://blessed.rs/crates
cf. <https://blessed.rs/crates>
nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-repeat-resume/
nombres aléatoires avec seed : <https://richard.dallaway.com/posts/2021-01-04-repeat-resume/>
- cli : https://lib.rs/crates/pico-args ( ou clap )
- cli : <https://lib.rs/crates/pico-args> ( ou clap )
- reseau async : tokio
- web serveur : axum (uses tokio)
- https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling
- <https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling>
- db : sqlx
- eyre, color-eyre (Results)
- tracing (logging)
- rayon ( sync <-> parallel )
- front : yew + tauri
- front : yew + tauri
- egui
- https://docs.rs/board-game/latest/board_game/
- <https://docs.rs/board-game/latest/board_game/>
## network games
- <https://www.mattkeeter.com/projects/pont/>
- <https://github.com/jackadamson/onitama> (wasm, rooms)
- <https://github.com/UkoeHB/renet2>
## Others
- plugins avec https://github.com/extism/extism
- plugins avec <https://github.com/extism/extism>
## Backgammon existing projects
* go : https://bgammon.org/blog/20240101-hello-world/
- protocole de communication : https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md
* ocaml : https://github.com/jacobhilton/backgammon?tab=readme-ov-file
cli example : https://www.jacobh.co.uk/backgammon/
* lib rust backgammon
- https://github.com/carlostrub/backgammon
- https://github.com/marktani/backgammon
* network webtarot
* front ?
- go : <https://bgammon.org/blog/20240101-hello-world/>
- protocole de communication : <https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md>
- ocaml : <https://github.com/jacobhilton/backgammon?tab=readme-ov-file>
cli example : <https://www.jacobh.co.uk/backgammon/>
- lib rust backgammon
- <https://github.com/carlostrub/backgammon>
- <https://github.com/marktani/backgammon>
- network webtarot
- front ?
## cli examples
@ -48,7 +54,7 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
(No game) new game
gnubg rolls 3, anthon rolls 1.
GNU Backgammon Positions ID: 4HPwATDgc/ABMA
Match ID : MIEFAAAAAAAA
+12-11-10--9--8--7-------6--5--4--3--2--1-+ O: gnubg
@ -64,7 +70,7 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
| O X | | X O |
| O X | | X O | 0 points
+13-14-15-16-17-18------19-20-21-22-23-24-+ X: anthon
gnubg moves 8/5 6/5.
### jacobh
@ -72,33 +78,37 @@ nombres aléatoires avec seed : https://richard.dallaway.com/posts/2021-01-04-re
Move 11: player O rolls a 6-2.
Player O estimates that they have a 90.6111% chance of winning.
Os borne off: none
24 23 22 21 20 19 18 17 16 15 14 13
-------------------------------------------------------------------
| v v v v v v | | v v v v v v |
| | | |
| X O O O | | O O O |
| X O O O | | O O |
| O | | |
| | X | |
| | | |
| | | |
| | | |
| | | |
|------------------------------| |------------------------------|
| | | |
| | | |
| | | |
| | | |
| X | | |
| X X | | X |
| X X X | | X O |
| X X X | | X O O |
| | | |
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
-------------------------------------------------------------------
1 2 3 4 5 6 7 8 9 10 11 12
Xs borne off: none
Os borne off: none
24 23 22 21 20 19 18 17 16 15 14 13
---
| v v v v v v | | v v v v v v |
| | | |
| X O O O | | O O O |
| X O O O | | O O |
| O | | |
| | X | |
| | | |
| | | |
| | | |
| | | |
|------------------------------| |------------------------------|
| | | |
| | | |
| | | |
| | | |
| X | | |
| X X | | X |
| X X X | | X O |
| X X X | | X O O |
| | | |
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
---
1 2 3 4 5 6 7 8 9 10 11 12
Xs borne off: none
Move 12: player X rolls a 6-3.
Your move (? for help): bar/22
@ -107,13 +117,12 @@ Your move (? for help): ?
Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move.
Each position should be number from 1 to 24, "bar" or "off".
Unlike in standard notation, you should enter each counter movement individually. For example:
24/18 18/13
bar/3 13/10 13/10 8/5
2/off 1/off
24/18 18/13
bar/3 13/10 13/10 8/5
2/off 1/off
You can also enter these commands:
p - show the previous move
n - show the next move
<enter> - toggle between showing the current and last moves
help - show this help text
quit - abandon game
p - show the previous move
n - show the next move
<enter> - toggle between showing the current and last moves
help - show this help text
quit - abandon game

172
doc/store.puml Normal file
View file

@ -0,0 +1,172 @@
@startuml
class "CheckerMove" {
- from: Field
- to: Field
+ to_display_string()
+ new(from: Field, to: Field)
+ mirror()
+ chain(cmove: Self)
+ get_from()
+ get_to()
+ is_exit()
+ doable_with_dice(dice: usize)
}
class "Board" {
- positions: [i8;24]
+ new()
+ mirror()
+ set_positions(positions: [ i8 ; 24 ])
+ count_checkers(color: Color, from: Field, to: Field)
+ to_vec()
+ to_gnupg_pos_id()
+ to_display_grid(col_size: usize)
+ set(color: & Color, field: Field, amount: i8)
+ blocked(color: & Color, field: Field)
+ passage_blocked(color: & Color, field: Field)
+ get_field_checkers(field: Field)
+ get_checkers_color(field: Field)
+ is_field_in_small_jan(field: Field)
+ get_color_fields(color: Color)
+ get_color_corner(color: & Color)
+ get_possible_moves(color: Color, dice: u8, with_excedants: bool, check_rest_corner_exit: bool, forbid_exits: bool)
+ passage_possible(color: & Color, cmove: & CheckerMove)
+ move_possible(color: & Color, cmove: & CheckerMove)
+ any_quarter_filled(color: Color)
+ is_quarter_filled(color: Color, field: Field)
+ get_quarter_filling_candidate(color: Color)
+ is_quarter_fillable(color: Color, field: Field)
- get_quarter_fields(field: Field)
+ move_checker(color: & Color, cmove: CheckerMove)
+ remove_checker(color: & Color, field: Field)
+ add_checker(color: & Color, field: Field)
}
class "MoveRules" {
+ board: Board
+ dice: Dice
+ new(color: & Color, board: & Board, dice: Dice)
+ set_board(color: & Color, board: & Board)
- get_board_from_color(color: & Color, board: & Board)
+ moves_follow_rules(moves: & ( CheckerMove , CheckerMove ))
- moves_possible(moves: & ( CheckerMove , CheckerMove ))
- moves_follows_dices(moves: & ( CheckerMove , CheckerMove ))
- get_move_compatible_dices(cmove: & CheckerMove)
+ moves_allowed(moves: & ( CheckerMove , CheckerMove ))
- check_opponent_can_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
- check_must_fill_quarter_rule(moves: & ( CheckerMove , CheckerMove ))
- check_corner_rules(moves: & ( CheckerMove , CheckerMove ))
- has_checkers_outside_last_quarter()
- check_exit_rules(moves: & ( CheckerMove , CheckerMove ))
+ get_possible_moves_sequences(with_excedents: bool, ignored_rules: Vec < TricTracRule >)
+ get_scoring_quarter_filling_moves_sequences()
- get_sequence_origin_from_destination(sequence: ( CheckerMove , CheckerMove ), destination: Field)
+ get_quarter_filling_moves_sequences()
- get_possible_moves_sequences_by_dices(dice1: u8, dice2: u8, with_excedents: bool, ignore_empty: bool, ignored_rules: Vec < TricTracRule >)
- _get_direct_exit_moves(state: & GameState)
- is_move_by_puissance(moves: & ( CheckerMove , CheckerMove ))
- can_take_corner_by_effect()
}
class "DiceRoller" {
- rng: StdRng
+ new(opt_seed: Option < u64 >)
+ roll()
}
class "Dice" {
+ values: (u8,u8)
+ to_bits_string()
+ to_display_string()
+ is_double()
}
class "GameState" {
+ stage: Stage
+ turn_stage: TurnStage
+ board: Board
+ active_player_id: PlayerId
+ players: HashMap<PlayerId,Player>
+ history: Vec<GameEvent>
+ dice: Dice
+ dice_points: (u8,u8)
+ dice_moves: (CheckerMove,CheckerMove)
+ dice_jans: PossibleJans
- roll_first: bool
+ schools_enabled: bool
+ new(schools_enabled: bool)
- set_schools_enabled(schools_enabled: bool)
- get_active_player()
- get_opponent_id()
+ to_vec_float()
+ to_vec()
+ to_string_id()
+ who_plays()
+ get_white_player()
+ get_black_player()
+ player_id_by_color(color: Color)
+ player_id(player: & Player)
+ player_color_by_id(player_id: & PlayerId)
+ validate(event: & GameEvent)
+ init_player(player_name: & str)
- add_player(player_id: PlayerId, player: Player)
+ switch_active_player()
+ consume(valid_event: & GameEvent)
- new_pick_up()
- get_rollresult_jans(dice: & Dice)
+ determine_winner()
- inc_roll_count(player_id: PlayerId)
- mark_points(player_id: PlayerId, points: u8)
}
class "Player" {
+ name: String
+ color: Color
+ points: u8
+ holes: u8
+ can_bredouille: bool
+ can_big_bredouille: bool
+ dice_roll_count: u8
+ new(name: String, color: Color)
+ to_bits_string()
+ to_vec()
}
class "PointsRules" {
+ board: Board
+ dice: Dice
+ move_rules: MoveRules
+ new(color: & Color, board: & Board, dice: Dice)
+ set_dice(dice: Dice)
+ update_positions(positions: [ i8 ; 24 ])
- get_jans(board_ini: & Board, dice_rolls_count: u8)
+ get_jans_points(jans: HashMap < Jan , Vec < ( CheckerMove , CheckerMove ) > >)
+ get_points(dice_rolls_count: u8)
+ get_result_jans(dice_rolls_count: u8)
}
"MoveRules" <-- "Board"
"MoveRules" <-- "Dice"
"GameState" <-- "Board"
"HashMap<PlayerId,Player>" <-- "Player"
"GameState" <-- "HashMap<PlayerId,Player>"
"GameState" <-- "Dice"
"PointsRules" <-- "Board"
"PointsRules" <-- "Dice"
"PointsRules" <-- "MoveRules"
@enduml

View file

@ -9,7 +9,7 @@ shell:
runcli:
RUST_LOG=info cargo run --bin=client_cli
runclibots:
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burn_dqn_model.mpk
cargo run --bin=client_cli -- --bot random,dqnburn:./bot/models/burnrl_dqn_40.mpk
#cargo run --bin=client_cli -- --bot dqn:./bot/models/dqn_model_final.json,dummy
# RUST_LOG=info cargo run --bin=client_cli -- --bot dummy,dqn
match:
@ -25,12 +25,13 @@ pythonlib:
trainsimple:
cargo build --release --bin=train_dqn_simple
LD_LIBRARY_PATH=./target/release ./target/release/train_dqn_simple | tee /tmp/train.out
trainbot:
trainbot algo:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn # ok
./bot/scripts/trainValid.sh
plottrainbot:
./bot/scripts/trainValid.sh plot
# ./bot/scripts/trainValid.sh
./bot/scripts/train.sh {{algo}}
plottrainbot algo:
./bot/scripts/train.sh plot {{algo}}
debugtrainbot:
cargo build --bin=train_dqn_burn
RUST_BACKTRACE=1 LD_LIBRARY_PATH=./target/debug ./target/debug/train_dqn_burn

View file

@ -43,7 +43,7 @@ fn main() {
.unwrap();
let mut transport = NetcodeServerTransport::new(current_time, server_config, socket).unwrap();
trace!("❂ TricTrac server listening on {}", SERVER_ADDR);
trace!("❂ TricTrac server listening on {SERVER_ADDR}");
let mut game_state = store::GameState::default();
let mut last_updated = Instant::now();
@ -80,7 +80,7 @@ fn main() {
// Tell all players that a new player has joined
server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("🎉 Client {} connected.", client_id);
info!("🎉 Client {client_id} connected.");
// In TicTacTussle the game can begin once two players has joined
if game_state.players.len() == 2 {
let event = store::GameEvent::BeginGame {
@ -101,7 +101,7 @@ fn main() {
};
game_state.consume(&event);
server.broadcast_message(0, bincode::serialize(&event).unwrap());
info!("Client {} disconnected", client_id);
info!("Client {client_id} disconnected");
// Then end the game, since tic tac toe can't go on with a single player
let event = store::GameEvent::EndGame {
@ -124,7 +124,7 @@ fn main() {
if let Ok(event) = bincode::deserialize::<store::GameEvent>(&message) {
if game_state.validate(&event) {
game_state.consume(&event);
trace!("Player {} sent:\n\t{:#?}", client_id, event);
trace!("Player {client_id} sent:\n\t{event:#?}");
server.broadcast_message(0, bincode::serialize(&event).unwrap());
// Determine if a player has won the game
@ -135,7 +135,7 @@ fn main() {
server.broadcast_message(0, bincode::serialize(&event).unwrap());
}
} else {
warn!("Player {} sent invalid event:\n\t{:#?}", client_id, event);
warn!("Player {client_id} sent invalid event:\n\t{event:#?}");
}
}
}

View file

@ -8,7 +8,7 @@ use std::fmt;
pub type Field = usize;
pub type FieldWithCount = (Field, i8);
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Deserialize)]
#[derive(Debug, Copy, Clone, Serialize, PartialEq, Eq, Deserialize)]
pub struct CheckerMove {
from: Field,
to: Field,
@ -94,7 +94,7 @@ impl CheckerMove {
}
/// Represents the Tric Trac board
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Board {
positions: [i8; 24],
}
@ -158,6 +158,42 @@ impl Board {
.unsigned_abs()
}
// get the number of the last checker in a field
pub fn get_field_checker(&self, color: &Color, field: Field) -> u8 {
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
let mut total_count: u8 = 0;
for (i, checker_count) in self.positions.iter().enumerate() {
// count white checkers (checker_count > 0)
if *checker_count > 0 {
total_count += *checker_count as u8;
if field == i + 1 {
return total_count;
}
}
}
0
}
// get the field of the nth checker
pub fn get_checker_field(&self, color: &Color, checker_pos: u8) -> Option<Field> {
assert_eq!(color, &Color::White); // sinon ajouter la gestion des noirs avec mirror
if checker_pos == 0 {
return None;
}
let mut total_count: u8 = 0;
for (i, checker_count) in self.positions.iter().enumerate() {
// count white checkers (checker_count > 0)
if *checker_count > 0 {
total_count += *checker_count as u8;
}
// return the current field if it contains the checker
if checker_pos <= total_count {
return Some(i + 1);
}
}
None
}
pub fn to_vec(&self) -> Vec<i8> {
self.positions.to_vec()
}
@ -235,7 +271,7 @@ impl Board {
.map(|cells| {
cells
.into_iter()
.map(|cell| format!("{:>5}", cell))
.map(|cell| format!("{cell:>5}"))
.collect::<Vec<String>>()
.join("")
})
@ -246,7 +282,7 @@ impl Board {
.map(|cells| {
cells
.into_iter()
.map(|cell| format!("{:>5}", cell))
.map(|cell| format!("{cell:>5}"))
.collect::<Vec<String>>()
.join("")
})
@ -721,4 +757,32 @@ mod tests {
);
assert_eq!(vec![2], board.get_quarter_filling_candidate(Color::White));
}
#[test]
fn get_checker_field() {
let mut board = Board::new();
board.set_positions(
&Color::White,
[
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
assert_eq!(None, board.get_checker_field(&Color::White, 0));
assert_eq!(Some(3), board.get_checker_field(&Color::White, 5));
assert_eq!(Some(3), board.get_checker_field(&Color::White, 6));
assert_eq!(None, board.get_checker_field(&Color::White, 14));
}
#[test]
fn get_field_checker() {
let mut board = Board::new();
board.set_positions(
&Color::White,
[
3, 1, 2, 2, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
);
assert_eq!(4, board.get_field_checker(&Color::White, 2));
assert_eq!(6, board.get_field_checker(&Color::White, 3));
}
}

View file

@ -44,7 +44,7 @@ impl DiceRoller {
/// Represents the two dice
///
/// Trictrac is always played with two dice.
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize, Default)]
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize, Default)]
pub struct Dice {
/// The two dice values
pub values: (u8, u8),

View file

@ -4,7 +4,7 @@ use crate::dice::Dice;
use crate::game_rules_moves::MoveRules;
use crate::game_rules_points::{PointsRules, PossibleJans};
use crate::player::{Color, Player, PlayerId};
use log::{debug, error, info};
use log::{debug, error};
// use itertools::Itertools;
use serde::{Deserialize, Serialize};
@ -60,7 +60,7 @@ impl From<TurnStage> for u8 {
}
/// Represents a TricTrac game
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct GameState {
pub stage: Stage,
pub turn_stage: TurnStage,
@ -123,6 +123,15 @@ impl GameState {
gs
}
pub fn new_with_players(p1_name: &str, p2_name: &str) -> Self {
let mut game = Self::default();
if let Some(p1) = game.init_player(p1_name) {
game.init_player(p2_name);
game.consume(&GameEvent::BeginGame { goes_first: p1 });
}
game
}
fn set_schools_enabled(&mut self, schools_enabled: bool) {
self.schools_enabled = schools_enabled;
}
@ -244,7 +253,7 @@ impl GameState {
pos_bits.push_str(&white_bits);
pos_bits.push_str(&black_bits);
pos_bits = format!("{:0>108}", pos_bits);
pos_bits = format!("{pos_bits:0>108}");
// println!("{}", pos_bits);
let pos_u8 = pos_bits
.as_bytes()
@ -338,7 +347,7 @@ impl GameState {
return false;
}
}
Roll { player_id } | RollResult { player_id, dice: _ } => {
Roll { player_id } => {
// Check player exists
if !self.players.contains_key(player_id) {
return false;
@ -347,6 +356,26 @@ impl GameState {
if self.active_player_id != *player_id {
return false;
}
// Check the turn stage
if self.turn_stage != TurnStage::RollDice {
error!("bad stage {:?}", self.turn_stage);
return false;
}
}
RollResult { player_id, dice: _ } => {
// Check player exists
if !self.players.contains_key(player_id) {
return false;
}
// Check player is currently the one making their move
if self.active_player_id != *player_id {
return false;
}
// Check the turn stage
if self.turn_stage != TurnStage::RollWaiting {
error!("bad stage {:?}", self.turn_stage);
return false;
}
}
Mark {
player_id,
@ -627,9 +656,7 @@ impl GameState {
fn inc_roll_count(&mut self, player_id: PlayerId) {
self.players.get_mut(&player_id).map(|p| {
if p.dice_roll_count < u8::MAX {
p.dice_roll_count += 1;
}
p.dice_roll_count = p.dice_roll_count.saturating_add(1);
p
});
}
@ -689,14 +716,14 @@ impl GameState {
}
/// The reasons why a game could end
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Deserialize)]
pub enum EndGameReason {
PlayerLeft { player_id: PlayerId },
PlayerWon { winner: PlayerId },
}
/// An event that progresses the GameState forward
#[derive(Debug, Clone, Serialize, PartialEq, Deserialize)]
#[derive(Debug, Clone, Serialize, PartialEq, Eq, Deserialize)]
pub enum GameEvent {
BeginGame {
goes_first: PlayerId,

View file

@ -603,7 +603,7 @@ mod tests {
);
let points_rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 4) });
let jans = points_rules.get_result_jans(8);
assert!(jans.0.len() > 0);
assert!(!jans.0.is_empty());
}
#[test]
@ -628,7 +628,7 @@ mod tests {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, -2,
],
);
let mut rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
let rules = PointsRules::new(&Color::Black, &board, Dice { values: (2, 3) });
assert_eq!(12, rules.get_points(5).0);
// Battre à vrai une dame située dans la table des grands jans : 2 + 2 = 4

View file

@ -4,7 +4,7 @@ use std::fmt;
// This just makes it easier to dissern between a player id and any ol' u64
pub type PlayerId = u64;
#[derive(Copy, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Copy, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Color {
White,
Black,
@ -20,7 +20,7 @@ impl Color {
}
/// Struct for storing player related data.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Player {
pub name: String,
pub color: Color,