From c524fdb3e7736e75e21556ced17c0bea487281d4 Mon Sep 17 00:00:00 2001 From: Ryan Walters Date: Wed, 24 Sep 2025 13:38:31 -0500 Subject: [PATCH] fix: rewrite oauth provider linking system, add email_verified attribute for providers --- Cargo.toml | 4 + build.rs | 4 + pacman-server/src/auth/discord.rs | 9 +- pacman-server/src/auth/github.rs | 37 ++- pacman-server/src/auth/provider.rs | 2 + pacman-server/src/data/user.rs | 46 +-- pacman-server/src/routes.rs | 253 ++++----------- pacman-server/tests/oauth.rs | 481 ++++++++++------------------- pacman-server/tests/sessions.rs | 33 +- 9 files changed, 296 insertions(+), 573 deletions(-) create mode 100644 build.rs diff --git a/Cargo.toml b/Cargo.toml index 65d2e45..1adcbfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,10 @@ publish = false [profile.dev] incremental = true +# Improve build times by optimizing sqlx-macros +[profile.dev.package.sqlx-macros] +opt-level = 3 + # Release profile for profiling (essentially the default 'release' profile with debug enabled) [profile.profile] inherits = "release" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..25c3936 --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/pacman-server/src/auth/discord.rs b/pacman-server/src/auth/discord.rs index a75c8e0..1aa7949 100644 --- a/pacman-server/src/auth/discord.rs +++ b/pacman-server/src/auth/discord.rs @@ -15,6 +15,7 @@ pub struct DiscordUser { pub username: String, pub global_name: Option, pub email: Option, + pub verified: Option, pub avatar: Option, } @@ -109,11 +110,17 @@ impl OAuthProvider for DiscordProvider { _ => None, }; + let (email, email_verified) = match (&user.email, user.verified) { + (Some(e), Some(true)) => (Some(e.clone()), true), + _ => (None, false), + }; + Ok(AuthUser { id: user.id, username: user.username, name: user.global_name, - email: user.email, + email, + email_verified, avatar_url, }) } diff --git a/pacman-server/src/auth/github.rs b/pacman-server/src/auth/github.rs index 6caf1bb..64e28ac 100644 --- a/pacman-server/src/auth/github.rs +++ b/pacman-server/src/auth/github.rs @@ -51,6 +51,28 @@ pub async fn fetch_github_user( Ok(user) } +/// Fetch user emails from GitHub API +pub async fn fetch_github_emails( + http_client: &reqwest::Client, + access_token: &str, +) -> Result, Box> { + let response = http_client + .get("https://api.github.com/user/emails") + .header("Authorization", format!("Bearer {}", access_token)) + .header("Accept", "application/vnd.github.v3+json") + .header("User-Agent", crate::config::USER_AGENT) + .send() + .await?; + + if !response.status().is_success() { + warn!(status = %response.status(), endpoint = "/user/emails", "GitHub API returned an error"); + return Err(format!("GitHub API error: {}", response.status()).into()); + } + + let emails: Vec = response.json().await?; + Ok(emails) +} + pub struct GitHubProvider { pub client: super::OAuthClient, pub http: reqwest::Client, @@ -112,11 +134,24 @@ impl OAuthProvider for GitHubProvider { ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e))) })?; + let emails = fetch_github_emails(&self.http, access_token).await.map_err(|e| { + warn!(error = %e, "Failed to fetch GitHub user emails"); + ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e))) + })?; + + let primary_email = emails.into_iter().find(|e| e.primary && e.verified); + + let (email, email_verified) = match primary_email { + Some(e) => (Some(e.email), true), + None => (user.email, false), + }; + Ok(AuthUser { id: user.id.to_string(), username: user.login, name: user.name, - email: user.email, + email, + email_verified, avatar_url: Some(user.avatar_url), }) } diff --git a/pacman-server/src/auth/provider.rs b/pacman-server/src/auth/provider.rs index 2455c85..13312c1 100644 --- a/pacman-server/src/auth/provider.rs +++ b/pacman-server/src/auth/provider.rs @@ -20,6 +20,8 @@ pub struct AuthUser { pub name: Option, // An email address for the user. Not always available. pub email: Option, + // Whether the email address has been verified by the provider. + pub email_verified: bool, // An avatar URL for the user. Not always available. pub avatar_url: Option, } diff --git a/pacman-server/src/data/user.rs b/pacman-server/src/data/user.rs index 942f2f0..479cd9a 100644 --- a/pacman-server/src/data/user.rs +++ b/pacman-server/src/data/user.rs @@ -66,54 +66,18 @@ pub async fn link_oauth_account( .await } -pub async fn create_user( - pool: &sqlx::PgPool, - provider_username: &str, - provider_display_name: Option<&str>, - provider_email: Option<&str>, - provider_avatar_url: Option<&str>, - provider: &str, - provider_user_id: &str, -) -> Result { - let user = sqlx::query_as::<_, User>( +pub async fn create_user(pool: &sqlx::PgPool, email: Option<&str>) -> Result { + sqlx::query_as::<_, User>( r#" INSERT INTO users (email) VALUES ($1) + ON CONFLICT (email) DO UPDATE SET email = EXCLUDED.email RETURNING id, email, created_at, updated_at "#, ) - .bind(provider_email) + .bind(email) .fetch_one(pool) - .await?; - - // Create oauth link - let _linked = link_oauth_account( - pool, - user.id, - provider, - provider_user_id, - provider_email, - Some(provider_username), - provider_display_name, - provider_avatar_url, - ) - .await?; - - Ok(user) -} - -pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64) -> Result { - let rec: (i64,) = sqlx::query_as( - r#" - SELECT COUNT(*)::BIGINT AS count - FROM oauth_accounts - WHERE user_id = $1 - "#, - ) - .bind(user_id) - .fetch_one(pool) - .await?; - Ok(rec.0) + .await } pub async fn find_user_by_provider_id( diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 62e0528..a836802 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -4,9 +4,8 @@ use axum::{ response::{IntoResponse, Redirect}, }; use axum_cookie::CookieManager; -use jsonwebtoken::{encode, Algorithm, Header}; use serde::Serialize; -use tracing::{debug, debug_span, info, instrument, trace, warn}; +use tracing::{debug, debug_span, info, instrument, trace, warn, Instrument}; use crate::data::user as user_repo; use crate::{app::AppState, errors::ErrorResponse, session}; @@ -19,11 +18,6 @@ pub struct OAuthCallbackParams { pub error_description: Option, } -#[derive(Debug, serde::Deserialize)] -pub struct AuthorizeQuery { - pub link: Option, -} - /// Handles the beginning of the OAuth authorization flow. /// /// Requires the `provider` path parameter, which determines the OAuth provider to use. @@ -31,79 +25,20 @@ pub struct AuthorizeQuery { pub async fn oauth_authorize_handler( State(app_state): State, Path(provider): Path, - Query(aq): Query, cookie: CookieManager, ) -> axum::response::Response { let Some(prov) = app_state.auth.get(&provider) else { warn!(%provider, "Unknown OAuth provider"); return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response(); }; - - let is_linking = aq.link == Some(true); - - // Persist link intent using a short-lived cookie; callbacks won't carry our query params. - if is_linking { - cookie.add( - axum_cookie::cookie::Cookie::builder("link", "1") - .http_only(true) - .same_site(axum_cookie::prelude::SameSite::Lax) - .path("/") - // TODO: Pick a reasonable max age that aligns with how long OAuth providers can successfully complete the flow. - .max_age(std::time::Duration::from_secs(60 * 60)) - .build(), - ); - } - trace!(linking = %is_linking, "Starting OAuth authorization"); - - // Try to acquire the existing session (PKCE session is ignored) - let existing_session = match session::get_session_token(&cookie) { - Some(token) => match session::decode_jwt(&token, &app_state.jwt_decoding_key) { - Some(claims) if !session::is_pkce_session(&claims) => Some(claims), - Some(_) => { - debug!("existing session ignored; it is a PKCE session"); - None - } - None => { - debug!("invalid session token"); - return ErrorResponse::unauthorized("invalid session token").into_response(); - } - }, - None => { - debug!("missing session cookie"); - return ErrorResponse::unauthorized("missing session cookie").into_response(); - } - }; - - // If linking is enabled, error if the session doesn't exist or is a PKCE session - if is_linking && existing_session.is_none() { - warn!("missing session cookie during linking flow, refusing"); - return ErrorResponse::unauthorized("missing session cookie").into_response(); - } + trace!("Starting OAuth authorization"); let auth_info = match prov.authorize(&app_state.jwt_encoding_key).await { Ok(info) => info, Err(e) => return e.into_response(), }; - let final_token = if let Some(mut claims) = existing_session { - // We have a user session and are linking. Merge PKCE info into it. - if let Some(pkce_claims) = session::decode_jwt(&auth_info.session_token, &app_state.jwt_decoding_key) { - claims.pkce_verifier = pkce_claims.pkce_verifier; - claims.csrf_state = pkce_claims.csrf_state; - - // re-encode - encode(&Header::new(Algorithm::HS256), &claims, &app_state.jwt_encoding_key).expect("jwt sign") - } else { - warn!("Failed to decode PKCE session token during linking flow"); - // Fallback to just using the PKCE token, which will break linking but not panic. - auth_info.session_token - } - } else { - // Not linking or no existing session, just use the new token. - auth_info.session_token - }; - - session::set_session_cookie(&cookie, &final_token); + session::set_session_cookie(&cookie, &auth_info.session_token); trace!("Redirecting to provider authorization page"); Redirect::to(auth_info.authorize_url.as_str()).into_response() } @@ -149,146 +84,62 @@ pub async fn oauth_callback_handler( } }; - debug!(cookies = ?cookie.cookie().iter().collect::>(), "Cookies"); - - // Linking or sign-in flow. Determine link intent from cookie (set at authorize time) - let link_cookie = cookie.get("link").map(|c| c.value().to_string()); - if link_cookie.is_some() { - cookie.remove("link"); - } - let email = user.email.as_deref(); - - // Determine linking intent with a valid session - if link_cookie.as_deref() == Some("1") { - debug!("Link intent present"); - - if let Some(claims) = - session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) - { - // Perform linking with current session user - let (cur_prov, cur_id) = claims.subject.split_once(':').unwrap_or(("", "")); - let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await { - Ok(Some(u)) => u, - Ok(None) => { - warn!("Current session user not found; proceeding as normal sign-in"); - return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into())) - .into_response(); - } - Err(_) => { - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); - } - }; - if let Err(e) = user_repo::link_oauth_account( - &app_state.db, - current_user.id, - &provider, - &user.id, - email, - Some(&user.username), - user.name.as_deref(), - user.avatar_url.as_deref(), - ) - .await - { - warn!(error = %e, %provider, "Failed to link OAuth account"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); - } - return (StatusCode::FOUND, Redirect::to("/profile")).into_response(); - } else { - warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in"); + // --- Simplified Sign-in / Sign-up Flow --- + let linking_span = debug_span!("account_linking", provider_user_id = %user.id, provider_email = ?user.email, email_verified = %user.email_verified); + let db_user_result: Result = async { + // 1. Check if we already have this specific provider account linked + if let Some(user) = user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await? { + debug!(user_id = %user.id, "Found existing user by provider ID"); + return Ok(user); } - } - // Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow. - if let Some(e) = email { - if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await { - // Only block if the user already has at least one linked provider. - // NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive, - // this may lock them out until the provider is reactivated or a manual admin link is performed. - match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await { - Ok(count) if count > 0 => { - // Check if the "new" provider is already linked to the user - match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await { - Ok(Some(_)) => { - debug!( - %provider, - %existing.id, - "Provider already linked to user, signing in normally"); - } - Ok(None) => { - debug!( - %provider, - %existing.id, - "Provider not linked to user, failing" - ); - return ErrorResponse::bad_request( - "account_exists", - Some(format!( - "An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.", - e, provider - )), - ) - .into_response(); - } - Err(e) => { - warn!(error = %e, %provider, "Failed to find user by provider ID"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) - .into_response(); - } - } - } - Ok(_) => { - // No providers linked yet: safe to associate this provider - if let Err(e) = user_repo::link_oauth_account( - &app_state.db, - existing.id, - &provider, - &user.id, - email, - Some(&user.username), - user.name.as_deref(), - user.avatar_url.as_deref(), - ) - .await - { - warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) - .into_response(); - } - } - Err(e) => { - warn!(error = %e, "Failed to count oauth accounts for user"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + // 2. If not, try to find an existing user by verified email to link to + let user_to_link = if user.email_verified { + if let Some(email) = user.email.as_deref() { + // Try to find a user with this email + if let Some(existing_user) = user_repo::find_user_by_email(&app_state.db, email).await? { + debug!(user_id = %existing_user.id, "Found existing user by email, linking new provider"); + existing_user + } else { + // No user with this email, create a new one + debug!("No user found by email, creating a new one"); + user_repo::create_user(&app_state.db, Some(email)).await? } + } else { + // Verified, but no email for some reason. Create a user without an email. + user_repo::create_user(&app_state.db, None).await? } } else { - // Create new user with email - match user_repo::create_user( - &app_state.db, - &user.username, - user.name.as_deref(), - email, - user.avatar_url.as_deref(), - &provider, - &user.id, - ) - .await - { - Ok(u) => u, - Err(e) => { - warn!(error = %e, %provider, "Failed to create user"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); - } - }; - } - } else { - // No email available: disallow sign-in for safety - return ErrorResponse::bad_request( - "invalid_request", - Some("account has no email; sign in with a different provider".into()), + // No verified email, so we must create a new user without an email. + debug!("No verified email, creating a new user"); + user_repo::create_user(&app_state.db, None).await? + }; + + // 3. Link the new provider account to our user record (whether old or new) + user_repo::link_oauth_account( + &app_state.db, + user_to_link.id, + &provider, + &user.id, + user.email.as_deref(), + Some(&user.username), + user.name.as_deref(), + user.avatar_url.as_deref(), ) - .into_response(); + .await?; + + Ok(user_to_link) } + .instrument(linking_span) + .await; + + let _: user_repo::User = match db_user_result { + Ok(u) => u, + Err(e) => { + warn!(error = %(&e as &dyn std::error::Error), "Failed to process user linking/creation"); + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + } + }; // Create session token let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); diff --git a/pacman-server/tests/oauth.rs b/pacman-server/tests/oauth.rs index a62268a..4c1dd8f 100644 --- a/pacman-server/tests/oauth.rs +++ b/pacman-server/tests/oauth.rs @@ -2,9 +2,10 @@ use std::{collections::HashMap, sync::Arc}; use pacman_server::{ auth::{ - provider::{MockOAuthProvider, OAuthProvider}, + provider::{AuthUser, MockOAuthProvider, OAuthProvider}, AuthRegistry, }, + data::user as user_repo, session, }; use pretty_assertions::assert_eq; @@ -13,41 +14,13 @@ use time::Duration; mod common; use crate::common::{test_context, TestContext}; -/// Test OAuth authorization flows +/// Test the basic authorization redirect flow #[tokio::test] -async fn test_oauth_authorization_flows() { - let TestContext { server, .. } = test_context().call().await; - - // Test OAuth authorize endpoint (should redirect) - let response = server.get("/auth/github").await; - assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth - - // Test OAuth authorize endpoint for Discord - let response = server.get("/auth/discord").await; - assert_eq!(response.status_code(), 303); // Redirect to Discord OAuth - - // Test unknown provider - let response = server.get("/auth/unknown").await; - assert_eq!(response.status_code(), 400); // Bad request for unknown provider -} - -/// Test OAuth callback handling -#[tokio::test] -async fn test_oauth_callback_handling() { - let TestContext { server, .. } = test_context().call().await; - - // Test OAuth callback with missing parameters (should fail gracefully) - let response = server.get("/auth/github/callback").await; - assert_eq!(response.status_code(), 400); // Bad request for missing code/state -} - -/// Test OAuth authorization flow -#[tokio::test] -async fn test_oauth_authorization_flow() { +async fn test_oauth_authorization_redirect() { let mut mock = MockOAuthProvider::new(); mock.expect_authorize().returning(|encoding_key| { Ok(pacman_server::auth::provider::AuthorizeInfo { - authorize_url: "https://example.com".parse().unwrap(), + authorize_url: "https://example.com/auth".parse().unwrap(), session_token: session::create_pkce_session("verifier", "state", encoding_key), }) }); @@ -59,194 +32,26 @@ async fn test_oauth_authorization_flow() { let TestContext { server, app_state, .. } = test_context().auth_registry(mock_registry).call().await; - // Test that valid handlers redirect - let response = server.get("/auth/mock").await; - assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth - - // Test that unknown handlers return an error - let response = server.get("/auth/unknown").await; - assert_eq!(response.status_code(), 400); // Bad request for unknown provider - - // Test that session cookie is set let response = server.get("/auth/mock").await; assert_eq!(response.status_code(), 303); - let cookies = { - let cookies = response.cookies(); - cookies.iter().cloned().collect::>() - }; - assert_eq!(cookies.len(), 1); - assert_eq!(cookies[0].name(), "session"); - let claims = session::decode_jwt(cookies[0].value(), &app_state.jwt_decoding_key).unwrap(); - assert!(session::is_pkce_session(&claims)); + assert_eq!(response.headers().get("location").unwrap(), "https://example.com/auth"); - // Test that link parameter redirects and sets a link cookie - let response = server.get("/auth/mock?link=true").await; - assert_eq!(response.status_code(), 303); - assert_eq!(response.maybe_cookie("link").is_some(), true); - assert_eq!(response.maybe_cookie("link").unwrap().value(), "1"); + let session_cookie = response.cookie("session"); + let claims = session::decode_jwt(session_cookie.value(), &app_state.jwt_decoding_key).unwrap(); + assert!(session::is_pkce_session(&claims), "A PKCE session should be set"); } -/// Test OAuth callback validation -#[tokio::test] -async fn test_oauth_callback_validation() { - let mut mock = MockOAuthProvider::new(); - mock.expect_handle_callback() - .times(0) // Should not be called - .returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "123".to_string(), - username: "testuser".to_string(), - name: None, - email: Some("test@example.com".to_string()), - avatar_url: None, - }) - }); - let provider: Arc = Arc::new(mock); - let mock_registry = AuthRegistry { - providers: HashMap::from([("mock", provider)]), - }; - - let TestContext { server, .. } = test_context().use_database(true).auth_registry(mock_registry).call().await; - - // Test that an unknown provider returns an error - let response = server.get("/auth/unknown/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 400); - - // Test that a provider-returned error is handled - let response = server.get("/auth/mock/callback?error=access_denied").await; - assert_eq!(response.status_code(), 400); - - // Test that a missing code returns an error - let response = server.get("/auth/mock/callback?state=b").await; - assert_eq!(response.status_code(), 400); - - // Test that a missing state returns an error - let response = server.get("/auth/mock/callback?code=a").await; - assert_eq!(response.status_code(), 400); -} - -/// Test OAuth callback processing -#[tokio::test] -async fn test_oauth_callback_processing() { - let mut mock = MockOAuthProvider::new(); - mock.expect_handle_callback().returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "123".to_string(), - username: "testuser".to_string(), - name: None, - email: Some("processing@example.com".to_string()), - avatar_url: None, - }) - }); - let provider: Arc = Arc::new(mock); - let mock_registry = AuthRegistry { - providers: HashMap::from([("mock", provider)]), - }; - - let context = test_context().use_database(true).auth_registry(mock_registry).call().await; - - // Test that a successful callback redirects and sets a session cookie - let response = context.server.get("/auth/mock/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 302); - assert_eq!(response.headers().get("location").unwrap(), "/profile"); - assert!(response.maybe_cookie("session").is_some()); -} - -/// Test account linking flow -#[tokio::test] -async fn test_account_linking_flow() { - let mut initial_provider_mock = MockOAuthProvider::new(); - initial_provider_mock.expect_handle_callback().returning(move |_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "123".to_string(), - username: "linkuser".to_string(), - name: None, - email: Some("link@example.com".to_string()), - avatar_url: None, - }) - }); - let initial_provider: Arc = Arc::new(initial_provider_mock); - - let mut link_provider_mock = MockOAuthProvider::new(); - link_provider_mock.expect_authorize().returning(|encoding_key| { - Ok(pacman_server::auth::provider::AuthorizeInfo { - authorize_url: "https://example.com".parse().unwrap(), - session_token: session::create_pkce_session("verifier", "state", encoding_key), - }) - }); - link_provider_mock.expect_handle_callback().returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "456".to_string(), - username: "linkuser_new".to_string(), - name: None, - email: Some("link@example.com".to_string()), - avatar_url: None, - }) - }); - let link_provider: Arc = Arc::new(link_provider_mock); - - let registry = AuthRegistry { - providers: HashMap::from([("mock_initial", initial_provider), ("mock_link", link_provider)]), - }; - let context = test_context().use_database(true).auth_registry(registry).call().await; - - // 1. Create an initial user and provider link - let user = pacman_server::data::user::create_user( - &context.app_state.db, - "linkuser", - None, - Some("link@example.com"), - None, - "mock_initial", - "123", - ) - .await - .expect("Failed to create user"); - - { - let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) - .await - .expect("Failed to list user's initial provider(s)"); - assert_eq!(providers.len(), 1, "User should have one provider"); - assert!(providers.iter().any(|p| p.provider == "mock_initial")); - } - - // 2. Create a session for this user - let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 302); - - // Begin linking flow - let response = context.server.get("/auth/mock_link?link=true").await; - assert_eq!(response.status_code(), 303); - - // 3. Perform the linking call - let response = context.server.get("/auth/mock_link/callback?code=a&state=b").await; - - assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect"); - assert_eq!( - response.headers().get("location").unwrap(), - "/profile", - "Post-linking response should redirect to /profile" - ); - - let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) - .await - .expect("Failed to list user's providers"); - assert_eq!(providers.len(), 2, "User should have two providers"); - assert!(providers.iter().any(|p| p.provider == "mock_initial")); - assert!(providers.iter().any(|p| p.provider == "mock_link")); -} - -/// Test new user registration +/// Test new user registration via OAuth callback #[tokio::test] async fn test_new_user_registration() { let mut mock = MockOAuthProvider::new(); mock.expect_handle_callback().returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "123".to_string(), - username: "testuser".to_string(), + Ok(AuthUser { + id: "newuser123".to_string(), + username: "new_user".to_string(), name: None, - email: Some("newuser@example.com".to_string()), + email: Some("new@example.com".to_string()), + email_verified: true, avatar_url: None, }) }); @@ -258,27 +63,151 @@ async fn test_new_user_registration() { let context = test_context().use_database(true).auth_registry(mock_registry).call().await; - // Test that the OAuth callback handler creates a new user account when no existing user is found let response = context.server.get("/auth/mock/callback?code=a&state=b").await; assert_eq!(response.status_code(), 302); assert_eq!(response.headers().get("location").unwrap(), "/profile"); - let user = pacman_server::data::user::find_user_by_email(&context.app_state.db, "newuser@example.com") + + // Verify user and oauth_account were created + let user = user_repo::find_user_by_email(&context.app_state.db, "new@example.com") + .await + .unwrap() + .expect("User should be created"); + assert_eq!(user.email, Some("new@example.com".to_string())); + + let providers = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap(); + assert_eq!(providers.len(), 1); + assert_eq!(providers[0].provider, "mock"); + assert_eq!(providers[0].provider_user_id, "newuser123"); +} + +/// Test sign-in for an existing user with an already-linked provider +#[tokio::test] +async fn test_existing_user_signin() { + let mut mock = MockOAuthProvider::new(); + mock.expect_handle_callback().returning(|_, _, _, _| { + Ok(AuthUser { + id: "existing123".to_string(), + username: "existing_user".to_string(), + name: None, + email: Some("existing@example.com".to_string()), + email_verified: true, + avatar_url: None, + }) + }); + + let provider: Arc = Arc::new(mock); + let mock_registry = AuthRegistry { + providers: HashMap::from([("mock", provider)]), + }; + + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Pre-create the user and link + let user = user_repo::create_user(&context.app_state.db, Some("existing@example.com")) + .await + .unwrap(); + user_repo::link_oauth_account( + &context.app_state.db, + user.id, + "mock", + "existing123", + Some("existing@example.com"), + Some("existing_user"), + None, + None, + ) + .await + .unwrap(); + + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302, "Should sign in successfully"); + assert_eq!(response.headers().get("location").unwrap(), "/profile"); + + // Verify no new user was created + let users = sqlx::query("SELECT * FROM users") + .fetch_all(&context.app_state.db) + .await + .unwrap(); + assert_eq!(users.len(), 1, "No new user should be created"); +} + +/// Test implicit account linking via a shared verified email +#[tokio::test] +async fn test_implicit_account_linking() { + // 1. User signs in with 'provider-a' + let mut mock_a = MockOAuthProvider::new(); + mock_a.expect_handle_callback().returning(|_, _, _, _| { + Ok(AuthUser { + id: "user_a_123".to_string(), + username: "user_a".to_string(), + name: None, + email: Some("shared@example.com".to_string()), + email_verified: true, + avatar_url: None, + }) + }); + + // 2. Later, the same user signs in with 'provider-b' + let mut mock_b = MockOAuthProvider::new(); + mock_b.expect_handle_callback().returning(|_, _, _, _| { + Ok(AuthUser { + id: "user_b_456".to_string(), + username: "user_b".to_string(), + name: None, + email: Some("shared@example.com".to_string()), + email_verified: true, + avatar_url: None, + }) + }); + + let provider_a: Arc = Arc::new(mock_a); + let provider_b: Arc = Arc::new(mock_b); + let mock_registry = AuthRegistry { + providers: HashMap::from([("provider-a", provider_a), ("provider-b", provider_b)]), + }; + + let context = test_context().use_database(true).auth_registry(mock_registry).call().await; + + // Action 1: Sign in with provider-a, creating the initial user + let response1 = context.server.get("/auth/provider-a/callback?code=a&state=b").await; + assert_eq!(response1.status_code(), 302); + + let user = user_repo::find_user_by_email(&context.app_state.db, "shared@example.com") .await .unwrap() .unwrap(); - assert_eq!(user.email, Some("newuser@example.com".to_string())); + let providers1 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap(); + assert_eq!(providers1.len(), 1); + assert_eq!(providers1[0].provider, "provider-a"); + + // Action 2: Sign in with provider-b + let response2 = context.server.get("/auth/provider-b/callback?code=a&state=b").await; + assert_eq!(response2.status_code(), 302); + + // Assertions: No new user, but a new provider link + let users = sqlx::query("SELECT * FROM users") + .fetch_all(&context.app_state.db) + .await + .unwrap(); + assert_eq!(users.len(), 1, "A new user should NOT have been created"); + + let providers2 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap(); + assert_eq!(providers2.len(), 2, "A new provider should have been linked"); + assert!(providers2.iter().any(|p| p.provider == "provider-a")); + assert!(providers2.iter().any(|p| p.provider == "provider-b")); } -/// Test OAuth callback handler rejects sign-in attempts when no email is available +/// Test that an unverified email does NOT link accounts #[tokio::test] -async fn test_oauth_callback_no_email() { +async fn test_unverified_email_creates_new_account() { let mut mock = MockOAuthProvider::new(); mock.expect_handle_callback().returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "456".to_string(), - username: "noemailuser".to_string(), + Ok(AuthUser { + id: "unverified123".to_string(), + username: "unverified_user".to_string(), name: None, - email: None, + email: Some("unverified@example.com".to_string()), + email_verified: false, avatar_url: None, }) }); @@ -290,86 +219,20 @@ async fn test_oauth_callback_no_email() { let context = test_context().use_database(true).auth_registry(mock_registry).call().await; - // Test that the OAuth callback handler rejects sign-in attempts when no email is available + // Pre-create a user with the same email, but they will not be linked. + user_repo::create_user(&context.app_state.db, Some("unverified@example.com")) + .await + .unwrap(); + let response = context.server.get("/auth/mock/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 400); -} + assert_eq!(response.status_code(), 302); -/// Test existing user sign-in with new provider fails -#[tokio::test] -async fn test_existing_user_sign_in_new_provider_fails() { - let mut mock = MockOAuthProvider::new(); - mock.expect_handle_callback().returning(move |_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "456".to_string(), // Different provider ID - username: "existinguser_newprovider".to_string(), - name: None, - email: Some("existing@example.com".to_string()), - avatar_url: None, - }) - }); - let provider: Arc = Arc::new(mock); - let mock_registry = AuthRegistry { - providers: HashMap::from([("mock_new", provider)]), - }; - - let context = test_context().use_database(true).auth_registry(mock_registry).call().await; - - // Create a user with one linked provider - pacman_server::data::user::create_user( - &context.app_state.db, - "existinguser", - None, - Some("existing@example.com"), - None, - "mock", - "123", - ) - .await - .unwrap(); - - // A user with the email exists, but has one provider. If they sign in with a *new* provider, it should fail. - let response = context.server.get("/auth/mock_new/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 400); // Should fail and ask to link explicitly. -} - -/// Test existing user sign-in with existing provider succeeds -#[tokio::test] -async fn test_existing_user_sign_in_existing_provider_succeeds() { - let mut mock = MockOAuthProvider::new(); - mock.expect_handle_callback().returning(move |_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { - id: "123".to_string(), // Same provider ID as created user - username: "existinguser".to_string(), - name: None, - email: Some("existing@example.com".to_string()), - avatar_url: None, - }) - }); - let provider: Arc = Arc::new(mock); - let mock_registry = AuthRegistry { - providers: HashMap::from([("mock", provider)]), - }; - - let context = test_context().use_database(true).auth_registry(mock_registry).call().await; - - // Create a user with one linked provider - pacman_server::data::user::create_user( - &context.app_state.db, - "existinguser", - None, - Some("existing@example.com"), - None, - "mock", - "123", - ) - .await - .unwrap(); - - // Test sign-in with an existing linked provider. - let response = context.server.get("/auth/mock/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 302); // Should sign in successfully - assert_eq!(response.headers().get("location").unwrap(), "/profile"); + // Should create a second user because the email wasn't trusted for linking + let users = sqlx::query("SELECT * FROM users") + .fetch_all(&context.app_state.db) + .await + .unwrap(); + assert_eq!(users.len(), 2, "A new user should be created for the unverified email"); } /// Test logout functionality @@ -377,11 +240,12 @@ async fn test_existing_user_sign_in_existing_provider_succeeds() { async fn test_logout_functionality() { let mut mock = MockOAuthProvider::new(); mock.expect_handle_callback().returning(|_, _, _, _| { - Ok(pacman_server::auth::provider::AuthUser { + Ok(AuthUser { id: "123".to_string(), username: "testuser".to_string(), name: None, email: Some("test@example.com".to_string()), + email_verified: true, avatar_url: None, }) }); @@ -395,27 +259,14 @@ async fn test_logout_functionality() { // Sign in to establish a session let response = context.server.get("/auth/mock/callback?code=a&state=b").await; assert_eq!(response.status_code(), 302); - let session_cookie = response.cookie("session").clone(); // Test that the logout handler clears the session cookie and redirects let response = context.server.get("/logout").await; - // Redirect assertions assert_eq!(response.status_code(), 302); - assert!( - response.headers().contains_key("location"), - "Response redirect should have a location header" - ); + assert!(response.headers().contains_key("location")); - // Cookie assertions - assert_eq!( - response.cookie("session").value(), - "removed", - "Session cookie should be removed" - ); - assert_eq!( - response.cookie("session").max_age(), - Some(Duration::ZERO), - "Session cookie should have a max age of 0" - ); + let cookie = response.cookie("session"); + assert_eq!(cookie.value(), "removed"); + assert_eq!(cookie.max_age(), Some(Duration::ZERO)); } diff --git a/pacman-server/tests/sessions.rs b/pacman-server/tests/sessions.rs index 155a399..4dd6bd6 100644 --- a/pacman-server/tests/sessions.rs +++ b/pacman-server/tests/sessions.rs @@ -1,7 +1,7 @@ mod common; -use crate::common::{test_context, TestContext}; +use crate::common::test_context; use cookie::Cookie; -use pacman_server::session; +use pacman_server::{data::user as user_repo, session}; use pretty_assertions::assert_eq; @@ -9,25 +9,30 @@ use pretty_assertions::assert_eq; async fn test_session_management() { let context = test_context().use_database(true).call().await; - // 1. Create a user - let user = - pacman_server::data::user::create_user(&context.app_state.db, "testuser", None, None, None, "test_provider", "123") - .await - .unwrap(); + // 1. Create a user and link a provider account + let user = user_repo::create_user(&context.app_state.db, Some("test@example.com")) + .await + .unwrap(); + let provider_account = user_repo::link_oauth_account( + &context.app_state.db, + user.id, + "test_provider", + "123", + Some("test@example.com"), + Some("testuser"), + None, + None, + ) + .await + .unwrap(); // 2. Create a session token for the user - let provider_account = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) - .await - .unwrap() - .into_iter() - .find(|p| p.provider == "test_provider") - .unwrap(); - let auth_user = pacman_server::auth::provider::AuthUser { id: provider_account.provider_user_id, username: provider_account.username.unwrap(), name: provider_account.display_name, email: user.email, + email_verified: true, avatar_url: provider_account.avatar_url, }; let token = session::create_jwt_for_user("test_provider", &auth_user, &context.app_state.jwt_encoding_key);