refact(docs): clean docs
This commit is contained in:
parent
3b9a1277d8
commit
a0e3cf5f19
16 changed files with 2 additions and 897 deletions
782
doc/refs/bot_rl/spiel_bot_research.md
Normal file
782
doc/refs/bot_rl/spiel_bot_research.md
Normal file
|
|
@ -0,0 +1,782 @@
|
|||
# spiel_bot: Rust-native AlphaZero Training Crate for Trictrac
|
||||
|
||||
## 0. Context and Scope
|
||||
|
||||
The existing `bot` crate already uses **Burn 0.20** with the `burn-rl` library
|
||||
(DQN, PPO, SAC) against a random opponent. It uses the old 36-value `to_vec()`
|
||||
encoding and handles only the `Move`/`HoldOrGoChoice` stages, outsourcing every
|
||||
other stage to an inline random-opponent loop.
|
||||
|
||||
`spiel_bot` is a new workspace crate that replaces the OpenSpiel C++ dependency
|
||||
for **self-play training**. Its goals:
|
||||
|
||||
- Provide a minimal, clean **game-environment abstraction** (the "Rust OpenSpiel")
|
||||
that works with Trictrac's multi-stage turn model and stochastic dice.
|
||||
- Implement **AlphaZero** (MCTS + policy-value network + self-play replay buffer)
|
||||
as the first algorithm.
|
||||
- Remain **modular**: adding DQN or PPO later requires only a new
|
||||
`impl Algorithm for Dqn` without touching the environment or network layers.
|
||||
- Use the 217-value `to_tensor()` encoding and `get_valid_actions()` from
|
||||
`trictrac-store`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Library Landscape
|
||||
|
||||
### 1.1 Neural Network Frameworks
|
||||
|
||||
| Crate | Autodiff | GPU | Pure Rust | Maturity | Notes |
|
||||
| --------------- | ------------------ | --------------------- | ---------------------------- | -------------------------------- | ---------------------------------- |
|
||||
| **Burn 0.20** | yes | wgpu / CUDA (via tch) | yes | active, breaking API every minor | already used in `bot/` |
|
||||
| **tch-rs 0.17** | yes (via LibTorch) | CUDA / MPS | no (requires LibTorch ~2 GB) | very mature | full PyTorch; best raw performance |
|
||||
| **Candle 0.8** | partial | CUDA | yes | stable, HuggingFace-backed | better for inference than training |
|
||||
| ndarray alone | no | no | yes | mature | array ops only; no autograd |
|
||||
|
||||
**Recommendation: Burn** — consistent with the existing `bot/` crate, no C++
|
||||
runtime needed, the `ndarray` backend is sufficient for CPU training and can
|
||||
switch to `wgpu` (GPU without CUDA driver) or `tch` (LibTorch, fastest) by
|
||||
changing one type alias.
|
||||
|
||||
`tch-rs` would be the best choice for raw training throughput (it is the most
|
||||
battle-tested backend for RL) but adds a 2 GB LibTorch download and breaks the
|
||||
pure-Rust constraint. If training speed becomes the bottleneck after prototyping,
|
||||
switching `spiel_bot` to `tch-rs` is a one-line backend swap.
|
||||
|
||||
### 1.2 Other Key Crates
|
||||
|
||||
| Crate | Role |
|
||||
| -------------------- | ----------------------------------------------------------------- |
|
||||
| `rand 0.9` | dice sampling, replay buffer shuffling (already in store) |
|
||||
| `rayon` | parallel self-play: `(0..n_games).into_par_iter().map(play_game)` |
|
||||
| `crossbeam-channel` | optional producer/consumer pipeline (self-play workers → trainer) |
|
||||
| `serde / serde_json` | replay buffer snapshots, checkpoint metadata |
|
||||
| `anyhow` | error propagation (already used everywhere) |
|
||||
| `indicatif` | training progress bars |
|
||||
| `tracing` | structured logging per episode/iteration |
|
||||
|
||||
### 1.3 What `burn-rl` Provides (and Does Not)
|
||||
|
||||
The external `burn-rl` crate (from `github.com/yunjhongwu/burn-rl-examples`)
|
||||
provides DQN, PPO, SAC agents via a `burn_rl::base::{Environment, State, Action}`
|
||||
trait. It does **not** provide:
|
||||
|
||||
- MCTS or any tree-search algorithm
|
||||
- Two-player self-play
|
||||
- Legal action masking during training
|
||||
- Chance-node handling
|
||||
|
||||
For AlphaZero, `burn-rl` is not useful. The `spiel_bot` crate will define its
|
||||
own (simpler, more targeted) traits and implement MCTS from scratch.
|
||||
|
||||
---
|
||||
|
||||
## 2. Trictrac-Specific Design Constraints
|
||||
|
||||
### 2.1 Multi-Stage Turn Model
|
||||
|
||||
A Trictrac turn passes through up to six `TurnStage` values. Only two involve
|
||||
genuine player choice:
|
||||
|
||||
| TurnStage | Node type | Handler |
|
||||
| ---------------- | ------------------------------- | ------------------------------- |
|
||||
| `RollDice` | Forced (player initiates roll) | Auto-apply `GameEvent::Roll` |
|
||||
| `RollWaiting` | **Chance** (dice outcome) | Sample dice, apply `RollResult` |
|
||||
| `MarkPoints` | Forced (score is deterministic) | Auto-apply `GameEvent::Mark` |
|
||||
| `HoldOrGoChoice` | **Player decision** | MCTS / policy network |
|
||||
| `Move` | **Player decision** | MCTS / policy network |
|
||||
| `MarkAdvPoints` | Forced | Auto-apply `GameEvent::Mark` |
|
||||
|
||||
The environment wrapper advances through forced/chance stages automatically so
|
||||
that from the algorithm's perspective every node it sees is a genuine player
|
||||
decision.
|
||||
|
||||
### 2.2 Stochastic Dice in MCTS
|
||||
|
||||
AlphaZero was designed for deterministic games (Chess, Go). For Trictrac, dice
|
||||
introduce stochasticity. Three approaches exist:
|
||||
|
||||
**A. Outcome sampling (recommended)**
|
||||
During each MCTS simulation, when a chance node is reached, sample one dice
|
||||
outcome at random and continue. After many simulations the expected value
|
||||
converges. This is the approach used by OpenSpiel's MCTS for stochastic games
|
||||
and requires no changes to the standard PUCT formula.
|
||||
|
||||
**B. Chance-node averaging (expectimax)**
|
||||
At each chance node, expand all 21 unique dice pairs weighted by their
|
||||
probability (doublet: 1/36 each × 6; non-doublet: 2/36 each × 15). This is
|
||||
exact but multiplies the branching factor by ~21 at every dice roll, making it
|
||||
prohibitively expensive.
|
||||
|
||||
**C. Condition on dice in the observation (current approach)**
|
||||
Dice values are already encoded at indices [192–193] of `to_tensor()`. The
|
||||
network naturally conditions on the rolled dice when it evaluates a position.
|
||||
MCTS only runs on player-decision nodes _after_ the dice have been sampled;
|
||||
chance nodes are bypassed by the environment wrapper (approach A). The policy
|
||||
and value heads learn to play optimally given any dice pair.
|
||||
|
||||
**Use approach A + C together**: the environment samples dice automatically
|
||||
(chance node bypass), and the 217-dim tensor encodes the dice so the network
|
||||
can exploit them.
|
||||
|
||||
### 2.3 Perspective / Mirroring
|
||||
|
||||
All move rules and tensor encoding are defined from White's perspective.
|
||||
`to_tensor()` must always be called after mirroring the state for Black.
|
||||
The environment wrapper handles this transparently: every observation returned
|
||||
to an algorithm is already in the active player's perspective.
|
||||
|
||||
### 2.4 Legal Action Masking
|
||||
|
||||
A crucial difference from the existing `bot/` code: instead of penalizing
|
||||
invalid actions with `ERROR_REWARD`, the policy head logits are **masked**
|
||||
before softmax — illegal action logits are set to `-inf`. This prevents the
|
||||
network from wasting capacity on illegal moves and eliminates the need for the
|
||||
penalty-reward hack.
|
||||
|
||||
---
|
||||
|
||||
## 3. Proposed Crate Architecture
|
||||
|
||||
```
|
||||
spiel_bot/
|
||||
├── Cargo.toml
|
||||
└── src/
|
||||
├── lib.rs # re-exports; feature flags: "alphazero", "dqn", "ppo"
|
||||
│
|
||||
├── env/
|
||||
│ ├── mod.rs # GameEnv trait — the minimal OpenSpiel interface
|
||||
│ └── trictrac.rs # TrictracEnv: impl GameEnv using trictrac-store
|
||||
│
|
||||
├── mcts/
|
||||
│ ├── mod.rs # MctsConfig, run_mcts() entry point
|
||||
│ ├── node.rs # MctsNode (visit count, W, prior, children)
|
||||
│ └── search.rs # simulate(), backup(), select_action()
|
||||
│
|
||||
├── network/
|
||||
│ ├── mod.rs # PolicyValueNet trait
|
||||
│ └── resnet.rs # Burn ResNet: Linear + residual blocks + two heads
|
||||
│
|
||||
├── alphazero/
|
||||
│ ├── mod.rs # AlphaZeroConfig
|
||||
│ ├── selfplay.rs # generate_episode() -> Vec<TrainSample>
|
||||
│ ├── replay.rs # ReplayBuffer (VecDeque, capacity, shuffle)
|
||||
│ └── trainer.rs # training loop: selfplay → sample → loss → update
|
||||
│
|
||||
└── agent/
|
||||
├── mod.rs # Agent trait
|
||||
├── random.rs # RandomAgent (baseline)
|
||||
└── mcts_agent.rs # MctsAgent: uses trained network for inference
|
||||
```
|
||||
|
||||
Future algorithms slot in without touching the above:
|
||||
|
||||
```
|
||||
├── dqn/ # (future) DQN: impl Algorithm + own replay buffer
|
||||
└── ppo/ # (future) PPO: impl Algorithm + rollout buffer
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Core Traits
|
||||
|
||||
### 4.1 `GameEnv` — the minimal OpenSpiel interface
|
||||
|
||||
```rust
|
||||
use rand::Rng;
|
||||
|
||||
/// Who controls the current node.
|
||||
pub enum Player {
|
||||
P1, // player index 0
|
||||
P2, // player index 1
|
||||
Chance, // dice roll
|
||||
Terminal, // game over
|
||||
}
|
||||
|
||||
pub trait GameEnv: Clone + Send + Sync + 'static {
|
||||
type State: Clone + Send + Sync;
|
||||
|
||||
/// Fresh game state.
|
||||
fn new_game(&self) -> Self::State;
|
||||
|
||||
/// Who acts at this node.
|
||||
fn current_player(&self, s: &Self::State) -> Player;
|
||||
|
||||
/// Legal action indices (always in [0, action_space())).
|
||||
/// Empty only at Terminal nodes.
|
||||
fn legal_actions(&self, s: &Self::State) -> Vec<usize>;
|
||||
|
||||
/// Apply a player action (must be legal).
|
||||
fn apply(&self, s: &mut Self::State, action: usize);
|
||||
|
||||
/// Advance a Chance node by sampling dice; no-op at non-Chance nodes.
|
||||
fn apply_chance(&self, s: &mut Self::State, rng: &mut impl Rng);
|
||||
|
||||
/// Observation tensor from `pov`'s perspective (0 or 1).
|
||||
/// Returns 217 f32 values for Trictrac.
|
||||
fn observation(&self, s: &Self::State, pov: usize) -> Vec<f32>;
|
||||
|
||||
/// Flat observation size (217 for Trictrac).
|
||||
fn obs_size(&self) -> usize;
|
||||
|
||||
/// Total action-space size (514 for Trictrac).
|
||||
fn action_space(&self) -> usize;
|
||||
|
||||
/// Game outcome per player, or None if not Terminal.
|
||||
/// Values in [-1, 1]: +1 = win, -1 = loss, 0 = draw.
|
||||
fn returns(&self, s: &Self::State) -> Option<[f32; 2]>;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 `PolicyValueNet` — neural network interface
|
||||
|
||||
```rust
|
||||
use burn::prelude::*;
|
||||
|
||||
pub trait PolicyValueNet<B: Backend>: Send + Sync {
|
||||
/// Forward pass.
|
||||
/// `obs`: [batch, obs_size] tensor.
|
||||
/// Returns: (policy_logits [batch, action_space], value [batch]).
|
||||
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
|
||||
|
||||
/// Save weights to `path`.
|
||||
fn save(&self, path: &std::path::Path) -> anyhow::Result<()>;
|
||||
|
||||
/// Load weights from `path`.
|
||||
fn load(path: &std::path::Path) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 `Agent` — player policy interface
|
||||
|
||||
```rust
|
||||
pub trait Agent: Send {
|
||||
/// Select an action index given the current game state observation.
|
||||
/// `legal`: mask of valid action indices.
|
||||
fn select_action(&mut self, obs: &[f32], legal: &[usize]) -> usize;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. MCTS Implementation
|
||||
|
||||
### 5.1 Node
|
||||
|
||||
```rust
|
||||
pub struct MctsNode {
|
||||
n: u32, // visit count N(s, a)
|
||||
w: f32, // sum of backed-up values W(s, a)
|
||||
p: f32, // prior from policy head P(s, a)
|
||||
children: Vec<(usize, MctsNode)>, // (action_idx, child)
|
||||
is_expanded: bool,
|
||||
}
|
||||
|
||||
impl MctsNode {
|
||||
pub fn q(&self) -> f32 {
|
||||
if self.n == 0 { 0.0 } else { self.w / self.n as f32 }
|
||||
}
|
||||
|
||||
/// PUCT score used for selection.
|
||||
pub fn puct(&self, parent_n: u32, c_puct: f32) -> f32 {
|
||||
self.q() + c_puct * self.p * (parent_n as f32).sqrt() / (1.0 + self.n as f32)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 Simulation Loop
|
||||
|
||||
One MCTS simulation (for deterministic decision nodes):
|
||||
|
||||
```
|
||||
1. SELECTION — traverse from root, always pick child with highest PUCT,
|
||||
auto-advancing forced/chance nodes via env.apply_chance().
|
||||
2. EXPANSION — at first unvisited leaf: call network.forward(obs) to get
|
||||
(policy_logits, value). Mask illegal actions, softmax
|
||||
the remaining logits → priors P(s,a) for each child.
|
||||
3. BACKUP — propagate -value up the tree (negate at each level because
|
||||
perspective alternates between P1 and P2).
|
||||
```
|
||||
|
||||
After `n_simulations` iterations, action selection at the root:
|
||||
|
||||
```rust
|
||||
// During training: sample proportional to N^(1/temperature)
|
||||
// During evaluation: argmax N
|
||||
fn select_action(root: &MctsNode, temperature: f32) -> usize { ... }
|
||||
```
|
||||
|
||||
### 5.3 Configuration
|
||||
|
||||
```rust
|
||||
pub struct MctsConfig {
|
||||
pub n_simulations: usize, // e.g. 200
|
||||
pub c_puct: f32, // exploration constant, e.g. 1.5
|
||||
pub dirichlet_alpha: f32, // root noise for exploration, e.g. 0.3
|
||||
pub dirichlet_eps: f32, // noise weight, e.g. 0.25
|
||||
pub temperature: f32, // action sampling temperature (anneals to 0)
|
||||
}
|
||||
```
|
||||
|
||||
### 5.4 Handling Chance Nodes Inside MCTS
|
||||
|
||||
When simulation reaches a Chance node (dice roll), the environment automatically
|
||||
samples dice and advances to the next decision node. The MCTS tree does **not**
|
||||
branch on dice outcomes — it treats the sampled outcome as the state. This
|
||||
corresponds to "outcome sampling" (approach A from §2.2). Because each
|
||||
simulation independently samples dice, the Q-values at player nodes converge to
|
||||
their expected value over many simulations.
|
||||
|
||||
---
|
||||
|
||||
## 6. Network Architecture
|
||||
|
||||
### 6.1 ResNet Policy-Value Network
|
||||
|
||||
A single trunk with residual blocks, then two heads:
|
||||
|
||||
```
|
||||
Input: [batch, 217]
|
||||
↓
|
||||
Linear(217 → 512) + ReLU
|
||||
↓
|
||||
ResBlock × 4 (Linear(512→512) + BN + ReLU + Linear(512→512) + BN + skip + ReLU)
|
||||
↓ trunk output [batch, 512]
|
||||
├─ Policy head: Linear(512 → 514) → logits (masked softmax at use site)
|
||||
└─ Value head: Linear(512 → 1) → tanh (output in [-1, 1])
|
||||
```
|
||||
|
||||
Burn implementation sketch:
|
||||
|
||||
```rust
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TrictracNet<B: Backend> {
|
||||
input: Linear<B>,
|
||||
res_blocks: Vec<ResBlock<B>>,
|
||||
policy_head: Linear<B>,
|
||||
value_head: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TrictracNet<B> {
|
||||
pub fn forward(&self, obs: Tensor<B, 2>)
|
||||
-> (Tensor<B, 2>, Tensor<B, 1>)
|
||||
{
|
||||
let x = activation::relu(self.input.forward(obs));
|
||||
let x = self.res_blocks.iter().fold(x, |x, b| b.forward(x));
|
||||
let policy = self.policy_head.forward(x.clone()); // raw logits
|
||||
let value = activation::tanh(self.value_head.forward(x))
|
||||
.squeeze(1);
|
||||
(policy, value)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
A simpler MLP (no residual blocks) is sufficient for a first version and much
|
||||
faster to train: `Linear(217→512) + ReLU + Linear(512→256) + ReLU` then two
|
||||
heads.
|
||||
|
||||
### 6.2 Loss Function
|
||||
|
||||
```
|
||||
L = MSE(value_pred, z)
|
||||
+ CrossEntropy(policy_logits_masked, π_mcts)
|
||||
- c_l2 * L2_regularization
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `z` = game outcome (±1) from the active player's perspective
|
||||
- `π_mcts` = normalized MCTS visit counts at the root (the policy target)
|
||||
- Legal action masking is applied before computing CrossEntropy
|
||||
|
||||
---
|
||||
|
||||
## 7. AlphaZero Training Loop
|
||||
|
||||
```
|
||||
INIT
|
||||
network ← random weights
|
||||
replay ← empty ReplayBuffer(capacity = 100_000)
|
||||
|
||||
LOOP forever:
|
||||
── Self-play phase ──────────────────────────────────────────────
|
||||
(parallel with rayon, n_workers games at once)
|
||||
for each game:
|
||||
state ← env.new_game()
|
||||
samples = []
|
||||
while not terminal:
|
||||
advance forced/chance nodes automatically
|
||||
obs ← env.observation(state, current_player)
|
||||
legal ← env.legal_actions(state)
|
||||
π, root_value ← mcts.run(state, network, config)
|
||||
action ← sample from π (with temperature)
|
||||
samples.push((obs, π, current_player))
|
||||
env.apply(state, action)
|
||||
z ← env.returns(state) // final scores
|
||||
for (obs, π, player) in samples:
|
||||
replay.push(TrainSample { obs, policy: π, value: z[player] })
|
||||
|
||||
── Training phase ───────────────────────────────────────────────
|
||||
for each gradient step:
|
||||
batch ← replay.sample(batch_size)
|
||||
(policy_logits, value_pred) ← network.forward(batch.obs)
|
||||
loss ← mse(value_pred, batch.value) + xent(policy_logits, batch.policy)
|
||||
optimizer.step(loss.backward())
|
||||
|
||||
── Evaluation (every N iterations) ─────────────────────────────
|
||||
win_rate ← evaluate(network_new vs network_prev, n_eval_games)
|
||||
if win_rate > 0.55: save checkpoint
|
||||
```
|
||||
|
||||
### 7.1 Replay Buffer
|
||||
|
||||
```rust
|
||||
pub struct TrainSample {
|
||||
pub obs: Vec<f32>, // 217 values
|
||||
pub policy: Vec<f32>, // 514 values (normalized MCTS visit counts)
|
||||
pub value: f32, // game outcome ∈ {-1, 0, +1}
|
||||
}
|
||||
|
||||
pub struct ReplayBuffer {
|
||||
data: VecDeque<TrainSample>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl ReplayBuffer {
|
||||
pub fn push(&mut self, s: TrainSample) {
|
||||
if self.data.len() == self.capacity { self.data.pop_front(); }
|
||||
self.data.push_back(s);
|
||||
}
|
||||
|
||||
pub fn sample(&self, n: usize, rng: &mut impl Rng) -> Vec<&TrainSample> {
|
||||
// sample without replacement
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7.2 Parallelism Strategy
|
||||
|
||||
Self-play is embarrassingly parallel (each game is independent):
|
||||
|
||||
```rust
|
||||
let samples: Vec<TrainSample> = (0..n_games)
|
||||
.into_par_iter() // rayon
|
||||
.flat_map(|_| generate_episode(&env, &network, &mcts_config))
|
||||
.collect();
|
||||
```
|
||||
|
||||
Note: Burn's `NdArray` backend is not `Send` by default when using autodiff.
|
||||
Self-play uses inference-only (no gradient tape), so a `NdArray<f32>` backend
|
||||
(without `Autodiff` wrapper) is `Send`. Training runs on the main thread with
|
||||
`Autodiff<NdArray<f32>>`.
|
||||
|
||||
For larger scale, a producer-consumer architecture (crossbeam-channel) separates
|
||||
self-play workers from the training thread, allowing continuous data generation
|
||||
while the GPU trains.
|
||||
|
||||
---
|
||||
|
||||
## 8. `TrictracEnv` Implementation Sketch
|
||||
|
||||
```rust
|
||||
use trictrac_store::{
|
||||
training_common::{get_valid_actions, TrictracAction, ACTION_SPACE_SIZE},
|
||||
Dice, DiceRoller, GameEvent, GameState, Stage, TurnStage,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TrictracEnv;
|
||||
|
||||
impl GameEnv for TrictracEnv {
|
||||
type State = GameState;
|
||||
|
||||
fn new_game(&self) -> GameState {
|
||||
GameState::new_with_players("P1", "P2")
|
||||
}
|
||||
|
||||
fn current_player(&self, s: &GameState) -> Player {
|
||||
match s.stage {
|
||||
Stage::Ended => Player::Terminal,
|
||||
_ => match s.turn_stage {
|
||||
TurnStage::RollWaiting => Player::Chance,
|
||||
_ => if s.active_player_id == 1 { Player::P1 } else { Player::P2 },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn legal_actions(&self, s: &GameState) -> Vec<usize> {
|
||||
let view = if s.active_player_id == 2 { s.mirror() } else { s.clone() };
|
||||
get_valid_action_indices(&view).unwrap_or_default()
|
||||
}
|
||||
|
||||
fn apply(&self, s: &mut GameState, action_idx: usize) {
|
||||
// advance all forced/chance nodes first, then apply the player action
|
||||
self.advance_forced(s);
|
||||
let needs_mirror = s.active_player_id == 2;
|
||||
let view = if needs_mirror { s.mirror() } else { s.clone() };
|
||||
if let Some(event) = TrictracAction::from_action_index(action_idx)
|
||||
.and_then(|a| a.to_event(&view))
|
||||
.map(|e| if needs_mirror { e.get_mirror(false) } else { e })
|
||||
{
|
||||
let _ = s.consume(&event);
|
||||
}
|
||||
// advance any forced stages that follow
|
||||
self.advance_forced(s);
|
||||
}
|
||||
|
||||
fn apply_chance(&self, s: &mut GameState, rng: &mut impl Rng) {
|
||||
// RollDice → RollWaiting
|
||||
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||
// RollWaiting → next stage
|
||||
let dice = Dice { values: (rng.random_range(1u8..=6), rng.random_range(1u8..=6)) };
|
||||
let _ = s.consume(&GameEvent::RollResult { player_id: s.active_player_id, dice });
|
||||
self.advance_forced(s);
|
||||
}
|
||||
|
||||
fn observation(&self, s: &GameState, pov: usize) -> Vec<f32> {
|
||||
if pov == 0 { s.to_tensor() } else { s.mirror().to_tensor() }
|
||||
}
|
||||
|
||||
fn obs_size(&self) -> usize { 217 }
|
||||
fn action_space(&self) -> usize { ACTION_SPACE_SIZE }
|
||||
|
||||
fn returns(&self, s: &GameState) -> Option<[f32; 2]> {
|
||||
if s.stage != Stage::Ended { return None; }
|
||||
// Convert hole+point scores to ±1 outcome
|
||||
let s1 = s.players.get(&1).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||
let s2 = s.players.get(&2).map(|p| p.holes as i32 * 12 + p.points as i32).unwrap_or(0);
|
||||
Some(match s1.cmp(&s2) {
|
||||
std::cmp::Ordering::Greater => [ 1.0, -1.0],
|
||||
std::cmp::Ordering::Less => [-1.0, 1.0],
|
||||
std::cmp::Ordering::Equal => [ 0.0, 0.0],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TrictracEnv {
|
||||
/// Advance through all forced (non-decision, non-chance) stages.
|
||||
fn advance_forced(&self, s: &mut GameState) {
|
||||
use trictrac_store::PointsRules;
|
||||
loop {
|
||||
match s.turn_stage {
|
||||
TurnStage::MarkPoints | TurnStage::MarkAdvPoints => {
|
||||
// Scoring is deterministic; compute and apply automatically.
|
||||
let color = s.player_color_by_id(&s.active_player_id)
|
||||
.unwrap_or(trictrac_store::Color::White);
|
||||
let drc = s.players.get(&s.active_player_id)
|
||||
.map(|p| p.dice_roll_count).unwrap_or(0);
|
||||
let pr = PointsRules::new(&color, &s.board, s.dice);
|
||||
let pts = pr.get_points(drc);
|
||||
let points = if s.turn_stage == TurnStage::MarkPoints { pts.0 } else { pts.1 };
|
||||
let _ = s.consume(&GameEvent::Mark {
|
||||
player_id: s.active_player_id, points,
|
||||
});
|
||||
}
|
||||
TurnStage::RollDice => {
|
||||
// RollDice is a forced "initiate roll" action with no real choice.
|
||||
let _ = s.consume(&GameEvent::Roll { player_id: s.active_player_id });
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Cargo.toml Changes
|
||||
|
||||
### 9.1 Add `spiel_bot` to the workspace
|
||||
|
||||
```toml
|
||||
# Cargo.toml (workspace root)
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["client_cli", "bot", "store", "spiel_bot"]
|
||||
```
|
||||
|
||||
### 9.2 `spiel_bot/Cargo.toml`
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "spiel_bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["alphazero"]
|
||||
alphazero = []
|
||||
# dqn = [] # future
|
||||
# ppo = [] # future
|
||||
|
||||
[dependencies]
|
||||
trictrac-store = { path = "../store" }
|
||||
anyhow = "1"
|
||||
rand = "0.9"
|
||||
rayon = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Burn: NdArray for pure-Rust CPU training
|
||||
# Replace NdArray with Wgpu or Tch for GPU.
|
||||
burn = { version = "0.20", features = ["ndarray", "autodiff"] }
|
||||
|
||||
# Optional: progress display and structured logging
|
||||
indicatif = "0.17"
|
||||
tracing = "0.1"
|
||||
|
||||
[[bin]]
|
||||
name = "az_train"
|
||||
path = "src/bin/az_train.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "az_eval"
|
||||
path = "src/bin/az_eval.rs"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Comparison: `bot` crate vs `spiel_bot`
|
||||
|
||||
| Aspect | `bot` (existing) | `spiel_bot` (proposed) |
|
||||
| ---------------- | --------------------------- | -------------------------------------------- |
|
||||
| State encoding | 36 i8 `to_vec()` | 217 f32 `to_tensor()` |
|
||||
| Algorithms | DQN, PPO, SAC via `burn-rl` | AlphaZero (MCTS) |
|
||||
| Opponent | hardcoded random | self-play |
|
||||
| Invalid actions | penalise with reward | legal action mask (no penalty) |
|
||||
| Dice handling | inline sampling in step() | `Chance` node in `GameEnv` trait |
|
||||
| Stochastic turns | manual per-stage code | `advance_forced()` in env wrapper |
|
||||
| Burn dep | yes (0.20) | yes (0.20), same backend |
|
||||
| `burn-rl` dep | yes | no |
|
||||
| C++ dep | no | no |
|
||||
| Python dep | no | no |
|
||||
| Modularity | one entry point per algo | `GameEnv` + `Agent` traits; algo is a plugin |
|
||||
|
||||
The two crates are **complementary**: `bot` is a working DQN/PPO baseline;
|
||||
`spiel_bot` adds MCTS-based self-play on top of a cleaner abstraction. The
|
||||
`TrictracEnv` in `spiel_bot` can also back-fill into `bot` if desired (just
|
||||
replace `TrictracEnvironment` with `TrictracEnv`).
|
||||
|
||||
---
|
||||
|
||||
## 11. Implementation Order
|
||||
|
||||
1. **`env/`**: `GameEnv` trait + `TrictracEnv` + unit tests (run a random game
|
||||
through the trait, verify terminal state and returns).
|
||||
2. **`network/`**: `PolicyValueNet` trait + MLP stub (no residual blocks yet) +
|
||||
Burn forward/backward pass test with dummy data.
|
||||
3. **`mcts/`**: `MctsNode` + `simulate()` + `select_action()` + property tests
|
||||
(visit counts sum to `n_simulations`, legal mask respected).
|
||||
4. **`alphazero/`**: `generate_episode()` + `ReplayBuffer` + training loop stub
|
||||
(one iteration, check loss decreases).
|
||||
5. **Integration test**: run 100 self-play games with a tiny network (1 res block,
|
||||
64 hidden units), verify the training loop completes without panics.
|
||||
6. **Benchmarks**: measure games/second, steps/second (target: ≥ 500 games/s
|
||||
on CPU, consistent with `random_game` throughput).
|
||||
7. **Upgrade network**: 4 residual blocks, 512 hidden units; schedule
|
||||
hyperparameter sweep.
|
||||
8. **`az_eval` binary**: play `MctsAgent` (trained) vs `RandomAgent`, report
|
||||
win rate every checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## 12. Key Open Questions
|
||||
|
||||
1. **Scoring as returns**: Trictrac scores (holes × 12 + points) are unbounded.
|
||||
AlphaZero needs ±1 returns. Simple option: win/loss at game end (whoever
|
||||
scored more holes). Better option: normalize the score margin. The final
|
||||
choice affects how the value head is trained.
|
||||
|
||||
2. **Episode length**: Trictrac games average ~600 steps (`random_game` data).
|
||||
MCTS with 200 simulations per step means ~120k network evaluations per game.
|
||||
At batch inference this is feasible on CPU; on GPU it becomes fast. Consider
|
||||
limiting `n_simulations` to 50–100 for early training.
|
||||
|
||||
3. **`HoldOrGoChoice` strategy**: The `Go` action resets the board (new relevé).
|
||||
This is a long-horizon decision that AlphaZero handles naturally via MCTS
|
||||
lookahead, but needs careful value normalization (a "Go" restarts scoring
|
||||
within the same game).
|
||||
|
||||
4. **`burn-rl` reuse**: The existing DQN/PPO code in `bot/` could be migrated
|
||||
to use `TrictracEnv` from `spiel_bot`, consolidating the environment logic.
|
||||
This is optional but reduces code duplication.
|
||||
|
||||
5. **Dirichlet noise parameters**: Standard AlphaZero uses α = 0.3 for Chess,
|
||||
0.03 for Go. For Trictrac with action space 514, empirical tuning is needed.
|
||||
A reasonable starting point: α = 10 / mean_legal_actions ≈ 0.1.
|
||||
|
||||
## Implementation results
|
||||
|
||||
All benchmarks compile and run. Here's the complete results table:
|
||||
|
||||
| Group | Benchmark | Time |
|
||||
| ------- | ----------------------- | --------------------- |
|
||||
| env | apply_chance | 3.87 µs |
|
||||
| | legal_actions | 1.91 µs |
|
||||
| | observation (to_tensor) | 341 ns |
|
||||
| | random_game (baseline) | 3.55 ms → 282 games/s |
|
||||
| network | mlp_b1 hidden=64 | 94.9 µs |
|
||||
| | mlp_b32 hidden=64 | 141 µs |
|
||||
| | mlp_b1 hidden=256 | 352 µs |
|
||||
| | mlp_b32 hidden=256 | 479 µs |
|
||||
| mcts | zero_eval n=1 | 6.8 µs |
|
||||
| | zero_eval n=5 | 23.9 µs |
|
||||
| | zero_eval n=20 | 90.9 µs |
|
||||
| | mlp64 n=1 | 203 µs |
|
||||
| | mlp64 n=5 | 622 µs |
|
||||
| | mlp64 n=20 | 2.30 ms |
|
||||
| episode | trictrac n=1 | 51.8 ms → 19 games/s |
|
||||
| | trictrac n=2 | 145 ms → 7 games/s |
|
||||
| train | mlp64 Adam b=16 | 1.93 ms |
|
||||
| | mlp64 Adam b=64 | 2.68 ms |
|
||||
|
||||
Key observations:
|
||||
|
||||
- random_game baseline: 282 games/s (short of the ≥ 500 target — game state ops dominate at 3.9 µs/apply_chance, ~600 steps/game)
|
||||
- observation (217-value tensor): only 341 ns — not a bottleneck
|
||||
- legal_actions: 1.9 µs — well optimised
|
||||
- Network (MLP hidden=64): 95 µs per call — the dominant MCTS cost; with n=1 each episode step costs ~200 µs
|
||||
- Tree traversal (zero_eval): only 6.8 µs for n=1 — MCTS overhead is minimal
|
||||
- Full episode n=1: 51.8 ms (19 games/s); the 95 µs × ~2 calls × ~600 moves accounts for most of it
|
||||
- Training: 2.7 ms/step at batch=64 → 370 steps/s
|
||||
|
||||
### Summary of Step 8
|
||||
|
||||
spiel_bot/src/bin/az_eval.rs — a self-contained evaluation binary:
|
||||
|
||||
- CLI flags: --checkpoint, --arch mlp|resnet, --hidden, --n-games, --n-sim, --seed, --c-puct
|
||||
- No checkpoint → random weights (useful as a sanity baseline — should converge toward 50%)
|
||||
- Game loop: alternates MctsAgent as P1 / P2 against a RandomAgent, n_games per side
|
||||
- MctsAgent: run_mcts + greedy select_action (temperature=0, no Dirichlet noise)
|
||||
- Output: win/draw/loss per side + combined decisive win rate
|
||||
|
||||
Typical usage after training:
|
||||
cargo run -p spiel_bot --bin az_eval --release -- \
|
||||
--checkpoint checkpoints/iter_100.mpk --arch resnet --n-games 200 --n-sim 100
|
||||
|
||||
### az_train
|
||||
|
||||
#### Fresh MLP training (default: 100 iters, 10 games, 100 sims, save every 10)
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release
|
||||
|
||||
#### ResNet, more games, custom output dir
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--arch resnet --n-iter 200 --n-games 20 --n-sim 100 \
|
||||
--save-every 20 --out checkpoints/
|
||||
|
||||
#### Resume from iteration 50
|
||||
|
||||
cargo run -p spiel_bot --bin az_train --release -- \
|
||||
--resume checkpoints/iter_0050.mpk --arch mlp --n-iter 50
|
||||
|
||||
What the binary does each iteration:
|
||||
|
||||
1. Calls model.valid() to get a zero-overhead inference copy for self-play
|
||||
2. Runs n_games episodes via generate_episode (temperature=1 for first --temp-drop moves, then greedy)
|
||||
3. Pushes samples into a ReplayBuffer (capacity --replay-cap)
|
||||
4. Runs n_train gradient steps via train_step with cosine LR annealing from --lr down to --lr-min
|
||||
5. Saves a .mpk checkpoint every --save-every iterations and always on the last
|
||||
253
doc/refs/bot_rl/tensor_research.md
Normal file
253
doc/refs/bot_rl/tensor_research.md
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
# Tensor research
|
||||
|
||||
## Current tensor anatomy
|
||||
|
||||
[0..23] board.positions[i]: i8 ∈ [-15,+15], positive=white, negative=black (combined!)
|
||||
[24] active player color: 0 or 1
|
||||
[25] turn_stage: 1–5
|
||||
[26–27] dice values (raw 1–6)
|
||||
[28–31] white: points, holes, can_bredouille, can_big_bredouille
|
||||
[32–35] black: same
|
||||
─────────────────────────────────
|
||||
Total 36 floats
|
||||
|
||||
The C++ side (ObservationTensorShape() → {kStateEncodingSize}) treats this as a flat 1D vector, so OpenSpiel's
|
||||
AlphaZero uses a fully-connected network.
|
||||
|
||||
### Fundamental problems with the current encoding
|
||||
|
||||
1. Colors mixed into a signed integer. A single value encodes both whose checker is there and how many. The network
|
||||
must learn from a value of -3 that (a) it's the opponent, (b) there are 3 of them, and (c) both facts interact with
|
||||
all the quarter-filling logic. Two separate, semantically clean channels would be much easier to learn from.
|
||||
|
||||
2. No normalization. Dice (1–6), counts (−15 to +15), booleans (0/1), points (0–12) coexist without scaling. Gradient
|
||||
flow during training is uneven.
|
||||
|
||||
3. Quarter fill status is completely absent. Filling a quarter is the dominant strategic goal in Trictrac — it
|
||||
triggers all scoring. The network has to discover from raw counts that six adjacent fields each having ≥2 checkers
|
||||
produces a score. Including this explicitly is the single highest-value addition.
|
||||
|
||||
4. Exit readiness is absent. Whether all own checkers are in the last quarter (fields 19–24) governs an entirely
|
||||
different mode of play. Knowing this explicitly avoids the network having to sum 18 entries and compare against 0.
|
||||
|
||||
5. dice_roll_count is missing. Used for "jan de 3 coups" (must fill the small jan within 3 dice rolls from the
|
||||
starting position). It's in the Player struct but not exported.
|
||||
|
||||
## Key Trictrac distinctions from backgammon that shape the encoding
|
||||
|
||||
| Concept | Backgammon | Trictrac |
|
||||
| ------------------------- | ---------------------- | --------------------------------------------------------- |
|
||||
| Hitting a blot | Removes checker to bar | Scores points, checker stays |
|
||||
| 1-checker field | Vulnerable (bar risk) | Vulnerable (battage target) but not physically threatened |
|
||||
| 2-checker field | Safe "point" | Minimum for quarter fill (critical threshold) |
|
||||
| 3-checker field | Safe with spare | Safe with spare |
|
||||
| Strategic goal early | Block and prime | Fill quarters (all 6 fields ≥ 2) |
|
||||
| Both colors on a field | Impossible | Perfectly legal |
|
||||
| Rest corner (field 12/13) | Does not exist | Special two-checker rules |
|
||||
|
||||
The critical thresholds — 1, 2, 3 — align exactly with TD-Gammon's encoding rationale. Splitting them into binary
|
||||
indicators directly teaches the network the phase transitions the game hinges on.
|
||||
|
||||
## Options
|
||||
|
||||
### Option A — Separated colors, TD-Gammon per-field encoding (flat 1D)
|
||||
|
||||
The minimum viable improvement.
|
||||
|
||||
For each of the 24 fields, encode own and opponent separately with 4 indicators each:
|
||||
|
||||
own_1[i]: 1.0 if exactly 1 own checker at field i (blot — battage target)
|
||||
own_2[i]: 1.0 if exactly 2 own checkers (minimum for quarter fill)
|
||||
own_3[i]: 1.0 if exactly 3 own checkers (stable with 1 spare)
|
||||
own_x[i]: max(0, count − 3) (overflow)
|
||||
opp_1[i]: same for opponent
|
||||
…
|
||||
|
||||
Plus unchanged game-state fields (turn stage, dice, scores), replacing the current to_vec().
|
||||
|
||||
Size: 24 × 8 = 192 (board) + 2 (dice) + 1 (current player) + 1 (turn stage) + 8 (scores) = 204
|
||||
Cost: Tensor is 5.7× larger. In practice the MCTS bottleneck is game tree expansion, not tensor fill; measured
|
||||
overhead is negligible.
|
||||
Benefit: Eliminates the color-mixing problem; the 1-checker vs. 2-checker distinction is now explicit. Learning from
|
||||
scratch will be substantially faster and the converged policy quality better.
|
||||
|
||||
### Option B — Option A + Trictrac-specific derived features (flat 1D)
|
||||
|
||||
Recommended starting point.
|
||||
|
||||
Add on top of Option A:
|
||||
|
||||
// Quarter fill status — the single most important derived feature
|
||||
quarter_filled_own[q] (q=0..3): 1.0 if own quarter q is fully filled (≥2 on all 6 fields)
|
||||
quarter_filled_opp[q] (q=0..3): same for opponent
|
||||
→ 8 values
|
||||
|
||||
// Exit readiness
|
||||
can_exit_own: 1.0 if all own checkers are in fields 19–24
|
||||
can_exit_opp: same for opponent
|
||||
→ 2 values
|
||||
|
||||
// Rest corner status (field 12/13)
|
||||
own_corner_taken: 1.0 if field 12 has ≥2 own checkers
|
||||
opp_corner_taken: 1.0 if field 13 has ≥2 opponent checkers
|
||||
→ 2 values
|
||||
|
||||
// Jan de 3 coups counter (normalized)
|
||||
dice_roll_count_own: dice_roll_count / 3.0 (clamped to 1.0)
|
||||
→ 1 value
|
||||
|
||||
Size: 204 + 8 + 2 + 2 + 1 = 217
|
||||
Training benefit: Quarter fill status is what an expert player reads at a glance. Providing it explicitly can halve
|
||||
the number of self-play games needed to learn the basic strategic structure. The corner status similarly removes
|
||||
expensive inference from the network.
|
||||
|
||||
### Option C — Option B + richer positional features (flat 1D)
|
||||
|
||||
More complete, higher sample efficiency, minor extra cost.
|
||||
|
||||
Add on top of Option B:
|
||||
|
||||
// Per-quarter fill fraction — how close to filling each quarter
|
||||
own_quarter_fill_fraction[q] (q=0..3): (count of fields with ≥2 own checkers in quarter q) / 6.0
|
||||
opp_quarter_fill_fraction[q] (q=0..3): same for opponent
|
||||
→ 8 values
|
||||
|
||||
// Blot counts — number of own/opponent single-checker fields globally
|
||||
// (tells the network at a glance how much battage risk/opportunity exists)
|
||||
own_blot_count: (number of own fields with exactly 1 checker) / 15.0
|
||||
opp_blot_count: same for opponent
|
||||
→ 2 values
|
||||
|
||||
// Bredouille would-double multiplier (already present, but explicitly scaled)
|
||||
// No change needed, already binary
|
||||
|
||||
Size: 217 + 8 + 2 = 227
|
||||
Tradeoff: The fill fractions are partially redundant with the TD-Gammon per-field counts, but they save the network
|
||||
from summing across a quarter. The redundancy is not harmful (it gives explicit shortcuts).
|
||||
|
||||
### Option D — 2D spatial tensor {K, 24}
|
||||
|
||||
For CNN-based networks. Best eventual architecture but requires changing the training setup.
|
||||
|
||||
Shape {14, 24} — 14 feature channels over 24 field positions:
|
||||
|
||||
Channel 0: own_count_1 (blot)
|
||||
Channel 1: own_count_2
|
||||
Channel 2: own_count_3
|
||||
Channel 3: own_count_overflow (float)
|
||||
Channel 4: opp_count_1
|
||||
Channel 5: opp_count_2
|
||||
Channel 6: opp_count_3
|
||||
Channel 7: opp_count_overflow
|
||||
Channel 8: own_corner_mask (1.0 at field 12)
|
||||
Channel 9: opp_corner_mask (1.0 at field 13)
|
||||
Channel 10: final_quarter_mask (1.0 at fields 19–24)
|
||||
Channel 11: quarter_filled_own (constant 1.0 across the 6 fields of any filled own quarter)
|
||||
Channel 12: quarter_filled_opp (same for opponent)
|
||||
Channel 13: dice_reach (1.0 at fields reachable this turn by own checkers)
|
||||
|
||||
Global scalars (dice, scores, bredouille, etc.) embedded as extra all-constant channels, e.g. one channel with uniform
|
||||
value dice1/6.0 across all 24 positions, another for dice2/6.0, etc. Alternatively pack them into a leading "global"
|
||||
row by returning shape {K, 25} with position 0 holding global features.
|
||||
|
||||
Size: 14 × 24 + few global channels ≈ 336–384
|
||||
C++ change needed: ObservationTensorShape() → {14, 24} (or {kNumChannels, 24}), kStateEncodingSize updated
|
||||
accordingly.
|
||||
Training setup change needed: The AlphaZero config must specify a ResNet/ConvNet rather than an MLP. OpenSpiel's
|
||||
alpha_zero.cc uses CreateTorchResnet() which already handles 2D input when the tensor shape has 3 dimensions ({C, H,
|
||||
W}). Shape {14, 24} would be treated as 2D with a 1D spatial dimension.
|
||||
Benefit: A convolutional network with kernel size 6 (= quarter width) would naturally learn quarter patterns. Kernel
|
||||
size 2–3 captures adjacent-field "tout d'une" interactions.
|
||||
|
||||
### On 3D tensors
|
||||
|
||||
Shape {K, 4, 6} — K features × 4 quarters × 6 fields — is the most semantically natural for Trictrac. The quarter is
|
||||
the fundamental tactical unit. A 2D conv over this shape (quarters × fields) would learn quarter-level patterns and
|
||||
field-within-quarter patterns jointly.
|
||||
|
||||
However, 3D tensors require a 3D convolutional network, which OpenSpiel's AlphaZero doesn't use out of the box. The
|
||||
extra architecture work makes this premature unless you're already building a custom network. The information content
|
||||
is the same as Option D.
|
||||
|
||||
### Recommendation
|
||||
|
||||
Start with Option B (217 values, flat 1D, kStateEncodingSize = 217). It requires only changes to to_vec() in Rust and
|
||||
the one constant in the C++ header — no architecture changes, no training pipeline changes. The three additions
|
||||
(quarter fill status, exit readiness, corner status) are the features a human expert reads before deciding their move.
|
||||
|
||||
Plan Option D as a follow-up once you have a baseline trained on Option B. The 2D spatial CNN becomes worthwhile when
|
||||
the MCTS games-per-second is high enough that the limit shifts from sample efficiency to wall-clock training time.
|
||||
|
||||
Costs summary:
|
||||
|
||||
| Option | Size | Rust change | C++ change | Architecture change | Expected sample-efficiency gain |
|
||||
| ------- | ---- | ---------------- | ----------------------- | ------------------- | ------------------------------- |
|
||||
| Current | 36 | — | — | — | baseline |
|
||||
| A | 204 | to_vec() rewrite | constant update | none | moderate (color separation) |
|
||||
| B | 217 | to_vec() rewrite | constant update | none | large (quarter fill explicit) |
|
||||
| C | 227 | to_vec() rewrite | constant update | none | large + moderate |
|
||||
| D | ~360 | to_vec() rewrite | constant + shape update | CNN required | large + spatial |
|
||||
|
||||
One concrete implementation note: since get_tensor() in cxxengine.rs calls game_state.mirror().to_vec() for player 2,
|
||||
the new to_vec() must express everything from the active player's perspective (which the mirror already handles for
|
||||
the board). The quarter fill status and corner status should therefore be computed on the already-mirrored state,
|
||||
which they will be if computed inside to_vec().
|
||||
|
||||
## Other algorithms
|
||||
|
||||
The recommended features (Option B) are the same or more important for DQN/PPO. But two things do shift meaningfully.
|
||||
|
||||
### 1. Without MCTS, feature quality matters more
|
||||
|
||||
AlphaZero has a safety net: even a weak policy network produces decent play once MCTS has run a few hundred
|
||||
simulations, because the tree search compensates for imprecise network estimates. DQN and PPO have no such backup —
|
||||
the network must learn the full strategic structure directly from gradient updates.
|
||||
|
||||
This means the quarter-fill status, exit readiness, and corner features from Option B are more important for DQN/PPO,
|
||||
not less. With AlphaZero you can get away with a mediocre tensor for longer. With PPO in particular, which is less
|
||||
sample-efficient than MCTS-based methods, a poorly represented state can make the game nearly unlearnable from
|
||||
scratch.
|
||||
|
||||
### 2. Normalization becomes mandatory, not optional
|
||||
|
||||
AlphaZero's value target is bounded (by MaxUtility) and MCTS normalizes visit counts into a policy. DQN bootstraps
|
||||
Q-values via TD updates, and PPO has gradient clipping but is still sensitive to input scale. With heterogeneous raw
|
||||
values (dice 1–6, counts 0–15, booleans 0/1, points 0–12) in the same vector, gradient flow is uneven and training can
|
||||
be unstable.
|
||||
|
||||
For DQN/PPO, every feature in the tensor should be in [0, 1]:
|
||||
|
||||
dice values: / 6.0
|
||||
checker counts: overflow channel / 12.0
|
||||
points: / 12.0
|
||||
holes: / 12.0
|
||||
dice_roll_count: / 3.0 (clamped)
|
||||
|
||||
Booleans and the TD-Gammon binary indicators are already in [0, 1].
|
||||
|
||||
### 3. The shape question depends on architecture, not algorithm
|
||||
|
||||
| Architecture | Shape | When to use |
|
||||
| ------------------------------------ | ---------------------------- | ------------------------------------------------------------------- |
|
||||
| MLP | {217} flat | Any algorithm, simplest baseline |
|
||||
| 1D CNN (conv over 24 fields) | {K, 24} | When you want spatial locality (adjacent fields, quarter patterns) |
|
||||
| 2D CNN (conv over quarters × fields) | {K, 4, 6} | Most semantically natural for Trictrac, but requires custom network |
|
||||
| Transformer | {24, K} (sequence of fields) | Attention over field positions; overkill for now |
|
||||
|
||||
The choice between these is independent of whether you use AlphaZero, DQN, or PPO. It depends on whether you want
|
||||
convolutions, and DQN/PPO give you more architectural freedom than OpenSpiel's AlphaZero (which uses a fixed ResNet
|
||||
template). With a custom DQN/PPO implementation you can use a 2D CNN immediately without touching the C++ side at all
|
||||
— you just reshape the flat tensor in Python before passing it to the network.
|
||||
|
||||
### One thing that genuinely changes: value function perspective
|
||||
|
||||
AlphaZero and ego-centric PPO always see the board from the active player's perspective (handled by mirror()). This
|
||||
works well.
|
||||
|
||||
DQN in a two-player game sometimes uses a canonical absolute representation (always White's view, with an explicit
|
||||
current-player indicator), because a single Q-network estimates action values for both players simultaneously. With
|
||||
the current ego-centric mirroring, the same board position looks different depending on whose turn it is, and DQN must
|
||||
learn both "sides" through the same weights — which it can do, but a canonical representation removes the ambiguity.
|
||||
This is a minor point for a symmetric game like Trictrac, but worth keeping in mind.
|
||||
|
||||
Bottom line: Stick with Option B (217 values, normalized), flat 1D. If you later add a CNN, reshape in Python — there's no need to change the Rust/C++ tensor format. The features themselves are the same regardless of algorithm.
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
# Inspirations
|
||||
|
||||
tools
|
||||
|
||||
- config clippy ?
|
||||
- bacon : tests runner (ou loom ?)
|
||||
|
||||
## Rust libs
|
||||
|
||||
cf. <https://blessed.rs/crates>
|
||||
|
||||
nombres aléatoires avec seed : <https://richard.dallaway.com/posts/2021-01-04-repeat-resume/>
|
||||
|
||||
- cli : <https://lib.rs/crates/pico-args> ( ou clap )
|
||||
- reseau async : tokio
|
||||
- web serveur : axum (uses tokio)
|
||||
- <https://fasterthanli.me/series/updating-fasterthanli-me-for-2022/part-2#the-opinions-of-axum-also-nice-error-handling>
|
||||
- db : sqlx
|
||||
|
||||
- eyre, color-eyre (Results)
|
||||
- tracing (logging)
|
||||
- rayon ( sync <-> parallel )
|
||||
|
||||
- front : yew + tauri
|
||||
|
||||
- egui
|
||||
|
||||
- <https://docs.rs/board-game/latest/board_game/>
|
||||
|
||||
## network games
|
||||
|
||||
- <https://www.mattkeeter.com/projects/pont/>
|
||||
- <https://github.com/jackadamson/onitama> (wasm, rooms)
|
||||
- <https://github.com/UkoeHB/renet2>
|
||||
- <https://github.com/UkoeHB/bevy_simplenet>
|
||||
|
||||
## Others
|
||||
|
||||
- plugins avec <https://github.com/extism/extism>
|
||||
|
||||
## Backgammon existing projects
|
||||
|
||||
- go : <https://bgammon.org/blog/20240101-hello-world/>
|
||||
- protocole de communication : <https://code.rocket9labs.com/tslocum/bgammon/src/branch/main/PROTOCOL.md>
|
||||
- ocaml : <https://github.com/jacobhilton/backgammon?tab=readme-ov-file>
|
||||
cli example : <https://www.jacobh.co.uk/backgammon/>
|
||||
- lib rust backgammon
|
||||
- <https://github.com/carlostrub/backgammon>
|
||||
- <https://github.com/marktani/backgammon>
|
||||
- network webtarot
|
||||
- front ?
|
||||
|
||||
## cli examples
|
||||
|
||||
### GnuBackgammon
|
||||
|
||||
(No game) new game
|
||||
gnubg rolls 3, anthon rolls 1.
|
||||
|
||||
GNU Backgammon Positions ID: 4HPwATDgc/ABMA
|
||||
Match ID : MIEFAAAAAAAA
|
||||
+12-11-10--9--8--7-------6--5--4--3--2--1-+ O: gnubg
|
||||
| X O | | O X | 0 points
|
||||
| X O | | O X | Rolled 31
|
||||
| X O | | O |
|
||||
| X | | O |
|
||||
| X | | O |
|
||||
^| |BAR| | (Cube: 1)
|
||||
| O | | X |
|
||||
| O | | X |
|
||||
| O X | | X |
|
||||
| O X | | X O |
|
||||
| O X | | X O | 0 points
|
||||
+13-14-15-16-17-18------19-20-21-22-23-24-+ X: anthon
|
||||
|
||||
gnubg moves 8/5 6/5.
|
||||
|
||||
### jacobh
|
||||
|
||||
Move 11: player O rolls a 6-2.
|
||||
Player O estimates that they have a 90.6111% chance of winning.
|
||||
|
||||
Os borne off: none
|
||||
24 23 22 21 20 19 18 17 16 15 14 13
|
||||
|
||||
---
|
||||
|
||||
| v v v v v v | | v v v v v v |
|
||||
| | | |
|
||||
| X O O O | | O O O |
|
||||
| X O O O | | O O |
|
||||
| O | | |
|
||||
| | X | |
|
||||
| | | |
|
||||
| | | |
|
||||
| | | |
|
||||
| | | |
|
||||
|------------------------------| |------------------------------|
|
||||
| | | |
|
||||
| | | |
|
||||
| | | |
|
||||
| | | |
|
||||
| X | | |
|
||||
| X X | | X |
|
||||
| X X X | | X O |
|
||||
| X X X | | X O O |
|
||||
| | | |
|
||||
| ^ ^ ^ ^ ^ ^ | | ^ ^ ^ ^ ^ ^ |
|
||||
|
||||
---
|
||||
|
||||
1 2 3 4 5 6 7 8 9 10 11 12
|
||||
Xs borne off: none
|
||||
|
||||
Move 12: player X rolls a 6-3.
|
||||
Your move (? for help): bar/22
|
||||
Illegal move: it is possible to move more.
|
||||
Your move (? for help): ?
|
||||
Enter the start and end positions, separated by a forward slash (or any non-numeric character), of each counter you want to move.
|
||||
Each position should be number from 1 to 24, "bar" or "off".
|
||||
Unlike in standard notation, you should enter each counter movement individually. For example:
|
||||
24/18 18/13
|
||||
bar/3 13/10 13/10 8/5
|
||||
2/off 1/off
|
||||
You can also enter these commands:
|
||||
p - show the previous move
|
||||
n - show the next move
|
||||
<enter> - toggle between showing the current and last moves
|
||||
help - show this help text
|
||||
quit - abandon game
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
# Journal
|
||||
|
||||
```sh
|
||||
devenv init
|
||||
cargo init
|
||||
cargo add pico-args
|
||||
```
|
||||
|
||||
Organisation store / server / client selon <https://herluf-ba.github.io/making-a-turn-based-multiplayer-game-in-rust-01-whats-a-turn-based-game-anyway>
|
||||
|
||||
_store_ est la bibliothèque contenant le _reducer_ qui transforme l'état du jeu en fonction des évènements. Elle est utilisée par le _server_ et le _client_. Seuls les évènements sont transmis entre clients et serveur.
|
||||
|
||||
## Config neovim debugger launchers
|
||||
|
||||
Cela se passe dans la config neovim (lua/plugins/overrides.lua)
|
||||
|
||||
## Organisation du store
|
||||
|
||||
lib
|
||||
|
||||
- game::GameState
|
||||
- error
|
||||
- dice
|
||||
- board
|
||||
- user
|
||||
- user
|
||||
|
||||
## Algorithme de détermination des coups
|
||||
|
||||
- strategy::choose_move
|
||||
|
||||
- GameRules.get_possible_moves_sequences(with_excedents: bool)
|
||||
- get_possible_moves_sequences_by_dices(dice_max, dice_min, with_excedents, false);
|
||||
- get_possible_moves_sequences_by_dices(dice_min, dice_max, with_excedents, true);
|
||||
- has_checkers_outside_last_quarter() ok
|
||||
- board.get_possible_moves ok
|
||||
- check_corner_rules(&(first_move, second_move)) ok
|
||||
|
||||
- handle_event
|
||||
- state.validate (ok)
|
||||
- rules.moves_follow_rules (ok)
|
||||
- moves_possible ok
|
||||
- moves_follows_dices ok
|
||||
- moves_allowed (ok)
|
||||
- check_corner_rules ok
|
||||
- can_take_corner_by_effect ok
|
||||
- get_possible_moves_sequences -> cf. l.15
|
||||
- check_exit_rules
|
||||
- get_possible_moves_sequences(without exedents) -> cf l.15
|
||||
- get_quarter_filling_moves_sequences
|
||||
- get_possible_moves_sequences -> cf l.15
|
||||
- state.consume (RollResult) (ok)
|
||||
- get_rollresult_jans -> points_rules.get_result_jans (ok)
|
||||
- get_jans (ok)
|
||||
- get_jans_by_ordered_dice (ok)
|
||||
- get_jans_by_ordered_dice ( dices.poped )
|
||||
- move_rules.get_scoring_quarter_filling_moves_sequences (ok)
|
||||
- get_quarter_filling_moves_sequences cf l.8 (ok)
|
||||
- board.get_quarter_filling_candidate -> is_quarter_fillable ok
|
||||
- move_rules.get_possible_moves_sequence -> cf l.15
|
||||
- get_jans_points -> jan.get_points ok
|
||||
|
|
@ -1,417 +0,0 @@
|
|||
# Outputs
|
||||
|
||||
## 50 episodes - 1000 steps max - desktop
|
||||
|
||||
{"episode": 0, "reward": -1798.7162, "steps count": 1000, "duration": 11}
|
||||
{"episode": 1, "reward": -1794.8162, "steps count": 1000, "duration": 32}
|
||||
{"episode": 2, "reward": -1387.7109, "steps count": 1000, "duration": 58}
|
||||
{"episode": 3, "reward": -42.5005, "steps count": 1000, "duration": 82}
|
||||
{"episode": 4, "reward": -48.2005, "steps count": 1000, "duration": 109}
|
||||
{"episode": 5, "reward": 1.2000, "steps count": 1000, "duration": 141}
|
||||
{"episode": 6, "reward": 8.8000, "steps count": 1000, "duration": 184}
|
||||
{"episode": 7, "reward": 6.9002, "steps count": 1000, "duration": 219}
|
||||
{"episode": 8, "reward": 16.5001, "steps count": 1000, "duration": 248}
|
||||
{"episode": 9, "reward": -2.6000, "steps count": 1000, "duration": 281}
|
||||
{"episode": 10, "reward": 3.0999, "steps count": 1000, "duration": 324}
|
||||
{"episode": 11, "reward": -34.7004, "steps count": 1000, "duration": 497}
|
||||
{"episode": 12, "reward": -15.7998, "steps count": 1000, "duration": 466}
|
||||
{"episode": 13, "reward": 6.9000, "steps count": 1000, "duration": 496}
|
||||
{"episode": 14, "reward": 6.3000, "steps count": 1000, "duration": 540}
|
||||
{"episode": 15, "reward": -2.6000, "steps count": 1000, "duration": 581}
|
||||
{"episode": 16, "reward": -33.0003, "steps count": 1000, "duration": 641}
|
||||
{"episode": 17, "reward": -36.8000, "steps count": 1000, "duration": 665}
|
||||
{"episode": 18, "reward": -10.1997, "steps count": 1000, "duration": 753}
|
||||
{"episode": 19, "reward": -88.1014, "steps count": 1000, "duration": 837}
|
||||
{"episode": 20, "reward": -57.5002, "steps count": 1000, "duration": 881}
|
||||
{"episode": 21, "reward": -17.7997, "steps count": 1000, "duration": 1159}
|
||||
{"episode": 22, "reward": -25.4000, "steps count": 1000, "duration": 1235}
|
||||
{"episode": 23, "reward": -104.4013, "steps count": 995, "duration": 1290}
|
||||
{"episode": 24, "reward": -268.6004, "steps count": 1000, "duration": 1322}
|
||||
{"episode": 25, "reward": -743.6052, "steps count": 1000, "duration": 1398}
|
||||
{"episode": 26, "reward": -821.5029, "steps count": 1000, "duration": 1427}
|
||||
{"episode": 27, "reward": -211.5993, "steps count": 1000, "duration": 1409}
|
||||
{"episode": 28, "reward": -276.1974, "steps count": 1000, "duration": 1463}
|
||||
{"episode": 29, "reward": -222.9980, "steps count": 1000, "duration": 1509}
|
||||
{"episode": 30, "reward": -298.9973, "steps count": 1000, "duration": 1560}
|
||||
{"episode": 31, "reward": -164.0011, "steps count": 1000, "duration": 1752}
|
||||
{"episode": 32, "reward": -221.0990, "steps count": 1000, "duration": 1807}
|
||||
{"episode": 33, "reward": -260.9996, "steps count": 1000, "duration": 1730}
|
||||
{"episode": 34, "reward": -420.5959, "steps count": 1000, "duration": 1767}
|
||||
{"episode": 35, "reward": -407.2964, "steps count": 1000, "duration": 1815}
|
||||
{"episode": 36, "reward": -291.2966, "steps count": 1000, "duration": 1870}
|
||||
|
||||
thread 'main' has overflowed its stack
|
||||
fatal runtime error: stack overflow, aborting
|
||||
error: Recipe `trainbot` was terminated on line 24 by signal 6
|
||||
|
||||
## 50 episodes - 700 steps max - desktop
|
||||
|
||||
const MEMORY_SIZE: usize = 4096;
|
||||
const DENSE_SIZE: usize = 128;
|
||||
const EPS_DECAY: f64 = 1000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -862.8993, "steps count": 700, "duration": 6}
|
||||
> {"episode": 1, "reward": -418.8971, "steps count": 700, "duration": 13}
|
||||
> {"episode": 2, "reward": -64.9999, "steps count": 453, "duration": 14}
|
||||
> {"episode": 3, "reward": -142.8002, "steps count": 700, "duration": 31}
|
||||
> {"episode": 4, "reward": -74.4004, "steps count": 700, "duration": 45}
|
||||
> {"episode": 5, "reward": -40.2002, "steps count": 700, "duration": 58}
|
||||
> {"episode": 6, "reward": -21.1998, "steps count": 700, "duration": 70}
|
||||
> {"episode": 7, "reward": 99.7000, "steps count": 642, "duration": 79}
|
||||
> {"episode": 8, "reward": -5.9999, "steps count": 700, "duration": 99}
|
||||
> {"episode": 9, "reward": -7.8999, "steps count": 700, "duration": 118}
|
||||
> {"episode": 10, "reward": 92.5000, "steps count": 624, "duration": 117}
|
||||
> {"episode": 11, "reward": -17.1998, "steps count": 700, "duration": 144}
|
||||
> {"episode": 12, "reward": 1.7000, "steps count": 700, "duration": 157}
|
||||
> {"episode": 13, "reward": -7.9000, "steps count": 700, "duration": 172}
|
||||
> {"episode": 14, "reward": -7.9000, "steps count": 700, "duration": 196}
|
||||
> {"episode": 15, "reward": -2.8000, "steps count": 700, "duration": 214}
|
||||
> {"episode": 16, "reward": 16.8002, "steps count": 700, "duration": 250}
|
||||
> {"episode": 17, "reward": -47.7001, "steps count": 700, "duration": 272}
|
||||
> k{"episode": 18, "reward": -13.6000, "steps count": 700, "duration": 288}
|
||||
> {"episode": 19, "reward": -79.9002, "steps count": 700, "duration": 304}
|
||||
> {"episode": 20, "reward": -355.5985, "steps count": 700, "duration": 317}
|
||||
> {"episode": 21, "reward": -205.5001, "steps count": 700, "duration": 333}
|
||||
> {"episode": 22, "reward": -207.3974, "steps count": 700, "duration": 348}
|
||||
> {"episode": 23, "reward": -161.7999, "steps count": 700, "duration": 367}
|
||||
|
||||
---
|
||||
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
const DENSE_SIZE: usize = 128;
|
||||
const EPS_DECAY: f64 = 10000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1119.9921, "steps count": 700, "duration": 6}
|
||||
> {"episode": 1, "reward": -928.6963, "steps count": 700, "duration": 13}
|
||||
> {"episode": 2, "reward": -364.5009, "steps count": 380, "duration": 11}
|
||||
> {"episode": 3, "reward": -797.5981, "steps count": 700, "duration": 28}
|
||||
> {"episode": 4, "reward": -577.5994, "steps count": 599, "duration": 34}
|
||||
> {"episode": 5, "reward": -725.2992, "steps count": 700, "duration": 49}
|
||||
> {"episode": 6, "reward": -638.8995, "steps count": 700, "duration": 59}
|
||||
> {"episode": 7, "reward": -1039.1932, "steps count": 700, "duration": 73}
|
||||
> field invalid : White, 3, Board { positions: [13, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -2, 0, -11] }
|
||||
|
||||
thread 'main' panicked at store/src/game.rs:556:65:
|
||||
called `Result::unwrap()` on an `Err` value: FieldInvalid
|
||||
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
|
||||
error: Recipe `trainbot` failed on line 27 with exit code 101
|
||||
|
||||
---
|
||||
|
||||
# [allow(unused)]
|
||||
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
const DENSE_SIZE: usize = 256;
|
||||
const EPS_DECAY: f64 = 10000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1102.6925, "steps count": 700, "duration": 9}
|
||||
> field invalid : White, 6, Board { positions: [14, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, -13] }
|
||||
|
||||
thread 'main' panicked at store/src/game.rs:556:65:
|
||||
called `Result::unwrap()` on an `Err` value: FieldInvalid
|
||||
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
|
||||
error: Recipe `trainbot` failed on line 27 with exit code 101
|
||||
|
||||
---
|
||||
|
||||
const MEMORY_SIZE: usize = 8192;
|
||||
const DENSE_SIZE: usize = 256;
|
||||
const EPS_DECAY: f64 = 1000.0;
|
||||
const EPS_START: f64 = 0.9;
|
||||
const EPS_END: f64 = 0.05;
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1116.2921, "steps count": 700, "duration": 9}
|
||||
> {"episode": 1, "reward": -1116.2922, "steps count": 700, "duration": 18}
|
||||
> {"episode": 2, "reward": -1119.9921, "steps count": 700, "duration": 29}
|
||||
> {"episode": 3, "reward": -1089.1927, "steps count": 700, "duration": 41}
|
||||
> {"episode": 4, "reward": -1116.2921, "steps count": 700, "duration": 53}
|
||||
> {"episode": 5, "reward": -684.8043, "steps count": 700, "duration": 66}
|
||||
> {"episode": 6, "reward": 0.3000, "steps count": 700, "duration": 80}
|
||||
> {"episode": 7, "reward": 2.0000, "steps count": 700, "duration": 96}
|
||||
> {"episode": 8, "reward": 30.9001, "steps count": 700, "duration": 112}
|
||||
> {"episode": 9, "reward": 0.3000, "steps count": 700, "duration": 128}
|
||||
> {"episode": 10, "reward": 0.3000, "steps count": 700, "duration": 141}
|
||||
> {"episode": 11, "reward": 8.8000, "steps count": 700, "duration": 155}
|
||||
> {"episode": 12, "reward": 7.1000, "steps count": 700, "duration": 169}
|
||||
> {"episode": 13, "reward": 17.3001, "steps count": 700, "duration": 190}
|
||||
> {"episode": 14, "reward": -107.9005, "steps count": 700, "duration": 210}
|
||||
> {"episode": 15, "reward": 7.1001, "steps count": 700, "duration": 236}
|
||||
> {"episode": 16, "reward": 17.3001, "steps count": 700, "duration": 268}
|
||||
> {"episode": 17, "reward": 7.1000, "steps count": 700, "duration": 283}
|
||||
> {"episode": 18, "reward": -5.9000, "steps count": 700, "duration": 300}
|
||||
> {"episode": 19, "reward": -36.8009, "steps count": 700, "duration": 316}
|
||||
> {"episode": 20, "reward": 19.0001, "steps count": 700, "duration": 332}
|
||||
> {"episode": 21, "reward": 113.3000, "steps count": 461, "duration": 227}
|
||||
> field invalid : White, 1, Board { positions: [0, 2, 2, 0, 2, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -7, -2, -1, 0, -1, -1] }
|
||||
|
||||
thread 'main' panicked at store/src/game.rs:556:65:
|
||||
called `Result::unwrap()` on an `Err` value: FieldInvalid
|
||||
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
|
||||
error: Recipe `trainbot` failed on line 27 with exit code 101
|
||||
|
||||
---
|
||||
|
||||
num_episodes: 50,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1118.8921, "steps count": 700, "duration": 9}
|
||||
> {"episode": 1, "reward": -1119.9921, "steps count": 700, "duration": 17}
|
||||
> {"episode": 2, "reward": -1118.8921, "steps count": 700, "duration": 28}
|
||||
> {"episode": 3, "reward": -283.5977, "steps count": 700, "duration": 41}
|
||||
> {"episode": 4, "reward": -23.4998, "steps count": 700, "duration": 54}
|
||||
> {"episode": 5, "reward": -31.9999, "steps count": 700, "duration": 68}
|
||||
> {"episode": 6, "reward": 2.0000, "steps count": 700, "duration": 82}
|
||||
> {"episode": 7, "reward": 109.3000, "steps count": 192, "duration": 26}
|
||||
> {"episode": 8, "reward": -4.8000, "steps count": 700, "duration": 102}
|
||||
> {"episode": 9, "reward": 15.6001, "steps count": 700, "duration": 124}
|
||||
> {"episode": 10, "reward": 15.6002, "steps count": 700, "duration": 144}
|
||||
> {"episode": 11, "reward": -65.7008, "steps count": 700, "duration": 162}
|
||||
> {"episode": 12, "reward": 19.0002, "steps count": 700, "duration": 182}
|
||||
> {"episode": 13, "reward": 20.7001, "steps count": 700, "duration": 197}
|
||||
> {"episode": 14, "reward": 12.2002, "steps count": 700, "duration": 229}
|
||||
> {"episode": 15, "reward": -32.0007, "steps count": 700, "duration": 242}
|
||||
> {"episode": 16, "reward": 10.5000, "steps count": 700, "duration": 287}
|
||||
> {"episode": 17, "reward": 24.1001, "steps count": 700, "duration": 318}
|
||||
> {"episode": 18, "reward": 25.8002, "steps count": 700, "duration": 335}
|
||||
> {"episode": 19, "reward": 29.2001, "steps count": 700, "duration": 367}
|
||||
> {"episode": 20, "reward": 9.1000, "steps count": 700, "duration": 366}
|
||||
> {"episode": 21, "reward": 3.7001, "steps count": 700, "duration": 398}
|
||||
> {"episode": 22, "reward": 10.5000, "steps count": 700, "duration": 417}
|
||||
> {"episode": 23, "reward": 10.5000, "steps count": 700, "duration": 438}
|
||||
> {"episode": 24, "reward": 13.9000, "steps count": 700, "duration": 444}
|
||||
> {"episode": 25, "reward": 7.1000, "steps count": 700, "duration": 486}
|
||||
> {"episode": 26, "reward": 12.2001, "steps count": 700, "duration": 499}
|
||||
> {"episode": 27, "reward": 8.8001, "steps count": 700, "duration": 554}
|
||||
> {"episode": 28, "reward": -6.5000, "steps count": 700, "duration": 608}
|
||||
> {"episode": 29, "reward": -3.1000, "steps count": 700, "duration": 633}
|
||||
> {"episode": 30, "reward": -32.0001, "steps count": 700, "duration": 696}
|
||||
> {"episode": 31, "reward": 22.4002, "steps count": 700, "duration": 843}
|
||||
> {"episode": 32, "reward": -77.9004, "steps count": 700, "duration": 817}
|
||||
> {"episode": 33, "reward": -368.5993, "steps count": 700, "duration": 827}
|
||||
> {"episode": 34, "reward": -254.6986, "steps count": 700, "duration": 852}
|
||||
> {"episode": 35, "reward": -433.1992, "steps count": 700, "duration": 884}
|
||||
> {"episode": 36, "reward": -521.6010, "steps count": 700, "duration": 905}
|
||||
> {"episode": 37, "reward": -71.1004, "steps count": 700, "duration": 930}
|
||||
> {"episode": 38, "reward": -251.0004, "steps count": 700, "duration": 956}
|
||||
> {"episode": 39, "reward": -594.7045, "steps count": 700, "duration": 982}
|
||||
> {"episode": 40, "reward": -154.4001, "steps count": 700, "duration": 1008}
|
||||
> {"episode": 41, "reward": -171.3994, "steps count": 700, "duration": 1033}
|
||||
> {"episode": 42, "reward": -118.7004, "steps count": 700, "duration": 1059}
|
||||
> {"episode": 43, "reward": -137.4003, "steps count": 700, "duration": 1087}
|
||||
|
||||
thread 'main' has overflowed its stack
|
||||
fatal runtime error: stack overflow, aborting
|
||||
error: Recipe `trainbot` was terminated on line 27 by signal 6
|
||||
|
||||
---
|
||||
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 1500, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 1000.0,
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -2399.9993, "steps count": 1500, "duration": 31}
|
||||
> {"episode": 1, "reward": -2061.6736, "steps count": 1500, "duration": 81}
|
||||
> {"episode": 2, "reward": -48.9010, "steps count": 1500, "duration": 145}
|
||||
> {"episode": 3, "reward": 3.8000, "steps count": 1500, "duration": 215}
|
||||
> {"episode": 4, "reward": -6.3999, "steps count": 1500, "duration": 302}
|
||||
> {"episode": 5, "reward": 20.8004, "steps count": 1500, "duration": 374}
|
||||
> {"episode": 6, "reward": 49.6992, "steps count": 1500, "duration": 469}
|
||||
> {"episode": 7, "reward": 29.3002, "steps count": 1500, "duration": 597}
|
||||
> {"episode": 8, "reward": 34.3999, "steps count": 1500, "duration": 710}
|
||||
> {"episode": 9, "reward": 115.3003, "steps count": 966, "duration": 515}
|
||||
> {"episode": 10, "reward": 25.9004, "steps count": 1500, "duration": 852}
|
||||
> {"episode": 11, "reward": -122.0007, "steps count": 1500, "duration": 1017}
|
||||
> {"episode": 12, "reward": -274.9966, "steps count": 1500, "duration": 1073}
|
||||
> {"episode": 13, "reward": 54.8994, "steps count": 651, "duration": 518}
|
||||
> {"episode": 14, "reward": -439.8978, "steps count": 1500, "duration": 1244}
|
||||
> {"episode": 15, "reward": -506.1997, "steps count": 1500, "duration": 1676}
|
||||
> {"episode": 16, "reward": -829.5031, "steps count": 1500, "duration": 1855}
|
||||
> {"episode": 17, "reward": -545.2961, "steps count": 1500, "duration": 1892}
|
||||
> {"episode": 18, "reward": -795.2026, "steps count": 1500, "duration": 2008}
|
||||
> {"episode": 19, "reward": -637.1031, "steps count": 1500, "duration": 2124}
|
||||
> {"episode": 20, "reward": -989.6997, "steps count": 1500, "duration": 2241}
|
||||
|
||||
thread 'main' has overflowed its stack
|
||||
fatal runtime error: stack overflow, aborting
|
||||
error: Recipe `trainbot` was terminated on line 27 by signal 6
|
||||
|
||||
---
|
||||
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 10000.0,
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1598.8848, "steps count": 1000, "duration": 16}
|
||||
> {"episode": 1, "reward": -1531.9866, "steps count": 1000, "duration": 34}
|
||||
> {"episode": 2, "reward": -515.6000, "steps count": 530, "duration": 25}
|
||||
> {"episode": 3, "reward": -396.1008, "steps count": 441, "duration": 27}
|
||||
> {"episode": 4, "reward": -540.6996, "steps count": 605, "duration": 43}
|
||||
> {"episode": 5, "reward": -976.0975, "steps count": 1000, "duration": 89}
|
||||
> {"episode": 6, "reward": -1014.2944, "steps count": 1000, "duration": 117}
|
||||
> {"episode": 7, "reward": -806.7012, "steps count": 1000, "duration": 140}
|
||||
> {"episode": 8, "reward": -1276.6891, "steps count": 1000, "duration": 166}
|
||||
> {"episode": 9, "reward": -1554.3855, "steps count": 1000, "duration": 197}
|
||||
> {"episode": 10, "reward": -1178.3925, "steps count": 1000, "duration": 219}
|
||||
> {"episode": 11, "reward": -1457.4869, "steps count": 1000, "duration": 258}
|
||||
> {"episode": 12, "reward": -1475.8882, "steps count": 1000, "duration": 291}
|
||||
|
||||
---
|
||||
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 3000.0,
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1598.8848, "steps count": 1000, "duration": 15}
|
||||
> {"episode": 1, "reward": -1599.9847, "steps count": 1000, "duration": 33}
|
||||
> {"episode": 2, "reward": -751.7018, "steps count": 1000, "duration": 57}
|
||||
> {"episode": 3, "reward": -402.8979, "steps count": 1000, "duration": 81}
|
||||
> {"episode": 4, "reward": -289.2985, "steps count": 1000, "duration": 108}
|
||||
> {"episode": 5, "reward": -231.4988, "steps count": 1000, "duration": 140}
|
||||
> {"episode": 6, "reward": -138.0006, "steps count": 1000, "duration": 165}
|
||||
> {"episode": 7, "reward": -145.0998, "steps count": 1000, "duration": 200}
|
||||
> {"episode": 8, "reward": -60.4005, "steps count": 1000, "duration": 236}
|
||||
> {"episode": 9, "reward": -35.7999, "steps count": 1000, "duration": 276}
|
||||
> {"episode": 10, "reward": -42.2002, "steps count": 1000, "duration": 313}
|
||||
> {"episode": 11, "reward": 69.0002, "steps count": 874, "duration": 300}
|
||||
> {"episode": 12, "reward": 93.2000, "steps count": 421, "duration": 153}
|
||||
> {"episode": 13, "reward": -324.9010, "steps count": 866, "duration": 364}
|
||||
> {"episode": 14, "reward": -1331.3883, "steps count": 1000, "duration": 478}
|
||||
> {"episode": 15, "reward": -1544.5859, "steps count": 1000, "duration": 514}
|
||||
> {"episode": 16, "reward": -1599.9847, "steps count": 1000, "duration": 552}
|
||||
|
||||
---
|
||||
|
||||
Nouveaux points...
|
||||
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 1000, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 3000.0,
|
||||
|
||||
> Entraînement
|
||||
> {"episode": 0, "reward": -1798.1161, "steps count": 1000, "duration": 15}
|
||||
> {"episode": 1, "reward": -1800.0162, "steps count": 1000, "duration": 34}
|
||||
> {"episode": 2, "reward": -1718.6151, "steps count": 1000, "duration": 57}
|
||||
> {"episode": 3, "reward": -1369.5055, "steps count": 1000, "duration": 82}
|
||||
> {"episode": 4, "reward": -321.5974, "steps count": 1000, "duration": 115}
|
||||
> {"episode": 5, "reward": -213.2988, "steps count": 1000, "duration": 148}
|
||||
> {"episode": 6, "reward": -175.4995, "steps count": 1000, "duration": 172}
|
||||
> {"episode": 7, "reward": -126.1011, "steps count": 1000, "duration": 203}
|
||||
> {"episode": 8, "reward": -105.1011, "steps count": 1000, "duration": 242}
|
||||
> {"episode": 9, "reward": -46.3007, "steps count": 1000, "duration": 281}
|
||||
> {"episode": 10, "reward": -57.7006, "steps count": 1000, "duration": 323}
|
||||
> {"episode": 11, "reward": -15.7997, "steps count": 1000, "duration": 354}
|
||||
> {"episode": 12, "reward": -38.6999, "steps count": 1000, "duration": 414}
|
||||
> {"episode": 13, "reward": 10.7002, "steps count": 1000, "duration": 513}
|
||||
> {"episode": 14, "reward": -10.1999, "steps count": 1000, "duration": 585}
|
||||
> {"episode": 15, "reward": -8.3000, "steps count": 1000, "duration": 644}
|
||||
> {"episode": 16, "reward": -463.4984, "steps count": 973, "duration": 588}
|
||||
> {"episode": 17, "reward": -148.8951, "steps count": 1000, "duration": 646}
|
||||
> {"episode": 18, "reward": 3.0999, "steps count": 1000, "duration": 676}
|
||||
> {"episode": 19, "reward": -12.0999, "steps count": 1000, "duration": 753}
|
||||
> {"episode": 20, "reward": 6.9000, "steps count": 1000, "duration": 801}
|
||||
> {"episode": 21, "reward": 14.5001, "steps count": 1000, "duration": 850}
|
||||
> {"episode": 22, "reward": -19.6999, "steps count": 1000, "duration": 937}
|
||||
> {"episode": 23, "reward": 83.0000, "steps count": 456, "duration": 532}
|
||||
> {"episode": 24, "reward": -13.9998, "steps count": 1000, "duration": 1236}
|
||||
> {"episode": 25, "reward": 25.9003, "steps count": 1000, "duration": 1264}
|
||||
> {"episode": 26, "reward": 1.2002, "steps count": 1000, "duration": 1349}
|
||||
> {"episode": 27, "reward": 3.1000, "steps count": 1000, "duration": 1364}
|
||||
> {"episode": 28, "reward": -6.4000, "steps count": 1000, "duration": 1392}
|
||||
> {"episode": 29, "reward": -4.4998, "steps count": 1000, "duration": 1444}
|
||||
> {"episode": 30, "reward": 3.1000, "steps count": 1000, "duration": 1611}
|
||||
|
||||
thread 'main' has overflowed its stack
|
||||
fatal runtime error: stack overflow, aborting
|
||||
|
||||
---
|
||||
|
||||
num_episodes: 40,
|
||||
// memory_size: 8192, // must be set in dqn_model.rs with the MEMORY_SIZE constant
|
||||
// max_steps: 700, // must be set in environment.rs with the MAX_STEPS constant
|
||||
dense_size: 256, // neural network complexity
|
||||
eps_start: 0.9, // epsilon initial value (0.9 => more exploration)
|
||||
eps_end: 0.05,
|
||||
eps_decay: 3000.0,
|
||||
|
||||
{"episode": 0, "reward": -1256.1014, "steps count": 700, "duration": 9}
|
||||
{"episode": 1, "reward": -1256.1013, "steps count": 700, "duration": 20}
|
||||
{"episode": 2, "reward": -1256.1014, "steps count": 700, "duration": 31}
|
||||
{"episode": 3, "reward": -1258.7015, "steps count": 700, "duration": 44}
|
||||
{"episode": 4, "reward": -1206.8009, "steps count": 700, "duration": 56}
|
||||
{"episode": 5, "reward": -473.2974, "steps count": 700, "duration": 68}
|
||||
{"episode": 6, "reward": -285.2984, "steps count": 700, "duration": 82}
|
||||
{"episode": 7, "reward": -332.6987, "steps count": 700, "duration": 103}
|
||||
{"episode": 8, "reward": -359.2984, "steps count": 700, "duration": 114}
|
||||
{"episode": 9, "reward": -118.7008, "steps count": 700, "duration": 125}
|
||||
{"episode": 10, "reward": -83.9004, "steps count": 700, "duration": 144}
|
||||
{"episode": 11, "reward": -68.7006, "steps count": 700, "duration": 165}
|
||||
{"episode": 12, "reward": -49.7002, "steps count": 700, "duration": 180}
|
||||
{"episode": 13, "reward": -68.7002, "steps count": 700, "duration": 204}
|
||||
{"episode": 14, "reward": -38.3001, "steps count": 700, "duration": 223}
|
||||
{"episode": 15, "reward": -19.2999, "steps count": 700, "duration": 240}
|
||||
{"episode": 16, "reward": -19.1998, "steps count": 700, "duration": 254}
|
||||
{"episode": 17, "reward": -21.1999, "steps count": 700, "duration": 250}
|
||||
{"episode": 18, "reward": -26.8998, "steps count": 700, "duration": 280}
|
||||
{"episode": 19, "reward": -11.6999, "steps count": 700, "duration": 301}
|
||||
{"episode": 20, "reward": -13.5998, "steps count": 700, "duration": 317}
|
||||
{"episode": 21, "reward": 5.4000, "steps count": 700, "duration": 334}
|
||||
{"episode": 22, "reward": 3.5000, "steps count": 700, "duration": 353}
|
||||
{"episode": 23, "reward": 13.0000, "steps count": 700, "duration": 374}
|
||||
{"episode": 24, "reward": 7.3001, "steps count": 700, "duration": 391}
|
||||
{"episode": 25, "reward": -4.1000, "steps count": 700, "duration": 408}
|
||||
{"episode": 26, "reward": -17.3998, "steps count": 700, "duration": 437}
|
||||
{"episode": 27, "reward": 11.1001, "steps count": 700, "duration": 480}
|
||||
{"episode": 28, "reward": -4.1000, "steps count": 700, "duration": 505}
|
||||
{"episode": 29, "reward": -13.5999, "steps count": 700, "duration": 522}
|
||||
{"episode": 30, "reward": -0.3000, "steps count": 700, "duration": 540}
|
||||
{"episode": 31, "reward": -15.4998, "steps count": 700, "duration": 572}
|
||||
{"episode": 32, "reward": 14.9001, "steps count": 700, "duration": 630}
|
||||
{"episode": 33, "reward": -4.1000, "steps count": 700, "duration": 729}
|
||||
{"episode": 34, "reward": 5.4000, "steps count": 700, "duration": 777}
|
||||
{"episode": 35, "reward": 7.3000, "steps count": 700, "duration": 748}
|
||||
{"episode": 36, "reward": 9.2001, "steps count": 700, "duration": 767}
|
||||
{"episode": 37, "reward": 13.0001, "steps count": 700, "duration": 791}
|
||||
{"episode": 38, "reward": -13.5999, "steps count": 700, "duration": 813}
|
||||
{"episode": 39, "reward": 26.3002, "steps count": 700, "duration": 838}
|
||||
|
||||
> Sauvegarde du modèle de validation
|
||||
> Modèle de validation sauvegardé : models/burn_dqn_50_model.mpk
|
||||
> Chargement du modèle pour test
|
||||
> Chargement du modèle depuis : models/burn_dqn_50_model.mpk
|
||||
> Test avec le modèle chargé
|
||||
> Episode terminé. Récompense totale: 70.00, Étapes: 700
|
||||
|
|
@ -1,182 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
sync::mpsc::{self, Receiver, TryRecvError},
|
||||
thread,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
use renet::{
|
||||
transport::{
|
||||
ClientAuthentication, NetcodeClientTransport, NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES,
|
||||
},
|
||||
ClientId, ConnectionConfig, DefaultChannel, RenetClient, RenetServer, ServerEvent,
|
||||
};
|
||||
|
||||
// Helper struct to pass an username in the user data
|
||||
struct Username(String);
|
||||
|
||||
impl Username {
|
||||
fn to_netcode_user_data(&self) -> [u8; NETCODE_USER_DATA_BYTES] {
|
||||
let mut user_data = [0u8; NETCODE_USER_DATA_BYTES];
|
||||
if self.0.len() > NETCODE_USER_DATA_BYTES - 8 {
|
||||
panic!("Username is too big");
|
||||
}
|
||||
user_data[0..8].copy_from_slice(&(self.0.len() as u64).to_le_bytes());
|
||||
user_data[8..self.0.len() + 8].copy_from_slice(self.0.as_bytes());
|
||||
|
||||
user_data
|
||||
}
|
||||
|
||||
fn from_user_data(user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> Self {
|
||||
let mut buffer = [0u8; 8];
|
||||
buffer.copy_from_slice(&user_data[0..8]);
|
||||
let mut len = u64::from_le_bytes(buffer) as usize;
|
||||
len = len.min(NETCODE_USER_DATA_BYTES - 8);
|
||||
let data = user_data[8..len + 8].to_vec();
|
||||
let username = String::from_utf8(data).unwrap();
|
||||
Self(username)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
println!("Usage: server [SERVER_PORT] or client [SERVER_ADDR] [USER_NAME]");
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
let exec_type = &args[1];
|
||||
match exec_type.as_str() {
|
||||
"client" => {
|
||||
let server_addr: SocketAddr = args[2].parse().unwrap();
|
||||
let username = Username(args[3].clone());
|
||||
client(server_addr, username);
|
||||
}
|
||||
"server" => {
|
||||
let server_addr: SocketAddr = format!("0.0.0.0:{}", args[2]).parse().unwrap();
|
||||
server(server_addr);
|
||||
}
|
||||
_ => {
|
||||
println!("Invalid argument, first one must be \"client\" or \"server\".");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PROTOCOL_ID: u64 = 7;
|
||||
|
||||
fn server(public_addr: SocketAddr) {
|
||||
let connection_config = ConnectionConfig::default();
|
||||
let mut server: RenetServer = RenetServer::new(connection_config);
|
||||
|
||||
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
let server_config = ServerConfig {
|
||||
current_time,
|
||||
max_clients: 64,
|
||||
protocol_id: PROTOCOL_ID,
|
||||
public_addresses: vec![public_addr],
|
||||
authentication: ServerAuthentication::Unsecure,
|
||||
};
|
||||
let socket: UdpSocket = UdpSocket::bind(public_addr).unwrap();
|
||||
|
||||
let mut transport = NetcodeServerTransport::new(server_config, socket).unwrap();
|
||||
|
||||
let mut usernames: HashMap<ClientId, String> = HashMap::new();
|
||||
let mut received_messages = vec![];
|
||||
let mut last_updated = Instant::now();
|
||||
|
||||
loop {
|
||||
let now = Instant::now();
|
||||
let duration = now - last_updated;
|
||||
last_updated = now;
|
||||
|
||||
server.update(duration);
|
||||
transport.update(duration, &mut server).unwrap();
|
||||
|
||||
received_messages.clear();
|
||||
|
||||
while let Some(event) = server.get_event() {
|
||||
match event {
|
||||
ServerEvent::ClientConnected { client_id } => {
|
||||
let user_data = transport.user_data(client_id).unwrap();
|
||||
let username = Username::from_user_data(&user_data);
|
||||
usernames.insert(client_id, username.0);
|
||||
println!("Client {} connected.", client_id)
|
||||
}
|
||||
ServerEvent::ClientDisconnected { client_id, reason } => {
|
||||
println!("Client {} disconnected: {}", client_id, reason);
|
||||
usernames.remove_entry(&client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for client_id in server.clients_id() {
|
||||
while let Some(message) = server.receive_message(client_id, DefaultChannel::ReliableOrdered) {
|
||||
let text = String::from_utf8(message.into()).unwrap();
|
||||
let username = usernames.get(&client_id).unwrap();
|
||||
println!("Client {} ({}) sent text: {}", username, client_id, text);
|
||||
let text = format!("{}: {}", username, text);
|
||||
received_messages.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
for text in received_messages.iter() {
|
||||
server.broadcast_message(DefaultChannel::ReliableOrdered, text.as_bytes().to_vec());
|
||||
}
|
||||
|
||||
transport.send_packets(&mut server);
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
||||
fn client(server_addr: SocketAddr, username: Username) {
|
||||
let connection_config = ConnectionConfig::default();
|
||||
let mut client = RenetClient::new(connection_config);
|
||||
|
||||
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
|
||||
let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
let client_id = current_time.as_millis() as u64;
|
||||
let authentication = ClientAuthentication::Unsecure {
|
||||
server_addr,
|
||||
client_id,
|
||||
user_data: Some(username.to_netcode_user_data()),
|
||||
protocol_id: PROTOCOL_ID,
|
||||
};
|
||||
|
||||
let mut transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap();
|
||||
let stdin_channel: Receiver<String> = spawn_stdin_channel();
|
||||
|
||||
let mut last_updated = Instant::now();
|
||||
loop {
|
||||
let now = Instant::now();
|
||||
let duration = now - last_updated;
|
||||
last_updated = now;
|
||||
|
||||
client.update(duration);
|
||||
transport.update(duration, &mut client).unwrap();
|
||||
|
||||
if transport.is_connected() {
|
||||
match stdin_channel.try_recv() {
|
||||
Ok(text) => client.send_message(DefaultChannel::ReliableOrdered, text.as_bytes().to_vec()),
|
||||
Err(TryRecvError::Empty) => {}
|
||||
Err(TryRecvError::Disconnected) => panic!("Channel disconnected"),
|
||||
}
|
||||
|
||||
while let Some(text) = client.receive_message(DefaultChannel::ReliableOrdered) {
|
||||
let text = String::from_utf8(text.into()).unwrap();
|
||||
println!("{}", text);
|
||||
}
|
||||
}
|
||||
|
||||
transport.send_packets(&mut client).unwrap();
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_stdin_channel() -> Receiver<String> {
|
||||
let (tx, rx) = mpsc::channel::<String>();
|
||||
thread::spawn(move || loop {
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin().read_line(&mut buffer).unwrap();
|
||||
tx.send(buffer.trim_end().to_string()).unwrap();
|
||||
});
|
||||
rx
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue