fix: rewrite oauth provider linking system, add email_verified attribute for providers

This commit is contained in:
Ryan Walters
2025-09-24 13:38:31 -05:00
parent bdd3c74a2d
commit c524fdb3e7
9 changed files with 296 additions and 573 deletions
+166 -315
View File
@@ -2,9 +2,10 @@ use std::{collections::HashMap, sync::Arc};
use pacman_server::{
auth::{
provider::{MockOAuthProvider, OAuthProvider},
provider::{AuthUser, MockOAuthProvider, OAuthProvider},
AuthRegistry,
},
data::user as user_repo,
session,
};
use pretty_assertions::assert_eq;
@@ -13,41 +14,13 @@ use time::Duration;
mod common;
use crate::common::{test_context, TestContext};
/// Test OAuth authorization flows
/// Test the basic authorization redirect flow
#[tokio::test]
async fn test_oauth_authorization_flows() {
let TestContext { server, .. } = test_context().call().await;
// Test OAuth authorize endpoint (should redirect)
let response = server.get("/auth/github").await;
assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth
// Test OAuth authorize endpoint for Discord
let response = server.get("/auth/discord").await;
assert_eq!(response.status_code(), 303); // Redirect to Discord OAuth
// Test unknown provider
let response = server.get("/auth/unknown").await;
assert_eq!(response.status_code(), 400); // Bad request for unknown provider
}
/// Test OAuth callback handling
#[tokio::test]
async fn test_oauth_callback_handling() {
let TestContext { server, .. } = test_context().call().await;
// Test OAuth callback with missing parameters (should fail gracefully)
let response = server.get("/auth/github/callback").await;
assert_eq!(response.status_code(), 400); // Bad request for missing code/state
}
/// Test OAuth authorization flow
#[tokio::test]
async fn test_oauth_authorization_flow() {
async fn test_oauth_authorization_redirect() {
let mut mock = MockOAuthProvider::new();
mock.expect_authorize().returning(|encoding_key| {
Ok(pacman_server::auth::provider::AuthorizeInfo {
authorize_url: "https://example.com".parse().unwrap(),
authorize_url: "https://example.com/auth".parse().unwrap(),
session_token: session::create_pkce_session("verifier", "state", encoding_key),
})
});
@@ -59,194 +32,26 @@ async fn test_oauth_authorization_flow() {
let TestContext { server, app_state, .. } = test_context().auth_registry(mock_registry).call().await;
// Test that valid handlers redirect
let response = server.get("/auth/mock").await;
assert_eq!(response.status_code(), 303); // Redirect to GitHub OAuth
// Test that unknown handlers return an error
let response = server.get("/auth/unknown").await;
assert_eq!(response.status_code(), 400); // Bad request for unknown provider
// Test that session cookie is set
let response = server.get("/auth/mock").await;
assert_eq!(response.status_code(), 303);
let cookies = {
let cookies = response.cookies();
cookies.iter().cloned().collect::<Vec<_>>()
};
assert_eq!(cookies.len(), 1);
assert_eq!(cookies[0].name(), "session");
let claims = session::decode_jwt(cookies[0].value(), &app_state.jwt_decoding_key).unwrap();
assert!(session::is_pkce_session(&claims));
assert_eq!(response.headers().get("location").unwrap(), "https://example.com/auth");
// Test that link parameter redirects and sets a link cookie
let response = server.get("/auth/mock?link=true").await;
assert_eq!(response.status_code(), 303);
assert_eq!(response.maybe_cookie("link").is_some(), true);
assert_eq!(response.maybe_cookie("link").unwrap().value(), "1");
let session_cookie = response.cookie("session");
let claims = session::decode_jwt(session_cookie.value(), &app_state.jwt_decoding_key).unwrap();
assert!(session::is_pkce_session(&claims), "A PKCE session should be set");
}
/// Test OAuth callback validation
#[tokio::test]
async fn test_oauth_callback_validation() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback()
.times(0) // Should not be called
.returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "123".to_string(),
username: "testuser".to_string(),
name: None,
email: Some("test@example.com".to_string()),
avatar_url: None,
})
});
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
let mock_registry = AuthRegistry {
providers: HashMap::from([("mock", provider)]),
};
let TestContext { server, .. } = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Test that an unknown provider returns an error
let response = server.get("/auth/unknown/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 400);
// Test that a provider-returned error is handled
let response = server.get("/auth/mock/callback?error=access_denied").await;
assert_eq!(response.status_code(), 400);
// Test that a missing code returns an error
let response = server.get("/auth/mock/callback?state=b").await;
assert_eq!(response.status_code(), 400);
// Test that a missing state returns an error
let response = server.get("/auth/mock/callback?code=a").await;
assert_eq!(response.status_code(), 400);
}
/// Test OAuth callback processing
#[tokio::test]
async fn test_oauth_callback_processing() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "123".to_string(),
username: "testuser".to_string(),
name: None,
email: Some("processing@example.com".to_string()),
avatar_url: None,
})
});
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
let mock_registry = AuthRegistry {
providers: HashMap::from([("mock", provider)]),
};
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Test that a successful callback redirects and sets a session cookie
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302);
assert_eq!(response.headers().get("location").unwrap(), "/profile");
assert!(response.maybe_cookie("session").is_some());
}
/// Test account linking flow
#[tokio::test]
async fn test_account_linking_flow() {
let mut initial_provider_mock = MockOAuthProvider::new();
initial_provider_mock.expect_handle_callback().returning(move |_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "123".to_string(),
username: "linkuser".to_string(),
name: None,
email: Some("link@example.com".to_string()),
avatar_url: None,
})
});
let initial_provider: Arc<dyn OAuthProvider> = Arc::new(initial_provider_mock);
let mut link_provider_mock = MockOAuthProvider::new();
link_provider_mock.expect_authorize().returning(|encoding_key| {
Ok(pacman_server::auth::provider::AuthorizeInfo {
authorize_url: "https://example.com".parse().unwrap(),
session_token: session::create_pkce_session("verifier", "state", encoding_key),
})
});
link_provider_mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "456".to_string(),
username: "linkuser_new".to_string(),
name: None,
email: Some("link@example.com".to_string()),
avatar_url: None,
})
});
let link_provider: Arc<dyn OAuthProvider> = Arc::new(link_provider_mock);
let registry = AuthRegistry {
providers: HashMap::from([("mock_initial", initial_provider), ("mock_link", link_provider)]),
};
let context = test_context().use_database(true).auth_registry(registry).call().await;
// 1. Create an initial user and provider link
let user = pacman_server::data::user::create_user(
&context.app_state.db,
"linkuser",
None,
Some("link@example.com"),
None,
"mock_initial",
"123",
)
.await
.expect("Failed to create user");
{
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
.await
.expect("Failed to list user's initial provider(s)");
assert_eq!(providers.len(), 1, "User should have one provider");
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
}
// 2. Create a session for this user
let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302);
// Begin linking flow
let response = context.server.get("/auth/mock_link?link=true").await;
assert_eq!(response.status_code(), 303);
// 3. Perform the linking call
let response = context.server.get("/auth/mock_link/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect");
assert_eq!(
response.headers().get("location").unwrap(),
"/profile",
"Post-linking response should redirect to /profile"
);
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
.await
.expect("Failed to list user's providers");
assert_eq!(providers.len(), 2, "User should have two providers");
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
assert!(providers.iter().any(|p| p.provider == "mock_link"));
}
/// Test new user registration
/// Test new user registration via OAuth callback
#[tokio::test]
async fn test_new_user_registration() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "123".to_string(),
username: "testuser".to_string(),
Ok(AuthUser {
id: "newuser123".to_string(),
username: "new_user".to_string(),
name: None,
email: Some("newuser@example.com".to_string()),
email: Some("new@example.com".to_string()),
email_verified: true,
avatar_url: None,
})
});
@@ -258,27 +63,151 @@ async fn test_new_user_registration() {
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Test that the OAuth callback handler creates a new user account when no existing user is found
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302);
assert_eq!(response.headers().get("location").unwrap(), "/profile");
let user = pacman_server::data::user::find_user_by_email(&context.app_state.db, "newuser@example.com")
// Verify user and oauth_account were created
let user = user_repo::find_user_by_email(&context.app_state.db, "new@example.com")
.await
.unwrap()
.expect("User should be created");
assert_eq!(user.email, Some("new@example.com".to_string()));
let providers = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].provider, "mock");
assert_eq!(providers[0].provider_user_id, "newuser123");
}
/// Test sign-in for an existing user with an already-linked provider
#[tokio::test]
async fn test_existing_user_signin() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(AuthUser {
id: "existing123".to_string(),
username: "existing_user".to_string(),
name: None,
email: Some("existing@example.com".to_string()),
email_verified: true,
avatar_url: None,
})
});
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
let mock_registry = AuthRegistry {
providers: HashMap::from([("mock", provider)]),
};
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Pre-create the user and link
let user = user_repo::create_user(&context.app_state.db, Some("existing@example.com"))
.await
.unwrap();
user_repo::link_oauth_account(
&context.app_state.db,
user.id,
"mock",
"existing123",
Some("existing@example.com"),
Some("existing_user"),
None,
None,
)
.await
.unwrap();
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302, "Should sign in successfully");
assert_eq!(response.headers().get("location").unwrap(), "/profile");
// Verify no new user was created
let users = sqlx::query("SELECT * FROM users")
.fetch_all(&context.app_state.db)
.await
.unwrap();
assert_eq!(users.len(), 1, "No new user should be created");
}
/// Test implicit account linking via a shared verified email
#[tokio::test]
async fn test_implicit_account_linking() {
// 1. User signs in with 'provider-a'
let mut mock_a = MockOAuthProvider::new();
mock_a.expect_handle_callback().returning(|_, _, _, _| {
Ok(AuthUser {
id: "user_a_123".to_string(),
username: "user_a".to_string(),
name: None,
email: Some("shared@example.com".to_string()),
email_verified: true,
avatar_url: None,
})
});
// 2. Later, the same user signs in with 'provider-b'
let mut mock_b = MockOAuthProvider::new();
mock_b.expect_handle_callback().returning(|_, _, _, _| {
Ok(AuthUser {
id: "user_b_456".to_string(),
username: "user_b".to_string(),
name: None,
email: Some("shared@example.com".to_string()),
email_verified: true,
avatar_url: None,
})
});
let provider_a: Arc<dyn OAuthProvider> = Arc::new(mock_a);
let provider_b: Arc<dyn OAuthProvider> = Arc::new(mock_b);
let mock_registry = AuthRegistry {
providers: HashMap::from([("provider-a", provider_a), ("provider-b", provider_b)]),
};
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Action 1: Sign in with provider-a, creating the initial user
let response1 = context.server.get("/auth/provider-a/callback?code=a&state=b").await;
assert_eq!(response1.status_code(), 302);
let user = user_repo::find_user_by_email(&context.app_state.db, "shared@example.com")
.await
.unwrap()
.unwrap();
assert_eq!(user.email, Some("newuser@example.com".to_string()));
let providers1 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
assert_eq!(providers1.len(), 1);
assert_eq!(providers1[0].provider, "provider-a");
// Action 2: Sign in with provider-b
let response2 = context.server.get("/auth/provider-b/callback?code=a&state=b").await;
assert_eq!(response2.status_code(), 302);
// Assertions: No new user, but a new provider link
let users = sqlx::query("SELECT * FROM users")
.fetch_all(&context.app_state.db)
.await
.unwrap();
assert_eq!(users.len(), 1, "A new user should NOT have been created");
let providers2 = user_repo::list_user_providers(&context.app_state.db, user.id).await.unwrap();
assert_eq!(providers2.len(), 2, "A new provider should have been linked");
assert!(providers2.iter().any(|p| p.provider == "provider-a"));
assert!(providers2.iter().any(|p| p.provider == "provider-b"));
}
/// Test OAuth callback handler rejects sign-in attempts when no email is available
/// Test that an unverified email does NOT link accounts
#[tokio::test]
async fn test_oauth_callback_no_email() {
async fn test_unverified_email_creates_new_account() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "456".to_string(),
username: "noemailuser".to_string(),
Ok(AuthUser {
id: "unverified123".to_string(),
username: "unverified_user".to_string(),
name: None,
email: None,
email: Some("unverified@example.com".to_string()),
email_verified: false,
avatar_url: None,
})
});
@@ -290,86 +219,20 @@ async fn test_oauth_callback_no_email() {
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Test that the OAuth callback handler rejects sign-in attempts when no email is available
// Pre-create a user with the same email, but they will not be linked.
user_repo::create_user(&context.app_state.db, Some("unverified@example.com"))
.await
.unwrap();
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 400);
}
assert_eq!(response.status_code(), 302);
/// Test existing user sign-in with new provider fails
#[tokio::test]
async fn test_existing_user_sign_in_new_provider_fails() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(move |_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "456".to_string(), // Different provider ID
username: "existinguser_newprovider".to_string(),
name: None,
email: Some("existing@example.com".to_string()),
avatar_url: None,
})
});
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
let mock_registry = AuthRegistry {
providers: HashMap::from([("mock_new", provider)]),
};
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Create a user with one linked provider
pacman_server::data::user::create_user(
&context.app_state.db,
"existinguser",
None,
Some("existing@example.com"),
None,
"mock",
"123",
)
.await
.unwrap();
// A user with the email exists, but has one provider. If they sign in with a *new* provider, it should fail.
let response = context.server.get("/auth/mock_new/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 400); // Should fail and ask to link explicitly.
}
/// Test existing user sign-in with existing provider succeeds
#[tokio::test]
async fn test_existing_user_sign_in_existing_provider_succeeds() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(move |_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
id: "123".to_string(), // Same provider ID as created user
username: "existinguser".to_string(),
name: None,
email: Some("existing@example.com".to_string()),
avatar_url: None,
})
});
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
let mock_registry = AuthRegistry {
providers: HashMap::from([("mock", provider)]),
};
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
// Create a user with one linked provider
pacman_server::data::user::create_user(
&context.app_state.db,
"existinguser",
None,
Some("existing@example.com"),
None,
"mock",
"123",
)
.await
.unwrap();
// Test sign-in with an existing linked provider.
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302); // Should sign in successfully
assert_eq!(response.headers().get("location").unwrap(), "/profile");
// Should create a second user because the email wasn't trusted for linking
let users = sqlx::query("SELECT * FROM users")
.fetch_all(&context.app_state.db)
.await
.unwrap();
assert_eq!(users.len(), 2, "A new user should be created for the unverified email");
}
/// Test logout functionality
@@ -377,11 +240,12 @@ async fn test_existing_user_sign_in_existing_provider_succeeds() {
async fn test_logout_functionality() {
let mut mock = MockOAuthProvider::new();
mock.expect_handle_callback().returning(|_, _, _, _| {
Ok(pacman_server::auth::provider::AuthUser {
Ok(AuthUser {
id: "123".to_string(),
username: "testuser".to_string(),
name: None,
email: Some("test@example.com".to_string()),
email_verified: true,
avatar_url: None,
})
});
@@ -395,27 +259,14 @@ async fn test_logout_functionality() {
// Sign in to establish a session
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
assert_eq!(response.status_code(), 302);
let session_cookie = response.cookie("session").clone();
// Test that the logout handler clears the session cookie and redirects
let response = context.server.get("/logout").await;
// Redirect assertions
assert_eq!(response.status_code(), 302);
assert!(
response.headers().contains_key("location"),
"Response redirect should have a location header"
);
assert!(response.headers().contains_key("location"));
// Cookie assertions
assert_eq!(
response.cookie("session").value(),
"removed",
"Session cookie should be removed"
);
assert_eq!(
response.cookie("session").max_age(),
Some(Duration::ZERO),
"Session cookie should have a max age of 0"
);
let cookie = response.cookie("session");
assert_eq!(cookie.value(), "removed");
assert_eq!(cookie.max_age(), Some(Duration::ZERO));
}