diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 903eaec..62e0528 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -4,14 +4,15 @@ use axum::{ response::{IntoResponse, Redirect}, }; use axum_cookie::CookieManager; +use jsonwebtoken::{encode, Algorithm, Header}; use serde::Serialize; -use tracing::{debug, info, instrument, span, trace, warn}; +use tracing::{debug, debug_span, info, instrument, trace, warn}; use crate::data::user as user_repo; use crate::{app::AppState, errors::ErrorResponse, session}; #[derive(Debug, serde::Deserialize)] -pub struct AuthQuery { +pub struct OAuthCallbackParams { pub code: Option, pub state: Option, pub error: Option, @@ -23,6 +24,9 @@ pub struct AuthorizeQuery { pub link: Option, } +/// Handles the beginning of the OAuth authorization flow. +/// +/// Requires the `provider` path parameter, which determines the OAuth provider to use. #[instrument(skip_all, fields(provider = %provider))] pub async fn oauth_authorize_handler( State(app_state): State, @@ -34,32 +38,84 @@ pub async fn oauth_authorize_handler( warn!(%provider, "Unknown OAuth provider"); return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response(); }; - trace!("Starting OAuth authorization"); + + let is_linking = aq.link == Some(true); + // Persist link intent using a short-lived cookie; callbacks won't carry our query params. - if aq.link == Some(true) { + if is_linking { cookie.add( axum_cookie::cookie::Cookie::builder("link", "1") .http_only(true) .same_site(axum_cookie::prelude::SameSite::Lax) .path("/") - .max_age(std::time::Duration::from_secs(120)) + // TODO: Pick a reasonable max age that aligns with how long OAuth providers can successfully complete the flow. + .max_age(std::time::Duration::from_secs(60 * 60)) .build(), ); } + trace!(linking = %is_linking, "Starting OAuth authorization"); + + // Try to acquire the existing session (PKCE session is ignored) + let existing_session = match session::get_session_token(&cookie) { + Some(token) => match session::decode_jwt(&token, &app_state.jwt_decoding_key) { + Some(claims) if !session::is_pkce_session(&claims) => Some(claims), + Some(_) => { + debug!("existing session ignored; it is a PKCE session"); + None + } + None => { + debug!("invalid session token"); + return ErrorResponse::unauthorized("invalid session token").into_response(); + } + }, + None => { + debug!("missing session cookie"); + return ErrorResponse::unauthorized("missing session cookie").into_response(); + } + }; + + // If linking is enabled, error if the session doesn't exist or is a PKCE session + if is_linking && existing_session.is_none() { + warn!("missing session cookie during linking flow, refusing"); + return ErrorResponse::unauthorized("missing session cookie").into_response(); + } + let auth_info = match prov.authorize(&app_state.jwt_encoding_key).await { Ok(info) => info, Err(e) => return e.into_response(), }; - session::set_session_cookie(&cookie, &auth_info.session_token); + let final_token = if let Some(mut claims) = existing_session { + // We have a user session and are linking. Merge PKCE info into it. + if let Some(pkce_claims) = session::decode_jwt(&auth_info.session_token, &app_state.jwt_decoding_key) { + claims.pkce_verifier = pkce_claims.pkce_verifier; + claims.csrf_state = pkce_claims.csrf_state; + + // re-encode + encode(&Header::new(Algorithm::HS256), &claims, &app_state.jwt_encoding_key).expect("jwt sign") + } else { + warn!("Failed to decode PKCE session token during linking flow"); + // Fallback to just using the PKCE token, which will break linking but not panic. + auth_info.session_token + } + } else { + // Not linking or no existing session, just use the new token. + auth_info.session_token + }; + + session::set_session_cookie(&cookie, &final_token); trace!("Redirecting to provider authorization page"); Redirect::to(auth_info.authorize_url.as_str()).into_response() } +/// Handles the callback from the OAuth provider after the user has authorized the app. +/// +/// Requires the `provider` path parameter, which determines the OAuth provider to use for finishing the OAuth flow. +/// Requires the `code` and `state` query parameters, which are returned by the OAuth provider after the user has authorized the app. pub async fn oauth_callback_handler( State(app_state): State, Path(provider): Path, - Query(params): Query, + Query(params): Query, cookie: CookieManager, ) -> axum::response::Response { // Validate provider @@ -82,7 +138,7 @@ pub async fn oauth_callback_handler( return ErrorResponse::bad_request("invalid_request", Some("missing state".into())).into_response(); }; - span!(tracing::Level::DEBUG, "oauth_callback_handler", provider = %provider, code = %code, state = %state); + debug_span!("oauth_callback_handler", provider = %provider, code = %code, state = %state); // Handle callback from provider let user = match prov.handle_callback(code, state, &cookie, &app_state.jwt_decoding_key).await { @@ -103,147 +159,136 @@ pub async fn oauth_callback_handler( let email = user.email.as_deref(); // Determine linking intent with a valid session - let is_link = if link_cookie.as_deref() == Some("1") { + if link_cookie.as_deref() == Some("1") { debug!("Link intent present"); - match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) { - Some(c) => { - // Perform linking with current session user - let (cur_prov, cur_id) = c.subject.split_once(':').unwrap_or(("", "")); - let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await { - Ok(Some(u)) => u, - Ok(None) => { - warn!("Current session user not found; proceeding as normal sign-in"); - return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into())) - .into_response(); - } - Err(_) => { - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) - .into_response(); - } - }; - if let Err(e) = user_repo::link_oauth_account( - &app_state.db, - current_user.id, - &provider, - &user.id, - email, - Some(&user.username), - user.name.as_deref(), - user.avatar_url.as_deref(), - ) - .await - { - warn!(error = %e, %provider, "Failed to link OAuth account"); + if let Some(claims) = + session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) + { + // Perform linking with current session user + let (cur_prov, cur_id) = claims.subject.split_once(':').unwrap_or(("", "")); + let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await { + Ok(Some(u)) => u, + Ok(None) => { + warn!("Current session user not found; proceeding as normal sign-in"); + return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into())) + .into_response(); + } + Err(_) => { return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); } - return (StatusCode::FOUND, Redirect::to("/profile")).into_response(); - } - None => { - warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in"); - false + }; + if let Err(e) = user_repo::link_oauth_account( + &app_state.db, + current_user.id, + &provider, + &user.id, + email, + Some(&user.username), + user.name.as_deref(), + user.avatar_url.as_deref(), + ) + .await + { + warn!(error = %e, %provider, "Failed to link OAuth account"); + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); } + return (StatusCode::FOUND, Redirect::to("/profile")).into_response(); + } else { + warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in"); } - } else { - false - }; + } - if is_link { - unreachable!(); // handled via early return above - } else { - // Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow. - if let Some(e) = email { - if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await { - // Only block if the user already has at least one linked provider. - // NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive, - // this may lock them out until the provider is reactivated or a manual admin link is performed. - match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await { - Ok(count) if count > 0 => { - // Check if the "new" provider is already linked to the user - match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await { - Ok(Some(_)) => { - debug!( - %provider, - %existing.id, - "Provider already linked to user, signing in normally"); - } - Ok(None) => { - debug!( - %provider, - %existing.id, - "Provider not linked to user, failing" - ); - return ErrorResponse::bad_request( - "account_exists", - Some(format!( - "An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.", - e, provider - )), - ) - .into_response(); - } - Err(e) => { - warn!(error = %e, %provider, "Failed to find user by provider ID"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) - .into_response(); - } + // Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow. + if let Some(e) = email { + if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await { + // Only block if the user already has at least one linked provider. + // NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive, + // this may lock them out until the provider is reactivated or a manual admin link is performed. + match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await { + Ok(count) if count > 0 => { + // Check if the "new" provider is already linked to the user + match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await { + Ok(Some(_)) => { + debug!( + %provider, + %existing.id, + "Provider already linked to user, signing in normally"); } - } - Ok(_) => { - // No providers linked yet: safe to associate this provider - if let Err(e) = user_repo::link_oauth_account( - &app_state.db, - existing.id, - &provider, - &user.id, - email, - Some(&user.username), - user.name.as_deref(), - user.avatar_url.as_deref(), - ) - .await - { - warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers"); + Ok(None) => { + debug!( + %provider, + %existing.id, + "Provider not linked to user, failing" + ); + return ErrorResponse::bad_request( + "account_exists", + Some(format!( + "An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.", + e, provider + )), + ) + .into_response(); + } + Err(e) => { + warn!(error = %e, %provider, "Failed to find user by provider ID"); return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) .into_response(); } } - Err(e) => { - warn!(error = %e, "Failed to count oauth accounts for user"); + } + Ok(_) => { + // No providers linked yet: safe to associate this provider + if let Err(e) = user_repo::link_oauth_account( + &app_state.db, + existing.id, + &provider, + &user.id, + email, + Some(&user.username), + user.name.as_deref(), + user.avatar_url.as_deref(), + ) + .await + { + warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers"); return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) .into_response(); } } - } else { - // Create new user with email - match user_repo::create_user( - &app_state.db, - &user.username, - user.name.as_deref(), - email, - user.avatar_url.as_deref(), - &provider, - &user.id, - ) - .await - { - Ok(u) => u, - Err(e) => { - warn!(error = %e, %provider, "Failed to create user"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) - .into_response(); - } - }; + Err(e) => { + warn!(error = %e, "Failed to count oauth accounts for user"); + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + } } } else { - // No email available: disallow sign-in for safety - return ErrorResponse::bad_request( - "invalid_request", - Some("account has no email; sign in with a different provider".into()), + // Create new user with email + match user_repo::create_user( + &app_state.db, + &user.username, + user.name.as_deref(), + email, + user.avatar_url.as_deref(), + &provider, + &user.id, ) - .into_response(); + .await + { + Ok(u) => u, + Err(e) => { + warn!(error = %e, %provider, "Failed to create user"); + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + } + }; } - }; + } else { + // No email available: disallow sign-in for safety + return ErrorResponse::bad_request( + "invalid_request", + Some("account has no email; sign in with a different provider".into()), + ) + .into_response(); + } // Create session token let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); @@ -282,6 +327,9 @@ pub async fn oauth_callback_handler( (StatusCode::FOUND, Redirect::to("/profile")).into_response() } +/// Handles the request to the profile endpoint. +/// +/// Requires the `session` cookie to be present. pub async fn profile_handler(State(app_state): State, cookie: CookieManager) -> axum::response::Response { let Some(token_str) = session::get_session_token(&cookie) else { debug!("Missing session cookie"); diff --git a/pacman-server/src/session.rs b/pacman-server/src/session.rs index 7c70ded..d28bd30 100644 --- a/pacman-server/src/session.rs +++ b/pacman-server/src/session.rs @@ -64,7 +64,7 @@ pub fn create_pkce_session(pkce_verifier: &str, csrf_state: &str, encoding_key: /// Checks if a session is a PKCE flow session pub fn is_pkce_session(claims: &Claims) -> bool { - claims.subject == "pkce_flow" && claims.pkce_verifier.is_some() && claims.csrf_state.is_some() + claims.pkce_verifier.is_some() && claims.csrf_state.is_some() } pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option { diff --git a/pacman-server/tests/common/mod.rs b/pacman-server/tests/common/mod.rs index f4d5114..cacc70f 100644 --- a/pacman-server/tests/common/mod.rs +++ b/pacman-server/tests/common/mod.rs @@ -17,6 +17,7 @@ use tracing::{debug, debug_span, Instrument}; static CRYPTO_INIT: Once = Once::new(); /// Test configuration for integration tests +/// Do not destructure this struct if you need the database, it will be dropped implicitly, which will kill the database container prematurely. #[allow(dead_code)] pub struct TestContext { pub config: Config, diff --git a/pacman-server/tests/oauth.rs b/pacman-server/tests/oauth.rs index ec56ba6..a62268a 100644 --- a/pacman-server/tests/oauth.rs +++ b/pacman-server/tests/oauth.rs @@ -1,8 +1,11 @@ use std::{collections::HashMap, sync::Arc}; -use pacman_server::auth::{ - provider::{MockOAuthProvider, OAuthProvider}, - AuthRegistry, +use pacman_server::{ + auth::{ + provider::{MockOAuthProvider, OAuthProvider}, + AuthRegistry, + }, + session, }; use pretty_assertions::assert_eq; use time::Duration; @@ -42,10 +45,10 @@ async fn test_oauth_callback_handling() { #[tokio::test] async fn test_oauth_authorization_flow() { let mut mock = MockOAuthProvider::new(); - mock.expect_authorize().returning(|_| { + mock.expect_authorize().returning(|encoding_key| { Ok(pacman_server::auth::provider::AuthorizeInfo { authorize_url: "https://example.com".parse().unwrap(), - session_token: "a_token".to_string(), + session_token: session::create_pkce_session("verifier", "state", encoding_key), }) }); @@ -54,7 +57,7 @@ async fn test_oauth_authorization_flow() { providers: HashMap::from([("mock", provider)]), }; - let TestContext { server, .. } = test_context().auth_registry(mock_registry).call().await; + let TestContext { server, app_state, .. } = test_context().auth_registry(mock_registry).call().await; // Test that valid handlers redirect let response = server.get("/auth/mock").await; @@ -73,7 +76,8 @@ async fn test_oauth_authorization_flow() { }; assert_eq!(cookies.len(), 1); assert_eq!(cookies[0].name(), "session"); - assert_eq!(cookies[0].value(), "a_token"); + let claims = session::decode_jwt(cookies[0].value(), &app_state.jwt_decoding_key).unwrap(); + assert!(session::is_pkce_session(&claims)); // Test that link parameter redirects and sets a link cookie let response = server.get("/auth/mock?link=true").await; @@ -164,10 +168,10 @@ async fn test_account_linking_flow() { let initial_provider: Arc = Arc::new(initial_provider_mock); let mut link_provider_mock = MockOAuthProvider::new(); - link_provider_mock.expect_authorize().returning(|_| { + link_provider_mock.expect_authorize().returning(|encoding_key| { Ok(pacman_server::auth::provider::AuthorizeInfo { authorize_url: "https://example.com".parse().unwrap(), - session_token: "b_token".to_string(), + session_token: session::create_pkce_session("verifier", "state", encoding_key), }) }); link_provider_mock.expect_handle_callback().returning(|_, _, _, _| { @@ -208,36 +212,15 @@ async fn test_account_linking_flow() { } // 2. Create a session for this user - let session_cookie = { - let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await; - assert_eq!(response.status_code(), 302); - assert!(response.maybe_cookie("session").is_some(), "Session cookie should be set"); - - response.cookie("session").clone() - }; - tracing::debug!(cookie = %session_cookie, "Session cookie acquired"); + let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await; + assert_eq!(response.status_code(), 302); // Begin linking flow - let link_cookie = { - let response = context - .server - .get("/auth/mock_link?link=true") - .add_cookie(session_cookie.clone()) - .await; - assert_eq!(response.status_code(), 303); - assert_eq!(response.maybe_cookie("link").unwrap().value(), "1"); - - response.cookie("link").clone() - }; - tracing::debug!(cookie = %link_cookie, "Link cookie acquired"); + let response = context.server.get("/auth/mock_link?link=true").await; + assert_eq!(response.status_code(), 303); // 3. Perform the linking call - let response = context - .server - .get("/auth/mock_link/callback?code=a&state=b") - .add_cookie(link_cookie) - .add_cookie(session_cookie.clone()) - .await; + let response = context.server.get("/auth/mock_link/callback?code=a&state=b").await; assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect"); assert_eq!( @@ -415,14 +398,7 @@ async fn test_logout_functionality() { let session_cookie = response.cookie("session").clone(); // Test that the logout handler clears the session cookie and redirects - let response = context - .server - .get("/logout") - .add_cookie(cookie::Cookie::new( - session_cookie.name().to_string(), - session_cookie.value().to_string(), - )) - .await; + let response = context.server.get("/logout").await; // Redirect assertions assert_eq!(response.status_code(), 302); diff --git a/pacman-server/tests/sessions.rs b/pacman-server/tests/sessions.rs index 7b7d6f0..155a399 100644 --- a/pacman-server/tests/sessions.rs +++ b/pacman-server/tests/sessions.rs @@ -1,18 +1,50 @@ mod common; use crate::common::{test_context, TestContext}; +use cookie::Cookie; +use pacman_server::session; use pretty_assertions::assert_eq; -/// Test session management endpoints #[tokio::test] async fn test_session_management() { - let TestContext { server, .. } = test_context().use_database(true).call().await; + let context = test_context().use_database(true).call().await; - // Test logout endpoint (should redirect) - let response = server.get("/logout").await; - assert_eq!(response.status_code(), 302); // Redirect to home + // 1. Create a user + let user = + pacman_server::data::user::create_user(&context.app_state.db, "testuser", None, None, None, "test_provider", "123") + .await + .unwrap(); - // Test profile endpoint without session (should be unauthorized) - let response = server.get("/profile").await; + // 2. Create a session token for the user + let provider_account = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id) + .await + .unwrap() + .into_iter() + .find(|p| p.provider == "test_provider") + .unwrap(); + + let auth_user = pacman_server::auth::provider::AuthUser { + id: provider_account.provider_user_id, + username: provider_account.username.unwrap(), + name: provider_account.display_name, + email: user.email, + avatar_url: provider_account.avatar_url, + }; + let token = session::create_jwt_for_user("test_provider", &auth_user, &context.app_state.jwt_encoding_key); + + // 3. Make a request to the protected route WITH the session, expect success + let response = context + .server + .get("/profile") + .add_cookie(Cookie::new(session::SESSION_COOKIE_NAME, token)) + .await; + assert_eq!(response.status_code(), 200); + + // 4. Sign out + let response = context.server.get("/logout").await; + assert_eq!(response.status_code(), 302); // Redirect after logout + + // 5. Make a request to the protected route without a session, expect failure + let response = context.server.get("/profile").await; assert_eq!(response.status_code(), 401); // Unauthorized without session }