Compare commits

..

22 Commits

Author SHA1 Message Date
Ryan Walters
cc06cd88a1 refactor: allow optional database in setup, use derived default 2025-09-18 22:58:38 -05:00
Ryan Walters
e2c725cb95 feat: allow health check forcing in debug, setup test mocking, plan out integration tests 2025-09-18 22:42:00 -05:00
Ryan Walters
350f92ab21 tests: setup basic tests, integration tests with testcontainers 2025-09-18 21:33:15 -05:00
Ryan Walters
3ad00bdcba chore: setup lib for testing, image handling notes in README 2025-09-18 13:18:53 -05:00
Ryan Walters
7f9d3e9158 feat: implement r2 image upload for avatars 2025-09-18 13:18:14 -05:00
Ryan Walters
56e02e7253 refactor: remove unnecessary HashMap for passing code/state strings, formatter lifetime tweak 2025-09-17 13:18:58 -05:00
Ryan Walters
e2f3f6790f refactor: create common pkce handling, max_age on link cookie 2025-09-17 13:08:48 -05:00
Ryan Walters
1be59f474d feat: add Server header middleware, bump version to v0.4.0 2025-09-17 12:37:12 -05:00
Ryan Walters
916428fe76 feat: setup healthcheck route & background task 2025-09-17 12:32:52 -05:00
Ryan Walters
e02c2286bb chore: add .scripts with local postgres setup script, setup todo list in README 2025-09-17 12:23:55 -05:00
Ryan Walters
c12dc11d8f feat: normalize provider details into oauth_accounts table, auth linking intent, provider array in profile response 2025-09-17 11:17:31 -05:00
Ryan Walters
1cf3b901e8 feat: users table with sqlx, migrations, data persistence 2025-09-17 09:43:52 -05:00
Ryan Walters
ac1417aabc feat: discord oauth provider, setup provider list route, add 'active' method, common type alias 2025-09-17 09:23:31 -05:00
Ryan Walters
8e23fb66a4 feat: setup smarter PKCE map purging & BasicClient type alias, smarter EnvFilter string building 2025-09-17 04:06:52 -05:00
Ryan Walters
92acb07b04 feat: setup tracing calls throughout project 2025-09-17 04:05:59 -05:00
Ryan Walters
18e750fa61 feat: add tracing/tracing-subscriber, setup CustomPrettyFormatter & CustomJsonFormatter 2025-09-17 03:48:35 -05:00
Ryan Walters
8d9c0621c9 feat: proper shutdown timeout handling 2025-09-17 03:41:13 -05:00
Ryan Walters
750b47b609 feat: add SIGINT/SIGTERM graceful shutdown handling 2025-09-17 03:36:59 -05:00
Ryan Walters
b1fae907ee chore: add railway.json drainingSeconds 2025-09-17 03:33:39 -05:00
Ryan Walters
f3db44c48b feat: setup github provider with generic trait, proper routes, session & jwt handling, errors & user agent 2025-09-17 03:33:18 -05:00
Ryan Walters
264478bdaa chore: reformat recipes, add server/docker recipes, strip symbols for release 2025-09-17 01:30:04 -05:00
Ryan Walters
f69a5c7d52 feat: initial server config & Dockerfile 2025-09-16 22:13:35 -05:00
34 changed files with 7651 additions and 113 deletions

16
.dockerignore Normal file
View File

@@ -0,0 +1,16 @@
# Build artifacts
/target
/dist
/emsdk
*.exe
/pacman/assets
/assets
# Development files
/.git
/*.md
/Justfile
/bacon.toml
/rust-toolchain.toml
/rustfmt.toml

3
.gitignore vendored
View File

@@ -23,3 +23,6 @@ flamegraph.svg
# Logs
*.log
# Sensitive
*.env

186
.scripts/postgres.ts Normal file
View File

@@ -0,0 +1,186 @@
import { $ } from "bun";
import { readFileSync, writeFileSync, existsSync } from "fs";
import { join, dirname } from "path";
import { fileURLToPath } from "url";
import { createInterface } from "readline";
// Helper function to get user input
async function getUserChoice(
prompt: string,
choices: string[],
defaultIndex: number = 1
): Promise<string> {
// Check if we're in an interactive TTY
if (!process.stdin.isTTY) {
console.log(
"Non-interactive environment detected; selecting default option " +
defaultIndex
);
return String(defaultIndex);
}
console.log(prompt);
choices.forEach((choice, index) => {
console.log(`${index + 1}. ${choice}`);
});
// Use readline for interactive input
const rl = createInterface({
input: process.stdin,
output: process.stdout,
});
return new Promise((resolve) => {
const askForChoice = () => {
rl.question("Enter your choice (1-3): ", (answer) => {
const choice = answer.trim();
if (["1", "2", "3"].includes(choice)) {
rl.close();
resolve(choice);
} else {
console.log("Invalid choice. Please enter 1, 2, or 3.");
askForChoice();
}
});
};
askForChoice();
});
}
// Get repository root path from script location
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
const repoRoot = join(__dirname, "..");
const envPath = join(repoRoot, "pacman-server", ".env");
console.log("Checking for .env file...");
// Check if .env file exists and read it
let envContent = "";
let envLines: string[] = [];
let databaseUrlLine = -1;
let databaseUrlValue = "";
if (existsSync(envPath)) {
console.log("Found .env file, reading...");
envContent = readFileSync(envPath, "utf-8");
envLines = envContent.split("\n");
// Parse .env file for DATABASE_URL
for (let i = 0; i < envLines.length; i++) {
const line = envLines[i].trim();
if (line.match(/^[A-Z_][A-Z0-9_]*=.*$/)) {
if (line.startsWith("DATABASE_URL=")) {
databaseUrlLine = i;
databaseUrlValue = line.substring(13); // Remove "DATABASE_URL="
break;
}
}
}
} else {
console.log("No .env file found, will create one");
}
// Determine user's choice
let userChoice = "2"; // Default to print
if (databaseUrlLine !== -1) {
console.log(`Found existing DATABASE_URL: ${databaseUrlValue}`);
userChoice = await getUserChoice("\nChoose an action:", [
"Quit",
"Print (create container, print DATABASE_URL)",
"Replace (update DATABASE_URL in .env)",
]);
if (userChoice === "1") {
console.log("Exiting...");
process.exit(0);
}
} else {
console.log("No existing DATABASE_URL found");
// Ask what to do when no .env file or DATABASE_URL exists
if (!existsSync(envPath)) {
userChoice = await getUserChoice(
"\nNo .env file found. What would you like to do?",
[
"Print (create container, print DATABASE_URL)",
"Create .env file and add DATABASE_URL",
"Quit",
]
);
if (userChoice === "3") {
console.log("Exiting...");
process.exit(0);
}
} else {
console.log("Will add DATABASE_URL to existing .env file");
}
}
// Check if container exists
console.log("Checking for existing container...");
const containerExists =
await $`docker ps -a --filter name=pacman-server-postgres --format "{{.Names}}"`
.text()
.then((names) => names.trim() === "pacman-server-postgres")
.catch(() => false);
let shouldReplaceContainer = false;
if (containerExists) {
console.log("Container already exists");
// Always ask what to do if container exists
const replaceChoice = await getUserChoice(
"\nContainer exists. What would you like to do?",
["Use existing container", "Replace container (remove and create new)"],
1
);
shouldReplaceContainer = replaceChoice === "2";
if (shouldReplaceContainer) {
console.log("Removing existing container...");
await $`docker rm --force --volumes pacman-server-postgres`;
} else {
console.log("Using existing container");
}
}
// Create container if needed
if (!containerExists || shouldReplaceContainer) {
console.log("Creating PostgreSQL container...");
await $`docker run --detach --name pacman-server-postgres --publish 5432:5432 --env POSTGRES_USER=postgres --env POSTGRES_PASSWORD=postgres --env POSTGRES_DB=pacman-server postgres:17`;
}
// Format DATABASE_URL
const databaseUrl =
"postgresql://postgres:postgres@localhost:5432/pacman-server";
// Handle the final action based on user choice
if (userChoice === "2") {
// Print option
console.log(`\nDATABASE_URL=${databaseUrl}`);
} else if (
userChoice === "3" ||
(databaseUrlLine === -1 && userChoice === "2")
) {
// Replace or add to .env file
if (databaseUrlLine !== -1) {
// Replace existing line
console.log("Updating DATABASE_URL in .env file...");
envLines[databaseUrlLine] = `DATABASE_URL=${databaseUrl}`;
writeFileSync(envPath, envLines.join("\n"));
console.log("Updated .env file");
} else {
// Add new line
console.log("Adding DATABASE_URL to .env file...");
const newContent =
envContent +
(envContent.endsWith("\n") ? "" : "\n") +
`DATABASE_URL=${databaseUrl}\n`;
writeFileSync(envPath, newContent);
console.log("Added to .env file");
}
}

4861
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -13,17 +13,8 @@ keywords = ["game", "pacman", "arcade", "sdl2"]
categories = ["games", "emulators"]
publish = false
[workspace.dependencies]
# Common dependencies that might be shared across crates
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
axum = "0.7"
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
[profile.dev]
incremental = true
# Release profile for profiling (essentially the default 'release' profile with debug enabled)
[profile.profile]
@@ -32,7 +23,8 @@ debug = true
# Undo the customizations for our release profile
opt-level = 3
lto = false
panic = 'unwind'
panic = "abort"
strip = "symbols"
# Optimized release profile for size
[profile.release]

View File

@@ -1,48 +1,78 @@
set shell := ["bash", "-c"]
set windows-shell := ["powershell.exe", "-NoLogo", "-Command"]
binary_extension := if os() == "windows" { ".exe" } else { "" }
# !!! --ignore-filename-regex should be used on both reports & coverage testing
# !!! --remap-path-prefix prevents the absolute path from being used in the generated report
# Display available recipes
default:
just --list
# Generate HTML report (for humans, source line inspection)
# Open HTML coverage report
html: coverage
cargo llvm-cov report \
--remap-path-prefix \
--html \
--open
cargo llvm-cov report \
# prevents the absolute path from being used in the generated report
--remap-path-prefix \
--html \
--open
# Display report (for humans)
# Display coverage report
report-coverage: coverage
cargo llvm-cov report --remap-path-prefix
cargo llvm-cov report --remap-path-prefix
# Run & generate LCOV report (as base report)
# Generate baseline LCOV report
coverage:
cargo +nightly llvm-cov \
--lcov \
--remap-path-prefix \
--workspace \
--output-path lcov.info \
--profile coverage \
--no-fail-fast nextest
cargo +nightly llvm-cov \
--lcov \
--remap-path-prefix \
--workspace \
--output-path lcov.info \
--profile coverage \
--no-fail-fast nextest
# Profile the project using 'samply'
# Profile the project using samply
samply:
cargo build --profile profile
samply record ./target/profile/pacman{{ binary_extension }}
cargo build --profile profile
samply record ./target/profile/pacman{{ binary_extension }}
# Build the project for Emscripten
web *args:
bun run pacman/web.build.ts {{args}};
caddy file-server --root pacman/dist
bun run pacman/web.build.ts {{args}};
caddy file-server --root pacman/dist
# Run cargo fix
# Fix linting errors & formatting
fix:
cargo fix --workspace --lib --allow-dirty
cargo fmt --all
cargo fix --workspace --lib --allow-dirty
cargo fmt --all
# Push commits & tags
push:
git push origin --tags;
git push
git push origin --tags;
git push
# Create a postgres container for the server
server-postgres:
bun run .scripts/postgres.ts
# Build the server image
server-image:
# build the server image
docker build \
--platform linux/amd64 \
--file ./pacman-server/Dockerfile \
--tag pacman-server \
.
# Build and run the server in a Docker container
run-server: server-image
# remove the server container if it exists
docker rm --force --volumes pacman-server
# run the server container
docker run \
--rm \
--stop-timeout 2 \
--name pacman-server \
--publish 3000:3000 \
--env PORT=3000 \
--env-file pacman-server/.env \
pacman-server

View File

@@ -1,9 +1,9 @@
[package]
name = "pacman-server"
version = "0.1.1"
version = "0.4.0"
authors.workspace = true
edition.workspace = true
rust-version = "1.86.0"
rust-version = "1.87.0"
description = "A leaderboard API for the Pac-Man game"
readme.workspace = true
homepage.workspace = true
@@ -14,4 +14,49 @@ categories.workspace = true
publish.workspace = true
default-run = "pacman-server"
[lib]
name = "pacman_server"
path = "src/lib.rs"
[dependencies]
axum = { version = "0.8", features = ["macros"] }
tokio = { version = "1", features = ["full"] }
oauth2 = "5"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sqlx = { version = "0.8", features = [
"runtime-tokio-rustls",
"postgres",
"chrono",
] }
chrono = { version = "0.4", features = ["serde", "clock"] }
figment = { version = "0.10", features = ["env"] }
dotenvy = "0.15"
dashmap = "6.1"
axum-cookie = "0.2"
async-trait = "0.1"
jsonwebtoken = { version = "9.3", default-features = false }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
tracing-futures = { version = "0.2.5", features = ["tokio"] }
time = { version = "0.3", features = ["macros", "formatting"] }
yansi = "1"
s3-tokio = { version = "0.39.6", default-features = false }
rustls = { version = "0.23", features = ["ring"] }
fast_image_resize = { version = "5.3", features = ["image"] }
image = { version = "0.25", features = ["png", "jpeg"] }
sha2 = "0.10"
mockall = "0.13.1"
# validator = { version = "0.16", features = ["derive"] }
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
http = "1"
hyper = { version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["server", "tokio", "http1"] }
bytes = "1"
anyhow = "1"
axum-test = "18.1.0"
pretty_assertions = "1.4.1"
testcontainers = "0.25.0"

46
pacman-server/Dockerfile Normal file
View File

@@ -0,0 +1,46 @@
ARG RUST_VERSION=1.89.0
FROM lukemathwalker/cargo-chef:latest-rust-${RUST_VERSION} AS chef
WORKDIR /app
# -- Planner stage --
FROM chef AS planner
COPY . .
RUN cargo chef prepare --bin pacman-server --recipe-path recipe.json
# -- Builder stage --
FROM chef AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook --release --bin pacman-server --recipe-path recipe.json
# Copy the source code AFTER, so that dependencies are already cached
COPY . .
# Install build dependencies, then build the server
RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cargo build --package pacman-server --release --bin pacman-server
# -- Runtime stage --
FROM debian:bookworm-slim AS runtime
WORKDIR /app
COPY --from=builder /app/target/release/pacman-server /usr/local/bin/pacman-server
# Install runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tzdata \
&& rm -rf /var/lib/apt/lists/*
ARG TZ=Etc/UTC
ENV TZ=${TZ}
# Optional build-time environment variable for embedding the Git commit SHA
ARG RAILWAY_GIT_COMMIT_SHA
ENV RAILWAY_GIT_COMMIT_SHA=${RAILWAY_GIT_COMMIT_SHA}
# Specify PORT at build-time or run-time, default to 3000
ARG PORT=3000
ENV PORT=${PORT}
EXPOSE ${PORT}
CMD ["sh", "-c", "exec /usr/local/bin/pacman-server"]

View File

@@ -6,16 +6,58 @@ This crate is a webserver that hosts an OAuth login and leaderboard API for the
## Features
- [ ] Axum Webserver
- [ ] Database
- [ ] OAuth
- [ ] Discord
- [ ] GitHub
- [ ] Google (?)
- [ ] Leaderboard API
- [x] Axum Webserver
- [x] Health Check
- [ ] Inbound Rate Limiting
- [ ] Outbound Rate Limiting
- [ ] Provider Circuit Breaker
- [x] Database
- [x] OAuth
- [x] Discord
- [x] GitHub
- [ ] Google
- [ ] Leaderboard
- [ ] Score Submission
- [ ] Score Listings
- [ ] Pagination
- [ ] Global / Daily
- [ ] Name Restrictions & Flagging
- [ ] Avatars
- [ ] 8-bit Conversion
- [ ] Storage?
- [ ] Common Server/Client Crate
- [ ] CI/CD & Tests
## Todo
1. Refresh Token Handling (Encryption, Expiration & Refresh Timings)
2. Refresh Token Background Job
3. S3 Storage for Avatars
4. Common Server/Client Crate, Basics
5. Crate-level Log Level Configuration
6. Span Tracing
7. Avatar Pixelization
8. Leaderboard API
9. React-based Frontend
10. Name Restrictions & Flagging
11. Simple CI/CD Checks & Tests
12. API Rate Limiting (outbound provider requests)
13. API Rate Limiting (inbound requests, by IP, by User)
14. Provider Circuit Breaker
15. Merge migration files
## Notes
### Image Handling
Avatar images are stored in S3 as follows:
- `avatars/{user_public_id}/{avatar_hash}.original.png`
- `avatars/{user_public_id}/{avatar_hash}.mini.png`
- The original image is converted to PNG and resized to a maximum of 512x512 pixels.
- Ideally, non-square images are fitted to a square.
- The mini image is converted to PNG and resized to a maximum of 16x16, 24x24, or 32x32 pixels. TBD.
- All images receive a Content-Type header of `image/png`.
Image processing is handled immediately asynchronously, allowing a valid presigned URL to be generated immediately.

View File

@@ -0,0 +1,15 @@
-- users table
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
username TEXT NOT NULL,
display_name TEXT NULL,
email TEXT NULL,
avatar_url TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (provider, provider_user_id)
);
CREATE INDEX IF NOT EXISTS idx_users_provider ON users (provider, provider_user_id);

View File

@@ -0,0 +1,18 @@
-- OAuth accounts linked to a single user
CREATE TABLE IF NOT EXISTS oauth_accounts (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
email TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (provider, provider_user_id)
);
-- Ensure we can look up by email efficiently
CREATE INDEX IF NOT EXISTS idx_oauth_accounts_email ON oauth_accounts (email);
-- Optional: ensure users email uniqueness if desired; keep NULLs allowed
ALTER TABLE users
ADD CONSTRAINT users_email_unique UNIQUE (email);

View File

@@ -0,0 +1,15 @@
-- Move provider-specific profile fields from users to oauth_accounts
-- Add provider profile fields to oauth_accounts
ALTER TABLE oauth_accounts
ADD COLUMN IF NOT EXISTS username TEXT,
ADD COLUMN IF NOT EXISTS display_name TEXT NULL,
ADD COLUMN IF NOT EXISTS avatar_url TEXT NULL;
-- Drop provider-specific fields from users (keep email as canonical)
ALTER TABLE users
DROP COLUMN IF EXISTS provider,
DROP COLUMN IF EXISTS provider_user_id,
DROP COLUMN IF EXISTS username,
DROP COLUMN IF EXISTS display_name,
DROP COLUMN IF EXISTS avatar_url;

View File

@@ -0,0 +1,9 @@
{
"$schema": "https://railway.com/railway.schema.json",
"deploy": {
"drainingSeconds": 10,
"healthcheckPath": "/health",
"healthcheckTimeout": 90,
"restartPolicyMaxRetries": 3
}
}

166
pacman-server/src/app.rs Normal file
View File

@@ -0,0 +1,166 @@
use axum::{routing::get, Router};
use axum_cookie::CookieLayer;
use dashmap::DashMap;
use jsonwebtoken::{DecodingKey, EncodingKey};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Notify, RwLock};
use tokio::task::JoinHandle;
use crate::data::pool::PgPool;
use crate::{auth::AuthRegistry, config::Config, image::ImageStorage, routes};
#[derive(Debug, Clone, Default)]
pub struct Health {
migrations: bool,
database: bool,
}
impl Health {
pub fn ok(&self) -> bool {
self.migrations && self.database
}
pub fn set_migrations(&mut self, done: bool) {
self.migrations = done;
}
pub fn set_database(&mut self, ok: bool) {
self.database = ok;
}
}
#[derive(Clone)]
pub struct AppState {
pub auth: Arc<AuthRegistry>,
pub sessions: Arc<DashMap<String, crate::auth::provider::AuthUser>>,
pub jwt_encoding_key: Arc<EncodingKey>,
pub jwt_decoding_key: Arc<DecodingKey>,
pub db: Arc<PgPool>,
pub health: Arc<RwLock<Health>>,
pub image_storage: Arc<ImageStorage>,
pub healthchecker_task: Arc<RwLock<Option<JoinHandle<()>>>>,
}
impl AppState {
pub async fn new(config: Config, auth: AuthRegistry, db: PgPool, shutdown_notify: Arc<Notify>) -> Self {
Self::new_with_database(config, auth, db, shutdown_notify, true).await
}
pub async fn new_with_database(
config: Config,
auth: AuthRegistry,
db: PgPool,
shutdown_notify: Arc<Notify>,
use_database: bool,
) -> Self {
let jwt_secret = config.jwt_secret.clone();
// Initialize image storage
let image_storage = match ImageStorage::from_config(&config) {
Ok(storage) => Arc::new(storage),
Err(e) => {
tracing::warn!(error = %e, "Failed to initialize image storage, avatar processing will be disabled");
// Create a dummy storage that will fail gracefully
Arc::new(ImageStorage::new(&config, "dummy").unwrap_or_else(|_| panic!("Failed to create dummy image storage")))
}
};
let app_state = Self {
auth: Arc::new(auth),
sessions: Arc::new(DashMap::new()),
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
db: Arc::new(db),
health: Arc::new(RwLock::new(Health::default())),
image_storage,
healthchecker_task: Arc::new(RwLock::new(None)),
};
// Start the healthchecker task only if database is being used
if use_database {
let health_state = app_state.health.clone();
let db_pool = app_state.db.clone();
let healthchecker_task = app_state.healthchecker_task.clone();
let task = tokio::spawn(async move {
tracing::trace!("Health checker task started");
let mut backoff: u32 = 1;
let mut next_sleep = Duration::from_secs(0);
loop {
tokio::select! {
_ = shutdown_notify.notified() => {
tracing::trace!("Health checker received shutdown notification; exiting");
break;
}
_ = tokio::time::sleep(next_sleep) => {
// Run health check
}
}
// Run the actual health check
let ok = sqlx::query("SELECT 1").execute(&*db_pool).await.is_ok();
{
let mut h = health_state.write().await;
h.set_database(ok);
}
if ok {
tracing::trace!(database_ok = true, "Health check succeeded; scheduling next run in 90s");
backoff = 1;
next_sleep = Duration::from_secs(90);
} else {
backoff = (backoff.saturating_mul(2)).min(60);
tracing::trace!(database_ok = false, backoff, "Health check failed; backing off");
next_sleep = Duration::from_secs(backoff as u64);
}
}
});
// Store the task handle
let mut task_handle = healthchecker_task.write().await;
*task_handle = Some(task);
}
app_state
}
/// Force an immediate health check (debug mode only)
pub async fn check_health(&self) -> bool {
let ok = sqlx::query("SELECT 1").execute(&*self.db).await.is_ok();
let mut h = self.health.write().await;
h.set_database(ok);
ok
}
}
/// Create the application router with all routes and middleware
pub fn create_router(app_state: AppState) -> Router {
Router::new()
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
.route("/health", get(routes::health_handler))
.route("/auth/providers", get(routes::list_providers_handler))
.route("/auth/{provider}", get(routes::oauth_authorize_handler))
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
.route("/logout", get(routes::logout_handler))
.route("/profile", get(routes::profile_handler))
.with_state(app_state)
.layer(CookieLayer::default())
.layer(axum::middleware::from_fn(inject_server_header))
}
/// Inject the server header into responses
async fn inject_server_header(
req: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> Result<axum::response::Response, axum::http::StatusCode> {
let mut res = next.run(req).await;
res.headers_mut().insert(
axum::http::header::SERVER,
axum::http::HeaderValue::from_static(SERVER_HEADER_VALUE),
);
Ok(res)
}
// Constant value for the Server header: "<crate>/<version>"
const SERVER_HEADER_VALUE: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));

View File

@@ -0,0 +1,128 @@
use axum::{response::IntoResponse, response::Redirect};
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{trace, warn};
use crate::auth::{
pkce::PkceManager,
provider::{AuthUser, OAuthProvider},
};
use crate::errors::ErrorResponse;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscordUser {
pub id: String,
pub username: String,
pub global_name: Option<String>,
pub email: Option<String>,
pub avatar: Option<String>,
}
pub async fn fetch_discord_user(
http_client: &reqwest::Client,
access_token: &str,
) -> Result<DiscordUser, Box<dyn std::error::Error + Send + Sync>> {
let response = http_client
.get("https://discord.com/api/users/@me")
.header("Authorization", format!("Bearer {}", access_token))
.header("User-Agent", crate::config::USER_AGENT)
.send()
.await?;
if !response.status().is_success() {
warn!(status = %response.status(), endpoint = "/users/@me", "Discord API returned an error");
return Err(format!("Discord API error: {}", response.status()).into());
}
let user: DiscordUser = response.json().await?;
Ok(user)
}
pub struct DiscordProvider {
pub client: super::OAuthClient,
pub http: reqwest::Client,
pkce: PkceManager,
}
impl DiscordProvider {
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
Arc::new(Self {
client,
http,
pkce: PkceManager::default(),
})
}
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
let ext = if avatar_hash.starts_with("a_") { "gif" } else { "png" };
format!("https://cdn.discordapp.com/avatars/{}/{}.{}", user_id, avatar_hash, ext)
}
}
#[async_trait::async_trait]
impl OAuthProvider for DiscordProvider {
fn id(&self) -> &'static str {
"discord"
}
fn label(&self) -> &'static str {
"Discord"
}
async fn authorize(&self) -> axum::response::Response {
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
let (authorize_url, csrf_state) = self
.client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge)
.add_scope(Scope::new("identify".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
self.pkce.store_verifier(csrf_state.secret(), verifier);
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
Redirect::to(authorize_url.as_str()).into_response()
}
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
let Some(verifier) = self.pkce.take_verifier(state) else {
warn!(%state, "Missing or expired PKCE verifier for state parameter");
return Err(ErrorResponse::bad_request(
"invalid_request",
Some("missing or expired pkce verifier for state".into()),
));
};
let token = self
.client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
.request_async(&self.http)
.await
.map_err(|e| {
warn!(error = %e, %state, "Token exchange with Discord failed");
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
})?;
let user = fetch_discord_user(&self.http, token.access_token().secret())
.await
.map_err(|e| {
warn!(error = %e, "Failed to fetch Discord user profile");
ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e)))
})?;
let avatar_url = match (&user.id, &user.avatar) {
(id, Some(hash)) => Some(Self::avatar_url_for(id, hash)),
_ => None,
};
Ok(AuthUser {
id: user.id,
username: user.username,
name: user.global_name,
email: user.email,
avatar_url,
})
}
}

View File

@@ -0,0 +1,160 @@
use axum::{response::IntoResponse, response::Redirect};
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{trace, warn};
use crate::{
auth::{
pkce::PkceManager,
provider::{AuthUser, OAuthProvider},
},
errors::ErrorResponse,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitHubUser {
pub id: u64,
pub login: String,
pub name: Option<String>,
pub email: Option<String>,
pub avatar_url: String,
pub html_url: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GitHubEmail {
pub email: String,
pub primary: bool,
pub verified: bool,
pub visibility: Option<String>,
}
/// Fetch user information from GitHub API
pub async fn fetch_github_user(
http_client: &reqwest::Client,
access_token: &str,
) -> Result<GitHubUser, Box<dyn std::error::Error + Send + Sync>> {
let response = http_client
.get("https://api.github.com/user")
.header("Authorization", format!("Bearer {}", access_token))
.header("Accept", "application/vnd.github.v3+json")
.header("User-Agent", crate::config::USER_AGENT)
.send()
.await?;
if !response.status().is_success() {
warn!(status = %response.status(), endpoint = "/user", "GitHub API returned an error");
return Err(format!("GitHub API error: {}", response.status()).into());
}
let user: GitHubUser = response.json().await?;
Ok(user)
}
pub struct GitHubProvider {
pub client: super::OAuthClient,
pub http: reqwest::Client,
pkce: PkceManager,
}
impl GitHubProvider {
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
Arc::new(Self {
client,
http,
pkce: PkceManager::default(),
})
}
}
#[async_trait::async_trait]
impl OAuthProvider for GitHubProvider {
fn id(&self) -> &'static str {
"github"
}
fn label(&self) -> &'static str {
"GitHub"
}
async fn authorize(&self) -> axum::response::Response {
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
let (authorize_url, csrf_state) = self
.client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge)
.add_scope(Scope::new("user:email".to_string()))
.add_scope(Scope::new("read:user".to_string()))
.url();
// store verifier keyed by the returned state
self.pkce.store_verifier(csrf_state.secret(), verifier);
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
Redirect::to(authorize_url.as_str()).into_response()
}
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
let Some(verifier) = self.pkce.take_verifier(state) else {
warn!(%state, "Missing or expired PKCE verifier for state parameter");
return Err(ErrorResponse::bad_request(
"invalid_request",
Some("missing or expired pkce verifier for state".into()),
));
};
let token = self
.client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
.request_async(&self.http)
.await
.map_err(|e| {
warn!(error = %e, %state, "Token exchange with GitHub failed");
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
})?;
let user = fetch_github_user(&self.http, token.access_token().secret())
.await
.map_err(|e| {
warn!(error = %e, "Failed to fetch GitHub user profile");
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
})?;
let _emails = fetch_github_emails(&self.http, token.access_token().secret())
.await
.map_err(|e| {
warn!(error = %e, "Failed to fetch GitHub user emails");
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e)))
})?;
Ok(AuthUser {
id: user.id.to_string(),
username: user.login,
name: user.name,
email: user.email,
avatar_url: Some(user.avatar_url),
})
}
}
impl GitHubProvider {}
/// Fetch user emails from GitHub API
pub async fn fetch_github_emails(
http_client: &reqwest::Client,
access_token: &str,
) -> Result<Vec<GitHubEmail>, Box<dyn std::error::Error + Send + Sync>> {
let response = http_client
.get("https://api.github.com/user/emails")
.header("Authorization", format!("Bearer {}", access_token))
.header("Accept", "application/vnd.github.v3+json")
.header("User-Agent", crate::config::USER_AGENT)
.send()
.await?;
if !response.status().is_success() {
return Err(format!("GitHub API error: {}", response.status()).into());
}
let emails: Vec<GitHubEmail> = response.json().await?;
Ok(emails)
}

View File

@@ -0,0 +1,64 @@
use std::collections::HashMap;
use std::sync::Arc;
use oauth2::{basic::BasicClient, EndpointNotSet, EndpointSet};
use crate::config::Config;
pub mod discord;
pub mod github;
pub mod pkce;
pub mod provider;
type OAuthClient =
BasicClient<oauth2::EndpointSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointSet>;
pub struct AuthRegistry {
providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>>,
}
impl AuthRegistry {
pub fn new(config: &Config) -> Result<Self, oauth2::url::ParseError> {
let http = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("HTTP client should build");
let github_client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet> =
BasicClient::new(oauth2::ClientId::new(config.github_client_id.clone()))
.set_client_secret(oauth2::ClientSecret::new(config.github_client_secret.clone()))
.set_auth_uri(oauth2::AuthUrl::new("https://github.com/login/oauth/authorize".to_string())?)
.set_token_uri(oauth2::TokenUrl::new(
"https://github.com/login/oauth/access_token".to_string(),
)?)
.set_redirect_uri(
oauth2::RedirectUrl::new(format!("{}/auth/github/callback", config.public_base_url))
.expect("Invalid redirect URI"),
);
let mut providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>> = HashMap::new();
providers.insert("github", github::GitHubProvider::new(github_client, http.clone()));
// Discord OAuth client
let discord_client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet> =
BasicClient::new(oauth2::ClientId::new(config.discord_client_id.clone()))
.set_client_secret(oauth2::ClientSecret::new(config.discord_client_secret.clone()))
.set_auth_uri(oauth2::AuthUrl::new("https://discord.com/api/oauth2/authorize".to_string())?)
.set_token_uri(oauth2::TokenUrl::new("https://discord.com/api/oauth2/token".to_string())?)
.set_redirect_uri(
oauth2::RedirectUrl::new(format!("{}/auth/discord/callback", config.public_base_url))
.expect("Invalid redirect URI"),
);
providers.insert("discord", discord::DiscordProvider::new(discord_client, http));
Ok(Self { providers })
}
pub fn get(&self, id: &str) -> Option<&Arc<dyn provider::OAuthProvider>> {
self.providers.get(id)
}
pub fn values(&self) -> impl Iterator<Item = &Arc<dyn provider::OAuthProvider>> {
self.providers.values()
}
}

View File

@@ -0,0 +1,84 @@
use dashmap::DashMap;
use oauth2::PkceCodeChallenge;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use tracing::{trace, warn};
#[derive(Debug, Clone)]
pub struct PkceRecord {
pub verifier: String,
pub created_at: Instant,
}
#[derive(Default)]
pub struct PkceManager {
pkce: DashMap<String, PkceRecord>,
last_purge_at_secs: AtomicU32,
pkce_additions: AtomicU32,
}
impl PkceManager {
pub fn generate_challenge(&self) -> (PkceCodeChallenge, String) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
trace!("PKCE challenge generated");
(pkce_challenge, pkce_verifier.secret().to_string())
}
pub fn store_verifier(&self, state: &str, verifier: String) {
self.pkce.insert(
state.to_string(),
PkceRecord {
verifier,
created_at: Instant::now(),
},
);
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
self.maybe_purge_stale_entries();
trace!(state = state, "Stored PKCE verifier for state");
}
pub fn take_verifier(&self, state: &str) -> Option<String> {
let Some(record) = self.pkce.remove(state).map(|e| e.1) else {
trace!(state = state, "PKCE verifier not found for state");
return None;
};
// Verify PKCE TTL
if Instant::now().duration_since(record.created_at) > Duration::from_secs(5 * 60) {
warn!(state = state, "PKCE verifier expired for state");
return None;
}
trace!(state = state, "PKCE verifier retrieved for state");
Some(record.verifier)
}
fn maybe_purge_stale_entries(&self) {
// Purge when at least 5 minutes passed or more than 128 additions occurred
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
const ADDITIONS_THRESHOLD: u32 = 128;
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs() as u32,
Err(_) => return,
};
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
let additions = self.pkce_additions.load(Ordering::Relaxed);
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
return;
}
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
let now_inst = Instant::now();
for entry in self.pkce.iter() {
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
self.pkce.remove(entry.key());
}
}
// Reset counters after purge
self.pkce_additions.store(0, Ordering::Relaxed);
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
}
}

View File

@@ -0,0 +1,28 @@
use async_trait::async_trait;
use mockall::automock;
use serde::Serialize;
use crate::errors::ErrorResponse;
#[derive(Debug, Clone, Serialize)]
pub struct AuthUser {
pub id: String,
pub username: String,
pub name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
}
#[automock]
#[async_trait]
pub trait OAuthProvider: Send + Sync {
fn id(&self) -> &'static str;
fn label(&self) -> &'static str;
fn active(&self) -> bool {
true
}
async fn authorize(&self) -> axum::response::Response;
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse>;
}

View File

@@ -0,0 +1,79 @@
use figment::{providers::Env, value::UncasedStr, Figment};
use serde::{Deserialize, Deserializer};
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
// Database URL
pub database_url: String,
// Discord Credentials
#[serde(deserialize_with = "deserialize_string_from_any")]
pub discord_client_id: String,
pub discord_client_secret: String,
// GitHub Credentials
#[serde(deserialize_with = "deserialize_string_from_any")]
pub github_client_id: String,
pub github_client_secret: String,
// S3 Credentials
pub s3_access_key: String,
pub s3_secret_access_key: String,
pub s3_bucket_name: String,
pub s3_public_base_url: String,
// Server Details
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_host")]
pub host: std::net::IpAddr,
#[serde(default = "default_shutdown_timeout")]
pub shutdown_timeout_seconds: u32,
// Public base URL used for OAuth redirect URIs
pub public_base_url: String,
// JWT
pub jwt_secret: String,
}
// Standard User-Agent: name/version (+site)
pub const USER_AGENT: &str = concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION"),
" (+https://pacman.xevion.dev)"
);
fn default_host() -> std::net::IpAddr {
"0.0.0.0".parse().unwrap()
}
fn default_port() -> u16 {
3000
}
fn default_shutdown_timeout() -> u32 {
5
}
fn deserialize_string_from_any<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
use serde_json::Value;
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Number(n) => Ok(n.to_string()),
_ => Err(serde::de::Error::custom("Expected string or number")),
}
}
pub fn load_config() -> Config {
Figment::new()
.merge(Env::raw().map(|key| {
if key == UncasedStr::new("RAILWAY_DEPLOYMENT_DRAINING_SECONDS") {
"SHUTDOWN_TIMEOUT_SECONDS".into()
} else {
key.into()
}
}))
.extract()
.expect("Failed to load config")
}

View File

@@ -0,0 +1,2 @@
pub mod pool;
pub mod user;

View File

@@ -0,0 +1,21 @@
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use tracing::{info, warn};
pub type PgPool = Pool<Postgres>;
pub async fn create_pool(immediate: bool, database_url: &str, max_connections: u32) -> PgPool {
info!(immediate, "Connecting to PostgreSQL");
let options = PgPoolOptions::new().max_connections(max_connections);
if immediate {
options.connect(database_url).await.unwrap_or_else(|e| {
warn!(error = %e, "Failed to connect to PostgreSQL");
panic!("database connect failed: {}", e);
})
} else {
options
.connect_lazy(database_url)
.expect("Failed to create lazy database pool")
}
}

View File

@@ -0,0 +1,162 @@
use serde::Serialize;
use sqlx::FromRow;
#[derive(Debug, Clone, Serialize, FromRow)]
pub struct User {
pub id: i64,
pub email: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, FromRow)]
pub struct OAuthAccount {
pub id: i64,
pub user_id: i64,
pub provider: String,
pub provider_user_id: String,
pub email: Option<String>,
pub username: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
pub async fn find_user_by_email(pool: &sqlx::PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"
SELECT id, email, created_at, updated_at
FROM users WHERE email = $1
"#,
)
.bind(email)
.fetch_optional(pool)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn link_oauth_account(
pool: &sqlx::PgPool,
user_id: i64,
provider: &str,
provider_user_id: &str,
email: Option<&str>,
username: Option<&str>,
display_name: Option<&str>,
avatar_url: Option<&str>,
) -> Result<OAuthAccount, sqlx::Error> {
sqlx::query_as::<_, OAuthAccount>(
r#"
INSERT INTO oauth_accounts (user_id, provider, provider_user_id, email, username, display_name, avatar_url)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (provider, provider_user_id)
DO UPDATE SET email = EXCLUDED.email, username = EXCLUDED.username, display_name = EXCLUDED.display_name, avatar_url = EXCLUDED.avatar_url, user_id = EXCLUDED.user_id, updated_at = NOW()
RETURNING id, user_id, provider, provider_user_id, email, username, display_name, avatar_url, created_at, updated_at
"#,
)
.bind(user_id)
.bind(provider)
.bind(provider_user_id)
.bind(email)
.bind(username)
.bind(display_name)
.bind(avatar_url)
.fetch_one(pool)
.await
}
pub async fn create_user(
pool: &sqlx::PgPool,
username: &str,
display_name: Option<&str>,
email: Option<&str>,
avatar_url: Option<&str>,
provider: &str,
provider_user_id: &str,
) -> Result<User, sqlx::Error> {
let user = sqlx::query_as::<_, User>(
r#"
INSERT INTO users (email)
VALUES ($1)
RETURNING id, email, created_at, updated_at
"#,
)
.bind(email)
.fetch_one(pool)
.await?;
// Create oauth link
let _ = link_oauth_account(
pool,
user.id,
provider,
provider_user_id,
email,
Some(username),
display_name,
avatar_url,
)
.await?;
Ok(user)
}
pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64) -> Result<i64, sqlx::Error> {
let rec: (i64,) = sqlx::query_as(
r#"
SELECT COUNT(*)::BIGINT AS count
FROM oauth_accounts
WHERE user_id = $1
"#,
)
.bind(user_id)
.fetch_one(pool)
.await?;
Ok(rec.0)
}
pub async fn find_user_by_provider_id(
pool: &sqlx::PgPool,
provider: &str,
provider_user_id: &str,
) -> Result<Option<User>, sqlx::Error> {
let rec = sqlx::query_as::<_, User>(
r#"
SELECT u.id, u.email, u.created_at, u.updated_at
FROM users u
JOIN oauth_accounts oa ON oa.user_id = u.id
WHERE oa.provider = $1 AND oa.provider_user_id = $2
"#,
)
.bind(provider)
.bind(provider_user_id)
.fetch_optional(pool)
.await?;
Ok(rec)
}
#[derive(Debug, Clone, Serialize, FromRow)]
pub struct ProviderPublic {
pub provider: String,
pub provider_user_id: String,
pub email: Option<String>,
pub username: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
}
pub async fn list_user_providers(pool: &sqlx::PgPool, user_id: i64) -> Result<Vec<ProviderPublic>, sqlx::Error> {
let recs = sqlx::query_as::<_, ProviderPublic>(
r#"
SELECT provider, provider_user_id, email, username, display_name, avatar_url
FROM oauth_accounts
WHERE user_id = $1
ORDER BY provider
"#,
)
.bind(user_id)
.fetch_all(pool)
.await?;
Ok(recs)
}

View File

@@ -0,0 +1,55 @@
use axum::{http::StatusCode, response::IntoResponse, Json};
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
#[serde(skip_serializing)]
status_code: Option<StatusCode>,
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl ErrorResponse {
pub fn status_code(&self) -> StatusCode {
self.status_code.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
}
pub fn unauthorized(description: impl Into<String>) -> Self {
Self {
status_code: Some(StatusCode::UNAUTHORIZED),
error: "unauthorized".into(),
description: Some(description.into()),
}
}
pub fn bad_request(error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
Self {
status_code: Some(StatusCode::BAD_REQUEST),
error: error.into(),
description: description.into(),
}
}
pub fn bad_gateway(error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
Self {
status_code: Some(StatusCode::BAD_GATEWAY),
error: error.into(),
description: description.into(),
}
}
pub fn with_status(status: StatusCode, error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
Self {
status_code: Some(status),
error: error.into(),
description: description.into(),
}
}
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> axum::response::Response {
(self.status_code(), Json(self)).into_response()
}
}

View File

@@ -0,0 +1,243 @@
//! Custom tracing formatter
use serde::Serialize;
use serde_json::{Map, Value};
use std::fmt;
use time::macros::format_description;
use time::{format_description::FormatItem, OffsetDateTime};
use tracing::field::{Field, Visit};
use tracing::{Event, Level, Subscriber};
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields, FormattedFields};
use tracing_subscriber::registry::LookupSpan;
use yansi::Paint;
// Cached format description for timestamps
const TIMESTAMP_FORMAT: &[FormatItem<'static>] = format_description!("[hour]:[minute]:[second].[subsecond digits:5]");
/// A custom formatter with enhanced timestamp formatting
///
/// Re-implementation of the Full formatter with improved timestamp display.
pub struct CustomPrettyFormatter;
impl<S, N> FormatEvent<S, N> for CustomPrettyFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(&self, ctx: &FmtContext<'_, S, N>, mut writer: Writer<'_>, event: &Event<'_>) -> fmt::Result {
let meta = event.metadata();
// 1) Timestamp (dimmed when ANSI)
let now = OffsetDateTime::now_utc();
let formatted_time = now.format(&TIMESTAMP_FORMAT).map_err(|e| {
eprintln!("Failed to format timestamp: {}", e);
fmt::Error
})?;
write_dimmed(&mut writer, formatted_time)?;
writer.write_char(' ')?;
// 2) Colored 5-char level like Full
write_colored_level(&mut writer, meta.level())?;
writer.write_char(' ')?;
// 3) Span scope chain (bold names, fields in braces, dimmed ':')
if let Some(scope) = ctx.event_scope() {
let mut saw_any = false;
for span in scope.from_root() {
write_bold(&mut writer, span.metadata().name())?;
saw_any = true;
write_dimmed(&mut writer, ":")?;
let ext = span.extensions();
if let Some(fields) = &ext.get::<FormattedFields<N>>() {
if !fields.fields.is_empty() {
write_bold(&mut writer, "{")?;
writer.write_str(fields.fields.as_str())?;
write_bold(&mut writer, "}")?;
}
}
write_dimmed(&mut writer, ":")?;
}
if saw_any {
writer.write_char(' ')?;
}
}
// 4) Target (dimmed), then a space
if writer.has_ansi_escapes() {
write!(writer, "{}: ", Paint::new(meta.target()).dim())?;
} else {
write!(writer, "{}: ", meta.target())?;
}
// 5) Event fields
ctx.format_fields(writer.by_ref(), event)?;
// 6) Newline
writeln!(writer)
}
}
/// A custom JSON formatter that flattens fields to root level
///
/// Outputs logs in the format:
/// { "message": "...", "level": "...", "customAttribute": "..." }
pub struct CustomJsonFormatter;
impl<S, N> FormatEvent<S, N> for CustomJsonFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(&self, ctx: &FmtContext<'_, S, N>, mut writer: Writer<'_>, event: &Event<'_>) -> fmt::Result {
let meta = event.metadata();
#[derive(Serialize)]
struct EventFields {
message: String,
level: String,
target: String,
#[serde(flatten)]
spans: Map<String, Value>,
#[serde(flatten)]
fields: Map<String, Value>,
}
let (message, fields, spans) = {
let mut message: Option<String> = None;
let mut fields: Map<String, Value> = Map::new();
let mut spans: Map<String, Value> = Map::new();
struct FieldVisitor<'a> {
message: &'a mut Option<String>,
fields: &'a mut Map<String, Value>,
}
impl Visit for FieldVisitor<'_> {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
let key = field.name();
if key == "message" {
*self.message = Some(format!("{:?}", value));
} else {
self.fields.insert(key.to_string(), Value::String(format!("{:?}", value)));
}
}
fn record_str(&mut self, field: &Field, value: &str) {
let key = field.name();
if key == "message" {
*self.message = Some(value.to_string());
} else {
self.fields.insert(key.to_string(), Value::String(value.to_string()));
}
}
fn record_i64(&mut self, field: &Field, value: i64) {
let key = field.name();
if key != "message" {
self.fields
.insert(key.to_string(), Value::Number(serde_json::Number::from(value)));
}
}
fn record_u64(&mut self, field: &Field, value: u64) {
let key = field.name();
if key != "message" {
self.fields
.insert(key.to_string(), Value::Number(serde_json::Number::from(value)));
}
}
fn record_bool(&mut self, field: &Field, value: bool) {
let key = field.name();
if key != "message" {
self.fields.insert(key.to_string(), Value::Bool(value));
}
}
}
let mut visitor = FieldVisitor {
message: &mut message,
fields: &mut fields,
};
event.record(&mut visitor);
// Collect span information from the span hierarchy
if let Some(scope) = ctx.event_scope() {
for span in scope.from_root() {
let span_name = span.metadata().name().to_string();
let mut span_fields: Map<String, Value> = Map::new();
// Try to extract fields from FormattedFields
let ext = span.extensions();
if let Some(formatted_fields) = ext.get::<FormattedFields<N>>() {
// Try to parse as JSON first
if let Ok(json_fields) = serde_json::from_str::<Map<String, Value>>(formatted_fields.fields.as_str()) {
span_fields.extend(json_fields);
} else {
// If not valid JSON, treat the entire field string as a single field
span_fields.insert("raw".to_string(), Value::String(formatted_fields.fields.as_str().to_string()));
}
}
// Insert span as a nested object directly into the spans map
spans.insert(span_name, Value::Object(span_fields));
}
}
(message, fields, spans)
};
let json = EventFields {
message: message.unwrap_or_default(),
level: meta.level().to_string(),
target: meta.target().to_string(),
spans,
fields,
};
writeln!(
writer,
"{}",
serde_json::to_string(&json).unwrap_or_else(|_| "{}".to_string())
)
}
}
/// Write the verbosity level with the same coloring/alignment as the Full formatter.
fn write_colored_level(writer: &mut Writer<'_>, level: &Level) -> fmt::Result {
if writer.has_ansi_escapes() {
let paint = match *level {
Level::TRACE => Paint::new("TRACE").magenta(),
Level::DEBUG => Paint::new("DEBUG").blue(),
Level::INFO => Paint::new(" INFO").green(),
Level::WARN => Paint::new(" WARN").yellow(),
Level::ERROR => Paint::new("ERROR").red(),
};
write!(writer, "{}", paint)
} else {
// Right-pad to width 5 like Full's non-ANSI mode
match *level {
Level::TRACE => write!(writer, "{:>5}", "TRACE"),
Level::DEBUG => write!(writer, "{:>5}", "DEBUG"),
Level::INFO => write!(writer, "{:>5}", " INFO"),
Level::WARN => write!(writer, "{:>5}", " WARN"),
Level::ERROR => write!(writer, "{:>5}", "ERROR"),
}
}
}
fn write_dimmed(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "{}", Paint::new(s).dim())
} else {
write!(writer, "{}", s)
}
}
fn write_bold(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "{}", Paint::new(s).bold())
} else {
write!(writer, "{}", s)
}
}

183
pacman-server/src/image.rs Normal file
View File

@@ -0,0 +1,183 @@
use std::sync::Arc;
use image::codecs::png::PngEncoder;
use s3::Bucket;
use sha2::Digest;
use tracing::trace;
use crate::config::Config;
/// Minimal S3-backed image storage. This keeps things intentionally simple for now:
/// - construct from existing `Config`
/// - upload raw bytes under a key
/// - upload a local file by path (reads whole file into memory)
/// - generate a simple presigned GET URL
/// - process avatars with resizing and upload
///
/// Backed by `s3-tokio` (hyper 1 + rustls) and compatible with S3/R2/MinIO endpoints.
#[derive(Clone)]
pub struct ImageStorage {
bucket: Arc<s3::Bucket>,
public_base_url: String,
}
impl ImageStorage {
/// Create a new storage for a specific `bucket_name` using settings from `Config`.
///
/// This uses a custom region + endpoint so it works across AWS S3 and compatible services
/// such as Cloudflare R2 and MinIO.
pub fn new(config: &Config, bucket_name: impl Into<String>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let credentials = s3::creds::Credentials::new(
Some(&config.s3_access_key),
Some(&config.s3_secret_access_key),
None, // security token
None, // session token
None, // profile
)?;
let bucket = Bucket::new(
&bucket_name.into(),
s3::Region::R2 {
account_id: "f188bf93079278e7bbc58de9b3d80693".to_string(),
},
credentials,
)?
.with_path_style();
Ok(Self {
bucket: Arc::new(bucket),
public_base_url: config.s3_public_base_url.clone(),
})
}
/// Upload a byte slice to `key` with optional content type.
///
/// Returns the ETag (if present) from the server response.
pub async fn upload_bytes(
&self,
key: &str,
bytes: impl AsRef<[u8]>,
content_type: Option<&str>,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
let data = bytes.as_ref();
let content_type = content_type.unwrap_or("application/octet-stream");
// Prefer the content-type variant for correct metadata
let status = {
let response = self.bucket.put_object_with_content_type(key, data, content_type).await?;
response.status_code()
};
if (200..300).contains(&status) {
// s3-tokio returns headers separately; attempt to pull the ETag if available
// Note: the current API returns (status, headers) where headers is `http::HeaderMap`.
// Some providers omit ETag on PUT; we handle that by returning `None`.
Ok(None)
} else {
Err(format!("upload failed with status {}", status).into())
}
}
/// Generate a simple presigned GET URL valid for `expires_in_seconds`.
#[allow(dead_code)]
pub fn presign_get(&self, key: &str, expires_in_seconds: u32) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let url = self.bucket.presign_get(key, expires_in_seconds, None)?;
Ok(url)
}
/// Process and upload an avatar from a URL.
///
/// Downloads the image, resizes it to 512x512 (original) and 32x32 (mini),
/// then uploads both versions to S3. Returns the public URLs for both images.
pub async fn process_avatar(
&self,
user_public_id: &str,
avatar_url: &str,
) -> Result<AvatarUrls, Box<dyn std::error::Error + Send + Sync>> {
// Download the avatar image
let response = reqwest::get(avatar_url).await?;
if !response.status().is_success() {
return Err(format!("Failed to download avatar: {}", response.status()).into());
}
let image_bytes = response.bytes().await?;
trace!(bytes = image_bytes.len(), "Downloaded avatar");
// Decode the image
let img = image::load_from_memory(&image_bytes)?;
let img_rgba = img.to_rgba8();
// Generate a simple hash for the avatar (using the URL for now)
let avatar_hash = format!("{:x}", sha2::Sha256::digest(avatar_url.as_bytes()));
trace!(
width = img_rgba.width(),
height = img_rgba.height(),
hash = avatar_hash,
"Avatar image decoded"
);
// Process original (512x512 max, square)
let original_key = format!("avatars/{}/{}.original.png", user_public_id, avatar_hash);
let original_png = self.resize_to_square_png(&img_rgba, 512)?;
self.upload_bytes(&original_key, &original_png, Some("image/png")).await?;
trace!(key = original_key, "Uploaded original avatar");
// Process mini (32x32)
let mini_key = format!("avatars/{}/{}.mini.png", user_public_id, avatar_hash);
let mini_png = self.resize_to_square_png(&img_rgba, 32)?;
self.upload_bytes(&mini_key, &mini_png, Some("image/png")).await?;
trace!(key = mini_key, "Uploaded mini avatar");
Ok(AvatarUrls {
original_url: format!("{}/{}", self.public_base_url, original_key),
mini_url: format!("{}/{}", self.public_base_url, mini_key),
})
}
/// Resize an RGBA image to a square of the specified size, maintaining aspect ratio.
fn resize_to_square_png(
&self,
img: &image::RgbaImage,
target_size: u32,
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let (width, height) = img.dimensions();
// Calculate dimensions for square crop (center crop)
let size = width.min(height);
let start_x = (width - size) / 2;
let start_y = (height - size) / 2;
// Crop to square
let cropped = image::imageops::crop_imm(img, start_x, start_y, size, size).to_image();
// Resize to target size
let resized = image::imageops::resize(&cropped, target_size, target_size, image::imageops::FilterType::Lanczos3);
// Encode as PNG
let mut bytes: Vec<u8> = Vec::new();
let cursor = std::io::Cursor::new(&mut bytes);
// Write the resized image to the cursor
resized.write_with_encoder(PngEncoder::new(cursor))?;
Ok(bytes)
}
}
/// URLs for processed avatar images
#[derive(Debug, Clone)]
pub struct AvatarUrls {
pub original_url: String,
pub mini_url: String,
}
impl ImageStorage {
/// Create a new storage using the default bucket from `Config`.
pub fn from_config(config: &Config) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
Self::new(config, &config.s3_bucket_name)
}
}
// References:
// - Example (R2): https://github.com/FemLolStudio/s3-tokio/blob/master/examples/r2-tokio.rs
// - Crate docs: https://lib.rs/crates/s3-tokio

10
pacman-server/src/lib.rs Normal file
View File

@@ -0,0 +1,10 @@
pub mod app;
pub mod auth;
pub mod config;
pub mod data;
pub mod errors;
pub mod formatter;
pub mod image;
pub mod logging;
pub mod routes;
pub mod session;

View File

@@ -0,0 +1,36 @@
use tracing_subscriber::fmt::format::JsonFields;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
use crate::config::Config;
use crate::formatter;
/// Configure and initialize logging for the application
pub fn setup_logging(_config: &Config) {
// Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME"))));
// Default to pretty for local dev; switchable later if we add CLI
let use_pretty = cfg!(debug_assertions);
let subscriber: Box<dyn tracing::Subscriber + Send + Sync> = if use_pretty {
Box::new(
FmtSubscriber::builder()
.with_target(true)
.event_format(formatter::CustomPrettyFormatter)
.with_env_filter(filter)
.finish(),
)
} else {
Box::new(
FmtSubscriber::builder()
.with_target(true)
.event_format(formatter::CustomJsonFormatter)
.fmt_fields(JsonFields::new())
.with_env_filter(filter)
.finish(),
)
};
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
}

View File

@@ -1,3 +1,133 @@
fn main() {
println!("Hello, world!");
use crate::{
app::{create_router, AppState},
auth::AuthRegistry,
config::Config,
};
use std::sync::Arc;
use std::time::Instant;
use tracing::{info, trace, warn};
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{watch, Notify};
mod app;
mod auth;
mod config;
mod data;
mod errors;
mod formatter;
mod image;
mod logging;
mod routes;
mod session;
#[tokio::main]
async fn main() {
// Load environment variables
#[cfg(debug_assertions)]
dotenvy::from_path(std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".env")).ok();
#[cfg(not(debug_assertions))]
dotenvy::dotenv().ok();
// Load configuration
let config: Config = config::load_config();
// Initialize tracing subscriber
logging::setup_logging(&config);
trace!(host = %config.host, port = config.port, shutdown_timeout_seconds = config.shutdown_timeout_seconds, "Loaded server configuration");
let addr = std::net::SocketAddr::new(config.host, config.port);
let shutdown_timeout = std::time::Duration::from_secs(config.shutdown_timeout_seconds as u64);
let auth = AuthRegistry::new(&config).expect("auth initializer");
let db = data::pool::create_pool(true, &config.database_url, 10).await;
// Run database migrations at startup
if let Err(e) = sqlx::migrate!("./migrations").run(&db).await {
panic!("failed to run database migrations: {}", e);
}
// Create the shutdown notification before creating AppState
let notify = Arc::new(Notify::new());
let app_state = AppState::new(config, auth, db, notify.clone()).await;
{
// migrations succeeded
let mut h = app_state.health.write().await;
h.set_migrations(true);
}
let app = create_router(app_state);
info!(%addr, "Starting HTTP server bind");
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
info!(%addr, "HTTP server listening");
// coordinated graceful shutdown with timeout
let (tx_signal, rx_signal) = watch::channel::<Option<Instant>>(None);
{
let notify = notify.clone();
let tx = tx_signal.clone();
tokio::spawn(async move {
let signaled_at = shutdown_signal().await;
let _ = tx.send(Some(signaled_at));
notify.notify_waiters();
});
}
let mut rx_for_timeout = rx_signal.clone();
let timeout_task = async move {
// wait until first signal observed
while rx_for_timeout.borrow().is_none() {
if rx_for_timeout.changed().await.is_err() {
return; // channel closed
}
}
tokio::time::sleep(shutdown_timeout).await;
warn!(timeout = ?shutdown_timeout, "Shutdown timeout elapsed; forcing exit");
std::process::exit(1);
};
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
notify.notified().await;
});
tokio::select! {
res = server => {
// server finished; if we had a signal, print remaining time
let now = Instant::now();
if let Some(signaled_at) = *rx_signal.borrow() {
let elapsed = now.duration_since(signaled_at);
if elapsed < shutdown_timeout {
let remaining = format!("{:.2?}", shutdown_timeout - elapsed);
info!(remaining = remaining, "Graceful shutdown complete");
}
}
res.unwrap();
}
_ = timeout_task => {}
}
}
async fn shutdown_signal() -> Instant {
let ctrl_c = async {
tokio::signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
warn!(signal = "ctrl_c", "Received Ctrl+C; shutting down");
};
#[cfg(unix)]
let sigterm = async {
let mut term_stream = signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
term_stream.recv().await;
warn!(signal = "sigterm", "Received SIGTERM; shutting down");
};
#[cfg(not(unix))]
let sigterm = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => { Instant::now() }
_ = sigterm => { Instant::now() }
}
}

388
pacman-server/src/routes.rs Normal file
View File

@@ -0,0 +1,388 @@
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Redirect},
};
use axum_cookie::CookieManager;
use serde::Serialize;
use tracing::{debug, info, instrument, span, trace, warn};
use crate::data::user as user_repo;
use crate::{app::AppState, errors::ErrorResponse, session};
#[derive(Debug, serde::Deserialize)]
pub struct AuthQuery {
pub code: Option<String>,
pub state: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
#[derive(Debug, serde::Deserialize)]
pub struct AuthorizeQuery {
pub link: Option<bool>,
}
#[instrument(skip_all, fields(provider = %provider))]
pub async fn oauth_authorize_handler(
State(app_state): State<AppState>,
Path(provider): Path<String>,
Query(aq): Query<AuthorizeQuery>,
cookie: CookieManager,
) -> axum::response::Response {
let Some(prov) = app_state.auth.get(&provider) else {
warn!(%provider, "Unknown OAuth provider");
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
};
trace!("Starting OAuth authorization");
// Persist link intent using a short-lived cookie; callbacks won't carry our query params.
if aq.link == Some(true) {
cookie.add(
axum_cookie::cookie::Cookie::builder("link", "1")
.http_only(true)
.same_site(axum_cookie::prelude::SameSite::Lax)
.path("/")
.max_age(std::time::Duration::from_secs(120))
.build(),
);
}
let resp = prov.authorize().await;
trace!("Redirecting to provider authorization page");
resp
}
pub async fn oauth_callback_handler(
State(app_state): State<AppState>,
Path(provider): Path<String>,
Query(params): Query<AuthQuery>,
cookie: CookieManager,
) -> axum::response::Response {
// Validate provider
let Some(prov) = app_state.auth.get(&provider) else {
warn!(%provider, "Unknown OAuth provider");
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
};
// Process callback-returned errors from provider
if let Some(error) = params.error {
warn!(%provider, error = %error, desc = ?params.error_description, "OAuth callback returned an error");
return ErrorResponse::bad_request(error, params.error_description).into_response();
}
// Acquire required parameters
let Some(code) = params.code.as_deref() else {
return ErrorResponse::bad_request("invalid_request", Some("missing code".into())).into_response();
};
let Some(state) = params.state.as_deref() else {
return ErrorResponse::bad_request("invalid_request", Some("missing state".into())).into_response();
};
span!(tracing::Level::DEBUG, "oauth_callback_handler", provider = %provider, code = %code, state = %state);
// Handle callback from provider
let user = match prov.handle_callback(code, state).await {
Ok(u) => u,
Err(e) => {
warn!(%provider, "OAuth callback handling failed");
return e.into_response();
}
};
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
if link_cookie.is_some() {
cookie.remove("link");
}
let email = user.email.as_deref();
// Determine linking intent with a valid session
let is_link = if link_cookie.as_deref() == Some("1") {
debug!("Link intent present");
match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) {
Some(c) => {
// Perform linking with current session user
let (cur_prov, cur_id) = c.sub.split_once(':').unwrap_or(("", ""));
let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await {
Ok(Some(u)) => u,
Ok(None) => {
warn!("Current session user not found; proceeding as normal sign-in");
return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into()))
.into_response();
}
Err(_) => {
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
.into_response();
}
};
if let Err(e) = user_repo::link_oauth_account(
&app_state.db,
current_user.id,
&provider,
&user.id,
email,
Some(&user.username),
user.name.as_deref(),
user.avatar_url.as_deref(),
)
.await
{
warn!(error = %e, %provider, "Failed to link OAuth account");
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
}
return (StatusCode::FOUND, Redirect::to("/profile")).into_response();
}
None => {
warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in");
false
}
}
} else {
false
};
if is_link {
unreachable!(); // handled via early return above
} else {
// Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow.
if let Some(e) = email {
if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await {
// Only block if the user already has at least one linked provider.
// NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive,
// this may lock them out until the provider is reactivated or a manual admin link is performed.
match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await {
Ok(count) if count > 0 => {
// Check if the "new" provider is already linked to the user
match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await {
Ok(Some(_)) => {
debug!(
%provider,
%existing.id,
"Provider already linked to user, signing in normally");
}
Ok(None) => {
debug!(
%provider,
%existing.id,
"Provider not linked to user, failing"
);
return ErrorResponse::bad_request(
"account_exists",
Some(format!(
"An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.",
e, provider
)),
)
.into_response();
}
Err(e) => {
warn!(error = %e, %provider, "Failed to find user by provider ID");
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
.into_response();
}
}
}
Ok(_) => {
// No providers linked yet: safe to associate this provider
if let Err(e) = user_repo::link_oauth_account(
&app_state.db,
existing.id,
&provider,
&user.id,
email,
Some(&user.username),
user.name.as_deref(),
user.avatar_url.as_deref(),
)
.await
{
warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers");
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
.into_response();
}
}
Err(e) => {
warn!(error = %e, "Failed to count oauth accounts for user");
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
.into_response();
}
}
} else {
// Create new user with email
match user_repo::create_user(
&app_state.db,
&user.username,
user.name.as_deref(),
email,
user.avatar_url.as_deref(),
&provider,
&user.id,
)
.await
{
Ok(u) => u,
Err(e) => {
warn!(error = %e, %provider, "Failed to create user");
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
.into_response();
}
};
}
} else {
// No email available: disallow sign-in for safety
return ErrorResponse::bad_request(
"invalid_request",
Some("account has no email; sign in with a different provider".into()),
)
.into_response();
}
};
// Create session token
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
session::set_session_cookie(&cookie, &session_token);
info!(%provider, "Signed in successfully");
// Process avatar asynchronously (don't block the response)
if let Some(avatar_url) = user.avatar_url.as_deref() {
let image_storage = app_state.image_storage.clone();
let user_public_id = user.id.clone();
let avatar_url = avatar_url.to_string();
debug!(%user_public_id, %avatar_url, "Processing avatar");
tokio::spawn(async move {
match image_storage.process_avatar(&user_public_id, &avatar_url).await {
Ok(avatar_urls) => {
info!(
user_id = %user_public_id,
original_url = %avatar_urls.original_url,
mini_url = %avatar_urls.mini_url,
"Avatar processed successfully"
);
}
Err(e) => {
warn!(
user_id = %user_public_id,
avatar_url = %avatar_url,
error = %e,
"Failed to process avatar"
);
}
}
});
}
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
}
pub async fn profile_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
let Some(token_str) = session::get_session_token(&cookie) else {
debug!("Missing session cookie");
return ErrorResponse::unauthorized("missing session cookie").into_response();
};
let Some(claims) = session::decode_jwt(&token_str, &app_state.jwt_decoding_key) else {
debug!("Invalid session token");
return ErrorResponse::unauthorized("invalid session token").into_response();
};
// sub format: provider:provider_user_id
let (prov, prov_user_id) = match claims.sub.split_once(':') {
Some((p, id)) => (p, id),
None => {
debug!("Malformed session token subject");
return ErrorResponse::unauthorized("invalid session token").into_response();
}
};
match user_repo::find_user_by_provider_id(&app_state.db, prov, prov_user_id).await {
Ok(Some(db_user)) => {
// Include linked providers in the profile payload
match user_repo::list_user_providers(&app_state.db, db_user.id).await {
Ok(providers) => {
#[derive(Serialize)]
struct ProfilePayload<T> {
id: i64,
email: Option<String>,
providers: Vec<T>,
created_at: chrono::DateTime<chrono::Utc>,
updated_at: chrono::DateTime<chrono::Utc>,
}
let body = ProfilePayload {
id: db_user.id,
email: db_user.email.clone(),
providers,
created_at: db_user.created_at,
updated_at: db_user.updated_at,
};
axum::Json(body).into_response()
}
Err(e) => {
warn!(error = %e, "Failed to list user providers");
ErrorResponse::with_status(
StatusCode::INTERNAL_SERVER_ERROR,
"database_error",
Some("could not fetch providers".into()),
)
.into_response()
}
}
}
Ok(None) => {
debug!("User not found for session");
ErrorResponse::unauthorized("session not found").into_response()
}
Err(e) => {
warn!(error = %e, "Failed to fetch user for session");
ErrorResponse::with_status(
StatusCode::INTERNAL_SERVER_ERROR,
"database_error",
Some("could not fetch user".into()),
)
.into_response()
}
}
}
pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
if let Some(token_str) = session::get_session_token(&cookie) {
// Remove from in-memory sessions if present
app_state.sessions.remove(&token_str);
}
session::clear_session_cookie(&cookie);
info!("Signed out successfully");
(StatusCode::FOUND, Redirect::to("/")).into_response()
}
#[derive(Serialize)]
struct ProviderInfo {
id: &'static str,
name: &'static str,
active: bool,
}
pub async fn list_providers_handler(State(app_state): State<AppState>) -> axum::response::Response {
let providers: Vec<ProviderInfo> = app_state
.auth
.values()
.map(|provider| ProviderInfo {
id: provider.id(),
name: provider.label(),
active: provider.active(),
})
.collect();
axum::Json(providers).into_response()
}
pub async fn health_handler(
State(app_state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> axum::response::Response {
// Force health check in debug mode
#[cfg(debug_assertions)]
if params.contains_key("force") {
app_state.check_health().await;
}
let ok = app_state.health.read().await.ok();
let status = if ok { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE };
let body = serde_json::json!({ "ok": ok });
(status, axum::Json(body)).into_response()
}

View File

@@ -0,0 +1,65 @@
use std::time::{SystemTime, UNIX_EPOCH};
use axum_cookie::{cookie::Cookie, prelude::SameSite, CookieManager};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use crate::auth::provider::AuthUser;
use tracing::{trace, warn};
pub const SESSION_COOKIE_NAME: &str = "session";
pub const JWT_TTL_SECS: u64 = 60 * 60; // 1 hour
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Claims {
pub sub: String, // format: "{provider}:{provider_user_id}"
pub name: Option<String>,
pub iat: usize,
pub exp: usize,
}
pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_secs() as usize;
let claims = Claims {
sub: format!("{}:{}", provider, user.id),
name: user.name.clone(),
iat: now,
exp: now + JWT_TTL_SECS as usize,
};
let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign");
trace!(sub = %claims.sub, exp = claims.exp, "Created session JWT");
token
}
pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option<Claims> {
let mut validation = Validation::new(Algorithm::HS256);
validation.leeway = 30;
match decode::<Claims>(token, decoding_key, &validation) {
Ok(data) => Some(data.claims),
Err(e) => {
warn!(error = %e, "Session JWT verification failed");
None
}
}
}
pub fn set_session_cookie(cookie: &CookieManager, token: &str) {
cookie.add(
Cookie::builder(SESSION_COOKIE_NAME, token.to_string())
.http_only(true)
.secure(!cfg!(debug_assertions))
.path("/")
.same_site(SameSite::Lax)
.build(),
);
}
pub fn clear_session_cookie(cookie: &CookieManager) {
cookie.remove(SESSION_COOKIE_NAME);
}
pub fn get_session_token(cookie: &CookieManager) -> Option<String> {
cookie.get(SESSION_COOKIE_NAME).map(|c| c.value().to_string())
}

View File

@@ -0,0 +1,137 @@
use axum::Router;
use pacman_server::{
app::{create_router, AppState},
auth::AuthRegistry,
config::Config,
};
use std::sync::Arc;
use testcontainers::{
core::{IntoContainerPort, WaitFor},
runners::AsyncRunner,
ContainerAsync, GenericImage, ImageExt,
};
use tokio::sync::Notify;
/// Test configuration for integration tests
pub struct TestConfig {
pub database_url: Option<String>,
pub container: Option<ContainerAsync<GenericImage>>,
pub config: Config,
}
impl TestConfig {
/// Create a test configuration with a test database
pub async fn new() -> Self {
Self::new_with_database(true).await
}
/// Create a test configuration with optional database setup
pub async fn new_with_database(use_database: bool) -> Self {
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install default crypto provider");
let (database_url, container) = if use_database {
let (url, container) = setup_test_database("testdb", "testuser", "testpass").await;
(Some(url), Some(container))
} else {
(None, None)
};
let config = Config {
database_url: database_url
.clone()
.unwrap_or_else(|| "postgresql://dummy:dummy@localhost:5432/dummy?sslmode=disable".to_string()),
discord_client_id: "test_discord_client_id".to_string(),
discord_client_secret: "test_discord_client_secret".to_string(),
github_client_id: "test_github_client_id".to_string(),
github_client_secret: "test_github_client_secret".to_string(),
s3_access_key: "test_s3_access_key".to_string(),
s3_secret_access_key: "test_s3_secret_access_key".to_string(),
s3_bucket_name: "test_bucket".to_string(),
s3_public_base_url: "https://test.example.com".to_string(),
port: 0, // Will be set by test server
host: "127.0.0.1".parse().unwrap(),
shutdown_timeout_seconds: 5,
public_base_url: "http://localhost:3000".to_string(),
jwt_secret: "test_jwt_secret_key_for_testing_only".to_string(),
};
Self {
database_url,
container,
config,
}
}
}
/// Set up a test PostgreSQL database using testcontainers
async fn setup_test_database(db: &str, user: &str, password: &str) -> (String, ContainerAsync<GenericImage>) {
let container = GenericImage::new("postgres", "15")
.with_exposed_port(5432.tcp())
.with_wait_for(WaitFor::message_on_stderr("database system is ready to accept connections"))
.with_env_var("POSTGRES_DB", db)
.with_env_var("POSTGRES_USER", user)
.with_env_var("POSTGRES_PASSWORD", password)
.start()
.await
.unwrap();
let host = container.get_host().await.unwrap();
let port = container.get_host_port_ipv4(5432).await.unwrap();
(
format!("postgresql://{user}:{password}@{host}:{port}/{db}?sslmode=disable"),
container,
)
}
/// Create a test app state with database and auth registry
pub async fn create_test_app_state(test_config: &TestConfig) -> AppState {
create_test_app_state_with_database(test_config, true).await
}
/// Create a test app state with optional database setup
pub async fn create_test_app_state_with_database(test_config: &TestConfig, use_database: bool) -> AppState {
let db = if use_database {
// Create database pool
let db_url = test_config
.database_url
.as_ref()
.expect("Database URL required when use_database is true");
let db = pacman_server::data::pool::create_pool(use_database, db_url, 5).await;
// Run migrations
sqlx::migrate!("./migrations")
.run(&db)
.await
.expect("Failed to run database migrations");
db
} else {
// Create a dummy database pool that will fail gracefully
let dummy_url = "postgresql://dummy:dummy@localhost:5432/dummy?sslmode=disable";
pacman_server::data::pool::create_pool(false, dummy_url, 1).await
};
// Create auth registry
let auth = AuthRegistry::new(&test_config.config).expect("Failed to create auth registry");
// Create app state
let notify = Arc::new(Notify::new());
let app_state = AppState::new_with_database(test_config.config.clone(), auth, db, notify, use_database).await;
// Set health status based on database usage
{
let mut health = app_state.health.write().await;
health.set_migrations(use_database);
health.set_database(use_database);
}
app_state
}
/// Create a test router with the given app state
pub fn create_test_router(app_state: AppState) -> Router {
create_router(app_state)
}

View File

@@ -0,0 +1,241 @@
use axum_test::TestServer;
use mockall::predicate::*;
use pretty_assertions::assert_eq;
mod common;
use common::{create_test_app_state, create_test_app_state_with_database, create_test_router, TestConfig};
/// Setup function with optional database
async fn setup_test_server(use_database: bool) -> TestServer {
let test_config = TestConfig::new_with_database(use_database).await;
let app_state = create_test_app_state_with_database(&test_config, use_database).await;
let router = create_test_router(app_state);
TestServer::new(router).unwrap()
}
/// Test basic endpoints functionality
#[tokio::test]
async fn test_basic_endpoints() {
let server = setup_test_server(false).await;
// Test root endpoint
let response = server.get("/").await;
assert_eq!(response.status_code(), 200);
}
/// Test health endpoint functionality with real database connectivity
#[tokio::test]
async fn test_health_endpoint() {
let test_config = TestConfig::new().await;
let app_state = create_test_app_state(&test_config).await;
let router = create_test_router(app_state.clone());
let server = TestServer::new(router).unwrap();
// First, verify health endpoint works when database is healthy
let response = server.get("/health").await;
assert_eq!(response.status_code(), 200);
let health_json: serde_json::Value = response.json();
assert_eq!(health_json["ok"], true);
// Now kill the database container to simulate database failure
drop(test_config.container);
// Now verify health endpoint reports bad health
let response = server.get("/health?force").await;
assert_eq!(response.status_code(), 503); // SERVICE_UNAVAILABLE
let health_json: serde_json::Value = response.json();
assert_eq!(health_json["ok"], false);
}
/// Test OAuth provider listing and configuration
#[tokio::test]
async fn test_oauth_provider_configuration() {
let server = setup_test_server(false).await;
// Test providers list endpoint
let response = server.get("/auth/providers").await;
assert_eq!(response.status_code(), 200);
let providers: Vec<serde_json::Value> = response.json();
assert_eq!(providers.len(), 2); // Should have GitHub and Discord providers
// Verify provider structure
let provider_ids: Vec<&str> = providers.iter().map(|p| p["id"].as_str().unwrap()).collect();
assert!(provider_ids.contains(&"github"));
assert!(provider_ids.contains(&"discord"));
// Verify provider details
for provider in providers {
let id = provider["id"].as_str().unwrap();
let name = provider["name"].as_str().unwrap();
let active = provider["active"].as_bool().unwrap();
assert!(active, "Provider {} should be active", id);
match id {
"github" => assert_eq!(name, "GitHub"),
"discord" => assert_eq!(name, "Discord"),
_ => panic!("Unknown provider: {}", id),
}
}
}
/// Test OAuth authorization flows
#[tokio::test]
async fn test_oauth_authorization_flows() {
let server = setup_test_server(false).await;
// Test OAuth authorize endpoint (should redirect)
let response = server.get("/auth/github").await;
assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth
// Test OAuth authorize endpoint for Discord
let response = server.get("/auth/discord").await;
assert_eq!(response.status_code(), 303); // Redirect to Discord OAuth
// Test unknown provider
let response = server.get("/auth/unknown").await;
assert_eq!(response.status_code(), 400); // Bad request for unknown provider
}
/// Test OAuth callback handling
#[tokio::test]
async fn test_oauth_callback_handling() {
let server = setup_test_server(false).await;
// Test OAuth callback with missing parameters (should fail gracefully)
let response = server.get("/auth/github/callback").await;
assert_eq!(response.status_code(), 400); // Bad request for missing code/state
}
/// Test session management endpoints
#[tokio::test]
async fn test_session_management() {
let server = setup_test_server(false).await;
// Test logout endpoint (should redirect)
let response = server.get("/logout").await;
assert_eq!(response.status_code(), 302); // Redirect to home
// Test profile endpoint without session (should be unauthorized)
let response = server.get("/profile").await;
assert_eq!(response.status_code(), 401); // Unauthorized without session
}
/// Test that verifies database operations work correctly
#[tokio::test]
async fn test_database_operations() {
let server = setup_test_server(true).await;
// Act: Test health endpoint to verify database connectivity
let response = server.get("/health").await;
// Assert: Health should be OK, indicating database is connected and migrations ran
assert_eq!(response.status_code(), 200);
let health_json: serde_json::Value = response.json();
assert_eq!(health_json["ok"], true);
}
/// Test OAuth authorization flow
#[tokio::test]
async fn test_oauth_authorization_flow() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth authorize handler redirects to the provider's authorization page for valid providers
// TODO: Test that the OAuth authorize handler returns an error for unknown providers
// TODO: Test that the OAuth authorize handler sets a link cookie when the link parameter is true
}
/// Test OAuth callback validation
#[tokio::test]
async fn test_oauth_callback_validation() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler validates the provider exists before processing
// TODO: Test that the OAuth callback handler returns an error when the provider returns an OAuth error
// TODO: Test that the OAuth callback handler returns an error when the authorization code is missing
// TODO: Test that the OAuth callback handler returns an error when the state parameter is missing
}
/// Test OAuth callback processing
#[tokio::test]
async fn test_oauth_callback_processing() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler exchanges the authorization code for user information successfully
// TODO: Test that the OAuth callback handler handles provider callback errors gracefully
// TODO: Test that the OAuth callback handler creates a session token after successful authentication
// TODO: Test that the OAuth callback handler sets a session cookie after successful authentication
// TODO: Test that the OAuth callback handler redirects to the profile page after successful authentication
}
/// Test account linking flow
#[tokio::test]
async fn test_account_linking_flow() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler links a new provider to an existing user when link intent is present and session is valid
// TODO: Test that the OAuth callback handler redirects to profile after successful account linking
// TODO: Test that the OAuth callback handler falls back to normal sign-in when link intent is present but no valid session exists
}
/// Test new user registration
#[tokio::test]
async fn test_new_user_registration() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler creates a new user account when no existing user is found
// TODO: Test that the OAuth callback handler requires an email address for all sign-ins
// TODO: Test that the OAuth callback handler rejects sign-in attempts when no email is available
}
/// Test existing user sign-in
#[tokio::test]
async fn test_existing_user_sign_in() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler allows sign-in when the provider is already linked to an existing user
// TODO: Test that the OAuth callback handler requires explicit linking when a user with the same email exists and has other providers linked
// TODO: Test that the OAuth callback handler auto-links a provider when a user exists but has no other providers linked
}
/// Test avatar processing
#[tokio::test]
async fn test_avatar_processing() {
let _server = setup_test_server(false).await;
// TODO: Test that the OAuth callback handler processes user avatars asynchronously without blocking the response
// TODO: Test that the OAuth callback handler handles avatar processing errors gracefully
}
/// Test profile access
#[tokio::test]
async fn test_profile_access() {
let _server = setup_test_server(false).await;
// TODO: Test that the profile handler returns user information when a valid session exists
// TODO: Test that the profile handler returns an error when no session cookie is present
// TODO: Test that the profile handler returns an error when the session token is invalid
// TODO: Test that the profile handler includes linked providers in the response
// TODO: Test that the profile handler returns an error when the user is not found in the database
}
/// Test logout functionality
#[tokio::test]
async fn test_logout_functionality() {
let _server = setup_test_server(false).await;
// TODO: Test that the logout handler clears the session if a session was there
// TODO: Test that the logout handler removes the session from memory storage
// TODO: Test that the logout handler clears the session cookie
// TODO: Test that the logout handler redirects to the home page after logout
}
/// Test provider configuration
#[tokio::test]
async fn test_provider_configuration() {
let _server = setup_test_server(false).await;
// TODO: Test that the providers list handler returns all configured OAuth providers
// TODO: Test that the providers list handler includes provider status (active/inactive)
}

View File

@@ -1,4 +1,4 @@
[toolchain]
# we are unfortunately pinned to 1.86.0 for some reason, bulk-memory-opt related issues on wasm32-unknown-emscripten
channel = "1.86.0"
channel = "1.87.0"
components = ["rustfmt", "llvm-tools-preview", "clippy"]