From 0b5aeceb510566618dfbf5e0f24a8af89fefac05 Mon Sep 17 00:00:00 2001 From: Ryan Walters Date: Fri, 19 Sep 2025 17:35:53 -0500 Subject: [PATCH] feat: improve test reliability and add request tracing - Add retry configuration for flaky tests (2 retries for default, 3 for OAuth) - Configure test groups with proper concurrency limits (serial: 1, integration: 4) - Add tower-http tracing layer with custom span formatting for HTTP requests - Simplify database pool handling by removing unnecessary Arc wrapper - Improve test context setup with better logging and error handling - Refactor user creation parameters for better clarity and consistency - Add debug logging for OAuth cookie handling --- .config/nextest.toml | 11 +- Cargo.lock | 2 + pacman-server/Cargo.toml | 1 + pacman-server/src/app.rs | 38 ++- pacman-server/src/data/user.rs | 20 +- pacman-server/src/logging.rs | 58 ++--- pacman-server/src/main.rs | 2 +- pacman-server/src/routes.rs | 2 + pacman-server/tests/common/mod.rs | 58 +++-- pacman-server/tests/oauth.rs | 379 +++++++++++++++++++++++++----- 10 files changed, 446 insertions(+), 125 deletions(-) diff --git a/.config/nextest.toml b/.config/nextest.toml index b9fb77c..1467b56 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -1,6 +1,7 @@ [profile.default] fail-fast = false -slow-timeout = { period = "5s", terminate-after = 6 } # max 30 seconds +slow-timeout = { period = "5s", terminate-after = 3 } # max 15 seconds +retries = 2 # CI machines are pretty slow, so we need to increase the timeout [profile.ci] @@ -11,12 +12,18 @@ slow-timeout = { period = "30s", terminate-after = 4 } # max 2 minutes for slow slow-timeout = { period = "45s", terminate-after = 5 } # max 3.75 minutes for slow tests status-level = "none" -[[profile.default.overrides]] # Integration tests in SDL2 run serially (may not be required) +[[profile.default.overrides]] filter = 'test(pacman::game::)' test-group = 'serial' +# Integration tests run max 4 at a time +[[profile.default.overrides]] +filter = 'test(pacman-server::tests::oauth)' +test-group = 'integration' +retries = 3 [test-groups] # Ensure serial tests don't run in parallel serial = { max-threads = 1 } +integration = { max-threads = 4 } diff --git a/Cargo.lock b/Cargo.lock index 9f79173..4b1840a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3130,6 +3130,7 @@ dependencies = [ "testcontainers", "time", "tokio 1.47.1", + "tower-http", "tracing", "tracing-futures", "tracing-subscriber", @@ -5474,6 +5475,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] diff --git a/pacman-server/Cargo.toml b/pacman-server/Cargo.toml index d9ecfa2..70142e4 100644 --- a/pacman-server/Cargo.toml +++ b/pacman-server/Cargo.toml @@ -44,6 +44,7 @@ jsonwebtoken = { version = "9.3", default-features = false } tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } tracing-futures = { version = "0.2.5", features = ["tokio"] } +tower-http = { version = "0.6", features = ["trace"] } time = { version = "0.3", features = ["macros", "formatting"] } yansi = "1" s3-tokio = { version = "0.39.6", default-features = false } diff --git a/pacman-server/src/app.rs b/pacman-server/src/app.rs index 61c285e..a4891bc 100644 --- a/pacman-server/src/app.rs +++ b/pacman-server/src/app.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::{Notify, RwLock}; use tokio::task::JoinHandle; +use tracing::info_span; use crate::data::pool::PgPool; use crate::{auth::AuthRegistry, config::Config, image::ImageStorage, routes}; @@ -36,7 +37,7 @@ pub struct AppState { pub sessions: Arc>, pub jwt_encoding_key: Arc, pub jwt_decoding_key: Arc, - pub db: Arc, + pub db: PgPool, pub health: Arc>, pub image_storage: Arc, pub healthchecker_task: Arc>>>, @@ -71,7 +72,7 @@ impl AppState { sessions: Arc::new(DashMap::new()), jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())), jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())), - db: Arc::new(db), + db: db, health: Arc::new(RwLock::new(Health::default())), image_storage, healthchecker_task: Arc::new(RwLock::new(None)), @@ -100,7 +101,7 @@ impl AppState { } // Run the actual health check - let ok = sqlx::query("SELECT 1").execute(&*db_pool).await.is_ok(); + let ok = sqlx::query("SELECT 1").execute(&db_pool).await.is_ok(); { let mut h = health_state.write().await; h.set_database(ok); @@ -127,13 +128,35 @@ impl AppState { /// Force an immediate health check (debug mode only) pub async fn check_health(&self) -> bool { - let ok = sqlx::query("SELECT 1").execute(&*self.db).await.is_ok(); + let ok = sqlx::query("SELECT 1").execute(&self.db).await.is_ok(); let mut h = self.health.write().await; h.set_database(ok); ok } } +/// Create a custom span for HTTP requests with reduced verbosity +pub fn make_span(request: &axum::http::Request) -> tracing::Span { + let path = request + .uri() + .path_and_query() + .map(|v| v.as_str()) + .unwrap_or_else(|| request.uri().path()); + + if request.method() == axum::http::Method::GET { + info_span!( + "request", + path = %path, + ) + } else { + info_span!( + "request", + method = %request.method(), + path = %path, + ) + } +} + /// Create the application router with all routes and middleware pub fn create_router(app_state: AppState) -> Router { Router::new() @@ -147,6 +170,13 @@ pub fn create_router(app_state: AppState) -> Router { .with_state(app_state) .layer(CookieLayer::default()) .layer(axum::middleware::from_fn(inject_server_header)) + .layer( + tower_http::trace::TraceLayer::new_for_http() + .make_span_with(make_span) + .on_request(|_request: &axum::http::Request, _span: &tracing::Span| { + // Disable request logging by doing nothing + }), + ) } /// Inject the server header into responses diff --git a/pacman-server/src/data/user.rs b/pacman-server/src/data/user.rs index 046973a..942f2f0 100644 --- a/pacman-server/src/data/user.rs +++ b/pacman-server/src/data/user.rs @@ -68,10 +68,10 @@ pub async fn link_oauth_account( pub async fn create_user( pool: &sqlx::PgPool, - username: &str, - display_name: Option<&str>, - email: Option<&str>, - avatar_url: Option<&str>, + provider_username: &str, + provider_display_name: Option<&str>, + provider_email: Option<&str>, + provider_avatar_url: Option<&str>, provider: &str, provider_user_id: &str, ) -> Result { @@ -82,20 +82,20 @@ pub async fn create_user( RETURNING id, email, created_at, updated_at "#, ) - .bind(email) + .bind(provider_email) .fetch_one(pool) .await?; // Create oauth link - let _ = link_oauth_account( + let _linked = link_oauth_account( pool, user.id, provider, provider_user_id, - email, - Some(username), - display_name, - avatar_url, + provider_email, + Some(provider_username), + provider_display_name, + provider_avatar_url, ) .await?; diff --git a/pacman-server/src/logging.rs b/pacman-server/src/logging.rs index a61ed4c..37b1464 100644 --- a/pacman-server/src/logging.rs +++ b/pacman-server/src/logging.rs @@ -1,36 +1,38 @@ -use tracing_subscriber::fmt::format::JsonFields; -use tracing_subscriber::{EnvFilter, FmtSubscriber}; +use tracing_subscriber::{fmt::format::JsonFields, EnvFilter, FmtSubscriber}; -use crate::config::Config; use crate::formatter; +static SUBSCRIBER_INIT: std::sync::Once = std::sync::Once::new(); + /// Configure and initialize logging for the application -pub fn setup_logging(_config: &Config) { - // Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere - let filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME")))); +pub fn setup_logging() { + SUBSCRIBER_INIT.call_once(|| { + // Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME")))); - // Default to pretty for local dev; switchable later if we add CLI - let use_pretty = cfg!(debug_assertions); + // Default to pretty for local dev; switchable later if we add CLI + let use_pretty = cfg!(debug_assertions); - let subscriber: Box = if use_pretty { - Box::new( - FmtSubscriber::builder() - .with_target(true) - .event_format(formatter::CustomPrettyFormatter) - .with_env_filter(filter) - .finish(), - ) - } else { - Box::new( - FmtSubscriber::builder() - .with_target(true) - .event_format(formatter::CustomJsonFormatter) - .fmt_fields(JsonFields::new()) - .with_env_filter(filter) - .finish(), - ) - }; + let subscriber: Box = if use_pretty { + Box::new( + FmtSubscriber::builder() + .with_target(true) + .event_format(formatter::CustomPrettyFormatter) + .with_env_filter(filter) + .finish(), + ) + } else { + Box::new( + FmtSubscriber::builder() + .with_target(true) + .event_format(formatter::CustomJsonFormatter) + .fmt_fields(JsonFields::new()) + .with_env_filter(filter) + .finish(), + ) + }; - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + }); } diff --git a/pacman-server/src/main.rs b/pacman-server/src/main.rs index a44a314..3b809cc 100644 --- a/pacman-server/src/main.rs +++ b/pacman-server/src/main.rs @@ -45,7 +45,7 @@ async fn main() { let config: Config = config::load_config(); // Initialize tracing subscriber - logging::setup_logging(&config); + logging::setup_logging(); trace!(host = %config.host, port = config.port, shutdown_timeout_seconds = config.shutdown_timeout_seconds, "Loaded server configuration"); let addr = std::net::SocketAddr::new(config.host, config.port); diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 476b11d..903eaec 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -93,6 +93,8 @@ pub async fn oauth_callback_handler( } }; + debug!(cookies = ?cookie.cookie().iter().collect::>(), "Cookies"); + // Linking or sign-in flow. Determine link intent from cookie (set at authorize time) let link_cookie = cookie.get("link").map(|c| c.value().to_string()); if link_cookie.is_some() { diff --git a/pacman-server/tests/common/mod.rs b/pacman-server/tests/common/mod.rs index 57fc7d0..71f6e1f 100644 --- a/pacman-server/tests/common/mod.rs +++ b/pacman-server/tests/common/mod.rs @@ -12,6 +12,7 @@ use testcontainers::{ ContainerAsync, GenericImage, ImageExt, }; use tokio::sync::Notify; +use tracing::{debug, debug_span, Instrument}; static CRYPTO_INIT: Once = Once::new(); @@ -26,16 +27,40 @@ pub struct TestContext { } #[builder] -pub async fn test_context(use_database: bool, auth_registry: Option) -> TestContext { +pub async fn test_context(#[builder(default = false)] use_database: bool, auth_registry: Option) -> TestContext { CRYPTO_INIT.call_once(|| { rustls::crypto::ring::default_provider() .install_default() .expect("Failed to install default crypto provider"); }); + // Set up logging + std::env::set_var("RUST_LOG", "debug,sqlx=info"); + pacman_server::logging::setup_logging(); let (database_url, container) = if use_database { - let (url, container) = setup_test_database("testdb", "testuser", "testpass").await; - (Some(url), Some(container)) + let db = "testdb"; + let user = "testuser"; + let password = "testpass"; + + // Create container request + let container_request = GenericImage::new("postgres", "15") + .with_exposed_port(5432.tcp()) + .with_wait_for(WaitFor::message_on_stderr("database system is ready to accept connections")) + .with_env_var("POSTGRES_DB", db) + .with_env_var("POSTGRES_USER", user) + .with_env_var("POSTGRES_PASSWORD", password); + + tracing::trace!(request = ?container_request, "Acquiring postgres testcontainer"); + let container = container_request.start().await.unwrap(); + let host = container.get_host().await.unwrap(); + let port = container.get_host_port_ipv4(5432).await.unwrap(); + + tracing::debug!(host = %host, port = %port, "Test database ready"); + + ( + Some(format!("postgresql://{user}:{password}@{host}:{port}/{db}?sslmode=disable")), + Some(container), + ) } else { (None, None) }; @@ -63,8 +88,10 @@ pub async fn test_context(use_database: bool, auth_registry: Option (String, ContainerAsync) { - let container = GenericImage::new("postgres", "15") - .with_exposed_port(5432.tcp()) - .with_wait_for(WaitFor::message_on_stderr("database system is ready to accept connections")) - .with_env_var("POSTGRES_DB", db) - .with_env_var("POSTGRES_USER", user) - .with_env_var("POSTGRES_PASSWORD", password) - .start() - .await - .unwrap(); - - let host = container.get_host().await.unwrap(); - let port = container.get_host_port_ipv4(5432).await.unwrap(); - - ( - format!("postgresql://{user}:{password}@{host}:{port}/{db}?sslmode=disable"), - container, - ) -} diff --git a/pacman-server/tests/oauth.rs b/pacman-server/tests/oauth.rs index d2bf302..ec56ba6 100644 --- a/pacman-server/tests/oauth.rs +++ b/pacman-server/tests/oauth.rs @@ -5,6 +5,7 @@ use pacman_server::auth::{ AuthRegistry, }; use pretty_assertions::assert_eq; +use time::Duration; mod common; use crate::common::{test_context, TestContext}; @@ -12,7 +13,7 @@ use crate::common::{test_context, TestContext}; /// Test OAuth authorization flows #[tokio::test] async fn test_oauth_authorization_flows() { - let TestContext { server, .. } = test_context().use_database(false).call().await; + let TestContext { server, .. } = test_context().call().await; // Test OAuth authorize endpoint (should redirect) let response = server.get("/auth/github").await; @@ -30,7 +31,7 @@ async fn test_oauth_authorization_flows() { /// Test OAuth callback handling #[tokio::test] async fn test_oauth_callback_handling() { - let TestContext { server, .. } = test_context().use_database(false).call().await; + let TestContext { server, .. } = test_context().call().await; // Test OAuth callback with missing parameters (should fail gracefully) let response = server.get("/auth/github/callback").await; @@ -53,7 +54,7 @@ async fn test_oauth_authorization_flow() { providers: HashMap::from([("mock", provider)]), }; - let TestContext { server, .. } = test_context().use_database(false).auth_registry(mock_registry).call().await; + let TestContext { server, .. } = test_context().auth_registry(mock_registry).call().await; // Test that valid handlers redirect let response = server.get("/auth/mock").await; @@ -84,93 +85,361 @@ async fn test_oauth_authorization_flow() { /// Test OAuth callback validation #[tokio::test] async fn test_oauth_callback_validation() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback() + .times(0) // Should not be called + .returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), + username: "testuser".to_string(), + name: None, + email: Some("test@example.com".to_string()), + avatar_url: None, + }) + }); + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; - // TODO: Test that the OAuth callback handler validates the provider exists before processing - // TODO: Test that the OAuth callback handler returns an error when the provider returns an OAuth error - // TODO: Test that the OAuth callback handler returns an error when the authorization code is missing - // TODO: Test that the OAuth callback handler returns an error when the state parameter is missing + let TestContext { server, .. } = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Test that an unknown provider returns an error + let response = server.get("/auth/unknown/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 400); + + // Test that a provider-returned error is handled + let response = server.get("/auth/mock/callback?error=access_denied").await; + assert_eq!(response.status_code(), 400); + + // Test that a missing code returns an error + let response = server.get("/auth/mock/callback?state=b").await; + assert_eq!(response.status_code(), 400); + + // Test that a missing state returns an error + let response = server.get("/auth/mock/callback?code=a").await; + assert_eq!(response.status_code(), 400); } /// Test OAuth callback processing #[tokio::test] async fn test_oauth_callback_processing() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), + username: "testuser".to_string(), + name: None, + email: Some("processing@example.com".to_string()), + avatar_url: None, + }) + }); + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; - // TODO: Test that the OAuth callback handler exchanges the authorization code for user information successfully - // TODO: Test that the OAuth callback handler handles provider callback errors gracefully - // TODO: Test that the OAuth callback handler creates a session token after successful authentication - // TODO: Test that the OAuth callback handler sets a session cookie after successful authentication - // TODO: Test that the OAuth callback handler redirects to the profile page after successful authentication + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Test that a successful callback redirects and sets a session cookie + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); + assert_eq!(response.headers().get("location").unwrap(), "/profile"); + assert!(response.maybe_cookie("session").is_some()); } /// Test account linking flow #[tokio::test] async fn test_account_linking_flow() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; + let mut initial_provider_mock = MockOAuthProvider::new(); + initial_provider_mock.expect_handle_callback().returning(move |_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), + username: "linkuser".to_string(), + name: None, + email: Some("link@example.com".to_string()), + avatar_url: None, + }) + }); + let initial_provider: Arc = Arc::new(initial_provider_mock); - // TODO: Test that the OAuth callback handler links a new provider to an existing user when link intent is present and session is valid - // TODO: Test that the OAuth callback handler redirects to profile after successful account linking - // TODO: Test that the OAuth callback handler falls back to normal sign-in when link intent is present but no valid session exists + let mut link_provider_mock = MockOAuthProvider::new(); + link_provider_mock.expect_authorize().returning(|_| { + Ok(pacman_server::auth::provider::AuthorizeInfo { + authorize_url: "https://example.com".parse().unwrap(), + session_token: "b_token".to_string(), + }) + }); + link_provider_mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "456".to_string(), + username: "linkuser_new".to_string(), + name: None, + email: Some("link@example.com".to_string()), + avatar_url: None, + }) + }); + let link_provider: Arc = Arc::new(link_provider_mock); + + let registry = AuthRegistry { + providers: HashMap::from([("mock_initial", initial_provider), ("mock_link", link_provider)]), + }; + let context = test_context().use_database(true).auth_registry(registry).call().await; + + // 1. Create an initial user and provider link + let user = pacman_server::data::user::create_user( + &context.app_state.db, + "linkuser", + None, + Some("link@example.com"), + None, + "mock_initial", + "123", + ) + .await + .expect("Failed to create user"); + + { + let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) + .await + .expect("Failed to list user's initial provider(s)"); + assert_eq!(providers.len(), 1, "User should have one provider"); + assert!(providers.iter().any(|p| p.provider == "mock_initial")); + } + + // 2. Create a session for this user + let session_cookie = { + let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); + assert!(response.maybe_cookie("session").is_some(), "Session cookie should be set"); + + response.cookie("session").clone() + }; + tracing::debug!(cookie = %session_cookie, "Session cookie acquired"); + + // Begin linking flow + let link_cookie = { + let response = context + .server + .get("/auth/mock_link?link=true") + .add_cookie(session_cookie.clone()) + .await; + assert_eq!(response.status_code(), 303); + assert_eq!(response.maybe_cookie("link").unwrap().value(), "1"); + + response.cookie("link").clone() + }; + tracing::debug!(cookie = %link_cookie, "Link cookie acquired"); + + // 3. Perform the linking call + let response = context + .server + .get("/auth/mock_link/callback?code=a&state=b") + .add_cookie(link_cookie) + .add_cookie(session_cookie.clone()) + .await; + + assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect"); + assert_eq!( + response.headers().get("location").unwrap(), + "/profile", + "Post-linking response should redirect to /profile" + ); + + let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) + .await + .expect("Failed to list user's providers"); + assert_eq!(providers.len(), 2, "User should have two providers"); + assert!(providers.iter().any(|p| p.provider == "mock_initial")); + assert!(providers.iter().any(|p| p.provider == "mock_link")); } /// Test new user registration #[tokio::test] async fn test_new_user_registration() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), + username: "testuser".to_string(), + name: None, + email: Some("newuser@example.com".to_string()), + avatar_url: None, + }) + }); - // TODO: Test that the OAuth callback handler creates a new user account when no existing user is found - // TODO: Test that the OAuth callback handler requires an email address for all sign-ins - // TODO: Test that the OAuth callback handler rejects sign-in attempts when no email is available + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; + + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Test that the OAuth callback handler creates a new user account when no existing user is found + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); + assert_eq!(response.headers().get("location").unwrap(), "/profile"); + let user = pacman_server::data::user::find_user_by_email(&context.app_state.db, "newuser@example.com") + .await + .unwrap() + .unwrap(); + assert_eq!(user.email, Some("newuser@example.com".to_string())); } -/// Test existing user sign-in +/// Test OAuth callback handler rejects sign-in attempts when no email is available #[tokio::test] -async fn test_existing_user_sign_in() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; +async fn test_oauth_callback_no_email() { + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "456".to_string(), + username: "noemailuser".to_string(), + name: None, + email: None, + avatar_url: None, + }) + }); - // TODO: Test that the OAuth callback handler allows sign-in when the provider is already linked to an existing user - // TODO: Test that the OAuth callback handler requires explicit linking when a user with the same email exists and has other providers linked - // TODO: Test that the OAuth callback handler auto-links a provider when a user exists but has no other providers linked + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; + + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Test that the OAuth callback handler rejects sign-in attempts when no email is available + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 400); } -/// Test avatar processing +/// Test existing user sign-in with new provider fails #[tokio::test] -async fn test_avatar_processing() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; +async fn test_existing_user_sign_in_new_provider_fails() { + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(move |_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "456".to_string(), // Different provider ID + username: "existinguser_newprovider".to_string(), + name: None, + email: Some("existing@example.com".to_string()), + avatar_url: None, + }) + }); + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock_new", provider)]), + }; - // TODO: Test that the OAuth callback handler processes user avatars asynchronously without blocking the response - // TODO: Test that the OAuth callback handler handles avatar processing errors gracefully + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Create a user with one linked provider + pacman_server::data::user::create_user( + &context.app_state.db, + "existinguser", + None, + Some("existing@example.com"), + None, + "mock", + "123", + ) + .await + .unwrap(); + + // A user with the email exists, but has one provider. If they sign in with a *new* provider, it should fail. + let response = context.server.get("/auth/mock_new/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 400); // Should fail and ask to link explicitly. } -/// Test profile access +/// Test existing user sign-in with existing provider succeeds #[tokio::test] -async fn test_profile_access() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; +async fn test_existing_user_sign_in_existing_provider_succeeds() { + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(move |_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), // Same provider ID as created user + username: "existinguser".to_string(), + name: None, + email: Some("existing@example.com".to_string()), + avatar_url: None, + }) + }); + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; - // TODO: Test that the profile handler returns user information when a valid session exists - // TODO: Test that the profile handler returns an error when no session cookie is present - // TODO: Test that the profile handler returns an error when the session token is invalid - // TODO: Test that the profile handler includes linked providers in the response - // TODO: Test that the profile handler returns an error when the user is not found in the database + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Create a user with one linked provider + pacman_server::data::user::create_user( + &context.app_state.db, + "existinguser", + None, + Some("existing@example.com"), + None, + "mock", + "123", + ) + .await + .unwrap(); + + // Test sign-in with an existing linked provider. + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); // Should sign in successfully + assert_eq!(response.headers().get("location").unwrap(), "/profile"); } /// Test logout functionality #[tokio::test] async fn test_logout_functionality() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(pacman_server::auth::provider::AuthUser { + id: "123".to_string(), + username: "testuser".to_string(), + name: None, + email: Some("test@example.com".to_string()), + avatar_url: None, + }) + }); + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; - // TODO: Test that the logout handler clears the session if a session was there - // TODO: Test that the logout handler removes the session from memory storage - // TODO: Test that the logout handler clears the session cookie - // TODO: Test that the logout handler redirects to the home page after logout -} - -/// Test provider configuration -#[tokio::test] -async fn test_provider_configuration() { - let TestContext { server: _server, .. } = test_context().use_database(false).call().await; - - // TODO: Test that the providers list handler returns all configured OAuth providers - // TODO: Test that the providers list handler includes provider status (active/inactive) + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Sign in to establish a session + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); + let session_cookie = response.cookie("session").clone(); + + // Test that the logout handler clears the session cookie and redirects + let response = context + .server + .get("/logout") + .add_cookie(cookie::Cookie::new( + session_cookie.name().to_string(), + session_cookie.value().to_string(), + )) + .await; + + // Redirect assertions + assert_eq!(response.status_code(), 302); + assert!( + response.headers().contains_key("location"), + "Response redirect should have a location header" + ); + + // Cookie assertions + assert_eq!( + response.cookie("session").value(), + "removed", + "Session cookie should be removed" + ); + assert_eq!( + response.cookie("session").max_age(), + Some(Duration::ZERO), + "Session cookie should have a max age of 0" + ); }