chore: integrate multiplayer code (wip)

This commit is contained in:
Henri Bourcereau 2026-04-22 17:42:05 +02:00
parent 2838d59f30
commit 4f5e21becb
66 changed files with 6423 additions and 18 deletions

View file

@ -0,0 +1,7 @@
[package]
name = "protocol"
version = "0.1.0"
edition = "2024"
[dependencies]
serde = { version = "1.0.228", features = ["derive"] }

View file

@ -0,0 +1,72 @@
//! The ids for messages that we use. They will be used consistent across the server and the client.
//! Also contains the protocol structure for joining a game.
use serde::{Deserialize, Serialize};
/// The buffer sizes for the channels for intra VPS communication.
pub const CHANNEL_BUFFER_SIZE: usize = 256;
// Client -> Server.
/// The message to announce a new client (Client->Server) followed by u16 client id.
pub const NEW_CLIENT: u8 = 0;
/// The message size for a new client (Header + Client Id) (u8 + u16)
pub const NEW_CLIENT_MSG_SIZE: usize = 3;
/// A client disconnects from the game. (Client->Server) and removes him from the room. followed by u16 client id.
pub const CLIENT_DISCONNECTS: u8 = 1;
/// The disconnect client message size (Header + Client Id) (u8 + u16)
pub const CLIENT_DISCONNECT_MSG_SIZE: usize = 3;
/// Client -> Server RPC followed by u16 Clientid, followed by payload from postcard or other coding. (Client->Server)
pub const SERVER_RPC: u8 = 2;
/// The disconnection message that is used for disconnecting without any arguments, that gets passed through the web socket layer.
pub const CLIENT_DISCONNECTS_SELF: u8 = 3;
// Server -> Client
/// The server disconnects from the game and the room gets closed.
pub const SERVER_DISCONNECTS: u8 = 0;
/// The disconnection message is just the byte itself.
pub const SERVER_DISCONNECT_MSG_SIZE: usize = 1;
/// A client gets kicked, meant for the situation, when no more clients should get accepted. followed by u16 client id. The receiving tokio task has to act on its own. (Server -> Client)
pub const CLIENT_GETS_KICKED: u8 = 1;
/// Delta update. Followed by payload for every delta update. May carry several delta messages in one pass.
pub const DELTA_UPDATE: u8 = 2;
/// Flagging a full update. Followed by payload for full update.
pub const FULL_UPDATE: u8 = 3;
/// The message to reset the game. This is also followed by a full update. Difference is, that every client will get the full update.
pub const RESET: u8 = 4;
/// The error message we add.
pub const SERVER_ERROR: u8 = 5;
/// The response message for the handshake.
pub const HAND_SHAKE_RESPONSE: u8 = 6;
// Sizes of entries.
/// For the handshake we respond with player id (u16), rule variation (u16), and reconnect token (u64).
pub const HAND_SHAKE_RESPONSE_SIZE: usize = 13;
/// The size of a new client. (u16)
pub const CLIENT_ID_SIZE: usize = 2;
/// The join request. This struct is used on the server and on the client.
#[derive(Deserialize, Serialize)]
pub struct JoinRequest {
/// Which game do we want to join.
pub game_id: String,
/// Which room do we want to join.
pub room_id: String,
/// The rule variation that is applied, this gets only interpreted if a room gets constructed.
pub rule_variation: u16,
/// Do we want to create a room and act as a server?
pub create_room: bool,
/// Reconnect token from a previous session. `None` = fresh join/create, `Some` = reconnect.
pub reconnect_token: Option<u64>,
}

View file

@ -0,0 +1,29 @@
[package]
name = "relay-server"
version = "0.1.0"
edition = "2024"
[dependencies]
tokio = {version = "1.48.0", features = ["full"]}
axum = { version = "0.8.7", features = ["ws"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
futures-util = "0.3.31"
postcard = "1.1.3"
bytes = "1.11.0"
tracing = "0.1.41"
tower-http = { version = "0.6.7", features = ["fs", "cors"] }
protocol = {path = "../protocol"}
rand = "0.8"
# User management / auth
sqlx = { version = "0.8", features = ["sqlite", "runtime-tokio", "migrate"] }
tower-sessions = "0.14"
tower-sessions-sqlx-store = { version = "0.15", features = ["sqlite"] }
axum-login = "0.18"
argon2 = "0.5"
time = "0.3"
thiserror = "1"

View file

@ -0,0 +1,10 @@
[
{
"name" : "tic-tac-toe",
"max_players" : 10
},
{
"name" : "Ternio",
"max_players" : 3
}
]

View file

@ -0,0 +1,24 @@
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS game_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
game_id TEXT NOT NULL,
room_code TEXT NOT NULL,
started_at INTEGER NOT NULL,
ended_at INTEGER,
result TEXT
);
CREATE TABLE IF NOT EXISTS game_participants (
id INTEGER PRIMARY KEY AUTOINCREMENT,
game_record_id INTEGER NOT NULL REFERENCES game_records(id),
user_id INTEGER REFERENCES users(id),
player_id INTEGER NOT NULL,
outcome TEXT
);

View file

@ -0,0 +1,3 @@
-- Prevent duplicate participant rows if POST /games/result is called more than once.
CREATE UNIQUE INDEX IF NOT EXISTS idx_participants_unique
ON game_participants(game_record_id, player_id);

View file

@ -0,0 +1,95 @@
//! Authentication backend for axum-login.
//!
//! Implements [`AuthUser`] on [`db::User`] and provides [`AuthBackend`] which
//! validates credentials against the database using Argon2 password hashing.
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::password_hash::rand_core::OsRng;
use argon2::Argon2;
use axum_login::{AuthUser, AuthnBackend, UserId};
use sqlx::SqlitePool;
use crate::db;
// ── AuthUser ─────────────────────────────────────────────────────────────────
impl AuthUser for db::User {
type Id = i64;
fn id(&self) -> Self::Id {
self.id
}
/// Changing the password invalidates all existing sessions for this user.
fn session_auth_hash(&self) -> &[u8] {
self.password_hash.as_bytes()
}
}
// ── Credentials ──────────────────────────────────────────────────────────────
#[derive(Clone)]
pub struct Credentials {
pub username: String,
pub password: String,
}
// ── Error ────────────────────────────────────────────────────────────────────
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("database error: {0}")]
Database(#[from] sqlx::Error),
#[error("password hashing error")]
PasswordHash,
}
// ── Backend ───────────────────────────────────────────────────────────────────
#[derive(Clone)]
pub struct AuthBackend {
pool: SqlitePool,
}
impl AuthBackend {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl AuthnBackend for AuthBackend {
type User = db::User;
type Credentials = Credentials;
type Error = AuthError;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
let Some(user) = db::get_user_by_username(&self.pool, &creds.username).await? else {
return Ok(None);
};
let parsed = PasswordHash::new(&user.password_hash).map_err(|_| AuthError::PasswordHash)?;
let valid = Argon2::default()
.verify_password(creds.password.as_bytes(), &parsed)
.is_ok();
Ok(valid.then_some(user))
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
Ok(db::get_user_by_id(&self.pool, *user_id).await?)
}
}
// ── Password hashing helper ───────────────────────────────────────────────────
/// Hashes a plaintext password with Argon2id. Used by the registration endpoint.
pub fn hash_password(password: &str) -> Result<String, AuthError> {
let salt = SaltString::generate(&mut OsRng);
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map(|h| h.to_string())
.map_err(|_| AuthError::PasswordHash)
}

View file

@ -0,0 +1,214 @@
//! Database access layer.
//!
//! All SQLite interaction is funnelled through this module. Functions return
//! `sqlx::Result` so callers can handle errors uniformly.
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{SqlitePool, pool::PoolOptions};
use std::time::{SystemTime, UNIX_EPOCH};
/// A registered user as stored in the database.
#[derive(Clone, Debug, sqlx::FromRow)]
pub struct User {
pub id: i64,
pub username: String,
pub email: String,
pub password_hash: String,
pub created_at: i64,
}
/// Aggregated game statistics for a user's public profile.
#[derive(sqlx::FromRow)]
pub struct UserStats {
pub total: i64,
pub wins: i64,
pub losses: i64,
pub draws: i64,
}
/// A condensed game entry returned by [`get_user_games`].
#[derive(sqlx::FromRow)]
pub struct GameSummary {
pub id: i64,
pub game_id: String,
pub room_code: String,
pub started_at: i64,
pub ended_at: Option<i64>,
pub result: Option<String>,
pub outcome: Option<String>,
}
fn now_unix() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
/// Opens (or creates) the SQLite database at `path` and runs all pending migrations.
pub async fn init_db(path: &str) -> SqlitePool {
if let Some(parent) = std::path::Path::new(path).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent)
.await
.expect("Failed to create database directory");
}
}
let pool = PoolOptions::<sqlx::Sqlite>::new()
.max_connections(5)
.connect_with(
SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true),
)
.await
.expect("Failed to open SQLite database");
sqlx::migrate::Migrator::new(
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")),
)
.await
.expect("Failed to locate migrations directory")
.run(&pool)
.await
.expect("Failed to run database migrations");
pool
}
// ── Users ────────────────────────────────────────────────────────────────────
pub async fn create_user(
pool: &SqlitePool,
username: &str,
email: &str,
password_hash: &str,
) -> sqlx::Result<i64> {
let id = sqlx::query(
"INSERT INTO users (username, email, password_hash, created_at) VALUES (?, ?, ?, ?)",
)
.bind(username)
.bind(email)
.bind(password_hash)
.bind(now_unix())
.execute(pool)
.await?
.last_insert_rowid();
Ok(id)
}
pub async fn get_user_by_id(pool: &SqlitePool, id: i64) -> sqlx::Result<Option<User>> {
sqlx::query_as::<_, User>(
"SELECT id, username, email, password_hash, created_at FROM users WHERE id = ?",
)
.bind(id)
.fetch_optional(pool)
.await
}
pub async fn get_user_by_username(pool: &SqlitePool, username: &str) -> sqlx::Result<Option<User>> {
sqlx::query_as::<_, User>(
"SELECT id, username, email, password_hash, created_at FROM users WHERE username = ?",
)
.bind(username)
.fetch_optional(pool)
.await
}
// ── Game records ─────────────────────────────────────────────────────────────
/// Creates a new game record when a room opens. Returns the record id.
pub async fn insert_game_record(
pool: &SqlitePool,
game_id: &str,
room_code: &str,
) -> sqlx::Result<i64> {
let id = sqlx::query(
"INSERT INTO game_records (game_id, room_code, started_at) VALUES (?, ?, ?)",
)
.bind(game_id)
.bind(room_code)
.bind(now_unix())
.execute(pool)
.await?
.last_insert_rowid();
Ok(id)
}
/// Stamps `ended_at` and stores the opaque result JSON supplied by the game.
pub async fn close_game_record(
pool: &SqlitePool,
record_id: i64,
result_json: Option<&str>,
) -> sqlx::Result<()> {
// AND ended_at IS NULL prevents overwriting a result already set by POST /games/result
sqlx::query(
"UPDATE game_records SET ended_at = ?, result = ? WHERE id = ? AND ended_at IS NULL",
)
.bind(now_unix())
.bind(result_json)
.bind(record_id)
.execute(pool)
.await?;
Ok(())
}
/// Records a player's participation in a game. `user_id` is `None` for anonymous players.
pub async fn insert_participant(
pool: &SqlitePool,
record_id: i64,
user_id: Option<i64>,
player_id: u16,
outcome: Option<&str>,
) -> sqlx::Result<()> {
sqlx::query(
"INSERT OR IGNORE INTO game_participants (game_record_id, user_id, player_id, outcome)
VALUES (?, ?, ?, ?)",
)
.bind(record_id)
.bind(user_id)
.bind(player_id as i64)
.bind(outcome)
.execute(pool)
.await?;
Ok(())
}
/// Returns win/loss/draw counts for a user. All values are 0 when the user has no games.
pub async fn get_user_stats(pool: &SqlitePool, user_id: i64) -> sqlx::Result<UserStats> {
sqlx::query_as::<_, UserStats>(
"SELECT
COUNT(*) as total,
COALESCE(SUM(CASE WHEN outcome = 'win' THEN 1 ELSE 0 END), 0) as wins,
COALESCE(SUM(CASE WHEN outcome = 'loss' THEN 1 ELSE 0 END), 0) as losses,
COALESCE(SUM(CASE WHEN outcome = 'draw' THEN 1 ELSE 0 END), 0) as draws
FROM game_participants
WHERE user_id = ?",
)
.bind(user_id)
.fetch_one(pool)
.await
}
/// Returns a paginated list of games a user participated in, newest first.
pub async fn get_user_games(
pool: &SqlitePool,
user_id: i64,
page: i64,
per_page: i64,
) -> sqlx::Result<Vec<GameSummary>> {
sqlx::query_as::<_, GameSummary>(
"SELECT gr.id, gr.game_id, gr.room_code, gr.started_at, gr.ended_at, gr.result, gp.outcome
FROM game_records gr
JOIN game_participants gp ON gp.game_record_id = gr.id
WHERE gp.user_id = ?
ORDER BY gr.started_at DESC
LIMIT ? OFFSET ?",
)
.bind(user_id)
.bind(per_page)
.bind(page * per_page)
.fetch_all(pool)
.await
}

View file

@ -0,0 +1,599 @@
//! This module does the whole initialization and handshake thing.
//! The general protocol of connecting is :
//! WASM Client -> Websocket: postcard serialized join request.
//! Websocket -> WASM Client: u16 player id, u16 rule variation, u64 reconnect token.
use crate::db;
use crate::hand_shake::ClientServerSpecificData::{Client, Server};
use crate::hand_shake::DisconnectEndpointSpecification::{DisconnectClient, DisconnectServer};
use crate::lobby::{AppState, Room};
use axum::extract::ws::Message::Binary;
use axum::extract::ws::{Message, WebSocket};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{sink::SinkExt, stream::StreamExt};
use postcard::from_bytes;
use protocol::{
CHANNEL_BUFFER_SIZE, CLIENT_DISCONNECT_MSG_SIZE, CLIENT_DISCONNECTS, HAND_SHAKE_RESPONSE,
HAND_SHAKE_RESPONSE_SIZE, JoinRequest, NEW_CLIENT, NEW_CLIENT_MSG_SIZE,
SERVER_DISCONNECT_MSG_SIZE, SERVER_DISCONNECTS, SERVER_ERROR,
};
use rand::random;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{broadcast, mpsc};
/// Is called on error, sends a text message because e-websocket can not interpret closing messages.
/// This text message is encoded as a binary message.
async fn send_closing_message(sender: &mut SplitSink<WebSocket, Message>, closing_message: String) {
let raw_data = closing_message.as_bytes();
let mut msg = BytesMut::with_capacity(1 + raw_data.len());
msg.put_u8(SERVER_ERROR);
msg.put_slice(raw_data);
let _ = sender.send(Message::Binary(msg.into())).await;
let _ = sender.send(Message::Close(None)).await;
}
/// The handshake result we get for the joining the room.
pub struct HandshakeResult {
/// The id of the player we play.
pub player_id: u16,
/// The complete identifier of the room as stored in the hashmap.
pub room_id: String,
/// The rule variation we apply.
pub rule_variation: u16,
/// The reconnect token for this player — sent back to the client for localStorage storage.
pub token: u64,
/// The internal connection information.
pub specific_data: ClientServerSpecificData,
}
/// Contains all the channel information for internal communication.
pub enum ClientServerSpecificData {
/// In this case we are servicing the server.
Server(Receiver<Bytes>, broadcast::Sender<Bytes>),
/// In this case we are servicing a client.
Client(broadcast::Receiver<Bytes>, Sender<Bytes>),
}
/// This data is data we need to keep for the disconnect handling and cleanup.
pub struct DisconnectData {
/// The id of the player we play.
pub player_id: u16,
/// The complete identifier of the room as stored in the hashmap.
pub room_id: String,
/// The sender we use.
pub sender: DisconnectEndpointSpecification,
}
/// Contains the information where to send error data to in case of disconnection.
pub enum DisconnectEndpointSpecification {
/// If we are servicing the server, we broadcast the info to all clients.
DisconnectServer(broadcast::Sender<Bytes>),
/// If we are servicing the client, we send data to the server.
DisconnectClient(Sender<Bytes>),
}
/// Construction of DisconnectData from Handshake result.
impl From<&HandshakeResult> for DisconnectData {
fn from(value: &HandshakeResult) -> Self {
match &value.specific_data {
Server(_, internal_sender) => DisconnectData {
player_id: value.player_id,
room_id: value.room_id.clone(),
sender: DisconnectServer(internal_sender.clone()),
},
Client(_, internal_sender) => DisconnectData {
player_id: value.player_id,
room_id: value.room_id.clone(),
sender: DisconnectClient(internal_sender.clone()),
},
}
}
}
/// Gets an initial connection result, where a room is constructed
/// and game and existence / non existence of room is checked for legality.
struct InitialConnectionResult {
/// Flags, if we are a server.
is_server: bool,
/// The complete room we have for internal administration.
compound_room_id: String,
/// Which game do we want to join.
game_id: String,
/// Which room do we want to join.
room_id: String,
/// The rule variation that is applied, this gets only interpreted if a room gets constructed.
rule_variation: u16,
/// The maximum amount of players a room allows (0 = infinite).
max_players: u16,
/// Reconnect token from the client, if this is a reconnect attempt.
reconnect_token: Option<u64>,
}
/// Reads in the join request from the web socket, verifies if game exists and generates the final room name.
async fn get_initial_query(
sender: &mut SplitSink<WebSocket, Message>,
receiver: &mut SplitStream<WebSocket>,
state: Arc<AppState>,
) -> Option<InitialConnectionResult> {
// First we get a room opening and joining request. This is the first binary message we received.
let my_data = loop {
let Some(raw_data) = receiver.next().await else {
tracing::warn!("WebSocket closed before handshake completed");
send_closing_message(sender, "Initial error during handshake.".into()).await;
return None;
};
match raw_data {
Err(err) => {
tracing::error!(?err, "Initial error during handshake.");
send_closing_message(sender, "Initial error during handshake.".into()).await;
return None;
}
Ok(Binary(data)) => {
break data;
}
// We do not care about any other message like ping pong messages.
Ok(_) => {}
}
};
// Now we get some data and we try to convert it into the required format.
let working_struct = match from_bytes::<JoinRequest>(&my_data) {
Ok(req) => req,
Err(e) => {
tracing::error!(error = ?e, "Failed to parse join request");
send_closing_message(sender, "Failed to parse join request.".into()).await;
return None;
}
};
// Let us take a look, if the game exists.
let games = state.configs.read().await;
let game_exists = games.contains_key(&working_struct.game_id);
let max_players = if game_exists {
games[&working_struct.game_id]
} else {
0
};
drop(games);
if !game_exists {
tracing::error!(
optional_game = working_struct.game_id,
"Requested illegal game."
);
send_closing_message(sender, format!("Unknown game {}.", &working_struct.game_id)).await;
return None;
}
// The final room id is the combination of game and room id.
let room_id = format!(
"{}#{}",
working_struct.room_id.as_str(),
working_struct.game_id.as_str()
);
let is_server = working_struct.create_room;
Some(InitialConnectionResult {
is_server,
compound_room_id: room_id,
game_id: working_struct.game_id,
room_id: working_struct.room_id,
rule_variation: working_struct.rule_variation,
max_players,
reconnect_token: working_struct.reconnect_token,
})
}
/// Connects and eventually establishes a room.
pub async fn init_and_connect(
sender: &mut SplitSink<WebSocket, Message>,
receiver: &mut SplitStream<WebSocket>,
state: Arc<AppState>,
user_id: Option<i64>,
) -> Option<HandshakeResult> {
let start_result = get_initial_query(sender, receiver, state.clone()).await?;
if let Some(token) = start_result.reconnect_token {
process_handshake_reconnect(sender, state, start_result, token, user_id).await
} else if start_result.is_server {
process_handshake_server(sender, state, start_result, user_id).await
} else {
process_handshake_client(sender, state, start_result, user_id).await
}
}
/// Does the handshake, if we are connected to a client.
async fn process_handshake_client(
sender: &mut SplitSink<WebSocket, Message>,
state: Arc<AppState>,
initial_result: InitialConnectionResult,
user_id: Option<i64>,
) -> Option<HandshakeResult> {
let mut rooms = state.rooms.lock().await;
let Some(local_room) = rooms.get_mut(&initial_result.compound_room_id) else {
drop(rooms);
send_closing_message(
sender,
format!(
"Room {} does not exist for game {}.",
&initial_result.room_id, &initial_result.game_id
),
)
.await;
return None;
};
// Do we fit in? max_players == 0 means "infinite".
if initial_result.max_players != 0 && local_room.amount_of_players >= initial_result.max_players
{
drop(rooms);
send_closing_message(
sender,
format!(
"Room {} exceeded max amount of players {}.",
&initial_result.room_id, initial_result.max_players
),
)
.await;
return None;
}
// Save guard against the case, that we have run out of client ids.
if local_room.next_client_id > u16::MAX - 100 {
drop(rooms);
send_closing_message(
sender,
format!("Room {} run out of client ids.", &initial_result.room_id),
)
.await;
tracing::error!("Server run out of client ids.");
return None;
}
local_room.amount_of_players += 1;
let player_id = local_room.next_client_id;
local_room.next_client_id += 1;
let token: u64 = random();
local_room.player_tokens.insert(player_id, token);
local_room.connected_players.push(player_id);
local_room.user_ids.insert(player_id, user_id);
let to_server_sender = local_room.to_host_sender.clone();
let receiver = local_room.host_to_client_broadcaster.subscribe();
let rule_variation = local_room.rule_variation;
drop(rooms);
// Here we send a message to the server, that a new client has joined.
let mut msg = BytesMut::with_capacity(NEW_CLIENT_MSG_SIZE);
msg.put_u8(NEW_CLIENT); // Message-Type
msg.put_u16(player_id); // player id.
let result = to_server_sender.send(msg.into()).await;
if let Err(error) = result {
// We have to leave the room again.
let mut rooms = state.rooms.lock().await;
if let Some(room) = rooms.get_mut(&initial_result.compound_room_id) {
room.amount_of_players -= 1;
room.player_tokens.remove(&player_id);
}
drop(rooms);
tracing::error!(?error, "Server unexpectedly left during handshake");
send_closing_message(sender, "Server unexpectedly left during handshake".into()).await;
return None;
}
Some(HandshakeResult {
room_id: initial_result.compound_room_id,
player_id,
rule_variation,
token,
specific_data: Client(receiver, to_server_sender),
})
}
/// Opens a new room and generates the handshake result for the server.
async fn process_handshake_server(
sender: &mut SplitSink<WebSocket, Message>,
state: Arc<AppState>,
initial_result: InitialConnectionResult,
user_id: Option<i64>,
) -> Option<HandshakeResult> {
// Insert a game record before taking the rooms lock (best-effort: failures don't abort the handshake).
let game_record_id =
match db::insert_game_record(&state.db, &initial_result.game_id, &initial_result.room_id)
.await
{
Ok(id) => Some(id),
Err(e) => {
tracing::warn!("Failed to create game record for room {}: {e}", initial_result.room_id);
None
}
};
let mut rooms = state.rooms.lock().await;
if rooms.contains_key(&initial_result.compound_room_id) {
drop(rooms);
send_closing_message(
sender,
format!(
"Room {} already exists for game {}.",
&initial_result.room_id, &initial_result.game_id
),
)
.await;
// User error no need for error tracing.
return None;
}
// Here we create a new room.
let (to_server_sender, to_server_receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let (to_client_sender, _) = broadcast::channel(CHANNEL_BUFFER_SIZE);
let token: u64 = random();
let mut player_tokens = HashMap::new();
player_tokens.insert(0u16, token);
let mut user_ids = HashMap::new();
user_ids.insert(0u16, user_id);
let new_room = Room {
next_client_id: 1,
amount_of_players: 1,
rule_variation: initial_result.rule_variation,
to_host_sender: to_server_sender,
host_to_client_broadcaster: to_client_sender.clone(),
player_tokens,
host_connected: true,
connected_players: Vec::new(),
game_record_id,
user_ids,
};
rooms.insert(initial_result.compound_room_id.clone(), new_room);
drop(rooms);
let hand_shake_result = HandshakeResult {
room_id: initial_result.compound_room_id,
player_id: 0,
rule_variation: initial_result.rule_variation,
token,
specific_data: Server(to_server_receiver, to_client_sender),
};
Some(hand_shake_result)
}
/// Reconnects a previously connected player (host or client) using their stored token.
///
/// **Client reconnect**: resubscribes to the broadcast channel and notifies the host
/// via `NEW_CLIENT` so it delivers a fresh `FULL_UPDATE`.
///
/// **Host reconnect**: creates a new mpsc channel (the old one died with the WebSocket),
/// replaces `room.to_host_sender`, and queues `NEW_CLIENT` / `CLIENT_DISCONNECTS`
/// messages so the host backend can reconstruct who is currently in the room.
async fn process_handshake_reconnect(
sender: &mut SplitSink<WebSocket, Message>,
state: Arc<AppState>,
initial_result: InitialConnectionResult,
reconnect_token: u64,
user_id: Option<i64>,
) -> Option<HandshakeResult> {
let mut rooms = state.rooms.lock().await;
let Some(local_room) = rooms.get_mut(&initial_result.compound_room_id) else {
drop(rooms);
send_closing_message(
sender,
format!(
"Room {} no longer exists for game {}.",
&initial_result.room_id, &initial_result.game_id
),
)
.await;
return None;
};
// Find the player whose token matches.
let player_id = match local_room
.player_tokens
.iter()
.find(|&(_, &t)| t == reconnect_token)
.map(|(&id, _)| id)
{
Some(id) => id,
None => {
drop(rooms);
tracing::warn!("Reconnect attempt with invalid token in room {}", &initial_result.room_id);
send_closing_message(sender, "Invalid reconnect token.".into()).await;
return None;
}
};
// ------------------------------------------------------------------ Host reconnect
if player_id == 0 {
if local_room.host_connected {
drop(rooms);
send_closing_message(sender, "Host is already connected.".into()).await;
return None;
}
// Create a fresh mpsc channel (the previous receiver was dropped when the
// host's WebSocket closed).
let (new_sender, new_receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
local_room.to_host_sender = new_sender.clone();
local_room.host_connected = true;
local_room.user_ids.insert(0u16, user_id);
let broadcaster = local_room.host_to_client_broadcaster.clone();
let rule_variation = local_room.rule_variation;
// Collect the players we need to notify about.
let connected = local_room.connected_players.clone();
let all_non_host: Vec<u16> = local_room
.player_tokens
.keys()
.filter(|&&pid| pid != 0)
.copied()
.collect();
drop(rooms);
// Queue NEW_CLIENT for every currently connected player so the host backend
// increments remote_player_count and sends a FULL_UPDATE.
for pid in &connected {
let mut msg = BytesMut::with_capacity(NEW_CLIENT_MSG_SIZE);
msg.put_u8(NEW_CLIENT);
msg.put_u16(*pid);
let _ = new_sender.send(msg.into()).await;
}
// Queue CLIENT_DISCONNECTS for players who left while the host was away so
// the backend can start their grace-period timers.
for pid in all_non_host {
if !connected.contains(&pid) {
let mut msg = BytesMut::with_capacity(CLIENT_DISCONNECT_MSG_SIZE);
msg.put_u8(CLIENT_DISCONNECTS);
msg.put_u16(pid);
let _ = new_sender.send(msg.into()).await;
}
}
tracing::info!(room = &initial_result.room_id, "Host reconnected");
return Some(HandshakeResult {
room_id: initial_result.compound_room_id,
player_id: 0,
rule_variation,
token: reconnect_token,
specific_data: Server(new_receiver, broadcaster),
});
}
// ---------------------------------------------------------------- Client reconnect
local_room.amount_of_players += 1;
local_room.connected_players.push(player_id);
local_room.user_ids.insert(player_id, user_id);
let to_server_sender = local_room.to_host_sender.clone();
let broadcast_receiver = local_room.host_to_client_broadcaster.subscribe();
let rule_variation = local_room.rule_variation;
drop(rooms);
// Notify the host that this player has rejoined so it sends a FULL_UPDATE.
let mut msg = BytesMut::with_capacity(NEW_CLIENT_MSG_SIZE);
msg.put_u8(NEW_CLIENT);
msg.put_u16(player_id);
if let Err(error) = to_server_sender.send(msg.into()).await {
let mut rooms = state.rooms.lock().await;
if let Some(room) = rooms.get_mut(&initial_result.compound_room_id) {
room.amount_of_players -= 1;
room.connected_players.retain(|&p| p != player_id);
}
drop(rooms);
tracing::error!(?error, "Host unavailable during reconnect handshake");
send_closing_message(sender, "Host is no longer available.".into()).await;
return None;
}
tracing::info!(
player_id,
room = &initial_result.room_id,
"Player reconnected"
);
Some(HandshakeResult {
room_id: initial_result.compound_room_id,
player_id,
rule_variation,
token: reconnect_token,
specific_data: Client(broadcast_receiver, to_server_sender),
})
}
/// Informs the partner of the connection result, returns a bool as a success flag.
pub async fn inform_client_of_connection(
sender: &mut SplitSink<WebSocket, Message>,
status: &HandshakeResult,
) -> bool {
let mut msg = BytesMut::with_capacity(HAND_SHAKE_RESPONSE_SIZE);
msg.put_u8(HAND_SHAKE_RESPONSE);
msg.put_u16(status.player_id);
msg.put_u16(status.rule_variation);
msg.put_u64(status.token);
let result = sender.send(Message::Binary(msg.into())).await;
result.is_ok()
}
/// Performs the shutdown of the system and sends a last message.
pub async fn shutdown_connection(
wrapped_sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
disconnect_data: DisconnectData,
app_state: Arc<AppState>,
error_message: &'static str,
) {
match disconnect_data.sender {
DisconnectServer(broadcaster) => {
// Mark the host as disconnected and start a 30-second grace period.
// If the host reconnects within that window the grace task does nothing;
// otherwise it broadcasts SERVER_DISCONNECTS and removes the room.
{
let mut rooms = app_state.rooms.lock().await;
if let Some(room) = rooms.get_mut(&disconnect_data.room_id) {
room.host_connected = false;
}
}
let state_clone = app_state.clone();
let room_id = disconnect_data.room_id.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
let game_record_id = {
let mut rooms = state_clone.rooms.lock().await;
if let Some(room) = rooms.get(&room_id) {
if !room.host_connected {
let record_id = room.game_record_id;
rooms.remove(&room_id);
record_id
} else {
return; // host reconnected
}
} else {
return; // room already removed
}
};
// Room lock released — broadcast and close the DB record.
let mut msg = BytesMut::with_capacity(SERVER_DISCONNECT_MSG_SIZE);
msg.put_u8(SERVER_DISCONNECTS);
let _ = broadcaster.send(msg.into());
tracing::info!(room_id, "Host grace period expired — room removed");
if let Some(record_id) = game_record_id {
if let Err(e) = db::close_game_record(&state_clone.db, record_id, None).await {
tracing::warn!("Failed to close game record {record_id}: {e}");
}
}
});
}
DisconnectClient(sender) => {
// Inform server first.
let mut msg = BytesMut::with_capacity(CLIENT_DISCONNECT_MSG_SIZE);
msg.put_u8(CLIENT_DISCONNECTS);
msg.put_u16(disconnect_data.player_id);
let _ = sender.send(msg.into()).await;
// Subtract one client from the room.
let mut rooms = app_state.rooms.lock().await;
// Check if the room still exists.
if let Some(room) = rooms.get_mut(&disconnect_data.room_id) {
room.amount_of_players -= 1;
room.connected_players.retain(|&p| p != disconnect_data.player_id);
// Note: we intentionally keep the token in player_tokens so the
// client can use it to reconnect as long as the room exists.
}
drop(rooms);
}
}
let mut sender = wrapped_sender.lock().await;
// Send the message to the WASM point.
send_closing_message(&mut sender, error_message.into()).await;
}

View file

@ -0,0 +1,399 @@
//! HTTP endpoints for user management (Phases 2 & 4).
//!
//! Routes:
//! POST /auth/register
//! POST /auth/login
//! POST /auth/logout
//! GET /auth/me
//! GET /users/:username
//! GET /users/:username/games?page=0&per_page=20
//! POST /games/result
use axum::{
Json, Router,
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
};
use axum_login::AuthSession;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::sync::Arc;
use crate::auth::{AuthBackend, Credentials, hash_password};
use crate::db;
use crate::lobby::AppState;
// ── Router ────────────────────────────────────────────────────────────────────
pub fn router() -> Router<Arc<AppState>> {
Router::new()
.route("/auth/register", post(register))
.route("/auth/login", post(login))
.route("/auth/logout", post(logout))
.route("/auth/me", get(me))
.route("/users/{username}", get(user_profile))
.route("/users/{username}/games", get(user_games))
.route("/games/result", post(game_result))
.route("/games/{id}", get(game_detail))
}
// ── Error type ────────────────────────────────────────────────────────────────
enum AppError {
Database(sqlx::Error),
NotFound,
Conflict(&'static str),
BadRequest(&'static str),
Unauthorized,
Internal,
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
match self {
AppError::Database(e) => {
tracing::error!("database error: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response()
}
AppError::NotFound => StatusCode::NOT_FOUND.into_response(),
AppError::Conflict(msg) => (StatusCode::CONFLICT, msg).into_response(),
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg).into_response(),
AppError::Unauthorized => StatusCode::UNAUTHORIZED.into_response(),
AppError::Internal => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
}
impl From<sqlx::Error> for AppError {
fn from(e: sqlx::Error) -> Self {
AppError::Database(e)
}
}
fn is_unique_violation(e: &sqlx::Error) -> bool {
matches!(e, sqlx::Error::Database(db_err) if db_err.message().contains("UNIQUE constraint failed"))
}
// ── Request / response bodies ─────────────────────────────────────────────────
#[derive(Deserialize)]
struct RegisterBody {
username: String,
email: String,
password: String,
}
#[derive(Deserialize)]
struct LoginBody {
username: String,
password: String,
}
#[derive(Serialize)]
struct MeResponse {
id: i64,
username: String,
}
#[derive(Serialize)]
struct UserProfileResponse {
id: i64,
username: String,
created_at: i64,
total_games: i64,
wins: i64,
losses: i64,
draws: i64,
}
#[derive(Deserialize)]
struct GamesQuery {
#[serde(default)]
page: i64,
#[serde(default = "default_per_page")]
per_page: i64,
}
fn default_per_page() -> i64 {
20
}
#[derive(Serialize)]
struct GamesResponse {
games: Vec<GameSummaryResponse>,
}
#[derive(Serialize)]
struct GameSummaryResponse {
id: i64,
game_id: String,
room_code: String,
started_at: i64,
ended_at: Option<i64>,
result: Option<String>,
outcome: Option<String>,
}
impl From<db::GameSummary> for GameSummaryResponse {
fn from(g: db::GameSummary) -> Self {
Self {
id: g.id,
game_id: g.game_id,
room_code: g.room_code,
started_at: g.started_at,
ended_at: g.ended_at,
result: g.result,
outcome: g.outcome,
}
}
}
// ── Handlers ──────────────────────────────────────────────────────────────────
async fn register(
mut auth_session: AuthSession<AuthBackend>,
State(state): State<Arc<AppState>>,
Json(body): Json<RegisterBody>,
) -> Result<impl IntoResponse, AppError> {
if body.username.len() < 3 || body.username.len() > 30 {
return Err(AppError::BadRequest("username must be 330 characters"));
}
if body.password.len() < 8 {
return Err(AppError::BadRequest("password must be at least 8 characters"));
}
if !body.email.contains('@') {
return Err(AppError::BadRequest("invalid email address"));
}
let hash = hash_password(&body.password).map_err(|_| AppError::Internal)?;
let user_id = db::create_user(&state.db, &body.username, &body.email, &hash)
.await
.map_err(|e| {
if is_unique_violation(&e) {
AppError::Conflict("username or email already taken")
} else {
AppError::Database(e)
}
})?;
let user = db::get_user_by_id(&state.db, user_id)
.await?
.ok_or(AppError::Internal)?;
auth_session.login(&user).await.map_err(|_| AppError::Internal)?;
Ok((
StatusCode::CREATED,
Json(MeResponse {
id: user.id,
username: user.username,
}),
))
}
async fn login(
mut auth_session: AuthSession<AuthBackend>,
Json(body): Json<LoginBody>,
) -> Result<impl IntoResponse, AppError> {
let creds = Credentials {
username: body.username,
password: body.password,
};
let user = match auth_session.authenticate(creds).await {
Ok(Some(u)) => u,
Ok(None) => return Err(AppError::Unauthorized),
Err(_) => return Err(AppError::Internal),
};
auth_session.login(&user).await.map_err(|_| AppError::Internal)?;
Ok(Json(MeResponse {
id: user.id,
username: user.username,
}))
}
async fn logout(mut auth_session: AuthSession<AuthBackend>) -> Result<StatusCode, AppError> {
auth_session.logout().await.map_err(|_| AppError::Internal)?;
Ok(StatusCode::NO_CONTENT)
}
async fn me(auth_session: AuthSession<AuthBackend>) -> Result<impl IntoResponse, AppError> {
match auth_session.user {
Some(user) => Ok(Json(MeResponse {
id: user.id,
username: user.username,
})
.into_response()),
None => Ok(StatusCode::UNAUTHORIZED.into_response()),
}
}
async fn user_profile(
Path(username): Path<String>,
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, AppError> {
let user = db::get_user_by_username(&state.db, &username)
.await?
.ok_or(AppError::NotFound)?;
let stats = db::get_user_stats(&state.db, user.id).await?;
Ok(Json(UserProfileResponse {
id: user.id,
username: user.username,
created_at: user.created_at,
total_games: stats.total,
wins: stats.wins,
losses: stats.losses,
draws: stats.draws,
}))
}
async fn user_games(
Path(username): Path<String>,
Query(query): Query<GamesQuery>,
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, AppError> {
let per_page = query.per_page.clamp(1, 100);
let page = query.page.max(0);
let user = db::get_user_by_username(&state.db, &username)
.await?
.ok_or(AppError::NotFound)?;
let summaries = db::get_user_games(&state.db, user.id, page, per_page).await?;
Ok(Json(GamesResponse {
games: summaries.into_iter().map(Into::into).collect(),
}))
}
// ── Game detail (Phase 5) ─────────────────────────────────────────────────────
#[derive(sqlx::FromRow, Serialize)]
struct GameRecordRow {
id: i64,
game_id: String,
room_code: String,
started_at: i64,
ended_at: Option<i64>,
result: Option<String>,
}
#[derive(sqlx::FromRow, Serialize)]
struct ParticipantWithUsername {
player_id: i64,
outcome: Option<String>,
username: Option<String>,
}
#[derive(Serialize)]
struct GameDetailResponse {
id: i64,
game_id: String,
room_code: String,
started_at: i64,
ended_at: Option<i64>,
result: Option<String>,
participants: Vec<ParticipantWithUsername>,
}
async fn game_detail(
Path(id): Path<i64>,
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, AppError> {
let record = sqlx::query_as::<_, GameRecordRow>(
"SELECT id, game_id, room_code, started_at, ended_at, result
FROM game_records WHERE id = ?",
)
.bind(id)
.fetch_optional(&state.db)
.await?
.ok_or(AppError::NotFound)?;
let participants = sqlx::query_as::<_, ParticipantWithUsername>(
"SELECT gp.player_id, gp.outcome, u.username
FROM game_participants gp
LEFT JOIN users u ON u.id = gp.user_id
WHERE gp.game_record_id = ?
ORDER BY gp.player_id",
)
.bind(id)
.fetch_all(&state.db)
.await?;
Ok(Json(GameDetailResponse {
id: record.id,
game_id: record.game_id,
room_code: record.room_code,
started_at: record.started_at,
ended_at: record.ended_at,
result: record.result,
participants,
}))
}
// ── Game result recording (Phase 4) ──────────────────────────────────────────
#[derive(Deserialize)]
struct GameResultBody {
room_code: String,
game_id: String,
/// Opaque game-specific result, stored verbatim as JSON.
result: JsonValue,
/// Per-player outcomes keyed by player_id as a string ("0", "1", …).
/// Accepted values: "win", "loss", "draw". Missing keys → NULL outcome.
#[serde(default)]
outcomes: HashMap<String, String>,
}
#[derive(Serialize)]
struct GameResultResponse {
game_record_id: i64,
}
/// Called by the WASM host when a game ends.
///
/// The room code + game ID act as the shared secret (same trust level as WS join).
/// `close_game_record` is idempotent (no-op if already closed), and participant
/// inserts use `INSERT OR IGNORE`, so safe retries are supported.
async fn game_result(
State(state): State<Arc<AppState>>,
Json(body): Json<GameResultBody>,
) -> Result<impl IntoResponse, AppError> {
let compound_id = format!("{}#{}", body.room_code, body.game_id);
// Snapshot the fields we need while holding the lock, then release immediately.
let (game_record_id, user_ids) = {
let rooms = state.rooms.lock().await;
let room = rooms.get(&compound_id).ok_or(AppError::NotFound)?;
let record_id = room
.game_record_id
.ok_or(AppError::NotFound)?;
(record_id, room.user_ids.clone())
};
let result_json = serde_json::to_string(&body.result)
.map_err(|_| AppError::BadRequest("could not serialise result"))?;
db::close_game_record(&state.db, game_record_id, Some(&result_json)).await?;
for (player_id, user_id) in &user_ids {
let outcome = body.outcomes.get(&player_id.to_string()).map(String::as_str);
db::insert_participant(&state.db, game_record_id, *user_id, *player_id, outcome).await?;
}
tracing::info!(
game_record_id,
room = body.room_code,
"Game result recorded"
);
Ok(Json(GameResultResponse { game_record_id }))
}

View file

@ -0,0 +1,91 @@
//! This module handles game rooms where players connect and exchange messages.
//! It provides:
//! - [`Room`]: A game session with host-to-client broadcast channels
//! - [`AppState`]: Global state holding all active rooms and game configurations
//! - [`reload_config`]: Hot-reloading of game settings from `GameConfig.json`
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::fs;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::{broadcast, mpsc};
/// The game entry we have for one game.
#[derive(Serialize, Deserialize)]
pub struct GameEntry {
/// The name of the game.
pub name: String,
/// The maximum amount of players (0 = no limit)
pub max_players: u16,
}
type EntryList = Vec<GameEntry>;
/// The description of the room, the players play in
pub struct Room {
/// The next id a client gets, this is consecutively counted.
pub next_client_id: u16, // Needs Mutex
/// The amount of players currently in the room.
pub amount_of_players: u16, // Needs mutex.
/// This is a status counter for rule variation in a game (like coop vs semi-coop).
pub rule_variation: u16,
/// The sender to send messages to the host.
pub to_host_sender: mpsc::Sender<Bytes>, // Clone-able no Mutex!
/// The broad case sender needed to subscribe for the clients.
pub host_to_client_broadcaster: broadcast::Sender<Bytes>, // Clone-able -> no Mutex!
/// Reconnect tokens keyed by player id. Used to authenticate reconnect attempts.
pub player_tokens: HashMap<u16, u64>,
/// Whether the host WebSocket is currently active. False during the grace period
/// after host disconnect — the grace-period task will clean up the room if the
/// host does not reconnect in time.
pub host_connected: bool,
/// IDs of non-host players whose WebSocket is currently active.
/// Used to replay NEW_CLIENT / CLIENT_DISCONNECTS when the host reconnects.
pub connected_players: Vec<u16>,
/// Row id in `game_records` for this session. None when no authenticated player created the room.
pub game_record_id: Option<i64>,
/// Maps in-game player_id → database user_id. None means the player is anonymous.
pub user_ids: HashMap<u16, Option<i64>>,
}
/// The application state.
pub struct AppState {
/// The rooms we associate with several sessions.
pub rooms: Mutex<HashMap<String, Room>>,
/// Contains a mapping from game name to the maximum amount of players allowed.
pub configs: RwLock<HashMap<String, u16>>,
/// SQLite connection pool — shared across all request handlers.
pub db: SqlitePool,
}
impl AppState {
pub fn new(db: SqlitePool) -> Self {
Self {
rooms: Mutex::new(HashMap::new()),
configs: RwLock::new(HashMap::new()),
db,
}
}
}
/// Reloads the configuration file, that lists the games with the maximum number of players per room.
pub async fn reload_config(state: &Arc<AppState>) -> Result<(), String> {
let json_content = fs::read_to_string("GameConfig.json")
.await
.map_err(|e| format!("Failed to read file: {}", e))?;
let raw_data: EntryList =
serde_json::from_str(&json_content).map_err(|e| format!("Failed to parse JSON: {}", e))?;
let new_configs: HashMap<String, u16> = raw_data
.into_iter()
.map(|entry| (entry.name, entry.max_players))
.collect();
{
let mut configs = state.configs.write().await;
*configs = new_configs; // Replace all.
}
Ok(())
}

View file

@ -0,0 +1,239 @@
mod auth;
mod db;
mod hand_shake;
mod http;
mod lobby;
mod message_relay;
use crate::auth::AuthBackend;
use crate::hand_shake::{
ClientServerSpecificData, DisconnectData, inform_client_of_connection, init_and_connect,
shutdown_connection,
};
use crate::lobby::{AppState, reload_config};
use crate::message_relay::{handle_client_logic, handle_server_logic};
use axum::Router;
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::routing::get;
use axum_login::{AuthManagerLayerBuilder, AuthSession};
use bytes::Bytes;
use futures_util::SinkExt;
use futures_util::stream::StreamExt;
use std::sync::Arc;
use std::time::Duration;
use time::Duration as TimeDuration;
use tokio::sync::Mutex;
use axum::http::{HeaderName, Method};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
use tower_sessions::{Expiry, SessionManagerLayer};
use tower_sessions_sqlx_store::SqliteStore;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
/// Activates error tracing, spawns a watch dog task to eliminate eventual dead rooms, then it sets up the roting system to serve the
/// web sockets and listen for the pages enlist and reload. The server listens on port 8080.
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()),
)
.with(
tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true)
.with_target(true) // Modul-Path (e.g. relay_server::processing_module)
.with_thread_ids(true) // Thread-ID (helpful for Tokio)
.with_thread_names(true), // Thread-Name
)
.init();
let db_path = std::env::var("DATABASE_PATH").unwrap_or_else(|_| "data/relay.db".to_string());
let pool = db::init_db(&db_path).await;
let session_store = SqliteStore::new(pool.clone());
session_store
.migrate()
.await
.expect("Failed to initialize session store");
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(TimeDuration::days(30)));
let auth_backend = AuthBackend::new(pool.clone());
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
let app_state = Arc::new(AppState::new(pool));
let watchdog_state = app_state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1200)); // 20 Min
loop {
interval.tick().await;
cleanup_dead_rooms(&watchdog_state).await;
}
});
let initial = reload_config(&app_state).await;
if let Err(message) = initial {
tracing::error!(message, "Initial load error.");
panic!("Initial load error: {}", message);
}
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::list([
"http://localhost:9091".parse().unwrap(), // tic-tac-toe dev server
"http://localhost:9092".parse().unwrap(), // portal dev server
]))
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([
HeaderName::from_static("content-type"),
HeaderName::from_static("cookie"),
])
.allow_credentials(true);
let app = Router::new()
.route("/reload", get(reload_handler))
.route("/enlist", get(enlist_handler))
.route("/ws", get(websocket_handler))
.merge(http::router())
.nest_service("/portal", ServeDir::new("portal").not_found_service(ServeFile::new("portal/index.html")))
.with_state(app_state)
.layer(auth_layer)
.layer(cors)
.fallback_service(ServeDir::new(".").not_found_service(ServeFile::new("index.html")));
let listener = tokio::net::TcpListener::bind("127.0.0.1:8080")
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
/// Runs over all rooms and checks if they are diconnected from the server.
/// If so, it cleans them up. This is a fallback solution things should be handled internally otherwise.
async fn cleanup_dead_rooms(state: &Arc<AppState>) {
let mut rooms = state.rooms.lock().await;
rooms.retain(|room_id, room| {
// Keep rooms where the host is actively connected.
// Rooms with host_connected = false are in the grace period — the
// grace-period task spawned by shutdown_connection owns their cleanup.
let is_alive = room.host_connected && !room.to_host_sender.is_closed();
if !is_alive {
tracing::info!("Removing dead room: {}", room_id);
}
is_alive
});
}
/// Generates a list with the current rooms, the amount of players and info if this is a dead room.
async fn enlist_handler(State(state): State<Arc<AppState>>) -> String {
let rooms = state.rooms.lock().await;
rooms
.iter()
.map(|(name, room)| {
format!(
"Room: {:<30} Variation: {:03} Players: {:03} is alive: {}",
name,
room.rule_variation,
room.amount_of_players,
!room.to_host_sender.is_closed()
)
})
.collect::<Vec<_>>()
.join("\n")
}
/// Forces the reload of the config file and lists the content. This enables the adding of new games
/// without restarting the service.
async fn reload_handler(State(state): State<Arc<AppState>>) -> String {
let res = reload_config(&state).await;
match res {
Ok(_) => state
.configs
.read()
.await
.iter()
.map(|(key, players)| {
format!("Game: {:<40} Maximum Amount of Players: {}", key, players)
})
.collect::<Vec<_>>()
.join("\n"),
Err(e) => {
format!("Config reload failed: {}", e)
}
}
}
/// This function gets immediately called and upgrades the web response to a web socket.
async fn websocket_handler(
ws: WebSocketUpgrade,
auth_session: AuthSession<AuthBackend>,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
let user_id = auth_session.user.map(|u| u.id);
ws.on_upgrade(move |socket| websocket(socket, state, user_id))
}
/// Does the whole handling from start to finish: Handshake -> Handling of logic depending on if we are connected to
/// the server or client -> Shut down processing.
async fn websocket(stream: WebSocket, state: Arc<AppState>, user_id: Option<i64>) {
// By splitting, we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();
let handshake_result = init_and_connect(&mut sender, &mut receiver, state.clone(), user_id).await;
if handshake_result.is_none() {
// We quit here, as the handshake did not work out.
return;
}
let base_data = handshake_result.unwrap();
let disconnect_data = DisconnectData::from(&base_data);
let success = inform_client_of_connection(&mut sender, &base_data).await;
let wrapped_sender = Arc::new(Mutex::new(sender));
// Ping-Task to keep alive.
let ping_sender = wrapped_sender.clone();
let ping_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
interval.tick().await; // Skip first tick.
loop {
interval.tick().await;
let mut s = ping_sender.lock().await;
if s.send(Message::Ping(Bytes::new())).await.is_err() {
break;
}
}
});
let mut error_message = "Connection to server lost";
if success {
match base_data.specific_data {
ClientServerSpecificData::Server(internal_receiver, internal_sender) => {
error_message = handle_server_logic(
wrapped_sender.clone(),
receiver,
internal_receiver,
internal_sender,
)
.await;
}
ClientServerSpecificData::Client(internal_receiver, internal_sender) => {
error_message = handle_client_logic(
wrapped_sender.clone(),
receiver,
internal_receiver,
internal_sender,
base_data.player_id,
)
.await;
}
}
}
ping_task.abort();
shutdown_connection(wrapped_sender, disconnect_data, state, error_message).await;
}

View file

@ -0,0 +1,354 @@
//! WebSocket message routing for the relay server.
//!
//! This module handles bidirectional communication between game hosts and clients.
//! It spawns paired Tokio tasks for each connection that:
//! - Validate and filter messages by type (preventing illegal commands)
//! - Route host broadcasts to subscribed clients
//! - Forward client RPCs to the host with injected player IDs
//! - Manage sync state so clients only receive deltas after a full update
//!
//! The relay server never interprets game logic — it only validates message types
//! and routes bytes between endpoints.
use axum::extract::ws::{Message, WebSocket};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use protocol::*;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use tokio::sync::broadcast::Sender;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::mpsc::Receiver;
/// Spawns bidirectional message handlers for a game host connection.
///
/// Creates two concurrent tasks:
/// - **Send task**: Forwards client messages (joins, disconnects, RPCs) to the host
/// - **Receive task**: Broadcasts host messages (updates, kicks) to all clients
///
/// When either task completes (connection lost, protocol error, intentional disconnect),
/// the other is aborted and the room should be cleaned up by the caller.
///
/// # Returns
/// A static string describing why the connection ended (for logging/debugging).
pub async fn handle_server_logic(
sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
receiver: SplitStream<WebSocket>,
internal_receiver: Receiver<Bytes>,
internal_sender: broadcast::Sender<Bytes>,
) -> &'static str {
let mut send_task =
tokio::spawn(async move { send_logic_server(sender, internal_receiver).await });
let mut receive_task =
tokio::spawn(async move { receive_logic_server(receiver, internal_sender).await });
// If any one of the tasks run to completion, we abort the other.
let result = tokio::select! {
res_a = &mut send_task => {receive_task.abort(); res_a},
res_b = &mut receive_task => {send_task.abort(); res_b},
};
result.unwrap_or_else(|err| {
tracing::error!(?err, "Error while handling server logic.");
"Internal panic in server side logic."
})
}
/// Receives messages from the game host and broadcasts them to all clients.
///
/// Allowed message types from host:
/// - [`CLIENT_GETS_KICKED`]: Remove a specific player
/// - [`DELTA_UPDATE`]: Incremental game state change
/// - [`FULL_UPDATE`]: Complete game state (for new/desynced clients)
/// - [`RESET`]: Game restart signal
/// - [`SERVER_DISCONNECTS`]: Graceful shutdown (triggers cleanup)
///
/// Any other message type is rejected as a protocol violation.
async fn receive_logic_server(
mut receiver: SplitStream<WebSocket>,
internal_sender: Sender<Bytes>,
) -> &'static str {
while let Some(state) = receiver.next().await {
match state {
Ok(Message::Binary(bytes)) => {
if bytes.is_empty() {
tracing::error!("Illegal empty message in receive logic server.");
return "Illegal empty message received.";
}
if bytes[0] == SERVER_DISCONNECTS {
// This something normal to be expected.
return "Server disconnected intentionally";
}
if !matches!(
bytes[0],
CLIENT_GETS_KICKED | DELTA_UPDATE | FULL_UPDATE | RESET
) {
tracing::error!(
message_type = bytes[0],
"Illegal message type Server->Client."
);
return "Illegal Server -> Client command.";
}
// All messages are simply passed through.
let res = internal_sender.send(bytes);
// An error may occur, if there are no further clients available.
// As a rule of a thumb the server should not send any messages, if he does not know of any clients.
// Currently logged as a warning, as it is unclear, if this is strictly avoidable.
if let Err(error) = res {
tracing::warn!(?error, "Sending to no clients.");
}
}
Ok(_) => {} // Ignore other messages (ping/pong handled by axum)
Err(_) => {
return "Connection lost.";
}
}
}
"Connection lost."
}
/// Forwards aggregated client messages to the game host.
///
/// Allowed message types to host:
/// - [`NEW_CLIENT`]: Player joined notification
/// - [`CLIENT_DISCONNECTS`]: Player left notification
/// - [`SERVER_RPC`]: Game action from a client (with player ID prepended)
///
/// This task owns the WebSocket sender lock for its lifetime to ensure
/// sequential message delivery to the host.
async fn send_logic_server(
sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
mut internal_receiver: Receiver<Bytes>,
) -> &'static str {
while let Some(bytes) = internal_receiver.recv().await {
if bytes.is_empty() {
tracing::error!("Illegal internal empty message in send logic server.");
return "Illegal empty message received.";
}
if !matches!(bytes[0], NEW_CLIENT | CLIENT_DISCONNECTS | SERVER_RPC) {
tracing::error!(
message_type = bytes[0],
"Unknown internal Client->Server command"
);
return "Unknown internal Client->Server command";
}
// Simply pass on the message.
let res = sender.lock().await.send(Message::Binary(bytes)).await;
if let Err(err) = res {
tracing::error!(?err, "Error in communication with server endpoint.");
return "Error in communication with server endpoint.";
}
}
// In normal shutdown procedure that should not happen, because we are responsible for closing the channel.
tracing::error!("Internal channel on server was unexpectedly closed.");
"Internal channel closed."
}
/// Spawns bidirectional message handlers for a game client connection.
///
/// Creates two concurrent tasks:
/// - **Send task**: Delivers host broadcasts to this client (with sync state filtering)
/// - **Receive task**: Forwards client RPCs to the host (with player ID injection)
///
/// # Arguments
/// * `player_id` - Unique identifier assigned to this client for the session
///
/// # Returns
/// A static string describing why the connection ended.
pub async fn handle_client_logic(
sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
receiver: SplitStream<WebSocket>,
internal_receiver: tokio::sync::broadcast::Receiver<Bytes>,
internal_sender: tokio::sync::mpsc::Sender<Bytes>,
player_id: u16,
) -> &'static str {
let mut send_task =
tokio::spawn(async move { send_logic_client(sender, internal_receiver, player_id).await });
let mut receive_task =
tokio::spawn(
async move { receive_logic_client(receiver, internal_sender, player_id).await },
);
// If any one of the tasks run to completion, we abort the other.
let result = tokio::select! {
res_a = &mut send_task => {receive_task.abort(); res_a},
res_b = &mut receive_task => {send_task.abort(); res_b},
};
result.unwrap_or_else(|err| {
tracing::error!(?err, "Internal panic in client side logic.");
"Internal panic in client side logic."
})
}
/// Receives messages from a client and forwards them to the host.
///
/// Allowed message types from client:
/// - [`SERVER_RPC`]: Game action — gets player ID injected before forwarding
/// - [`CLIENT_DISCONNECTS_SELF`]: Graceful disconnect (triggers cleanup)
///
/// # Player ID Injection
/// RPC messages are transformed from `[SERVER_RPC, payload...]` to
/// `[SERVER_RPC, player_id_high, player_id_low, payload...]` so the host
/// knows which player sent the action.
async fn receive_logic_client(
mut receiver: SplitStream<WebSocket>,
internal_sender: tokio::sync::mpsc::Sender<Bytes>,
player_id: u16,
) -> &'static str {
while let Some(state) = receiver.next().await {
match state {
Ok(Message::Binary(bytes)) => {
if bytes.is_empty() {
tracing::error!("Illegal empty message received in receive logic client.");
return "Illegal empty message received.";
}
match bytes[0] {
SERVER_RPC => {
// Inject player ID after command byte
let mut msg = BytesMut::with_capacity(bytes.len() + CLIENT_ID_SIZE);
msg.put_u8(SERVER_RPC);
msg.put_u16(player_id);
msg.put_slice(&bytes[1..]);
let res = internal_sender.send(msg.into()).await;
if let Err(error) = res {
tracing::error!(?error, "Error in internal broadcast.");
return "Error in internal broadcast.";
}
}
CLIENT_DISCONNECTS_SELF => {
return "Client disconnected intentionally";
}
_ => {
tracing::error!(command = ?bytes[0], "Illegal command from client.");
return "Illegal Command from client";
}
}
}
Ok(_) => {} // Ignore other messages
Err(_) => {
return "Connection lost.";
}
}
}
"Connection lost."
}
/// Delivers host broadcasts to a specific client with sync state management.
///
/// # Sync State Machine
/// Clients start unsynced and must receive a [`FULL_UPDATE`] or [`RESET`] before
/// processing [`DELTA_UPDATE`] messages. This prevents clients from applying
/// deltas to an unknown base state.
///
/// ```text
/// [Unsynced] --FULL_UPDATE--> [Synced] --DELTA_UPDATE--> [Synced]
/// [Unsynced] --RESET-------> [Synced]
/// [Synced] --DELTA_UPDATE--> [Synced] (forwarded)
/// [Unsynced] --DELTA_UPDATE--> [Unsynced] (dropped)
/// ```
///
/// # Filtered Messages
/// - [`CLIENT_GETS_KICKED`]: Only terminates if `player_id` matches
/// - [`SERVER_DISCONNECTS`]: Always terminates
///
/// # Error Handling
/// Returns immediately if the broadcast channel lags (buffer overflow),
/// as the client cannot recover from missed messages.
async fn send_logic_client(
sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
mut internal_receiver: tokio::sync::broadcast::Receiver<Bytes>,
player_id: u16,
) -> &'static str {
let mut is_synced = false;
loop {
let state = internal_receiver.recv().await;
match state {
Err(RecvError::Closed) => {
tracing::error!("Internal channel closed.");
return "Internal channel closed.";
}
Err(RecvError::Lagged(skipped)) => {
tracing::warn!(
skipped_messages = skipped,
"Lagging started on internal channel."
);
return "Lagging on internal channel - Computer too slow.";
}
Ok(mut bytes) => {
if bytes.is_empty() {
tracing::error!("Illegal empty message received.");
return "Illegal empty message received.";
}
match bytes[0] {
SERVER_DISCONNECTS => {
return "Server has left the game.";
}
CLIENT_GETS_KICKED => {
if bytes.len() < 3 {
tracing::error!("Malformed CLIENT_GETS_KICKED message");
return "Malformed message received.";
}
bytes.get_u8(); // Skip command byte
let meant_client = bytes.get_u16();
// We have to see if we are meant.
if meant_client == player_id {
return "We got rejected by server.";
}
}
DELTA_UPDATE => {
if is_synced {
let res = sender.lock().await.send(Message::Binary(bytes)).await;
if let Err(error) = res {
tracing::error!(
?error,
"Error in communication with client endpoint."
);
return "Error in communication with client endpoint.";
}
}
// Silently drop deltas for unsynced clients
}
FULL_UPDATE => {
if !is_synced {
is_synced = true;
let res = sender.lock().await.send(Message::Binary(bytes)).await;
if let Err(error) = res {
tracing::error!(
?error,
"Error in communication with client endpoint."
);
return "Error in communication with client endpoint.";
}
}
// Drop redundant full updates for already synced clients
}
RESET => {
// We simply forward the message and are definitively synced here.
is_synced = true;
let res = sender.lock().await.send(Message::Binary(bytes)).await;
if let Err(error) = res {
tracing::error!(?error, "Error in communication with client endpoint.");
return "Error in communication with client endpoint.";
}
}
_ => {
tracing::error!(
message = bytes[0],
"Illegal message on client side received."
);
return "Illegal message on client side received.";
}
}
}
}
}
}