mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-14 14:12:26 -06:00
refactor: create common pkce handling, max_age on link cookie
This commit is contained in:
@@ -1,14 +1,14 @@
|
|||||||
use axum::{response::IntoResponse, response::Redirect};
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
use dashmap::DashMap;
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tracing::{trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
use crate::auth::provider::{AuthUser, OAuthProvider};
|
use crate::auth::{
|
||||||
|
pkce::PkceManager,
|
||||||
|
provider::{AuthUser, OAuthProvider},
|
||||||
|
};
|
||||||
use crate::errors::ErrorResponse;
|
use crate::errors::ErrorResponse;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -43,15 +43,7 @@ pub async fn fetch_discord_user(
|
|||||||
pub struct DiscordProvider {
|
pub struct DiscordProvider {
|
||||||
pub client: super::OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
pkce: DashMap<String, PkceRecord>,
|
pkce: PkceManager,
|
||||||
last_purge_at_secs: AtomicU32,
|
|
||||||
pkce_additions: AtomicU32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct PkceRecord {
|
|
||||||
verifier: String,
|
|
||||||
created_at: Instant,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DiscordProvider {
|
impl DiscordProvider {
|
||||||
@@ -59,41 +51,10 @@ impl DiscordProvider {
|
|||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
client,
|
client,
|
||||||
http,
|
http,
|
||||||
pkce: DashMap::new(),
|
pkce: PkceManager::new(),
|
||||||
last_purge_at_secs: AtomicU32::new(0),
|
|
||||||
pkce_additions: AtomicU32::new(0),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn maybe_purge_stale_pkce_entries(&self) {
|
|
||||||
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
|
||||||
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
|
||||||
const ADDITIONS_THRESHOLD: u32 = 128;
|
|
||||||
|
|
||||||
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
|
||||||
Ok(d) => d.as_secs() as u32,
|
|
||||||
Err(_) => return,
|
|
||||||
};
|
|
||||||
|
|
||||||
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
|
||||||
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
|
||||||
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
|
||||||
let now_inst = Instant::now();
|
|
||||||
for entry in self.pkce.iter() {
|
|
||||||
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
|
||||||
self.pkce.remove(entry.key());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset counters after purge
|
|
||||||
self.pkce_additions.store(0, Ordering::Relaxed);
|
|
||||||
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
|
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
|
||||||
let ext = if avatar_hash.starts_with("a_") { "gif" } else { "png" };
|
let ext = if avatar_hash.starts_with("a_") { "gif" } else { "png" };
|
||||||
format!("https://cdn.discordapp.com/avatars/{}/{}.{}", user_id, avatar_hash, ext)
|
format!("https://cdn.discordapp.com/avatars/{}/{}.{}", user_id, avatar_hash, ext)
|
||||||
@@ -110,7 +71,7 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response {
|
async fn authorize(&self) -> axum::response::Response {
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
|
||||||
let (authorize_url, csrf_state) = self
|
let (authorize_url, csrf_state) = self
|
||||||
.client
|
.client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
@@ -118,19 +79,9 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
.add_scope(Scope::new("identify".to_string()))
|
.add_scope(Scope::new("identify".to_string()))
|
||||||
.add_scope(Scope::new("email".to_string()))
|
.add_scope(Scope::new("email".to_string()))
|
||||||
.url();
|
.url();
|
||||||
|
self.pkce.store_verifier(csrf_state.secret(), verifier);
|
||||||
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||||
|
|
||||||
// Insert PKCE verifier with timestamp and purge when needed
|
|
||||||
self.pkce.insert(
|
|
||||||
csrf_state.secret().to_string(),
|
|
||||||
PkceRecord {
|
|
||||||
verifier: pkce_verifier.secret().to_string(),
|
|
||||||
created_at: Instant::now(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
|
||||||
self.maybe_purge_stale_pkce_entries();
|
|
||||||
|
|
||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,31 +101,22 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
.get("state")
|
.get("state")
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
||||||
let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else {
|
let Some(verifier) = self.pkce.take_verifier(&state) else {
|
||||||
warn!("Missing PKCE verifier for state parameter");
|
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
||||||
return Err(ErrorResponse::bad_request(
|
return Err(ErrorResponse::bad_request(
|
||||||
"invalid_request",
|
"invalid_request",
|
||||||
Some("missing pkce verifier for state".into()),
|
Some("missing or expired pkce verifier for state".into()),
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
// Verify PKCE TTL
|
|
||||||
if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) {
|
|
||||||
warn!("PKCE verifier expired for state parameter");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
"invalid_request",
|
|
||||||
Some("expired pkce verifier for state".into()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(rec.verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!(error = %e, "Token exchange with Discord failed");
|
warn!(error = %e, %state, "Token exchange with Discord failed");
|
||||||
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
use axum::{response::IntoResponse, response::Redirect};
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
use dashmap::DashMap;
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tracing::{trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::provider::{AuthUser, OAuthProvider},
|
auth::{
|
||||||
|
pkce::PkceManager,
|
||||||
|
provider::{AuthUser, OAuthProvider},
|
||||||
|
},
|
||||||
errors::ErrorResponse,
|
errors::ErrorResponse,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -56,15 +56,7 @@ pub async fn fetch_github_user(
|
|||||||
pub struct GitHubProvider {
|
pub struct GitHubProvider {
|
||||||
pub client: super::OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
pkce: DashMap<String, PkceRecord>,
|
pkce: PkceManager,
|
||||||
last_purge_at_secs: AtomicU32,
|
|
||||||
pkce_additions: AtomicU32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct PkceRecord {
|
|
||||||
verifier: String,
|
|
||||||
created_at: Instant,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GitHubProvider {
|
impl GitHubProvider {
|
||||||
@@ -72,9 +64,7 @@ impl GitHubProvider {
|
|||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
client,
|
client,
|
||||||
http,
|
http,
|
||||||
pkce: DashMap::new(),
|
pkce: PkceManager::new(),
|
||||||
last_purge_at_secs: AtomicU32::new(0),
|
|
||||||
pkce_additions: AtomicU32::new(0),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,7 +79,7 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response {
|
async fn authorize(&self) -> axum::response::Response {
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
let (pkce_challenge, verifier) = self.pkce.generate_challenge();
|
||||||
let (authorize_url, csrf_state) = self
|
let (authorize_url, csrf_state) = self
|
||||||
.client
|
.client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
@@ -97,17 +87,9 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
.add_scope(Scope::new("user:email".to_string()))
|
.add_scope(Scope::new("user:email".to_string()))
|
||||||
.add_scope(Scope::new("read:user".to_string()))
|
.add_scope(Scope::new("read:user".to_string()))
|
||||||
.url();
|
.url();
|
||||||
|
// store verifier keyed by the returned state
|
||||||
|
self.pkce.store_verifier(csrf_state.secret(), verifier);
|
||||||
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||||
// Insert PKCE verifier with timestamp and purge when needed
|
|
||||||
self.pkce.insert(
|
|
||||||
csrf_state.secret().to_string(),
|
|
||||||
PkceRecord {
|
|
||||||
verifier: pkce_verifier.secret().to_string(),
|
|
||||||
created_at: Instant::now(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
|
||||||
self.maybe_purge_stale_pkce_entries();
|
|
||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,30 +109,22 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
.get("state")
|
.get("state")
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
||||||
let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else {
|
let Some(verifier) = self.pkce.take_verifier(&state) else {
|
||||||
warn!("Missing PKCE verifier for state parameter");
|
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
||||||
return Err(ErrorResponse::bad_request(
|
return Err(ErrorResponse::bad_request(
|
||||||
"invalid_request",
|
"invalid_request",
|
||||||
Some("missing pkce verifier for state".into()),
|
Some("missing or expired pkce verifier for state".into()),
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
// Verify PKCE TTL
|
|
||||||
if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) {
|
|
||||||
warn!("PKCE verifier expired for state parameter");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
"invalid_request",
|
|
||||||
Some("expired pkce verifier for state".into()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(rec.verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!(error = %e, "Token exchange with GitHub failed");
|
warn!(error = %e, %state, "Token exchange with GitHub failed");
|
||||||
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -177,36 +151,7 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GitHubProvider {
|
impl GitHubProvider {}
|
||||||
fn maybe_purge_stale_pkce_entries(&self) {
|
|
||||||
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
|
||||||
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
|
||||||
const ADDITIONS_THRESHOLD: u32 = 128;
|
|
||||||
|
|
||||||
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
|
||||||
Ok(d) => d.as_secs() as u32,
|
|
||||||
Err(_) => return,
|
|
||||||
};
|
|
||||||
|
|
||||||
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
|
||||||
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
|
||||||
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
|
||||||
let now_inst = Instant::now();
|
|
||||||
for entry in self.pkce.iter() {
|
|
||||||
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
|
||||||
self.pkce.remove(entry.key());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset counters after purge
|
|
||||||
self.pkce_additions.store(0, Ordering::Relaxed);
|
|
||||||
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Fetch user emails from GitHub API
|
/// Fetch user emails from GitHub API
|
||||||
pub async fn fetch_github_emails(
|
pub async fn fetch_github_emails(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use crate::config::Config;
|
|||||||
|
|
||||||
pub mod discord;
|
pub mod discord;
|
||||||
pub mod github;
|
pub mod github;
|
||||||
|
pub mod pkce;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
|
||||||
type OAuthClient =
|
type OAuthClient =
|
||||||
|
|||||||
91
pacman-server/src/auth/pkce.rs
Normal file
91
pacman-server/src/auth/pkce.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use dashmap::DashMap;
|
||||||
|
use oauth2::PkceCodeChallenge;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PkceRecord {
|
||||||
|
pub verifier: String,
|
||||||
|
pub created_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PkceManager {
|
||||||
|
pkce: DashMap<String, PkceRecord>,
|
||||||
|
last_purge_at_secs: AtomicU32,
|
||||||
|
pkce_additions: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PkceManager {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
pkce: DashMap::new(),
|
||||||
|
last_purge_at_secs: AtomicU32::new(0),
|
||||||
|
pkce_additions: AtomicU32::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_challenge(&self) -> (PkceCodeChallenge, String) {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
trace!("PKCE challenge generated");
|
||||||
|
(pkce_challenge, pkce_verifier.secret().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_verifier(&self, state: &str, verifier: String) {
|
||||||
|
self.pkce.insert(
|
||||||
|
state.to_string(),
|
||||||
|
PkceRecord {
|
||||||
|
verifier,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.maybe_purge_stale_entries();
|
||||||
|
trace!(state = state, "Stored PKCE verifier for state");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_verifier(&self, state: &str) -> Option<String> {
|
||||||
|
let Some(record) = self.pkce.remove(state).map(|e| e.1) else {
|
||||||
|
trace!(state = state, "PKCE verifier not found for state");
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify PKCE TTL
|
||||||
|
if Instant::now().duration_since(record.created_at) > Duration::from_secs(5 * 60) {
|
||||||
|
warn!(state = state, "PKCE verifier expired for state");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
trace!(state = state, "PKCE verifier retrieved for state");
|
||||||
|
Some(record.verifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn maybe_purge_stale_entries(&self) {
|
||||||
|
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
||||||
|
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
||||||
|
const ADDITIONS_THRESHOLD: u32 = 128;
|
||||||
|
|
||||||
|
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||||
|
Ok(d) => d.as_secs() as u32,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
||||||
|
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
||||||
|
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
||||||
|
let now_inst = Instant::now();
|
||||||
|
for entry in self.pkce.iter() {
|
||||||
|
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
||||||
|
self.pkce.remove(entry.key());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset counters after purge
|
||||||
|
self.pkce_additions.store(0, Ordering::Relaxed);
|
||||||
|
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ pub struct OAuthAccount {
|
|||||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_user_by_email(pool: &sqlx::PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
|
pub async fn find_user_by_email(pool: &sqlx::PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
|
||||||
sqlx::query_as::<_, User>(
|
sqlx::query_as::<_, User>(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, email, created_at, updated_at
|
SELECT id, email, created_at, updated_at
|
||||||
@@ -115,7 +115,7 @@ pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64)
|
|||||||
Ok(rec.0)
|
Ok(rec.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_user_by_provider_id(
|
pub async fn find_user_by_provider_id(
|
||||||
pool: &sqlx::PgPool,
|
pool: &sqlx::PgPool,
|
||||||
provider: &str,
|
provider: &str,
|
||||||
provider_user_id: &str,
|
provider_user_id: &str,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, Query, State},
|
extract::{Path, Query, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
@@ -42,6 +44,7 @@ pub async fn oauth_authorize_handler(
|
|||||||
.http_only(true)
|
.http_only(true)
|
||||||
.same_site(axum_cookie::prelude::SameSite::Lax)
|
.same_site(axum_cookie::prelude::SameSite::Lax)
|
||||||
.path("/")
|
.path("/")
|
||||||
|
.max_age(std::time::Duration::from_secs(120))
|
||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -57,15 +60,19 @@ pub async fn oauth_callback_handler(
|
|||||||
Query(params): Query<AuthQuery>,
|
Query(params): Query<AuthQuery>,
|
||||||
cookie: CookieManager,
|
cookie: CookieManager,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
|
// Validate provider
|
||||||
let Some(prov) = app_state.auth.get(&provider) else {
|
let Some(prov) = app_state.auth.get(&provider) else {
|
||||||
warn!(%provider, "Unknown OAuth provider");
|
warn!(%provider, "Unknown OAuth provider");
|
||||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Process callback-returned errors from provider
|
||||||
if let Some(error) = params.error {
|
if let Some(error) = params.error {
|
||||||
warn!(%provider, error = %error, desc = ?params.error_description, "OAuth callback returned an error");
|
warn!(%provider, error = %error, desc = ?params.error_description, "OAuth callback returned an error");
|
||||||
return ErrorResponse::bad_request(error, params.error_description).into_response();
|
return ErrorResponse::bad_request(error, params.error_description).into_response();
|
||||||
}
|
}
|
||||||
let mut q = std::collections::HashMap::new();
|
|
||||||
|
let mut q = HashMap::new();
|
||||||
if let Some(v) = params.code {
|
if let Some(v) = params.code {
|
||||||
q.insert("code".to_string(), v);
|
q.insert("code".to_string(), v);
|
||||||
}
|
}
|
||||||
@@ -85,29 +92,24 @@ pub async fn oauth_callback_handler(
|
|||||||
cookie.remove("link");
|
cookie.remove("link");
|
||||||
}
|
}
|
||||||
let email = user.email.as_deref();
|
let email = user.email.as_deref();
|
||||||
let _db_user = if link_cookie.as_deref() == Some("1") {
|
// Determine linking intent with a valid session
|
||||||
// Must be logged in already to link
|
let is_link = if link_cookie.as_deref() == Some("1") {
|
||||||
let Some(session_token) = session::get_session_token(&cookie) else {
|
match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) {
|
||||||
return ErrorResponse::bad_request("invalid_request", Some("must be signed in to link provider".into()))
|
Some(c) => {
|
||||||
.into_response();
|
// Perform linking with current session user
|
||||||
};
|
let (cur_prov, cur_id) = c.sub.split_once(':').unwrap_or(("", ""));
|
||||||
let Some(claims) = session::decode_jwt(&session_token, &app_state.jwt_decoding_key) else {
|
let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await {
|
||||||
return ErrorResponse::bad_request("invalid_request", Some("invalid session token".into())).into_response();
|
|
||||||
};
|
|
||||||
// Resolve current user from session
|
|
||||||
let (cur_prov, cur_id) = claims.sub.split_once(':').unwrap_or(("", ""));
|
|
||||||
let current_user = match user_repo::get_user_by_provider_id(&app_state.db, cur_prov, cur_id).await {
|
|
||||||
Ok(Some(u)) => u,
|
Ok(Some(u)) => u,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
|
warn!("Current session user not found; proceeding as normal sign-in");
|
||||||
return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into()))
|
return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into()))
|
||||||
.into_response();
|
.into_response();
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Link provider to current user
|
|
||||||
if let Err(e) = user_repo::link_oauth_account(
|
if let Err(e) = user_repo::link_oauth_account(
|
||||||
&app_state.db,
|
&app_state.db,
|
||||||
current_user.id,
|
current_user.id,
|
||||||
@@ -123,11 +125,23 @@ pub async fn oauth_callback_handler(
|
|||||||
warn!(error = %e, %provider, "Failed to link OAuth account");
|
warn!(error = %e, %provider, "Failed to link OAuth account");
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||||
}
|
}
|
||||||
current_user
|
return (StatusCode::FOUND, Redirect::to("/profile")).into_response();
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_link {
|
||||||
|
unreachable!(); // handled via early return above
|
||||||
} else {
|
} else {
|
||||||
// Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow.
|
// Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow.
|
||||||
if let Some(e) = email {
|
if let Some(e) = email {
|
||||||
if let Ok(Some(existing)) = user_repo::get_user_by_email(&app_state.db, e).await {
|
if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await {
|
||||||
// Only block if the user already has at least one linked provider.
|
// Only block if the user already has at least one linked provider.
|
||||||
// NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive,
|
// NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive,
|
||||||
// this may lock them out until the provider is reactivated or a manual admin link is performed.
|
// this may lock them out until the provider is reactivated or a manual admin link is performed.
|
||||||
@@ -198,9 +212,11 @@ pub async fn oauth_callback_handler(
|
|||||||
.into_response();
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
||||||
session::set_session_cookie(&cookie, &session_token);
|
session::set_session_cookie(&cookie, &session_token);
|
||||||
info!(%provider, "Signed in successfully");
|
info!(%provider, "Signed in successfully");
|
||||||
|
|
||||||
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
|
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,7 +238,7 @@ pub async fn profile_handler(State(app_state): State<AppState>, cookie: CookieMa
|
|||||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
return ErrorResponse::unauthorized("invalid session token").into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
match user_repo::get_user_by_provider_id(&app_state.db, prov, prov_user_id).await {
|
match user_repo::find_user_by_provider_id(&app_state.db, prov, prov_user_id).await {
|
||||||
Ok(Some(db_user)) => {
|
Ok(Some(db_user)) => {
|
||||||
// Include linked providers in the profile payload
|
// Include linked providers in the profile payload
|
||||||
match user_repo::list_user_providers(&app_state.db, db_user.id).await {
|
match user_repo::list_user_providers(&app_state.db, db_user.id).await {
|
||||||
|
|||||||
Reference in New Issue
Block a user