feat: setup smarter PKCE map purging & BasicClient type alias, smarter EnvFilter string building

This commit is contained in:
Ryan Walters
2025-09-17 04:06:52 -05:00
parent 92acb07b04
commit 8e23fb66a4
2 changed files with 64 additions and 34 deletions

View File

@@ -3,6 +3,7 @@ use dashmap::DashMap;
use oauth2::{basic::BasicClient, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; use oauth2::{basic::BasicClient, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::{trace, warn}; use tracing::{trace, warn};
@@ -12,6 +13,10 @@ use crate::{
errors::ErrorResponse, errors::ErrorResponse,
}; };
// Private type alias for the OAuth2 BasicClient specialized type
type OAuthClient =
BasicClient<oauth2::EndpointSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointSet>;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitHubUser { pub struct GitHubUser {
pub id: u64, pub id: u64,
@@ -53,32 +58,27 @@ pub async fn fetch_github_user(
} }
pub struct GitHubProvider { pub struct GitHubProvider {
pub client: BasicClient< pub client: OAuthClient,
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>,
pub http: reqwest::Client, pub http: reqwest::Client,
pkce: DashMap<String, (String, Instant)>, pkce: DashMap<String, PkceRecord>,
last_purge_at_secs: AtomicU32,
pkce_additions: AtomicU32,
}
#[derive(Debug, Clone)]
struct PkceRecord {
verifier: String,
created_at: Instant,
} }
impl GitHubProvider { impl GitHubProvider {
pub fn new( pub fn new(client: OAuthClient, http: reqwest::Client) -> Arc<Self> {
client: BasicClient<
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>,
http: reqwest::Client,
) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
client, client,
http, http,
pkce: DashMap::new(), pkce: DashMap::new(),
last_purge_at_secs: AtomicU32::new(0),
pkce_additions: AtomicU32::new(0),
}) })
} }
} }
@@ -101,18 +101,17 @@ impl OAuthProvider for GitHubProvider {
.add_scope(Scope::new("user:email".to_string())) .add_scope(Scope::new("user:email".to_string()))
.add_scope(Scope::new("read:user".to_string())) .add_scope(Scope::new("read:user".to_string()))
.url(); .url();
// Insert PKCE verifier with timestamp and purge stale entries
let now = Instant::now();
self.pkce
.insert(csrf_state.secret().to_string(), (pkce_verifier.secret().to_string(), now));
// Best-effort cleanup to avoid unbounded growth
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
for entry in self.pkce.iter() {
if now.duration_since(entry.value().1) > PKCE_TTL {
self.pkce.remove(entry.key());
}
}
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL"); trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
// Insert PKCE verifier with timestamp and purge when needed
self.pkce.insert(
csrf_state.secret().to_string(),
PkceRecord {
verifier: pkce_verifier.secret().to_string(),
created_at: Instant::now(),
},
);
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
self.maybe_purge_stale_pkce_entries();
Redirect::to(authorize_url.as_str()).into_response() Redirect::to(authorize_url.as_str()).into_response()
} }
@@ -132,7 +131,7 @@ impl OAuthProvider for GitHubProvider {
.get("state") .get("state")
.cloned() .cloned()
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
let Some((verifier, created_at)) = self.pkce.remove(&state).map(|e| e.1) else { let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else {
warn!("Missing PKCE verifier for state parameter"); warn!("Missing PKCE verifier for state parameter");
return Err(ErrorResponse::bad_request( return Err(ErrorResponse::bad_request(
"invalid_request", "invalid_request",
@@ -140,7 +139,7 @@ impl OAuthProvider for GitHubProvider {
)); ));
}; };
// Verify PKCE TTL // Verify PKCE TTL
if Instant::now().duration_since(created_at) > Duration::from_secs(5 * 60) { if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) {
warn!("PKCE verifier expired for state parameter"); warn!("PKCE verifier expired for state parameter");
return Err(ErrorResponse::bad_request( return Err(ErrorResponse::bad_request(
"invalid_request", "invalid_request",
@@ -151,7 +150,7 @@ impl OAuthProvider for GitHubProvider {
let token = self let token = self
.client .client
.exchange_code(AuthorizationCode::new(code)) .exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(PkceCodeVerifier::new(verifier)) .set_pkce_verifier(PkceCodeVerifier::new(rec.verifier))
.request_async(&self.http) .request_async(&self.http)
.await .await
.map_err(|e| { .map_err(|e| {
@@ -182,6 +181,37 @@ impl OAuthProvider for GitHubProvider {
} }
} }
impl GitHubProvider {
fn maybe_purge_stale_pkce_entries(&self) {
// Purge when at least 5 minutes passed or more than 128 additions occurred
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
const ADDITIONS_THRESHOLD: u32 = 128;
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs() as u32,
Err(_) => return,
};
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
let additions = self.pkce_additions.load(Ordering::Relaxed);
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
return;
}
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
let now_inst = Instant::now();
for entry in self.pkce.iter() {
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
self.pkce.remove(entry.key());
}
}
// Reset counters after purge
self.pkce_additions.store(0, Ordering::Relaxed);
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
}
}
/// Fetch user emails from GitHub API /// Fetch user emails from GitHub API
pub async fn fetch_github_emails( pub async fn fetch_github_emails(
http_client: &reqwest::Client, http_client: &reqwest::Client,

View File

@@ -7,8 +7,8 @@ use crate::formatter;
/// Configure and initialize logging for the application /// Configure and initialize logging for the application
pub fn setup_logging(_config: &Config) { pub fn setup_logging(_config: &Config) {
// Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere // Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere
let filter = let filter = EnvFilter::try_from_default_env()
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn,pacman_server=info,pacman_server::auth=info")); .unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME"))));
// Default to pretty for local dev; switchable later if we add CLI // Default to pretty for local dev; switchable later if we add CLI
let use_pretty = cfg!(debug_assertions); let use_pretty = cfg!(debug_assertions);