refactor: consolidate HTTP client for TCP/Unix socket handling

Extract reqwest client creation into dedicated HttpClient abstraction that handles both TCP and Unix socket connections transparently. Simplifies proxy logic by removing duplicate URL construction and client selection throughout the codebase.
This commit is contained in:
2026-01-07 14:34:32 -06:00
parent dcc496c979
commit dd1ce186d2
13 changed files with 398 additions and 329 deletions
+2 -2
View File
@@ -7,7 +7,7 @@ pub use projects::*;
pub use settings::*;
pub use tags::*;
use sqlx::{PgPool, postgres::PgPoolOptions};
use sqlx::{PgPool, postgres::PgPoolOptions, query};
/// Database connection pool creation
pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
@@ -20,7 +20,7 @@ pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
/// Health check query
pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> {
sqlx::query!("SELECT 1 as check")
query!("SELECT 1 as check")
.fetch_one(pool)
.await
.map(|_| ())
+71 -74
View File
@@ -1,9 +1,13 @@
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use time::OffsetDateTime;
use serde_json::json;
use sqlx::{PgPool, query, query_as};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use uuid::Uuid;
use super::{ProjectStatus, slugify};
use super::{
ProjectStatus, slugify,
tags::{ApiTag, DbTag, get_tags_for_project},
};
// Database model
#[derive(Debug, Clone, sqlx::FromRow)]
@@ -43,7 +47,7 @@ pub struct ApiProject {
pub struct ApiAdminProject {
#[serde(flatten)]
pub project: ApiProject,
pub tags: Vec<super::tags::ApiTag>,
pub tags: Vec<ApiTag>,
pub status: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -86,7 +90,7 @@ impl DbProject {
}
}
pub fn to_api_admin_project(&self, tags: Vec<super::tags::DbTag>) -> ApiAdminProject {
pub fn to_api_admin_project(&self, tags: Vec<DbTag>) -> ApiAdminProject {
ApiAdminProject {
project: self.to_api_project(),
tags: tags.into_iter().map(|t| t.to_api_tag()).collect(),
@@ -94,18 +98,11 @@ impl DbProject {
description: self.description.clone(),
github_repo: self.github_repo.clone(),
demo_url: self.demo_url.clone(),
created_at: self
.created_at
.format(&time::format_description::well_known::Rfc3339)
.unwrap(),
updated_at: self
.updated_at
.format(&time::format_description::well_known::Rfc3339)
.unwrap(),
last_github_activity: self.last_github_activity.map(|dt| {
dt.format(&time::format_description::well_known::Rfc3339)
.unwrap()
}),
created_at: self.created_at.format(&Rfc3339).unwrap(),
updated_at: self.updated_at.format(&Rfc3339).unwrap(),
last_github_activity: self
.last_github_activity
.map(|dt| dt.format(&Rfc3339).unwrap()),
}
}
}
@@ -150,20 +147,20 @@ pub struct AdminStats {
// Query functions
pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
id,
slug,
SELECT
id,
slug,
name,
short_description,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
updated_at
FROM projects
WHERE status != 'hidden'
@@ -176,12 +173,12 @@ pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::
pub async fn get_public_projects_with_tags(
pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_public_projects(pool).await?;
let mut result = Vec::new();
for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags));
}
@@ -190,20 +187,20 @@ pub async fn get_public_projects_with_tags(
/// Get all projects (admin view - includes hidden)
pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
id,
slug,
SELECT
id,
slug,
name,
short_description,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
updated_at
FROM projects
ORDER BY updated_at DESC
@@ -216,12 +213,12 @@ pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sql
/// Get all projects with tags (admin view)
pub async fn get_all_projects_with_tags_admin(
pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_all_projects_admin(pool).await?;
let mut result = Vec::new();
for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags));
}
@@ -230,21 +227,21 @@ pub async fn get_all_projects_with_tags_admin(
/// Get single project by ID
pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
id,
slug,
SELECT
id,
slug,
name,
short_description,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
last_github_activity,
created_at,
updated_at
FROM projects
WHERE id = $1
@@ -259,12 +256,12 @@ pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProje
pub async fn get_project_by_id_with_tags(
pool: &PgPool,
id: Uuid,
) -> Result<Option<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Option<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let project = get_project_by_id(pool, id).await?;
match project {
Some(p) => {
let tags = super::tags::get_tags_for_project(pool, p.id).await?;
let tags = get_tags_for_project(pool, p.id).await?;
Ok(Some((p, tags)))
}
None => Ok(None),
@@ -276,21 +273,21 @@ pub async fn get_project_by_slug(
pool: &PgPool,
slug: &str,
) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
id,
slug,
SELECT
id,
slug,
name,
short_description,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
description,
status as "status: ProjectStatus",
github_repo,
demo_url,
last_github_activity,
created_at,
last_github_activity,
created_at,
updated_at
FROM projects
WHERE slug = $1
@@ -316,12 +313,12 @@ pub async fn create_project(
.map(|s| slugify(s))
.unwrap_or_else(|| slugify(name));
sqlx::query_as!(
query_as!(
DbProject,
r#"
INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus",
RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus",
github_repo, demo_url, last_github_activity, created_at, updated_at
"#,
slug,
@@ -352,14 +349,14 @@ pub async fn update_project(
.map(|s| slugify(s))
.unwrap_or_else(|| slugify(name));
sqlx::query_as!(
query_as!(
DbProject,
r#"
UPDATE projects
SET slug = $2, name = $3, short_description = $4, description = $5,
SET slug = $2, name = $3, short_description = $4, description = $5,
status = $6, github_repo = $7, demo_url = $8
WHERE id = $1
RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus",
RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus",
github_repo, demo_url, last_github_activity, created_at, updated_at
"#,
id,
@@ -377,7 +374,7 @@ pub async fn update_project(
/// Delete project (CASCADE will handle tags)
pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM projects WHERE id = $1", id)
query!("DELETE FROM projects WHERE id = $1", id)
.execute(pool)
.await?;
Ok(())
@@ -386,9 +383,9 @@ pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error>
/// Get admin stats
pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
// Get project counts by status
let status_counts = sqlx::query!(
let status_counts = query!(
r#"
SELECT
SELECT
status as "status!: ProjectStatus",
COUNT(*)::int as "count!"
FROM projects
@@ -398,7 +395,7 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
.fetch_all(pool)
.await?;
let mut projects_by_status = serde_json::json!({
let mut projects_by_status = json!({
"active": 0,
"maintained": 0,
"archived": 0,
@@ -408,12 +405,12 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
let mut total_projects = 0;
for row in status_counts {
let status_str = format!("{:?}", row.status).to_lowercase();
projects_by_status[status_str] = serde_json::json!(row.count);
projects_by_status[status_str] = json!(row.count);
total_projects += row.count;
}
// Get total tags
let tag_count = sqlx::query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
let tag_count = query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
.fetch_one(pool)
.await?;
+4 -11
View File
@@ -60,17 +60,10 @@ pub async fn proxy_icons_handler(
let full_path = format!("/api/icons/{}", path);
let query = req.uri().query().unwrap_or("");
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
if query.is_empty() {
format!("http://localhost{}", full_path)
} else {
format!("http://localhost{}?{}", full_path, query)
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, full_path)
let path_with_query = if query.is_empty() {
full_path.clone()
} else {
format!("{}{}?{}", state.downstream_url, full_path, query)
format!("{full_path}?{query}")
};
// Build trusted headers with session info
@@ -86,7 +79,7 @@ pub async fn proxy_icons_handler(
}
}
match proxy::proxy_to_bun(&bun_url, state, forward_headers).await {
match proxy::proxy_to_bun(&path_with_query, state, forward_headers).await {
Ok((status, headers, body)) => (status, headers, body).into_response(),
Err(err) => {
tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request");
+146
View File
@@ -0,0 +1,146 @@
use reqwest::Method;
use std::path::PathBuf;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ClientError {
#[error("Failed to build reqwest client: {0}")]
BuildError(#[from] reqwest::Error),
#[error("Invalid downstream URL: {0}")]
InvalidUrl(String),
}
#[derive(Clone)]
pub struct HttpClient {
client: reqwest::Client,
target: TargetUrl,
}
#[derive(Debug, Clone)]
enum TargetUrl {
Tcp(String), // Base URL like "http://localhost:5173"
Unix(PathBuf), // Socket path like "/tmp/bun.sock"
}
impl HttpClient {
/// Create a new HttpClient from a downstream URL
///
/// Accepts:
/// - TCP: "http://localhost:5173", "https://example.com"
/// - Unix: "/tmp/bun.sock", "./relative.sock"
pub fn new(downstream: &str) -> Result<Self, ClientError> {
let target = if downstream.starts_with('/') || downstream.starts_with("./") {
TargetUrl::Unix(PathBuf::from(downstream))
} else if downstream.starts_with("http://") || downstream.starts_with("https://") {
TargetUrl::Tcp(downstream.to_string())
} else {
return Err(ClientError::InvalidUrl(downstream.to_string()));
};
tracing::debug!(
target = ?target,
downstream = %downstream,
"Creating HTTP client"
);
let client = match &target {
TargetUrl::Unix(path) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.unix_socket(path.clone())
.build()?,
TargetUrl::Tcp(_) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.build()?,
};
Ok(Self { client, target })
}
/// Build a full URL from a path
///
/// Examples:
/// - TCP target "http://localhost:5173" + "/api/health" → "http://localhost:5173/api/health"
/// - Unix target "/tmp/bun.sock" + "/api/health" → "http://localhost/api/health"
fn build_url(&self, path: &str) -> String {
match &self.target {
TargetUrl::Tcp(base) => format!("{}{}", base, path),
TargetUrl::Unix(_) => format!("http://localhost{}", path),
}
}
pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
self.client.get(self.build_url(path))
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
self.client.post(self.build_url(path))
}
pub fn request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
self.client.request(method, self.build_url(path))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_url_construction() {
let client = HttpClient::new("http://localhost:5173").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost:5173/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost:5173/path?query=1"
);
}
#[test]
fn test_unix_url_construction() {
let client = HttpClient::new("/tmp/bun.sock").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost/path?query=1"
);
}
#[test]
fn test_relative_unix_socket() {
let client = HttpClient::new("./relative.sock").unwrap();
assert!(matches!(client.target, TargetUrl::Unix(_)));
}
#[test]
fn test_invalid_url() {
let result = HttpClient::new("not-a-valid-url");
assert!(result.is_err());
}
#[test]
fn test_https_url() {
let client = HttpClient::new("https://example.com").unwrap();
assert!(matches!(client.target, TargetUrl::Tcp(_)));
assert_eq!(
client.build_url("/api/test"),
"https://example.com/api/test"
);
}
}
+7 -41
View File
@@ -1,6 +1,5 @@
use clap::Parser;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
@@ -13,6 +12,7 @@ mod db;
mod formatter;
mod handlers;
mod health;
mod http;
mod middleware;
mod og;
mod proxy;
@@ -125,50 +125,18 @@ async fn main() {
std::process::exit(1);
}
// Create HTTP client for TCP connections with optimized pool settings
let http_client = reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.build()
.expect("Failed to create HTTP client");
// Create Unix socket client if downstream is a Unix socket
let unix_client = if args.downstream.starts_with('/') || args.downstream.starts_with("./") {
let path = PathBuf::from(&args.downstream);
Some(
reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.unix_socket(path)
.build()
.expect("Failed to create Unix socket client"),
)
} else {
None
};
// Create socket-aware HTTP client
let client = http::HttpClient::new(&args.downstream).expect("Failed to create HTTP client");
// Create health checker
let downstream_url_for_health = args.downstream.clone();
let http_client_for_health = http_client.clone();
let unix_client_for_health = unix_client.clone();
let client_for_health = client.clone();
let pool_for_health = pool.clone();
let health_checker = Arc::new(HealthChecker::new(move || {
let downstream_url = downstream_url_for_health.clone();
let http_client = http_client_for_health.clone();
let unix_client = unix_client_for_health.clone();
let client = client_for_health.clone();
let pool = pool_for_health.clone();
async move {
proxy::perform_health_check(downstream_url, http_client, unix_client, Some(pool)).await
}
async move { proxy::perform_health_check(client, Some(pool)).await }
}));
let tarpit_config = TarpitConfig::from_env();
@@ -186,9 +154,7 @@ async fn main() {
);
let state = Arc::new(AppState {
downstream_url: args.downstream.clone(),
http_client,
unix_client,
client,
health_checker,
tarpit_state,
pool: pool.clone(),
+3 -11
View File
@@ -36,17 +36,9 @@ pub async fn generate_og_image(spec: &OGImageSpec, state: Arc<AppState>) -> Resu
tracing::Span::current().record("r2_key", &r2_key);
// Call Bun's internal endpoint
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
"http://localhost/internal/ogp/generate".to_string()
} else {
format!("{}/internal/ogp/generate", state.downstream_url)
};
let client = state.unix_client.as_ref().unwrap_or(&state.http_client);
let response = client
.post(&bun_url)
let response = state
.client
.post("/internal/ogp/generate")
.json(spec)
.timeout(Duration::from_secs(30))
.send()
+50 -74
View File
@@ -72,17 +72,10 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
return response;
}
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
if query.is_empty() {
format!("http://localhost{path}")
} else {
format!("http://localhost{path}?{query}")
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, path)
let path_with_query = if query.is_empty() {
path.to_string()
} else {
format!("{}{}?{}", state.downstream_url, path, query)
format!("{path}?{query}")
};
// Build trusted headers to forward to downstream
@@ -120,7 +113,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let start = std::time::Instant::now();
match proxy_to_bun(&bun_url, state.clone(), forward_headers).await {
match proxy_to_bun(&path_with_query, state.clone(), forward_headers).await {
Ok((status, headers, body)) => {
let duration_ms = start.elapsed().as_millis() as u64;
let cache = "miss";
@@ -182,7 +175,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let duration_ms = start.elapsed().as_millis() as u64;
tracing::error!(
error = %err,
url = %bun_url,
path = %path_with_query,
duration_ms,
"Failed to proxy to Bun"
);
@@ -203,18 +196,12 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
/// Proxy a request to Bun SSR
pub async fn proxy_to_bun(
url: &str,
path: &str,
state: Arc<AppState>,
forward_headers: HeaderMap,
) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> {
let client = if state.unix_client.is_some() {
state.unix_client.as_ref().unwrap()
} else {
&state.http_client
};
// Build request with forwarded headers
let mut request_builder = client.get(url);
let mut request_builder = state.client.get(path);
for (name, value) in forward_headers.iter() {
request_builder = request_builder.header(name, value);
}
@@ -247,45 +234,35 @@ pub async fn proxy_to_bun(
/// Perform health check on Bun SSR and database
pub async fn perform_health_check(
downstream_url: String,
http_client: reqwest::Client,
unix_client: Option<reqwest::Client>,
client: crate::http::HttpClient,
pool: Option<sqlx::PgPool>,
) -> bool {
let url = if downstream_url.starts_with('/') || downstream_url.starts_with("./") {
"http://localhost/internal/health".to_string()
} else {
format!("{downstream_url}/internal/health")
let bun_healthy = match tokio::time::timeout(
Duration::from_secs(5),
client.get("/internal/health").send(),
)
.await
{
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
};
let client = if unix_client.is_some() {
unix_client.as_ref().unwrap()
} else {
&http_client
};
let bun_healthy =
match tokio::time::timeout(Duration::from_secs(5), client.get(&url).send()).await {
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
};
// Check database
let db_healthy = if let Some(pool) = pool {
match db::health_check(&pool).await {
@@ -307,33 +284,32 @@ fn should_tarpit(state: &TarpitState, path: &str) -> bool {
state.config.enabled && tarpit::is_malicious_path(path)
}
/// Common handler logic for requests with optional peer info
async fn handle_request_with_optional_peer(
state: Arc<AppState>,
peer: Option<SocketAddr>,
req: Request,
) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
let peer_info = peer.map(ConnectInfo);
tarpit::tarpit_handler(State(state.tarpit_state.clone()), peer_info, req).await
} else {
isr_handler(State(state), req).await
}
}
/// Fallback handler for TCP connections (has access to peer IP)
pub async fn fallback_handler_tcp(
State(state): State<Arc<AppState>>,
ConnectInfo(peer): ConnectInfo<SocketAddr>,
req: Request,
) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(
State(state.tarpit_state.clone()),
Some(ConnectInfo(peer)),
req,
)
.await
} else {
isr_handler(State(state), req).await
}
handle_request_with_optional_peer(state, Some(peer), req).await
}
/// Fallback handler for Unix sockets (no peer IP available)
pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(State(state.tarpit_state.clone()), None, req).await
} else {
isr_handler(State(state), req).await
}
handle_request_with_optional_peer(state, None, req).await
}
+31 -50
View File
@@ -1,4 +1,11 @@
use axum::{Router, extract::Request, http::Uri, response::IntoResponse, routing::any};
use axum::{
Router,
body::Body,
extract::Request,
http::{Method, Uri},
response::IntoResponse,
routing::{any, get, post},
};
use std::sync::Arc;
use crate::{assets, handlers, state::AppState};
@@ -9,32 +16,28 @@ pub fn api_routes() -> Router<Arc<AppState>> {
.route("/", any(api_root_404_handler))
.route(
"/health",
axum::routing::get(handlers::health_handler).head(handlers::health_handler),
get(handlers::health_handler).head(handlers::health_handler),
)
// Authentication endpoints (public)
.route("/login", axum::routing::post(handlers::api_login_handler))
.route("/logout", axum::routing::post(handlers::api_logout_handler))
.route(
"/session",
axum::routing::get(handlers::api_session_handler),
)
.route("/login", post(handlers::api_login_handler))
.route("/logout", post(handlers::api_logout_handler))
.route("/session", get(handlers::api_session_handler))
// Projects - GET is public (shows all for admin, only non-hidden for public)
// POST/PUT/DELETE require authentication
.route(
"/projects",
axum::routing::get(handlers::projects_handler).post(handlers::create_project_handler),
get(handlers::projects_handler).post(handlers::create_project_handler),
)
.route(
"/projects/{id}",
axum::routing::get(handlers::get_project_handler)
get(handlers::get_project_handler)
.put(handlers::update_project_handler)
.delete(handlers::delete_project_handler),
)
// Project tags - authentication checked in handlers
.route(
"/projects/{id}/tags",
axum::routing::get(handlers::get_project_tags_handler)
.post(handlers::add_project_tag_handler),
get(handlers::get_project_tags_handler).post(handlers::add_project_tag_handler),
)
.route(
"/projects/{id}/tags/{tag_id}",
@@ -43,36 +46,29 @@ pub fn api_routes() -> Router<Arc<AppState>> {
// Tags - authentication checked in handlers
.route(
"/tags",
axum::routing::get(handlers::list_tags_handler).post(handlers::create_tag_handler),
get(handlers::list_tags_handler).post(handlers::create_tag_handler),
)
.route(
"/tags/{slug}",
axum::routing::get(handlers::get_tag_handler).put(handlers::update_tag_handler),
get(handlers::get_tag_handler).put(handlers::update_tag_handler),
)
.route(
"/tags/{slug}/related",
axum::routing::get(handlers::get_related_tags_handler),
get(handlers::get_related_tags_handler),
)
.route(
"/tags/recalculate-cooccurrence",
axum::routing::post(handlers::recalculate_cooccurrence_handler),
post(handlers::recalculate_cooccurrence_handler),
)
// Admin stats - requires authentication
.route(
"/stats",
axum::routing::get(handlers::get_admin_stats_handler),
)
.route("/stats", get(handlers::get_admin_stats_handler))
// Site settings - GET is public, PUT requires authentication
.route(
"/settings",
axum::routing::get(handlers::get_settings_handler)
.put(handlers::update_settings_handler),
get(handlers::get_settings_handler).put(handlers::update_settings_handler),
)
// Icon API - proxy to SvelteKit (authentication handled by SvelteKit)
.route(
"/icons/{*path}",
axum::routing::get(handlers::proxy_icons_handler),
)
.route("/icons/{*path}", get(handlers::proxy_icons_handler))
.fallback(api_404_and_method_handler)
}
@@ -83,19 +79,13 @@ pub fn build_base_router() -> Router<Arc<AppState>> {
.route("/api/", any(api_root_404_handler))
.route(
"/_app/{*path}",
axum::routing::get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
)
.route("/pgp", axum::routing::get(handlers::handle_pgp_route))
.route(
"/publickey.asc",
axum::routing::get(handlers::serve_pgp_key),
)
.route("/pgp.asc", axum::routing::get(handlers::serve_pgp_key))
.route(
"/.well-known/pgpkey.asc",
axum::routing::get(handlers::serve_pgp_key),
)
.route("/keys", axum::routing::get(handlers::redirect_to_pgp))
.route("/pgp", get(handlers::handle_pgp_route))
.route("/publickey.asc", get(handlers::serve_pgp_key))
.route("/pgp.asc", get(handlers::serve_pgp_key))
.route("/.well-known/pgpkey.asc", get(handlers::serve_pgp_key))
.route("/keys", get(handlers::redirect_to_pgp))
}
async fn api_root_404_handler(uri: Uri) -> impl IntoResponse {
@@ -109,10 +99,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
let uri = req.uri();
let path = uri.path();
if method != axum::http::Method::GET
&& method != axum::http::Method::HEAD
&& method != axum::http::Method::OPTIONS
{
if method != Method::GET && method != Method::HEAD && method != Method::OPTIONS {
let content_type = req
.headers()
.get(axum::http::header::CONTENT_TYPE)
@@ -129,10 +116,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
)
.into_response();
}
} else if method == axum::http::Method::POST
|| method == axum::http::Method::PUT
|| method == axum::http::Method::PATCH
{
} else if method == Method::POST || method == Method::PUT || method == Method::PATCH {
// POST/PUT/PATCH require Content-Type header
return (
StatusCode::BAD_REQUEST,
@@ -158,10 +142,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
}
async fn api_404_handler(uri: Uri) -> impl IntoResponse {
let req = Request::builder()
.uri(uri)
.body(axum::body::Body::empty())
.unwrap();
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
api_404_and_method_handler(req).await
}
+2 -4
View File
@@ -1,13 +1,11 @@
use std::sync::Arc;
use crate::{auth::SessionManager, health::HealthChecker, tarpit::TarpitState};
use crate::{auth::SessionManager, health::HealthChecker, http::HttpClient, tarpit::TarpitState};
/// Application state shared across all handlers
#[derive(Clone)]
pub struct AppState {
pub downstream_url: String,
pub http_client: reqwest::Client,
pub unix_client: Option<reqwest::Client>,
pub client: HttpClient,
pub health_checker: Arc<HealthChecker>,
pub tarpit_state: Arc<TarpitState>,
pub pool: sqlx::PgPool,
+12 -12
View File
@@ -1,5 +1,5 @@
use axum::{
http::{HeaderMap, StatusCode},
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
@@ -34,7 +34,7 @@ pub fn is_page_route(path: &str) -> bool {
/// Check if the request accepts HTML responses
pub fn accepts_html(headers: &HeaderMap) -> bool {
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() {
return accept_str.contains("text/html") || accept_str.contains("*/*");
}
@@ -46,7 +46,7 @@ pub fn accepts_html(headers: &HeaderMap) -> bool {
/// Determines if request prefers raw content (CLI tools) over HTML
pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
// Check User-Agent for known CLI tools first (most reliable)
if let Some(ua) = headers.get(axum::http::header::USER_AGENT) {
if let Some(ua) = headers.get(header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() {
let ua_lower = ua_str.to_lowercase();
if ua_lower.starts_with("curl/")
@@ -60,7 +60,7 @@ pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
}
// Check Accept header - if it explicitly prefers text/html, serve HTML
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() {
// If text/html appears before */* in the list, they prefer HTML
if let Some(html_pos) = accept_str.find("text/html") {
@@ -88,12 +88,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(html) = assets::get_error_page(status_code) {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
);
(status, headers, html).into_response()
@@ -107,12 +107,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(fallback_html) = assets::get_error_page(500) {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
);
(status, headers, fallback_html).into_response()