diff --git a/pacman-server/src/auth/github.rs b/pacman-server/src/auth/github.rs index f677431..53e380b 100644 --- a/pacman-server/src/auth/github.rs +++ b/pacman-server/src/auth/github.rs @@ -3,6 +3,7 @@ use dashmap::DashMap; use oauth2::{basic::BasicClient, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::{trace, warn}; @@ -12,6 +13,10 @@ use crate::{ errors::ErrorResponse, }; +// Private type alias for the OAuth2 BasicClient specialized type +type OAuthClient = + BasicClient; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GitHubUser { pub id: u64, @@ -53,32 +58,27 @@ pub async fn fetch_github_user( } pub struct GitHubProvider { - pub client: BasicClient< - oauth2::EndpointSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointSet, - >, + pub client: OAuthClient, pub http: reqwest::Client, - pkce: DashMap, + pkce: DashMap, + last_purge_at_secs: AtomicU32, + pkce_additions: AtomicU32, +} + +#[derive(Debug, Clone)] +struct PkceRecord { + verifier: String, + created_at: Instant, } impl GitHubProvider { - pub fn new( - client: BasicClient< - oauth2::EndpointSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointSet, - >, - http: reqwest::Client, - ) -> Arc { + pub fn new(client: OAuthClient, http: reqwest::Client) -> Arc { Arc::new(Self { client, http, pkce: DashMap::new(), + last_purge_at_secs: AtomicU32::new(0), + pkce_additions: AtomicU32::new(0), }) } } @@ -101,18 +101,17 @@ impl OAuthProvider for GitHubProvider { .add_scope(Scope::new("user:email".to_string())) .add_scope(Scope::new("read:user".to_string())) .url(); - // Insert PKCE verifier with timestamp and purge stale entries - let now = Instant::now(); - self.pkce - .insert(csrf_state.secret().to_string(), (pkce_verifier.secret().to_string(), now)); - // Best-effort cleanup to avoid unbounded growth - const PKCE_TTL: Duration = Duration::from_secs(5 * 60); - for entry in self.pkce.iter() { - if now.duration_since(entry.value().1) > PKCE_TTL { - self.pkce.remove(entry.key()); - } - } trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); + // Insert PKCE verifier with timestamp and purge when needed + self.pkce.insert( + csrf_state.secret().to_string(), + PkceRecord { + verifier: pkce_verifier.secret().to_string(), + created_at: Instant::now(), + }, + ); + self.pkce_additions.fetch_add(1, Ordering::Relaxed); + self.maybe_purge_stale_pkce_entries(); Redirect::to(authorize_url.as_str()).into_response() } @@ -132,7 +131,7 @@ impl OAuthProvider for GitHubProvider { .get("state") .cloned() .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; - let Some((verifier, created_at)) = self.pkce.remove(&state).map(|e| e.1) else { + let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else { warn!("Missing PKCE verifier for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", @@ -140,7 +139,7 @@ impl OAuthProvider for GitHubProvider { )); }; // Verify PKCE TTL - if Instant::now().duration_since(created_at) > Duration::from_secs(5 * 60) { + if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) { warn!("PKCE verifier expired for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", @@ -151,7 +150,7 @@ impl OAuthProvider for GitHubProvider { let token = self .client .exchange_code(AuthorizationCode::new(code)) - .set_pkce_verifier(PkceCodeVerifier::new(verifier)) + .set_pkce_verifier(PkceCodeVerifier::new(rec.verifier)) .request_async(&self.http) .await .map_err(|e| { @@ -182,6 +181,37 @@ impl OAuthProvider for GitHubProvider { } } +impl GitHubProvider { + fn maybe_purge_stale_pkce_entries(&self) { + // Purge when at least 5 minutes passed or more than 128 additions occurred + const PURGE_INTERVAL_SECS: u32 = 5 * 60; + const ADDITIONS_THRESHOLD: u32 = 128; + + let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { + Ok(d) => d.as_secs() as u32, + Err(_) => return, + }; + + let last = self.last_purge_at_secs.load(Ordering::Relaxed); + let additions = self.pkce_additions.load(Ordering::Relaxed); + if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS { + return; + } + + const PKCE_TTL: Duration = Duration::from_secs(5 * 60); + let now_inst = Instant::now(); + for entry in self.pkce.iter() { + if now_inst.duration_since(entry.value().created_at) > PKCE_TTL { + self.pkce.remove(entry.key()); + } + } + + // Reset counters after purge + self.pkce_additions.store(0, Ordering::Relaxed); + self.last_purge_at_secs.store(now_secs, Ordering::Relaxed); + } +} + /// Fetch user emails from GitHub API pub async fn fetch_github_emails( http_client: &reqwest::Client, diff --git a/pacman-server/src/logging.rs b/pacman-server/src/logging.rs index 8f261f8..a61ed4c 100644 --- a/pacman-server/src/logging.rs +++ b/pacman-server/src/logging.rs @@ -7,8 +7,8 @@ use crate::formatter; /// Configure and initialize logging for the application pub fn setup_logging(_config: &Config) { // Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere - let filter = - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn,pacman_server=info,pacman_server::auth=info")); + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME")))); // Default to pretty for local dev; switchable later if we add CLI let use_pretty = cfg!(debug_assertions);