feat: allow health check forcing in debug, setup test mocking, plan out integration tests

This commit is contained in:
Ryan Walters
2025-09-18 22:42:00 -05:00
parent 350f92ab21
commit e2c725cb95
8 changed files with 292 additions and 82 deletions

81
Cargo.lock generated
View File

@@ -50,6 +50,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anstyle"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.99" version = "1.0.99"
@@ -1236,6 +1242,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "downcast"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
[[package]] [[package]]
name = "downcast-rs" name = "downcast-rs"
version = "2.0.2" version = "2.0.2"
@@ -1564,6 +1576,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fragile"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619"
[[package]] [[package]]
name = "fs_extra" name = "fs_extra"
version = "1.3.0" version = "1.3.0"
@@ -2453,7 +2471,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [ dependencies = [
"cfg-if 1.0.3", "cfg-if 1.0.3",
"windows-targets 0.52.6", "windows-targets 0.53.3",
] ]
[[package]] [[package]]
@@ -2693,6 +2711,32 @@ dependencies = [
"ws2_32-sys", "ws2_32-sys",
] ]
[[package]]
name = "mockall"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
dependencies = [
"cfg-if 1.0.3",
"downcast",
"fragile",
"mockall_derive",
"predicates",
"predicates-tree",
]
[[package]]
name = "mockall_derive"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
dependencies = [
"cfg-if 1.0.3",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "moxcms" name = "moxcms"
version = "0.7.5" version = "0.7.5"
@@ -3041,6 +3085,7 @@ dependencies = [
"hyper-util", "hyper-util",
"image", "image",
"jsonwebtoken", "jsonwebtoken",
"mockall",
"oauth2", "oauth2",
"pretty_assertions", "pretty_assertions",
"reqwest", "reqwest",
@@ -3351,6 +3396,32 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "predicates"
version = "3.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
dependencies = [
"anstyle",
"predicates-core",
]
[[package]]
name = "predicates-core"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
[[package]]
name = "predicates-tree"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
dependencies = [
"predicates-core",
"termtree",
]
[[package]] [[package]]
name = "pretty_assertions" name = "pretty_assertions"
version = "1.4.1" version = "1.4.1"
@@ -3530,7 +3601,7 @@ dependencies = [
"once_cell", "once_cell",
"socket2 0.6.0", "socket2 0.6.0",
"tracing", "tracing",
"windows-sys 0.59.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
@@ -4832,6 +4903,12 @@ dependencies = [
"windows-sys 0.61.0", "windows-sys 0.61.0",
] ]
[[package]]
name = "termtree"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]] [[package]]
name = "testcontainers" name = "testcontainers"
version = "0.25.0" version = "0.25.0"

View File

@@ -47,6 +47,7 @@ rustls = { version = "0.23", features = ["ring"] }
fast_image_resize = { version = "5.3", features = ["image"] } fast_image_resize = { version = "5.3", features = ["image"] }
image = { version = "0.25", features = ["png", "jpeg"] } image = { version = "0.25", features = ["png", "jpeg"] }
sha2 = "0.10" sha2 = "0.10"
mockall = "0.13.1"
# validator = { version = "0.16", features = ["derive"] } # validator = { version = "0.16", features = ["derive"] }
[dev-dependencies] [dev-dependencies]

View File

@@ -3,25 +3,20 @@ use axum_cookie::CookieLayer;
use dashmap::DashMap; use dashmap::DashMap;
use jsonwebtoken::{DecodingKey, EncodingKey}; use jsonwebtoken::{DecodingKey, EncodingKey};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use std::time::Duration;
use tokio::sync::{Notify, RwLock};
use tokio::task::JoinHandle;
use crate::data::pool::PgPool; use crate::data::pool::PgPool;
use crate::{auth::AuthRegistry, config::Config, image::ImageStorage, routes}; use crate::{auth::AuthRegistry, config::Config, image::ImageStorage, routes};
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct Health { pub struct Health {
migrations: bool, migrations: bool,
database: bool, database: bool,
} }
impl Health { impl Health {
pub fn new() -> Self {
Self {
migrations: false,
database: false,
}
}
pub fn ok(&self) -> bool { pub fn ok(&self) -> bool {
self.migrations && self.database self.migrations && self.database
} }
@@ -44,10 +39,11 @@ pub struct AppState {
pub db: Arc<PgPool>, pub db: Arc<PgPool>,
pub health: Arc<RwLock<Health>>, pub health: Arc<RwLock<Health>>,
pub image_storage: Arc<ImageStorage>, pub image_storage: Arc<ImageStorage>,
pub healthchecker_task: Arc<RwLock<Option<JoinHandle<()>>>>,
} }
impl AppState { impl AppState {
pub fn new(config: Config, auth: AuthRegistry, db: PgPool) -> Self { pub async fn new(config: Config, auth: AuthRegistry, db: PgPool, shutdown_notify: Arc<Notify>) -> Self {
let jwt_secret = config.jwt_secret.clone(); let jwt_secret = config.jwt_secret.clone();
// Initialize image storage // Initialize image storage
@@ -60,15 +56,71 @@ impl AppState {
} }
}; };
Self { let app_state = Self {
auth: Arc::new(auth), auth: Arc::new(auth),
sessions: Arc::new(DashMap::new()), sessions: Arc::new(DashMap::new()),
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())), jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())), jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
db: Arc::new(db), db: Arc::new(db),
health: Arc::new(RwLock::new(Health::new())), health: Arc::new(RwLock::new(Health::default())),
image_storage, image_storage,
healthchecker_task: Arc::new(RwLock::new(None)),
};
// Start the healthchecker task
{
let health_state = app_state.health.clone();
let db_pool = app_state.db.clone();
let healthchecker_task = app_state.healthchecker_task.clone();
let task = tokio::spawn(async move {
tracing::trace!("Health checker task started");
let mut backoff: u32 = 1;
let mut next_sleep = Duration::from_secs(0);
loop {
tokio::select! {
_ = shutdown_notify.notified() => {
tracing::trace!("Health checker received shutdown notification; exiting");
break;
}
_ = tokio::time::sleep(next_sleep) => {
// Run health check
}
}
// Run the actual health check
let ok = sqlx::query("SELECT 1").execute(&*db_pool).await.is_ok();
{
let mut h = health_state.write().await;
h.set_database(ok);
}
if ok {
tracing::trace!(database_ok = true, "Health check succeeded; scheduling next run in 90s");
backoff = 1;
next_sleep = Duration::from_secs(90);
} else {
backoff = (backoff.saturating_mul(2)).min(60);
tracing::trace!(database_ok = false, backoff, "Health check failed; backing off");
next_sleep = Duration::from_secs(backoff as u64);
}
}
});
// Store the task handle
let mut task_handle = healthchecker_task.write().await;
*task_handle = Some(task);
} }
app_state
}
/// Force an immediate health check (debug mode only)
pub async fn check_health(&self) -> bool {
let ok = sqlx::query("SELECT 1").execute(&*self.db).await.is_ok();
let mut h = self.health.write().await;
h.set_database(ok);
ok
} }
} }

View File

@@ -1,4 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use mockall::automock;
use serde::Serialize; use serde::Serialize;
use crate::errors::ErrorResponse; use crate::errors::ErrorResponse;
@@ -12,6 +13,7 @@ pub struct AuthUser {
pub avatar_url: Option<String>, pub avatar_url: Option<String>,
} }
#[automock]
#[async_trait] #[async_trait]
pub trait OAuthProvider: Send + Sync { pub trait OAuthProvider: Send + Sync {
fn id(&self) -> &'static str; fn id(&self) -> &'static str;

View File

@@ -3,8 +3,8 @@ use crate::{
auth::AuthRegistry, auth::AuthRegistry,
config::Config, config::Config,
}; };
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use std::{sync::Arc, time::Duration};
use tracing::{info, trace, warn}; use tracing::{info, trace, warn};
#[cfg(unix)] #[cfg(unix)]
@@ -47,17 +47,16 @@ async fn main() {
panic!("failed to run database migrations: {}", e); panic!("failed to run database migrations: {}", e);
} }
let app_state = AppState::new(config, auth, db); // Create the shutdown notification before creating AppState
let notify = Arc::new(Notify::new());
let app_state = AppState::new(config, auth, db, notify.clone()).await;
{ {
// migrations succeeded // migrations succeeded
let mut h = app_state.health.write().await; let mut h = app_state.health.write().await;
h.set_migrations(true); h.set_migrations(true);
} }
// Extract needed parts for health checker before moving app_state
let health_state = app_state.health.clone();
let db_pool = app_state.db.clone();
let app = create_router(app_state); let app = create_router(app_state);
info!(%addr, "Starting HTTP server bind"); info!(%addr, "Starting HTTP server bind");
@@ -65,45 +64,8 @@ async fn main() {
info!(%addr, "HTTP server listening"); info!(%addr, "HTTP server listening");
// coordinated graceful shutdown with timeout // coordinated graceful shutdown with timeout
let notify = Arc::new(Notify::new());
let (tx_signal, rx_signal) = watch::channel::<Option<Instant>>(None); let (tx_signal, rx_signal) = watch::channel::<Option<Instant>>(None);
// Spawn background health checker (listens for shutdown via notify)
{
let health_state = health_state.clone();
let db_pool = db_pool.clone();
let notify_for_health = notify.clone();
tokio::spawn(async move {
trace!("Health checker task started");
let mut backoff: u32 = 1;
let mut next_sleep = Duration::from_secs(0);
loop {
tokio::select! {
_ = notify_for_health.notified() => {
trace!("Health checker received shutdown notification; exiting");
break;
}
_ = tokio::time::sleep(next_sleep) => {
let ok = sqlx::query("SELECT 1").execute(&*db_pool).await.is_ok();
{
let mut h = health_state.write().await;
h.set_database(ok);
}
if ok {
trace!(database_ok = true, "Health check succeeded; scheduling next run in 90s");
backoff = 1;
next_sleep = Duration::from_secs(90);
} else {
backoff = (backoff.saturating_mul(2)).min(60);
trace!(database_ok = false, backoff, "Health check failed; backing off");
next_sleep = Duration::from_secs(backoff as u64);
}
}
}
}
});
}
{ {
let notify = notify.clone(); let notify = notify.clone();
let tx = tx_signal.clone(); let tx = tx_signal.clone();
@@ -127,9 +89,8 @@ async fn main() {
std::process::exit(1); std::process::exit(1);
}; };
let notify_for_server = notify.clone();
let server = axum::serve(listener, app).with_graceful_shutdown(async move { let server = axum::serve(listener, app).with_graceful_shutdown(async move {
notify_for_server.notified().await; notify.notified().await;
}); });
tokio::select! { tokio::select! {

View File

@@ -371,7 +371,16 @@ pub async fn list_providers_handler(State(app_state): State<AppState>) -> axum::
axum::Json(providers).into_response() axum::Json(providers).into_response()
} }
pub async fn health_handler(State(app_state): State<AppState>) -> axum::response::Response { pub async fn health_handler(
State(app_state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> axum::response::Response {
// Force health check in debug mode
#[cfg(debug_assertions)]
if params.get("force").is_some() {
app_state.check_health().await;
}
let ok = app_state.health.read().await.ok(); let ok = app_state.health.read().await.ok();
let status = if ok { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE }; let status = if ok { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE };
let body = serde_json::json!({ "ok": ok }); let body = serde_json::json!({ "ok": ok });

View File

@@ -4,16 +4,18 @@ use pacman_server::{
auth::AuthRegistry, auth::AuthRegistry,
config::Config, config::Config,
}; };
use std::sync::Arc;
use testcontainers::{ use testcontainers::{
core::{IntoContainerPort, WaitFor}, core::{IntoContainerPort, WaitFor},
runners::AsyncRunner, runners::AsyncRunner,
ContainerAsync, GenericImage, ImageExt, ContainerAsync, GenericImage, ImageExt,
}; };
use tokio::sync::Notify;
/// Test configuration for integration tests /// Test configuration for integration tests
pub struct TestConfig { pub struct TestConfig {
pub database_url: String, pub database_url: String,
pub _container: ContainerAsync<GenericImage>, pub container: ContainerAsync<GenericImage>,
pub config: Config, pub config: Config,
} }
@@ -45,7 +47,7 @@ impl TestConfig {
Self { Self {
database_url, database_url,
_container: container, container,
config, config,
} }
} }
@@ -87,7 +89,8 @@ pub async fn create_test_app_state(test_config: &TestConfig) -> AppState {
let auth = AuthRegistry::new(&test_config.config).expect("Failed to create auth registry"); let auth = AuthRegistry::new(&test_config.config).expect("Failed to create auth registry");
// Create app state // Create app state
let app_state = AppState::new(test_config.config.clone(), auth, db); let notify = Arc::new(Notify::new());
let app_state = AppState::new(test_config.config.clone(), auth, db, notify).await;
// Set health status to true for tests (migrations and database are both working) // Set health status to true for tests (migrations and database are both working)
{ {

View File

@@ -1,8 +1,10 @@
use axum_test::TestServer; use axum_test::TestServer;
use mockall::predicate::*;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
mod common; mod common;
use common::{create_test_app_state, create_test_router, TestConfig}; use common::{create_test_app_state, create_test_router, TestConfig};
// OAuth provider imports removed as they're not used in these health tests
/// Common setup function for all tests /// Common setup function for all tests
async fn setup_test_server() -> TestServer { async fn setup_test_server() -> TestServer {
@@ -23,30 +25,29 @@ async fn test_basic_endpoints() {
assert_eq!(response.text(), "Hello, World! Visit /auth/github to start OAuth flow."); assert_eq!(response.text(), "Hello, World! Visit /auth/github to start OAuth flow.");
} }
/// Test health endpoint functionality /// Test health endpoint functionality with real database connectivity
#[tokio::test] #[tokio::test]
async fn test_health_endpoint() { async fn test_health_endpoint() {
let server = setup_test_server().await; let test_config = TestConfig::new().await;
let app_state = create_test_app_state(&test_config).await;
// Test health endpoint - wait for health checker to complete initial run let router = create_test_router(app_state.clone());
tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; let server = TestServer::new(router).unwrap();
let mut health_ok = false; // First, verify health endpoint works when database is healthy
let start = tokio::time::Instant::now(); let response = server.get("/health").await;
let timeout = tokio::time::Duration::from_secs(3); assert_eq!(response.status_code(), 200);
while start.elapsed() < timeout { let health_json: serde_json::Value = response.json();
let response = server.get("/health").await; assert_eq!(health_json["ok"], true);
if response.status_code() == 200 {
let health_json: serde_json::Value = response.json();
if health_json["ok"] == true {
health_ok = true;
break;
}
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
assert!(health_ok, "Health endpoint did not return ok=true within 3 seconds"); // Now kill the database container to simulate database failure
drop(test_config.container);
// Now verify health endpoint reports bad health
let response = server.get("/health?force").await;
assert_eq!(response.status_code(), 503); // SERVICE_UNAVAILABLE
let health_json: serde_json::Value = response.json();
assert_eq!(health_json["ok"], false);
} }
/// Test OAuth provider listing and configuration /// Test OAuth provider listing and configuration
@@ -136,3 +137,107 @@ async fn test_database_operations() {
let health_json: serde_json::Value = response.json(); let health_json: serde_json::Value = response.json();
assert_eq!(health_json["ok"], true); assert_eq!(health_json["ok"], true);
} }
/// Test OAuth authorization flow
#[tokio::test]
async fn test_oauth_authorization_flow() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth authorize handler redirects to the provider's authorization page for valid providers
// TODO: Test that the OAuth authorize handler returns an error for unknown providers
// TODO: Test that the OAuth authorize handler sets a link cookie when the link parameter is true
}
/// Test OAuth callback validation
#[tokio::test]
async fn test_oauth_callback_validation() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler validates the provider exists before processing
// TODO: Test that the OAuth callback handler returns an error when the provider returns an OAuth error
// TODO: Test that the OAuth callback handler returns an error when the authorization code is missing
// TODO: Test that the OAuth callback handler returns an error when the state parameter is missing
}
/// Test OAuth callback processing
#[tokio::test]
async fn test_oauth_callback_processing() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler exchanges the authorization code for user information successfully
// TODO: Test that the OAuth callback handler handles provider callback errors gracefully
// TODO: Test that the OAuth callback handler creates a session token after successful authentication
// TODO: Test that the OAuth callback handler sets a session cookie after successful authentication
// TODO: Test that the OAuth callback handler redirects to the profile page after successful authentication
}
/// Test account linking flow
#[tokio::test]
async fn test_account_linking_flow() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler links a new provider to an existing user when link intent is present and session is valid
// TODO: Test that the OAuth callback handler redirects to profile after successful account linking
// TODO: Test that the OAuth callback handler falls back to normal sign-in when link intent is present but no valid session exists
}
/// Test new user registration
#[tokio::test]
async fn test_new_user_registration() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler creates a new user account when no existing user is found
// TODO: Test that the OAuth callback handler requires an email address for all sign-ins
// TODO: Test that the OAuth callback handler rejects sign-in attempts when no email is available
}
/// Test existing user sign-in
#[tokio::test]
async fn test_existing_user_sign_in() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler allows sign-in when the provider is already linked to an existing user
// TODO: Test that the OAuth callback handler requires explicit linking when a user with the same email exists and has other providers linked
// TODO: Test that the OAuth callback handler auto-links a provider when a user exists but has no other providers linked
}
/// Test avatar processing
#[tokio::test]
async fn test_avatar_processing() {
let _server = setup_test_server().await;
// TODO: Test that the OAuth callback handler processes user avatars asynchronously without blocking the response
// TODO: Test that the OAuth callback handler handles avatar processing errors gracefully
}
/// Test profile access
#[tokio::test]
async fn test_profile_access() {
let _server = setup_test_server().await;
// TODO: Test that the profile handler returns user information when a valid session exists
// TODO: Test that the profile handler returns an error when no session cookie is present
// TODO: Test that the profile handler returns an error when the session token is invalid
// TODO: Test that the profile handler includes linked providers in the response
// TODO: Test that the profile handler returns an error when the user is not found in the database
}
/// Test logout functionality
#[tokio::test]
async fn test_logout_functionality() {
let _server = setup_test_server().await;
// TODO: Test that the logout handler clears the session if a session was there
// TODO: Test that the logout handler removes the session from memory storage
// TODO: Test that the logout handler clears the session cookie
// TODO: Test that the logout handler redirects to the home page after logout
}
/// Test provider configuration
#[tokio::test]
async fn test_provider_configuration() {
let _server = setup_test_server().await;
// TODO: Test that the providers list handler returns all configured OAuth providers
// TODO: Test that the providers list handler includes provider status (active/inactive)
}