From 56e02e72539da3f45f3b4fd6ee358ad7820b009d Mon Sep 17 00:00:00 2001 From: Ryan Walters Date: Wed, 17 Sep 2025 13:18:58 -0500 Subject: [PATCH] refactor: remove unnecessary HashMap for passing code/state strings, formatter lifetime tweak --- pacman-server/src/auth/discord.rs | 21 +++------------------ pacman-server/src/auth/github.rs | 21 +++------------------ pacman-server/src/auth/provider.rs | 4 +--- pacman-server/src/formatter.rs | 3 +-- pacman-server/src/routes.rs | 23 +++++++++++++---------- 5 files changed, 21 insertions(+), 51 deletions(-) diff --git a/pacman-server/src/auth/discord.rs b/pacman-server/src/auth/discord.rs index 31ae7f8..5dc4a5a 100644 --- a/pacman-server/src/auth/discord.rs +++ b/pacman-server/src/auth/discord.rs @@ -85,23 +85,8 @@ impl OAuthProvider for DiscordProvider { Redirect::to(authorize_url.as_str()).into_response() } - async fn handle_callback(&self, query: &std::collections::HashMap) -> Result { - if let Some(err) = query.get("error") { - warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error"); - return Err(ErrorResponse::bad_request( - err.clone(), - query.get("error_description").cloned(), - )); - } - let code = query - .get("code") - .cloned() - .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?; - let state = query - .get("state") - .cloned() - .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; - let Some(verifier) = self.pkce.take_verifier(&state) else { + async fn handle_callback(&self, code: &str, state: &str) -> Result { + let Some(verifier) = self.pkce.take_verifier(state) else { warn!(%state, "Missing or expired PKCE verifier for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", @@ -111,7 +96,7 @@ impl OAuthProvider for DiscordProvider { let token = self .client - .exchange_code(AuthorizationCode::new(code)) + .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(PkceCodeVerifier::new(verifier)) .request_async(&self.http) .await diff --git a/pacman-server/src/auth/github.rs b/pacman-server/src/auth/github.rs index ac76fcf..420f44d 100644 --- a/pacman-server/src/auth/github.rs +++ b/pacman-server/src/auth/github.rs @@ -93,23 +93,8 @@ impl OAuthProvider for GitHubProvider { Redirect::to(authorize_url.as_str()).into_response() } - async fn handle_callback(&self, query: &std::collections::HashMap) -> Result { - if let Some(err) = query.get("error") { - warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error"); - return Err(ErrorResponse::bad_request( - err.clone(), - query.get("error_description").cloned(), - )); - } - let code = query - .get("code") - .cloned() - .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?; - let state = query - .get("state") - .cloned() - .ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?; - let Some(verifier) = self.pkce.take_verifier(&state) else { + async fn handle_callback(&self, code: &str, state: &str) -> Result { + let Some(verifier) = self.pkce.take_verifier(state) else { warn!(%state, "Missing or expired PKCE verifier for state parameter"); return Err(ErrorResponse::bad_request( "invalid_request", @@ -119,7 +104,7 @@ impl OAuthProvider for GitHubProvider { let token = self .client - .exchange_code(AuthorizationCode::new(code)) + .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(PkceCodeVerifier::new(verifier)) .request_async(&self.http) .await diff --git a/pacman-server/src/auth/provider.rs b/pacman-server/src/auth/provider.rs index 4623c3b..5c3e032 100644 --- a/pacman-server/src/auth/provider.rs +++ b/pacman-server/src/auth/provider.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use async_trait::async_trait; use serde::Serialize; @@ -24,5 +22,5 @@ pub trait OAuthProvider: Send + Sync { async fn authorize(&self) -> axum::response::Response; - async fn handle_callback(&self, query: &HashMap) -> Result; + async fn handle_callback(&self, code: &str, state: &str) -> Result; } diff --git a/pacman-server/src/formatter.rs b/pacman-server/src/formatter.rs index 8f0c2c3..37993a9 100644 --- a/pacman-server/src/formatter.rs +++ b/pacman-server/src/formatter.rs @@ -112,8 +112,7 @@ where message: &'a mut Option, fields: &'a mut Map, } - - impl<'a> Visit for FieldVisitor<'a> { + impl Visit for FieldVisitor<'_> { fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { let key = field.name(); if key == "message" { diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 677be73..7eb2e7f 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use axum::{ extract::{Path, Query, State}, http::StatusCode, @@ -72,26 +70,30 @@ pub async fn oauth_callback_handler( return ErrorResponse::bad_request(error, params.error_description).into_response(); } - let mut q = HashMap::new(); - if let Some(v) = params.code { - q.insert("code".to_string(), v); - } - if let Some(v) = params.state { - q.insert("state".to_string(), v); - } - let user = match prov.handle_callback(&q).await { + // Acquire required parameters + let Some(code) = params.code.as_deref() else { + return ErrorResponse::bad_request("invalid_request", Some("missing code".into())).into_response(); + }; + let Some(state) = params.state.as_deref() else { + return ErrorResponse::bad_request("invalid_request", Some("missing state".into())).into_response(); + }; + + // Handle callback from provider + let user = match prov.handle_callback(code, state).await { Ok(u) => u, Err(e) => { warn!(%provider, "OAuth callback handling failed"); return e.into_response(); } }; + // Linking or sign-in flow. Determine link intent from cookie (set at authorize time) let link_cookie = cookie.get("link").map(|c| c.value().to_string()); if link_cookie.is_some() { cookie.remove("link"); } let email = user.email.as_deref(); + // Determine linking intent with a valid session let is_link = if link_cookie.as_deref() == Some("1") { match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) { @@ -213,6 +215,7 @@ pub async fn oauth_callback_handler( } }; + // Create session token let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); session::set_session_cookie(&cookie, &session_token); info!(%provider, "Signed in successfully");