feat: add request validation and HEAD method support

- Validate Content-Type for API requests (application/json only)
- Add HEAD method handlers for all routes
- Add 1MB request body limit
- Improve method not allowed responses with Allow header
This commit is contained in:
2026-01-04 19:21:53 -06:00
parent 32f1f88a90
commit edf271bcc6
4 changed files with 81 additions and 18 deletions
Generated
+1
View File
@@ -1617,6 +1617,7 @@ dependencies = [
"futures-util", "futures-util",
"http", "http",
"http-body", "http-body",
"http-body-util",
"iri-string", "iri-string",
"pin-project-lite", "pin-project-lite",
"tower", "tower",
+1 -1
View File
@@ -16,7 +16,7 @@ time = { version = "0.3.44", features = ["formatting", "macros"] }
tokio = { version = "1.49.0", features = ["full"] } tokio = { version = "1.49.0", features = ["full"] }
tokio-util = { version = "0.7.18", features = ["io"] } tokio-util = { version = "0.7.18", features = ["io"] }
tower = "0.5" tower = "0.5"
tower-http = { version = "0.6.8", features = ["trace", "cors"] } tower-http = { version = "0.6.8", features = ["trace", "cors", "limit"] }
tracing = "0.1.44" tracing = "0.1.44"
tracing-subscriber = { version = "0.3.22", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.22", features = ["env-filter", "json"] }
ulid = { version = "1", features = ["serde"] } ulid = { version = "1", features = ["serde"] }
+3
View File
@@ -15,6 +15,9 @@ build:
bun run --cwd web build bun run --cwd web build
cargo build --release cargo build --release
serve:
LOG_JSON=true bunx concurrently --raw --prefix none "SOCKET_PATH=/tmp/xevion-bun.sock bun --preload ../console-logger.js --silent --cwd web/build index.js" "target/release/api --listen localhost:8080 --listen /tmp/xevion-api.sock --downstream /tmp/xevion-bun.sock"
check: check:
bun run --cwd web format bun run --cwd web format
bun run --cwd web lint bun run --cwd web lint
+76 -17
View File
@@ -3,13 +3,13 @@ use axum::{
extract::{Request, State}, extract::{Request, State},
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{any, get}, routing::any,
}; };
use clap::Parser; use clap::Parser;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tower_http::{cors::CorsLayer, trace::TraceLayer}; use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
mod assets; mod assets;
@@ -88,11 +88,12 @@ async fn main() {
let app = Router::new() let app = Router::new()
.nest("/api", api_routes()) .nest("/api", api_routes())
.route("/api/", any(api_root_404_handler)) .route("/api/", any(api_root_404_handler))
.route("/_app/{*path}", get(serve_embedded_asset)) .route("/_app/{*path}", axum::routing::get(serve_embedded_asset).head(serve_embedded_asset))
.fallback(isr_handler) .fallback(isr_handler)
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(RequestIdLayer::new(args.trust_request_id.clone())) .layer(RequestIdLayer::new(args.trust_request_id.clone()))
.layer(CorsLayer::permissive()) .layer(CorsLayer::permissive())
.layer(RequestBodyLimitLayer::new(1_048_576)) // 1MB request body limit
.with_state(state); .with_state(state);
// Spawn a listener for each address // Spawn a listener for each address
@@ -199,9 +200,9 @@ fn is_page_route(path: &str) -> bool {
fn api_routes() -> Router<Arc<AppState>> { fn api_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.route("/", any(api_root_404_handler)) .route("/", any(api_root_404_handler))
.route("/health", get(health_handler)) .route("/health", axum::routing::get(health_handler).head(health_handler))
.route("/projects", get(projects_handler)) .route("/projects", axum::routing::get(projects_handler).head(projects_handler))
.fallback(api_404_handler) .fallback(api_404_and_method_handler)
} }
// API root 404 handler - explicit 404 for /api and /api/ requests // API root 404 handler - explicit 404 for /api and /api/ requests
@@ -214,16 +215,61 @@ async fn health_handler() -> impl IntoResponse {
(StatusCode::OK, "OK") (StatusCode::OK, "OK")
} }
// API 404 fallback handler - catches unmatched /api/* routes // API 404 and method handler - catches unmatched /api/* routes and validates methods/content-type
async fn api_404_handler(uri: axum::http::Uri) -> impl IntoResponse { async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
tracing::warn!(path = %uri.path(), "API route not found"); let method = req.method();
let uri = req.uri();
let path = uri.path();
// For non-GET/HEAD requests, validate Content-Type
if method != axum::http::Method::GET && method != axum::http::Method::HEAD && method != axum::http::Method::OPTIONS {
let content_type = req.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok());
if let Some(ct) = content_type {
// Only accept application/json for request bodies
if !ct.starts_with("application/json") {
return (
StatusCode::UNSUPPORTED_MEDIA_TYPE,
Json(serde_json::json!({
"error": "Unsupported media type",
"message": "API endpoints only accept application/json"
})),
).into_response();
}
} else if method == axum::http::Method::POST || method == axum::http::Method::PUT || method == axum::http::Method::PATCH {
// POST/PUT/PATCH require Content-Type header
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Missing Content-Type header",
"message": "Content-Type: application/json is required"
})),
).into_response();
}
}
// Route not found
tracing::warn!(path = %path, method = %method, "API route not found");
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(serde_json::json!({ Json(serde_json::json!({
"error": "Not found", "error": "Not found",
"path": uri.path() "path": path
})), })),
) ).into_response()
}
// Simple 404 handler for /api and /api/ that delegates to the main handler
async fn api_404_handler(uri: axum::http::Uri) -> impl IntoResponse {
// Create a minimal request for the handler
let req = Request::builder()
.uri(uri)
.body(axum::body::Body::empty())
.unwrap();
api_404_and_method_handler(req).await
} }
// Project data structure // Project data structure
@@ -283,20 +329,29 @@ async fn projects_handler() -> impl IntoResponse {
// This is the fallback for all routes not matched by /api/* // This is the fallback for all routes not matched by /api/*
#[tracing::instrument(skip(state, req), fields(path = %req.uri().path(), method = %req.method()))] #[tracing::instrument(skip(state, req), fields(path = %req.uri().path(), method = %req.method()))]
async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Response {
let method = req.method(); let method = req.method().clone();
let uri = req.uri(); let uri = req.uri();
let path = uri.path(); let path = uri.path();
let query = uri.query().unwrap_or(""); let query = uri.query().unwrap_or("");
// Only allow GET requests outside of /api routes // Only allow GET and HEAD requests outside of /api routes
if method != axum::http::Method::GET { if method != axum::http::Method::GET && method != axum::http::Method::HEAD {
tracing::warn!(method = %method, path = %path, "Non-GET request to non-API route"); tracing::warn!(method = %method, path = %path, "Non-GET/HEAD request to non-API route");
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::ALLOW,
axum::http::HeaderValue::from_static("GET, HEAD, OPTIONS")
);
return ( return (
StatusCode::METHOD_NOT_ALLOWED, StatusCode::METHOD_NOT_ALLOWED,
headers,
"Method not allowed", "Method not allowed",
) )
.into_response(); .into_response();
} }
// For HEAD requests, we'll still proxy to Bun but strip the body later
let is_head = method == axum::http::Method::HEAD;
// Check if API route somehow reached ISR handler (shouldn't happen) // Check if API route somehow reached ISR handler (shouldn't happen)
if path.starts_with("/api/") { if path.starts_with("/api/") {
@@ -386,8 +441,12 @@ async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Respon
} }
} }
// Forward response // Forward response, but strip body for HEAD requests
(status, headers, body).into_response() if is_head {
(status, headers).into_response()
} else {
(status, headers, body).into_response()
}
} }
Err(err) => { Err(err) => {
let duration_ms = start.elapsed().as_millis() as u64; let duration_ms = start.elapsed().as_millis() as u64;