mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-07 13:15:54 -06:00
feat: setup github provider with generic trait, proper routes, session & jwt handling, errors & user agent
This commit is contained in:
88
Cargo.lock
generated
88
Cargo.lock
generated
@@ -87,6 +87,17 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atoi"
|
||||
version = "2.0.0"
|
||||
@@ -127,6 +138,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
@@ -154,6 +166,19 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-cookie"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd42d7564a7d36703869e6191aa2dc29d72c9158501035da83ad934eca2e7ca1"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"cookie-rs",
|
||||
"http",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.5.2"
|
||||
@@ -174,6 +199,17 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.75"
|
||||
@@ -263,7 +299,7 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"critical-section",
|
||||
"foldhash",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"serde",
|
||||
@@ -454,6 +490,12 @@ version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
|
||||
|
||||
[[package]]
|
||||
name = "cookie-rs"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69deedbff42d34777a2c86794e57629898a686d84efcc28dfa8fbb515c8884ce"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -525,6 +567,20 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "6.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deprecate-until"
|
||||
version = "0.1.1"
|
||||
@@ -941,6 +997,12 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
@@ -959,7 +1021,7 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1278,7 +1340,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92119844f513ffa41556430369ab02c295a3578af21cf945caa3e9e0c2481ac3"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1339,6 +1401,19 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "9.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"js-sys",
|
||||
"ring",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
@@ -1736,10 +1811,13 @@ version = "0.1.1"
|
||||
name = "pacman-server"
|
||||
version = "0.1.1"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum",
|
||||
"chrono",
|
||||
"axum-cookie",
|
||||
"dashmap",
|
||||
"dotenvy",
|
||||
"figment",
|
||||
"jsonwebtoken",
|
||||
"oauth2",
|
||||
"reqwest",
|
||||
"serde",
|
||||
@@ -2630,7 +2708,7 @@ dependencies = [
|
||||
"futures-intrusive",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.5",
|
||||
"hashlink",
|
||||
"indexmap",
|
||||
"log",
|
||||
|
||||
@@ -13,8 +13,6 @@ keywords = ["game", "pacman", "arcade", "sdl2"]
|
||||
categories = ["games", "emulators"]
|
||||
publish = false
|
||||
|
||||
[workspace.dependencies]
|
||||
|
||||
[profile.dev]
|
||||
incremental = true
|
||||
|
||||
|
||||
@@ -15,13 +15,12 @@ publish.workspace = true
|
||||
default-run = "pacman-server"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.8"
|
||||
axum = { version = "0.8", features = ["macros"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
oauth2 = "5"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
sqlx = { version = "0.8", features = [
|
||||
"runtime-tokio-rustls",
|
||||
"postgres",
|
||||
@@ -29,9 +28,8 @@ sqlx = { version = "0.8", features = [
|
||||
] }
|
||||
figment = { version = "0.10", features = ["env"] }
|
||||
dotenvy = "0.15"
|
||||
|
||||
# Validation
|
||||
dashmap = "6.1"
|
||||
axum-cookie = "0.2"
|
||||
async-trait = "0.1"
|
||||
jsonwebtoken = { version = "9.3", default-features = false }
|
||||
# validator = { version = "0.16", features = ["derive"] }
|
||||
|
||||
# JWT for internal sessions
|
||||
# jsonwebtoken = "8.3"
|
||||
|
||||
28
pacman-server/src/app.rs
Normal file
28
pacman-server/src/app.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use dashmap::DashMap;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{auth::AuthRegistry, config::Config};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: Arc<Config>,
|
||||
pub auth: Arc<AuthRegistry>,
|
||||
pub sessions: Arc<DashMap<String, crate::auth::provider::AuthUser>>,
|
||||
pub jwt_encoding_key: Arc<EncodingKey>,
|
||||
pub jwt_decoding_key: Arc<DecodingKey>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(config: Config, auth: AuthRegistry) -> Self {
|
||||
let jwt_secret = config.jwt_secret.clone();
|
||||
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
auth: Arc::new(auth),
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
}
|
||||
}
|
||||
}
|
||||
189
pacman-server/src/auth/github.rs
Normal file
189
pacman-server/src/auth/github.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use axum::{response::IntoResponse, response::Redirect};
|
||||
use dashmap::DashMap;
|
||||
use oauth2::{basic::BasicClient, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::{
|
||||
auth::provider::{AuthUser, OAuthProvider},
|
||||
errors::ErrorResponse,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GitHubUser {
|
||||
pub id: u64,
|
||||
pub login: String,
|
||||
pub name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub avatar_url: String,
|
||||
pub html_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GitHubEmail {
|
||||
pub email: String,
|
||||
pub primary: bool,
|
||||
pub verified: bool,
|
||||
pub visibility: Option<String>,
|
||||
}
|
||||
|
||||
/// Fetch user information from GitHub API
|
||||
pub async fn fetch_github_user(
|
||||
http_client: &reqwest::Client,
|
||||
access_token: &str,
|
||||
) -> Result<GitHubUser, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = http_client
|
||||
.get("https://api.github.com/user")
|
||||
.header("Authorization", format!("Bearer {}", access_token))
|
||||
.header("Accept", "application/vnd.github.v3+json")
|
||||
.header("User-Agent", crate::config::USER_AGENT)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("GitHub API error: {}", response.status()).into());
|
||||
}
|
||||
|
||||
let user: GitHubUser = response.json().await?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub struct GitHubProvider {
|
||||
pub client: BasicClient<
|
||||
oauth2::EndpointSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointSet,
|
||||
>,
|
||||
pub http: reqwest::Client,
|
||||
pkce: DashMap<String, (String, Instant)>,
|
||||
}
|
||||
|
||||
impl GitHubProvider {
|
||||
pub fn new(
|
||||
client: BasicClient<
|
||||
oauth2::EndpointSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointSet,
|
||||
>,
|
||||
http: reqwest::Client,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
client,
|
||||
http,
|
||||
pkce: DashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for GitHubProvider {
|
||||
fn id(&self) -> &'static str {
|
||||
"github"
|
||||
}
|
||||
fn label(&self) -> &'static str {
|
||||
"GitHub"
|
||||
}
|
||||
|
||||
async fn authorize(&self) -> axum::response::Response {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (authorize_url, csrf_state) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.add_scope(Scope::new("user:email".to_string()))
|
||||
.add_scope(Scope::new("read:user".to_string()))
|
||||
.url();
|
||||
// Insert PKCE verifier with timestamp and purge stale entries
|
||||
let now = Instant::now();
|
||||
self.pkce
|
||||
.insert(csrf_state.secret().to_string(), (pkce_verifier.secret().to_string(), now));
|
||||
// Best-effort cleanup to avoid unbounded growth
|
||||
const PKCE_TTL: Duration = Duration::from_secs(5 * 60);
|
||||
for entry in self.pkce.iter() {
|
||||
if now.duration_since(entry.value().1) > PKCE_TTL {
|
||||
self.pkce.remove(entry.key());
|
||||
}
|
||||
}
|
||||
Redirect::to(authorize_url.as_str()).into_response()
|
||||
}
|
||||
|
||||
async fn handle_callback(&self, query: &std::collections::HashMap<String, String>) -> Result<AuthUser, ErrorResponse> {
|
||||
if let Some(err) = query.get("error") {
|
||||
return Err(ErrorResponse::bad_request(
|
||||
err.clone(),
|
||||
query.get("error_description").cloned(),
|
||||
));
|
||||
}
|
||||
let code = query
|
||||
.get("code")
|
||||
.cloned()
|
||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing code".into())))?;
|
||||
let state = query
|
||||
.get("state")
|
||||
.cloned()
|
||||
.ok_or_else(|| ErrorResponse::bad_request("invalid_request", Some("missing state".into())))?;
|
||||
let Some((verifier, created_at)) = self.pkce.remove(&state).map(|e| e.1) else {
|
||||
return Err(ErrorResponse::bad_request(
|
||||
"invalid_request",
|
||||
Some("missing pkce verifier for state".into()),
|
||||
));
|
||||
};
|
||||
// Verify PKCE TTL
|
||||
if Instant::now().duration_since(created_at) > Duration::from_secs(5 * 60) {
|
||||
return Err(ErrorResponse::bad_request(
|
||||
"invalid_request",
|
||||
Some("expired pkce verifier for state".into()),
|
||||
));
|
||||
}
|
||||
|
||||
let token = self
|
||||
.client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(PkceCodeVerifier::new(verifier))
|
||||
.request_async(&self.http)
|
||||
.await
|
||||
.map_err(|e| ErrorResponse::bad_gateway("token_exchange_failed", Some(e.to_string())))?;
|
||||
|
||||
let user = fetch_github_user(&self.http, token.access_token().secret())
|
||||
.await
|
||||
.map_err(|e| ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch user: {}", e))))?;
|
||||
let _emails = fetch_github_emails(&self.http, token.access_token().secret())
|
||||
.await
|
||||
.map_err(|e| ErrorResponse::bad_gateway("github_api_error", Some(format!("failed to fetch emails: {}", e))))?;
|
||||
|
||||
Ok(AuthUser {
|
||||
id: user.id.to_string(),
|
||||
username: user.login,
|
||||
name: user.name,
|
||||
email: user.email,
|
||||
avatar_url: Some(user.avatar_url),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch user emails from GitHub API
|
||||
pub async fn fetch_github_emails(
|
||||
http_client: &reqwest::Client,
|
||||
access_token: &str,
|
||||
) -> Result<Vec<GitHubEmail>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = http_client
|
||||
.get("https://api.github.com/user/emails")
|
||||
.header("Authorization", format!("Bearer {}", access_token))
|
||||
.header("Accept", "application/vnd.github.v3+json")
|
||||
.header("User-Agent", crate::config::USER_AGENT)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("GitHub API error: {}", response.status()).into());
|
||||
}
|
||||
|
||||
let emails: Vec<GitHubEmail> = response.json().await?;
|
||||
Ok(emails)
|
||||
}
|
||||
47
pacman-server/src/auth/mod.rs
Normal file
47
pacman-server/src/auth/mod.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use oauth2::{basic::BasicClient, EndpointNotSet, EndpointSet};
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
pub mod github;
|
||||
pub mod provider;
|
||||
|
||||
pub struct AuthRegistry {
|
||||
providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>>,
|
||||
}
|
||||
|
||||
impl AuthRegistry {
|
||||
pub fn new(config: &Config) -> Result<Self, oauth2::url::ParseError> {
|
||||
let http = reqwest::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("HTTP client should build");
|
||||
|
||||
let github_client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet> =
|
||||
BasicClient::new(oauth2::ClientId::new(config.github_client_id.clone()))
|
||||
.set_client_secret(oauth2::ClientSecret::new(config.github_client_secret.clone()))
|
||||
.set_auth_uri(oauth2::AuthUrl::new("https://github.com/login/oauth/authorize".to_string())?)
|
||||
.set_token_uri(oauth2::TokenUrl::new(
|
||||
"https://github.com/login/oauth/access_token".to_string(),
|
||||
)?)
|
||||
.set_redirect_uri(
|
||||
oauth2::RedirectUrl::new(format!("{}/auth/github/callback", config.public_base_url))
|
||||
.expect("Invalid redirect URI"),
|
||||
);
|
||||
|
||||
let mut providers: HashMap<&'static str, Arc<dyn provider::OAuthProvider>> = HashMap::new();
|
||||
providers.insert("github", github::GitHubProvider::new(github_client, http));
|
||||
|
||||
Ok(Self { providers })
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &str) -> Option<&Arc<dyn provider::OAuthProvider>> {
|
||||
self.providers.get(id)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&'static str, &Arc<dyn provider::OAuthProvider>)> {
|
||||
self.providers.iter().map(|(k, v)| (*k, v))
|
||||
}
|
||||
}
|
||||
25
pacman-server/src/auth/provider.rs
Normal file
25
pacman-server/src/auth/provider.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::errors::ErrorResponse;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AuthUser {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuthProvider: Send + Sync {
|
||||
fn id(&self) -> &'static str;
|
||||
fn label(&self) -> &'static str;
|
||||
|
||||
async fn authorize(&self) -> axum::response::Response;
|
||||
|
||||
async fn handle_callback(&self, query: &HashMap<String, String>) -> Result<AuthUser, ErrorResponse>;
|
||||
}
|
||||
@@ -25,8 +25,20 @@ pub struct Config {
|
||||
pub host: std::net::IpAddr,
|
||||
#[serde(default = "default_shutdown_timeout")]
|
||||
pub shutdown_timeout_seconds: u32,
|
||||
// Public base URL used for OAuth redirect URIs
|
||||
pub public_base_url: String,
|
||||
// JWT
|
||||
pub jwt_secret: String,
|
||||
}
|
||||
|
||||
// Standard User-Agent: name/version (+site)
|
||||
pub const USER_AGENT: &str = concat!(
|
||||
env!("CARGO_PKG_NAME"),
|
||||
"/",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
" (+https://pacman.xevion.dev)"
|
||||
);
|
||||
|
||||
fn default_host() -> std::net::IpAddr {
|
||||
"0.0.0.0".parse().unwrap()
|
||||
}
|
||||
@@ -57,7 +69,7 @@ pub fn load_config() -> Config {
|
||||
Figment::new()
|
||||
.merge(Env::raw().map(|key| {
|
||||
if key == UncasedStr::new("RAILWAY_DEPLOYMENT_DRAINING_SECONDS") {
|
||||
"SHUTDOWN_TIMEOUT".into()
|
||||
"SHUTDOWN_TIMEOUT_SECONDS".into()
|
||||
} else {
|
||||
key.into()
|
||||
}
|
||||
|
||||
55
pacman-server/src/errors.rs
Normal file
55
pacman-server/src/errors.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
#[serde(skip_serializing)]
|
||||
status_code: Option<StatusCode>,
|
||||
pub error: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
self.status_code.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
pub fn unauthorized(description: impl Into<String>) -> Self {
|
||||
Self {
|
||||
status_code: Some(StatusCode::UNAUTHORIZED),
|
||||
error: "unauthorized".into(),
|
||||
description: Some(description.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bad_request(error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
|
||||
Self {
|
||||
status_code: Some(StatusCode::BAD_REQUEST),
|
||||
error: error.into(),
|
||||
description: description.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bad_gateway(error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
|
||||
Self {
|
||||
status_code: Some(StatusCode::BAD_GATEWAY),
|
||||
error: error.into(),
|
||||
description: description.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_status(status: StatusCode, error: impl Into<String>, description: impl Into<Option<String>>) -> Self {
|
||||
Self {
|
||||
status_code: Some(status),
|
||||
error: error.into(),
|
||||
description: description.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ErrorResponse {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
(self.status_code(), Json(self)).into_response()
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,38 @@
|
||||
use axum::Router;
|
||||
use axum::{routing::get, Router};
|
||||
use axum_cookie::CookieLayer;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{app::AppState, auth::AuthRegistry, config::Config};
|
||||
mod routes;
|
||||
|
||||
mod app;
|
||||
mod auth;
|
||||
mod config;
|
||||
mod errors;
|
||||
mod session;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Load environment variables
|
||||
#[cfg(debug_assertions)]
|
||||
dotenvy::from_path(format!("{}.env", env!("CARGO_MANIFEST_DIR"))).ok();
|
||||
dotenvy::from_path(std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".env")).ok();
|
||||
#[cfg(not(debug_assertions))]
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// Load configuration
|
||||
let config: Config = config::load_config();
|
||||
|
||||
let app = Router::new().fallback(|| async { "Hello, World!" });
|
||||
|
||||
let addr = std::net::SocketAddr::new(config.host, config.port);
|
||||
let auth = AuthRegistry::new(&config).expect("auth initializer");
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
|
||||
.route("/auth/{provider}", get(routes::oauth_authorize_handler))
|
||||
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
|
||||
.route("/logout", get(routes::logout_handler))
|
||||
.route("/profile", get(routes::profile_handler))
|
||||
.with_state(AppState::new(config, auth))
|
||||
.layer(CookieLayer::default());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
77
pacman-server/src/routes.rs
Normal file
77
pacman-server/src/routes.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_cookie::CookieManager;
|
||||
|
||||
use crate::{app::AppState, errors::ErrorResponse, session};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct AuthQuery {
|
||||
pub code: Option<String>,
|
||||
pub state: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn oauth_authorize_handler(
|
||||
State(app_state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
) -> axum::response::Response {
|
||||
let Some(prov) = app_state.auth.get(&provider) else {
|
||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||
};
|
||||
prov.authorize().await
|
||||
}
|
||||
|
||||
pub async fn oauth_callback_handler(
|
||||
State(app_state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
Query(params): Query<AuthQuery>,
|
||||
cookie: CookieManager,
|
||||
) -> axum::response::Response {
|
||||
let Some(prov) = app_state.auth.get(&provider) else {
|
||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||
};
|
||||
if let Some(error) = params.error {
|
||||
return ErrorResponse::bad_request(error, params.error_description).into_response();
|
||||
}
|
||||
let mut q = std::collections::HashMap::new();
|
||||
if let Some(v) = params.code {
|
||||
q.insert("code".to_string(), v);
|
||||
}
|
||||
if let Some(v) = params.state {
|
||||
q.insert("state".to_string(), v);
|
||||
}
|
||||
let user = match prov.handle_callback(&q).await {
|
||||
Ok(u) => u,
|
||||
Err(e) => return e.into_response(),
|
||||
};
|
||||
let session_token = session::create_jwt_for_user(&user, &app_state.jwt_encoding_key);
|
||||
app_state.sessions.insert(session_token.clone(), user);
|
||||
session::set_session_cookie(&cookie, &session_token);
|
||||
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
|
||||
}
|
||||
|
||||
pub async fn profile_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
|
||||
let Some(token_str) = session::get_session_token(&cookie) else {
|
||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
||||
};
|
||||
if !session::verify_jwt(&token_str, &app_state.jwt_decoding_key) {
|
||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
||||
}
|
||||
if let Some(user) = app_state.sessions.get(&token_str) {
|
||||
return axum::Json(&*user).into_response();
|
||||
}
|
||||
ErrorResponse::unauthorized("session not found").into_response()
|
||||
}
|
||||
|
||||
pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
|
||||
if let Some(token_str) = session::get_session_token(&cookie) {
|
||||
// Remove from in-memory sessions if present
|
||||
app_state.sessions.remove(&token_str);
|
||||
}
|
||||
session::clear_session_cookie(&cookie);
|
||||
(StatusCode::FOUND, Redirect::to("/")).into_response()
|
||||
}
|
||||
62
pacman-server/src/session.rs
Normal file
62
pacman-server/src/session.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum_cookie::{cookie::Cookie, prelude::SameSite, CookieManager};
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
|
||||
use crate::auth::provider::AuthUser;
|
||||
|
||||
pub const SESSION_COOKIE_NAME: &str = "session";
|
||||
pub const JWT_TTL_SECS: u64 = 60 * 60; // 1 hour
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub name: Option<String>,
|
||||
pub iat: usize,
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time went backwards")
|
||||
.as_secs() as usize;
|
||||
let claims = Claims {
|
||||
sub: user.username.clone(),
|
||||
name: user.name.clone(),
|
||||
iat: now,
|
||||
exp: now + JWT_TTL_SECS as usize,
|
||||
};
|
||||
encode(&Header::new(Algorithm::HS256), &claims, encoding_key).expect("jwt sign")
|
||||
}
|
||||
|
||||
pub fn verify_jwt(token: &str, decoding_key: &DecodingKey) -> bool {
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.leeway = 30;
|
||||
match decode::<Claims>(token, decoding_key, &validation) {
|
||||
Ok(_) => true,
|
||||
Err(e) => {
|
||||
println!("JWT verification failed: {:?}", e);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_session_cookie(cookie: &CookieManager, token: &str) {
|
||||
cookie.add(
|
||||
Cookie::builder(SESSION_COOKIE_NAME, token.to_string())
|
||||
.http_only(true)
|
||||
.secure(!cfg!(debug_assertions))
|
||||
.path("/")
|
||||
.same_site(SameSite::Lax)
|
||||
.build(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn clear_session_cookie(cookie: &CookieManager) {
|
||||
cookie.remove(SESSION_COOKIE_NAME);
|
||||
}
|
||||
|
||||
pub fn get_session_token(cookie: &CookieManager) -> Option<String> {
|
||||
cookie.get(SESSION_COOKIE_NAME).map(|c| c.value().to_string())
|
||||
}
|
||||
Reference in New Issue
Block a user