diff --git a/pacman-server/src/auth/discord.rs b/pacman-server/src/auth/discord.rs index 3ea9be6..31ae7f8 100644 --- a/pacman-server/src/auth/discord.rs +++ b/pacman-server/src/auth/discord.rs @@ -1,14 +1,14 @@ use axum::{response::IntoResponse, response::Redirect}; -use dashmap::DashMap; -use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse}; use serde::{Deserialize, Serialize}; -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use std::time::{Duration, Instant}; use tracing::{trace, warn}; -use crate::auth::provider::{AuthUser, OAuthProvider}; +use crate::auth::{ + pkce::PkceManager, + provider::{AuthUser, OAuthProvider}, +}; use crate::errors::ErrorResponse; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -43,15 +43,7 @@ pub async fn fetch_discord_user( pub struct DiscordProvider { pub client: super::OAuthClient, pub http: reqwest::Client, - pkce: DashMap, - last_purge_at_secs: AtomicU32, - pkce_additions: AtomicU32, -} - -#[derive(Debug, Clone)] -struct PkceRecord { - verifier: String, - created_at: Instant, + pkce: PkceManager, } impl DiscordProvider { @@ -59,41 +51,10 @@ impl DiscordProvider { Arc::new(Self { client, http, - pkce: DashMap::new(), - last_purge_at_secs: AtomicU32::new(0), - pkce_additions: AtomicU32::new(0), + pkce: PkceManager::new(), }) } - fn maybe_purge_stale_pkce_entries(&self) { - // Purge when at least 5 minutes passed or more than 128 additions occurred - const PURGE_INTERVAL_SECS: u32 = 5 * 60; - const ADDITIONS_THRESHOLD: u32 = 128; - - let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(d) => d.as_secs() as u32, - Err(_) => return, - }; - - let last = self.last_purge_at_secs.load(Ordering::Relaxed); - let additions = self.pkce_additions.load(Ordering::Relaxed); - if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS { - return; - } - - const PKCE_TTL: Duration = Duration::from_secs(5 * 60); - let now_inst = Instant::now(); - for entry in self.pkce.iter() { - if now_inst.duration_since(entry.value().created_at) > PKCE_TTL { - self.pkce.remove(entry.key()); - } - } - - // Reset counters after purge - self.pkce_additions.store(0, Ordering::Relaxed); - self.last_purge_at_secs.store(now_secs, Ordering::Relaxed); - } - fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String { let ext = if avatar_hash.starts_with("a_") { "gif" } else { "png" }; format!("https://cdn.discordapp.com/avatars/{}/{}.{}", user_id, avatar_hash, ext) @@ -110,7 +71,7 @@ impl OAuthProvider for DiscordProvider { } async fn authorize(&self) -> axum::response::Response { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (pkce_challenge, verifier) = self.pkce.generate_challenge(); let (authorize_url, csrf_state) = self .client .authorize_url(CsrfToken::new_random) @@ -118,19 +79,9 @@ impl OAuthProvider for DiscordProvider { .add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("email".to_string())) .url(); + self.pkce.store_verifier(csrf_state.secret(), verifier); trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); - // Insert PKCE verifier with timestamp and purge when needed - self.pkce.insert( - csrf_state.secret().to_string(), - PkceRecord { - verifier: pkce_verifier.secret().to_string(), - created_at: Instant::now(), - }, - ); - self.pkce_additions.fetch_add(1, Ordering::Relaxed); - self.maybe_purge_stale_pkce_entries(); - Redirect::to(authorize_url.as_str()).into_response() } @@ -150,31 +101,22 @@ impl OAuthProvider for DiscordProvider { .get("state") .cloned() .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; - let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else { - warn!("Missing PKCE verifier for state parameter"); + let Some(verifier) = self.pkce.take_verifier(&state) else { + warn!(%state, "Missing or expired PKCE verifier for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", - Some("missing pkce verifier for state".into()), + Some("missing or expired pkce verifier for state".into()), )); }; - // Verify PKCE TTL - if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) { - warn!("PKCE verifier expired for state parameter"); - return Err(ErrorResponse::bad_request( - "invalid_request", - Some("expired pkce verifier for state".into()), - )); - } - let token = self .client .exchange_code(AuthorizationCode::new(code)) - .set_pkce_verifier(PkceCodeVerifier::new(rec.verifier)) + .set_pkce_verifier(PkceCodeVerifier::new(verifier)) .request_async(&self.http) .await .map_err(|e| { - warn!(error = %e, "Token exchange with Discord failed"); + warn!(error = %e, %state, "Token exchange with Discord failed"); ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string())) })?; diff --git a/pacman-server/src/auth/github.rs b/pacman-server/src/auth/github.rs index 3554006..ac76fcf 100644 --- a/pacman-server/src/auth/github.rs +++ b/pacman-server/src/auth/github.rs @@ -1,15 +1,15 @@ use axum::{response::IntoResponse, response::Redirect}; -use dashmap::DashMap; -use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse}; use serde::{Deserialize, Serialize}; -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use std::time::{Duration, Instant}; use tracing::{trace, warn}; use crate::{ - auth::provider::{AuthUser, OAuthProvider}, + auth::{ + pkce::PkceManager, + provider::{AuthUser, OAuthProvider}, + }, errors::ErrorResponse, }; @@ -56,15 +56,7 @@ pub async fn fetch_github_user( pub struct GitHubProvider { pub client: super::OAuthClient, pub http: reqwest::Client, - pkce: DashMap, - last_purge_at_secs: AtomicU32, - pkce_additions: AtomicU32, -} - -#[derive(Debug, Clone)] -struct PkceRecord { - verifier: String, - created_at: Instant, + pkce: PkceManager, } impl GitHubProvider { @@ -72,9 +64,7 @@ impl GitHubProvider { Arc::new(Self { client, http, - pkce: DashMap::new(), - last_purge_at_secs: AtomicU32::new(0), - pkce_additions: AtomicU32::new(0), + pkce: PkceManager::new(), }) } } @@ -89,7 +79,7 @@ impl OAuthProvider for GitHubProvider { } async fn authorize(&self) -> axum::response::Response { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (pkce_challenge, verifier) = self.pkce.generate_challenge(); let (authorize_url, csrf_state) = self .client .authorize_url(CsrfToken::new_random) @@ -97,17 +87,9 @@ impl OAuthProvider for GitHubProvider { .add_scope(Scope::new("user:email".to_string())) .add_scope(Scope::new("read:user".to_string())) .url(); + // store verifier keyed by the returned state + self.pkce.store_verifier(csrf_state.secret(), verifier); trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); - // Insert PKCE verifier with timestamp and purge when needed - self.pkce.insert( - csrf_state.secret().to_string(), - PkceRecord { - verifier: pkce_verifier.secret().to_string(), - created_at: Instant::now(), - }, - ); - self.pkce_additions.fetch_add(1, Ordering::Relaxed); - self.maybe_purge_stale_pkce_entries(); Redirect::to(authorize_url.as_str()).into_response() } @@ -127,30 +109,22 @@ impl OAuthProvider for GitHubProvider { .get("state") .cloned() .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; - let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else { - warn!("Missing PKCE verifier for state parameter"); + let Some(verifier) = self.pkce.take_verifier(&state) else { + warn!(%state, "Missing or expired PKCE verifier for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", - Some("missing pkce verifier for state".into()), + Some("missing or expired pkce verifier for state".into()), )); }; - // Verify PKCE TTL - if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) { - warn!("PKCE verifier expired for state parameter"); - return Err(ErrorResponse::bad_request( - "invalid_request", - Some("expired pkce verifier for state".into()), - )); - } let token = self .client .exchange_code(AuthorizationCode::new(code)) - .set_pkce_verifier(PkceCodeVerifier::new(rec.verifier)) + .set_pkce_verifier(PkceCodeVerifier::new(verifier)) .request_async(&self.http) .await .map_err(|e| { - warn!(error = %e, "Token exchange with GitHub failed"); + warn!(error = %e, %state, "Token exchange with GitHub failed"); ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string())) })?; @@ -177,36 +151,7 @@ impl OAuthProvider for GitHubProvider { } } -impl GitHubProvider { - fn maybe_purge_stale_pkce_entries(&self) { - // Purge when at least 5 minutes passed or more than 128 additions occurred - const PURGE_INTERVAL_SECS: u32 = 5 * 60; - const ADDITIONS_THRESHOLD: u32 = 128; - - let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(d) => d.as_secs() as u32, - Err(_) => return, - }; - - let last = self.last_purge_at_secs.load(Ordering::Relaxed); - let additions = self.pkce_additions.load(Ordering::Relaxed); - if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS { - return; - } - - const PKCE_TTL: Duration = Duration::from_secs(5 * 60); - let now_inst = Instant::now(); - for entry in self.pkce.iter() { - if now_inst.duration_since(entry.value().created_at) > PKCE_TTL { - self.pkce.remove(entry.key()); - } - } - - // Reset counters after purge - self.pkce_additions.store(0, Ordering::Relaxed); - self.last_purge_at_secs.store(now_secs, Ordering::Relaxed); - } -} +impl GitHubProvider {} /// Fetch user emails from GitHub API pub async fn fetch_github_emails( diff --git a/pacman-server/src/auth/mod.rs b/pacman-server/src/auth/mod.rs index 149ff42..49d27e9 100644 --- a/pacman-server/src/auth/mod.rs +++ b/pacman-server/src/auth/mod.rs @@ -7,6 +7,7 @@ use crate::config::Config; pub mod discord; pub mod github; +pub mod pkce; pub mod provider; type OAuthClient = diff --git a/pacman-server/src/auth/pkce.rs b/pacman-server/src/auth/pkce.rs new file mode 100644 index 0000000..44c37a0 --- /dev/null +++ b/pacman-server/src/auth/pkce.rs @@ -0,0 +1,91 @@ +use dashmap::DashMap; +use oauth2::PkceCodeChallenge; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::{Duration, Instant}; +use tracing::{trace, warn}; + +#[derive(Debug, Clone)] +pub struct PkceRecord { + pub verifier: String, + pub created_at: Instant, +} + +pub struct PkceManager { + pkce: DashMap, + last_purge_at_secs: AtomicU32, + pkce_additions: AtomicU32, +} + +impl PkceManager { + pub fn new() -> Self { + Self { + pkce: DashMap::new(), + last_purge_at_secs: AtomicU32::new(0), + pkce_additions: AtomicU32::new(0), + } + } + + pub fn generate_challenge(&self) -> (PkceCodeChallenge, String) { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + trace!("PKCE challenge generated"); + (pkce_challenge, pkce_verifier.secret().to_string()) + } + + pub fn store_verifier(&self, state: &str, verifier: String) { + self.pkce.insert( + state.to_string(), + PkceRecord { + verifier, + created_at: Instant::now(), + }, + ); + self.pkce_additions.fetch_add(1, Ordering::Relaxed); + self.maybe_purge_stale_entries(); + trace!(state = state, "Stored PKCE verifier for state"); + } + + pub fn take_verifier(&self, state: &str) -> Option { + let Some(record) = self.pkce.remove(state).map(|e| e.1) else { + trace!(state = state, "PKCE verifier not found for state"); + return None; + }; + + // Verify PKCE TTL + if Instant::now().duration_since(record.created_at) > Duration::from_secs(5 * 60) { + warn!(state = state, "PKCE verifier expired for state"); + return None; + } + + trace!(state = state, "PKCE verifier retrieved for state"); + Some(record.verifier) + } + + fn maybe_purge_stale_entries(&self) { + // Purge when at least 5 minutes passed or more than 128 additions occurred + const PURGE_INTERVAL_SECS: u32 = 5 * 60; + const ADDITIONS_THRESHOLD: u32 = 128; + + let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { + Ok(d) => d.as_secs() as u32, + Err(_) => return, + }; + + let last = self.last_purge_at_secs.load(Ordering::Relaxed); + let additions = self.pkce_additions.load(Ordering::Relaxed); + if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS { + return; + } + + const PKCE_TTL: Duration = Duration::from_secs(5 * 60); + let now_inst = Instant::now(); + for entry in self.pkce.iter() { + if now_inst.duration_since(entry.value().created_at) > PKCE_TTL { + self.pkce.remove(entry.key()); + } + } + + // Reset counters after purge + self.pkce_additions.store(0, Ordering::Relaxed); + self.last_purge_at_secs.store(now_secs, Ordering::Relaxed); + } +} diff --git a/pacman-server/src/data/user.rs b/pacman-server/src/data/user.rs index 3f5b276..3ea9bd0 100644 --- a/pacman-server/src/data/user.rs +++ b/pacman-server/src/data/user.rs @@ -23,7 +23,7 @@ pub struct OAuthAccount { pub updated_at: chrono::DateTime, } -pub async fn get_user_by_email(pool: &sqlx::PgPool, email: &str) -> Result, sqlx::Error> { +pub async fn find_user_by_email(pool: &sqlx::PgPool, email: &str) -> Result, sqlx::Error> { sqlx::query_as::<_, User>( r#" SELECT id, email, created_at, updated_at @@ -115,7 +115,7 @@ pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64) Ok(rec.0) } -pub async fn get_user_by_provider_id( +pub async fn find_user_by_provider_id( pool: &sqlx::PgPool, provider: &str, provider_user_id: &str, diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 0e17e39..677be73 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use axum::{ extract::{Path, Query, State}, http::StatusCode, @@ -42,6 +44,7 @@ pub async fn oauth_authorize_handler( .http_only(true) .same_site(axum_cookie::prelude::SameSite::Lax) .path("/") + .max_age(std::time::Duration::from_secs(120)) .build(), ); } @@ -57,15 +60,19 @@ pub async fn oauth_callback_handler( Query(params): Query, cookie: CookieManager, ) -> axum::response::Response { + // Validate provider let Some(prov) = app_state.auth.get(&provider) else { warn!(%provider, "Unknown OAuth provider"); return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response(); }; + + // Process callback-returned errors from provider if let Some(error) = params.error { warn!(%provider, error = %error, desc = ?params.error_description, "OAuth callback returned an error"); return ErrorResponse::bad_request(error, params.error_description).into_response(); } - let mut q = std::collections::HashMap::new(); + + let mut q = HashMap::new(); if let Some(v) = params.code { q.insert("code".to_string(), v); } @@ -85,49 +92,56 @@ pub async fn oauth_callback_handler( cookie.remove("link"); } let email = user.email.as_deref(); - let _db_user = if link_cookie.as_deref() == Some("1") { - // Must be logged in already to link - let Some(session_token) = session::get_session_token(&cookie) else { - return ErrorResponse::bad_request("invalid_request", Some("must be signed in to link provider".into())) - .into_response(); - }; - let Some(claims) = session::decode_jwt(&session_token, &app_state.jwt_decoding_key) else { - return ErrorResponse::bad_request("invalid_request", Some("invalid session token".into())).into_response(); - }; - // Resolve current user from session - let (cur_prov, cur_id) = claims.sub.split_once(':').unwrap_or(("", "")); - let current_user = match user_repo::get_user_by_provider_id(&app_state.db, cur_prov, cur_id).await { - Ok(Some(u)) => u, - Ok(None) => { - return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into())) - .into_response(); + // Determine linking intent with a valid session + let is_link = if link_cookie.as_deref() == Some("1") { + match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) { + Some(c) => { + // Perform linking with current session user + let (cur_prov, cur_id) = c.sub.split_once(':').unwrap_or(("", "")); + let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await { + Ok(Some(u)) => u, + Ok(None) => { + warn!("Current session user not found; proceeding as normal sign-in"); + return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into())) + .into_response(); + } + Err(_) => { + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None) + .into_response(); + } + }; + if let Err(e) = user_repo::link_oauth_account( + &app_state.db, + current_user.id, + &provider, + &user.id, + email, + Some(&user.username), + user.name.as_deref(), + user.avatar_url.as_deref(), + ) + .await + { + warn!(error = %e, %provider, "Failed to link OAuth account"); + return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + } + return (StatusCode::FOUND, Redirect::to("/profile")).into_response(); } - Err(_) => { - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); + None => { + warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in"); + false } - }; - - // Link provider to current user - if let Err(e) = user_repo::link_oauth_account( - &app_state.db, - current_user.id, - &provider, - &user.id, - email, - Some(&user.username), - user.name.as_deref(), - user.avatar_url.as_deref(), - ) - .await - { - warn!(error = %e, %provider, "Failed to link OAuth account"); - return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response(); } - current_user + } else { + false + }; + + if is_link { + unreachable!(); // handled via early return above } else { // Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow. if let Some(e) = email { - if let Ok(Some(existing)) = user_repo::get_user_by_email(&app_state.db, e).await { + if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await { // Only block if the user already has at least one linked provider. // NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive, // this may lock them out until the provider is reactivated or a manual admin link is performed. @@ -198,9 +212,11 @@ pub async fn oauth_callback_handler( .into_response(); } }; + let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); session::set_session_cookie(&cookie, &session_token); info!(%provider, "Signed in successfully"); + (StatusCode::FOUND, Redirect::to("/profile")).into_response() } @@ -222,7 +238,7 @@ pub async fn profile_handler(State(app_state): State, cookie: CookieMa return ErrorResponse::unauthorized("invalid session token").into_response(); } }; - match user_repo::get_user_by_provider_id(&app_state.db, prov, prov_user_id).await { + match user_repo::find_user_by_provider_id(&app_state.db, prov, prov_user_id).await { Ok(Some(db_user)) => { // Include linked providers in the profile payload match user_repo::list_user_providers(&app_state.db, db_user.id).await {