mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-06 11:15:46 -06:00
refactor(auth): implement session-based PKCE and eliminate provider duplication
- Replace in-memory PKCE storage with encrypted session cookies - Add PKCE verifier and CSRF state fields to JWT Claims struct - Move common PKCE validation logic to OAuthProvider trait - Extract provider-specific methods for token exchange and user fetching - Remove PkceManager and DashMap-based storage system - Update GitHub and Discord providers to use new session-based approach
This commit is contained in:
@@ -1,15 +1,15 @@
|
|||||||
use axum::{response::IntoResponse, response::Redirect};
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
|
use axum_cookie::CookieManager;
|
||||||
|
use jsonwebtoken::EncodingKey;
|
||||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
use crate::auth::{
|
use crate::auth::provider::{AuthUser, OAuthProvider};
|
||||||
pkce::PkceManager,
|
|
||||||
provider::{AuthUser, OAuthProvider},
|
|
||||||
};
|
|
||||||
use crate::errors::ErrorResponse;
|
use crate::errors::ErrorResponse;
|
||||||
|
use crate::session;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct DiscordUser {
|
pub struct DiscordUser {
|
||||||
@@ -43,16 +43,11 @@ pub async fn fetch_discord_user(
|
|||||||
pub struct DiscordProvider {
|
pub struct DiscordProvider {
|
||||||
pub client: super::OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
pkce: PkceManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DiscordProvider {
|
impl DiscordProvider {
|
||||||
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self { client, http })
|
||||||
client,
|
|
||||||
http,
|
|
||||||
pkce: PkceManager::default(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
|
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
|
||||||
@@ -70,8 +65,8 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
"Discord"
|
"Discord"
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response {
|
async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response {
|
||||||
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
|
let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
|
||||||
let (authorize_url, csrf_state) = self
|
let (authorize_url, csrf_state) = self
|
||||||
.client
|
.client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
@@ -79,38 +74,35 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
.add_scope(Scope::new("identify".to_string()))
|
.add_scope(Scope::new("identify".to_string()))
|
||||||
.add_scope(Scope::new("email".to_string()))
|
.add_scope(Scope::new("email".to_string()))
|
||||||
.url();
|
.url();
|
||||||
self.pkce.store_verifier(csrf_state.secret(), verifier);
|
|
||||||
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
|
||||||
|
|
||||||
|
// Store PKCE verifier and CSRF state in session
|
||||||
|
let session_token = session::create_pkce_session(pkce_verifier.secret(), csrf_state.secret(), encoding_key);
|
||||||
|
session::set_session_cookie(cookie, &session_token);
|
||||||
|
|
||||||
|
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
|
async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result<String, ErrorResponse> {
|
||||||
let Some(verifier) = self.pkce.take_verifier(state) else {
|
|
||||||
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
"invalid_request",
|
|
||||||
Some("missing or expired pkce verifier for state".into()),
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string()))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!(error = %e, %state, "Token exchange with Discord failed");
|
warn!(error = %e, "Token exchange with Discord failed");
|
||||||
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let user = fetch_discord_user(&self.http, token.access_token().secret())
|
Ok(token.access_token().secret().to_string())
|
||||||
.await
|
}
|
||||||
.map_err(|e| {
|
|
||||||
warn!(error = %e, "Failed to fetch Discord user profile");
|
async fn fetch_user_from_token(&self, access_token: &str) -> Result<AuthUser, ErrorResponse> {
|
||||||
ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e)))
|
let user = fetch_discord_user(&self.http, access_token).await.map_err(|e| {
|
||||||
})?;
|
warn!(error = %e, "Failed to fetch Discord user profile");
|
||||||
|
ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e)))
|
||||||
|
})?;
|
||||||
|
|
||||||
let avatar_url = match (&user.id, &user.avatar) {
|
let avatar_url = match (&user.id, &user.avatar) {
|
||||||
(id, Some(hash)) => Some(Self::avatar_url_for(id, hash)),
|
(id, Some(hash)) => Some(Self::avatar_url_for(id, hash)),
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use axum::{response::IntoResponse, response::Redirect};
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
|
use axum_cookie::CookieManager;
|
||||||
|
use jsonwebtoken::EncodingKey;
|
||||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -6,11 +8,9 @@ use std::sync::Arc;
|
|||||||
use tracing::{trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{
|
auth::provider::{AuthUser, OAuthProvider},
|
||||||
pkce::PkceManager,
|
|
||||||
provider::{AuthUser, OAuthProvider},
|
|
||||||
},
|
|
||||||
errors::ErrorResponse,
|
errors::ErrorResponse,
|
||||||
|
session,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -56,16 +56,11 @@ pub async fn fetch_github_user(
|
|||||||
pub struct GitHubProvider {
|
pub struct GitHubProvider {
|
||||||
pub client: super::OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
pkce: PkceManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GitHubProvider {
|
impl GitHubProvider {
|
||||||
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self { client, http })
|
||||||
client,
|
|
||||||
http,
|
|
||||||
pkce: PkceManager::default(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,8 +73,8 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
"GitHub"
|
"GitHub"
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response {
|
async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response {
|
||||||
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
|
let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
|
||||||
let (authorize_url, csrf_state) = self
|
let (authorize_url, csrf_state) = self
|
||||||
.client
|
.client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
@@ -87,44 +82,35 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
.add_scope(Scope::new("user:email".to_string()))
|
.add_scope(Scope::new("user:email".to_string()))
|
||||||
.add_scope(Scope::new("read:user".to_string()))
|
.add_scope(Scope::new("read:user".to_string()))
|
||||||
.url();
|
.url();
|
||||||
// store verifier keyed by the returned state
|
|
||||||
self.pkce.store_verifier(csrf_state.secret(), verifier);
|
// Store PKCE verifier and CSRF state in session
|
||||||
|
let session_token = session::create_pkce_session(pkce_verifier.secret(), csrf_state.secret(), encoding_key);
|
||||||
|
session::set_session_cookie(cookie, &session_token);
|
||||||
|
|
||||||
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
|
async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result<String, ErrorResponse> {
|
||||||
let Some(verifier) = self.pkce.take_verifier(state) else {
|
|
||||||
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
"invalid_request",
|
|
||||||
Some("missing or expired pkce verifier for state".into()),
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string()))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!(error = %e, %state, "Token exchange with GitHub failed");
|
warn!(error = %e, "Token exchange with GitHub failed");
|
||||||
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let user = fetch_github_user(&self.http, token.access_token().secret())
|
Ok(token.access_token().secret().to_string())
|
||||||
.await
|
}
|
||||||
.map_err(|e| {
|
|
||||||
warn!(error = %e, "Failed to fetch GitHub user profile");
|
async fn fetch_user_from_token(&self, access_token: &str) -> Result<AuthUser, ErrorResponse> {
|
||||||
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
|
let user = fetch_github_user(&self.http, access_token).await.map_err(|e| {
|
||||||
})?;
|
warn!(error = %e, "Failed to fetch GitHub user profile");
|
||||||
let _emails = fetch_github_emails(&self.http, token.access_token().secret())
|
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
|
||||||
.await
|
})?;
|
||||||
.map_err(|e| {
|
|
||||||
warn!(error = %e, "Failed to fetch GitHub user emails");
|
|
||||||
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e)))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(AuthUser {
|
Ok(AuthUser {
|
||||||
id: user.id.to_string(),
|
id: user.id.to_string(),
|
||||||
@@ -135,26 +121,3 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GitHubProvider {}
|
|
||||||
|
|
||||||
/// Fetch user emails from GitHub API
|
|
||||||
pub async fn fetch_github_emails(
|
|
||||||
http_client: &reqwest::Client,
|
|
||||||
access_token: &str,
|
|
||||||
) -> Result<Vec<GitHubEmail>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let response = http_client
|
|
||||||
.get("https://api.github.com/user/emails")
|
|
||||||
.header("Authorization", format!("Bearer {}", access_token))
|
|
||||||
.header("Accept", "application/vnd.github.v3+json")
|
|
||||||
.header("User-Agent", crate::config::USER_AGENT)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
return Err(format!("GitHub API error: {}", response.status()).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let emails: Vec<GitHubEmail> = response.json().await?;
|
|
||||||
Ok(emails)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ use crate::config::Config;
|
|||||||
|
|
||||||
pub mod discord;
|
pub mod discord;
|
||||||
pub mod github;
|
pub mod github;
|
||||||
pub mod pkce;
|
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
|
||||||
type OAuthClient =
|
type OAuthClient =
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
use dashmap::DashMap;
|
|
||||||
use oauth2::PkceCodeChallenge;
|
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tracing::{trace, warn};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PkceRecord {
|
|
||||||
pub verifier: String,
|
|
||||||
pub created_at: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct PkceManager {
|
|
||||||
pkce: DashMap<String, PkceRecord>,
|
|
||||||
last_purge_at_secs: AtomicU32,
|
|
||||||
pkce_additions: AtomicU32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PkceManager {
|
|
||||||
pub fn generate_challenge(&self) -> (PkceCodeChallenge, String) {
|
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
|
||||||
trace!("PKCE challenge generated");
|
|
||||||
(pkce_challenge, pkce_verifier.secret().to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn store_verifier(&self, state: &str, verifier: String) {
|
|
||||||
self.pkce.insert(
|
|
||||||
state.to_string(),
|
|
||||||
PkceRecord {
|
|
||||||
verifier,
|
|
||||||
created_at: Instant::now(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
|
||||||
self.maybe_purge_stale_entries();
|
|
||||||
trace!(state = state, "Stored PKCE verifier for state");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn take_verifier(&self, state: &str) -> Option<String> {
|
|
||||||
let Some(record) = self.pkce.remove(state).map(|e| e.1) else {
|
|
||||||
trace!(state = state, "PKCE verifier not found for state");
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Verify PKCE TTL
|
|
||||||
if Instant::now().duration_since(record.created_at) > Duration::from_secs(5 * 60) {
|
|
||||||
warn!(state = state, "PKCE verifier expired for state");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
trace!(state = state, "PKCE verifier retrieved for state");
|
|
||||||
Some(record.verifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn maybe_purge_stale_entries(&self) {
|
|
||||||
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
|
||||||
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
|
||||||
const ADDITIONS_THRESHOLD: u32 = 128;
|
|
||||||
|
|
||||||
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
|
||||||
Ok(d) => d.as_secs() as u32,
|
|
||||||
Err(_) => return,
|
|
||||||
};
|
|
||||||
|
|
||||||
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
|
||||||
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
|
||||||
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
|
||||||
let now_inst = Instant::now();
|
|
||||||
for entry in self.pkce.iter() {
|
|
||||||
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
|
||||||
self.pkce.remove(entry.key());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset counters after purge
|
|
||||||
self.pkce_additions.store(0, Ordering::Relaxed);
|
|
||||||
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,28 +1,120 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use axum_cookie::CookieManager;
|
||||||
|
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||||
use mockall::automock;
|
use mockall::automock;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::errors::ErrorResponse;
|
use crate::errors::ErrorResponse;
|
||||||
|
use crate::session;
|
||||||
|
|
||||||
|
// A user object returned from the provider after authentication.
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct AuthUser {
|
pub struct AuthUser {
|
||||||
|
// A unique identifier for the user, from the provider.
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
// A username from the provider. Generally unique, a handle for the user.
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
|
||||||
|
// A display name for the user. Not always available.
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
// An email address for the user. Not always available.
|
||||||
pub email: Option<String>,
|
pub email: Option<String>,
|
||||||
|
// An avatar URL for the user. Not always available.
|
||||||
pub avatar_url: Option<String>,
|
pub avatar_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[automock]
|
#[automock]
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait OAuthProvider: Send + Sync {
|
pub trait OAuthProvider: Send + Sync {
|
||||||
|
// Builds a server response to redirect the user to the provider's authorization page.
|
||||||
|
// This generally also includes beginning a PKCE flow (proof key for code exchange).
|
||||||
|
// The cookie manager is used to store the PKCE verifier in the session.
|
||||||
|
async fn authorize(&self, cookie: &CookieManager, encoding_key: &EncodingKey) -> axum::response::Response;
|
||||||
|
|
||||||
|
// Handles the callback from the provider after the user has authorized the app.
|
||||||
|
// This generally also includes completing the PKCE flow (proof key for code exchange).
|
||||||
|
// The cookie manager is used to retrieve the PKCE verifier from the session.
|
||||||
|
async fn handle_callback(
|
||||||
|
&self,
|
||||||
|
code: &str,
|
||||||
|
state: &str,
|
||||||
|
cookie: &CookieManager,
|
||||||
|
decoding_key: &DecodingKey,
|
||||||
|
) -> Result<AuthUser, ErrorResponse> {
|
||||||
|
// Common PKCE session validation and token exchange logic
|
||||||
|
let verifier = self.validate_pkce_session(cookie, state, decoding_key).await?;
|
||||||
|
let access_token = self.exchange_code_for_token(code, &verifier).await?;
|
||||||
|
let user = self.fetch_user_from_token(&access_token).await?;
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the PKCE session and returns the verifier
|
||||||
|
async fn validate_pkce_session(
|
||||||
|
&self,
|
||||||
|
cookie: &CookieManager,
|
||||||
|
state: &str,
|
||||||
|
decoding_key: &DecodingKey,
|
||||||
|
) -> Result<String, ErrorResponse> {
|
||||||
|
// Get the session token and verify it's a PKCE session
|
||||||
|
let Some(session_token) = session::get_session_token(cookie) else {
|
||||||
|
warn!(%state, "Missing session cookie during OAuth callback");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("missing session cookie".into()),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(claims) = session::decode_jwt(&session_token, decoding_key) else {
|
||||||
|
warn!(%state, "Invalid session token during OAuth callback");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("invalid session token".into()),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify this is a PKCE session and the state matches
|
||||||
|
if !session::is_pkce_session(&claims) {
|
||||||
|
warn!(%state, "Session is not a PKCE session");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("invalid session type".into()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.csrf_state.as_deref() != Some(state) {
|
||||||
|
warn!(%state, "CSRF state mismatch during OAuth callback");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("state parameter mismatch".into()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(verifier) = claims.pkce_verifier else {
|
||||||
|
warn!(%state, "Missing PKCE verifier in session");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("missing pkce verifier".into()),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(verifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchanges the authorization code for an access token using PKCE
|
||||||
|
async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result<String, ErrorResponse>;
|
||||||
|
|
||||||
|
// Fetches user information from the provider using the access token
|
||||||
|
async fn fetch_user_from_token(&self, access_token: &str) -> Result<AuthUser, ErrorResponse>;
|
||||||
|
|
||||||
|
// The provider's unique identifier (e.g. "discord")
|
||||||
fn id(&self) -> &'static str;
|
fn id(&self) -> &'static str;
|
||||||
|
|
||||||
|
// The provider's display name (e.g. "Discord")
|
||||||
fn label(&self) -> &'static str;
|
fn label(&self) -> &'static str;
|
||||||
|
|
||||||
|
// Whether the provider is active (defaults to true for now)
|
||||||
fn active(&self) -> bool {
|
fn active(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response;
|
|
||||||
|
|
||||||
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse>;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ pub async fn oauth_authorize_handler(
|
|||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let resp = prov.authorize().await;
|
let resp = prov.authorize(&cookie, &app_state.jwt_encoding_key).await;
|
||||||
trace!("Redirecting to provider authorization page");
|
trace!("Redirecting to provider authorization page");
|
||||||
resp
|
resp
|
||||||
}
|
}
|
||||||
@@ -80,7 +80,7 @@ pub async fn oauth_callback_handler(
|
|||||||
span!(tracing::Level::DEBUG, "oauth_callback_handler", provider = %provider, code = %code, state = %state);
|
span!(tracing::Level::DEBUG, "oauth_callback_handler", provider = %provider, code = %code, state = %state);
|
||||||
|
|
||||||
// Handle callback from provider
|
// Handle callback from provider
|
||||||
let user = match prov.handle_callback(code, state).await {
|
let user = match prov.handle_callback(code, state, &cookie, &app_state.jwt_decoding_key).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(%provider, "OAuth callback handling failed");
|
warn!(%provider, "OAuth callback handling failed");
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ pub struct Claims {
|
|||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
pub iat: usize,
|
pub iat: usize,
|
||||||
pub exp: usize,
|
pub exp: usize,
|
||||||
|
// PKCE flow fields - only present during OAuth flow
|
||||||
|
#[serde(rename = "ver", skip_serializing_if = "Option::is_none")]
|
||||||
|
pub pkce_verifier: Option<String>,
|
||||||
|
#[serde(rename = "st", skip_serializing_if = "Option::is_none")]
|
||||||
|
pub csrf_state: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String {
|
pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String {
|
||||||
@@ -27,12 +32,38 @@ pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &Encod
|
|||||||
name: user.name.clone(),
|
name: user.name.clone(),
|
||||||
iat: now,
|
iat: now,
|
||||||
exp: now + JWT_TTL_SECS as usize,
|
exp: now + JWT_TTL_SECS as usize,
|
||||||
|
pkce_verifier: None,
|
||||||
|
csrf_state: None,
|
||||||
};
|
};
|
||||||
let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign");
|
let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign");
|
||||||
trace!(sub = %claims.sub, exp = claims.exp, "Created session JWT");
|
trace!(sub = %claims.sub, exp = claims.exp, "Created session JWT");
|
||||||
token
|
token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a temporary session for PKCE flow with verifier and CSRF state
|
||||||
|
pub fn create_pkce_session(pkce_verifier: &str, csrf_state: &str, encoding_key: &EncodingKey) -> String {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time went backwards")
|
||||||
|
.as_secs() as usize;
|
||||||
|
let claims = Claims {
|
||||||
|
sub: "pkce_flow".to_string(), // Special marker for PKCE flow
|
||||||
|
name: None,
|
||||||
|
iat: now,
|
||||||
|
exp: now + JWT_TTL_SECS as usize,
|
||||||
|
pkce_verifier: Some(pkce_verifier.to_string()),
|
||||||
|
csrf_state: Some(csrf_state.to_string()),
|
||||||
|
};
|
||||||
|
let token = encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign");
|
||||||
|
trace!(csrf_state = %csrf_state, "Created PKCE session JWT");
|
||||||
|
token
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if a session is a PKCE flow session
|
||||||
|
pub fn is_pkce_session(claims: &Claims) -> bool {
|
||||||
|
claims.sub == "pkce_flow" && claims.pkce_verifier.is_some() && claims.csrf_state.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option<Claims> {
|
pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option<Claims> {
|
||||||
let mut validation = Validation::new(Algorithm::HS256);
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
validation.leeway = 30;
|
validation.leeway = 30;
|
||||||
|
|||||||
Reference in New Issue
Block a user