refactor: consolidate HTTP client for TCP/Unix socket handling

Extract reqwest client creation into dedicated HttpClient abstraction that handles both TCP and Unix socket connections transparently. Simplifies proxy logic by removing duplicate URL construction and client selection throughout the codebase.
This commit is contained in:
2026-01-07 14:34:32 -06:00
parent dcc496c979
commit dd1ce186d2
13 changed files with 398 additions and 329 deletions
Generated
+1
View File
@@ -88,6 +88,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"thiserror 2.0.17",
"time",
"tokio",
"tokio-util",
+1
View File
@@ -22,6 +22,7 @@ reqwest = { version = "0.13.1", default-features = false, features = ["rustls",
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.148"
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] }
thiserror = "2.0.17"
time = { version = "0.3.44", features = ["formatting", "macros", "serde"] }
tokio = { version = "1.49.0", features = ["full"] }
tokio-util = { version = "0.7.18", features = ["io"] }
+2 -2
View File
@@ -7,7 +7,7 @@ pub use projects::*;
pub use settings::*;
pub use tags::*;
use sqlx::{PgPool, postgres::PgPoolOptions};
use sqlx::{PgPool, postgres::PgPoolOptions, query};
/// Database connection pool creation
pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
@@ -20,7 +20,7 @@ pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
/// Health check query
pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> {
sqlx::query!("SELECT 1 as check")
query!("SELECT 1 as check")
.fetch_one(pool)
.await
.map(|_| ())
+31 -34
View File
@@ -1,9 +1,13 @@
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use time::OffsetDateTime;
use serde_json::json;
use sqlx::{PgPool, query, query_as};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use uuid::Uuid;
use super::{ProjectStatus, slugify};
use super::{
ProjectStatus, slugify,
tags::{ApiTag, DbTag, get_tags_for_project},
};
// Database model
#[derive(Debug, Clone, sqlx::FromRow)]
@@ -43,7 +47,7 @@ pub struct ApiProject {
pub struct ApiAdminProject {
#[serde(flatten)]
pub project: ApiProject,
pub tags: Vec<super::tags::ApiTag>,
pub tags: Vec<ApiTag>,
pub status: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -86,7 +90,7 @@ impl DbProject {
}
}
pub fn to_api_admin_project(&self, tags: Vec<super::tags::DbTag>) -> ApiAdminProject {
pub fn to_api_admin_project(&self, tags: Vec<DbTag>) -> ApiAdminProject {
ApiAdminProject {
project: self.to_api_project(),
tags: tags.into_iter().map(|t| t.to_api_tag()).collect(),
@@ -94,18 +98,11 @@ impl DbProject {
description: self.description.clone(),
github_repo: self.github_repo.clone(),
demo_url: self.demo_url.clone(),
created_at: self
.created_at
.format(&time::format_description::well_known::Rfc3339)
.unwrap(),
updated_at: self
.updated_at
.format(&time::format_description::well_known::Rfc3339)
.unwrap(),
last_github_activity: self.last_github_activity.map(|dt| {
dt.format(&time::format_description::well_known::Rfc3339)
.unwrap()
}),
created_at: self.created_at.format(&Rfc3339).unwrap(),
updated_at: self.updated_at.format(&Rfc3339).unwrap(),
last_github_activity: self
.last_github_activity
.map(|dt| dt.format(&Rfc3339).unwrap()),
}
}
}
@@ -150,7 +147,7 @@ pub struct AdminStats {
// Query functions
pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
@@ -176,12 +173,12 @@ pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::
pub async fn get_public_projects_with_tags(
pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_public_projects(pool).await?;
let mut result = Vec::new();
for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags));
}
@@ -190,7 +187,7 @@ pub async fn get_public_projects_with_tags(
/// Get all projects (admin view - includes hidden)
pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
@@ -216,12 +213,12 @@ pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sql
/// Get all projects with tags (admin view)
pub async fn get_all_projects_with_tags_admin(
pool: &PgPool,
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let projects = get_all_projects_admin(pool).await?;
let mut result = Vec::new();
for project in projects {
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
let tags = get_tags_for_project(pool, project.id).await?;
result.push((project, tags));
}
@@ -230,7 +227,7 @@ pub async fn get_all_projects_with_tags_admin(
/// Get single project by ID
pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
@@ -259,12 +256,12 @@ pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProje
pub async fn get_project_by_id_with_tags(
pool: &PgPool,
id: Uuid,
) -> Result<Option<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
) -> Result<Option<(DbProject, Vec<DbTag>)>, sqlx::Error> {
let project = get_project_by_id(pool, id).await?;
match project {
Some(p) => {
let tags = super::tags::get_tags_for_project(pool, p.id).await?;
let tags = get_tags_for_project(pool, p.id).await?;
Ok(Some((p, tags)))
}
None => Ok(None),
@@ -276,7 +273,7 @@ pub async fn get_project_by_slug(
pool: &PgPool,
slug: &str,
) -> Result<Option<DbProject>, sqlx::Error> {
sqlx::query_as!(
query_as!(
DbProject,
r#"
SELECT
@@ -316,7 +313,7 @@ pub async fn create_project(
.map(|s| slugify(s))
.unwrap_or_else(|| slugify(name));
sqlx::query_as!(
query_as!(
DbProject,
r#"
INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url)
@@ -352,7 +349,7 @@ pub async fn update_project(
.map(|s| slugify(s))
.unwrap_or_else(|| slugify(name));
sqlx::query_as!(
query_as!(
DbProject,
r#"
UPDATE projects
@@ -377,7 +374,7 @@ pub async fn update_project(
/// Delete project (CASCADE will handle tags)
pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM projects WHERE id = $1", id)
query!("DELETE FROM projects WHERE id = $1", id)
.execute(pool)
.await?;
Ok(())
@@ -386,7 +383,7 @@ pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error>
/// Get admin stats
pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
// Get project counts by status
let status_counts = sqlx::query!(
let status_counts = query!(
r#"
SELECT
status as "status!: ProjectStatus",
@@ -398,7 +395,7 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
.fetch_all(pool)
.await?;
let mut projects_by_status = serde_json::json!({
let mut projects_by_status = json!({
"active": 0,
"maintained": 0,
"archived": 0,
@@ -408,12 +405,12 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
let mut total_projects = 0;
for row in status_counts {
let status_str = format!("{:?}", row.status).to_lowercase();
projects_by_status[status_str] = serde_json::json!(row.count);
projects_by_status[status_str] = json!(row.count);
total_projects += row.count;
}
// Get total tags
let tag_count = sqlx::query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
let tag_count = query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
.fetch_one(pool)
.await?;
+4 -11
View File
@@ -60,17 +60,10 @@ pub async fn proxy_icons_handler(
let full_path = format!("/api/icons/{}", path);
let query = req.uri().query().unwrap_or("");
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
if query.is_empty() {
format!("http://localhost{}", full_path)
} else {
format!("http://localhost{}?{}", full_path, query)
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, full_path)
let path_with_query = if query.is_empty() {
full_path.clone()
} else {
format!("{}{}?{}", state.downstream_url, full_path, query)
format!("{full_path}?{query}")
};
// Build trusted headers with session info
@@ -86,7 +79,7 @@ pub async fn proxy_icons_handler(
}
}
match proxy::proxy_to_bun(&bun_url, state, forward_headers).await {
match proxy::proxy_to_bun(&path_with_query, state, forward_headers).await {
Ok((status, headers, body)) => (status, headers, body).into_response(),
Err(err) => {
tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request");
+146
View File
@@ -0,0 +1,146 @@
use reqwest::Method;
use std::path::PathBuf;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ClientError {
#[error("Failed to build reqwest client: {0}")]
BuildError(#[from] reqwest::Error),
#[error("Invalid downstream URL: {0}")]
InvalidUrl(String),
}
#[derive(Clone)]
pub struct HttpClient {
client: reqwest::Client,
target: TargetUrl,
}
#[derive(Debug, Clone)]
enum TargetUrl {
Tcp(String), // Base URL like "http://localhost:5173"
Unix(PathBuf), // Socket path like "/tmp/bun.sock"
}
impl HttpClient {
/// Create a new HttpClient from a downstream URL
///
/// Accepts:
/// - TCP: "http://localhost:5173", "https://example.com"
/// - Unix: "/tmp/bun.sock", "./relative.sock"
pub fn new(downstream: &str) -> Result<Self, ClientError> {
let target = if downstream.starts_with('/') || downstream.starts_with("./") {
TargetUrl::Unix(PathBuf::from(downstream))
} else if downstream.starts_with("http://") || downstream.starts_with("https://") {
TargetUrl::Tcp(downstream.to_string())
} else {
return Err(ClientError::InvalidUrl(downstream.to_string()));
};
tracing::debug!(
target = ?target,
downstream = %downstream,
"Creating HTTP client"
);
let client = match &target {
TargetUrl::Unix(path) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.unix_socket(path.clone())
.build()?,
TargetUrl::Tcp(_) => reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600))
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none())
.build()?,
};
Ok(Self { client, target })
}
/// Build a full URL from a path
///
/// Examples:
/// - TCP target "http://localhost:5173" + "/api/health" → "http://localhost:5173/api/health"
/// - Unix target "/tmp/bun.sock" + "/api/health" → "http://localhost/api/health"
fn build_url(&self, path: &str) -> String {
match &self.target {
TargetUrl::Tcp(base) => format!("{}{}", base, path),
TargetUrl::Unix(_) => format!("http://localhost{}", path),
}
}
pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
self.client.get(self.build_url(path))
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
self.client.post(self.build_url(path))
}
pub fn request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
self.client.request(method, self.build_url(path))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_url_construction() {
let client = HttpClient::new("http://localhost:5173").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost:5173/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost:5173/path?query=1"
);
}
#[test]
fn test_unix_url_construction() {
let client = HttpClient::new("/tmp/bun.sock").unwrap();
assert_eq!(
client.build_url("/api/health"),
"http://localhost/api/health"
);
assert_eq!(
client.build_url("/path?query=1"),
"http://localhost/path?query=1"
);
}
#[test]
fn test_relative_unix_socket() {
let client = HttpClient::new("./relative.sock").unwrap();
assert!(matches!(client.target, TargetUrl::Unix(_)));
}
#[test]
fn test_invalid_url() {
let result = HttpClient::new("not-a-valid-url");
assert!(result.is_err());
}
#[test]
fn test_https_url() {
let client = HttpClient::new("https://example.com").unwrap();
assert!(matches!(client.target, TargetUrl::Tcp(_)));
assert_eq!(
client.build_url("/api/test"),
"https://example.com/api/test"
);
}
}
+7 -41
View File
@@ -1,6 +1,5 @@
use clap::Parser;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
@@ -13,6 +12,7 @@ mod db;
mod formatter;
mod handlers;
mod health;
mod http;
mod middleware;
mod og;
mod proxy;
@@ -125,50 +125,18 @@ async fn main() {
std::process::exit(1);
}
// Create HTTP client for TCP connections with optimized pool settings
let http_client = reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.build()
.expect("Failed to create HTTP client");
// Create Unix socket client if downstream is a Unix socket
let unix_client = if args.downstream.starts_with('/') || args.downstream.starts_with("./") {
let path = PathBuf::from(&args.downstream);
Some(
reqwest::Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
.timeout(Duration::from_secs(5)) // Default timeout for SSR
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
.unix_socket(path)
.build()
.expect("Failed to create Unix socket client"),
)
} else {
None
};
// Create socket-aware HTTP client
let client = http::HttpClient::new(&args.downstream).expect("Failed to create HTTP client");
// Create health checker
let downstream_url_for_health = args.downstream.clone();
let http_client_for_health = http_client.clone();
let unix_client_for_health = unix_client.clone();
let client_for_health = client.clone();
let pool_for_health = pool.clone();
let health_checker = Arc::new(HealthChecker::new(move || {
let downstream_url = downstream_url_for_health.clone();
let http_client = http_client_for_health.clone();
let unix_client = unix_client_for_health.clone();
let client = client_for_health.clone();
let pool = pool_for_health.clone();
async move {
proxy::perform_health_check(downstream_url, http_client, unix_client, Some(pool)).await
}
async move { proxy::perform_health_check(client, Some(pool)).await }
}));
let tarpit_config = TarpitConfig::from_env();
@@ -186,9 +154,7 @@ async fn main() {
);
let state = Arc::new(AppState {
downstream_url: args.downstream.clone(),
http_client,
unix_client,
client,
health_checker,
tarpit_state,
pool: pool.clone(),
+3 -11
View File
@@ -36,17 +36,9 @@ pub async fn generate_og_image(spec: &OGImageSpec, state: Arc<AppState>) -> Resu
tracing::Span::current().record("r2_key", &r2_key);
// Call Bun's internal endpoint
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
"http://localhost/internal/ogp/generate".to_string()
} else {
format!("{}/internal/ogp/generate", state.downstream_url)
};
let client = state.unix_client.as_ref().unwrap_or(&state.http_client);
let response = client
.post(&bun_url)
let response = state
.client
.post("/internal/ogp/generate")
.json(spec)
.timeout(Duration::from_secs(30))
.send()
+50 -74
View File
@@ -72,17 +72,10 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
return response;
}
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
{
if query.is_empty() {
format!("http://localhost{path}")
} else {
format!("http://localhost{path}?{query}")
}
} else if query.is_empty() {
format!("{}{}", state.downstream_url, path)
let path_with_query = if query.is_empty() {
path.to_string()
} else {
format!("{}{}?{}", state.downstream_url, path, query)
format!("{path}?{query}")
};
// Build trusted headers to forward to downstream
@@ -120,7 +113,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let start = std::time::Instant::now();
match proxy_to_bun(&bun_url, state.clone(), forward_headers).await {
match proxy_to_bun(&path_with_query, state.clone(), forward_headers).await {
Ok((status, headers, body)) => {
let duration_ms = start.elapsed().as_millis() as u64;
let cache = "miss";
@@ -182,7 +175,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
let duration_ms = start.elapsed().as_millis() as u64;
tracing::error!(
error = %err,
url = %bun_url,
path = %path_with_query,
duration_ms,
"Failed to proxy to Bun"
);
@@ -203,18 +196,12 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
/// Proxy a request to Bun SSR
pub async fn proxy_to_bun(
url: &str,
path: &str,
state: Arc<AppState>,
forward_headers: HeaderMap,
) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> {
let client = if state.unix_client.is_some() {
state.unix_client.as_ref().unwrap()
} else {
&state.http_client
};
// Build request with forwarded headers
let mut request_builder = client.get(url);
let mut request_builder = state.client.get(path);
for (name, value) in forward_headers.iter() {
request_builder = request_builder.header(name, value);
}
@@ -247,45 +234,35 @@ pub async fn proxy_to_bun(
/// Perform health check on Bun SSR and database
pub async fn perform_health_check(
downstream_url: String,
http_client: reqwest::Client,
unix_client: Option<reqwest::Client>,
client: crate::http::HttpClient,
pool: Option<sqlx::PgPool>,
) -> bool {
let url = if downstream_url.starts_with('/') || downstream_url.starts_with("./") {
"http://localhost/internal/health".to_string()
} else {
format!("{downstream_url}/internal/health")
let bun_healthy = match tokio::time::timeout(
Duration::from_secs(5),
client.get("/internal/health").send(),
)
.await
{
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
};
let client = if unix_client.is_some() {
unix_client.as_ref().unwrap()
} else {
&http_client
};
let bun_healthy =
match tokio::time::timeout(Duration::from_secs(5), client.get(&url).send()).await {
Ok(Ok(response)) => {
let is_success = response.status().is_success();
if !is_success {
tracing::warn!(
status = response.status().as_u16(),
"Health check failed: Bun returned non-success status"
);
}
is_success
}
Ok(Err(err)) => {
tracing::error!(error = %err, "Health check failed: cannot reach Bun");
false
}
Err(_) => {
tracing::error!("Health check failed: timeout after 5s");
false
}
};
// Check database
let db_healthy = if let Some(pool) = pool {
match db::health_check(&pool).await {
@@ -307,33 +284,32 @@ fn should_tarpit(state: &TarpitState, path: &str) -> bool {
state.config.enabled && tarpit::is_malicious_path(path)
}
/// Common handler logic for requests with optional peer info
async fn handle_request_with_optional_peer(
state: Arc<AppState>,
peer: Option<SocketAddr>,
req: Request,
) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
let peer_info = peer.map(ConnectInfo);
tarpit::tarpit_handler(State(state.tarpit_state.clone()), peer_info, req).await
} else {
isr_handler(State(state), req).await
}
}
/// Fallback handler for TCP connections (has access to peer IP)
pub async fn fallback_handler_tcp(
State(state): State<Arc<AppState>>,
ConnectInfo(peer): ConnectInfo<SocketAddr>,
req: Request,
) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(
State(state.tarpit_state.clone()),
Some(ConnectInfo(peer)),
req,
)
.await
} else {
isr_handler(State(state), req).await
}
handle_request_with_optional_peer(state, Some(peer), req).await
}
/// Fallback handler for Unix sockets (no peer IP available)
pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response {
let path = req.uri().path();
if should_tarpit(&state.tarpit_state, path) {
tarpit::tarpit_handler(State(state.tarpit_state.clone()), None, req).await
} else {
isr_handler(State(state), req).await
}
handle_request_with_optional_peer(state, None, req).await
}
+31 -50
View File
@@ -1,4 +1,11 @@
use axum::{Router, extract::Request, http::Uri, response::IntoResponse, routing::any};
use axum::{
Router,
body::Body,
extract::Request,
http::{Method, Uri},
response::IntoResponse,
routing::{any, get, post},
};
use std::sync::Arc;
use crate::{assets, handlers, state::AppState};
@@ -9,32 +16,28 @@ pub fn api_routes() -> Router<Arc<AppState>> {
.route("/", any(api_root_404_handler))
.route(
"/health",
axum::routing::get(handlers::health_handler).head(handlers::health_handler),
get(handlers::health_handler).head(handlers::health_handler),
)
// Authentication endpoints (public)
.route("/login", axum::routing::post(handlers::api_login_handler))
.route("/logout", axum::routing::post(handlers::api_logout_handler))
.route(
"/session",
axum::routing::get(handlers::api_session_handler),
)
.route("/login", post(handlers::api_login_handler))
.route("/logout", post(handlers::api_logout_handler))
.route("/session", get(handlers::api_session_handler))
// Projects - GET is public (shows all for admin, only non-hidden for public)
// POST/PUT/DELETE require authentication
.route(
"/projects",
axum::routing::get(handlers::projects_handler).post(handlers::create_project_handler),
get(handlers::projects_handler).post(handlers::create_project_handler),
)
.route(
"/projects/{id}",
axum::routing::get(handlers::get_project_handler)
get(handlers::get_project_handler)
.put(handlers::update_project_handler)
.delete(handlers::delete_project_handler),
)
// Project tags - authentication checked in handlers
.route(
"/projects/{id}/tags",
axum::routing::get(handlers::get_project_tags_handler)
.post(handlers::add_project_tag_handler),
get(handlers::get_project_tags_handler).post(handlers::add_project_tag_handler),
)
.route(
"/projects/{id}/tags/{tag_id}",
@@ -43,36 +46,29 @@ pub fn api_routes() -> Router<Arc<AppState>> {
// Tags - authentication checked in handlers
.route(
"/tags",
axum::routing::get(handlers::list_tags_handler).post(handlers::create_tag_handler),
get(handlers::list_tags_handler).post(handlers::create_tag_handler),
)
.route(
"/tags/{slug}",
axum::routing::get(handlers::get_tag_handler).put(handlers::update_tag_handler),
get(handlers::get_tag_handler).put(handlers::update_tag_handler),
)
.route(
"/tags/{slug}/related",
axum::routing::get(handlers::get_related_tags_handler),
get(handlers::get_related_tags_handler),
)
.route(
"/tags/recalculate-cooccurrence",
axum::routing::post(handlers::recalculate_cooccurrence_handler),
post(handlers::recalculate_cooccurrence_handler),
)
// Admin stats - requires authentication
.route(
"/stats",
axum::routing::get(handlers::get_admin_stats_handler),
)
.route("/stats", get(handlers::get_admin_stats_handler))
// Site settings - GET is public, PUT requires authentication
.route(
"/settings",
axum::routing::get(handlers::get_settings_handler)
.put(handlers::update_settings_handler),
get(handlers::get_settings_handler).put(handlers::update_settings_handler),
)
// Icon API - proxy to SvelteKit (authentication handled by SvelteKit)
.route(
"/icons/{*path}",
axum::routing::get(handlers::proxy_icons_handler),
)
.route("/icons/{*path}", get(handlers::proxy_icons_handler))
.fallback(api_404_and_method_handler)
}
@@ -83,19 +79,13 @@ pub fn build_base_router() -> Router<Arc<AppState>> {
.route("/api/", any(api_root_404_handler))
.route(
"/_app/{*path}",
axum::routing::get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
)
.route("/pgp", axum::routing::get(handlers::handle_pgp_route))
.route(
"/publickey.asc",
axum::routing::get(handlers::serve_pgp_key),
)
.route("/pgp.asc", axum::routing::get(handlers::serve_pgp_key))
.route(
"/.well-known/pgpkey.asc",
axum::routing::get(handlers::serve_pgp_key),
)
.route("/keys", axum::routing::get(handlers::redirect_to_pgp))
.route("/pgp", get(handlers::handle_pgp_route))
.route("/publickey.asc", get(handlers::serve_pgp_key))
.route("/pgp.asc", get(handlers::serve_pgp_key))
.route("/.well-known/pgpkey.asc", get(handlers::serve_pgp_key))
.route("/keys", get(handlers::redirect_to_pgp))
}
async fn api_root_404_handler(uri: Uri) -> impl IntoResponse {
@@ -109,10 +99,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
let uri = req.uri();
let path = uri.path();
if method != axum::http::Method::GET
&& method != axum::http::Method::HEAD
&& method != axum::http::Method::OPTIONS
{
if method != Method::GET && method != Method::HEAD && method != Method::OPTIONS {
let content_type = req
.headers()
.get(axum::http::header::CONTENT_TYPE)
@@ -129,10 +116,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
)
.into_response();
}
} else if method == axum::http::Method::POST
|| method == axum::http::Method::PUT
|| method == axum::http::Method::PATCH
{
} else if method == Method::POST || method == Method::PUT || method == Method::PATCH {
// POST/PUT/PATCH require Content-Type header
return (
StatusCode::BAD_REQUEST,
@@ -158,10 +142,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
}
async fn api_404_handler(uri: Uri) -> impl IntoResponse {
let req = Request::builder()
.uri(uri)
.body(axum::body::Body::empty())
.unwrap();
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
api_404_and_method_handler(req).await
}
+2 -4
View File
@@ -1,13 +1,11 @@
use std::sync::Arc;
use crate::{auth::SessionManager, health::HealthChecker, tarpit::TarpitState};
use crate::{auth::SessionManager, health::HealthChecker, http::HttpClient, tarpit::TarpitState};
/// Application state shared across all handlers
#[derive(Clone)]
pub struct AppState {
pub downstream_url: String,
pub http_client: reqwest::Client,
pub unix_client: Option<reqwest::Client>,
pub client: HttpClient,
pub health_checker: Arc<HealthChecker>,
pub tarpit_state: Arc<TarpitState>,
pub pool: sqlx::PgPool,
+12 -12
View File
@@ -1,5 +1,5 @@
use axum::{
http::{HeaderMap, StatusCode},
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
@@ -34,7 +34,7 @@ pub fn is_page_route(path: &str) -> bool {
/// Check if the request accepts HTML responses
pub fn accepts_html(headers: &HeaderMap) -> bool {
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() {
return accept_str.contains("text/html") || accept_str.contains("*/*");
}
@@ -46,7 +46,7 @@ pub fn accepts_html(headers: &HeaderMap) -> bool {
/// Determines if request prefers raw content (CLI tools) over HTML
pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
// Check User-Agent for known CLI tools first (most reliable)
if let Some(ua) = headers.get(axum::http::header::USER_AGENT) {
if let Some(ua) = headers.get(header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() {
let ua_lower = ua_str.to_lowercase();
if ua_lower.starts_with("curl/")
@@ -60,7 +60,7 @@ pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
}
// Check Accept header - if it explicitly prefers text/html, serve HTML
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
if let Some(accept) = headers.get(header::ACCEPT) {
if let Ok(accept_str) = accept.to_str() {
// If text/html appears before */* in the list, they prefer HTML
if let Some(html_pos) = accept_str.find("text/html") {
@@ -88,12 +88,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(html) = assets::get_error_page(status_code) {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
);
(status, headers, html).into_response()
@@ -107,12 +107,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
if let Some(fallback_html) = assets::get_error_page(500) {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
);
(status, headers, fallback_html).into_response()
+68 -50
View File
@@ -3,68 +3,86 @@ import { env } from "$env/dynamic/private";
const logger = getLogger(["ssr", "lib", "api"]);
const upstreamUrl = env.UPSTREAM_URL;
const isUnixSocket =
upstreamUrl?.startsWith("/") || upstreamUrl?.startsWith("./");
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
interface FetchOptions extends RequestInit {
fetch?: typeof fetch;
}
export async function apiFetch<T>(
path: string,
init?: RequestInit & { fetch?: typeof fetch },
): Promise<T> {
interface BunFetchOptions extends RequestInit {
unix?: string;
}
/**
* Create a socket-aware fetch function
* Automatically handles Unix socket vs TCP based on UPSTREAM_URL
*/
function createSmartFetch(upstreamUrl: string | undefined) {
if (!upstreamUrl) {
logger.error("UPSTREAM_URL environment variable not set");
throw new Error("UPSTREAM_URL environment variable not set");
const error = "UPSTREAM_URL environment variable not set";
logger.error(error);
throw new Error(error);
}
const url = `${baseUrl}${path}`;
const method = init?.method ?? "GET";
const isUnixSocket =
upstreamUrl.startsWith("/") || upstreamUrl.startsWith("./");
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
// Unix sockets require Bun's native fetch (SvelteKit's fetch doesn't support it)
const fetchFn = isUnixSocket ? fetch : (init?.fetch ?? fetch);
return async function smartFetch<T>(
path: string,
options?: FetchOptions,
): Promise<T> {
const url = `${baseUrl}${path}`;
const method = options?.method ?? "GET";
const fetchOptions: RequestInit & { unix?: string } = {
...init,
signal: init?.signal ?? AbortSignal.timeout(30_000),
};
// Unix sockets require Bun's native fetch
// SvelteKit's fetch doesn't support the 'unix' option
const fetchFn = isUnixSocket ? fetch : (options?.fetch ?? fetch);
// Remove custom fetch property from options
delete (fetchOptions as Record<string, unknown>).fetch;
const fetchOptions: BunFetchOptions = {
...options,
signal: options?.signal ?? AbortSignal.timeout(30_000),
};
if (isUnixSocket) {
fetchOptions.unix = upstreamUrl;
}
// Remove custom fetch property from options (not part of standard RequestInit)
delete (fetchOptions as Record<string, unknown>).fetch;
logger.debug("API request", {
method,
url,
path,
isUnixSocket,
upstreamUrl,
});
try {
const response = await fetchFn(url, fetchOptions);
if (!response.ok) {
logger.error("API request failed", {
method,
url,
status: response.status,
statusText: response.statusText,
});
throw new Error(`API error: ${response.status} ${response.statusText}`);
// Add Unix socket path if needed
if (isUnixSocket) {
fetchOptions.unix = upstreamUrl;
}
const data = await response.json();
logger.debug("API response", { method, url, status: response.status });
return data;
} catch (error) {
logger.error("API request exception", {
logger.debug("API request", {
method,
url,
error: error instanceof Error ? error.message : String(error),
path,
isUnixSocket,
});
throw error;
}
try {
const response = await fetchFn(url, fetchOptions);
if (!response.ok) {
logger.error("API request failed", {
method,
url,
status: response.status,
statusText: response.statusText,
});
throw new Error(`API error: ${response.status} ${response.statusText}`);
}
const data = await response.json();
logger.debug("API response", { method, url, status: response.status });
return data;
} catch (error) {
logger.error("API request exception", {
method,
url,
error: error instanceof Error ? error.message : String(error),
});
throw error;
}
};
}
// Export the configured smart fetch function
export const apiFetch = createSmartFetch(env.UPSTREAM_URL);