mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-10 06:07:55 -06:00
refactor: remove unnecessary HashMap for passing code/state strings, formatter lifetime tweak
This commit is contained in:
@@ -85,23 +85,8 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> {
|
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
|
||||||
if let Some(err) = query.get("error") {
|
let Some(verifier) = self.pkce.take_verifier(state) else {
|
||||||
warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
err.clone(),
|
|
||||||
query.get("error_description").cloned(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let code = query
|
|
||||||
.get("code")
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
|
|
||||||
let state = query
|
|
||||||
.get("state")
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
|
||||||
let Some(verifier) = self.pkce.take_verifier(&state) else {
|
|
||||||
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
||||||
return Err(ErrorResponse::bad_request(
|
return Err(ErrorResponse::bad_request(
|
||||||
"invalid_request",
|
"invalid_request",
|
||||||
@@ -111,7 +96,7 @@ impl OAuthProvider for DiscordProvider {
|
|||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -93,23 +93,8 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
Redirect::to(authorize_url.as_str()).into_response()
|
Redirect::to(authorize_url.as_str()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> {
|
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse> {
|
||||||
if let Some(err) = query.get("error") {
|
let Some(verifier) = self.pkce.take_verifier(state) else {
|
||||||
warn!(error = %err, desc = query.get("error_description").map(|s| s.as_str()), "OAuth callback contained an error");
|
|
||||||
return Err(ErrorResponse::bad_request(
|
|
||||||
err.clone(),
|
|
||||||
query.get("error_description").cloned(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let code = query
|
|
||||||
.get("code")
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
|
|
||||||
let state = query
|
|
||||||
.get("state")
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
|
||||||
let Some(verifier) = self.pkce.take_verifier(&state) else {
|
|
||||||
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
warn!(%state, "Missing or expired PKCE verifier for state parameter");
|
||||||
return Err(ErrorResponse::bad_request(
|
return Err(ErrorResponse::bad_request(
|
||||||
"invalid_request",
|
"invalid_request",
|
||||||
@@ -119,7 +104,7 @@ impl OAuthProvider for GitHubProvider {
|
|||||||
|
|
||||||
let token = self
|
let token = self
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||||
.request_async(&self.http)
|
.request_async(&self.http)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
@@ -24,5 +22,5 @@ pub trait OAuthProvider: Send + Sync {
|
|||||||
|
|
||||||
async fn authorize(&self) -> axum::response::Response;
|
async fn authorize(&self) -> axum::response::Response;
|
||||||
|
|
||||||
async fn handle_callback(&self, query: &HashMap<String, String>) -> Result<AuthUser, ErrorResponse>;
|
async fn handle_callback(&self, code: &str, state: &str) -> Result<AuthUser, ErrorResponse>;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,8 +112,7 @@ where
|
|||||||
message: &'a mut Option<String>,
|
message: &'a mut Option<String>,
|
||||||
fields: &'a mut Map<String, Value>,
|
fields: &'a mut Map<String, Value>,
|
||||||
}
|
}
|
||||||
|
impl Visit for FieldVisitor<'_> {
|
||||||
impl<'a> Visit for FieldVisitor<'a> {
|
|
||||||
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
|
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
|
||||||
let key = field.name();
|
let key = field.name();
|
||||||
if key == "message" {
|
if key == "message" {
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, Query, State},
|
extract::{Path, Query, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
@@ -72,26 +70,30 @@ pub async fn oauth_callback_handler(
|
|||||||
return ErrorResponse::bad_request(error, params.error_description).into_response();
|
return ErrorResponse::bad_request(error, params.error_description).into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut q = HashMap::new();
|
// Acquire required parameters
|
||||||
if let Some(v) = params.code {
|
let Some(code) = params.code.as_deref() else {
|
||||||
q.insert("code".to_string(), v);
|
return ErrorResponse::bad_request("invalid_request", Some("missing code".into())).into_response();
|
||||||
}
|
};
|
||||||
if let Some(v) = params.state {
|
let Some(state) = params.state.as_deref() else {
|
||||||
q.insert("state".to_string(), v);
|
return ErrorResponse::bad_request("invalid_request", Some("missing state".into())).into_response();
|
||||||
}
|
};
|
||||||
let user = match prov.handle_callback(&q).await {
|
|
||||||
|
// Handle callback from provider
|
||||||
|
let user = match prov.handle_callback(code, state).await {
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(%provider, "OAuth callback handling failed");
|
warn!(%provider, "OAuth callback handling failed");
|
||||||
return e.into_response();
|
return e.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
|
// Linking or sign-in flow. Determine link intent from cookie (set at authorize time)
|
||||||
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
|
let link_cookie = cookie.get("link").map(|c| c.value().to_string());
|
||||||
if link_cookie.is_some() {
|
if link_cookie.is_some() {
|
||||||
cookie.remove("link");
|
cookie.remove("link");
|
||||||
}
|
}
|
||||||
let email = user.email.as_deref();
|
let email = user.email.as_deref();
|
||||||
|
|
||||||
// Determine linking intent with a valid session
|
// Determine linking intent with a valid session
|
||||||
let is_link = if link_cookie.as_deref() == Some("1") {
|
let is_link = if link_cookie.as_deref() == Some("1") {
|
||||||
match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) {
|
match session::get_session_token(&cookie).and_then(|t| session::decode_jwt(&t, &app_state.jwt_decoding_key)) {
|
||||||
@@ -213,6 +215,7 @@ pub async fn oauth_callback_handler(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create session token
|
||||||
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
||||||
session::set_session_cookie(&cookie, &session_token);
|
session::set_session_cookie(&cookie, &session_token);
|
||||||
info!(%provider, "Signed in successfully");
|
info!(%provider, "Signed in successfully");
|
||||||
|
|||||||
Reference in New Issue
Block a user