diff --git a/Cargo.lock b/Cargo.lock index 7a9880d..cbb213c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,7 @@ dependencies = [ "serde", "serde_json", "sqlx", + "thiserror 2.0.17", "time", "tokio", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index cdef5da..2f417da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ reqwest = { version = "0.13.1", default-features = false, features = ["rustls", serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.148" sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] } +thiserror = "2.0.17" time = { version = "0.3.44", features = ["formatting", "macros", "serde"] } tokio = { version = "1.49.0", features = ["full"] } tokio-util = { version = "0.7.18", features = ["io"] } diff --git a/src/db/mod.rs b/src/db/mod.rs index d9ef3cd..1519407 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -7,7 +7,7 @@ pub use projects::*; pub use settings::*; pub use tags::*; -use sqlx::{PgPool, postgres::PgPoolOptions}; +use sqlx::{PgPool, postgres::PgPoolOptions, query}; /// Database connection pool creation pub async fn create_pool(database_url: &str) -> Result { @@ -20,7 +20,7 @@ pub async fn create_pool(database_url: &str) -> Result { /// Health check query pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> { - sqlx::query!("SELECT 1 as check") + query!("SELECT 1 as check") .fetch_one(pool) .await .map(|_| ()) diff --git a/src/db/projects.rs b/src/db/projects.rs index e03df1b..c81bd8c 100644 --- a/src/db/projects.rs +++ b/src/db/projects.rs @@ -1,9 +1,13 @@ use serde::{Deserialize, Serialize}; -use sqlx::PgPool; -use time::OffsetDateTime; +use serde_json::json; +use sqlx::{PgPool, query, query_as}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339}; use uuid::Uuid; -use super::{ProjectStatus, slugify}; +use super::{ + ProjectStatus, slugify, + tags::{ApiTag, DbTag, get_tags_for_project}, +}; // Database model #[derive(Debug, Clone, sqlx::FromRow)] @@ -43,7 +47,7 @@ pub struct ApiProject { pub struct ApiAdminProject { #[serde(flatten)] pub project: ApiProject, - pub tags: Vec, + pub tags: Vec, pub status: String, pub description: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -86,7 +90,7 @@ impl DbProject { } } - pub fn to_api_admin_project(&self, tags: Vec) -> ApiAdminProject { + pub fn to_api_admin_project(&self, tags: Vec) -> ApiAdminProject { ApiAdminProject { project: self.to_api_project(), tags: tags.into_iter().map(|t| t.to_api_tag()).collect(), @@ -94,18 +98,11 @@ impl DbProject { description: self.description.clone(), github_repo: self.github_repo.clone(), demo_url: self.demo_url.clone(), - created_at: self - .created_at - .format(&time::format_description::well_known::Rfc3339) - .unwrap(), - updated_at: self - .updated_at - .format(&time::format_description::well_known::Rfc3339) - .unwrap(), - last_github_activity: self.last_github_activity.map(|dt| { - dt.format(&time::format_description::well_known::Rfc3339) - .unwrap() - }), + created_at: self.created_at.format(&Rfc3339).unwrap(), + updated_at: self.updated_at.format(&Rfc3339).unwrap(), + last_github_activity: self + .last_github_activity + .map(|dt| dt.format(&Rfc3339).unwrap()), } } } @@ -150,20 +147,20 @@ pub struct AdminStats { // Query functions pub async fn get_public_projects(pool: &PgPool) -> Result, sqlx::Error> { - sqlx::query_as!( + query_as!( DbProject, r#" - SELECT - id, - slug, + SELECT + id, + slug, name, short_description, - description, - status as "status: ProjectStatus", - github_repo, - demo_url, - last_github_activity, - created_at, + description, + status as "status: ProjectStatus", + github_repo, + demo_url, + last_github_activity, + created_at, updated_at FROM projects WHERE status != 'hidden' @@ -176,12 +173,12 @@ pub async fn get_public_projects(pool: &PgPool) -> Result, sqlx:: pub async fn get_public_projects_with_tags( pool: &PgPool, -) -> Result)>, sqlx::Error> { +) -> Result)>, sqlx::Error> { let projects = get_public_projects(pool).await?; let mut result = Vec::new(); for project in projects { - let tags = super::tags::get_tags_for_project(pool, project.id).await?; + let tags = get_tags_for_project(pool, project.id).await?; result.push((project, tags)); } @@ -190,20 +187,20 @@ pub async fn get_public_projects_with_tags( /// Get all projects (admin view - includes hidden) pub async fn get_all_projects_admin(pool: &PgPool) -> Result, sqlx::Error> { - sqlx::query_as!( + query_as!( DbProject, r#" - SELECT - id, - slug, + SELECT + id, + slug, name, short_description, - description, - status as "status: ProjectStatus", - github_repo, - demo_url, - last_github_activity, - created_at, + description, + status as "status: ProjectStatus", + github_repo, + demo_url, + last_github_activity, + created_at, updated_at FROM projects ORDER BY updated_at DESC @@ -216,12 +213,12 @@ pub async fn get_all_projects_admin(pool: &PgPool) -> Result, sql /// Get all projects with tags (admin view) pub async fn get_all_projects_with_tags_admin( pool: &PgPool, -) -> Result)>, sqlx::Error> { +) -> Result)>, sqlx::Error> { let projects = get_all_projects_admin(pool).await?; let mut result = Vec::new(); for project in projects { - let tags = super::tags::get_tags_for_project(pool, project.id).await?; + let tags = get_tags_for_project(pool, project.id).await?; result.push((project, tags)); } @@ -230,21 +227,21 @@ pub async fn get_all_projects_with_tags_admin( /// Get single project by ID pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result, sqlx::Error> { - sqlx::query_as!( + query_as!( DbProject, r#" - SELECT - id, - slug, + SELECT + id, + slug, name, short_description, - description, - status as "status: ProjectStatus", - github_repo, - demo_url, + description, + status as "status: ProjectStatus", + github_repo, + demo_url, - last_github_activity, - created_at, + last_github_activity, + created_at, updated_at FROM projects WHERE id = $1 @@ -259,12 +256,12 @@ pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result Result)>, sqlx::Error> { +) -> Result)>, sqlx::Error> { let project = get_project_by_id(pool, id).await?; match project { Some(p) => { - let tags = super::tags::get_tags_for_project(pool, p.id).await?; + let tags = get_tags_for_project(pool, p.id).await?; Ok(Some((p, tags))) } None => Ok(None), @@ -276,21 +273,21 @@ pub async fn get_project_by_slug( pool: &PgPool, slug: &str, ) -> Result, sqlx::Error> { - sqlx::query_as!( + query_as!( DbProject, r#" - SELECT - id, - slug, + SELECT + id, + slug, name, short_description, - description, - status as "status: ProjectStatus", - github_repo, - demo_url, + description, + status as "status: ProjectStatus", + github_repo, + demo_url, - last_github_activity, - created_at, + last_github_activity, + created_at, updated_at FROM projects WHERE slug = $1 @@ -316,12 +313,12 @@ pub async fn create_project( .map(|s| slugify(s)) .unwrap_or_else(|| slugify(name)); - sqlx::query_as!( + query_as!( DbProject, r#" INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url) VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus", + RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus", github_repo, demo_url, last_github_activity, created_at, updated_at "#, slug, @@ -352,14 +349,14 @@ pub async fn update_project( .map(|s| slugify(s)) .unwrap_or_else(|| slugify(name)); - sqlx::query_as!( + query_as!( DbProject, r#" UPDATE projects - SET slug = $2, name = $3, short_description = $4, description = $5, + SET slug = $2, name = $3, short_description = $4, description = $5, status = $6, github_repo = $7, demo_url = $8 WHERE id = $1 - RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus", + RETURNING id, slug, name, short_description, description, status as "status: ProjectStatus", github_repo, demo_url, last_github_activity, created_at, updated_at "#, id, @@ -377,7 +374,7 @@ pub async fn update_project( /// Delete project (CASCADE will handle tags) pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> { - sqlx::query!("DELETE FROM projects WHERE id = $1", id) + query!("DELETE FROM projects WHERE id = $1", id) .execute(pool) .await?; Ok(()) @@ -386,9 +383,9 @@ pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> /// Get admin stats pub async fn get_admin_stats(pool: &PgPool) -> Result { // Get project counts by status - let status_counts = sqlx::query!( + let status_counts = query!( r#" - SELECT + SELECT status as "status!: ProjectStatus", COUNT(*)::int as "count!" FROM projects @@ -398,7 +395,7 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result { .fetch_all(pool) .await?; - let mut projects_by_status = serde_json::json!({ + let mut projects_by_status = json!({ "active": 0, "maintained": 0, "archived": 0, @@ -408,12 +405,12 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result { let mut total_projects = 0; for row in status_counts { let status_str = format!("{:?}", row.status).to_lowercase(); - projects_by_status[status_str] = serde_json::json!(row.count); + projects_by_status[status_str] = json!(row.count); total_projects += row.count; } // Get total tags - let tag_count = sqlx::query!("SELECT COUNT(*)::int as \"count!\" FROM tags") + let tag_count = query!("SELECT COUNT(*)::int as \"count!\" FROM tags") .fetch_one(pool) .await?; diff --git a/src/handlers/assets.rs b/src/handlers/assets.rs index c6aa61f..be4e864 100644 --- a/src/handlers/assets.rs +++ b/src/handlers/assets.rs @@ -60,17 +60,10 @@ pub async fn proxy_icons_handler( let full_path = format!("/api/icons/{}", path); let query = req.uri().query().unwrap_or(""); - let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") - { - if query.is_empty() { - format!("http://localhost{}", full_path) - } else { - format!("http://localhost{}?{}", full_path, query) - } - } else if query.is_empty() { - format!("{}{}", state.downstream_url, full_path) + let path_with_query = if query.is_empty() { + full_path.clone() } else { - format!("{}{}?{}", state.downstream_url, full_path, query) + format!("{full_path}?{query}") }; // Build trusted headers with session info @@ -86,7 +79,7 @@ pub async fn proxy_icons_handler( } } - match proxy::proxy_to_bun(&bun_url, state, forward_headers).await { + match proxy::proxy_to_bun(&path_with_query, state, forward_headers).await { Ok((status, headers, body)) => (status, headers, body).into_response(), Err(err) => { tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request"); diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..e5f542b --- /dev/null +++ b/src/http.rs @@ -0,0 +1,146 @@ +use reqwest::Method; +use std::path::PathBuf; +use std::time::Duration; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ClientError { + #[error("Failed to build reqwest client: {0}")] + BuildError(#[from] reqwest::Error), + + #[error("Invalid downstream URL: {0}")] + InvalidUrl(String), +} + +#[derive(Clone)] +pub struct HttpClient { + client: reqwest::Client, + target: TargetUrl, +} + +#[derive(Debug, Clone)] +enum TargetUrl { + Tcp(String), // Base URL like "http://localhost:5173" + Unix(PathBuf), // Socket path like "/tmp/bun.sock" +} + +impl HttpClient { + /// Create a new HttpClient from a downstream URL + /// + /// Accepts: + /// - TCP: "http://localhost:5173", "https://example.com" + /// - Unix: "/tmp/bun.sock", "./relative.sock" + pub fn new(downstream: &str) -> Result { + let target = if downstream.starts_with('/') || downstream.starts_with("./") { + TargetUrl::Unix(PathBuf::from(downstream)) + } else if downstream.starts_with("http://") || downstream.starts_with("https://") { + TargetUrl::Tcp(downstream.to_string()) + } else { + return Err(ClientError::InvalidUrl(downstream.to_string())); + }; + + tracing::debug!( + target = ?target, + downstream = %downstream, + "Creating HTTP client" + ); + + let client = match &target { + TargetUrl::Unix(path) => reqwest::Client::builder() + .pool_max_idle_per_host(8) + .pool_idle_timeout(Duration::from_secs(600)) + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(3)) + .redirect(reqwest::redirect::Policy::none()) + .unix_socket(path.clone()) + .build()?, + TargetUrl::Tcp(_) => reqwest::Client::builder() + .pool_max_idle_per_host(8) + .pool_idle_timeout(Duration::from_secs(600)) + .tcp_keepalive(Some(Duration::from_secs(60))) + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(3)) + .redirect(reqwest::redirect::Policy::none()) + .build()?, + }; + + Ok(Self { client, target }) + } + + /// Build a full URL from a path + /// + /// Examples: + /// - TCP target "http://localhost:5173" + "/api/health" → "http://localhost:5173/api/health" + /// - Unix target "/tmp/bun.sock" + "/api/health" → "http://localhost/api/health" + fn build_url(&self, path: &str) -> String { + match &self.target { + TargetUrl::Tcp(base) => format!("{}{}", base, path), + TargetUrl::Unix(_) => format!("http://localhost{}", path), + } + } + + pub fn get(&self, path: &str) -> reqwest::RequestBuilder { + self.client.get(self.build_url(path)) + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + self.client.post(self.build_url(path)) + } + + pub fn request(&self, method: Method, path: &str) -> reqwest::RequestBuilder { + self.client.request(method, self.build_url(path)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tcp_url_construction() { + let client = HttpClient::new("http://localhost:5173").unwrap(); + assert_eq!( + client.build_url("/api/health"), + "http://localhost:5173/api/health" + ); + assert_eq!( + client.build_url("/path?query=1"), + "http://localhost:5173/path?query=1" + ); + } + + #[test] + fn test_unix_url_construction() { + let client = HttpClient::new("/tmp/bun.sock").unwrap(); + assert_eq!( + client.build_url("/api/health"), + "http://localhost/api/health" + ); + assert_eq!( + client.build_url("/path?query=1"), + "http://localhost/path?query=1" + ); + } + + #[test] + fn test_relative_unix_socket() { + let client = HttpClient::new("./relative.sock").unwrap(); + assert!(matches!(client.target, TargetUrl::Unix(_))); + } + + #[test] + fn test_invalid_url() { + let result = HttpClient::new("not-a-valid-url"); + assert!(result.is_err()); + } + + #[test] + fn test_https_url() { + let client = HttpClient::new("https://example.com").unwrap(); + assert!(matches!(client.target, TargetUrl::Tcp(_))); + assert_eq!( + client.build_url("/api/test"), + "https://example.com/api/test" + ); + } +} diff --git a/src/main.rs b/src/main.rs index fc34ed2..0b56e0c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use clap::Parser; use std::net::SocketAddr; -use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer}; @@ -13,6 +12,7 @@ mod db; mod formatter; mod handlers; mod health; +mod http; mod middleware; mod og; mod proxy; @@ -125,50 +125,18 @@ async fn main() { std::process::exit(1); } - // Create HTTP client for TCP connections with optimized pool settings - let http_client = reqwest::Client::builder() - .pool_max_idle_per_host(8) - .pool_idle_timeout(Duration::from_secs(600)) // 10 minutes - .tcp_keepalive(Some(Duration::from_secs(60))) - .timeout(Duration::from_secs(5)) // Default timeout for SSR - .connect_timeout(Duration::from_secs(3)) - .redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through - .build() - .expect("Failed to create HTTP client"); - - // Create Unix socket client if downstream is a Unix socket - let unix_client = if args.downstream.starts_with('/') || args.downstream.starts_with("./") { - let path = PathBuf::from(&args.downstream); - Some( - reqwest::Client::builder() - .pool_max_idle_per_host(8) - .pool_idle_timeout(Duration::from_secs(600)) // 10 minutes - .timeout(Duration::from_secs(5)) // Default timeout for SSR - .connect_timeout(Duration::from_secs(3)) - .redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through - .unix_socket(path) - .build() - .expect("Failed to create Unix socket client"), - ) - } else { - None - }; + // Create socket-aware HTTP client + let client = http::HttpClient::new(&args.downstream).expect("Failed to create HTTP client"); // Create health checker - let downstream_url_for_health = args.downstream.clone(); - let http_client_for_health = http_client.clone(); - let unix_client_for_health = unix_client.clone(); + let client_for_health = client.clone(); let pool_for_health = pool.clone(); let health_checker = Arc::new(HealthChecker::new(move || { - let downstream_url = downstream_url_for_health.clone(); - let http_client = http_client_for_health.clone(); - let unix_client = unix_client_for_health.clone(); + let client = client_for_health.clone(); let pool = pool_for_health.clone(); - async move { - proxy::perform_health_check(downstream_url, http_client, unix_client, Some(pool)).await - } + async move { proxy::perform_health_check(client, Some(pool)).await } })); let tarpit_config = TarpitConfig::from_env(); @@ -186,9 +154,7 @@ async fn main() { ); let state = Arc::new(AppState { - downstream_url: args.downstream.clone(), - http_client, - unix_client, + client, health_checker, tarpit_state, pool: pool.clone(), diff --git a/src/og.rs b/src/og.rs index b05989e..6db0c55 100644 --- a/src/og.rs +++ b/src/og.rs @@ -36,17 +36,9 @@ pub async fn generate_og_image(spec: &OGImageSpec, state: Arc) -> Resu tracing::Span::current().record("r2_key", &r2_key); // Call Bun's internal endpoint - let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") - { - "http://localhost/internal/ogp/generate".to_string() - } else { - format!("{}/internal/ogp/generate", state.downstream_url) - }; - - let client = state.unix_client.as_ref().unwrap_or(&state.http_client); - - let response = client - .post(&bun_url) + let response = state + .client + .post("/internal/ogp/generate") .json(spec) .timeout(Duration::from_secs(30)) .send() diff --git a/src/proxy.rs b/src/proxy.rs index adceea6..188df3b 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -72,17 +72,10 @@ pub async fn isr_handler(State(state): State>, req: Request) -> Re return response; } - let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./") - { - if query.is_empty() { - format!("http://localhost{path}") - } else { - format!("http://localhost{path}?{query}") - } - } else if query.is_empty() { - format!("{}{}", state.downstream_url, path) + let path_with_query = if query.is_empty() { + path.to_string() } else { - format!("{}{}?{}", state.downstream_url, path, query) + format!("{path}?{query}") }; // Build trusted headers to forward to downstream @@ -120,7 +113,7 @@ pub async fn isr_handler(State(state): State>, req: Request) -> Re let start = std::time::Instant::now(); - match proxy_to_bun(&bun_url, state.clone(), forward_headers).await { + match proxy_to_bun(&path_with_query, state.clone(), forward_headers).await { Ok((status, headers, body)) => { let duration_ms = start.elapsed().as_millis() as u64; let cache = "miss"; @@ -182,7 +175,7 @@ pub async fn isr_handler(State(state): State>, req: Request) -> Re let duration_ms = start.elapsed().as_millis() as u64; tracing::error!( error = %err, - url = %bun_url, + path = %path_with_query, duration_ms, "Failed to proxy to Bun" ); @@ -203,18 +196,12 @@ pub async fn isr_handler(State(state): State>, req: Request) -> Re /// Proxy a request to Bun SSR pub async fn proxy_to_bun( - url: &str, + path: &str, state: Arc, forward_headers: HeaderMap, ) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> { - let client = if state.unix_client.is_some() { - state.unix_client.as_ref().unwrap() - } else { - &state.http_client - }; - // Build request with forwarded headers - let mut request_builder = client.get(url); + let mut request_builder = state.client.get(path); for (name, value) in forward_headers.iter() { request_builder = request_builder.header(name, value); } @@ -247,45 +234,35 @@ pub async fn proxy_to_bun( /// Perform health check on Bun SSR and database pub async fn perform_health_check( - downstream_url: String, - http_client: reqwest::Client, - unix_client: Option, + client: crate::http::HttpClient, pool: Option, ) -> bool { - let url = if downstream_url.starts_with('/') || downstream_url.starts_with("./") { - "http://localhost/internal/health".to_string() - } else { - format!("{downstream_url}/internal/health") + let bun_healthy = match tokio::time::timeout( + Duration::from_secs(5), + client.get("/internal/health").send(), + ) + .await + { + Ok(Ok(response)) => { + let is_success = response.status().is_success(); + if !is_success { + tracing::warn!( + status = response.status().as_u16(), + "Health check failed: Bun returned non-success status" + ); + } + is_success + } + Ok(Err(err)) => { + tracing::error!(error = %err, "Health check failed: cannot reach Bun"); + false + } + Err(_) => { + tracing::error!("Health check failed: timeout after 5s"); + false + } }; - let client = if unix_client.is_some() { - unix_client.as_ref().unwrap() - } else { - &http_client - }; - - let bun_healthy = - match tokio::time::timeout(Duration::from_secs(5), client.get(&url).send()).await { - Ok(Ok(response)) => { - let is_success = response.status().is_success(); - if !is_success { - tracing::warn!( - status = response.status().as_u16(), - "Health check failed: Bun returned non-success status" - ); - } - is_success - } - Ok(Err(err)) => { - tracing::error!(error = %err, "Health check failed: cannot reach Bun"); - false - } - Err(_) => { - tracing::error!("Health check failed: timeout after 5s"); - false - } - }; - // Check database let db_healthy = if let Some(pool) = pool { match db::health_check(&pool).await { @@ -307,33 +284,32 @@ fn should_tarpit(state: &TarpitState, path: &str) -> bool { state.config.enabled && tarpit::is_malicious_path(path) } +/// Common handler logic for requests with optional peer info +async fn handle_request_with_optional_peer( + state: Arc, + peer: Option, + req: Request, +) -> Response { + let path = req.uri().path(); + + if should_tarpit(&state.tarpit_state, path) { + let peer_info = peer.map(ConnectInfo); + tarpit::tarpit_handler(State(state.tarpit_state.clone()), peer_info, req).await + } else { + isr_handler(State(state), req).await + } +} + /// Fallback handler for TCP connections (has access to peer IP) pub async fn fallback_handler_tcp( State(state): State>, ConnectInfo(peer): ConnectInfo, req: Request, ) -> Response { - let path = req.uri().path(); - - if should_tarpit(&state.tarpit_state, path) { - tarpit::tarpit_handler( - State(state.tarpit_state.clone()), - Some(ConnectInfo(peer)), - req, - ) - .await - } else { - isr_handler(State(state), req).await - } + handle_request_with_optional_peer(state, Some(peer), req).await } /// Fallback handler for Unix sockets (no peer IP available) pub async fn fallback_handler_unix(State(state): State>, req: Request) -> Response { - let path = req.uri().path(); - - if should_tarpit(&state.tarpit_state, path) { - tarpit::tarpit_handler(State(state.tarpit_state.clone()), None, req).await - } else { - isr_handler(State(state), req).await - } + handle_request_with_optional_peer(state, None, req).await } diff --git a/src/routes.rs b/src/routes.rs index 171664c..fb2c1e0 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,4 +1,11 @@ -use axum::{Router, extract::Request, http::Uri, response::IntoResponse, routing::any}; +use axum::{ + Router, + body::Body, + extract::Request, + http::{Method, Uri}, + response::IntoResponse, + routing::{any, get, post}, +}; use std::sync::Arc; use crate::{assets, handlers, state::AppState}; @@ -9,32 +16,28 @@ pub fn api_routes() -> Router> { .route("/", any(api_root_404_handler)) .route( "/health", - axum::routing::get(handlers::health_handler).head(handlers::health_handler), + get(handlers::health_handler).head(handlers::health_handler), ) // Authentication endpoints (public) - .route("/login", axum::routing::post(handlers::api_login_handler)) - .route("/logout", axum::routing::post(handlers::api_logout_handler)) - .route( - "/session", - axum::routing::get(handlers::api_session_handler), - ) + .route("/login", post(handlers::api_login_handler)) + .route("/logout", post(handlers::api_logout_handler)) + .route("/session", get(handlers::api_session_handler)) // Projects - GET is public (shows all for admin, only non-hidden for public) // POST/PUT/DELETE require authentication .route( "/projects", - axum::routing::get(handlers::projects_handler).post(handlers::create_project_handler), + get(handlers::projects_handler).post(handlers::create_project_handler), ) .route( "/projects/{id}", - axum::routing::get(handlers::get_project_handler) + get(handlers::get_project_handler) .put(handlers::update_project_handler) .delete(handlers::delete_project_handler), ) // Project tags - authentication checked in handlers .route( "/projects/{id}/tags", - axum::routing::get(handlers::get_project_tags_handler) - .post(handlers::add_project_tag_handler), + get(handlers::get_project_tags_handler).post(handlers::add_project_tag_handler), ) .route( "/projects/{id}/tags/{tag_id}", @@ -43,36 +46,29 @@ pub fn api_routes() -> Router> { // Tags - authentication checked in handlers .route( "/tags", - axum::routing::get(handlers::list_tags_handler).post(handlers::create_tag_handler), + get(handlers::list_tags_handler).post(handlers::create_tag_handler), ) .route( "/tags/{slug}", - axum::routing::get(handlers::get_tag_handler).put(handlers::update_tag_handler), + get(handlers::get_tag_handler).put(handlers::update_tag_handler), ) .route( "/tags/{slug}/related", - axum::routing::get(handlers::get_related_tags_handler), + get(handlers::get_related_tags_handler), ) .route( "/tags/recalculate-cooccurrence", - axum::routing::post(handlers::recalculate_cooccurrence_handler), + post(handlers::recalculate_cooccurrence_handler), ) // Admin stats - requires authentication - .route( - "/stats", - axum::routing::get(handlers::get_admin_stats_handler), - ) + .route("/stats", get(handlers::get_admin_stats_handler)) // Site settings - GET is public, PUT requires authentication .route( "/settings", - axum::routing::get(handlers::get_settings_handler) - .put(handlers::update_settings_handler), + get(handlers::get_settings_handler).put(handlers::update_settings_handler), ) // Icon API - proxy to SvelteKit (authentication handled by SvelteKit) - .route( - "/icons/{*path}", - axum::routing::get(handlers::proxy_icons_handler), - ) + .route("/icons/{*path}", get(handlers::proxy_icons_handler)) .fallback(api_404_and_method_handler) } @@ -83,19 +79,13 @@ pub fn build_base_router() -> Router> { .route("/api/", any(api_root_404_handler)) .route( "/_app/{*path}", - axum::routing::get(assets::serve_embedded_asset).head(assets::serve_embedded_asset), + get(assets::serve_embedded_asset).head(assets::serve_embedded_asset), ) - .route("/pgp", axum::routing::get(handlers::handle_pgp_route)) - .route( - "/publickey.asc", - axum::routing::get(handlers::serve_pgp_key), - ) - .route("/pgp.asc", axum::routing::get(handlers::serve_pgp_key)) - .route( - "/.well-known/pgpkey.asc", - axum::routing::get(handlers::serve_pgp_key), - ) - .route("/keys", axum::routing::get(handlers::redirect_to_pgp)) + .route("/pgp", get(handlers::handle_pgp_route)) + .route("/publickey.asc", get(handlers::serve_pgp_key)) + .route("/pgp.asc", get(handlers::serve_pgp_key)) + .route("/.well-known/pgpkey.asc", get(handlers::serve_pgp_key)) + .route("/keys", get(handlers::redirect_to_pgp)) } async fn api_root_404_handler(uri: Uri) -> impl IntoResponse { @@ -109,10 +99,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse { let uri = req.uri(); let path = uri.path(); - if method != axum::http::Method::GET - && method != axum::http::Method::HEAD - && method != axum::http::Method::OPTIONS - { + if method != Method::GET && method != Method::HEAD && method != Method::OPTIONS { let content_type = req .headers() .get(axum::http::header::CONTENT_TYPE) @@ -129,10 +116,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse { ) .into_response(); } - } else if method == axum::http::Method::POST - || method == axum::http::Method::PUT - || method == axum::http::Method::PATCH - { + } else if method == Method::POST || method == Method::PUT || method == Method::PATCH { // POST/PUT/PATCH require Content-Type header return ( StatusCode::BAD_REQUEST, @@ -158,10 +142,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse { } async fn api_404_handler(uri: Uri) -> impl IntoResponse { - let req = Request::builder() - .uri(uri) - .body(axum::body::Body::empty()) - .unwrap(); + let req = Request::builder().uri(uri).body(Body::empty()).unwrap(); api_404_and_method_handler(req).await } diff --git a/src/state.rs b/src/state.rs index addcb8b..d6b5673 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,13 +1,11 @@ use std::sync::Arc; -use crate::{auth::SessionManager, health::HealthChecker, tarpit::TarpitState}; +use crate::{auth::SessionManager, health::HealthChecker, http::HttpClient, tarpit::TarpitState}; /// Application state shared across all handlers #[derive(Clone)] pub struct AppState { - pub downstream_url: String, - pub http_client: reqwest::Client, - pub unix_client: Option, + pub client: HttpClient, pub health_checker: Arc, pub tarpit_state: Arc, pub pool: sqlx::PgPool, diff --git a/src/utils.rs b/src/utils.rs index 3e4d486..511d706 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ use axum::{ - http::{HeaderMap, StatusCode}, + http::{HeaderMap, HeaderValue, StatusCode, header}, response::{IntoResponse, Response}, }; @@ -34,7 +34,7 @@ pub fn is_page_route(path: &str) -> bool { /// Check if the request accepts HTML responses pub fn accepts_html(headers: &HeaderMap) -> bool { - if let Some(accept) = headers.get(axum::http::header::ACCEPT) { + if let Some(accept) = headers.get(header::ACCEPT) { if let Ok(accept_str) = accept.to_str() { return accept_str.contains("text/html") || accept_str.contains("*/*"); } @@ -46,7 +46,7 @@ pub fn accepts_html(headers: &HeaderMap) -> bool { /// Determines if request prefers raw content (CLI tools) over HTML pub fn prefers_raw_content(headers: &HeaderMap) -> bool { // Check User-Agent for known CLI tools first (most reliable) - if let Some(ua) = headers.get(axum::http::header::USER_AGENT) { + if let Some(ua) = headers.get(header::USER_AGENT) { if let Ok(ua_str) = ua.to_str() { let ua_lower = ua_str.to_lowercase(); if ua_lower.starts_with("curl/") @@ -60,7 +60,7 @@ pub fn prefers_raw_content(headers: &HeaderMap) -> bool { } // Check Accept header - if it explicitly prefers text/html, serve HTML - if let Some(accept) = headers.get(axum::http::header::ACCEPT) { + if let Some(accept) = headers.get(header::ACCEPT) { if let Ok(accept_str) = accept.to_str() { // If text/html appears before */* in the list, they prefer HTML if let Some(html_pos) = accept_str.find("text/html") { @@ -88,12 +88,12 @@ pub fn serve_error_page(status: StatusCode) -> Response { if let Some(html) = assets::get_error_page(status_code) { let mut headers = HeaderMap::new(); headers.insert( - axum::http::header::CONTENT_TYPE, - axum::http::HeaderValue::from_static("text/html; charset=utf-8"), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html; charset=utf-8"), ); headers.insert( - axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"), + header::CACHE_CONTROL, + HeaderValue::from_static("no-cache, no-store, must-revalidate"), ); (status, headers, html).into_response() @@ -107,12 +107,12 @@ pub fn serve_error_page(status: StatusCode) -> Response { if let Some(fallback_html) = assets::get_error_page(500) { let mut headers = HeaderMap::new(); headers.insert( - axum::http::header::CONTENT_TYPE, - axum::http::HeaderValue::from_static("text/html; charset=utf-8"), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html; charset=utf-8"), ); headers.insert( - axum::http::header::CACHE_CONTROL, - axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"), + header::CACHE_CONTROL, + HeaderValue::from_static("no-cache, no-store, must-revalidate"), ); (status, headers, fallback_html).into_response() diff --git a/web/src/lib/api.server.ts b/web/src/lib/api.server.ts index 37a6c31..be61f4f 100644 --- a/web/src/lib/api.server.ts +++ b/web/src/lib/api.server.ts @@ -3,68 +3,86 @@ import { env } from "$env/dynamic/private"; const logger = getLogger(["ssr", "lib", "api"]); -const upstreamUrl = env.UPSTREAM_URL; -const isUnixSocket = - upstreamUrl?.startsWith("/") || upstreamUrl?.startsWith("./"); -const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl; +interface FetchOptions extends RequestInit { + fetch?: typeof fetch; +} -export async function apiFetch( - path: string, - init?: RequestInit & { fetch?: typeof fetch }, -): Promise { +interface BunFetchOptions extends RequestInit { + unix?: string; +} + +/** + * Create a socket-aware fetch function + * Automatically handles Unix socket vs TCP based on UPSTREAM_URL + */ +function createSmartFetch(upstreamUrl: string | undefined) { if (!upstreamUrl) { - logger.error("UPSTREAM_URL environment variable not set"); - throw new Error("UPSTREAM_URL environment variable not set"); + const error = "UPSTREAM_URL environment variable not set"; + logger.error(error); + throw new Error(error); } - const url = `${baseUrl}${path}`; - const method = init?.method ?? "GET"; + const isUnixSocket = + upstreamUrl.startsWith("/") || upstreamUrl.startsWith("./"); + const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl; - // Unix sockets require Bun's native fetch (SvelteKit's fetch doesn't support it) - const fetchFn = isUnixSocket ? fetch : (init?.fetch ?? fetch); + return async function smartFetch( + path: string, + options?: FetchOptions, + ): Promise { + const url = `${baseUrl}${path}`; + const method = options?.method ?? "GET"; - const fetchOptions: RequestInit & { unix?: string } = { - ...init, - signal: init?.signal ?? AbortSignal.timeout(30_000), - }; + // Unix sockets require Bun's native fetch + // SvelteKit's fetch doesn't support the 'unix' option + const fetchFn = isUnixSocket ? fetch : (options?.fetch ?? fetch); - // Remove custom fetch property from options - delete (fetchOptions as Record).fetch; + const fetchOptions: BunFetchOptions = { + ...options, + signal: options?.signal ?? AbortSignal.timeout(30_000), + }; - if (isUnixSocket) { - fetchOptions.unix = upstreamUrl; - } + // Remove custom fetch property from options (not part of standard RequestInit) + delete (fetchOptions as Record).fetch; - logger.debug("API request", { - method, - url, - path, - isUnixSocket, - upstreamUrl, - }); - - try { - const response = await fetchFn(url, fetchOptions); - - if (!response.ok) { - logger.error("API request failed", { - method, - url, - status: response.status, - statusText: response.statusText, - }); - throw new Error(`API error: ${response.status} ${response.statusText}`); + // Add Unix socket path if needed + if (isUnixSocket) { + fetchOptions.unix = upstreamUrl; } - const data = await response.json(); - logger.debug("API response", { method, url, status: response.status }); - return data; - } catch (error) { - logger.error("API request exception", { + logger.debug("API request", { method, url, - error: error instanceof Error ? error.message : String(error), + path, + isUnixSocket, }); - throw error; - } + + try { + const response = await fetchFn(url, fetchOptions); + + if (!response.ok) { + logger.error("API request failed", { + method, + url, + status: response.status, + statusText: response.statusText, + }); + throw new Error(`API error: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + logger.debug("API response", { method, url, status: response.status }); + return data; + } catch (error) { + logger.error("API request exception", { + method, + url, + error: error instanceof Error ? error.message : String(error), + }); + throw error; + } + }; } + +// Export the configured smart fetch function +export const apiFetch = createSmartFetch(env.UPSTREAM_URL);