refactor: consolidate HTTP client for TCP/Unix socket handling

Extract reqwest client creation into dedicated HttpClient abstraction that handles both TCP and Unix socket connections transparently. Simplifies proxy logic by removing duplicate URL construction and client selection throughout the codebase.
This commit is contained in:
2026-01-07 14:34:32 -06:00
parent dcc496c979
commit dd1ce186d2
13 changed files with 398 additions and 329 deletions
Generated
+1
View File
@@ -88,6 +88,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
"thiserror 2.0.17",
"time", "time",
"tokio", "tokio",
"tokio-util", "tokio-util",
+1
View File
@@ -22,6 +22,7 @@ reqwest = { version = "0.13.1", default-features = false, features = ["rustls",
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.148" serde_json = "1.0.148"
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] } sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] }
thiserror = "2.0.17"
time = { version = "0.3.44", features = ["formatting", "macros", "serde"] } time = { version = "0.3.44", features = ["formatting", "macros", "serde"] }
tokio = { version = "1.49.0", features = ["full"] } tokio = { version = "1.49.0", features = ["full"] }
tokio-util = { version = "0.7.18", features = ["io"] } tokio-util = { version = "0.7.18", features = ["io"] }
+2 -2
View File
@@ -7,7 +7,7 @@ pub use projects::*;
pub use settings::*; pub use settings::*;
pub use tags::*; pub use tags::*;
use sqlx::{PgPool, postgres::PgPoolOptions}; use sqlx::{PgPool, postgres::PgPoolOptions, query};
/// Database connection pool creation /// Database connection pool creation
pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> { pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
@@ -20,7 +20,7 @@ pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
/// Health check query /// Health check query
pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> { pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> {
sqlx::query!("SELECT 1 as check") query!("SELECT 1 as check")
.fetch_one(pool) .fetch_one(pool)
.await .await
.map(|_| ()) .map(|_| ())
+31 -34
View File
@@ -1,9 +1,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use serde_json::json;
use time::OffsetDateTime; use sqlx::{PgPool, query, query_as};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use uuid::Uuid; use uuid::Uuid;
use super::{ProjectStatus, slugify}; use super::{
ProjectStatus, slugify,
tags::{ApiTag, DbTag, get_tags_for_project},
};
// Database model // Database model
#[derive(Debug, Clone, sqlx::FromRow)] #[derive(Debug, Clone, sqlx::FromRow)]
@@ -43,7 +47,7 @@ pub struct ApiProject {
pub struct ApiAdminProject { pub struct ApiAdminProject {
#[serde(flatten)] #[serde(flatten)]
pub project: ApiProject, pub project: ApiProject,
pub tags: Vec<super::tags::ApiTag>, pub tags: Vec<ApiTag>,
pub status: String, pub status: String,
pub description: String, pub description: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -86,7 +90,7 @@ impl DbProject {
} }
} }
pub fn to_api_admin_project(&self, tags: Vec<super::tags::DbTag>) -> ApiAdminProject { pub fn to_api_admin_project(&self, tags: Vec<DbTag>) -> ApiAdminProject {
ApiAdminProject { ApiAdminProject {
project: self.to_api_project(), project: self.to_api_project(),
tags: tags.into_iter().map(|t| t.to_api_tag()).collect(), tags: tags.into_iter().map(|t| t.to_api_tag()).collect(),
@@ -94,18 +98,11 @@ impl DbProject {
description: self.description.clone(), description: self.description.clone(),
github_repo: self.github_repo.clone(), github_repo: self.github_repo.clone(),
demo_url: self.demo_url.clone(), demo_url: self.demo_url.clone(),
created_at: self created_at: self.created_at.format(&Rfc3339).unwrap(),
.created_at updated_at: self.updated_at.format(&Rfc3339).unwrap(),
.format(&time::format_description::well_known::Rfc3339) last_github_activity: self
.unwrap(), .last_github_activity
updated_at: self .map(|dt| dt.format(&Rfc3339).unwrap()),
.updated_at
.format(&time::format_description::well_known::Rfc3339)
.unwrap(),
last_github_activity: self.last_github_activity.map(|dt| {
dt.format(&time::format_description::well_known::Rfc3339)
.unwrap()
}),
} }
} }
} }
@@ -150,7 +147,7 @@ pub struct AdminStats {
// Query functions // Query functions
pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> { pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
SELECT SELECT
@@ -176,12 +173,12 @@ pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::
pub async fn get_public_projects_with_tags( pub async fn get_public_projects_with_tags(
pool: &PgPool, pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> { ) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_public_projects(pool).await?; let projects = get_public_projects(pool).await?;
let mut result = Vec::new(); let mut result = Vec::new();
for project in projects { for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?; let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags)); result.push((project, tags));
} }
@@ -190,7 +187,7 @@ pub async fn get_public_projects_with_tags(
/// Get all projects (admin view - includes hidden) /// Get all projects (admin view - includes hidden)
pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> { pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
SELECT SELECT
@@ -216,12 +213,12 @@ pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sql
/// Get all projects with tags (admin view) /// Get all projects with tags (admin view)
pub async fn get_all_projects_with_tags_admin( pub async fn get_all_projects_with_tags_admin(
pool: &PgPool, pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> { ) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_all_projects_admin(pool).await?; let projects = get_all_projects_admin(pool).await?;
let mut result = Vec::new(); let mut result = Vec::new();
for project in projects { for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?; let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags)); result.push((project, tags));
} }
@@ -230,7 +227,7 @@ pub async fn get_all_projects_with_tags_admin(
/// Get single project by ID /// Get single project by ID
pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> { pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
SELECT SELECT
@@ -259,12 +256,12 @@ pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProje
pub async fn get_project_by_id_with_tags( pub async fn get_project_by_id_with_tags(
pool: &PgPool, pool: &PgPool,
id: Uuid, id: Uuid,
) -> Result<Option<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> { ) -> Result<Option<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let project = get_project_by_id(pool, id).await?; let project = get_project_by_id(pool, id).await?;
match project { match project {
Some(p) => { Some(p) => {
let tags = super::tags::get_tags_for_project(pool, p.id).await?; let tags = get_tags_for_project(pool, p.id).await?;
Ok(Some((p, tags))) Ok(Some((p, tags)))
} }
None => Ok(None), None => Ok(None),
@@ -276,7 +273,7 @@ pub async fn get_project_by_slug(
pool: &PgPool, pool: &PgPool,
slug: &str, slug: &str,
) -> Result<Option<DbProject>, sqlx::Error> { ) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
SELECT SELECT
@@ -316,7 +313,7 @@ pub async fn create_project(
.map(|s| slugify(s)) .map(|s| slugify(s))
.unwrap_or_else(|| slugify(name)); .unwrap_or_else(|| slugify(name));
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url) INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url)
@@ -352,7 +349,7 @@ pub async fn update_project(
.map(|s| slugify(s)) .map(|s| slugify(s))
.unwrap_or_else(|| slugify(name)); .unwrap_or_else(|| slugify(name));
sqlx::query_as!( query_as!(
DbProject, DbProject,
r#" r#"
UPDATE projects UPDATE projects
@@ -377,7 +374,7 @@ pub async fn update_project(
/// Delete project (CASCADE will handle tags) /// Delete project (CASCADE will handle tags)
pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> { pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM projects WHERE id = $1", id) query!("DELETE FROM projects WHERE id = $1", id)
.execute(pool) .execute(pool)
.await?; .await?;
Ok(()) Ok(())
@@ -386,7 +383,7 @@ pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error>
/// Get admin stats /// Get admin stats
pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> { pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
// Get project counts by status // Get project counts by status
let status_counts = sqlx::query!( let status_counts = query!(
r#" r#"
SELECT SELECT
status as "status!: ProjectStatus", status as "status!: ProjectStatus",
@@ -398,7 +395,7 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
.fetch_all(pool) .fetch_all(pool)
.await?; .await?;
let mut projects_by_status = serde_json::json!({ let mut projects_by_status = json!({
"active": 0, "active": 0,
"maintained": 0, "maintained": 0,
"archived": 0, "archived": 0,
@@ -408,12 +405,12 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
let mut total_projects = 0; let mut total_projects = 0;
for row in status_counts { for row in status_counts {
let status_str = format!("{:?}", row.status).to_lowercase(); let status_str = format!("{:?}", row.status).to_lowercase();
projects_by_status[status_str] = serde_json::json!(row.count); projects_by_status[status_str] = json!(row.count);
total_projects += row.count; total_projects += row.count;
} }
// Get total tags // Get total tags
let tag_count = sqlx::query!("SELECT COUNT(*)::int as \"count!\" FROM tags") let tag_count = query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
.fetch_one(pool) .fetch_one(pool)
.await?; .await?;
+4 -11
View File
@@ -60,17 +60,10 @@ pub async fn proxy_icons_handler(
let full_path = format!("/api/icons/{}", path); let full_path = format!("/api/icons/{}", path);
let query = req.uri().query().unwrap_or(""); let query = req.uri().query().unwrap_or("");
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") let path_with_query = if query.is_empty() {
{ full_path.clone()
if query.is_empty() {
format!("http://localhost{}", full_path)
} else {
format!("http://localhost{}?{}", full_path, query)
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, full_path)
} else { } else {
format!("{}{}?{}", state.downstream_url, full_path, query) format!("{full_path}?{query}")
}; };
// Build trusted headers with session info // Build trusted headers with session info
@@ -86,7 +79,7 @@ pub async fn proxy_icons_handler(
} }
} }
match proxy::proxy_to_bun(&bun_url, state, forward_headers).await { match proxy::proxy_to_bun(&path_with_query, state, forward_headers).await {
Ok((status, headers, body)) => (status, headers, body).into_response(), Ok((status, headers, body)) => (status, headers, body).into_response(),
Err(err) => { Err(err) => {
tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request"); tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request");
+146
View File
@@ -0,0 +1,146 @@
use reqwest::Method;
use std::path::PathBuf;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ClientError {
#[error("Failed to build reqwest client: {0}")]
BuildError(#[from] reqwest::Error),
#[error("Invalid downstream URL: {0}")]
InvalidUrl(String),
}
#[derive(Clone)]
pub struct HttpClient {
client: reqwest::Client,
target: TargetUrl,
}
#[derive(Debug, Clone)]
enum TargetUrl {
Tcp(String), // Base URL like "http://localhost:5173"
Unix(PathBuf), // Socket path like "/tmp/bun.sock"
}
impl HttpClient {
/// Create a new HttpClient from a downstream URL
///
/// Accepts:
/// - TCP: "http://localhost:5173", "https://example.com"
/// - Unix: "/tmp/bun.sock", "./relative.sock"
pub fn new(downstream: &str) -> Result<Self, ClientError> {
let target = if downstream.starts_with('/') || downstream.starts_with("./") {
TargetUrl::Unix(PathBuf::from(downstream))
} else if downstream.starts_with("http://") || downstream.starts_with("https://") {
TargetUrl::Tcp(downstream.to_string())
} else {
return Err(ClientError::InvalidUrl(downstream.to_string()));
};
tracing::debug!(
target = ?target,
downstream = %downstream,
"Creating HTTP client"
);
let client = match &target {
TargetUrl::Unix(path) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.unix_socket(path.clone())
.build()?,
TargetUrl::Tcp(_) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.build()?,
};
Ok(Self { client, target })
}
/// Build a full URL from a path
///
/// Examples:
/// - TCP target "http://localhost:5173" + "/api/health" → "http://localhost:5173/api/health"
/// - Unix target "/tmp/bun.sock" + "/api/health" → "http://localhost/api/health"
fn build_url(&self, path: &str) -> String {
match &self.target {
TargetUrl::Tcp(base) => format!("{}{}", base, path),
TargetUrl::Unix(_) => format!("http://localhost{}", path),
}
}
pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
self.client.get(self.build_url(path))
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
self.client.post(self.build_url(path))
}
pub fn request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
self.client.request(method, self.build_url(path))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_url_construction() {
let client = HttpClient::new("http://localhost:5173").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost:5173/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost:5173/path?query=1"
);
}
#[test]
fn test_unix_url_construction() {
let client = HttpClient::new("/tmp/bun.sock").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost/path?query=1"
);
}
#[test]
fn test_relative_unix_socket() {
let client = HttpClient::new("./relative.sock").unwrap();
assert!(matches!(client.target, TargetUrl::Unix(_)));
}
#[test]
fn test_invalid_url() {
let result = HttpClient::new("not-a-valid-url");
assert!(result.is_err());
}
#[test]
fn test_https_url() {
let client = HttpClient::new("https://example.com").unwrap();
assert!(matches!(client.target, TargetUrl::Tcp(_)));
assert_eq!(
client.build_url("/api/test"),
"https://example.com/api/test"
);
}
}
+7 -41
View File
@@ -1,6 +1,5 @@
use clap::Parser; use clap::Parser;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer}; use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
@@ -13,6 +12,7 @@ mod db;
mod formatter; mod formatter;
mod handlers; mod handlers;
mod health; mod health;
mod http;
mod middleware; mod middleware;
mod og; mod og;
mod proxy; mod proxy;
@@ -125,50 +125,18 @@ async fn main() {
std::process::exit(1); std::process::exit(1);
} }
// Create HTTP client for TCP connections with optimized pool settings // Create socket-aware HTTP client
let http_client = reqwest::Client::builder() let client = http::HttpClient::new(&args.downstream).expect("Failed to create HTTP client");
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.build()
.expect("Failed to create HTTP client");
// Create Unix socket client if downstream is a Unix socket
let unix_client = if args.downstream.starts_with('/') || args.downstream.starts_with("./") {
let path = PathBuf::from(&args.downstream);
Some(
reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.unix_socket(path)
.build()
.expect("Failed to create Unix socket client"),
)
} else {
None
};
// Create health checker // Create health checker
let downstream_url_for_health = args.downstream.clone(); let client_for_health = client.clone();
let http_client_for_health = http_client.clone();
let unix_client_for_health = unix_client.clone();
let pool_for_health = pool.clone(); let pool_for_health = pool.clone();
let health_checker = Arc::new(HealthChecker::new(move || { let health_checker = Arc::new(HealthChecker::new(move || {
let downstream_url = downstream_url_for_health.clone(); let client = client_for_health.clone();
let http_client = http_client_for_health.clone();
let unix_client = unix_client_for_health.clone();
let pool = pool_for_health.clone(); let pool = pool_for_health.clone();
async move { async move { proxy::perform_health_check(client, Some(pool)).await }
proxy::perform_health_check(downstream_url, http_client, unix_client, Some(pool)).await
}
})); }));
let tarpit_config = TarpitConfig::from_env(); let tarpit_config = TarpitConfig::from_env();
@@ -186,9 +154,7 @@ async fn main() {
); );
let state = Arc::new(AppState { let state = Arc::new(AppState {
downstream_url: args.downstream.clone(), client,
http_client,
unix_client,
health_checker, health_checker,
tarpit_state, tarpit_state,
pool: pool.clone(), pool: pool.clone(),
+3 -11
View File
@@ -36,17 +36,9 @@ pub async fn generate_og_image(spec: &OGImageSpec, state: Arc<AppState>) -> Resu
tracing::Span::current().record("r2_key", &r2_key); tracing::Span::current().record("r2_key", &r2_key);
// Call Bun's internal endpoint // Call Bun's internal endpoint
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") let response = state
{ .client
"http://localhost/internal/ogp/generate".to_string() .post("/internal/ogp/generate")
} else {
format!("{}/internal/ogp/generate", state.downstream_url)
};
let client = state.unix_client.as_ref().unwrap_or(&state.http_client);
let response = client
.post(&bun_url)
.json(spec) .json(spec)
.timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30))
.send() .send()
+50 -74
View File
@@ -72,17 +72,10 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
return response; return response;
} }
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") let path_with_query = if query.is_empty() {
{ path.to_string()
if query.is_empty() {
format!("http://localhost{path}")
} else {
format!("http://localhost{path}?{query}")
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, path)
} else { } else {
format!("{}{}?{}", state.downstream_url, path, query) format!("{path}?{query}")
}; };
// Build trusted headers to forward to downstream // Build trusted headers to forward to downstream
@@ -120,7 +113,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let start = std::time::Instant::now(); let start = std::time::Instant::now();
match proxy_to_bun(&bun_url, state.clone(), forward_headers).await { match proxy_to_bun(&path_with_query, state.clone(), forward_headers).await {
Ok((status, headers, body)) => { Ok((status, headers, body)) => {
let duration_ms = start.elapsed().as_millis() as u64; let duration_ms = start.elapsed().as_millis() as u64;
let cache = "miss"; let cache = "miss";
@@ -182,7 +175,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let duration_ms = start.elapsed().as_millis() as u64; let duration_ms = start.elapsed().as_millis() as u64;
tracing::error!( tracing::error!(
error = %err, error = %err,
url = %bun_url, path = %path_with_query,
duration_ms, duration_ms,
"Failed to proxy to Bun" "Failed to proxy to Bun"
); );
@@ -203,18 +196,12 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
/// Proxy a request to Bun SSR /// Proxy a request to Bun SSR
pub async fn proxy_to_bun( pub async fn proxy_to_bun(
url: &str, path: &str,
state: Arc<AppState>, state: Arc<AppState>,
forward_headers: HeaderMap, forward_headers: HeaderMap,
) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> { ) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> {
let client = if state.unix_client.is_some() {
state.unix_client.as_ref().unwrap()
} else {
&state.http_client
};
// Build request with forwarded headers // Build request with forwarded headers
let mut request_builder = client.get(url); let mut request_builder = state.client.get(path);
for (name, value) in forward_headers.iter() { for (name, value) in forward_headers.iter() {
request_builder = request_builder.header(name, value); request_builder = request_builder.header(name, value);
} }
@@ -247,45 +234,35 @@ pub async fn proxy_to_bun(
/// Perform health check on Bun SSR and database /// Perform health check on Bun SSR and database
pub async fn perform_health_check( pub async fn perform_health_check(
downstream_url: String, client: crate::http::HttpClient,
http_client: reqwest::Client,
unix_client: Option<reqwest::Client>,
pool: Option<sqlx::PgPool>, pool: Option<sqlx::PgPool>,
) -> bool { ) -> bool {
let url = if downstream_url.starts_with('/') || downstream_url.starts_with("./") { let bun_healthy = match tokio::time::timeout(
"http://localhost/internal/health".to_string() Duration::from_secs(5),
} else { client.get("/internal/health").send(),
format!("{downstream_url}/internal/health") )
.await
{
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
}; };
let client = if unix_client.is_some() {
unix_client.as_ref().unwrap()
} else {
&http_client
};
let bun_healthy =
match tokio::time::timeout(Duration::from_secs(5), client.get(&url).send()).await {
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
};
// Check database // Check database
let db_healthy = if let Some(pool) = pool { let db_healthy = if let Some(pool) = pool {
match db::health_check(&pool).await { match db::health_check(&pool).await {
@@ -307,33 +284,32 @@ fn should_tarpit(state: &TarpitState, path: &str) -> bool {
state.config.enabled && tarpit::is_malicious_path(path) state.config.enabled && tarpit::is_malicious_path(path)
} }
/// Common handler logic for requests with optional peer info
async fn handle_request_with_optional_peer(
state: Arc<AppState>,
peer: Option<SocketAddr>,
req: Request,
) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
let peer_info = peer.map(ConnectInfo);
tarpit::tarpit_handler(State(state.tarpit_state.clone()), peer_info, req).await
} else {
isr_handler(State(state), req).await
}
}
/// Fallback handler for TCP connections (has access to peer IP) /// Fallback handler for TCP connections (has access to peer IP)
pub async fn fallback_handler_tcp( pub async fn fallback_handler_tcp(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
ConnectInfo(peer): ConnectInfo<SocketAddr>, ConnectInfo(peer): ConnectInfo<SocketAddr>,
req: Request, req: Request,
) -> Response { ) -> Response {
let path = req.uri().path(); handle_request_with_optional_peer(state, Some(peer), req).await
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(
State(state.tarpit_state.clone()),
Some(ConnectInfo(peer)),
req,
)
.await
} else {
isr_handler(State(state), req).await
}
} }
/// Fallback handler for Unix sockets (no peer IP available) /// Fallback handler for Unix sockets (no peer IP available)
pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response { pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response {
let path = req.uri().path(); handle_request_with_optional_peer(state, None, req).await
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(State(state.tarpit_state.clone()), None, req).await
} else {
isr_handler(State(state), req).await
}
} }
+31 -50
View File
@@ -1,4 +1,11 @@
use axum::{Router, extract::Request, http::Uri, response::IntoResponse, routing::any}; use axum::{
Router,
body::Body,
extract::Request,
http::{Method, Uri},
response::IntoResponse,
routing::{any, get, post},
};
use std::sync::Arc; use std::sync::Arc;
use crate::{assets, handlers, state::AppState}; use crate::{assets, handlers, state::AppState};
@@ -9,32 +16,28 @@ pub fn api_routes() -> Router<Arc<AppState>> {
.route("/", any(api_root_404_handler)) .route("/", any(api_root_404_handler))
.route( .route(
"/health", "/health",
axum::routing::get(handlers::health_handler).head(handlers::health_handler), get(handlers::health_handler).head(handlers::health_handler),
) )
// Authentication endpoints (public) // Authentication endpoints (public)
.route("/login", axum::routing::post(handlers::api_login_handler)) .route("/login", post(handlers::api_login_handler))
.route("/logout", axum::routing::post(handlers::api_logout_handler)) .route("/logout", post(handlers::api_logout_handler))
.route( .route("/session", get(handlers::api_session_handler))
"/session",
axum::routing::get(handlers::api_session_handler),
)
// Projects - GET is public (shows all for admin, only non-hidden for public) // Projects - GET is public (shows all for admin, only non-hidden for public)
// POST/PUT/DELETE require authentication // POST/PUT/DELETE require authentication
.route( .route(
"/projects", "/projects",
axum::routing::get(handlers::projects_handler).post(handlers::create_project_handler), get(handlers::projects_handler).post(handlers::create_project_handler),
) )
.route( .route(
"/projects/{id}", "/projects/{id}",
axum::routing::get(handlers::get_project_handler) get(handlers::get_project_handler)
.put(handlers::update_project_handler) .put(handlers::update_project_handler)
.delete(handlers::delete_project_handler), .delete(handlers::delete_project_handler),
) )
// Project tags - authentication checked in handlers // Project tags - authentication checked in handlers
.route( .route(
"/projects/{id}/tags", "/projects/{id}/tags",
axum::routing::get(handlers::get_project_tags_handler) get(handlers::get_project_tags_handler).post(handlers::add_project_tag_handler),
.post(handlers::add_project_tag_handler),
) )
.route( .route(
"/projects/{id}/tags/{tag_id}", "/projects/{id}/tags/{tag_id}",
@@ -43,36 +46,29 @@ pub fn api_routes() -> Router<Arc<AppState>> {
// Tags - authentication checked in handlers // Tags - authentication checked in handlers
.route( .route(
"/tags", "/tags",
axum::routing::get(handlers::list_tags_handler).post(handlers::create_tag_handler), get(handlers::list_tags_handler).post(handlers::create_tag_handler),
) )
.route( .route(
"/tags/{slug}", "/tags/{slug}",
axum::routing::get(handlers::get_tag_handler).put(handlers::update_tag_handler), get(handlers::get_tag_handler).put(handlers::update_tag_handler),
) )
.route( .route(
"/tags/{slug}/related", "/tags/{slug}/related",
axum::routing::get(handlers::get_related_tags_handler), get(handlers::get_related_tags_handler),
) )
.route( .route(
"/tags/recalculate-cooccurrence", "/tags/recalculate-cooccurrence",
axum::routing::post(handlers::recalculate_cooccurrence_handler), post(handlers::recalculate_cooccurrence_handler),
) )
// Admin stats - requires authentication // Admin stats - requires authentication
.route( .route("/stats", get(handlers::get_admin_stats_handler))
"/stats",
axum::routing::get(handlers::get_admin_stats_handler),
)
// Site settings - GET is public, PUT requires authentication // Site settings - GET is public, PUT requires authentication
.route( .route(
"/settings", "/settings",
axum::routing::get(handlers::get_settings_handler) get(handlers::get_settings_handler).put(handlers::update_settings_handler),
.put(handlers::update_settings_handler),
) )
// Icon API - proxy to SvelteKit (authentication handled by SvelteKit) // Icon API - proxy to SvelteKit (authentication handled by SvelteKit)
.route( .route("/icons/{*path}", get(handlers::proxy_icons_handler))
"/icons/{*path}",
axum::routing::get(handlers::proxy_icons_handler),
)
.fallback(api_404_and_method_handler) .fallback(api_404_and_method_handler)
} }
@@ -83,19 +79,13 @@ pub fn build_base_router() -> Router<Arc<AppState>> {
.route("/api/", any(api_root_404_handler)) .route("/api/", any(api_root_404_handler))
.route( .route(
"/_app/{*path}", "/_app/{*path}",
axum::routing::get(assets::serve_embedded_asset).head(assets::serve_embedded_asset), get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
) )
.route("/pgp", axum::routing::get(handlers::handle_pgp_route)) .route("/pgp", get(handlers::handle_pgp_route))
.route( .route("/publickey.asc", get(handlers::serve_pgp_key))
"/publickey.asc", .route("/pgp.asc", get(handlers::serve_pgp_key))
axum::routing::get(handlers::serve_pgp_key), .route("/.well-known/pgpkey.asc", get(handlers::serve_pgp_key))
) .route("/keys", get(handlers::redirect_to_pgp))
.route("/pgp.asc", axum::routing::get(handlers::serve_pgp_key))
.route(
"/.well-known/pgpkey.asc",
axum::routing::get(handlers::serve_pgp_key),
)
.route("/keys", axum::routing::get(handlers::redirect_to_pgp))
} }
async fn api_root_404_handler(uri: Uri) -> impl IntoResponse { async fn api_root_404_handler(uri: Uri) -> impl IntoResponse {
@@ -109,10 +99,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
let uri = req.uri(); let uri = req.uri();
let path = uri.path(); let path = uri.path();
if method != axum::http::Method::GET if method != Method::GET && method != Method::HEAD && method != Method::OPTIONS {
&& method != axum::http::Method::HEAD
&& method != axum::http::Method::OPTIONS
{
let content_type = req let content_type = req
.headers() .headers()
.get(axum::http::header::CONTENT_TYPE) .get(axum::http::header::CONTENT_TYPE)
@@ -129,10 +116,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
) )
.into_response(); .into_response();
} }
} else if method == axum::http::Method::POST } else if method == Method::POST || method == Method::PUT || method == Method::PATCH {
|| method == axum::http::Method::PUT
|| method == axum::http::Method::PATCH
{
// POST/PUT/PATCH require Content-Type header // POST/PUT/PATCH require Content-Type header
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
@@ -158,10 +142,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
} }
async fn api_404_handler(uri: Uri) -> impl IntoResponse { async fn api_404_handler(uri: Uri) -> impl IntoResponse {
let req = Request::builder() let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
.uri(uri)
.body(axum::body::Body::empty())
.unwrap();
api_404_and_method_handler(req).await api_404_and_method_handler(req).await
} }
+2 -4
View File
@@ -1,13 +1,11 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{auth::SessionManager, health::HealthChecker, tarpit::TarpitState}; use crate::{auth::SessionManager, health::HealthChecker, http::HttpClient, tarpit::TarpitState};
/// Application state shared across all handlers /// Application state shared across all handlers
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub downstream_url: String, pub client: HttpClient,
pub http_client: reqwest::Client,
pub unix_client: Option<reqwest::Client>,
pub health_checker: Arc<HealthChecker>, pub health_checker: Arc<HealthChecker>,
pub tarpit_state: Arc<TarpitState>, pub tarpit_state: Arc<TarpitState>,
pub pool: sqlx::PgPool, pub pool: sqlx::PgPool,
+12 -12
View File
@@ -1,5 +1,5 @@
use axum::{ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
@@ -34,7 +34,7 @@ pub fn is_page_route(path: &str) -> bool {
/// Check if the request accepts HTML responses /// Check if the request accepts HTML responses
pub fn accepts_html(headers: &HeaderMap) -> bool { pub fn accepts_html(headers: &HeaderMap) -> bool {
if let Some(accept) = headers.get(axum::http::header::ACCEPT) { if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() { if let Ok(accept_str) = accept.to_str() {
return accept_str.contains("text/html") || accept_str.contains("*/*"); return accept_str.contains("text/html") || accept_str.contains("*/*");
} }
@@ -46,7 +46,7 @@ pub fn accepts_html(headers: &HeaderMap) -> bool {
/// Determines if request prefers raw content (CLI tools) over HTML /// Determines if request prefers raw content (CLI tools) over HTML
pub fn prefers_raw_content(headers: &HeaderMap) -> bool { pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
// Check User-Agent for known CLI tools first (most reliable) // Check User-Agent for known CLI tools first (most reliable)
if let Some(ua) = headers.get(axum::http::header::USER_AGENT) { if let Some(ua) = headers.get(header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() { if let Ok(ua_str) = ua.to_str() {
let ua_lower = ua_str.to_lowercase(); let ua_lower = ua_str.to_lowercase();
if ua_lower.starts_with("curl/") if ua_lower.starts_with("curl/")
@@ -60,7 +60,7 @@ pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
} }
// Check Accept header - if it explicitly prefers text/html, serve HTML // Check Accept header - if it explicitly prefers text/html, serve HTML
if let Some(accept) = headers.get(axum::http::header::ACCEPT) { if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() { if let Ok(accept_str) = accept.to_str() {
// If text/html appears before */* in the list, they prefer HTML // If text/html appears before */* in the list, they prefer HTML
if let Some(html_pos) = accept_str.find("text/html") { if let Some(html_pos) = accept_str.find("text/html") {
@@ -88,12 +88,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(html) = assets::get_error_page(status_code) { if let Some(html) = assets::get_error_page(status_code) {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
axum::http::header::CONTENT_TYPE, header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"), HeaderValue::from_static("text/html; charset=utf-8"),
); );
headers.insert( headers.insert(
axum::http::header::CACHE_CONTROL, header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"), HeaderValue::from_static("no-cache, no-store, must-revalidate"),
); );
(status, headers, html).into_response() (status, headers, html).into_response()
@@ -107,12 +107,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(fallback_html) = assets::get_error_page(500) { if let Some(fallback_html) = assets::get_error_page(500) {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
axum::http::header::CONTENT_TYPE, header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"), HeaderValue::from_static("text/html; charset=utf-8"),
); );
headers.insert( headers.insert(
axum::http::header::CACHE_CONTROL, header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"), HeaderValue::from_static("no-cache, no-store, must-revalidate"),
); );
(status, headers, fallback_html).into_response() (status, headers, fallback_html).into_response()
+68 -50
View File
@@ -3,68 +3,86 @@ import { env } from "$env/dynamic/private";
const logger = getLogger(["ssr", "lib", "api"]); const logger = getLogger(["ssr", "lib", "api"]);
const upstreamUrl = env.UPSTREAM_URL; interface FetchOptions extends RequestInit {
const isUnixSocket = fetch?: typeof fetch;
upstreamUrl?.startsWith("/") || upstreamUrl?.startsWith("./"); }
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
export async function apiFetch<T>( interface BunFetchOptions extends RequestInit {
path: string, unix?: string;
init?: RequestInit & { fetch?: typeof fetch }, }
): Promise<T> {
/**
* Create a socket-aware fetch function
* Automatically handles Unix socket vs TCP based on UPSTREAM_URL
*/
function createSmartFetch(upstreamUrl: string | undefined) {
if (!upstreamUrl) { if (!upstreamUrl) {
logger.error("UPSTREAM_URL environment variable not set"); const error = "UPSTREAM_URL environment variable not set";
throw new Error("UPSTREAM_URL environment variable not set"); logger.error(error);
throw new Error(error);
} }
const url = `${baseUrl}${path}`; const isUnixSocket =
const method = init?.method ?? "GET"; upstreamUrl.startsWith("/") || upstreamUrl.startsWith("./");
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
// Unix sockets require Bun's native fetch (SvelteKit's fetch doesn't support it) return async function smartFetch<T>(
const fetchFn = isUnixSocket ? fetch : (init?.fetch ?? fetch); path: string,
options?: FetchOptions,
): Promise<T> {
const url = `${baseUrl}${path}`;
const method = options?.method ?? "GET";
const fetchOptions: RequestInit & { unix?: string } = { // Unix sockets require Bun's native fetch
...init, // SvelteKit's fetch doesn't support the 'unix' option
signal: init?.signal ?? AbortSignal.timeout(30_000), const fetchFn = isUnixSocket ? fetch : (options?.fetch ?? fetch);
};
// Remove custom fetch property from options const fetchOptions: BunFetchOptions = {
delete (fetchOptions as Record<string, unknown>).fetch; ...options,
signal: options?.signal ?? AbortSignal.timeout(30_000),
};
if (isUnixSocket) { // Remove custom fetch property from options (not part of standard RequestInit)
fetchOptions.unix = upstreamUrl; delete (fetchOptions as Record<string, unknown>).fetch;
}
logger.debug("API request", { // Add Unix socket path if needed
method, if (isUnixSocket) {
url, fetchOptions.unix = upstreamUrl;
path,
isUnixSocket,
upstreamUrl,
});
try {
const response = await fetchFn(url, fetchOptions);
if (!response.ok) {
logger.error("API request failed", {
method,
url,
status: response.status,
statusText: response.statusText,
});
throw new Error(`API error: ${response.status} ${response.statusText}`);
} }
const data = await response.json(); logger.debug("API request", {
logger.debug("API response", { method, url, status: response.status });
return data;
} catch (error) {
logger.error("API request exception", {
method, method,
url, url,
error: error instanceof Error ? error.message : String(error), path,
isUnixSocket,
}); });
throw error;
} try {
const response = await fetchFn(url, fetchOptions);
if (!response.ok) {
logger.error("API request failed", {
method,
url,
status: response.status,
statusText: response.statusText,
});
throw new Error(`API error: ${response.status} ${response.statusText}`);
}
const data = await response.json();
logger.debug("API response", { method, url, status: response.status });
return data;
} catch (error) {
logger.error("API request exception", {
method,
url,
error: error instanceof Error ? error.message : String(error),
});
throw error;
}
};
} }
// Export the configured smart fetch function
export const apiFetch = createSmartFetch(env.UPSTREAM_URL);