diff --git a/Cargo.lock b/Cargo.lock index bbbd760..03c8a8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1814,6 +1814,7 @@ dependencies = [ "async-trait", "axum", "axum-cookie", + "chrono", "dashmap", "dotenvy", "figment", diff --git a/pacman-server/Cargo.toml b/pacman-server/Cargo.toml index be33d7c..789fb62 100644 --- a/pacman-server/Cargo.toml +++ b/pacman-server/Cargo.toml @@ -26,6 +26,7 @@ sqlx = { version = "0.8", features = [ "postgres", "chrono", ] } +chrono = { version = "0.4", features = ["serde", "clock"] } figment = { version = "0.10", features = ["env"] } dotenvy = "0.15" dashmap = "6.1" diff --git a/pacman-server/migrations/20240917120000_init_users.sql b/pacman-server/migrations/20240917120000_init_users.sql new file mode 100644 index 0000000..cc06314 --- /dev/null +++ b/pacman-server/migrations/20240917120000_init_users.sql @@ -0,0 +1,15 @@ +-- users table +CREATE TABLE IF NOT EXISTS users ( + id BIGSERIAL PRIMARY KEY, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + username TEXT NOT NULL, + display_name TEXT NULL, + email TEXT NULL, + avatar_url TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (provider, provider_user_id) +); + +CREATE INDEX IF NOT EXISTS idx_users_provider ON users (provider, provider_user_id); diff --git a/pacman-server/src/app.rs b/pacman-server/src/app.rs index 0edf9f5..85c454f 100644 --- a/pacman-server/src/app.rs +++ b/pacman-server/src/app.rs @@ -2,6 +2,7 @@ use dashmap::DashMap; use jsonwebtoken::{DecodingKey, EncodingKey}; use std::sync::Arc; +use crate::data::pool::PgPool; use crate::{auth::AuthRegistry, config::Config}; #[derive(Clone)] @@ -11,10 +12,11 @@ pub struct AppState { pub sessions: Arc>, pub jwt_encoding_key: Arc, pub jwt_decoding_key: Arc, + pub db: Arc, } impl AppState { - pub fn new(config: Config, auth: AuthRegistry) -> Self { + pub fn new(config: Config, auth: AuthRegistry, db: PgPool) -> Self { let jwt_secret = config.jwt_secret.clone(); Self { @@ -23,6 +25,7 @@ impl AppState { sessions: Arc::new(DashMap::new()), jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())), jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())), + db: Arc::new(db), } } } diff --git a/pacman-server/src/auth/mod.rs b/pacman-server/src/auth/mod.rs index bfcc94f..149ff42 100644 --- a/pacman-server/src/auth/mod.rs +++ b/pacman-server/src/auth/mod.rs @@ -57,7 +57,7 @@ impl AuthRegistry { self.providers.get(id) } - pub fn iter(&self) -> impl Iterator)> { - self.providers.iter().map(|(k, v)| (*k, v)) + pub fn values(&self) -> impl Iterator> { + self.providers.values() } } diff --git a/pacman-server/src/data/mod.rs b/pacman-server/src/data/mod.rs new file mode 100644 index 0000000..59e8749 --- /dev/null +++ b/pacman-server/src/data/mod.rs @@ -0,0 +1,2 @@ +pub mod pool; +pub mod user; diff --git a/pacman-server/src/data/pool.rs b/pacman-server/src/data/pool.rs new file mode 100644 index 0000000..e33e35d --- /dev/null +++ b/pacman-server/src/data/pool.rs @@ -0,0 +1,16 @@ +use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; +use tracing::{info, warn}; + +pub type PgPool = Pool; + +pub async fn create_pool(database_url: &str, max_connections: u32) -> PgPool { + info!("Connecting to PostgreSQL"); + PgPoolOptions::new() + .max_connections(max_connections) + .connect(database_url) + .await + .unwrap_or_else(|e| { + warn!(error = %e, "Failed to connect to PostgreSQL"); + panic!("database connect failed: {}", e); + }) +} diff --git a/pacman-server/src/data/user.rs b/pacman-server/src/data/user.rs new file mode 100644 index 0000000..a9f1b0c --- /dev/null +++ b/pacman-server/src/data/user.rs @@ -0,0 +1,69 @@ +use serde::Serialize; +use sqlx::FromRow; + +#[derive(Debug, Clone, Serialize, FromRow)] +pub struct User { + pub id: i64, + pub provider: String, + pub provider_user_id: String, + pub username: String, + pub display_name: Option, + pub email: Option, + pub avatar_url: Option, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +pub async fn upsert_user( + pool: &sqlx::PgPool, + provider: &str, + provider_user_id: &str, + username: &str, + display_name: Option<&str>, + email: Option<&str>, + avatar_url: Option<&str>, +) -> Result { + let rec = sqlx::query_as::<_, User>( + r#" + INSERT INTO users (provider, provider_user_id, username, display_name, email, avatar_url) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (provider, provider_user_id) + DO UPDATE SET + username = EXCLUDED.username, + display_name = EXCLUDED.display_name, + email = EXCLUDED.email, + avatar_url = EXCLUDED.avatar_url, + updated_at = NOW() + RETURNING id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at + "#, + ) + .bind(provider) + .bind(provider_user_id) + .bind(username) + .bind(display_name) + .bind(email) + .bind(avatar_url) + .fetch_one(pool) + .await?; + + Ok(rec) +} + +pub async fn get_user_by_provider_id( + pool: &sqlx::PgPool, + provider: &str, + provider_user_id: &str, +) -> Result, sqlx::Error> { + let rec = sqlx::query_as::<_, User>( + r#" + SELECT id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at + FROM users + WHERE provider = $1 AND provider_user_id = $2 + "#, + ) + .bind(provider) + .bind(provider_user_id) + .fetch_optional(pool) + .await?; + Ok(rec) +} diff --git a/pacman-server/src/main.rs b/pacman-server/src/main.rs index c0d5cc6..0432500 100644 --- a/pacman-server/src/main.rs +++ b/pacman-server/src/main.rs @@ -9,6 +9,7 @@ mod routes; mod app; mod auth; mod config; +mod data; mod errors; mod session; use std::sync::Arc; @@ -36,6 +37,11 @@ async fn main() { let addr = std::net::SocketAddr::new(config.host, config.port); let shutdown_timeout = std::time::Duration::from_secs(config.shutdown_timeout_seconds as u64); let auth = AuthRegistry::new(&config).expect("auth initializer"); + let db = data::pool::create_pool(&config.database_url, 10).await; + // Run database migrations at startup + if let Err(e) = sqlx::migrate!("./migrations").run(&db).await { + panic!("failed to run database migrations: {}", e); + } let app = Router::new() .route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." })) @@ -44,7 +50,7 @@ async fn main() { .route("/auth/{provider}/callback", get(routes::oauth_callback_handler)) .route("/logout", get(routes::logout_handler)) .route("/profile", get(routes::profile_handler)) - .with_state(AppState::new(config, auth)) + .with_state(AppState::new(config, auth, db)) .layer(CookieLayer::default()); info!(%addr, "Starting HTTP server bind"); @@ -90,8 +96,8 @@ async fn main() { if let Some(signaled_at) = *rx_signal.borrow() { let elapsed = now.duration_since(signaled_at); if elapsed < shutdown_timeout { - let remaining = shutdown_timeout - elapsed; - info!(remaining = ?remaining, "Graceful shutdown complete"); + let remaining = format!("{:.2?}", shutdown_timeout - elapsed); + info!(remaining = remaining, "Graceful shutdown complete"); } } res.unwrap(); diff --git a/pacman-server/src/routes.rs b/pacman-server/src/routes.rs index 2fd7b5b..073905f 100644 --- a/pacman-server/src/routes.rs +++ b/pacman-server/src/routes.rs @@ -5,8 +5,9 @@ use axum::{ }; use axum_cookie::CookieManager; use serde::Serialize; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, info, instrument, trace, warn}; +use crate::data::user as user_repo; use crate::{app::AppState, errors::ErrorResponse, session}; #[derive(Debug, serde::Deserialize)] @@ -17,6 +18,7 @@ pub struct AuthQuery { pub error_description: Option, } +#[instrument(skip_all, fields(provider = %provider))] pub async fn oauth_authorize_handler( State(app_state): State, Path(provider): Path, @@ -25,12 +27,13 @@ pub async fn oauth_authorize_handler( warn!(%provider, "Unknown OAuth provider"); return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response(); }; - trace!(provider = %provider, "Starting OAuth authorization"); + trace!("Starting OAuth authorization"); let resp = prov.authorize().await; - trace!(provider = %provider, "Redirecting to provider authorization page"); + trace!("Redirecting to provider authorization page"); resp } +#[instrument(skip_all, fields(provider = %provider))] pub async fn oauth_callback_handler( State(app_state): State, Path(provider): Path, @@ -59,28 +62,63 @@ pub async fn oauth_callback_handler( return e.into_response(); } }; - let session_token = session::create_jwt_for_user(&user, &app_state.jwt_encoding_key); - app_state.sessions.insert(session_token.clone(), user); + // Persist or update in database + match user_repo::upsert_user( + &app_state.db, + &provider, + &user.id, + &user.username, + user.name.as_deref(), + user.email.as_deref(), + user.avatar_url.as_deref(), + ) + .await + { + Ok(_db_user) => {} + Err(e) => { + warn!(error = %e, provider = %provider, "Failed to upsert user in database"); + } + } + let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key); session::set_session_cookie(&cookie, &session_token); info!(%provider, "Signed in successfully"); (StatusCode::FOUND, Redirect::to("/profile")).into_response() } +#[instrument(skip_all)] pub async fn profile_handler(State(app_state): State, cookie: CookieManager) -> axum::response::Response { let Some(token_str) = session::get_session_token(&cookie) else { debug!("Missing session cookie"); return ErrorResponse::unauthorized("missing session cookie").into_response(); }; - if !session::verify_jwt(&token_str, &app_state.jwt_decoding_key) { + let Some(claims) = session::decode_jwt(&token_str, &app_state.jwt_decoding_key) else { debug!("Invalid session token"); return ErrorResponse::unauthorized("invalid session token").into_response(); + }; + // sub format: provider:provider_user_id + let (prov, prov_user_id) = match claims.sub.split_once(':') { + Some((p, id)) => (p, id), + None => { + debug!("Malformed session token subject"); + return ErrorResponse::unauthorized("invalid session token").into_response(); + } + }; + match user_repo::get_user_by_provider_id(&app_state.db, prov, prov_user_id).await { + Ok(Some(db_user)) => axum::Json(db_user).into_response(), + Ok(None) => { + debug!("User not found for session"); + ErrorResponse::unauthorized("session not found").into_response() + } + Err(e) => { + warn!(error = %e, "Failed to fetch user for session"); + ErrorResponse::with_status( + StatusCode::INTERNAL_SERVER_ERROR, + "database_error", + Some("could not fetch user".into()), + ) + .into_response() + } } - if let Some(user) = app_state.sessions.get(&token_str) { - trace!("Fetched user profile"); - return axum::Json(&*user).into_response(); - } - debug!("Session not found"); - ErrorResponse::unauthorized("session not found").into_response() } pub async fn logout_handler(State(app_state): State, cookie: CookieManager) -> axum::response::Response { @@ -95,16 +133,18 @@ pub async fn logout_handler(State(app_state): State, cookie: CookieMan #[derive(Serialize)] struct ProviderInfo { - provider: &'static str, + id: &'static str, + name: &'static str, active: bool, } pub async fn list_providers_handler(State(app_state): State) -> axum::response::Response { let providers: Vec = app_state .auth - .iter() - .map(|(id, provider)| ProviderInfo { - provider: id, + .values() + .map(|provider| ProviderInfo { + id: provider.id(), + name: provider.label(), active: provider.active(), }) .collect(); diff --git a/pacman-server/src/session.rs b/pacman-server/src/session.rs index 4ab7061..ee541cb 100644 --- a/pacman-server/src/session.rs +++ b/pacman-server/src/session.rs @@ -11,19 +11,19 @@ pub const JWT_TTL_SECS: u64 = 60 * 60; // 1 hour #[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Claims { - pub sub: String, + pub sub: String, // format: "{provider}:{provider_user_id}" pub name: Option, pub iat: usize, pub exp: usize, } -pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> String { +pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time went backwards") .as_secs() as usize; let claims = Claims { - sub: user.username.clone(), + sub: format!("{}:{}", provider, user.id), name: user.name.clone(), iat: now, exp: now + JWT_TTL_SECS as usize, @@ -33,14 +33,14 @@ pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> Strin token } -pub fn verify_jwt(token: &str, decoding_key: &DecodingKey) -> bool { +pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option { let mut validation = Validation::new(Algorithm::HS256); validation.leeway = 30; match decode::(token, decoding_key, &validation) { - Ok(_) => true, + Ok(data) => Some(data.claims), Err(e) => { warn!(error = %e, "Session JWT verification failed"); - false + None } } }