mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-05 23:15:40 -06:00
feat: improve test reliability and add request tracing
- Add retry configuration for flaky tests (2 retries for default, 3 for OAuth) - Configure test groups with proper concurrency limits (serial: 1, integration: 4) - Add tower-http tracing layer with custom span formatting for HTTP requests - Simplify database pool handling by removing unnecessary Arc wrapper - Improve test context setup with better logging and error handling - Refactor user creation parameters for better clarity and consistency - Add debug logging for OAuth cookie handling
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
[profile.default]
|
||||
fail-fast = false
|
||||
slow-timeout = { period = "5s", terminate-after = 6 } # max 30 seconds
|
||||
slow-timeout = { period = "5s", terminate-after = 3 } # max 15 seconds
|
||||
retries = 2
|
||||
|
||||
# CI machines are pretty slow, so we need to increase the timeout
|
||||
[profile.ci]
|
||||
@@ -11,12 +12,18 @@ slow-timeout = { period = "30s", terminate-after = 4 } # max 2 minutes for slow
|
||||
slow-timeout = { period = "45s", terminate-after = 5 } # max 3.75 minutes for slow tests
|
||||
status-level = "none"
|
||||
|
||||
[[profile.default.overrides]]
|
||||
# Integration tests in SDL2 run serially (may not be required)
|
||||
[[profile.default.overrides]]
|
||||
filter = 'test(pacman::game::)'
|
||||
test-group = 'serial'
|
||||
|
||||
# Integration tests run max 4 at a time
|
||||
[[profile.default.overrides]]
|
||||
filter = 'test(pacman-server::tests::oauth)'
|
||||
test-group = 'integration'
|
||||
retries = 3
|
||||
|
||||
[test-groups]
|
||||
# Ensure serial tests don't run in parallel
|
||||
serial = { max-threads = 1 }
|
||||
integration = { max-threads = 4 }
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3130,6 +3130,7 @@ dependencies = [
|
||||
"testcontainers",
|
||||
"time",
|
||||
"tokio 1.47.1",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
"tracing-subscriber",
|
||||
@@ -5474,6 +5475,7 @@ dependencies = [
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -44,6 +44,7 @@ jsonwebtoken = { version = "9.3", default-features = false }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
||||
tracing-futures = { version = "0.2.5", features = ["tokio"] }
|
||||
tower-http = { version = "0.6", features = ["trace"] }
|
||||
time = { version = "0.3", features = ["macros", "formatting"] }
|
||||
yansi = "1"
|
||||
s3-tokio = { version = "0.39.6", default-features = false }
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{Notify, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::info_span;
|
||||
|
||||
use crate::data::pool::PgPool;
|
||||
use crate::{auth::AuthRegistry, config::Config, image::ImageStorage, routes};
|
||||
@@ -36,7 +37,7 @@ pub struct AppState {
|
||||
pub sessions: Arc<DashMap<String, crate::auth::provider::AuthUser>>,
|
||||
pub jwt_encoding_key: Arc<EncodingKey>,
|
||||
pub jwt_decoding_key: Arc<DecodingKey>,
|
||||
pub db: Arc<PgPool>,
|
||||
pub db: PgPool,
|
||||
pub health: Arc<RwLock<Health>>,
|
||||
pub image_storage: Arc<ImageStorage>,
|
||||
pub healthchecker_task: Arc<RwLock<Option<JoinHandle<()>>>>,
|
||||
@@ -71,7 +72,7 @@ impl AppState {
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
db: Arc::new(db),
|
||||
db: db,
|
||||
health: Arc::new(RwLock::new(Health::default())),
|
||||
image_storage,
|
||||
healthchecker_task: Arc::new(RwLock::new(None)),
|
||||
@@ -100,7 +101,7 @@ impl AppState {
|
||||
}
|
||||
|
||||
// Run the actual health check
|
||||
let ok = sqlx::query("SELECT 1").execute(&*db_pool).await.is_ok();
|
||||
let ok = sqlx::query("SELECT 1").execute(&db_pool).await.is_ok();
|
||||
{
|
||||
let mut h = health_state.write().await;
|
||||
h.set_database(ok);
|
||||
@@ -127,13 +128,35 @@ impl AppState {
|
||||
|
||||
/// Force an immediate health check (debug mode only)
|
||||
pub async fn check_health(&self) -> bool {
|
||||
let ok = sqlx::query("SELECT 1").execute(&*self.db).await.is_ok();
|
||||
let ok = sqlx::query("SELECT 1").execute(&self.db).await.is_ok();
|
||||
let mut h = self.health.write().await;
|
||||
h.set_database(ok);
|
||||
ok
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a custom span for HTTP requests with reduced verbosity
|
||||
pub fn make_span<B>(request: &axum::http::Request<B>) -> tracing::Span {
|
||||
let path = request
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or_else(|| request.uri().path());
|
||||
|
||||
if request.method() == axum::http::Method::GET {
|
||||
info_span!(
|
||||
"request",
|
||||
path = %path,
|
||||
)
|
||||
} else {
|
||||
info_span!(
|
||||
"request",
|
||||
method = %request.method(),
|
||||
path = %path,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create the application router with all routes and middleware
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
Router::new()
|
||||
@@ -147,6 +170,13 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
.with_state(app_state)
|
||||
.layer(CookieLayer::default())
|
||||
.layer(axum::middleware::from_fn(inject_server_header))
|
||||
.layer(
|
||||
tower_http::trace::TraceLayer::new_for_http()
|
||||
.make_span_with(make_span)
|
||||
.on_request(|_request: &axum::http::Request<axum::body::Body>, _span: &tracing::Span| {
|
||||
// Disable request logging by doing nothing
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Inject the server header into responses
|
||||
|
||||
@@ -68,10 +68,10 @@ pub async fn link_oauth_account(
|
||||
|
||||
pub async fn create_user(
|
||||
pool: &sqlx::PgPool,
|
||||
username: &str,
|
||||
display_name: Option<&str>,
|
||||
email: Option<&str>,
|
||||
avatar_url: Option<&str>,
|
||||
provider_username: &str,
|
||||
provider_display_name: Option<&str>,
|
||||
provider_email: Option<&str>,
|
||||
provider_avatar_url: Option<&str>,
|
||||
provider: &str,
|
||||
provider_user_id: &str,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
@@ -82,20 +82,20 @@ pub async fn create_user(
|
||||
RETURNING id, email, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(email)
|
||||
.bind(provider_email)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// Create oauth link
|
||||
let _ = link_oauth_account(
|
||||
let _linked = link_oauth_account(
|
||||
pool,
|
||||
user.id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
email,
|
||||
Some(username),
|
||||
display_name,
|
||||
avatar_url,
|
||||
provider_email,
|
||||
Some(provider_username),
|
||||
provider_display_name,
|
||||
provider_avatar_url,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -1,36 +1,38 @@
|
||||
use tracing_subscriber::fmt::format::JsonFields;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
use tracing_subscriber::{fmt::format::JsonFields, EnvFilter, FmtSubscriber};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::formatter;
|
||||
|
||||
static SUBSCRIBER_INIT: std::sync::Once = std::sync::Once::new();
|
||||
|
||||
/// Configure and initialize logging for the application
|
||||
pub fn setup_logging(_config: &Config) {
|
||||
// Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME"))));
|
||||
pub fn setup_logging() {
|
||||
SUBSCRIBER_INIT.call_once(|| {
|
||||
// Allow RUST_LOG to override levels; default to info for our crate and warn elsewhere
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(format!("warn,{name}=info,{name}::auth=info", name = env!("CARGO_CRATE_NAME"))));
|
||||
|
||||
// Default to pretty for local dev; switchable later if we add CLI
|
||||
let use_pretty = cfg!(debug_assertions);
|
||||
// Default to pretty for local dev; switchable later if we add CLI
|
||||
let use_pretty = cfg!(debug_assertions);
|
||||
|
||||
let subscriber: Box<dyn tracing::Subscriber + Send + Sync> = if use_pretty {
|
||||
Box::new(
|
||||
FmtSubscriber::builder()
|
||||
.with_target(true)
|
||||
.event_format(formatter::CustomPrettyFormatter)
|
||||
.with_env_filter(filter)
|
||||
.finish(),
|
||||
)
|
||||
} else {
|
||||
Box::new(
|
||||
FmtSubscriber::builder()
|
||||
.with_target(true)
|
||||
.event_format(formatter::CustomJsonFormatter)
|
||||
.fmt_fields(JsonFields::new())
|
||||
.with_env_filter(filter)
|
||||
.finish(),
|
||||
)
|
||||
};
|
||||
let subscriber: Box<dyn tracing::Subscriber + Send + Sync> = if use_pretty {
|
||||
Box::new(
|
||||
FmtSubscriber::builder()
|
||||
.with_target(true)
|
||||
.event_format(formatter::CustomPrettyFormatter)
|
||||
.with_env_filter(filter)
|
||||
.finish(),
|
||||
)
|
||||
} else {
|
||||
Box::new(
|
||||
FmtSubscriber::builder()
|
||||
.with_target(true)
|
||||
.event_format(formatter::CustomJsonFormatter)
|
||||
.fmt_fields(JsonFields::new())
|
||||
.with_env_filter(filter)
|
||||
.finish(),
|
||||
)
|
||||
};
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ async fn main() {
|
||||
let config: Config = config::load_config();
|
||||
|
||||
// Initialize tracing subscriber
|
||||
logging::setup_logging(&config);
|
||||
logging::setup_logging();
|
||||
trace!(host = %config.host, port = config.port, shutdown_timeout_seconds = config.shutdown_timeout_seconds, "Loaded server configuration");
|
||||
|
||||
let addr = std::net::SocketAddr::new(config.host, config.port);
|
||||
|
||||
@@ -93,6 +93,8 @@ pub async fn oauth_callback_handler(
|
||||
}
|
||||
};
|
||||
|
||||
debug!(cookies = ?cookie.cookie().iter().collect::<Vec<_>>(), "Cookies");
|
||||
|
||||
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
|
||||
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
|
||||
if link_cookie.is_some() {
|
||||
|
||||
@@ -12,6 +12,7 @@ use testcontainers::{
|
||||
ContainerAsync, GenericImage, ImageExt,
|
||||
};
|
||||
use tokio::sync::Notify;
|
||||
use tracing::{debug, debug_span, Instrument};
|
||||
|
||||
static CRYPTO_INIT: Once = Once::new();
|
||||
|
||||
@@ -26,16 +27,40 @@ pub struct TestContext {
|
||||
}
|
||||
|
||||
#[builder]
|
||||
pub async fn test_context(use_database: bool, auth_registry: Option<AuthRegistry>) -> TestContext {
|
||||
pub async fn test_context(#[builder(default = false)] use_database: bool, auth_registry: Option<AuthRegistry>) -> TestContext {
|
||||
CRYPTO_INIT.call_once(|| {
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install default crypto provider");
|
||||
});
|
||||
|
||||
// Set up logging
|
||||
std::env::set_var("RUST_LOG", "debug,sqlx=info");
|
||||
pacman_server::logging::setup_logging();
|
||||
let (database_url, container) = if use_database {
|
||||
let (url, container) = setup_test_database("testdb", "testuser", "testpass").await;
|
||||
(Some(url), Some(container))
|
||||
let db = "testdb";
|
||||
let user = "testuser";
|
||||
let password = "testpass";
|
||||
|
||||
// Create container request
|
||||
let container_request = GenericImage::new("postgres", "15")
|
||||
.with_exposed_port(5432.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stderr("database system is ready to accept connections"))
|
||||
.with_env_var("POSTGRES_DB", db)
|
||||
.with_env_var("POSTGRES_USER", user)
|
||||
.with_env_var("POSTGRES_PASSWORD", password);
|
||||
|
||||
tracing::trace!(request = ?container_request, "Acquiring postgres testcontainer");
|
||||
let container = container_request.start().await.unwrap();
|
||||
let host = container.get_host().await.unwrap();
|
||||
let port = container.get_host_port_ipv4(5432).await.unwrap();
|
||||
|
||||
tracing::debug!(host = %host, port = %port, "Test database ready");
|
||||
|
||||
(
|
||||
Some(format!("postgresql://{user}:{password}@{host}:{port}/{db}?sslmode=disable")),
|
||||
Some(container),
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
@@ -63,8 +88,10 @@ pub async fn test_context(use_database: bool, auth_registry: Option<AuthRegistry
|
||||
// Run migrations
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&db)
|
||||
.instrument(debug_span!("running_migrations"))
|
||||
.await
|
||||
.expect("Failed to run database migrations");
|
||||
debug!("Database migrations ran successfully");
|
||||
|
||||
db
|
||||
} else {
|
||||
@@ -88,32 +115,13 @@ pub async fn test_context(use_database: bool, auth_registry: Option<AuthRegistry
|
||||
}
|
||||
|
||||
let router = create_router(app_state.clone());
|
||||
let mut server = TestServer::new(router).unwrap();
|
||||
server.save_cookies();
|
||||
|
||||
TestContext {
|
||||
server: TestServer::new(router).unwrap(),
|
||||
server,
|
||||
app_state,
|
||||
config,
|
||||
container,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set up a test PostgreSQL database using testcontainers
|
||||
async fn setup_test_database(db: &str, user: &str, password: &str) -> (String, ContainerAsync<GenericImage>) {
|
||||
let container = GenericImage::new("postgres", "15")
|
||||
.with_exposed_port(5432.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stderr("database system is ready to accept connections"))
|
||||
.with_env_var("POSTGRES_DB", db)
|
||||
.with_env_var("POSTGRES_USER", user)
|
||||
.with_env_var("POSTGRES_PASSWORD", password)
|
||||
.start()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let host = container.get_host().await.unwrap();
|
||||
let port = container.get_host_port_ipv4(5432).await.unwrap();
|
||||
|
||||
(
|
||||
format!("postgresql://{user}:{password}@{host}:{port}/{db}?sslmode=disable"),
|
||||
container,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use pacman_server::auth::{
|
||||
AuthRegistry,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use time::Duration;
|
||||
|
||||
mod common;
|
||||
use crate::common::{test_context, TestContext};
|
||||
@@ -12,7 +13,7 @@ use crate::common::{test_context, TestContext};
|
||||
/// Test OAuth authorization flows
|
||||
#[tokio::test]
|
||||
async fn test_oauth_authorization_flows() {
|
||||
let TestContext { server, .. } = test_context().use_database(false).call().await;
|
||||
let TestContext { server, .. } = test_context().call().await;
|
||||
|
||||
// Test OAuth authorize endpoint (should redirect)
|
||||
let response = server.get("/auth/github").await;
|
||||
@@ -30,7 +31,7 @@ async fn test_oauth_authorization_flows() {
|
||||
/// Test OAuth callback handling
|
||||
#[tokio::test]
|
||||
async fn test_oauth_callback_handling() {
|
||||
let TestContext { server, .. } = test_context().use_database(false).call().await;
|
||||
let TestContext { server, .. } = test_context().call().await;
|
||||
|
||||
// Test OAuth callback with missing parameters (should fail gracefully)
|
||||
let response = server.get("/auth/github/callback").await;
|
||||
@@ -53,7 +54,7 @@ async fn test_oauth_authorization_flow() {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
let TestContext { server, .. } = test_context().use_database(false).auth_registry(mock_registry).call().await;
|
||||
let TestContext { server, .. } = test_context().auth_registry(mock_registry).call().await;
|
||||
|
||||
// Test that valid handlers redirect
|
||||
let response = server.get("/auth/mock").await;
|
||||
@@ -84,93 +85,361 @@ async fn test_oauth_authorization_flow() {
|
||||
/// Test OAuth callback validation
|
||||
#[tokio::test]
|
||||
async fn test_oauth_callback_validation() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback()
|
||||
.times(0) // Should not be called
|
||||
.returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
name: None,
|
||||
email: Some("test@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
// TODO: Test that the OAuth callback handler validates the provider exists before processing
|
||||
// TODO: Test that the OAuth callback handler returns an error when the provider returns an OAuth error
|
||||
// TODO: Test that the OAuth callback handler returns an error when the authorization code is missing
|
||||
// TODO: Test that the OAuth callback handler returns an error when the state parameter is missing
|
||||
let TestContext { server, .. } = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Test that an unknown provider returns an error
|
||||
let response = server.get("/auth/unknown/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 400);
|
||||
|
||||
// Test that a provider-returned error is handled
|
||||
let response = server.get("/auth/mock/callback?error=access_denied").await;
|
||||
assert_eq!(response.status_code(), 400);
|
||||
|
||||
// Test that a missing code returns an error
|
||||
let response = server.get("/auth/mock/callback?state=b").await;
|
||||
assert_eq!(response.status_code(), 400);
|
||||
|
||||
// Test that a missing state returns an error
|
||||
let response = server.get("/auth/mock/callback?code=a").await;
|
||||
assert_eq!(response.status_code(), 400);
|
||||
}
|
||||
|
||||
/// Test OAuth callback processing
|
||||
#[tokio::test]
|
||||
async fn test_oauth_callback_processing() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
name: None,
|
||||
email: Some("processing@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
// TODO: Test that the OAuth callback handler exchanges the authorization code for user information successfully
|
||||
// TODO: Test that the OAuth callback handler handles provider callback errors gracefully
|
||||
// TODO: Test that the OAuth callback handler creates a session token after successful authentication
|
||||
// TODO: Test that the OAuth callback handler sets a session cookie after successful authentication
|
||||
// TODO: Test that the OAuth callback handler redirects to the profile page after successful authentication
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Test that a successful callback redirects and sets a session cookie
|
||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 302);
|
||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
||||
assert!(response.maybe_cookie("session").is_some());
|
||||
}
|
||||
|
||||
/// Test account linking flow
|
||||
#[tokio::test]
|
||||
async fn test_account_linking_flow() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
let mut initial_provider_mock = MockOAuthProvider::new();
|
||||
initial_provider_mock.expect_handle_callback().returning(move |_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(),
|
||||
username: "linkuser".to_string(),
|
||||
name: None,
|
||||
email: Some("link@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let initial_provider: Arc<dyn OAuthProvider> = Arc::new(initial_provider_mock);
|
||||
|
||||
// TODO: Test that the OAuth callback handler links a new provider to an existing user when link intent is present and session is valid
|
||||
// TODO: Test that the OAuth callback handler redirects to profile after successful account linking
|
||||
// TODO: Test that the OAuth callback handler falls back to normal sign-in when link intent is present but no valid session exists
|
||||
let mut link_provider_mock = MockOAuthProvider::new();
|
||||
link_provider_mock.expect_authorize().returning(|_| {
|
||||
Ok(pacman_server::auth::provider::AuthorizeInfo {
|
||||
authorize_url: "https://example.com".parse().unwrap(),
|
||||
session_token: "b_token".to_string(),
|
||||
})
|
||||
});
|
||||
link_provider_mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "456".to_string(),
|
||||
username: "linkuser_new".to_string(),
|
||||
name: None,
|
||||
email: Some("link@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let link_provider: Arc<dyn OAuthProvider> = Arc::new(link_provider_mock);
|
||||
|
||||
let registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock_initial", initial_provider), ("mock_link", link_provider)]),
|
||||
};
|
||||
let context = test_context().use_database(true).auth_registry(registry).call().await;
|
||||
|
||||
// 1. Create an initial user and provider link
|
||||
let user = pacman_server::data::user::create_user(
|
||||
&context.app_state.db,
|
||||
"linkuser",
|
||||
None,
|
||||
Some("link@example.com"),
|
||||
None,
|
||||
"mock_initial",
|
||||
"123",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
{
|
||||
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
|
||||
.await
|
||||
.expect("Failed to list user's initial provider(s)");
|
||||
assert_eq!(providers.len(), 1, "User should have one provider");
|
||||
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
|
||||
}
|
||||
|
||||
// 2. Create a session for this user
|
||||
let session_cookie = {
|
||||
let response = context.server.get("/auth/mock_initial/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 302);
|
||||
assert!(response.maybe_cookie("session").is_some(), "Session cookie should be set");
|
||||
|
||||
response.cookie("session").clone()
|
||||
};
|
||||
tracing::debug!(cookie = %session_cookie, "Session cookie acquired");
|
||||
|
||||
// Begin linking flow
|
||||
let link_cookie = {
|
||||
let response = context
|
||||
.server
|
||||
.get("/auth/mock_link?link=true")
|
||||
.add_cookie(session_cookie.clone())
|
||||
.await;
|
||||
assert_eq!(response.status_code(), 303);
|
||||
assert_eq!(response.maybe_cookie("link").unwrap().value(), "1");
|
||||
|
||||
response.cookie("link").clone()
|
||||
};
|
||||
tracing::debug!(cookie = %link_cookie, "Link cookie acquired");
|
||||
|
||||
// 3. Perform the linking call
|
||||
let response = context
|
||||
.server
|
||||
.get("/auth/mock_link/callback?code=a&state=b")
|
||||
.add_cookie(link_cookie)
|
||||
.add_cookie(session_cookie.clone())
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status_code(), 303, "Post-linking response should be a redirect");
|
||||
assert_eq!(
|
||||
response.headers().get("location").unwrap(),
|
||||
"/profile",
|
||||
"Post-linking response should redirect to /profile"
|
||||
);
|
||||
|
||||
let providers = pacman_server::data::user::list_user_providers(&context.app_state.db, user.id)
|
||||
.await
|
||||
.expect("Failed to list user's providers");
|
||||
assert_eq!(providers.len(), 2, "User should have two providers");
|
||||
assert!(providers.iter().any(|p| p.provider == "mock_initial"));
|
||||
assert!(providers.iter().any(|p| p.provider == "mock_link"));
|
||||
}
|
||||
|
||||
/// Test new user registration
|
||||
#[tokio::test]
|
||||
async fn test_new_user_registration() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
name: None,
|
||||
email: Some("newuser@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
|
||||
// TODO: Test that the OAuth callback handler creates a new user account when no existing user is found
|
||||
// TODO: Test that the OAuth callback handler requires an email address for all sign-ins
|
||||
// TODO: Test that the OAuth callback handler rejects sign-in attempts when no email is available
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Test that the OAuth callback handler creates a new user account when no existing user is found
|
||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 302);
|
||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
||||
let user = pacman_server::data::user::find_user_by_email(&context.app_state.db, "newuser@example.com")
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(user.email, Some("newuser@example.com".to_string()));
|
||||
}
|
||||
|
||||
/// Test existing user sign-in
|
||||
/// Test OAuth callback handler rejects sign-in attempts when no email is available
|
||||
#[tokio::test]
|
||||
async fn test_existing_user_sign_in() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
async fn test_oauth_callback_no_email() {
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "456".to_string(),
|
||||
username: "noemailuser".to_string(),
|
||||
name: None,
|
||||
email: None,
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
|
||||
// TODO: Test that the OAuth callback handler allows sign-in when the provider is already linked to an existing user
|
||||
// TODO: Test that the OAuth callback handler requires explicit linking when a user with the same email exists and has other providers linked
|
||||
// TODO: Test that the OAuth callback handler auto-links a provider when a user exists but has no other providers linked
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Test that the OAuth callback handler rejects sign-in attempts when no email is available
|
||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 400);
|
||||
}
|
||||
|
||||
/// Test avatar processing
|
||||
/// Test existing user sign-in with new provider fails
|
||||
#[tokio::test]
|
||||
async fn test_avatar_processing() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
async fn test_existing_user_sign_in_new_provider_fails() {
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(move |_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "456".to_string(), // Different provider ID
|
||||
username: "existinguser_newprovider".to_string(),
|
||||
name: None,
|
||||
email: Some("existing@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock_new", provider)]),
|
||||
};
|
||||
|
||||
// TODO: Test that the OAuth callback handler processes user avatars asynchronously without blocking the response
|
||||
// TODO: Test that the OAuth callback handler handles avatar processing errors gracefully
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Create a user with one linked provider
|
||||
pacman_server::data::user::create_user(
|
||||
&context.app_state.db,
|
||||
"existinguser",
|
||||
None,
|
||||
Some("existing@example.com"),
|
||||
None,
|
||||
"mock",
|
||||
"123",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// A user with the email exists, but has one provider. If they sign in with a *new* provider, it should fail.
|
||||
let response = context.server.get("/auth/mock_new/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 400); // Should fail and ask to link explicitly.
|
||||
}
|
||||
|
||||
/// Test profile access
|
||||
/// Test existing user sign-in with existing provider succeeds
|
||||
#[tokio::test]
|
||||
async fn test_profile_access() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
async fn test_existing_user_sign_in_existing_provider_succeeds() {
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(move |_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(), // Same provider ID as created user
|
||||
username: "existinguser".to_string(),
|
||||
name: None,
|
||||
email: Some("existing@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
// TODO: Test that the profile handler returns user information when a valid session exists
|
||||
// TODO: Test that the profile handler returns an error when no session cookie is present
|
||||
// TODO: Test that the profile handler returns an error when the session token is invalid
|
||||
// TODO: Test that the profile handler includes linked providers in the response
|
||||
// TODO: Test that the profile handler returns an error when the user is not found in the database
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Create a user with one linked provider
|
||||
pacman_server::data::user::create_user(
|
||||
&context.app_state.db,
|
||||
"existinguser",
|
||||
None,
|
||||
Some("existing@example.com"),
|
||||
None,
|
||||
"mock",
|
||||
"123",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test sign-in with an existing linked provider.
|
||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 302); // Should sign in successfully
|
||||
assert_eq!(response.headers().get("location").unwrap(), "/profile");
|
||||
}
|
||||
|
||||
/// Test logout functionality
|
||||
#[tokio::test]
|
||||
async fn test_logout_functionality() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
let mut mock = MockOAuthProvider::new();
|
||||
mock.expect_handle_callback().returning(|_, _, _, _| {
|
||||
Ok(pacman_server::auth::provider::AuthUser {
|
||||
id: "123".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
name: None,
|
||||
email: Some("test@example.com".to_string()),
|
||||
avatar_url: None,
|
||||
})
|
||||
});
|
||||
let provider: Arc<dyn OAuthProvider> = Arc::new(mock);
|
||||
let mock_registry = AuthRegistry {
|
||||
providers: HashMap::from([("mock", provider)]),
|
||||
};
|
||||
|
||||
// TODO: Test that the logout handler clears the session if a session was there
|
||||
// TODO: Test that the logout handler removes the session from memory storage
|
||||
// TODO: Test that the logout handler clears the session cookie
|
||||
// TODO: Test that the logout handler redirects to the home page after logout
|
||||
}
|
||||
|
||||
/// Test provider configuration
|
||||
#[tokio::test]
|
||||
async fn test_provider_configuration() {
|
||||
let TestContext { server: _server, .. } = test_context().use_database(false).call().await;
|
||||
|
||||
// TODO: Test that the providers list handler returns all configured OAuth providers
|
||||
// TODO: Test that the providers list handler includes provider status (active/inactive)
|
||||
let context = test_context().use_database(true).auth_registry(mock_registry).call().await;
|
||||
|
||||
// Sign in to establish a session
|
||||
let response = context.server.get("/auth/mock/callback?code=a&state=b").await;
|
||||
assert_eq!(response.status_code(), 302);
|
||||
let session_cookie = response.cookie("session").clone();
|
||||
|
||||
// Test that the logout handler clears the session cookie and redirects
|
||||
let response = context
|
||||
.server
|
||||
.get("/logout")
|
||||
.add_cookie(cookie::Cookie::new(
|
||||
session_cookie.name().to_string(),
|
||||
session_cookie.value().to_string(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// Redirect assertions
|
||||
assert_eq!(response.status_code(), 302);
|
||||
assert!(
|
||||
response.headers().contains_key("location"),
|
||||
"Response redirect should have a location header"
|
||||
);
|
||||
|
||||
// Cookie assertions
|
||||
assert_eq!(
|
||||
response.cookie("session").value(),
|
||||
"removed",
|
||||
"Session cookie should be removed"
|
||||
);
|
||||
assert_eq!(
|
||||
response.cookie("session").max_age(),
|
||||
Some(Duration::ZERO),
|
||||
"Session cookie should have a max age of 0"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user