mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-05 23:15:40 -06:00
fix: rewrite oauth provider linking system, add email_verified attribute for providers
This commit is contained in:
@@ -16,6 +16,10 @@ publish = false
|
|||||||
[profile.dev]
|
[profile.dev]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
|
||||||
|
# Improve build times by optimizing sqlx-macros
|
||||||
|
[profile.dev.package.sqlx-macros]
|
||||||
|
opt-level = 3
|
||||||
|
|
||||||
# Release profile for profiling (essentially the default 'release' profile with debug enabled)
|
# Release profile for profiling (essentially the default 'release' profile with debug enabled)
|
||||||
[profile.profile]
|
[profile.profile]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
|||||||
4
build.rs
Normal file
4
build.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
fn main() {
|
||||||
|
// trigger recompilation when a new migration is added
|
||||||
|
println!("cargo:rerun-if-changed=migrations");
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ pub struct DiscordUser {
|
|||||||
pub username: String,
|
pub username: String,
|
||||||
pub global_name: Option<String>,
|
pub global_name: Option<String>,
|
||||||
pub email: Option<String>,
|
pub email: Option<String>,
|
||||||
|
pub verified: Option<bool>,
|
||||||
pub avatar: Option<String>,
|
pub avatar: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,11 +110,17 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (email, email_verified) = match (&user.email, user.verified) {
|
||||||
|
(Some(e), Some(true)) => (Some(e.clone()), true),
|
||||||
|
_ => (None, false),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(AuthUser {
|
Ok(AuthUser {
|
||||||
id: user.id,
|
id: user.id,
|
||||||
username: user.username,
|
username: user.username,
|
||||||
name: user.global_name,
|
name: user.global_name,
|
||||||
email: user.email,
|
email,
|
||||||
|
email_verified,
|
||||||
avatar_url,
|
avatar_url,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,28 @@ pub async fn fetch_github_user(
|
|||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch user emails from GitHub API
|
||||||
|
pub async fn fetch_github_emails(
|
||||||
|
http_client: &reqwest::Client,
|
||||||
|
access_token: &str,
|
||||||
|
) -> Result<Vec<GitHubEmail>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = http_client
|
||||||
|
.get("https://api.github.com/user/emails")
|
||||||
|
.header("Authorization", format!("Bearer {}", access_token))
|
||||||
|
.header("Accept", "application/vnd.github.v3+json")
|
||||||
|
.header("User-Agent", crate::config::USER_AGENT)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
warn!(status = %response.status(), endpoint = "/user/emails", "GitHub API returned an error");
|
||||||
|
return Err(format!("GitHub API error: {}", response.status()).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let emails: Vec<GitHubEmail> = response.json().await?;
|
||||||
|
Ok(emails)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct GitHubProvider {
|
pub struct GitHubProvider {
|
||||||
pub client: super::OAuthClient,
|
pub client: super::OAuthClient,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
@@ -112,11 +134,24 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
|
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e)))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let emails = fetch_github_emails(&self.http, access_token).await.map_err(|e| {
|
||||||
|
warn!(error = %e, "Failed to fetch GitHub user emails");
|
||||||
|
ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e)))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let primary_email = emails.into_iter().find(|e| e.primary && e.verified);
|
||||||
|
|
||||||
|
let (email, email_verified) = match primary_email {
|
||||||
|
Some(e) => (Some(e.email), true),
|
||||||
|
None => (user.email, false),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(AuthUser {
|
Ok(AuthUser {
|
||||||
id: user.id.to_string(),
|
id: user.id.to_string(),
|
||||||
username: user.login,
|
username: user.login,
|
||||||
name: user.name,
|
name: user.name,
|
||||||
email: user.email,
|
email,
|
||||||
|
email_verified,
|
||||||
avatar_url: Some(user.avatar_url),
|
avatar_url: Some(user.avatar_url),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ pub struct AuthUser {
|
|||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
// An email address for the user. Not always available.
|
// An email address for the user. Not always available.
|
||||||
pub email: Option<String>,
|
pub email: Option<String>,
|
||||||
|
// Whether the email address has been verified by the provider.
|
||||||
|
pub email_verified: bool,
|
||||||
// An avatar URL for the user. Not always available.
|
// An avatar URL for the user. Not always available.
|
||||||
pub avatar_url: Option<String>,
|
pub avatar_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,54 +66,18 @@ pub async fn link_oauth_account(
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_user(
|
pub async fn create_user(pool: &sqlx::PgPool, email: Option<&str>) -> Result<User, sqlx::Error> {
|
||||||
pool: &sqlx::PgPool,
|
sqlx::query_as::<_, User>(
|
||||||
provider_username: &str,
|
|
||||||
provider_display_name: Option<&str>,
|
|
||||||
provider_email: Option<&str>,
|
|
||||||
provider_avatar_url: Option<&str>,
|
|
||||||
provider: &str,
|
|
||||||
provider_user_id: &str,
|
|
||||||
) -> Result<User, sqlx::Error> {
|
|
||||||
let user = sqlx::query_as::<_, User>(
|
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO users (email)
|
INSERT INTO users (email)
|
||||||
VALUES ($1)
|
VALUES ($1)
|
||||||
|
ON CONFLICT (email) DO UPDATE SET email = EXCLUDED.email
|
||||||
RETURNING id, email, created_at, updated_at
|
RETURNING id, email, created_at, updated_at
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(provider_email)
|
.bind(email)
|
||||||
.fetch_one(pool)
|
.fetch_one(pool)
|
||||||
.await?;
|
.await
|
||||||
|
|
||||||
// Create oauth link
|
|
||||||
let _linked = link_oauth_account(
|
|
||||||
pool,
|
|
||||||
user.id,
|
|
||||||
provider,
|
|
||||||
provider_user_id,
|
|
||||||
provider_email,
|
|
||||||
Some(provider_username),
|
|
||||||
provider_display_name,
|
|
||||||
provider_avatar_url,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_oauth_account_count_for_user(pool: &sqlx::PgPool, user_id: i64) -> Result<i64, sqlx::Error> {
|
|
||||||
let rec: (i64,) = sqlx::query_as(
|
|
||||||
r#"
|
|
||||||
SELECT COUNT(*)::BIGINT AS count
|
|
||||||
FROM oauth_accounts
|
|
||||||
WHERE user_id = $1
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(user_id)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await?;
|
|
||||||
Ok(rec.0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn find_user_by_provider_id(
|
pub async fn find_user_by_provider_id(
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ use axum::{
|
|||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_cookie::CookieManager;
|
use axum_cookie::CookieManager;
|
||||||
use jsonwebtoken::{encode, Algorithm, Header};
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tracing::{debug, debug_span, info, instrument, trace, warn};
|
use tracing::{debug, debug_span, info, instrument, trace, warn, Instrument};
|
||||||
|
|
||||||
use crate::data::user as user_repo;
|
use crate::data::user as user_repo;
|
||||||
use crate::{app::AppState, errors::ErrorResponse, session};
|
use crate::{app::AppState, errors::ErrorResponse, session};
|
||||||
@@ -19,11 +18,6 @@ pub struct OAuthCallbackParams {
|
|||||||
pub error_description: Option<String>,
|
pub error_description: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
|
||||||
pub struct AuthorizeQuery {
|
|
||||||
pub link: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handles the beginning of the OAuth authorization flow.
|
/// Handles the beginning of the OAuth authorization flow.
|
||||||
///
|
///
|
||||||
/// Requires the `provider` path parameter, which determines the OAuth provider to use.
|
/// Requires the `provider` path parameter, which determines the OAuth provider to use.
|
||||||
@@ -31,79 +25,20 @@ pub struct AuthorizeQuery {
|
|||||||
pub async fn oauth_authorize_handler(
|
pub async fn oauth_authorize_handler(
|
||||||
State(app_state): State<AppState>,
|
State(app_state): State<AppState>,
|
||||||
Path(provider): Path<String>,
|
Path(provider): Path<String>,
|
||||||
Query(aq): Query<AuthorizeQuery>,
|
|
||||||
cookie: CookieManager,
|
cookie: CookieManager,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
let Some(prov) = app_state.auth.get(&provider) else {
|
let Some(prov) = app_state.auth.get(&provider) else {
|
||||||
warn!(%provider, "Unknown OAuth provider");
|
warn!(%provider, "Unknown OAuth provider");
|
||||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||||
};
|
};
|
||||||
|
trace!("Starting OAuth authorization");
|
||||||
let is_linking = aq.link == Some(true);
|
|
||||||
|
|
||||||
// Persist link intent using a short-lived cookie; callbacks won't carry our query params.
|
|
||||||
if is_linking {
|
|
||||||
cookie.add(
|
|
||||||
axum_cookie::cookie::Cookie::builder("link", "1")
|
|
||||||
.http_only(true)
|
|
||||||
.same_site(axum_cookie::prelude::SameSite::Lax)
|
|
||||||
.path("/")
|
|
||||||
// TODO: Pick a reasonable max age that aligns with how long OAuth providers can successfully complete the flow.
|
|
||||||
.max_age(std::time::Duration::from_secs(60 * 60))
|
|
||||||
.build(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
trace!(linking = %is_linking, "Starting OAuth authorization");
|
|
||||||
|
|
||||||
// Try to acquire the existing session (PKCE session is ignored)
|
|
||||||
let existing_session = match session::get_session_token(&cookie) {
|
|
||||||
Some(token) => match session::decode_jwt(&token, &app_state.jwt_decoding_key) {
|
|
||||||
Some(claims) if !session::is_pkce_session(&claims) => Some(claims),
|
|
||||||
Some(_) => {
|
|
||||||
debug!("existing session ignored; it is a PKCE session");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
debug!("invalid session token");
|
|
||||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => {
|
|
||||||
debug!("missing session cookie");
|
|
||||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If linking is enabled, error if the session doesn't exist or is a PKCE session
|
|
||||||
if is_linking && existing_session.is_none() {
|
|
||||||
warn!("missing session cookie during linking flow, refusing");
|
|
||||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
|
||||||
}
|
|
||||||
|
|
||||||
let auth_info = match prov.authorize(&app_state.jwt_encoding_key).await {
|
let auth_info = match prov.authorize(&app_state.jwt_encoding_key).await {
|
||||||
Ok(info) => info,
|
Ok(info) => info,
|
||||||
Err(e) => return e.into_response(),
|
Err(e) => return e.into_response(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let final_token = if let Some(mut claims) = existing_session {
|
session::set_session_cookie(&cookie, &auth_info.session_token);
|
||||||
// We have a user session and are linking. Merge PKCE info into it.
|
|
||||||
if let Some(pkce_claims) = session::decode_jwt(&auth_info.session_token, &app_state.jwt_decoding_key) {
|
|
||||||
claims.pkce_verifier = pkce_claims.pkce_verifier;
|
|
||||||
claims.csrf_state = pkce_claims.csrf_state;
|
|
||||||
|
|
||||||
// re-encode
|
|
||||||
encode(&Header::new(Algorithm::HS256), &claims, &app_state.jwt_encoding_key).expect("jwt sign")
|
|
||||||
} else {
|
|
||||||
warn!("Failed to decode PKCE session token during linking flow");
|
|
||||||
// Fallback to just using the PKCE token, which will break linking but not panic.
|
|
||||||
auth_info.session_token
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Not linking or no existing session, just use the new token.
|
|
||||||
auth_info.session_token
|
|
||||||
};
|
|
||||||
|
|
||||||
session::set_session_cookie(&cookie, &final_token);
|
|
||||||
trace!("Redirecting to provider authorization page");
|
trace!("Redirecting to provider authorization page");
|
||||||
Redirect::to(auth_info.authorize_url.as_str()).into_response()
|
Redirect::to(auth_info.authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
@@ -149,146 +84,62 @@ pub async fn oauth_callback_handler(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!(cookies = ?cookie.cookie().iter().collect::<Vec<_>>(), "Cookies");
|
// --- Simplified Sign-in / Sign-up Flow ---
|
||||||
|
let linking_span = debug_span!("account_linking", provider_user_id = %user.id, provider_email = ?user.email, email_verified = %user.email_verified);
|
||||||
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
|
let db_user_result: Result<user_repo::User, sqlx::Error> = async {
|
||||||
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
|
// 1. Check if we already have this specific provider account linked
|
||||||
if link_cookie.is_some() {
|
if let Some(user) = user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await? {
|
||||||
cookie.remove("link");
|
debug!(user_id = %user.id, "Found existing user by provider ID");
|
||||||
}
|
return Ok(user);
|
||||||
let email = user.email.as_deref();
|
|
||||||
|
|
||||||
// Determine linking intent with a valid session
|
|
||||||
if link_cookie.as_deref() == Some("1") {
|
|
||||||
debug!("Link intent present");
|
|
||||||
|
|
||||||
if let Some(claims) =
|
|
||||||
session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key))
|
|
||||||
{
|
|
||||||
// Perform linking with current session user
|
|
||||||
let (cur_prov, cur_id) = claims.subject.split_once(':').unwrap_or(("", ""));
|
|
||||||
let current_user = match user_repo::find_user_by_provider_id(&app_state.db, cur_prov, cur_id).await {
|
|
||||||
Ok(Some(u)) => u,
|
|
||||||
Ok(None) => {
|
|
||||||
warn!("Current session user not found; proceeding as normal sign-in");
|
|
||||||
return ErrorResponse::bad_request("invalid_request", Some("current session user not found".into()))
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if let Err(e) = user_repo::link_oauth_account(
|
|
||||||
&app_state.db,
|
|
||||||
current_user.id,
|
|
||||||
&provider,
|
|
||||||
&user.id,
|
|
||||||
email,
|
|
||||||
Some(&user.username),
|
|
||||||
user.name.as_deref(),
|
|
||||||
user.avatar_url.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!(error = %e, %provider, "Failed to link OAuth account");
|
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
|
||||||
}
|
|
||||||
return (StatusCode::FOUND, Redirect::to("/profile")).into_response();
|
|
||||||
} else {
|
|
||||||
warn!(%provider, "Link intent present but session missing/invalid; proceeding as normal sign-in");
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Normal sign-in: do NOT auto-link by email (security). If email exists, require linking flow.
|
// 2. If not, try to find an existing user by verified email to link to
|
||||||
if let Some(e) = email {
|
let user_to_link = if user.email_verified {
|
||||||
if let Ok(Some(existing)) = user_repo::find_user_by_email(&app_state.db, e).await {
|
if let Some(email) = user.email.as_deref() {
|
||||||
// Only block if the user already has at least one linked provider.
|
// Try to find a user with this email
|
||||||
// NOTE: We do not check whether providers are currently active. If a user has exactly one provider and it is inactive,
|
if let Some(existing_user) = user_repo::find_user_by_email(&app_state.db, email).await? {
|
||||||
// this may lock them out until the provider is reactivated or a manual admin link is performed.
|
debug!(user_id = %existing_user.id, "Found existing user by email, linking new provider");
|
||||||
match user_repo::get_oauth_account_count_for_user(&app_state.db, existing.id).await {
|
existing_user
|
||||||
Ok(count) if count > 0 => {
|
} else {
|
||||||
// Check if the "new" provider is already linked to the user
|
// No user with this email, create a new one
|
||||||
match user_repo::find_user_by_provider_id(&app_state.db, &provider, &user.id).await {
|
debug!("No user found by email, creating a new one");
|
||||||
Ok(Some(_)) => {
|
user_repo::create_user(&app_state.db, Some(email)).await?
|
||||||
debug!(
|
|
||||||
%provider,
|
|
||||||
%existing.id,
|
|
||||||
"Provider already linked to user, signing in normally");
|
|
||||||
}
|
|
||||||
Ok(None) => {
|
|
||||||
debug!(
|
|
||||||
%provider,
|
|
||||||
%existing.id,
|
|
||||||
"Provider not linked to user, failing"
|
|
||||||
);
|
|
||||||
return ErrorResponse::bad_request(
|
|
||||||
"account_exists",
|
|
||||||
Some(format!(
|
|
||||||
"An account already exists for {}. Sign in with your existing provider, then visit /auth/{}?link=true to add this provider.",
|
|
||||||
e, provider
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(error = %e, %provider, "Failed to find user by provider ID");
|
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => {
|
|
||||||
// No providers linked yet: safe to associate this provider
|
|
||||||
if let Err(e) = user_repo::link_oauth_account(
|
|
||||||
&app_state.db,
|
|
||||||
existing.id,
|
|
||||||
&provider,
|
|
||||||
&user.id,
|
|
||||||
email,
|
|
||||||
Some(&user.username),
|
|
||||||
user.name.as_deref(),
|
|
||||||
user.avatar_url.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!(error = %e, %provider, "Failed to link OAuth account to existing user with no providers");
|
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(error = %e, "Failed to count oauth accounts for user");
|
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Verified, but no email for some reason. Create a user without an email.
|
||||||
|
user_repo::create_user(&app_state.db, None).await?
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create new user with email
|
// No verified email, so we must create a new user without an email.
|
||||||
match user_repo::create_user(
|
debug!("No verified email, creating a new user");
|
||||||
&app_state.db,
|
user_repo::create_user(&app_state.db, None).await?
|
||||||
&user.username,
|
};
|
||||||
user.name.as_deref(),
|
|
||||||
email,
|
// 3. Link the new provider account to our user record (whether old or new)
|
||||||
user.avatar_url.as_deref(),
|
user_repo::link_oauth_account(
|
||||||
&provider,
|
&app_state.db,
|
||||||
&user.id,
|
user_to_link.id,
|
||||||
)
|
&provider,
|
||||||
.await
|
&user.id,
|
||||||
{
|
user.email.as_deref(),
|
||||||
Ok(u) => u,
|
Some(&user.username),
|
||||||
Err(e) => {
|
user.name.as_deref(),
|
||||||
warn!(error = %e, %provider, "Failed to create user");
|
user.avatar_url.as_deref(),
|
||||||
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No email available: disallow sign-in for safety
|
|
||||||
return ErrorResponse::bad_request(
|
|
||||||
"invalid_request",
|
|
||||||
Some("account has no email; sign in with a different provider".into()),
|
|
||||||
)
|
)
|
||||||
.into_response();
|
.await?;
|
||||||
|
|
||||||
|
Ok(user_to_link)
|
||||||
}
|
}
|
||||||
|
.instrument(linking_span)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let _: user_repo::User = match db_user_result {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %(&e as &dyn std::error::Error), "Failed to process user linking/creation");
|
||||||
|
return ErrorResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, "database_error", None).into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Create session token
|
// Create session token
|
||||||
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ use std::{collections::HashMap, sync::Arc};
|
|||||||
|
|
||||||
use pacman_server::{
|
use pacman_server::{
|
||||||
auth::{
|
auth::{
|
||||||
provider::{MockOAuthProvider, OAuthProvider},
|
provider::{AuthUser, MockOAuthProvider, OAuthProvider},
|
||||||
AuthRegistry,
|
AuthRegistry,
|
||||||
},
|
},
|
||||||
|
data::user as user_repo,
|
||||||
session,
|
session,
|
||||||
};
|
};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
@@ -13,41 +14,13 @@ use time::Duration;
|
|||||||
mod common;
|
mod common;
|
||||||
use crate::common::{test_context, TestContext};
|
use crate::common::{test_context, TestContext};
|
||||||
|
|
||||||
/// Test OAuth authorization flows
|
/// Test the basic authorization redirect flow
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oauth_authorization_flows() {
|
async fn test_oauth_authorization_redirect() {
|
||||||
let TestContext { server, .. } = test_context().call().await;
|
|
||||||
|
|
||||||
// Test OAuth authorize endpoint (should redirect)
|
|
||||||
let response = server.get("/auth/github").await;
|
|
||||||
assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth
|
|
||||||
|
|
||||||
// Test OAuth authorize endpoint for Discord
|
|
||||||
let response = server.get("/auth/discord").await;
|
|
||||||
assert_eq!(response.status_code(), 303); // Redirect to Discord OAuth
|
|
||||||
|
|
||||||
// Test unknown provider
|
|
||||||
let response = server.get("/auth/unknown").await;
|
|
||||||
assert_eq!(response.status_code(), 400); // Bad request for unknown provider
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test OAuth callback handling
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_oauth_callback_handling() {
|
|
||||||
let TestContext { server, .. } = test_context().call().await;
|
|
||||||
|
|
||||||
// Test OAuth callback with missing parameters (should fail gracefully)
|
|
||||||
let response = server.get("/auth/github/callback").await;
|
|
||||||
assert_eq!(response.status_code(), 400); // Bad request for missing code/state
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test OAuth authorization flow
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_oauth_authorization_flow() {
|
|
||||||
let mut mock = MockOAuthProvider::new();
|
let mut mock = MockOAuthProvider::new();
|
||||||
mock.expect_authorize().returning(|encoding_key| {
|
mock.expect_authorize().returning(|encoding_key| {
|
||||||
Ok(pacman_server::auth::provider::AuthorizeInfo {
|
Ok(pacman_server::auth::provider::AuthorizeInfo {
|
||||||
authorize_url: "https://example.com".parse().unwrap(),
|
authorize_url: "https://example.com/auth".parse().unwrap(),
|
||||||
session_token: session::create_pkce_session("verifier", "state", encoding_key),
|
session_token: session::create_pkce_session("verifier", "state", encoding_key),
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@@ -59,194 +32,26 @@ async fn test_oauth_authorization_flow() {
|
|||||||
|
|
||||||
let TestContext { server, app_state, .. } = test_context().auth_registry(mock_registry).call().await;
|
let TestContext { server, app_state, .. } = test_context().auth_registry(mock_registry).call().await;
|
||||||
|
|
||||||
// Test that valid handlers redirect
|
|
||||||
let response = server.get("/auth/mock").await;
|
|
||||||
assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth
|
|
||||||
|
|
||||||
// Test that unknown handlers return an error
|
|
||||||
let response = server.get("/auth/unknown").await;
|
|
||||||
assert_eq!(response.status_code(), 400); // Bad request for unknown provider
|
|
||||||
|
|
||||||
// Test that session cookie is set
|
|
||||||
let response = server.get("/auth/mock").await;
|
let response = server.get("/auth/mock").await;
|
||||||
assert_eq!(response.status_code(), 303);
|
assert_eq!(response.status_code(), 303);
|
||||||
let cookies = {
|
assert_eq!(response.headers().get("location").unwrap(), "https://example.com/auth");
|
||||||
let cookies = response.cookies();
|
|
||||||
cookies.iter().cloned().collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
assert_eq!(cookies.len(), 1);
|
|
||||||
assert_eq!(cookies[0].name(), "session");
|
|
||||||
let claims = session::decode_jwt(cookies[0].value(), &app_state.jwt_decoding_key).unwrap();
|
|
||||||
assert!(session::is_pkce_session(&claims));
|
|
||||||
|
|
||||||
// Test that link parameter redirects and sets a link cookie
|
let session_cookie = response.cookie("session");
|
||||||
let response = server.get("/auth/mock?link=true").await;
|
let claims = session::decode_jwt(session_cookie.value(), &app_state.jwt_decoding_key).unwrap();
|
||||||
assert_eq!(response.status_code(), 303);
|
assert!(session::is_pkce_session(&claims), "A PKCE session should be set");
|
||||||
assert_eq!(response.maybe_cookie("link").is_some(), true);
|
|
||||||
assert_eq!(response.maybe_cookie("link").unwrap().value(), "1");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test OAuth callback validation
|
/// Test new user registration via OAuth callback
|
||||||
#[tokio::test]
|
|
||||||
async fn test_oauth_callback_validation() {
|
|
||||||
let mut mock = MockOAuthProvider::new();
|
|
||||||
mock.expect_handle_callback()
|
|
||||||
.times(0) // Should not be called
|
|
||||||
.returning(|_, _, _, _| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
|
||||||
id: "123".to_string(),
|
|
||||||
username: "testuser".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("test@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
|
||||||
let mock_registry = AuthRegistry {
|
|
||||||
providers: HashMap::from([("mock", provider)]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let TestContext { server, .. } = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
|
||||||
|
|
||||||
// Test that an unknown provider returns an error
|
|
||||||
let response = server.get("/auth/unknown/callback?code=a&state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 400);
|
|
||||||
|
|
||||||
// Test that a provider-returned error is handled
|
|
||||||
let response = server.get("/auth/mock/callback?error=access_denied").await;
|
|
||||||
assert_eq!(response.status_code(), 400);
|
|
||||||
|
|
||||||
// Test that a missing code returns an error
|
|
||||||
let response = server.get("/auth/mock/callback?state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 400);
|
|
||||||
|
|
||||||
// Test that a missing state returns an error
|
|
||||||
let response = server.get("/auth/mock/callback?code=a").await;
|
|
||||||
assert_eq!(response.status_code(), 400);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test OAuth callback processing
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_oauth_callback_processing() {
|
|
||||||
let mut mock = MockOAuthProvider::new();
|
|
||||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
|
||||||
id: "123".to_string(),
|
|
||||||
username: "testuser".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("processing@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
|
||||||
let mock_registry = AuthRegistry {
|
|
||||||
providers: HashMap::from([("mock", provider)]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
|
||||||
|
|
||||||
// Test that a successful callback redirects and sets a session cookie
|
|
||||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 302);
|
|
||||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
|
||||||
assert!(response.maybe_cookie("session").is_some());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test account linking flow
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_account_linking_flow() {
|
|
||||||
let mut initial_provider_mock = MockOAuthProvider::new();
|
|
||||||
initial_provider_mock.expect_handle_callback().returning(move |_, _, _, _| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
|
||||||
id: "123".to_string(),
|
|
||||||
username: "linkuser".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("link@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let initial_provider: Arc<dyn OAuthProvider> = Arc::new(initial_provider_mock);
|
|
||||||
|
|
||||||
let mut link_provider_mock = MockOAuthProvider::new();
|
|
||||||
link_provider_mock.expect_authorize().returning(|encoding_key| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthorizeInfo {
|
|
||||||
authorize_url: "https://example.com".parse().unwrap(),
|
|
||||||
session_token: session::create_pkce_session("verifier", "state", encoding_key),
|
|
||||||
})
|
|
||||||
});
|
|
||||||
link_provider_mock.expect_handle_callback().returning(|_, _, _, _| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
|
||||||
id: "456".to_string(),
|
|
||||||
username: "linkuser_new".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("link@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let link_provider: Arc<dyn OAuthProvider> = Arc::new(link_provider_mock);
|
|
||||||
|
|
||||||
let registry = AuthRegistry {
|
|
||||||
providers: HashMap::from([("mock_initial", initial_provider), ("mock_link", link_provider)]),
|
|
||||||
};
|
|
||||||
let context = test_context().use_database(true).auth_registry(registry).call().await;
|
|
||||||
|
|
||||||
// 1. Create an initial user and provider link
|
|
||||||
let user = pacman_server::data::user::create_user(
|
|
||||||
&context.app_state.db,
|
|
||||||
"linkuser",
|
|
||||||
None,
|
|
||||||
Some("link@example.com"),
|
|
||||||
None,
|
|
||||||
"mock_initial",
|
|
||||||
"123",
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("Failed to create user");
|
|
||||||
|
|
||||||
{
|
|
||||||
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
|
|
||||||
.await
|
|
||||||
.expect("Failed to list user's initial provider(s)");
|
|
||||||
assert_eq!(providers.len(), 1, "User should have one provider");
|
|
||||||
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Create a session for this user
|
|
||||||
let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 302);
|
|
||||||
|
|
||||||
// Begin linking flow
|
|
||||||
let response = context.server.get("/auth/mock_link?link=true").await;
|
|
||||||
assert_eq!(response.status_code(), 303);
|
|
||||||
|
|
||||||
// 3. Perform the linking call
|
|
||||||
let response = context.server.get("/auth/mock_link/callback?code=a&state=b").await;
|
|
||||||
|
|
||||||
assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect");
|
|
||||||
assert_eq!(
|
|
||||||
response.headers().get("location").unwrap(),
|
|
||||||
"/profile",
|
|
||||||
"Post-linking response should redirect to /profile"
|
|
||||||
);
|
|
||||||
|
|
||||||
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
|
|
||||||
.await
|
|
||||||
.expect("Failed to list user's providers");
|
|
||||||
assert_eq!(providers.len(), 2, "User should have two providers");
|
|
||||||
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
|
|
||||||
assert!(providers.iter().any(|p| p.provider == "mock_link"));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test new user registration
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_new_user_registration() {
|
async fn test_new_user_registration() {
|
||||||
let mut mock = MockOAuthProvider::new();
|
let mut mock = MockOAuthProvider::new();
|
||||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
Ok(AuthUser {
|
||||||
id: "123".to_string(),
|
id: "newuser123".to_string(),
|
||||||
username: "testuser".to_string(),
|
username: "new_user".to_string(),
|
||||||
name: None,
|
name: None,
|
||||||
email: Some("newuser@example.com".to_string()),
|
email: Some("new@example.com".to_string()),
|
||||||
|
email_verified: true,
|
||||||
avatar_url: None,
|
avatar_url: None,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@@ -258,27 +63,151 @@ async fn test_new_user_registration() {
|
|||||||
|
|
||||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||||
|
|
||||||
// Test that the OAuth callback handler creates a new user account when no existing user is found
|
|
||||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||||
assert_eq!(response.status_code(), 302);
|
assert_eq!(response.status_code(), 302);
|
||||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
||||||
let user = pacman_server::data::user::find_user_by_email(&context.app_state.db, "newuser@example.com")
|
|
||||||
|
// Verify user and oauth_account were created
|
||||||
|
let user = user_repo::find_user_by_email(&context.app_state.db, "new@example.com")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("User should be created");
|
||||||
|
assert_eq!(user.email, Some("new@example.com".to_string()));
|
||||||
|
|
||||||
|
let providers = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
|
||||||
|
assert_eq!(providers.len(), 1);
|
||||||
|
assert_eq!(providers[0].provider, "mock");
|
||||||
|
assert_eq!(providers[0].provider_user_id, "newuser123");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test sign-in for an existing user with an already-linked provider
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_existing_user_signin() {
|
||||||
|
let mut mock = MockOAuthProvider::new();
|
||||||
|
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
|
Ok(AuthUser {
|
||||||
|
id: "existing123".to_string(),
|
||||||
|
username: "existing_user".to_string(),
|
||||||
|
name: None,
|
||||||
|
email: Some("existing@example.com".to_string()),
|
||||||
|
email_verified: true,
|
||||||
|
avatar_url: None,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||||
|
let mock_registry = AuthRegistry {
|
||||||
|
providers: HashMap::from([("mock", provider)]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||||
|
|
||||||
|
// Pre-create the user and link
|
||||||
|
let user = user_repo::create_user(&context.app_state.db, Some("existing@example.com"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
user_repo::link_oauth_account(
|
||||||
|
&context.app_state.db,
|
||||||
|
user.id,
|
||||||
|
"mock",
|
||||||
|
"existing123",
|
||||||
|
Some("existing@example.com"),
|
||||||
|
Some("existing_user"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||||
|
assert_eq!(response.status_code(), 302, "Should sign in successfully");
|
||||||
|
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
||||||
|
|
||||||
|
// Verify no new user was created
|
||||||
|
let users = sqlx::query("SELECT * FROM users")
|
||||||
|
.fetch_all(&context.app_state.db)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(users.len(), 1, "No new user should be created");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test implicit account linking via a shared verified email
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_implicit_account_linking() {
|
||||||
|
// 1. User signs in with 'provider-a'
|
||||||
|
let mut mock_a = MockOAuthProvider::new();
|
||||||
|
mock_a.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
|
Ok(AuthUser {
|
||||||
|
id: "user_a_123".to_string(),
|
||||||
|
username: "user_a".to_string(),
|
||||||
|
name: None,
|
||||||
|
email: Some("shared@example.com".to_string()),
|
||||||
|
email_verified: true,
|
||||||
|
avatar_url: None,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
// 2. Later, the same user signs in with 'provider-b'
|
||||||
|
let mut mock_b = MockOAuthProvider::new();
|
||||||
|
mock_b.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
|
Ok(AuthUser {
|
||||||
|
id: "user_b_456".to_string(),
|
||||||
|
username: "user_b".to_string(),
|
||||||
|
name: None,
|
||||||
|
email: Some("shared@example.com".to_string()),
|
||||||
|
email_verified: true,
|
||||||
|
avatar_url: None,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider_a: Arc<dyn OAuthProvider> = Arc::new(mock_a);
|
||||||
|
let provider_b: Arc<dyn OAuthProvider> = Arc::new(mock_b);
|
||||||
|
let mock_registry = AuthRegistry {
|
||||||
|
providers: HashMap::from([("provider-a", provider_a), ("provider-b", provider_b)]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||||
|
|
||||||
|
// Action 1: Sign in with provider-a, creating the initial user
|
||||||
|
let response1 = context.server.get("/auth/provider-a/callback?code=a&state=b").await;
|
||||||
|
assert_eq!(response1.status_code(), 302);
|
||||||
|
|
||||||
|
let user = user_repo::find_user_by_email(&context.app_state.db, "shared@example.com")
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(user.email, Some("newuser@example.com".to_string()));
|
let providers1 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
|
||||||
|
assert_eq!(providers1.len(), 1);
|
||||||
|
assert_eq!(providers1[0].provider, "provider-a");
|
||||||
|
|
||||||
|
// Action 2: Sign in with provider-b
|
||||||
|
let response2 = context.server.get("/auth/provider-b/callback?code=a&state=b").await;
|
||||||
|
assert_eq!(response2.status_code(), 302);
|
||||||
|
|
||||||
|
// Assertions: No new user, but a new provider link
|
||||||
|
let users = sqlx::query("SELECT * FROM users")
|
||||||
|
.fetch_all(&context.app_state.db)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(users.len(), 1, "A new user should NOT have been created");
|
||||||
|
|
||||||
|
let providers2 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
|
||||||
|
assert_eq!(providers2.len(), 2, "A new provider should have been linked");
|
||||||
|
assert!(providers2.iter().any(|p| p.provider == "provider-a"));
|
||||||
|
assert!(providers2.iter().any(|p| p.provider == "provider-b"));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test OAuth callback handler rejects sign-in attempts when no email is available
|
/// Test that an unverified email does NOT link accounts
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_oauth_callback_no_email() {
|
async fn test_unverified_email_creates_new_account() {
|
||||||
let mut mock = MockOAuthProvider::new();
|
let mut mock = MockOAuthProvider::new();
|
||||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
Ok(AuthUser {
|
||||||
id: "456".to_string(),
|
id: "unverified123".to_string(),
|
||||||
username: "noemailuser".to_string(),
|
username: "unverified_user".to_string(),
|
||||||
name: None,
|
name: None,
|
||||||
email: None,
|
email: Some("unverified@example.com".to_string()),
|
||||||
|
email_verified: false,
|
||||||
avatar_url: None,
|
avatar_url: None,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@@ -290,86 +219,20 @@ async fn test_oauth_callback_no_email() {
|
|||||||
|
|
||||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||||
|
|
||||||
// Test that the OAuth callback handler rejects sign-in attempts when no email is available
|
// Pre-create a user with the same email, but they will not be linked.
|
||||||
|
user_repo::create_user(&context.app_state.db, Some("unverified@example.com"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||||
assert_eq!(response.status_code(), 400);
|
assert_eq!(response.status_code(), 302);
|
||||||
}
|
|
||||||
|
|
||||||
/// Test existing user sign-in with new provider fails
|
// Should create a second user because the email wasn't trusted for linking
|
||||||
#[tokio::test]
|
let users = sqlx::query("SELECT * FROM users")
|
||||||
async fn test_existing_user_sign_in_new_provider_fails() {
|
.fetch_all(&context.app_state.db)
|
||||||
let mut mock = MockOAuthProvider::new();
|
.await
|
||||||
mock.expect_handle_callback().returning(move |_, _, _, _| {
|
.unwrap();
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
assert_eq!(users.len(), 2, "A new user should be created for the unverified email");
|
||||||
id: "456".to_string(), // Different provider ID
|
|
||||||
username: "existinguser_newprovider".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("existing@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
|
||||||
let mock_registry = AuthRegistry {
|
|
||||||
providers: HashMap::from([("mock_new", provider)]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
|
||||||
|
|
||||||
// Create a user with one linked provider
|
|
||||||
pacman_server::data::user::create_user(
|
|
||||||
&context.app_state.db,
|
|
||||||
"existinguser",
|
|
||||||
None,
|
|
||||||
Some("existing@example.com"),
|
|
||||||
None,
|
|
||||||
"mock",
|
|
||||||
"123",
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// A user with the email exists, but has one provider. If they sign in with a *new* provider, it should fail.
|
|
||||||
let response = context.server.get("/auth/mock_new/callback?code=a&state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 400); // Should fail and ask to link explicitly.
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test existing user sign-in with existing provider succeeds
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_existing_user_sign_in_existing_provider_succeeds() {
|
|
||||||
let mut mock = MockOAuthProvider::new();
|
|
||||||
mock.expect_handle_callback().returning(move |_, _, _, _| {
|
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
|
||||||
id: "123".to_string(), // Same provider ID as created user
|
|
||||||
username: "existinguser".to_string(),
|
|
||||||
name: None,
|
|
||||||
email: Some("existing@example.com".to_string()),
|
|
||||||
avatar_url: None,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
|
||||||
let mock_registry = AuthRegistry {
|
|
||||||
providers: HashMap::from([("mock", provider)]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
|
||||||
|
|
||||||
// Create a user with one linked provider
|
|
||||||
pacman_server::data::user::create_user(
|
|
||||||
&context.app_state.db,
|
|
||||||
"existinguser",
|
|
||||||
None,
|
|
||||||
Some("existing@example.com"),
|
|
||||||
None,
|
|
||||||
"mock",
|
|
||||||
"123",
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Test sign-in with an existing linked provider.
|
|
||||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
|
||||||
assert_eq!(response.status_code(), 302); // Should sign in successfully
|
|
||||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test logout functionality
|
/// Test logout functionality
|
||||||
@@ -377,11 +240,12 @@ async fn test_existing_user_sign_in_existing_provider_succeeds() {
|
|||||||
async fn test_logout_functionality() {
|
async fn test_logout_functionality() {
|
||||||
let mut mock = MockOAuthProvider::new();
|
let mut mock = MockOAuthProvider::new();
|
||||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||||
Ok(pacman_server::auth::provider::AuthUser {
|
Ok(AuthUser {
|
||||||
id: "123".to_string(),
|
id: "123".to_string(),
|
||||||
username: "testuser".to_string(),
|
username: "testuser".to_string(),
|
||||||
name: None,
|
name: None,
|
||||||
email: Some("test@example.com".to_string()),
|
email: Some("test@example.com".to_string()),
|
||||||
|
email_verified: true,
|
||||||
avatar_url: None,
|
avatar_url: None,
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@@ -395,27 +259,14 @@ async fn test_logout_functionality() {
|
|||||||
// Sign in to establish a session
|
// Sign in to establish a session
|
||||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||||
assert_eq!(response.status_code(), 302);
|
assert_eq!(response.status_code(), 302);
|
||||||
let session_cookie = response.cookie("session").clone();
|
|
||||||
|
|
||||||
// Test that the logout handler clears the session cookie and redirects
|
// Test that the logout handler clears the session cookie and redirects
|
||||||
let response = context.server.get("/logout").await;
|
let response = context.server.get("/logout").await;
|
||||||
|
|
||||||
// Redirect assertions
|
|
||||||
assert_eq!(response.status_code(), 302);
|
assert_eq!(response.status_code(), 302);
|
||||||
assert!(
|
assert!(response.headers().contains_key("location"));
|
||||||
response.headers().contains_key("location"),
|
|
||||||
"Response redirect should have a location header"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Cookie assertions
|
let cookie = response.cookie("session");
|
||||||
assert_eq!(
|
assert_eq!(cookie.value(), "removed");
|
||||||
response.cookie("session").value(),
|
assert_eq!(cookie.max_age(), Some(Duration::ZERO));
|
||||||
"removed",
|
|
||||||
"Session cookie should be removed"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
response.cookie("session").max_age(),
|
|
||||||
Some(Duration::ZERO),
|
|
||||||
"Session cookie should have a max age of 0"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
mod common;
|
mod common;
|
||||||
use crate::common::{test_context, TestContext};
|
use crate::common::test_context;
|
||||||
use cookie::Cookie;
|
use cookie::Cookie;
|
||||||
use pacman_server::session;
|
use pacman_server::{data::user as user_repo, session};
|
||||||
|
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
@@ -9,25 +9,30 @@ use pretty_assertions::assert_eq;
|
|||||||
async fn test_session_management() {
|
async fn test_session_management() {
|
||||||
let context = test_context().use_database(true).call().await;
|
let context = test_context().use_database(true).call().await;
|
||||||
|
|
||||||
// 1. Create a user
|
// 1. Create a user and link a provider account
|
||||||
let user =
|
let user = user_repo::create_user(&context.app_state.db, Some("test@example.com"))
|
||||||
pacman_server::data::user::create_user(&context.app_state.db, "testuser", None, None, None, "test_provider", "123")
|
.await
|
||||||
.await
|
.unwrap();
|
||||||
.unwrap();
|
let provider_account = user_repo::link_oauth_account(
|
||||||
|
&context.app_state.db,
|
||||||
|
user.id,
|
||||||
|
"test_provider",
|
||||||
|
"123",
|
||||||
|
Some("test@example.com"),
|
||||||
|
Some("testuser"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// 2. Create a session token for the user
|
// 2. Create a session token for the user
|
||||||
let provider_account = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.find(|p| p.provider == "test_provider")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let auth_user = pacman_server::auth::provider::AuthUser {
|
let auth_user = pacman_server::auth::provider::AuthUser {
|
||||||
id: provider_account.provider_user_id,
|
id: provider_account.provider_user_id,
|
||||||
username: provider_account.username.unwrap(),
|
username: provider_account.username.unwrap(),
|
||||||
name: provider_account.display_name,
|
name: provider_account.display_name,
|
||||||
email: user.email,
|
email: user.email,
|
||||||
|
email_verified: true,
|
||||||
avatar_url: provider_account.avatar_url,
|
avatar_url: provider_account.avatar_url,
|
||||||
};
|
};
|
||||||
let token = session::create_jwt_for_user("test_provider", &auth_user, &context.app_state.jwt_encoding_key);
|
let token = session::create_jwt_for_user("test_provider", &auth_user, &context.app_state.jwt_encoding_key);
|
||||||
|
|||||||
Reference in New Issue
Block a user