mirror of
https://github.com/Xevion/Pac-Man.git
synced 2026-01-31 12:25:04 -06:00
fix: rewrite oauth provider linking system, add email_verified attribute for providers
This commit is contained in:
@@ -15,6 +15,7 @@ pub struct DiscordUser {
|
||||
pub username: String,
|
||||
pub global_name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub verified: Option<bool>,
|
||||
pub avatar: Option<String>,
|
||||
}
|
||||
|
||||
@@ -109,11 +110,17 @@ impl OAuthProvider for DiscordProvider {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let (email, email_verified) = match (&user.email, user.verified) {
|
||||
(Some(e), Some(true)) => (Some(e.clone()), true),
|
||||
_ => (None, false),
|
||||
};
|
||||
|
||||
Ok(AuthUser {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
name: user.global_name,
|
||||
email: user.email,
|
||||
email,
|
||||
email_verified,
|
||||
avatar_url,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,6 +51,28 @@ pub async fn fetch_github_user(
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Fetch user emails from GitHub API
|
||||
pub async fn fetch_github_emails(
|
||||
http_client: &reqwest::Client,
|
||||
access_token: &str,
|
||||
) -> Result<Vec<GitHubEmail>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = http_client
|
||||
.get("https://api.github.com/user/emails")
|
||||
.header("Authorization", format!("Bearer {}", access_token))
|
||||
.header("Accept", "application/vnd.github.v3+json")
|
||||
.header("User-Agent", crate::config::USER_AGENT)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
warn!(status = %response.status(), endpoint = "/user/emails", "GitHub API returned an error");
|
||||
return Err(format!("GitHub API error: {}", response.status()).into());
|
||||
}
|
||||
|
||||
let emails: Vec<GitHubEmail> = response.json().await?;
|
||||
Ok(emails)
|
||||
}
|
||||
|
||||
pub struct GitHubProvider {
|
||||
pub client: super::OAuthClient,
|
||||
pub http: reqwest::Client,
|
||||
@@ -112,11 +134,24 @@ impl OAuthProvider for GitHubProvider {
|
||||
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
|
||||
})?;
|
||||
|
||||
let emails = fetch_github_emails(&self.http, access_token).await.map_err(|e| {
|
||||
warn!(error = %e, "Failed to fetch GitHub user emails");
|
||||
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e)))
|
||||
})?;
|
||||
|
||||
let primary_email = emails.into_iter().find(|e| e.primary && e.verified);
|
||||
|
||||
let (email, email_verified) = match primary_email {
|
||||
Some(e) => (Some(e.email), true),
|
||||
None => (user.email, false),
|
||||
};
|
||||
|
||||
Ok(AuthUser {
|
||||
id: user.id.to_string(),
|
||||
username: user.login,
|
||||
name: user.name,
|
||||
email: user.email,
|
||||
email,
|
||||
email_verified,
|
||||
avatar_url: Some(user.avatar_url),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ pub struct AuthUser {
|
||||
pub name: Option<String>,
|
||||
// An email address for the user. Not always available.
|
||||
pub email: Option<String>,
|
||||
// Whether the email address has been verified by the provider.
|
||||
pub email_verified: bool,
|
||||
// An avatar URL for the user. Not always available.
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
@@ -66,54 +66,18 @@ pub async fn link_oauth_account(
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create_user(
|
||||
pool: &sqlx::PgPool,
|
||||
provider_username: &str,
|
||||
provider_display_name: Option<&str>,
|
||||
provider_email: Option<&str>,
|
||||
provider_avatar_url: Option<&str>,
|
||||
provider: &str,
|
||||
provider_user_id: &str,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
pub async fn create_user(pool: &sqlx::PgPool, email: Option<&str>) -> Result<User, sqlx::Error> {
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"
|
||||
INSERT INTO users (email)
|
||||
VALUES ($1)
|
||||
ON CONFLICT (email) DO UPDATE SET email = EXCLUDED.email
|
||||
RETURNING id, email, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(provider_email)
|
||||
.bind(email)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// Create oauth link
|
||||
let _linked = link_oauth_account(
|
||||
pool,
|
||||
user.id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_email,
|
||||
Some(provider_username),
|
||||
provider_display_name,
|
||||
provider_avatar_url,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64) -> Result<i64, sqlx::Error> {
|
||||
let rec: (i64,) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT COUNT(*)::BIGINT AS count
|
||||
FROM oauth_accounts
|
||||
WHERE user_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
Ok(rec.0)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_user_by_provider_id(
|
||||
|
||||
+52
-201
@@ -4,9 +4,8 @@ use axum::{
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_cookie::CookieManager;
|
||||
use jsonwebtoken::{encode, Algorithm, Header};
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, debug_span, info, instrument, trace, warn};
|
||||
use tracing::{debug, debug_span, info, instrument, trace, warn, Instrument};
|
||||
|
||||
use crate::data::user as user_repo;
|
||||
use crate::{app::AppState, errors::ErrorResponse, session};
|
||||
@@ -19,11 +18,6 @@ pub struct OAuthCallbackParams {
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct AuthorizeQuery {
|
||||
pub link: Option<bool>,
|
||||
}
|
||||
|
||||
/// Handles the beginning of the OAuth authorization flow.
|
||||
///
|
||||
/// Requires the `provider` path parameter, which determines the OAuth provider to use.
|
||||
@@ -31,79 +25,20 @@ pub struct AuthorizeQuery {
|
||||
pub async fn oauth_authorize_handler(
|
||||
State(app_state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
Query(aq): Query<AuthorizeQuery>,
|
||||
cookie: CookieManager,
|
||||
) -> axum::response::Response {
|
||||
let Some(prov) = app_state.auth.get(&provider) else {
|
||||
warn!(%provider, "Unknown OAuth provider");
|
||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||
};
|
||||
|
||||
let is_linking = aq.link == Some(true);
|
||||
|
||||
// Persist link intent using a short-lived cookie; callbacks won't carry our query params.
|
||||
if is_linking {
|
||||
cookie.add(
|
||||
axum_cookie::cookie::Cookie::builder("link", "1")
|
||||
.http_only(true)
|
||||
.same_site(axum_cookie::prelude::SameSite::Lax)
|
||||
.path("/")
|
||||
// TODO: Pick a reasonable max age that aligns with how long OAuth providers can successfully complete the flow.
|
||||
.max_age(std::time::Duration::from_secs(60 * 60))
|
||||
.build(),
|
||||
);
|
||||
}
|
||||
trace!(linking = %is_linking, "Starting OAuth authorization");
|
||||
|
||||
// Try to acquire the existing session (PKCE session is ignored)
|
||||
let existing_session = match session::get_session_token(&cookie) {
|
||||
Some(token) => match session::decode_jwt(&token, &app_state.jwt_decoding_key) {
|
||||
Some(claims) if !session::is_pkce_session(&claims) => Some(claims),
|
||||
Some(_) => {
|
||||
debug!("existing session ignored; it is a PKCE session");
|
||||
None
|
||||
}
|
||||
None => {
|
||||
debug!("invalid session token");
|
||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
||||
}
|
||||
},
|
||||
None => {
|
||||
debug!("missing session cookie");
|
||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// If linking is enabled, error if the session doesn't exist or is a PKCE session
|
||||
if is_linking && existing_session.is_none() {
|
||||
warn!("missing session cookie during linking flow, refusing");
|
||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
||||
}
|
||||
trace!("Starting OAuth authorization");
|
||||
|
||||
let auth_info = match prov.authorize(&app_state.jwt_encoding_key).await {
|
||||
Ok(info) => info,
|
||||
Err(e) => return e.into_response(),
|
||||
};
|
||||
|
||||
let final_token = if let Some(mut claims) = existing_session {
|
||||
// We have a user session and are linking. Merge PKCE info into it.
|
||||
if let Some(pkce_claims) = session::decode_jwt(&auth_info.session_token, &app_state.jwt_decoding_key) {
|
||||
claims.pkce_verifier = pkce_claims.pkce_verifier;
|
||||
claims.csrf_state = pkce_claims.csrf_state;
|
||||
|
||||
// re-encode
|
||||
encode(&Header::new(Algorithm::HS256), &claims, &app_state.jwt_encoding_key).expect("jwt sign")
|
||||
} else {
|
||||
warn!("Failed to decode PKCE session token during linking flow");
|
||||
// Fallback to just using the PKCE token, which will break linking but not panic.
|
||||
auth_info.session_token
|
||||
}
|
||||
} else {
|
||||
// Not linking or no existing session, just use the new token.
|
||||
auth_info.session_token
|
||||
};
|
||||
|
||||
session::set_session_cookie(&cookie, &final_token);
|
||||
session::set_session_cookie(&cookie, &auth_info.session_token);
|
||||
trace!("Redirecting to provider authorization page");
|
||||
Redirect::to(auth_info.authorize_url.as_str()).into_response()
|
||||
}
|
||||
@@ -149,146 +84,62 @@ pub async fn oauth_callback_handler(
|
||||
}
|
||||
};
|
||||
|
||||
debug!(cookies = ?cookie.cookie().iter().collect::<Vec<_>>(), "Cookies");
|
||||
|
||||
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
|
||||
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
|
||||
if link_cookie.is_some() {
|
||||
cookie.remove("link");
|
||||
}
|
||||
let email = user.email.as_deref();
|
||||
|
||||
// Determine linking intent with a valid session
|
||||
if link_cookie.as_deref() == Some("1") {
|
||||
debug!("Link intent present");
|
||||
|
||||
if let Some(claims) =
|
||||
session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key))
|
||||
{
|
||||
// Perform linking with current session user
|
||||
let (cur_prov, cur_id) = claims.subject.split_once(':').unwrap_or(("", ""));
|
||||
let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await {
|
||||
Ok(Some(u)) => u,
|
||||
Ok(None) => {
|
||||
warn!("Current session user not found; proceeding as normal sign-in");
|
||||
return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into()))
|
||||
.into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||
}
|
||||
};
|
||||
if let Err(e) = user_repo::link_oauth_account(
|
||||
&app_state.db,
|
||||
current_user.id,
|
||||
&provider,
|
||||
&user.id,
|
||||
email,
|
||||
Some(&user.username),
|
||||
user.name.as_deref(),
|
||||
user.avatar_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %e, %provider, "Failed to link OAuth account");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||
}
|
||||
return (StatusCode::FOUND, Redirect::to("/profile")).into_response();
|
||||
} else {
|
||||
warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in");
|
||||
// --- Simplified Sign-in / Sign-up Flow ---
|
||||
let linking_span = debug_span!("account_linking", provider_user_id = %user.id, provider_email = ?user.email, email_verified = %user.email_verified);
|
||||
let db_user_result: Result<user_repo::User, sqlx::Error> = async {
|
||||
// 1. Check if we already have this specific provider account linked
|
||||
if let Some(user) = user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await? {
|
||||
debug!(user_id = %user.id, "Found existing user by provider ID");
|
||||
return Ok(user);
|
||||
}
|
||||
}
|
||||
|
||||
// Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow.
|
||||
if let Some(e) = email {
|
||||
if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await {
|
||||
// Only block if the user already has at least one linked provider.
|
||||
// NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive,
|
||||
// this may lock them out until the provider is reactivated or a manual admin link is performed.
|
||||
match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await {
|
||||
Ok(count) if count > 0 => {
|
||||
// Check if the "new" provider is already linked to the user
|
||||
match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await {
|
||||
Ok(Some(_)) => {
|
||||
debug!(
|
||||
%provider,
|
||||
%existing.id,
|
||||
"Provider already linked to user, signing in normally");
|
||||
}
|
||||
Ok(None) => {
|
||||
debug!(
|
||||
%provider,
|
||||
%existing.id,
|
||||
"Provider not linked to user, failing"
|
||||
);
|
||||
return ErrorResponse::bad_request(
|
||||
"account_exists",
|
||||
Some(format!(
|
||||
"An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.",
|
||||
e, provider
|
||||
)),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, %provider, "Failed to find user by provider ID");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
// No providers linked yet: safe to associate this provider
|
||||
if let Err(e) = user_repo::link_oauth_account(
|
||||
&app_state.db,
|
||||
existing.id,
|
||||
&provider,
|
||||
&user.id,
|
||||
email,
|
||||
Some(&user.username),
|
||||
user.name.as_deref(),
|
||||
user.avatar_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to count oauth accounts for user");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||
// 2. If not, try to find an existing user by verified email to link to
|
||||
let user_to_link = if user.email_verified {
|
||||
if let Some(email) = user.email.as_deref() {
|
||||
// Try to find a user with this email
|
||||
if let Some(existing_user) = user_repo::find_user_by_email(&app_state.db, email).await? {
|
||||
debug!(user_id = %existing_user.id, "Found existing user by email, linking new provider");
|
||||
existing_user
|
||||
} else {
|
||||
// No user with this email, create a new one
|
||||
debug!("No user found by email, creating a new one");
|
||||
user_repo::create_user(&app_state.db, Some(email)).await?
|
||||
}
|
||||
} else {
|
||||
// Verified, but no email for some reason. Create a user without an email.
|
||||
user_repo::create_user(&app_state.db, None).await?
|
||||
}
|
||||
} else {
|
||||
// Create new user with email
|
||||
match user_repo::create_user(
|
||||
&app_state.db,
|
||||
&user.username,
|
||||
user.name.as_deref(),
|
||||
email,
|
||||
user.avatar_url.as_deref(),
|
||||
&provider,
|
||||
&user.id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!(error = %e, %provider, "Failed to create user");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// No email available: disallow sign-in for safety
|
||||
return ErrorResponse::bad_request(
|
||||
"invalid_request",
|
||||
Some("account has no email; sign in with a different provider".into()),
|
||||
// No verified email, so we must create a new user without an email.
|
||||
debug!("No verified email, creating a new user");
|
||||
user_repo::create_user(&app_state.db, None).await?
|
||||
};
|
||||
|
||||
// 3. Link the new provider account to our user record (whether old or new)
|
||||
user_repo::link_oauth_account(
|
||||
&app_state.db,
|
||||
user_to_link.id,
|
||||
&provider,
|
||||
&user.id,
|
||||
user.email.as_deref(),
|
||||
Some(&user.username),
|
||||
user.name.as_deref(),
|
||||
user.avatar_url.as_deref(),
|
||||
)
|
||||
.into_response();
|
||||
.await?;
|
||||
|
||||
Ok(user_to_link)
|
||||
}
|
||||
.instrument(linking_span)
|
||||
.await;
|
||||
|
||||
let _: user_repo::User = match db_user_result {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!(error = %(&e as &dyn std::error::Error), "Failed to process user linking/creation");
|
||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Create session token
|
||||
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
||||
|
||||
Reference in New Issue
Block a user