mirror of
https://github.com/Xevion/xevion.dev.git
synced 2026-01-31 02:26:38 -06:00
refactor: consolidate HTTP client for TCP/Unix socket handling
Extract reqwest client creation into dedicated HttpClient abstraction that handles both TCP and Unix socket connections transparently. Simplifies proxy logic by removing duplicate URL construction and client selection throughout the codebase.
This commit is contained in:
Generated
+1
@@ -88,6 +88,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
"thiserror 2.0.17",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ reqwest = { version = "0.13.1", default-features = false, features = ["rustls",
|
|||||||
serde = { version = "1.0.228", features = ["derive"] }
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
serde_json = "1.0.148"
|
serde_json = "1.0.148"
|
||||||
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] }
|
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "time", "migrate"] }
|
||||||
|
thiserror = "2.0.17"
|
||||||
time = { version = "0.3.44", features = ["formatting", "macros", "serde"] }
|
time = { version = "0.3.44", features = ["formatting", "macros", "serde"] }
|
||||||
tokio = { version = "1.49.0", features = ["full"] }
|
tokio = { version = "1.49.0", features = ["full"] }
|
||||||
tokio-util = { version = "0.7.18", features = ["io"] }
|
tokio-util = { version = "0.7.18", features = ["io"] }
|
||||||
|
|||||||
+2
-2
@@ -7,7 +7,7 @@ pub use projects::*;
|
|||||||
pub use settings::*;
|
pub use settings::*;
|
||||||
pub use tags::*;
|
pub use tags::*;
|
||||||
|
|
||||||
use sqlx::{PgPool, postgres::PgPoolOptions};
|
use sqlx::{PgPool, postgres::PgPoolOptions, query};
|
||||||
|
|
||||||
/// Database connection pool creation
|
/// Database connection pool creation
|
||||||
pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
|
pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
|
||||||
@@ -20,7 +20,7 @@ pub async fn create_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
|
|||||||
|
|
||||||
/// Health check query
|
/// Health check query
|
||||||
pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> {
|
pub async fn health_check(pool: &PgPool) -> Result<(), sqlx::Error> {
|
||||||
sqlx::query!("SELECT 1 as check")
|
query!("SELECT 1 as check")
|
||||||
.fetch_one(pool)
|
.fetch_one(pool)
|
||||||
.await
|
.await
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
|
|||||||
+31
-34
@@ -1,9 +1,13 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::PgPool;
|
use serde_json::json;
|
||||||
use time::OffsetDateTime;
|
use sqlx::{PgPool, query, query_as};
|
||||||
|
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::{ProjectStatus, slugify};
|
use super::{
|
||||||
|
ProjectStatus, slugify,
|
||||||
|
tags::{ApiTag, DbTag, get_tags_for_project},
|
||||||
|
};
|
||||||
|
|
||||||
// Database model
|
// Database model
|
||||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||||
@@ -43,7 +47,7 @@ pub struct ApiProject {
|
|||||||
pub struct ApiAdminProject {
|
pub struct ApiAdminProject {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub project: ApiProject,
|
pub project: ApiProject,
|
||||||
pub tags: Vec<super::tags::ApiTag>,
|
pub tags: Vec<ApiTag>,
|
||||||
pub status: String,
|
pub status: String,
|
||||||
pub description: String,
|
pub description: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -86,7 +90,7 @@ impl DbProject {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_api_admin_project(&self, tags: Vec<super::tags::DbTag>) -> ApiAdminProject {
|
pub fn to_api_admin_project(&self, tags: Vec<DbTag>) -> ApiAdminProject {
|
||||||
ApiAdminProject {
|
ApiAdminProject {
|
||||||
project: self.to_api_project(),
|
project: self.to_api_project(),
|
||||||
tags: tags.into_iter().map(|t| t.to_api_tag()).collect(),
|
tags: tags.into_iter().map(|t| t.to_api_tag()).collect(),
|
||||||
@@ -94,18 +98,11 @@ impl DbProject {
|
|||||||
description: self.description.clone(),
|
description: self.description.clone(),
|
||||||
github_repo: self.github_repo.clone(),
|
github_repo: self.github_repo.clone(),
|
||||||
demo_url: self.demo_url.clone(),
|
demo_url: self.demo_url.clone(),
|
||||||
created_at: self
|
created_at: self.created_at.format(&Rfc3339).unwrap(),
|
||||||
.created_at
|
updated_at: self.updated_at.format(&Rfc3339).unwrap(),
|
||||||
.format(&time::format_description::well_known::Rfc3339)
|
last_github_activity: self
|
||||||
.unwrap(),
|
.last_github_activity
|
||||||
updated_at: self
|
.map(|dt| dt.format(&Rfc3339).unwrap()),
|
||||||
.updated_at
|
|
||||||
.format(&time::format_description::well_known::Rfc3339)
|
|
||||||
.unwrap(),
|
|
||||||
last_github_activity: self.last_github_activity.map(|dt| {
|
|
||||||
dt.format(&time::format_description::well_known::Rfc3339)
|
|
||||||
.unwrap()
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -150,7 +147,7 @@ pub struct AdminStats {
|
|||||||
// Query functions
|
// Query functions
|
||||||
|
|
||||||
pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
|
pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
@@ -176,12 +173,12 @@ pub async fn get_public_projects(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::
|
|||||||
|
|
||||||
pub async fn get_public_projects_with_tags(
|
pub async fn get_public_projects_with_tags(
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
|
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
|
||||||
let projects = get_public_projects(pool).await?;
|
let projects = get_public_projects(pool).await?;
|
||||||
|
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for project in projects {
|
for project in projects {
|
||||||
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
|
let tags = get_tags_for_project(pool, project.id).await?;
|
||||||
result.push((project, tags));
|
result.push((project, tags));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +187,7 @@ pub async fn get_public_projects_with_tags(
|
|||||||
|
|
||||||
/// Get all projects (admin view - includes hidden)
|
/// Get all projects (admin view - includes hidden)
|
||||||
pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
|
pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sqlx::Error> {
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
@@ -216,12 +213,12 @@ pub async fn get_all_projects_admin(pool: &PgPool) -> Result<Vec<DbProject>, sql
|
|||||||
/// Get all projects with tags (admin view)
|
/// Get all projects with tags (admin view)
|
||||||
pub async fn get_all_projects_with_tags_admin(
|
pub async fn get_all_projects_with_tags_admin(
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
) -> Result<Vec<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
|
) -> Result<Vec<(DbProject, Vec<DbTag>)>, sqlx::Error> {
|
||||||
let projects = get_all_projects_admin(pool).await?;
|
let projects = get_all_projects_admin(pool).await?;
|
||||||
|
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for project in projects {
|
for project in projects {
|
||||||
let tags = super::tags::get_tags_for_project(pool, project.id).await?;
|
let tags = get_tags_for_project(pool, project.id).await?;
|
||||||
result.push((project, tags));
|
result.push((project, tags));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +227,7 @@ pub async fn get_all_projects_with_tags_admin(
|
|||||||
|
|
||||||
/// Get single project by ID
|
/// Get single project by ID
|
||||||
pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> {
|
pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProject>, sqlx::Error> {
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
@@ -259,12 +256,12 @@ pub async fn get_project_by_id(pool: &PgPool, id: Uuid) -> Result<Option<DbProje
|
|||||||
pub async fn get_project_by_id_with_tags(
|
pub async fn get_project_by_id_with_tags(
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
id: Uuid,
|
id: Uuid,
|
||||||
) -> Result<Option<(DbProject, Vec<super::tags::DbTag>)>, sqlx::Error> {
|
) -> Result<Option<(DbProject, Vec<DbTag>)>, sqlx::Error> {
|
||||||
let project = get_project_by_id(pool, id).await?;
|
let project = get_project_by_id(pool, id).await?;
|
||||||
|
|
||||||
match project {
|
match project {
|
||||||
Some(p) => {
|
Some(p) => {
|
||||||
let tags = super::tags::get_tags_for_project(pool, p.id).await?;
|
let tags = get_tags_for_project(pool, p.id).await?;
|
||||||
Ok(Some((p, tags)))
|
Ok(Some((p, tags)))
|
||||||
}
|
}
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
@@ -276,7 +273,7 @@ pub async fn get_project_by_slug(
|
|||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
slug: &str,
|
slug: &str,
|
||||||
) -> Result<Option<DbProject>, sqlx::Error> {
|
) -> Result<Option<DbProject>, sqlx::Error> {
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
@@ -316,7 +313,7 @@ pub async fn create_project(
|
|||||||
.map(|s| slugify(s))
|
.map(|s| slugify(s))
|
||||||
.unwrap_or_else(|| slugify(name));
|
.unwrap_or_else(|| slugify(name));
|
||||||
|
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url)
|
INSERT INTO projects (slug, name, short_description, description, status, github_repo, demo_url)
|
||||||
@@ -352,7 +349,7 @@ pub async fn update_project(
|
|||||||
.map(|s| slugify(s))
|
.map(|s| slugify(s))
|
||||||
.unwrap_or_else(|| slugify(name));
|
.unwrap_or_else(|| slugify(name));
|
||||||
|
|
||||||
sqlx::query_as!(
|
query_as!(
|
||||||
DbProject,
|
DbProject,
|
||||||
r#"
|
r#"
|
||||||
UPDATE projects
|
UPDATE projects
|
||||||
@@ -377,7 +374,7 @@ pub async fn update_project(
|
|||||||
|
|
||||||
/// Delete project (CASCADE will handle tags)
|
/// Delete project (CASCADE will handle tags)
|
||||||
pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
|
pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
|
||||||
sqlx::query!("DELETE FROM projects WHERE id = $1", id)
|
query!("DELETE FROM projects WHERE id = $1", id)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -386,7 +383,7 @@ pub async fn delete_project(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error>
|
|||||||
/// Get admin stats
|
/// Get admin stats
|
||||||
pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
|
pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
|
||||||
// Get project counts by status
|
// Get project counts by status
|
||||||
let status_counts = sqlx::query!(
|
let status_counts = query!(
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
status as "status!: ProjectStatus",
|
status as "status!: ProjectStatus",
|
||||||
@@ -398,7 +395,7 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
|
|||||||
.fetch_all(pool)
|
.fetch_all(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut projects_by_status = serde_json::json!({
|
let mut projects_by_status = json!({
|
||||||
"active": 0,
|
"active": 0,
|
||||||
"maintained": 0,
|
"maintained": 0,
|
||||||
"archived": 0,
|
"archived": 0,
|
||||||
@@ -408,12 +405,12 @@ pub async fn get_admin_stats(pool: &PgPool) -> Result<AdminStats, sqlx::Error> {
|
|||||||
let mut total_projects = 0;
|
let mut total_projects = 0;
|
||||||
for row in status_counts {
|
for row in status_counts {
|
||||||
let status_str = format!("{:?}", row.status).to_lowercase();
|
let status_str = format!("{:?}", row.status).to_lowercase();
|
||||||
projects_by_status[status_str] = serde_json::json!(row.count);
|
projects_by_status[status_str] = json!(row.count);
|
||||||
total_projects += row.count;
|
total_projects += row.count;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get total tags
|
// Get total tags
|
||||||
let tag_count = sqlx::query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
|
let tag_count = query!("SELECT COUNT(*)::int as \"count!\" FROM tags")
|
||||||
.fetch_one(pool)
|
.fetch_one(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|||||||
+4
-11
@@ -60,17 +60,10 @@ pub async fn proxy_icons_handler(
|
|||||||
let full_path = format!("/api/icons/{}", path);
|
let full_path = format!("/api/icons/{}", path);
|
||||||
let query = req.uri().query().unwrap_or("");
|
let query = req.uri().query().unwrap_or("");
|
||||||
|
|
||||||
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
|
let path_with_query = if query.is_empty() {
|
||||||
{
|
full_path.clone()
|
||||||
if query.is_empty() {
|
|
||||||
format!("http://localhost{}", full_path)
|
|
||||||
} else {
|
} else {
|
||||||
format!("http://localhost{}?{}", full_path, query)
|
format!("{full_path}?{query}")
|
||||||
}
|
|
||||||
} else if query.is_empty() {
|
|
||||||
format!("{}{}", state.downstream_url, full_path)
|
|
||||||
} else {
|
|
||||||
format!("{}{}?{}", state.downstream_url, full_path, query)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build trusted headers with session info
|
// Build trusted headers with session info
|
||||||
@@ -86,7 +79,7 @@ pub async fn proxy_icons_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match proxy::proxy_to_bun(&bun_url, state, forward_headers).await {
|
match proxy::proxy_to_bun(&path_with_query, state, forward_headers).await {
|
||||||
Ok((status, headers, body)) => (status, headers, body).into_response(),
|
Ok((status, headers, body)) => (status, headers, body).into_response(),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request");
|
tracing::error!(error = %err, path = %full_path, "Failed to proxy icon request");
|
||||||
|
|||||||
+146
@@ -0,0 +1,146 @@
|
|||||||
|
use reqwest::Method;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::time::Duration;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("Failed to build reqwest client: {0}")]
|
||||||
|
BuildError(#[from] reqwest::Error),
|
||||||
|
|
||||||
|
#[error("Invalid downstream URL: {0}")]
|
||||||
|
InvalidUrl(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct HttpClient {
|
||||||
|
client: reqwest::Client,
|
||||||
|
target: TargetUrl,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum TargetUrl {
|
||||||
|
Tcp(String), // Base URL like "http://localhost:5173"
|
||||||
|
Unix(PathBuf), // Socket path like "/tmp/bun.sock"
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HttpClient {
|
||||||
|
/// Create a new HttpClient from a downstream URL
|
||||||
|
///
|
||||||
|
/// Accepts:
|
||||||
|
/// - TCP: "http://localhost:5173", "https://example.com"
|
||||||
|
/// - Unix: "/tmp/bun.sock", "./relative.sock"
|
||||||
|
pub fn new(downstream: &str) -> Result<Self, ClientError> {
|
||||||
|
let target = if downstream.starts_with('/') || downstream.starts_with("./") {
|
||||||
|
TargetUrl::Unix(PathBuf::from(downstream))
|
||||||
|
} else if downstream.starts_with("http://") || downstream.starts_with("https://") {
|
||||||
|
TargetUrl::Tcp(downstream.to_string())
|
||||||
|
} else {
|
||||||
|
return Err(ClientError::InvalidUrl(downstream.to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
target = ?target,
|
||||||
|
downstream = %downstream,
|
||||||
|
"Creating HTTP client"
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = match &target {
|
||||||
|
TargetUrl::Unix(path) => reqwest::Client::builder()
|
||||||
|
.pool_max_idle_per_host(8)
|
||||||
|
.pool_idle_timeout(Duration::from_secs(600))
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.connect_timeout(Duration::from_secs(3))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
|
.unix_socket(path.clone())
|
||||||
|
.build()?,
|
||||||
|
TargetUrl::Tcp(_) => reqwest::Client::builder()
|
||||||
|
.pool_max_idle_per_host(8)
|
||||||
|
.pool_idle_timeout(Duration::from_secs(600))
|
||||||
|
.tcp_keepalive(Some(Duration::from_secs(60)))
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.connect_timeout(Duration::from_secs(3))
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
|
.build()?,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self { client, target })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a full URL from a path
|
||||||
|
///
|
||||||
|
/// Examples:
|
||||||
|
/// - TCP target "http://localhost:5173" + "/api/health" → "http://localhost:5173/api/health"
|
||||||
|
/// - Unix target "/tmp/bun.sock" + "/api/health" → "http://localhost/api/health"
|
||||||
|
fn build_url(&self, path: &str) -> String {
|
||||||
|
match &self.target {
|
||||||
|
TargetUrl::Tcp(base) => format!("{}{}", base, path),
|
||||||
|
TargetUrl::Unix(_) => format!("http://localhost{}", path),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
|
||||||
|
self.client.get(self.build_url(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||||
|
self.client.post(self.build_url(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
|
||||||
|
self.client.request(method, self.build_url(path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tcp_url_construction() {
|
||||||
|
let client = HttpClient::new("http://localhost:5173").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client.build_url("/api/health"),
|
||||||
|
"http://localhost:5173/api/health"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
client.build_url("/path?query=1"),
|
||||||
|
"http://localhost:5173/path?query=1"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unix_url_construction() {
|
||||||
|
let client = HttpClient::new("/tmp/bun.sock").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client.build_url("/api/health"),
|
||||||
|
"http://localhost/api/health"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
client.build_url("/path?query=1"),
|
||||||
|
"http://localhost/path?query=1"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_relative_unix_socket() {
|
||||||
|
let client = HttpClient::new("./relative.sock").unwrap();
|
||||||
|
assert!(matches!(client.target, TargetUrl::Unix(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_url() {
|
||||||
|
let result = HttpClient::new("not-a-valid-url");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_https_url() {
|
||||||
|
let client = HttpClient::new("https://example.com").unwrap();
|
||||||
|
assert!(matches!(client.target, TargetUrl::Tcp(_)));
|
||||||
|
assert_eq!(
|
||||||
|
client.build_url("/api/test"),
|
||||||
|
"https://example.com/api/test"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
+7
-41
@@ -1,6 +1,5 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
|
use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer};
|
||||||
@@ -13,6 +12,7 @@ mod db;
|
|||||||
mod formatter;
|
mod formatter;
|
||||||
mod handlers;
|
mod handlers;
|
||||||
mod health;
|
mod health;
|
||||||
|
mod http;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod og;
|
mod og;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
@@ -125,50 +125,18 @@ async fn main() {
|
|||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create HTTP client for TCP connections with optimized pool settings
|
// Create socket-aware HTTP client
|
||||||
let http_client = reqwest::Client::builder()
|
let client = http::HttpClient::new(&args.downstream).expect("Failed to create HTTP client");
|
||||||
.pool_max_idle_per_host(8)
|
|
||||||
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
|
|
||||||
.tcp_keepalive(Some(Duration::from_secs(60)))
|
|
||||||
.timeout(Duration::from_secs(5)) // Default timeout for SSR
|
|
||||||
.connect_timeout(Duration::from_secs(3))
|
|
||||||
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
|
|
||||||
.build()
|
|
||||||
.expect("Failed to create HTTP client");
|
|
||||||
|
|
||||||
// Create Unix socket client if downstream is a Unix socket
|
|
||||||
let unix_client = if args.downstream.starts_with('/') || args.downstream.starts_with("./") {
|
|
||||||
let path = PathBuf::from(&args.downstream);
|
|
||||||
Some(
|
|
||||||
reqwest::Client::builder()
|
|
||||||
.pool_max_idle_per_host(8)
|
|
||||||
.pool_idle_timeout(Duration::from_secs(600)) // 10 minutes
|
|
||||||
.timeout(Duration::from_secs(5)) // Default timeout for SSR
|
|
||||||
.connect_timeout(Duration::from_secs(3))
|
|
||||||
.redirect(reqwest::redirect::Policy::none()) // Don't follow redirects - pass them through
|
|
||||||
.unix_socket(path)
|
|
||||||
.build()
|
|
||||||
.expect("Failed to create Unix socket client"),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create health checker
|
// Create health checker
|
||||||
let downstream_url_for_health = args.downstream.clone();
|
let client_for_health = client.clone();
|
||||||
let http_client_for_health = http_client.clone();
|
|
||||||
let unix_client_for_health = unix_client.clone();
|
|
||||||
let pool_for_health = pool.clone();
|
let pool_for_health = pool.clone();
|
||||||
|
|
||||||
let health_checker = Arc::new(HealthChecker::new(move || {
|
let health_checker = Arc::new(HealthChecker::new(move || {
|
||||||
let downstream_url = downstream_url_for_health.clone();
|
let client = client_for_health.clone();
|
||||||
let http_client = http_client_for_health.clone();
|
|
||||||
let unix_client = unix_client_for_health.clone();
|
|
||||||
let pool = pool_for_health.clone();
|
let pool = pool_for_health.clone();
|
||||||
|
|
||||||
async move {
|
async move { proxy::perform_health_check(client, Some(pool)).await }
|
||||||
proxy::perform_health_check(downstream_url, http_client, unix_client, Some(pool)).await
|
|
||||||
}
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let tarpit_config = TarpitConfig::from_env();
|
let tarpit_config = TarpitConfig::from_env();
|
||||||
@@ -186,9 +154,7 @@ async fn main() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
downstream_url: args.downstream.clone(),
|
client,
|
||||||
http_client,
|
|
||||||
unix_client,
|
|
||||||
health_checker,
|
health_checker,
|
||||||
tarpit_state,
|
tarpit_state,
|
||||||
pool: pool.clone(),
|
pool: pool.clone(),
|
||||||
|
|||||||
@@ -36,17 +36,9 @@ pub async fn generate_og_image(spec: &OGImageSpec, state: Arc<AppState>) -> Resu
|
|||||||
tracing::Span::current().record("r2_key", &r2_key);
|
tracing::Span::current().record("r2_key", &r2_key);
|
||||||
|
|
||||||
// Call Bun's internal endpoint
|
// Call Bun's internal endpoint
|
||||||
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
|
let response = state
|
||||||
{
|
.client
|
||||||
"http://localhost/internal/ogp/generate".to_string()
|
.post("/internal/ogp/generate")
|
||||||
} else {
|
|
||||||
format!("{}/internal/ogp/generate", state.downstream_url)
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = state.unix_client.as_ref().unwrap_or(&state.http_client);
|
|
||||||
|
|
||||||
let response = client
|
|
||||||
.post(&bun_url)
|
|
||||||
.json(spec)
|
.json(spec)
|
||||||
.timeout(Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.send()
|
.send()
|
||||||
|
|||||||
+32
-56
@@ -72,17 +72,10 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
|
|||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
let bun_url = if state.downstream_url.starts_with('/') || state.downstream_url.starts_with("./")
|
let path_with_query = if query.is_empty() {
|
||||||
{
|
path.to_string()
|
||||||
if query.is_empty() {
|
|
||||||
format!("http://localhost{path}")
|
|
||||||
} else {
|
} else {
|
||||||
format!("http://localhost{path}?{query}")
|
format!("{path}?{query}")
|
||||||
}
|
|
||||||
} else if query.is_empty() {
|
|
||||||
format!("{}{}", state.downstream_url, path)
|
|
||||||
} else {
|
|
||||||
format!("{}{}?{}", state.downstream_url, path, query)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build trusted headers to forward to downstream
|
// Build trusted headers to forward to downstream
|
||||||
@@ -120,7 +113,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
match proxy_to_bun(&bun_url, state.clone(), forward_headers).await {
|
match proxy_to_bun(&path_with_query, state.clone(), forward_headers).await {
|
||||||
Ok((status, headers, body)) => {
|
Ok((status, headers, body)) => {
|
||||||
let duration_ms = start.elapsed().as_millis() as u64;
|
let duration_ms = start.elapsed().as_millis() as u64;
|
||||||
let cache = "miss";
|
let cache = "miss";
|
||||||
@@ -182,7 +175,7 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
|
|||||||
let duration_ms = start.elapsed().as_millis() as u64;
|
let duration_ms = start.elapsed().as_millis() as u64;
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
error = %err,
|
error = %err,
|
||||||
url = %bun_url,
|
path = %path_with_query,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
"Failed to proxy to Bun"
|
"Failed to proxy to Bun"
|
||||||
);
|
);
|
||||||
@@ -203,18 +196,12 @@ pub async fn isr_handler(State(state): State<Arc<AppState>>, req: Request) -> Re
|
|||||||
|
|
||||||
/// Proxy a request to Bun SSR
|
/// Proxy a request to Bun SSR
|
||||||
pub async fn proxy_to_bun(
|
pub async fn proxy_to_bun(
|
||||||
url: &str,
|
path: &str,
|
||||||
state: Arc<AppState>,
|
state: Arc<AppState>,
|
||||||
forward_headers: HeaderMap,
|
forward_headers: HeaderMap,
|
||||||
) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> {
|
) -> Result<(StatusCode, HeaderMap, axum::body::Bytes), ProxyError> {
|
||||||
let client = if state.unix_client.is_some() {
|
|
||||||
state.unix_client.as_ref().unwrap()
|
|
||||||
} else {
|
|
||||||
&state.http_client
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build request with forwarded headers
|
// Build request with forwarded headers
|
||||||
let mut request_builder = client.get(url);
|
let mut request_builder = state.client.get(path);
|
||||||
for (name, value) in forward_headers.iter() {
|
for (name, value) in forward_headers.iter() {
|
||||||
request_builder = request_builder.header(name, value);
|
request_builder = request_builder.header(name, value);
|
||||||
}
|
}
|
||||||
@@ -247,25 +234,15 @@ pub async fn proxy_to_bun(
|
|||||||
|
|
||||||
/// Perform health check on Bun SSR and database
|
/// Perform health check on Bun SSR and database
|
||||||
pub async fn perform_health_check(
|
pub async fn perform_health_check(
|
||||||
downstream_url: String,
|
client: crate::http::HttpClient,
|
||||||
http_client: reqwest::Client,
|
|
||||||
unix_client: Option<reqwest::Client>,
|
|
||||||
pool: Option<sqlx::PgPool>,
|
pool: Option<sqlx::PgPool>,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let url = if downstream_url.starts_with('/') || downstream_url.starts_with("./") {
|
let bun_healthy = match tokio::time::timeout(
|
||||||
"http://localhost/internal/health".to_string()
|
Duration::from_secs(5),
|
||||||
} else {
|
client.get("/internal/health").send(),
|
||||||
format!("{downstream_url}/internal/health")
|
)
|
||||||
};
|
.await
|
||||||
|
{
|
||||||
let client = if unix_client.is_some() {
|
|
||||||
unix_client.as_ref().unwrap()
|
|
||||||
} else {
|
|
||||||
&http_client
|
|
||||||
};
|
|
||||||
|
|
||||||
let bun_healthy =
|
|
||||||
match tokio::time::timeout(Duration::from_secs(5), client.get(&url).send()).await {
|
|
||||||
Ok(Ok(response)) => {
|
Ok(Ok(response)) => {
|
||||||
let is_success = response.status().is_success();
|
let is_success = response.status().is_success();
|
||||||
if !is_success {
|
if !is_success {
|
||||||
@@ -307,33 +284,32 @@ fn should_tarpit(state: &TarpitState, path: &str) -> bool {
|
|||||||
state.config.enabled && tarpit::is_malicious_path(path)
|
state.config.enabled && tarpit::is_malicious_path(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Common handler logic for requests with optional peer info
|
||||||
|
async fn handle_request_with_optional_peer(
|
||||||
|
state: Arc<AppState>,
|
||||||
|
peer: Option<SocketAddr>,
|
||||||
|
req: Request,
|
||||||
|
) -> Response {
|
||||||
|
let path = req.uri().path();
|
||||||
|
|
||||||
|
if should_tarpit(&state.tarpit_state, path) {
|
||||||
|
let peer_info = peer.map(ConnectInfo);
|
||||||
|
tarpit::tarpit_handler(State(state.tarpit_state.clone()), peer_info, req).await
|
||||||
|
} else {
|
||||||
|
isr_handler(State(state), req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Fallback handler for TCP connections (has access to peer IP)
|
/// Fallback handler for TCP connections (has access to peer IP)
|
||||||
pub async fn fallback_handler_tcp(
|
pub async fn fallback_handler_tcp(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
ConnectInfo(peer): ConnectInfo<SocketAddr>,
|
ConnectInfo(peer): ConnectInfo<SocketAddr>,
|
||||||
req: Request,
|
req: Request,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let path = req.uri().path();
|
handle_request_with_optional_peer(state, Some(peer), req).await
|
||||||
|
|
||||||
if should_tarpit(&state.tarpit_state, path) {
|
|
||||||
tarpit::tarpit_handler(
|
|
||||||
State(state.tarpit_state.clone()),
|
|
||||||
Some(ConnectInfo(peer)),
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
isr_handler(State(state), req).await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fallback handler for Unix sockets (no peer IP available)
|
/// Fallback handler for Unix sockets (no peer IP available)
|
||||||
pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
pub async fn fallback_handler_unix(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
let path = req.uri().path();
|
handle_request_with_optional_peer(state, None, req).await
|
||||||
|
|
||||||
if should_tarpit(&state.tarpit_state, path) {
|
|
||||||
tarpit::tarpit_handler(State(state.tarpit_state.clone()), None, req).await
|
|
||||||
} else {
|
|
||||||
isr_handler(State(state), req).await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+31
-50
@@ -1,4 +1,11 @@
|
|||||||
use axum::{Router, extract::Request, http::Uri, response::IntoResponse, routing::any};
|
use axum::{
|
||||||
|
Router,
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{Method, Uri},
|
||||||
|
response::IntoResponse,
|
||||||
|
routing::{any, get, post},
|
||||||
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{assets, handlers, state::AppState};
|
use crate::{assets, handlers, state::AppState};
|
||||||
@@ -9,32 +16,28 @@ pub fn api_routes() -> Router<Arc<AppState>> {
|
|||||||
.route("/", any(api_root_404_handler))
|
.route("/", any(api_root_404_handler))
|
||||||
.route(
|
.route(
|
||||||
"/health",
|
"/health",
|
||||||
axum::routing::get(handlers::health_handler).head(handlers::health_handler),
|
get(handlers::health_handler).head(handlers::health_handler),
|
||||||
)
|
)
|
||||||
// Authentication endpoints (public)
|
// Authentication endpoints (public)
|
||||||
.route("/login", axum::routing::post(handlers::api_login_handler))
|
.route("/login", post(handlers::api_login_handler))
|
||||||
.route("/logout", axum::routing::post(handlers::api_logout_handler))
|
.route("/logout", post(handlers::api_logout_handler))
|
||||||
.route(
|
.route("/session", get(handlers::api_session_handler))
|
||||||
"/session",
|
|
||||||
axum::routing::get(handlers::api_session_handler),
|
|
||||||
)
|
|
||||||
// Projects - GET is public (shows all for admin, only non-hidden for public)
|
// Projects - GET is public (shows all for admin, only non-hidden for public)
|
||||||
// POST/PUT/DELETE require authentication
|
// POST/PUT/DELETE require authentication
|
||||||
.route(
|
.route(
|
||||||
"/projects",
|
"/projects",
|
||||||
axum::routing::get(handlers::projects_handler).post(handlers::create_project_handler),
|
get(handlers::projects_handler).post(handlers::create_project_handler),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/projects/{id}",
|
"/projects/{id}",
|
||||||
axum::routing::get(handlers::get_project_handler)
|
get(handlers::get_project_handler)
|
||||||
.put(handlers::update_project_handler)
|
.put(handlers::update_project_handler)
|
||||||
.delete(handlers::delete_project_handler),
|
.delete(handlers::delete_project_handler),
|
||||||
)
|
)
|
||||||
// Project tags - authentication checked in handlers
|
// Project tags - authentication checked in handlers
|
||||||
.route(
|
.route(
|
||||||
"/projects/{id}/tags",
|
"/projects/{id}/tags",
|
||||||
axum::routing::get(handlers::get_project_tags_handler)
|
get(handlers::get_project_tags_handler).post(handlers::add_project_tag_handler),
|
||||||
.post(handlers::add_project_tag_handler),
|
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/projects/{id}/tags/{tag_id}",
|
"/projects/{id}/tags/{tag_id}",
|
||||||
@@ -43,36 +46,29 @@ pub fn api_routes() -> Router<Arc<AppState>> {
|
|||||||
// Tags - authentication checked in handlers
|
// Tags - authentication checked in handlers
|
||||||
.route(
|
.route(
|
||||||
"/tags",
|
"/tags",
|
||||||
axum::routing::get(handlers::list_tags_handler).post(handlers::create_tag_handler),
|
get(handlers::list_tags_handler).post(handlers::create_tag_handler),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/tags/{slug}",
|
"/tags/{slug}",
|
||||||
axum::routing::get(handlers::get_tag_handler).put(handlers::update_tag_handler),
|
get(handlers::get_tag_handler).put(handlers::update_tag_handler),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/tags/{slug}/related",
|
"/tags/{slug}/related",
|
||||||
axum::routing::get(handlers::get_related_tags_handler),
|
get(handlers::get_related_tags_handler),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/tags/recalculate-cooccurrence",
|
"/tags/recalculate-cooccurrence",
|
||||||
axum::routing::post(handlers::recalculate_cooccurrence_handler),
|
post(handlers::recalculate_cooccurrence_handler),
|
||||||
)
|
)
|
||||||
// Admin stats - requires authentication
|
// Admin stats - requires authentication
|
||||||
.route(
|
.route("/stats", get(handlers::get_admin_stats_handler))
|
||||||
"/stats",
|
|
||||||
axum::routing::get(handlers::get_admin_stats_handler),
|
|
||||||
)
|
|
||||||
// Site settings - GET is public, PUT requires authentication
|
// Site settings - GET is public, PUT requires authentication
|
||||||
.route(
|
.route(
|
||||||
"/settings",
|
"/settings",
|
||||||
axum::routing::get(handlers::get_settings_handler)
|
get(handlers::get_settings_handler).put(handlers::update_settings_handler),
|
||||||
.put(handlers::update_settings_handler),
|
|
||||||
)
|
)
|
||||||
// Icon API - proxy to SvelteKit (authentication handled by SvelteKit)
|
// Icon API - proxy to SvelteKit (authentication handled by SvelteKit)
|
||||||
.route(
|
.route("/icons/{*path}", get(handlers::proxy_icons_handler))
|
||||||
"/icons/{*path}",
|
|
||||||
axum::routing::get(handlers::proxy_icons_handler),
|
|
||||||
)
|
|
||||||
.fallback(api_404_and_method_handler)
|
.fallback(api_404_and_method_handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,19 +79,13 @@ pub fn build_base_router() -> Router<Arc<AppState>> {
|
|||||||
.route("/api/", any(api_root_404_handler))
|
.route("/api/", any(api_root_404_handler))
|
||||||
.route(
|
.route(
|
||||||
"/_app/{*path}",
|
"/_app/{*path}",
|
||||||
axum::routing::get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
|
get(assets::serve_embedded_asset).head(assets::serve_embedded_asset),
|
||||||
)
|
)
|
||||||
.route("/pgp", axum::routing::get(handlers::handle_pgp_route))
|
.route("/pgp", get(handlers::handle_pgp_route))
|
||||||
.route(
|
.route("/publickey.asc", get(handlers::serve_pgp_key))
|
||||||
"/publickey.asc",
|
.route("/pgp.asc", get(handlers::serve_pgp_key))
|
||||||
axum::routing::get(handlers::serve_pgp_key),
|
.route("/.well-known/pgpkey.asc", get(handlers::serve_pgp_key))
|
||||||
)
|
.route("/keys", get(handlers::redirect_to_pgp))
|
||||||
.route("/pgp.asc", axum::routing::get(handlers::serve_pgp_key))
|
|
||||||
.route(
|
|
||||||
"/.well-known/pgpkey.asc",
|
|
||||||
axum::routing::get(handlers::serve_pgp_key),
|
|
||||||
)
|
|
||||||
.route("/keys", axum::routing::get(handlers::redirect_to_pgp))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn api_root_404_handler(uri: Uri) -> impl IntoResponse {
|
async fn api_root_404_handler(uri: Uri) -> impl IntoResponse {
|
||||||
@@ -109,10 +99,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
|
|||||||
let uri = req.uri();
|
let uri = req.uri();
|
||||||
let path = uri.path();
|
let path = uri.path();
|
||||||
|
|
||||||
if method != axum::http::Method::GET
|
if method != Method::GET && method != Method::HEAD && method != Method::OPTIONS {
|
||||||
&& method != axum::http::Method::HEAD
|
|
||||||
&& method != axum::http::Method::OPTIONS
|
|
||||||
{
|
|
||||||
let content_type = req
|
let content_type = req
|
||||||
.headers()
|
.headers()
|
||||||
.get(axum::http::header::CONTENT_TYPE)
|
.get(axum::http::header::CONTENT_TYPE)
|
||||||
@@ -129,10 +116,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
|
|||||||
)
|
)
|
||||||
.into_response();
|
.into_response();
|
||||||
}
|
}
|
||||||
} else if method == axum::http::Method::POST
|
} else if method == Method::POST || method == Method::PUT || method == Method::PATCH {
|
||||||
|| method == axum::http::Method::PUT
|
|
||||||
|| method == axum::http::Method::PATCH
|
|
||||||
{
|
|
||||||
// POST/PUT/PATCH require Content-Type header
|
// POST/PUT/PATCH require Content-Type header
|
||||||
return (
|
return (
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
@@ -158,10 +142,7 @@ async fn api_404_and_method_handler(req: Request) -> impl IntoResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn api_404_handler(uri: Uri) -> impl IntoResponse {
|
async fn api_404_handler(uri: Uri) -> impl IntoResponse {
|
||||||
let req = Request::builder()
|
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
|
||||||
.uri(uri)
|
|
||||||
.body(axum::body::Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
api_404_and_method_handler(req).await
|
api_404_and_method_handler(req).await
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-4
@@ -1,13 +1,11 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{auth::SessionManager, health::HealthChecker, tarpit::TarpitState};
|
use crate::{auth::SessionManager, health::HealthChecker, http::HttpClient, tarpit::TarpitState};
|
||||||
|
|
||||||
/// Application state shared across all handlers
|
/// Application state shared across all handlers
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub downstream_url: String,
|
pub client: HttpClient,
|
||||||
pub http_client: reqwest::Client,
|
|
||||||
pub unix_client: Option<reqwest::Client>,
|
|
||||||
pub health_checker: Arc<HealthChecker>,
|
pub health_checker: Arc<HealthChecker>,
|
||||||
pub tarpit_state: Arc<TarpitState>,
|
pub tarpit_state: Arc<TarpitState>,
|
||||||
pub pool: sqlx::PgPool,
|
pub pool: sqlx::PgPool,
|
||||||
|
|||||||
+12
-12
@@ -1,5 +1,5 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, HeaderValue, StatusCode, header},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ pub fn is_page_route(path: &str) -> bool {
|
|||||||
|
|
||||||
/// Check if the request accepts HTML responses
|
/// Check if the request accepts HTML responses
|
||||||
pub fn accepts_html(headers: &HeaderMap) -> bool {
|
pub fn accepts_html(headers: &HeaderMap) -> bool {
|
||||||
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
|
if let Some(accept) = headers.get(header::ACCEPT) {
|
||||||
if let Ok(accept_str) = accept.to_str() {
|
if let Ok(accept_str) = accept.to_str() {
|
||||||
return accept_str.contains("text/html") || accept_str.contains("*/*");
|
return accept_str.contains("text/html") || accept_str.contains("*/*");
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ pub fn accepts_html(headers: &HeaderMap) -> bool {
|
|||||||
/// Determines if request prefers raw content (CLI tools) over HTML
|
/// Determines if request prefers raw content (CLI tools) over HTML
|
||||||
pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
|
pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
|
||||||
// Check User-Agent for known CLI tools first (most reliable)
|
// Check User-Agent for known CLI tools first (most reliable)
|
||||||
if let Some(ua) = headers.get(axum::http::header::USER_AGENT) {
|
if let Some(ua) = headers.get(header::USER_AGENT) {
|
||||||
if let Ok(ua_str) = ua.to_str() {
|
if let Ok(ua_str) = ua.to_str() {
|
||||||
let ua_lower = ua_str.to_lowercase();
|
let ua_lower = ua_str.to_lowercase();
|
||||||
if ua_lower.starts_with("curl/")
|
if ua_lower.starts_with("curl/")
|
||||||
@@ -60,7 +60,7 @@ pub fn prefers_raw_content(headers: &HeaderMap) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check Accept header - if it explicitly prefers text/html, serve HTML
|
// Check Accept header - if it explicitly prefers text/html, serve HTML
|
||||||
if let Some(accept) = headers.get(axum::http::header::ACCEPT) {
|
if let Some(accept) = headers.get(header::ACCEPT) {
|
||||||
if let Ok(accept_str) = accept.to_str() {
|
if let Ok(accept_str) = accept.to_str() {
|
||||||
// If text/html appears before */* in the list, they prefer HTML
|
// If text/html appears before */* in the list, they prefer HTML
|
||||||
if let Some(html_pos) = accept_str.find("text/html") {
|
if let Some(html_pos) = accept_str.find("text/html") {
|
||||||
@@ -88,12 +88,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
|
|||||||
if let Some(html) = assets::get_error_page(status_code) {
|
if let Some(html) = assets::get_error_page(status_code) {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
axum::http::header::CONTENT_TYPE,
|
header::CONTENT_TYPE,
|
||||||
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
|
HeaderValue::from_static("text/html; charset=utf-8"),
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
axum::http::header::CACHE_CONTROL,
|
header::CACHE_CONTROL,
|
||||||
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
|
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
|
||||||
);
|
);
|
||||||
|
|
||||||
(status, headers, html).into_response()
|
(status, headers, html).into_response()
|
||||||
@@ -107,12 +107,12 @@ pub fn serve_error_page(status: StatusCode) -> Response {
|
|||||||
if let Some(fallback_html) = assets::get_error_page(500) {
|
if let Some(fallback_html) = assets::get_error_page(500) {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
axum::http::header::CONTENT_TYPE,
|
header::CONTENT_TYPE,
|
||||||
axum::http::HeaderValue::from_static("text/html; charset=utf-8"),
|
HeaderValue::from_static("text/html; charset=utf-8"),
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
axum::http::header::CACHE_CONTROL,
|
header::CACHE_CONTROL,
|
||||||
axum::http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
|
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
|
||||||
);
|
);
|
||||||
|
|
||||||
(status, headers, fallback_html).into_response()
|
(status, headers, fallback_html).into_response()
|
||||||
|
|||||||
+36
-18
@@ -3,34 +3,49 @@ import { env } from "$env/dynamic/private";
|
|||||||
|
|
||||||
const logger = getLogger(["ssr", "lib", "api"]);
|
const logger = getLogger(["ssr", "lib", "api"]);
|
||||||
|
|
||||||
const upstreamUrl = env.UPSTREAM_URL;
|
interface FetchOptions extends RequestInit {
|
||||||
const isUnixSocket =
|
fetch?: typeof fetch;
|
||||||
upstreamUrl?.startsWith("/") || upstreamUrl?.startsWith("./");
|
}
|
||||||
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
|
|
||||||
|
|
||||||
export async function apiFetch<T>(
|
interface BunFetchOptions extends RequestInit {
|
||||||
path: string,
|
unix?: string;
|
||||||
init?: RequestInit & { fetch?: typeof fetch },
|
}
|
||||||
): Promise<T> {
|
|
||||||
|
/**
|
||||||
|
* Create a socket-aware fetch function
|
||||||
|
* Automatically handles Unix socket vs TCP based on UPSTREAM_URL
|
||||||
|
*/
|
||||||
|
function createSmartFetch(upstreamUrl: string | undefined) {
|
||||||
if (!upstreamUrl) {
|
if (!upstreamUrl) {
|
||||||
logger.error("UPSTREAM_URL environment variable not set");
|
const error = "UPSTREAM_URL environment variable not set";
|
||||||
throw new Error("UPSTREAM_URL environment variable not set");
|
logger.error(error);
|
||||||
|
throw new Error(error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isUnixSocket =
|
||||||
|
upstreamUrl.startsWith("/") || upstreamUrl.startsWith("./");
|
||||||
|
const baseUrl = isUnixSocket ? "http://localhost" : upstreamUrl;
|
||||||
|
|
||||||
|
return async function smartFetch<T>(
|
||||||
|
path: string,
|
||||||
|
options?: FetchOptions,
|
||||||
|
): Promise<T> {
|
||||||
const url = `${baseUrl}${path}`;
|
const url = `${baseUrl}${path}`;
|
||||||
const method = init?.method ?? "GET";
|
const method = options?.method ?? "GET";
|
||||||
|
|
||||||
// Unix sockets require Bun's native fetch (SvelteKit's fetch doesn't support it)
|
// Unix sockets require Bun's native fetch
|
||||||
const fetchFn = isUnixSocket ? fetch : (init?.fetch ?? fetch);
|
// SvelteKit's fetch doesn't support the 'unix' option
|
||||||
|
const fetchFn = isUnixSocket ? fetch : (options?.fetch ?? fetch);
|
||||||
|
|
||||||
const fetchOptions: RequestInit & { unix?: string } = {
|
const fetchOptions: BunFetchOptions = {
|
||||||
...init,
|
...options,
|
||||||
signal: init?.signal ?? AbortSignal.timeout(30_000),
|
signal: options?.signal ?? AbortSignal.timeout(30_000),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Remove custom fetch property from options
|
// Remove custom fetch property from options (not part of standard RequestInit)
|
||||||
delete (fetchOptions as Record<string, unknown>).fetch;
|
delete (fetchOptions as Record<string, unknown>).fetch;
|
||||||
|
|
||||||
|
// Add Unix socket path if needed
|
||||||
if (isUnixSocket) {
|
if (isUnixSocket) {
|
||||||
fetchOptions.unix = upstreamUrl;
|
fetchOptions.unix = upstreamUrl;
|
||||||
}
|
}
|
||||||
@@ -40,7 +55,6 @@ export async function apiFetch<T>(
|
|||||||
url,
|
url,
|
||||||
path,
|
path,
|
||||||
isUnixSocket,
|
isUnixSocket,
|
||||||
upstreamUrl,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -67,4 +81,8 @@ export async function apiFetch<T>(
|
|||||||
});
|
});
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Export the configured smart fetch function
|
||||||
|
export const apiFetch = createSmartFetch(env.UPSTREAM_URL);
|
||||||
|
|||||||
Reference in New Issue
Block a user