mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-13 10:12:19 -06:00
refactor: create common pkce handling, max_age on link cookie
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
use axum::{response::IntoResponse, response::Redirect};
|
||||
use dashmap::DashMap;
|
||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use crate::{
|
||||
auth::provider::{AuthUser, OAuthProvider},
|
||||
auth::{
|
||||
pkce::PkceManager,
|
||||
provider::{AuthUser, OAuthProvider},
|
||||
},
|
||||
errors::ErrorResponse,
|
||||
};
|
||||
|
||||
@@ -56,15 +56,7 @@ pub async fn fetch_github_user(
|
||||
pub struct GitHubProvider {
|
||||
pub client: super::OAuthClient,
|
||||
pub http: reqwest::Client,
|
||||
pkce: DashMap<String, PkceRecord>,
|
||||
last_purge_at_secs: AtomicU32,
|
||||
pkce_additions: AtomicU32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PkceRecord {
|
||||
verifier: String,
|
||||
created_at: Instant,
|
||||
pkce: PkceManager,
|
||||
}
|
||||
|
||||
impl GitHubProvider {
|
||||
@@ -72,9 +64,7 @@ impl GitHubProvider {
|
||||
Arc::new(Self {
|
||||
client,
|
||||
http,
|
||||
pkce: DashMap::new(),
|
||||
last_purge_at_secs: AtomicU32::new(0),
|
||||
pkce_additions: AtomicU32::new(0),
|
||||
pkce: PkceManager::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -89,7 +79,7 @@ impl OAuthProvider for GitHubProvider {
|
||||
}
|
||||
|
||||
async fn authorize(&self) -> axum::response::Response {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
|
||||
let (authorize_url, csrf_state) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
@@ -97,17 +87,9 @@ impl OAuthProvider for GitHubProvider {
|
||||
.add_scope(Scope::new("user:email".to_string()))
|
||||
.add_scope(Scope::new("read:user".to_string()))
|
||||
.url();
|
||||
// store verifier keyed by the returned state
|
||||
self.pkce.store_verifier(csrf_state.secret(), verifier);
|
||||
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||
// Insert PKCE verifier with timestamp and purge when needed
|
||||
self.pkce.insert(
|
||||
csrf_state.secret().to_string(),
|
||||
PkceRecord {
|
||||
verifier: pkce_verifier.secret().to_string(),
|
||||
created_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
||||
self.maybe_purge_stale_pkce_entries();
|
||||
Redirect::to(authorize_url.as_str()).into_response()
|
||||
}
|
||||
|
||||
@@ -127,30 +109,22 @@ impl OAuthProvider for GitHubProvider {
|
||||
.get("state")
|
||||
.cloned()
|
||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
||||
let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else {
|
||||
warn!("Missing PKCE verifier for state parameter");
|
||||
let Some(verifier) = self.pkce.take_verifier(&state) else {
|
||||
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
||||
return Err(ErrorResponse::bad_request(
|
||||
"invalid_request",
|
||||
Some("missing pkce verifier for state".into()),
|
||||
Some("missing or expired pkce verifier for state".into()),
|
||||
));
|
||||
};
|
||||
// Verify PKCE TTL
|
||||
if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) {
|
||||
warn!("PKCE verifier expired for state parameter");
|
||||
return Err(ErrorResponse::bad_request(
|
||||
"invalid_request",
|
||||
Some("expired pkce verifier for state".into()),
|
||||
));
|
||||
}
|
||||
|
||||
let token = self
|
||||
.client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(PkceCodeVerifier::new(rec.verifier))
|
||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||
.request_async(&self.http)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!(error = %e, "Token exchange with GitHub failed");
|
||||
warn!(error = %e, %state, "Token exchange with GitHub failed");
|
||||
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||
})?;
|
||||
|
||||
@@ -177,36 +151,7 @@ impl OAuthProvider for GitHubProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl GitHubProvider {
|
||||
fn maybe_purge_stale_pkce_entries(&self) {
|
||||
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
||||
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
||||
const ADDITIONS_THRESHOLD: u32 = 128;
|
||||
|
||||
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||
Ok(d) => d.as_secs() as u32,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
||||
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
||||
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
||||
return;
|
||||
}
|
||||
|
||||
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
||||
let now_inst = Instant::now();
|
||||
for entry in self.pkce.iter() {
|
||||
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
||||
self.pkce.remove(entry.key());
|
||||
}
|
||||
}
|
||||
|
||||
// Reset counters after purge
|
||||
self.pkce_additions.store(0, Ordering::Relaxed);
|
||||
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
impl GitHubProvider {}
|
||||
|
||||
/// Fetch user emails from GitHub API
|
||||
pub async fn fetch_github_emails(
|
||||
|
||||
Reference in New Issue
Block a user