From 67c9460c8421f2f677eb6e617fc61f142ea45d53 Mon Sep 17 00:00:00 2001 From: Ryan Walters Date: Fri, 19 Sep 2025 10:23:33 -0500 Subject: [PATCH] refactor(auth): implement session-based PKCE and eliminate provider duplication - Replace in-memory PKCE storage with encrypted session cookies - Add PKCE verifier and CSRF state fields to JWT Claims struct - Move common PKCE validation logic to OAuthProvider trait - Extract provider-specific methods for token exchange and user fetching - Remove PkceManager and DashMap-based storage system - Update GitHub and Discord providers to use new session-based approach --- pacman-server/src/auth/discord.rs | 54 +++++++--------- pacman-server/src/auth/github.rs | 83 +++++++----------------- pacman-server/src/auth/mod.rs | 1 - pacman-server/src/auth/pkce.rs | 84 ------------------------ pacman-server/src/auth/provider.rs | 100 +++++++++++++++++++++++++++-- pacman-server/src/routes.rs | 4 +- pacman-server/src/session.rs | 31 +++++++++ 7 files changed, 175 insertions(+), 182 deletions(-) delete mode 100644 pacman-server/src/auth/pkce.rs diff --git a/pacman-server/src/auth/discord.rs b/pacman-server/src/auth/discord.rs index 3d44e38..4d82020 100644 --- a/pacman-server/src/auth/discord.rs +++ b/pacman-server/src/auth/discord.rs @@ -1,15 +1,15 @@ use axum::{response::IntoResponse, response::Redirect}; +use axum_cookie::CookieManager; +use jsonwebtoken::EncodingKey; use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::{trace, warn}; -use crate::auth::{ - pkce::PkceManager, - provider::{AuthUser, OAuthProvider}, -}; +use crate::auth::provider::{AuthUser, OAuthProvider}; use crate::errors::ErrorResponse; +use crate::session; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DiscordUser { @@ -43,16 +43,11 @@ pub async fn fetch_discord_user( pub struct DiscordProvider { pub client: super::OAuthClient, pub http: reqwest::Client, - pkce: PkceManager, } impl DiscordProvider { pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc { - Arc::new(Self { - client, - http, - pkce: PkceManager::default(), - }) + Arc::new(Self { client, http }) } fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String { @@ -70,8 +65,8 @@ impl OAuthProvider for DiscordProvider { "Discord" } - async fn authorize(&self) -> axum::response::Response { - let (pkce_challenge, verifier) = self.pkce.generate_challenge(); + async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response { + let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256(); let (authorize_url, csrf_state) = self .client .authorize_url(CsrfToken::new_random) @@ -79,38 +74,35 @@ impl OAuthProvider for DiscordProvider { .add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("email".to_string())) .url(); - self.pkce.store_verifier(csrf_state.secret(), verifier); - trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); + // Store PKCE verifier and CSRF state in session + let session_token = session::create_pkce_session(pkce_verifier.secret(), csrf_state.secret(), encoding_key); + session::set_session_cookie(cookie, &session_token); + + trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); Redirect::to(authorize_url.as_str()).into_response() } - async fn handle_callback(&self, code: &str, state: &str) -> Result { - let Some(verifier) = self.pkce.take_verifier(state) else { - warn!(%state, "Missing or expired PKCE verifier for state parameter"); - return Err(ErrorResponse::bad_request( - "invalid_request", - Some("missing or expired pkce verifier for state".into()), - )); - }; - + async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result { let token = self .client .exchange_code(AuthorizationCode::new(code.to_string())) - .set_pkce_verifier(PkceCodeVerifier::new(verifier)) + .set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string())) .request_async(&self.http) .await .map_err(|e| { - warn!(error = %e, %state, "Token exchange with Discord failed"); + warn!(error = %e, "Token exchange with Discord failed"); ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string())) })?; - let user = fetch_discord_user(&self.http, token.access_token().secret()) - .await - .map_err(|e| { - warn!(error = %e, "Failed to fetch Discord user profile"); - ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e))) - })?; + Ok(token.access_token().secret().to_string()) + } + + async fn fetch_user_from_token(&self, access_token: &str) -> Result { + let user = fetch_discord_user(&self.http, access_token).await.map_err(|e| { + warn!(error = %e, "Failed to fetch Discord user profile"); + ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e))) + })?; let avatar_url = match (&user.id, &user.avatar) { (id, Some(hash)) => Some(Self::avatar_url_for(id, hash)), diff --git a/pacman-server/src/auth/github.rs b/pacman-server/src/auth/github.rs index ffbb07a..8729234 100644 --- a/pacman-server/src/auth/github.rs +++ b/pacman-server/src/auth/github.rs @@ -1,4 +1,6 @@ use axum::{response::IntoResponse, response::Redirect}; +use axum_cookie::CookieManager; +use jsonwebtoken::EncodingKey; use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse}; use serde::{Deserialize, Serialize}; @@ -6,11 +8,9 @@ use std::sync::Arc; use tracing::{trace, warn}; use crate::{ - auth::{ - pkce::PkceManager, - provider::{AuthUser, OAuthProvider}, - }, + auth::provider::{AuthUser, OAuthProvider}, errors::ErrorResponse, + session, }; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -56,16 +56,11 @@ pub async fn fetch_github_user( pub struct GitHubProvider { pub client: super::OAuthClient, pub http: reqwest::Client, - pkce: PkceManager, } impl GitHubProvider { pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc { - Arc::new(Self { - client, - http, - pkce: PkceManager::default(), - }) + Arc::new(Self { client, http }) } } @@ -78,8 +73,8 @@ impl OAuthProvider for GitHubProvider { "GitHub" } - async fn authorize(&self) -> axum::response::Response { - let (pkce_challenge, verifier) = self.pkce.generate_challenge(); + async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response { + let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256(); let (authorize_url, csrf_state) = self .client .authorize_url(CsrfToken::new_random) @@ -87,44 +82,35 @@ impl OAuthProvider for GitHubProvider { .add_scope(Scope::new("user:email".to_string())) .add_scope(Scope::new("read:user".to_string())) .url(); - // store verifier keyed by the returned state - self.pkce.store_verifier(csrf_state.secret(), verifier); + + // Store PKCE verifier and CSRF state in session + let session_token = session::create_pkce_session(pkce_verifier.secret(), csrf_state.secret(), encoding_key); + session::set_session_cookie(cookie, &session_token); + trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); Redirect::to(authorize_url.as_str()).into_response() } - async fn handle_callback(&self, code: &str, state: &str) -> Result { - let Some(verifier) = self.pkce.take_verifier(state) else { - warn!(%state, "Missing or expired PKCE verifier for state parameter"); - return Err(ErrorResponse::bad_request( - "invalid_request", - Some("missing or expired pkce verifier for state".into()), - )); - }; - + async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result { let token = self .client .exchange_code(AuthorizationCode::new(code.to_string())) - .set_pkce_verifier(PkceCodeVerifier::new(verifier)) + .set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string())) .request_async(&self.http) .await .map_err(|e| { - warn!(error = %e, %state, "Token exchange with GitHub failed"); + warn!(error = %e, "Token exchange with GitHub failed"); ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string())) })?; - let user = fetch_github_user(&self.http, token.access_token().secret()) - .await - .map_err(|e| { - warn!(error = %e, "Failed to fetch GitHub user profile"); - ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e))) - })?; - let _emails = fetch_github_emails(&self.http, token.access_token().secret()) - .await - .map_err(|e| { - warn!(error = %e, "Failed to fetch GitHub user emails"); - ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e))) - })?; + Ok(token.access_token().secret().to_string()) + } + + async fn fetch_user_from_token(&self, access_token: &str) -> Result { + let user = fetch_github_user(&self.http, access_token).await.map_err(|e| { + warn!(error = %e, "Failed to fetch GitHub user profile"); + ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e))) + })?; Ok(AuthUser { id: user.id.to_string(), @@ -135,26 +121,3 @@ impl OAuthProvider for GitHubProvider { }) } } - -impl GitHubProvider {} - -/// Fetch user emails from GitHub API -pub async fn fetch_github_emails( - http_client: &reqwest::Client, - access_token: &str, -) -> Result, Box> { - let response = http_client - .get("https://api.github.com/user/emails") - .header("Authorization", format!("Bearer {}", access_token)) - .header("Accept", "application/vnd.github.v3+json") - .header("User-Agent", crate::config::USER_AGENT) - .send() - .await?; - - if !response.status().is_success() { - return Err(format!("GitHub API error: {}", response.status()).into()); - } - - let emails: Vec = response.json().await?; - Ok(emails) -} diff --git a/pacman-server/src/auth/mod.rs b/pacman-server/src/auth/mod.rs index 49d27e9..149ff42 100644 --- a/pacman-server/src/auth/mod.rs +++ b/pacman-server/src/auth/mod.rs @@ -7,7 +7,6 @@ use crate::config::Config; pub mod discord; pub mod github; -pub mod pkce; pub mod provider; type OAuthClient = diff --git a/pacman-server/src/auth/pkce.rs b/pacman-server/src/auth/pkce.rs deleted file mode 100644 index c3d520e..0000000 --- a/pacman-server/src/auth/pkce.rs +++ /dev/null @@ -1,84 +0,0 @@ -use dashmap::DashMap; -use oauth2::PkceCodeChallenge; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::time::{Duration, Instant}; -use tracing::{trace, warn}; - -#[derive(Debug, Clone)] -pub struct PkceRecord { - pub verifier: String, - pub created_at: Instant, -} - -#[derive(Default)] -pub struct PkceManager { - pkce: DashMap, - last_purge_at_secs: AtomicU32, - pkce_additions: AtomicU32, -} - -impl PkceManager { - pub fn generate_challenge(&self) -> (PkceCodeChallenge, String) { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - trace!("PKCE challenge generated"); - (pkce_challenge, pkce_verifier.secret().to_string()) - } - - pub fn store_verifier(&self, state: &str, verifier: String) { - self.pkce.insert( - state.to_string(), - PkceRecord { - verifier, - created_at: Instant::now(), - }, - ); - self.pkce_additions.fetch_add(1, Ordering::Relaxed); - self.maybe_purge_stale_entries(); - trace!(state = state, "Stored PKCE verifier for state"); - } - - pub fn take_verifier(&self, state: &str) -> Option { - let Some(record) = self.pkce.remove(state).map(|e| e.1) else { - trace!(state = state, "PKCE verifier not found for state"); - return None; - }; - - // Verify PKCE TTL - if Instant::now().duration_since(record.created_at) > Duration::from_secs(5 * 60) { - warn!(state = state, "PKCE verifier expired for state"); - return None; - } - - trace!(state = state, "PKCE verifier retrieved for state"); - Some(record.verifier) - } - - fn maybe_purge_stale_entries(&self) { - // Purge when at least 5 minutes passed or more than 128 additions occurred - const PURGE_INTERVAL_SECS: u32 = 5 * 60; - const ADDITIONS_THRESHOLD: u32 = 128; - - let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(d) => d.as_secs() as u32, - Err(_) => return, - }; - - let last = self.last_purge_at_secs.load(Ordering::Relaxed); - let additions = self.pkce_additions.load(Ordering::Relaxed); - if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS { - return; - } - - const PKCE_TTL: Duration = Duration::from_secs(5 * 60); - let now_inst = Instant::now(); - for entry in self.pkce.iter() { - if now_inst.duration_since(entry.value().created_at) > PKCE_TTL { - self.pkce.remove(entry.key()); - } - } - - // Reset counters after purge - self.pkce_additions.store(0, Ordering::Relaxed); - self.last_purge_at_secs.store(now_secs, Ordering::Relaxed); - } -} diff --git a/pacman-server/src/auth/provider.rs b/pacman-server/src/auth/provider.rs index 45edf26..1b69d64 100644 --- a/pacman-server/src/auth/provider.rs +++ b/pacman-server/src/auth/provider.rs @@ -1,28 +1,120 @@ use async_trait::async_trait; +use axum_cookie::CookieManager; +use jsonwebtoken::{DecodingKey, EncodingKey}; use mockall::automock; use serde::Serialize; +use tracing::warn; use crate::errors::ErrorResponse; +use crate::session; +// A user object returned from the provider after authentication. #[derive(Debug, Clone, Serialize)] pub struct AuthUser { + // A unique identifier for the user, from the provider. pub id: String, + // A username from the provider. Generally unique, a handle for the user. pub username: String, + + // A display name for the user. Not always available. pub name: Option, + // An email address for the user. Not always available. pub email: Option, + // An avatar URL for the user. Not always available. pub avatar_url: Option, } #[automock] #[async_trait] pub trait OAuthProvider: Send + Sync { + // Builds a server response to redirect the user to the provider's authorization page. + // This generally also includes beginning a PKCE flow (proof key for code exchange). + // The cookie manager is used to store the PKCE verifier in the session. + async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response; + + // Handles the callback from the provider after the user has authorized the app. + // This generally also includes completing the PKCE flow (proof key for code exchange). + // The cookie manager is used to retrieve the PKCE verifier from the session. + async fn handle_callback( + &self, + code: &str, + state: &str, + cookie: &CookieManager, + decoding_key: &DecodingKey, + ) -> Result { + // Common PKCE session validation and token exchange logic + let verifier = self.validate_pkce_session(cookie, state, decoding_key).await?; + let access_token = self.exchange_code_for_token(code, &verifier).await?; + let user = self.fetch_user_from_token(&access_token).await?; + Ok(user) + } + + // Validates the PKCE session and returns the verifier + async fn validate_pkce_session( + &self, + cookie: &CookieManager, + state: &str, + decoding_key: &DecodingKey, + ) -> Result { + // Get the session token and verify it's a PKCE session + let Some(session_token) = session::get_session_token(cookie) else { + warn!(%state, "Missing session cookie during OAuth callback"); + return Err(ErrorResponse::bad_request( + "invalid_request", + Some("missing session cookie".into()), + )); + }; + + let Some(claims) = session::decode_jwt(&session_token, decoding_key) else { + warn!(%state, "Invalid session token during OAuth callback"); + return Err(ErrorResponse::bad_request( + "invalid_request", + Some("invalid session token".into()), + )); + }; + + // Verify this is a PKCE session and the state matches + if !session::is_pkce_session(&claims) { + warn!(%state, "Session is not a PKCE session"); + return Err(ErrorResponse::bad_request( + "invalid_request", + Some("invalid session type".into()), + )); + } + + if claims.csrf_state.as_deref() != Some(state) { + warn!(%state, "CSRF state mismatch during OAuth callback"); + return Err(ErrorResponse::bad_request( + "invalid_request", + Some("state parameter mismatch".into()), + )); + } + + let Some(verifier) = claims.pkce_verifier else { + warn!(%state, "Missing PKCE verifier in session"); + return Err(ErrorResponse::bad_request( + "invalid_request", + Some("missing pkce verifier".into()), + )); + }; + + Ok(verifier) + } + + // Exchanges the authorization code for an access token using PKCE + async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result; + + // Fetches user information from the provider using the access token + async fn fetch_user_from_token(&self, access_token: &str) -> Result; + + // The provider's unique identifier (e.g. "discord") fn id(&self) -> &'static str; + + // The provider's display name (e.g. "Discord") fn label(&self) -> &'static str; + + // Whether the provider is active (defaults to true for now) fn active(&self) -> bool { true } - - async fn authorize(&self) -> axum::response::Response; - - async fn handle_callback(&self, code: &str, state: &str) -> Result; } diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 8e1bf50..b5c47a7 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -46,7 +46,7 @@ pub async fn oauth_authorize_handler( .build(), ); } - let resp = prov.authorize().await; + let resp = prov.authorize(&cookie, &app_state.jwt_encoding_key).await; trace!("Redirecting to provider authorization page"); resp } @@ -80,7 +80,7 @@ pub async fn oauth_callback_handler( span!(tracing::Level::DEBUG, "oauth_callback_handler", provider = %provider, code = %code, state = %state); // Handle callback from provider - let user = match prov.handle_callback(code, state).await { + let user = match prov.handle_callback(code, state, &cookie, &app_state.jwt_decoding_key).await { Ok(u) => u, Err(e) => { warn!(%provider, "OAuth callback handling failed"); diff --git a/pacman-server/src/session.rs b/pacman-server/src/session.rs index ee541cb..ef31d84 100644 --- a/pacman-server/src/session.rs +++ b/pacman-server/src/session.rs @@ -15,6 +15,11 @@ pub struct Claims { pub name: Option, pub iat: usize, pub exp: usize, + // PKCE flow fields - only present during OAuth flow + #[serde(rename = "ver", skip_serializing_if = "Option::is_none")] + pub pkce_verifier: Option, + #[serde(rename = "st", skip_serializing_if = "Option::is_none")] + pub csrf_state: Option, } pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String { @@ -27,12 +32,38 @@ pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &Encod name: user.name.clone(), iat: now, exp: now + JWT_TTL_SECS as usize, + pkce_verifier: None, + csrf_state: None, }; let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign"); trace!(sub = %claims.sub, exp = claims.exp, "Created session JWT"); token } +/// Creates a temporary session for PKCE flow with verifier and CSRF state +pub fn create_pkce_session(pkce_verifier: &str, csrf_state: &str, encoding_key: &EncodingKey) -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time went backwards") + .as_secs() as usize; + let claims = Claims { + sub: "pkce_flow".to_string(), // Special marker for PKCE flow + name: None, + iat: now, + exp: now + JWT_TTL_SECS as usize, + pkce_verifier: Some(pkce_verifier.to_string()), + csrf_state: Some(csrf_state.to_string()), + }; + let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign"); + trace!(csrf_state = %csrf_state, "Created PKCE session JWT"); + token +} + +/// Checks if a session is a PKCE flow session +pub fn is_pkce_session(claims: &Claims) -> bool { + claims.sub == "pkce_flow" && claims.pkce_verifier.is_some() && claims.csrf_state.is_some() +} + pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option { let mut validation = Validation::new(Algorithm::HS256); validation.leeway = 30;