feat: add Rust reverse proxy with JSON logging

- Axum-based API server with Unix socket and TCP support
- Custom tracing formatters for Railway-compatible JSON logs
- SvelteKit hooks and Vite plugin for unified logging
- Justfile updated for concurrent dev workflow with hl log viewer
This commit is contained in:
2026-01-04 18:21:00 -06:00
parent 07ea1c093e
commit d86027d27a
19 changed files with 3069 additions and 11 deletions
+142
View File
@@ -0,0 +1,142 @@
use clap::Parser;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::PathBuf;
use std::str::FromStr;
/// Server configuration parsed from CLI arguments and environment variables
#[derive(Parser, Debug)]
#[command(name = "api")]
#[command(about = "xevion.dev API server with ISR caching", long_about = None)]
pub struct Args {
/// Address(es) to listen on. Can be host:port, :port, or Unix socket path.
/// Can be specified multiple times.
/// Examples: :8080, 0.0.0.0:8080, [::]:8080, /tmp/api.sock
#[arg(long, env = "LISTEN_ADDR", value_delimiter = ',', required = true)]
pub listen: Vec<ListenAddr>,
/// Downstream Bun SSR server URL or Unix socket path
/// Examples: http://localhost:5173, /tmp/bun.sock
#[arg(long, env = "DOWNSTREAM_URL", required = true)]
pub downstream: String,
/// Optional header name to trust for request IDs (e.g., X-Railway-Request-Id)
#[arg(long, env = "TRUST_REQUEST_ID")]
pub trust_request_id: Option<String>,
}
/// Address to listen on - either TCP or Unix socket
#[derive(Debug, Clone)]
pub enum ListenAddr {
Tcp(SocketAddr),
Unix(PathBuf),
}
impl FromStr for ListenAddr {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Unix socket: starts with / or ./
if s.starts_with('/') || s.starts_with("./") {
return Ok(ListenAddr::Unix(PathBuf::from(s)));
}
// Shorthand :port -> 127.0.0.1:port
if let Some(port_str) = s.strip_prefix(':') {
let port: u16 = port_str
.parse()
.map_err(|_| format!("Invalid port number: {}", port_str))?;
return Ok(ListenAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))));
}
// Try parsing as a socket address (handles both IPv4 and IPv6)
// This supports formats like: 0.0.0.0:8080, [::]:8080, 192.168.1.1:3000
match s.parse::<SocketAddr>() {
Ok(addr) => Ok(ListenAddr::Tcp(addr)),
Err(_) => {
// Try resolving as hostname:port
match s.to_socket_addrs() {
Ok(mut addrs) => addrs
.next()
.ok_or_else(|| format!("Could not resolve address: {}", s))
.map(ListenAddr::Tcp),
Err(_) => Err(format!(
"Invalid address '{}'. Expected host:port, :port, or Unix socket path",
s
)),
}
}
}
}
}
impl std::fmt::Display for ListenAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ListenAddr::Tcp(addr) => write!(f, "{}", addr),
ListenAddr::Unix(path) => write!(f, "{}", path.display()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_shorthand_port() {
let addr: ListenAddr = ":8080".parse().unwrap();
match addr {
ListenAddr::Tcp(socket) => {
assert_eq!(socket.port(), 8080);
assert_eq!(socket.ip().to_string(), "127.0.0.1");
}
_ => panic!("Expected TCP address"),
}
}
#[test]
fn test_parse_ipv4() {
let addr: ListenAddr = "0.0.0.0:8080".parse().unwrap();
match addr {
ListenAddr::Tcp(socket) => {
assert_eq!(socket.port(), 8080);
assert_eq!(socket.ip().to_string(), "0.0.0.0");
}
_ => panic!("Expected TCP address"),
}
}
#[test]
fn test_parse_ipv6() {
let addr: ListenAddr = "[::]:8080".parse().unwrap();
match addr {
ListenAddr::Tcp(socket) => {
assert_eq!(socket.port(), 8080);
assert_eq!(socket.ip().to_string(), "::");
}
_ => panic!("Expected TCP address"),
}
}
#[test]
fn test_parse_unix_socket() {
let addr: ListenAddr = "/tmp/api.sock".parse().unwrap();
match addr {
ListenAddr::Unix(path) => {
assert_eq!(path, PathBuf::from("/tmp/api.sock"));
}
_ => panic!("Expected Unix socket"),
}
}
#[test]
fn test_parse_relative_unix_socket() {
let addr: ListenAddr = "./api.sock".parse().unwrap();
match addr {
ListenAddr::Unix(path) => {
assert_eq!(path, PathBuf::from("./api.sock"));
}
_ => panic!("Expected Unix socket"),
}
}
}
+280
View File
@@ -0,0 +1,280 @@
//! Custom tracing formatter for Railway-compatible structured logging
use nu_ansi_term::Color;
use serde::Serialize;
use serde_json::{Map, Value};
use std::fmt;
use time::macros::format_description;
use time::{format_description::FormatItem, OffsetDateTime};
use tracing::field::{Field, Visit};
use tracing::{Event, Level, Subscriber};
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields, FormattedFields};
use tracing_subscriber::registry::LookupSpan;
/// Cached format description for timestamps with 3 subsecond digits (milliseconds)
const TIMESTAMP_FORMAT: &[FormatItem<'static>] =
format_description!("[hour]:[minute]:[second].[subsecond digits:3]");
/// A custom formatter with enhanced timestamp formatting and colored output
///
/// Provides human-readable output for local development with:
/// - Colored log levels
/// - Timestamp with millisecond precision
/// - Span context with hierarchy
/// - Clean field formatting
pub struct CustomPrettyFormatter;
impl<S, N> FormatEvent<S, N> for CustomPrettyFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &Event<'_>,
) -> fmt::Result {
let meta = event.metadata();
// 1) Timestamp (dimmed when ANSI)
let now = OffsetDateTime::now_utc();
let formatted_time = now.format(&TIMESTAMP_FORMAT).map_err(|e| {
eprintln!("Failed to format timestamp: {}", e);
fmt::Error
})?;
write_dimmed(&mut writer, formatted_time)?;
writer.write_char(' ')?;
// 2) Colored 5-char level
write_colored_level(&mut writer, meta.level())?;
writer.write_char(' ')?;
// 3) Span scope chain (bold names, fields in braces, dimmed ':')
if let Some(scope) = ctx.event_scope() {
let mut saw_any = false;
for span in scope.from_root() {
write_bold(&mut writer, span.metadata().name())?;
saw_any = true;
write_dimmed(&mut writer, ":")?;
let ext = span.extensions();
if let Some(fields) = &ext.get::<FormattedFields<N>>() {
if !fields.fields.is_empty() {
write_bold(&mut writer, "{")?;
writer.write_str(fields.fields.as_str())?;
write_bold(&mut writer, "}")?;
}
}
write_dimmed(&mut writer, ":")?;
}
if saw_any {
writer.write_char(' ')?;
}
}
// 4) Target (dimmed), then a space
if writer.has_ansi_escapes() {
write!(writer, "{}: ", Color::DarkGray.paint(meta.target()))?;
} else {
write!(writer, "{}: ", meta.target())?;
}
// 5) Event fields
ctx.format_fields(writer.by_ref(), event)?;
// 6) Newline
writeln!(writer)
}
}
/// A custom JSON formatter that flattens fields to root level for Railway
///
/// Outputs logs in Railway-compatible format:
/// ```json
/// {
/// "message": "...",
/// "level": "...",
/// "target": "...",
/// "customAttribute": "..."
/// }
/// ```
///
/// This format allows Railway to:
/// - Parse the `message` field correctly
/// - Filter by `level` and custom attributes using `@attribute:value`
/// - Preserve multi-line logs like stack traces
pub struct CustomJsonFormatter;
impl<S, N> FormatEvent<S, N> for CustomJsonFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &Event<'_>,
) -> fmt::Result {
let meta = event.metadata();
#[derive(Serialize)]
struct EventFields {
timestamp: String,
message: String,
level: String,
target: String,
#[serde(flatten)]
fields: Map<String, Value>,
}
let (message, fields) = {
let mut message: Option<String> = None;
let mut fields: Map<String, Value> = Map::new();
struct FieldVisitor<'a> {
message: &'a mut Option<String>,
fields: &'a mut Map<String, Value>,
}
impl<'a> Visit for FieldVisitor<'a> {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
let key = field.name();
if key == "message" {
*self.message = Some(format!("{:?}", value));
} else {
self.fields
.insert(key.to_string(), Value::String(format!("{:?}", value)));
}
}
fn record_str(&mut self, field: &Field, value: &str) {
let key = field.name();
if key == "message" {
*self.message = Some(value.to_string());
} else {
self.fields
.insert(key.to_string(), Value::String(value.to_string()));
}
}
fn record_i64(&mut self, field: &Field, value: i64) {
let key = field.name();
if key != "message" {
self.fields.insert(
key.to_string(),
Value::Number(serde_json::Number::from(value)),
);
}
}
fn record_u64(&mut self, field: &Field, value: u64) {
let key = field.name();
if key != "message" {
self.fields.insert(
key.to_string(),
Value::Number(serde_json::Number::from(value)),
);
}
}
fn record_bool(&mut self, field: &Field, value: bool) {
let key = field.name();
if key != "message" {
self.fields.insert(key.to_string(), Value::Bool(value));
}
}
}
let mut visitor = FieldVisitor {
message: &mut message,
fields: &mut fields,
};
event.record(&mut visitor);
// Collect span information from the span hierarchy
// Flatten all span fields directly into root level
if let Some(scope) = ctx.event_scope() {
for span in scope.from_root() {
// Extract span fields by parsing the stored extension data
// The fields are stored as a formatted string, so we need to parse them
let ext = span.extensions();
if let Some(formatted_fields) = ext.get::<FormattedFields<N>>() {
let field_str = formatted_fields.fields.as_str();
// Parse key=value pairs from the formatted string
// Format is typically: key=value key2=value2
for pair in field_str.split_whitespace() {
if let Some((key, value)) = pair.split_once('=') {
// Remove quotes if present
let value = value.trim_matches('"').trim_matches('\'');
fields.insert(key.to_string(), Value::String(value.to_string()));
}
}
}
}
}
(message, fields)
};
let json = EventFields {
timestamp: OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_else(|_| String::from("1970-01-01T00:00:00Z")),
message: message.unwrap_or_default(),
level: meta.level().to_string().to_lowercase(),
target: meta.target().to_string(),
fields,
};
writeln!(
writer,
"{}",
serde_json::to_string(&json).unwrap_or_else(|_| "{}".to_string())
)
}
}
/// Write the verbosity level with colored output
fn write_colored_level(writer: &mut Writer<'_>, level: &Level) -> fmt::Result {
if writer.has_ansi_escapes() {
let colored = match *level {
Level::TRACE => Color::Purple.paint("TRACE"),
Level::DEBUG => Color::Blue.paint("DEBUG"),
Level::INFO => Color::Green.paint(" INFO"),
Level::WARN => Color::Yellow.paint(" WARN"),
Level::ERROR => Color::Red.paint("ERROR"),
};
write!(writer, "{}", colored)
} else {
// Right-pad to width 5 for alignment
match *level {
Level::TRACE => write!(writer, "{:>5}", "TRACE"),
Level::DEBUG => write!(writer, "{:>5}", "DEBUG"),
Level::INFO => write!(writer, "{:>5}", " INFO"),
Level::WARN => write!(writer, "{:>5}", " WARN"),
Level::ERROR => write!(writer, "{:>5}", "ERROR"),
}
}
}
fn write_dimmed(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "{}", Color::DarkGray.paint(s.to_string()))
} else {
write!(writer, "{}", s)
}
}
fn write_bold(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "{}", Color::White.bold().paint(s.to_string()))
} else {
write!(writer, "{}", s)
}
}
+418
View File
@@ -0,0 +1,418 @@
use axum::{
Json, Router,
extract::{Request, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::get,
};
use clap::Parser;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
mod config;
mod formatter;
mod middleware;
use config::{Args, ListenAddr};
use formatter::{CustomJsonFormatter, CustomPrettyFormatter};
use middleware::RequestIdLayer;
fn init_tracing() {
let use_json = std::env::var("LOG_JSON")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
// Build the EnvFilter
// Priority: RUST_LOG > LOG_LEVEL > default
let filter = if let Ok(rust_log) = std::env::var("RUST_LOG") {
// RUST_LOG overwrites everything
EnvFilter::new(rust_log)
} else {
// Get LOG_LEVEL for our crate, default based on build profile
let our_level = std::env::var("LOG_LEVEL").unwrap_or_else(|_| {
if cfg!(debug_assertions) {
"debug".to_string()
} else {
"info".to_string()
}
});
// Default other crates to WARN, our crate to LOG_LEVEL
EnvFilter::new(format!("warn,api={}", our_level))
};
if use_json {
tracing_subscriber::registry()
.with(filter)
.with(
tracing_subscriber::fmt::layer()
.event_format(CustomJsonFormatter)
.fmt_fields(tracing_subscriber::fmt::format::DefaultFields::new())
.with_ansi(false), // Disable ANSI codes in JSON mode
)
.init();
} else {
tracing_subscriber::registry()
.with(filter)
.with(tracing_subscriber::fmt::layer().event_format(CustomPrettyFormatter))
.init();
}
}
#[tokio::main]
async fn main() {
// Initialize tracing with configurable format and levels
init_tracing();
// Parse CLI arguments and environment variables
let args = Args::parse();
// Validate we have at least one listen address
if args.listen.is_empty() {
eprintln!("Error: At least one --listen address is required");
std::process::exit(1);
}
// Create shared application state
let state = Arc::new(AppState {
downstream_url: args.downstream.clone(),
});
// Build router with shared state
let app = Router::new()
.nest("/api", api_routes().fallback(api_404_handler))
.fallback(isr_handler)
.layer(TraceLayer::new_for_http())
.layer(RequestIdLayer::new(args.trust_request_id.clone()))
.layer(CorsLayer::permissive())
.with_state(state);
// Spawn a listener for each address
let mut tasks = Vec::new();
for listen_addr in &args.listen {
let app = app.clone();
let listen_addr = listen_addr.clone();
let task = tokio::spawn(async move {
match listen_addr {
ListenAddr::Tcp(addr) => {
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("Failed to bind TCP listener");
// Format as clickable URL
let url = if addr.is_ipv6() {
format!("http://[{}]:{}", addr.ip(), addr.port())
} else {
format!("http://{}:{}", addr.ip(), addr.port())
};
tracing::info!(url, "Listening on TCP");
axum::serve(listener, app)
.await
.expect("Server error on TCP listener");
}
ListenAddr::Unix(path) => {
// Remove existing socket file if it exists
let _ = std::fs::remove_file(&path);
let listener = tokio::net::UnixListener::bind(&path)
.expect("Failed to bind Unix socket listener");
tracing::info!(socket = %path.display(), "Listening on Unix socket");
axum::serve(listener, app)
.await
.expect("Server error on Unix socket listener");
}
}
});
tasks.push(task);
}
// Wait for all listeners (this will run forever unless interrupted)
for task in tasks {
task.await.expect("Listener task panicked");
}
}
/// Shared application state
#[derive(Clone)]
struct AppState {
downstream_url: String,
}
/// Custom error type for proxy operations
#[derive(Debug)]
enum ProxyError {
/// Network error (connection failed, timeout, etc.)
Network(reqwest::Error),
}
impl std::fmt::Display for ProxyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProxyError::Network(e) => write!(f, "Network error: {}", e),
}
}
}
impl std::error::Error for ProxyError {}
/// Check if a path represents a static asset that should be logged at TRACE level
fn is_static_asset(path: &str) -> bool {
path.starts_with("/node_modules/")
|| path.starts_with("/@") // Vite internals like /@vite/client, /@fs/, /@id/
|| path.starts_with("/.svelte-kit/")
|| path.starts_with("/.well-known/")
|| path.ends_with(".woff2")
|| path.ends_with(".woff")
|| path.ends_with(".ttf")
|| path.ends_with(".ico")
|| path.ends_with(".png")
|| path.ends_with(".jpg")
|| path.ends_with(".svg")
|| path.ends_with(".webp")
|| path.ends_with(".css")
|| path.ends_with(".js")
|| path.ends_with(".map")
}
/// Check if a path represents a page route (heuristic: no file extension)
fn is_page_route(path: &str) -> bool {
!path.starts_with("/node_modules/")
&& !path.starts_with("/@")
&& !path.starts_with("/.svelte-kit/")
&& !path.contains('.') // Simple heuristic: no extension = likely a page
}
// API routes for data endpoints
fn api_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/health", get(health_handler))
.route("/projects", get(projects_handler))
}
// Health check endpoint
async fn health_handler() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
// API 404 fallback handler - catches unmatched /api/* routes
async fn api_404_handler(uri: axum::http::Uri) -> impl IntoResponse {
tracing::warn!(path = %uri.path(), "API route not found");
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Not found",
"path": uri.path()
})),
)
}
// Project data structure
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ProjectLink {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Project {
id: String,
name: String,
#[serde(rename = "shortDescription")]
short_description: String,
#[serde(skip_serializing_if = "Option::is_none")]
icon: Option<String>,
links: Vec<ProjectLink>,
}
// Projects endpoint - returns hardcoded project data for now
async fn projects_handler() -> impl IntoResponse {
let projects = vec![
Project {
id: "1".to_string(),
name: "xevion.dev".to_string(),
short_description: "Personal portfolio with fuzzy tag discovery".to_string(),
icon: None,
links: vec![ProjectLink {
url: "https://github.com/Xevion/xevion.dev".to_string(),
title: Some("GitHub".to_string()),
}],
},
Project {
id: "2".to_string(),
name: "Contest".to_string(),
short_description: "Competitive programming problem archive".to_string(),
icon: None,
links: vec![
ProjectLink {
url: "https://github.com/Xevion/contest".to_string(),
title: Some("GitHub".to_string()),
},
ProjectLink {
url: "https://contest.xevion.dev".to_string(),
title: Some("Demo".to_string()),
},
],
},
];
Json(projects)
}
// ISR handler - proxies to Bun SSR server
// This is the fallback for all routes not matched by /api/*
#[tracing::instrument(skip(state, req), fields(path = %req.uri().path()))]
async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Response {
let uri = req.uri();
let path = uri.path();
let query = uri.query().unwrap_or("");
// Check if API route somehow reached ISR handler (shouldn't happen)
if path.starts_with("/api/") {
tracing::error!("API request reached ISR handler - routing bug!");
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal routing error",
)
.into_response();
}
// Build URL for Bun server
let bun_url = if query.is_empty() {
format!("{}{}", state.downstream_url, path)
} else {
format!("{}{}?{}", state.downstream_url, path, query)
};
// Track request timing
let start = std::time::Instant::now();
// TODO: Add ISR caching layer here (moka, singleflight, stale-while-revalidate)
// For now, just proxy directly to Bun
match proxy_to_bun(&bun_url, &state.downstream_url).await {
Ok((status, headers, body)) => {
let duration_ms = start.elapsed().as_millis() as u64;
let cache = "miss"; // Hardcoded for now, will change when caching is implemented
// Intelligent logging based on path type and status
let is_static = is_static_asset(path);
let is_page = is_page_route(path);
match (status.as_u16(), is_static, is_page) {
// Static assets - success at TRACE
(200..=299, true, _) => {
tracing::trace!(status = status.as_u16(), duration_ms, cache, "ISR request");
}
// Static assets - 404 at WARN
(404, true, _) => {
tracing::warn!(
status = status.as_u16(),
duration_ms,
cache,
"ISR request - missing asset"
);
}
// Static assets - server error at ERROR
(500..=599, true, _) => {
tracing::error!(
status = status.as_u16(),
duration_ms,
cache,
"ISR request - server error"
);
}
// Page routes - success at DEBUG
(200..=299, _, true) => {
tracing::debug!(status = status.as_u16(), duration_ms, cache, "ISR request");
}
// Page routes - 404 silent (normal case for non-existent pages)
(404, _, true) => {}
// Page routes - server error at ERROR
(500..=599, _, _) => {
tracing::error!(
status = status.as_u16(),
duration_ms,
cache,
"ISR request - server error"
);
}
// Default fallback - DEBUG
_ => {
tracing::debug!(status = status.as_u16(), duration_ms, cache, "ISR request");
}
}
// Forward response
(status, headers, body).into_response()
}
Err(err) => {
let duration_ms = start.elapsed().as_millis() as u64;
tracing::error!(
error = %err,
url = %bun_url,
duration_ms,
"Failed to proxy to Bun"
);
(
StatusCode::BAD_GATEWAY,
format!("Failed to render page: {}", err),
)
.into_response()
}
}
}
// Proxy a request to the Bun SSR server, returning status, headers and body
async fn proxy_to_bun(
url: &str,
downstream_url: &str,
) -> Result<(StatusCode, HeaderMap, String), ProxyError> {
// Check if downstream is a Unix socket path
let client = if downstream_url.starts_with('/') || downstream_url.starts_with("./") {
// Unix socket
let path = PathBuf::from(downstream_url);
reqwest::Client::builder()
.unix_socket(path)
.build()
.map_err(ProxyError::Network)?
} else {
// Regular HTTP
reqwest::Client::new()
};
let response = client.get(url).send().await.map_err(ProxyError::Network)?;
// Extract status code
let status = StatusCode::from_u16(response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Convert reqwest headers to axum HeaderMap
let mut headers = HeaderMap::new();
for (name, value) in response.headers() {
// Skip hop-by-hop headers and content-length (axum will recalculate it)
let name_str = name.as_str();
if name_str == "transfer-encoding"
|| name_str == "connection"
|| name_str == "content-length"
{
continue;
}
if let Ok(header_name) = axum::http::HeaderName::try_from(name.as_str()) {
if let Ok(header_value) = axum::http::HeaderValue::try_from(value.as_bytes()) {
headers.insert(header_name, header_value);
}
}
}
let body = response.text().await.map_err(ProxyError::Network)?;
Ok((status, headers, body))
}
+85
View File
@@ -0,0 +1,85 @@
//! Request ID middleware for distributed tracing and correlation
use axum::{
body::Body,
extract::Request,
http::HeaderName,
response::Response,
};
use std::task::{Context, Poll};
use tower::{Layer, Service};
/// Layer that creates request ID spans for all requests
#[derive(Clone)]
pub struct RequestIdLayer {
/// Optional header name to trust for request IDs
trust_header: Option<HeaderName>,
}
impl RequestIdLayer {
/// Create a new request ID layer
pub fn new(trust_header: Option<String>) -> Self {
Self {
trust_header: trust_header.and_then(|h| h.parse().ok()),
}
}
}
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdService {
inner,
trust_header: self.trust_header.clone(),
}
}
}
/// Service that extracts or generates request IDs and creates tracing spans
#[derive(Clone)]
pub struct RequestIdService<S> {
inner: S,
trust_header: Option<HeaderName>,
}
impl<S> Service<Request> for RequestIdService<S>
where
S: Service<Request, Response = Response<Body>> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
// Extract or generate request ID
let req_id = self
.trust_header
.as_ref()
.and_then(|header| req.headers().get(header))
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| ulid::Ulid::new().to_string());
// Create a tracing span for this request
let span = tracing::info_span!("request", req_id = %req_id);
let _enter = span.enter();
// Clone span for the future
let span_clone = span.clone();
// Call the inner service
let future = self.inner.call(req);
Box::pin(async move {
// Execute the future within the span
let _enter = span_clone.enter();
future.await
})
}
}