mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-06 05:15:49 -06:00
feat: discord oauth provider, setup provider list route, add 'active' method, common type alias
This commit is contained in:
201
pacman-server/src/auth/discord.rs
Normal file
201
pacman-server/src/auth/discord.rs
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
|
use crate::auth::provider::{AuthUser, OAuthProvider};
|
||||||
|
use crate::errors::ErrorResponse;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DiscordUser {
|
||||||
|
pub id: String,
|
||||||
|
pub username: String,
|
||||||
|
pub global_name: Option<String>,
|
||||||
|
pub email: Option<String>,
|
||||||
|
pub avatar: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn fetch_discord_user(
|
||||||
|
http_client: &reqwest::Client,
|
||||||
|
access_token: &str,
|
||||||
|
) -> Result<DiscordUser, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = http_client
|
||||||
|
.get("https://discord.com/api/users/@me")
|
||||||
|
.header("Authorization", format!("Bearer {}", access_token))
|
||||||
|
.header("User-Agent", crate::config::USER_AGENT)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
warn!(status = %response.status(), endpoint = "/users/@me", "Discord API returned an error");
|
||||||
|
return Err(format!("Discord API error: {}", response.status()).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let user: DiscordUser = response.json().await?;
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DiscordProvider {
|
||||||
|
pub client: super::OAuthClient,
|
||||||
|
pub http: reqwest::Client,
|
||||||
|
pkce: DashMap<String, PkceRecord>,
|
||||||
|
last_purge_at_secs: AtomicU32,
|
||||||
|
pkce_additions: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct PkceRecord {
|
||||||
|
verifier: String,
|
||||||
|
created_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscordProvider {
|
||||||
|
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
client,
|
||||||
|
http,
|
||||||
|
pkce: DashMap::new(),
|
||||||
|
last_purge_at_secs: AtomicU32::new(0),
|
||||||
|
pkce_additions: AtomicU32::new(0),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn maybe_purge_stale_pkce_entries(&self) {
|
||||||
|
// Purge when at least 5 minutes passed or more than 128 additions occurred
|
||||||
|
const PURGE_INTERVAL_SECS: u32 = 5 * 60;
|
||||||
|
const ADDITIONS_THRESHOLD: u32 = 128;
|
||||||
|
|
||||||
|
let now_secs = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||||
|
Ok(d) => d.as_secs() as u32,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
let last = self.last_purge_at_secs.load(Ordering::Relaxed);
|
||||||
|
let additions = self.pkce_additions.load(Ordering::Relaxed);
|
||||||
|
if additions < ADDITIONS_THRESHOLD && now_secs.saturating_sub(last) < PURGE_INTERVAL_SECS {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
||||||
|
let now_inst = Instant::now();
|
||||||
|
for entry in self.pkce.iter() {
|
||||||
|
if now_inst.duration_since(entry.value().created_at) > PKCE_TTL {
|
||||||
|
self.pkce.remove(entry.key());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset counters after purge
|
||||||
|
self.pkce_additions.store(0, Ordering::Relaxed);
|
||||||
|
self.last_purge_at_secs.store(now_secs, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avatar_url_for(user_id: &str, avatar_hash: &str) -> String {
|
||||||
|
let ext = if avatar_hash.starts_with("a_") { "gif" } else { "png" };
|
||||||
|
format!("https://cdn.discordapp.com/avatars/{}/{}.{}", user_id, avatar_hash, ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl OAuthProvider for DiscordProvider {
|
||||||
|
fn id(&self) -> &'static str {
|
||||||
|
"discord"
|
||||||
|
}
|
||||||
|
fn label(&self) -> &'static str {
|
||||||
|
"Discord"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn authorize(&self) -> axum::response::Response {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
let (authorize_url, csrf_state) = self
|
||||||
|
.client
|
||||||
|
.authorize_url(CsrfToken::new_random)
|
||||||
|
.set_pkce_challenge(pkce_challenge)
|
||||||
|
.add_scope(Scope::new("identify".to_string()))
|
||||||
|
.add_scope(Scope::new("email".to_string()))
|
||||||
|
.url();
|
||||||
|
trace!(state = %csrf_state.secret(), "Generated OAuth authorization URL");
|
||||||
|
|
||||||
|
// Insert PKCE verifier with timestamp and purge when needed
|
||||||
|
self.pkce.insert(
|
||||||
|
csrf_state.secret().to_string(),
|
||||||
|
PkceRecord {
|
||||||
|
verifier: pkce_verifier.secret().to_string(),
|
||||||
|
created_at: Instant::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.pkce_additions.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.maybe_purge_stale_pkce_entries();
|
||||||
|
|
||||||
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> {
|
||||||
|
if let Some(err) = query.get("error") {
|
||||||
|
warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
err.clone(),
|
||||||
|
query.get("error_description").cloned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let code = query
|
||||||
|
.get("code")
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
|
||||||
|
let state = query
|
||||||
|
.get("state")
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
||||||
|
let Some(rec) = self.pkce.remove(&state).map(|e| e.1) else {
|
||||||
|
warn!("Missing PKCE verifier for state parameter");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("missing pkce verifier for state".into()),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify PKCE TTL
|
||||||
|
if Instant::now().duration_since(rec.created_at) > Duration::from_secs(5 * 60) {
|
||||||
|
warn!("PKCE verifier expired for state parameter");
|
||||||
|
return Err(ErrorResponse::bad_request(
|
||||||
|
"invalid_request",
|
||||||
|
Some("expired pkce verifier for state".into()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = self
|
||||||
|
.client
|
||||||
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
|
.set_pkce_verifier(PkceCodeVerifier::new(rec.verifier))
|
||||||
|
.request_async(&self.http)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
warn!(error = %e, "Token exchange with Discord failed");
|
||||||
|
ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string()))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let user = fetch_discord_user(&self.http, token.access_token().secret())
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
warn!(error = %e, "Failed to fetch Discord user profile");
|
||||||
|
ErrorResponse::bad_gateway("discord_api_error", Some(format!("failed to fetch user: {}", e)))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let avatar_url = match (&user.id, &user.avatar) {
|
||||||
|
(id, Some(hash)) => Some(Self::avatar_url_for(id, hash)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(AuthUser {
|
||||||
|
id: user.id,
|
||||||
|
username: user.username,
|
||||||
|
name: user.global_name,
|
||||||
|
email: user.email,
|
||||||
|
avatar_url,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use axum::{response::IntoResponse, response::Redirect};
|
use axum::{response::IntoResponse, response::Redirect};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use oauth2::{basic::BasicClient, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
@@ -13,10 +13,6 @@ use crate::{
|
|||||||
errors::ErrorResponse,
|
errors::ErrorResponse,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Private type alias for the OAuth2 BasicClient specialized type
|
|
||||||
type OAuthClient =
|
|
||||||
BasicClient<oauth2::EndpointSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointSet>;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GitHubUser {
|
pub struct GitHubUser {
|
||||||
pub id: u64,
|
pub id: u64,
|
||||||
@@ -58,7 +54,7 @@ pub async fn fetch_github_user(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct GitHubProvider {
|
pub struct GitHubProvider {
|
||||||
pub client: OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
pkce: DashMap<String, PkceRecord>,
|
pkce: DashMap<String, PkceRecord>,
|
||||||
last_purge_at_secs: AtomicU32,
|
last_purge_at_secs: AtomicU32,
|
||||||
@@ -72,7 +68,7 @@ struct PkceRecord {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl GitHubProvider {
|
impl GitHubProvider {
|
||||||
pub fn new(client: OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
pub fn new(client: super::OAuthClient, http: reqwest::Client) -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
client,
|
client,
|
||||||
http,
|
http,
|
||||||
|
|||||||
@@ -5,9 +5,13 @@ use oauth2::{basic::BasicClient, EndpointNotSet, EndpointSet};
|
|||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
||||||
|
pub mod discord;
|
||||||
pub mod github;
|
pub mod github;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
|
||||||
|
type OAuthClient =
|
||||||
|
BasicClient<oauth2::EndpointSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointNotSet, oauth2::EndpointSet>;
|
||||||
|
|
||||||
pub struct AuthRegistry {
|
pub struct AuthRegistry {
|
||||||
providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>>,
|
providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>>,
|
||||||
}
|
}
|
||||||
@@ -32,7 +36,19 @@ impl AuthRegistry {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let mut providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>> = HashMap::new();
|
let mut providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>> = HashMap::new();
|
||||||
providers.insert("github", github::GitHubProvider::new(github_client, http));
|
providers.insert("github", github::GitHubProvider::new(github_client, http.clone()));
|
||||||
|
|
||||||
|
// Discord OAuth client
|
||||||
|
let discord_client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet> =
|
||||||
|
BasicClient::new(oauth2::ClientId::new(config.discord_client_id.clone()))
|
||||||
|
.set_client_secret(oauth2::ClientSecret::new(config.discord_client_secret.clone()))
|
||||||
|
.set_auth_uri(oauth2::AuthUrl::new("https://discord.com/api/oauth2/authorize".to_string())?)
|
||||||
|
.set_token_uri(oauth2::TokenUrl::new("https://discord.com/api/oauth2/token".to_string())?)
|
||||||
|
.set_redirect_uri(
|
||||||
|
oauth2::RedirectUrl::new(format!("{}/auth/discord/callback", config.public_base_url))
|
||||||
|
.expect("Invalid redirect URI"),
|
||||||
|
);
|
||||||
|
providers.insert("discord", discord::DiscordProvider::new(discord_client, http));
|
||||||
|
|
||||||
Ok(Self { providers })
|
Ok(Self { providers })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ pub struct AuthUser {
|
|||||||
pub trait OAuthProvider: Send + Sync {
|
pub trait OAuthProvider: Send + Sync {
|
||||||
fn id(&self) -> &'static str;
|
fn id(&self) -> &'static str;
|
||||||
fn label(&self) -> &'static str;
|
fn label(&self) -> &'static str;
|
||||||
|
fn active(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response;
|
async fn authorize(&self) -> axum::response::Response;
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ async fn main() {
|
|||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
|
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
|
||||||
|
.route("/auth/providers", get(routes::list_providers_handler))
|
||||||
.route("/auth/{provider}", get(routes::oauth_authorize_handler))
|
.route("/auth/{provider}", get(routes::oauth_authorize_handler))
|
||||||
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
|
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
|
||||||
.route("/logout", get(routes::logout_handler))
|
.route("/logout", get(routes::logout_handler))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use axum::{
|
|||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_cookie::CookieManager;
|
use axum_cookie::CookieManager;
|
||||||
|
use serde::Serialize;
|
||||||
use tracing::{debug, info, trace, warn};
|
use tracing::{debug, info, trace, warn};
|
||||||
|
|
||||||
use crate::{app::AppState, errors::ErrorResponse, session};
|
use crate::{app::AppState, errors::ErrorResponse, session};
|
||||||
@@ -91,3 +92,21 @@ pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieMan
|
|||||||
info!("Signed out successfully");
|
info!("Signed out successfully");
|
||||||
(StatusCode::FOUND, Redirect::to("/")).into_response()
|
(StatusCode::FOUND, Redirect::to("/")).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ProviderInfo {
|
||||||
|
provider: &'static str,
|
||||||
|
active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_providers_handler(State(app_state): State<AppState>) -> axum::response::Response {
|
||||||
|
let providers: Vec<ProviderInfo> = app_state
|
||||||
|
.auth
|
||||||
|
.iter()
|
||||||
|
.map(|(id, provider)| ProviderInfo {
|
||||||
|
provider: id,
|
||||||
|
active: provider.active(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
axum::Json(providers).into_response()
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user