refactor: remove unnecessary HashMap for passing code/state strings, formatter lifetime tweak

This commit is contained in:
Ryan Walters
2025-09-17 13:18:58 -05:00
parent e2f3f6790f
commit 56e02e7253
5 changed files with 21 additions and 51 deletions

View File

@@ -85,23 +85,8 @@ impl OAuthProvider for DiscordProvider {
Redirect::to(authorize_url.as_str()).into_response() Redirect::to(authorize_url.as_str()).into_response()
} }
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> { async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
if let Some(err) = query.get("error") { let Some(verifier) = self.pkce.take_verifier(state) else {
warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error");
return Err(ErrorResponse::bad_request(
err.clone(),
query.get("error_description").cloned(),
));
}
let code = query
.get("code")
.cloned()
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
let state = query
.get("state")
.cloned()
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
let Some(verifier) = self.pkce.take_verifier(&state) else {
warn!(%state, "Missing or expired PKCE verifier for state parameter"); warn!(%state, "Missing or expired PKCE verifier for state parameter");
return Err(ErrorResponse::bad_request( return Err(ErrorResponse::bad_request(
"invalid_request", "invalid_request",
@@ -111,7 +96,7 @@ impl OAuthProvider for DiscordProvider {
let token = self let token = self
.client .client
.exchange_code(AuthorizationCode::new(code)) .exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(PkceCodeVerifier::new(verifier)) .set_pkce_verifier(PkceCodeVerifier::new(verifier))
.request_async(&self.http) .request_async(&self.http)
.await .await

View File

@@ -93,23 +93,8 @@ impl OAuthProvider for GitHubProvider {
Redirect::to(authorize_url.as_str()).into_response() Redirect::to(authorize_url.as_str()).into_response()
} }
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> { async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
if let Some(err) = query.get("error") { let Some(verifier) = self.pkce.take_verifier(state) else {
warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error");
return Err(ErrorResponse::bad_request(
err.clone(),
query.get("error_description").cloned(),
));
}
let code = query
.get("code")
.cloned()
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
let state = query
.get("state")
.cloned()
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
let Some(verifier) = self.pkce.take_verifier(&state) else {
warn!(%state, "Missing or expired PKCE verifier for state parameter"); warn!(%state, "Missing or expired PKCE verifier for state parameter");
return Err(ErrorResponse::bad_request( return Err(ErrorResponse::bad_request(
"invalid_request", "invalid_request",
@@ -119,7 +104,7 @@ impl OAuthProvider for GitHubProvider {
let token = self let token = self
.client .client
.exchange_code(AuthorizationCode::new(code)) .exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(PkceCodeVerifier::new(verifier)) .set_pkce_verifier(PkceCodeVerifier::new(verifier))
.request_async(&self.http) .request_async(&self.http)
.await .await

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use serde::Serialize; use serde::Serialize;
@@ -24,5 +22,5 @@ pub trait OAuthProvider: Send + Sync {
async fn authorize(&self) -> axum::response::Response; async fn authorize(&self) -> axum::response::Response;
async fn handle_callback(&self, query: &HashMap<String, String>) -> Result<AuthUser, ErrorResponse>; async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse>;
} }

View File

@@ -112,8 +112,7 @@ where
message: &'a mut Option<String>, message: &'a mut Option<String>,
fields: &'a mut Map<String, Value>, fields: &'a mut Map<String, Value>,
} }
impl Visit for FieldVisitor<'_> {
impl<'a> Visit for FieldVisitor<'a> {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
let key = field.name(); let key = field.name();
if key == "message" { if key == "message" {

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use axum::{ use axum::{
extract::{Path, Query, State}, extract::{Path, Query, State},
http::StatusCode, http::StatusCode,
@@ -72,26 +70,30 @@ pub async fn oauth_callback_handler(
return ErrorResponse::bad_request(error, params.error_description).into_response(); return ErrorResponse::bad_request(error, params.error_description).into_response();
} }
let mut q = HashMap::new(); // Acquire required parameters
if let Some(v) = params.code { let Some(code) = params.code.as_deref() else {
q.insert("code".to_string(), v); return ErrorResponse::bad_request("invalid_request", Some("missing code".into())).into_response();
} };
if let Some(v) = params.state { let Some(state) = params.state.as_deref() else {
q.insert("state".to_string(), v); return ErrorResponse::bad_request("invalid_request", Some("missing state".into())).into_response();
} };
let user = match prov.handle_callback(&q).await {
// Handle callback from provider
let user = match prov.handle_callback(code, state).await {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
warn!(%provider, "OAuth callback handling failed"); warn!(%provider, "OAuth callback handling failed");
return e.into_response(); return e.into_response();
} }
}; };
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time) // Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
let link_cookie = cookie.get("link").map(|c| c.value().to_string()); let link_cookie = cookie.get("link").map(|c| c.value().to_string());
if link_cookie.is_some() { if link_cookie.is_some() {
cookie.remove("link"); cookie.remove("link");
} }
let email = user.email.as_deref(); let email = user.email.as_deref();
// Determine linking intent with a valid session // Determine linking intent with a valid session
let is_link = if link_cookie.as_deref() == Some("1") { let is_link = if link_cookie.as_deref() == Some("1") {
match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) { match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) {
@@ -213,6 +215,7 @@ pub async fn oauth_callback_handler(
} }
}; };
// Create session token
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
session::set_session_cookie(&cookie, &session_token); session::set_session_cookie(&cookie, &session_token);
info!(%provider, "Signed in successfully"); info!(%provider, "Signed in successfully");