feat: users table with sqlx, migrations, data persistence

This commit is contained in:
Ryan Walters
2025-09-17 09:43:52 -05:00
parent ac1417aabc
commit 1cf3b901e8
11 changed files with 181 additions and 28 deletions

1
Cargo.lock generated
View File

@@ -1814,6 +1814,7 @@ dependencies = [
"async-trait",
"axum",
"axum-cookie",
"chrono",
"dashmap",
"dotenvy",
"figment",

View File

@@ -26,6 +26,7 @@ sqlx = { version = "0.8", features = [
"postgres",
"chrono",
] }
chrono = { version = "0.4", features = ["serde", "clock"] }
figment = { version = "0.10", features = ["env"] }
dotenvy = "0.15"
dashmap = "6.1"

View File

@@ -0,0 +1,15 @@
-- users table
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
username TEXT NOT NULL,
display_name TEXT NULL,
email TEXT NULL,
avatar_url TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (provider, provider_user_id)
);
CREATE INDEX IF NOT EXISTS idx_users_provider ON users (provider, provider_user_id);

View File

@@ -2,6 +2,7 @@ use dashmap::DashMap;
use jsonwebtoken::{DecodingKey, EncodingKey};
use std::sync::Arc;
use crate::data::pool::PgPool;
use crate::{auth::AuthRegistry, config::Config};
#[derive(Clone)]
@@ -11,10 +12,11 @@ pub struct AppState {
pub sessions: Arc<DashMap<String, crate::auth::provider::AuthUser>>,
pub jwt_encoding_key: Arc<EncodingKey>,
pub jwt_decoding_key: Arc<DecodingKey>,
pub db: Arc<PgPool>,
}
impl AppState {
pub fn new(config: Config, auth: AuthRegistry) -> Self {
pub fn new(config: Config, auth: AuthRegistry, db: PgPool) -> Self {
let jwt_secret = config.jwt_secret.clone();
Self {
@@ -23,6 +25,7 @@ impl AppState {
sessions: Arc::new(DashMap::new()),
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
db: Arc::new(db),
}
}
}

View File

@@ -57,7 +57,7 @@ impl AuthRegistry {
self.providers.get(id)
}
pub fn iter(&self) -> impl Iterator<Item = (&'static str, &Arc<dyn provider::OAuthProvider>)> {
self.providers.iter().map(|(k, v)| (*k, v))
pub fn values(&self) -> impl Iterator<Item = &Arc<dyn provider::OAuthProvider>> {
self.providers.values()
}
}

View File

@@ -0,0 +1,2 @@
pub mod pool;
pub mod user;

View File

@@ -0,0 +1,16 @@
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use tracing::{info, warn};
pub type PgPool = Pool<Postgres>;
pub async fn create_pool(database_url: &str, max_connections: u32) -> PgPool {
info!("Connecting to PostgreSQL");
PgPoolOptions::new()
.max_connections(max_connections)
.connect(database_url)
.await
.unwrap_or_else(|e| {
warn!(error = %e, "Failed to connect to PostgreSQL");
panic!("database connect failed: {}", e);
})
}

View File

@@ -0,0 +1,69 @@
use serde::Serialize;
use sqlx::FromRow;
#[derive(Debug, Clone, Serialize, FromRow)]
pub struct User {
pub id: i64,
pub provider: String,
pub provider_user_id: String,
pub username: String,
pub display_name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
pub async fn upsert_user(
pool: &sqlx::PgPool,
provider: &str,
provider_user_id: &str,
username: &str,
display_name: Option<&str>,
email: Option<&str>,
avatar_url: Option<&str>,
) -> Result<User, sqlx::Error> {
let rec = sqlx::query_as::<_, User>(
r#"
INSERT INTO users (provider, provider_user_id, username, display_name, email, avatar_url)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (provider, provider_user_id)
DO UPDATE SET
username = EXCLUDED.username,
display_name = EXCLUDED.display_name,
email = EXCLUDED.email,
avatar_url = EXCLUDED.avatar_url,
updated_at = NOW()
RETURNING id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at
"#,
)
.bind(provider)
.bind(provider_user_id)
.bind(username)
.bind(display_name)
.bind(email)
.bind(avatar_url)
.fetch_one(pool)
.await?;
Ok(rec)
}
pub async fn get_user_by_provider_id(
pool: &sqlx::PgPool,
provider: &str,
provider_user_id: &str,
) -> Result<Option<User>, sqlx::Error> {
let rec = sqlx::query_as::<_, User>(
r#"
SELECT id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at
FROM users
WHERE provider = $1 AND provider_user_id = $2
"#,
)
.bind(provider)
.bind(provider_user_id)
.fetch_optional(pool)
.await?;
Ok(rec)
}

View File

@@ -9,6 +9,7 @@ mod routes;
mod app;
mod auth;
mod config;
mod data;
mod errors;
mod session;
use std::sync::Arc;
@@ -36,6 +37,11 @@ async fn main() {
let addr = std::net::SocketAddr::new(config.host, config.port);
let shutdown_timeout = std::time::Duration::from_secs(config.shutdown_timeout_seconds as u64);
let auth = AuthRegistry::new(&config).expect("auth initializer");
let db = data::pool::create_pool(&config.database_url, 10).await;
// Run database migrations at startup
if let Err(e) = sqlx::migrate!("./migrations").run(&db).await {
panic!("failed to run database migrations: {}", e);
}
let app = Router::new()
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
@@ -44,7 +50,7 @@ async fn main() {
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
.route("/logout", get(routes::logout_handler))
.route("/profile", get(routes::profile_handler))
.with_state(AppState::new(config, auth))
.with_state(AppState::new(config, auth, db))
.layer(CookieLayer::default());
info!(%addr, "Starting HTTP server bind");
@@ -90,8 +96,8 @@ async fn main() {
if let Some(signaled_at) = *rx_signal.borrow() {
let elapsed = now.duration_since(signaled_at);
if elapsed < shutdown_timeout {
let remaining = shutdown_timeout - elapsed;
info!(remaining = ?remaining, "Graceful shutdown complete");
let remaining = format!("{:.2?}", shutdown_timeout - elapsed);
info!(remaining = remaining, "Graceful shutdown complete");
}
}
res.unwrap();

View File

@@ -5,8 +5,9 @@ use axum::{
};
use axum_cookie::CookieManager;
use serde::Serialize;
use tracing::{debug, info, trace, warn};
use tracing::{debug, info, instrument, trace, warn};
use crate::data::user as user_repo;
use crate::{app::AppState, errors::ErrorResponse, session};
#[derive(Debug, serde::Deserialize)]
@@ -17,6 +18,7 @@ pub struct AuthQuery {
pub error_description: Option<String>,
}
#[instrument(skip_all, fields(provider = %provider))]
pub async fn oauth_authorize_handler(
State(app_state): State<AppState>,
Path(provider): Path<String>,
@@ -25,12 +27,13 @@ pub async fn oauth_authorize_handler(
warn!(%provider, "Unknown OAuth provider");
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
};
trace!(provider = %provider, "Starting OAuth authorization");
trace!("Starting OAuth authorization");
let resp = prov.authorize().await;
trace!(provider = %provider, "Redirecting to provider authorization page");
trace!("Redirecting to provider authorization page");
resp
}
#[instrument(skip_all, fields(provider = %provider))]
pub async fn oauth_callback_handler(
State(app_state): State<AppState>,
Path(provider): Path<String>,
@@ -59,28 +62,63 @@ pub async fn oauth_callback_handler(
return e.into_response();
}
};
let session_token = session::create_jwt_for_user(&user, &app_state.jwt_encoding_key);
app_state.sessions.insert(session_token.clone(), user);
// Persist or update in database
match user_repo::upsert_user(
&app_state.db,
&provider,
&user.id,
&user.username,
user.name.as_deref(),
user.email.as_deref(),
user.avatar_url.as_deref(),
)
.await
{
Ok(_db_user) => {}
Err(e) => {
warn!(error = %e, provider = %provider, "Failed to upsert user in database");
}
}
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
session::set_session_cookie(&cookie, &session_token);
info!(%provider, "Signed in successfully");
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
}
#[instrument(skip_all)]
pub async fn profile_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
let Some(token_str) = session::get_session_token(&cookie) else {
debug!("Missing session cookie");
return ErrorResponse::unauthorized("missing session cookie").into_response();
};
if !session::verify_jwt(&token_str, &app_state.jwt_decoding_key) {
let Some(claims) = session::decode_jwt(&token_str, &app_state.jwt_decoding_key) else {
debug!("Invalid session token");
return ErrorResponse::unauthorized("invalid session token").into_response();
};
// sub format: provider:provider_user_id
let (prov, prov_user_id) = match claims.sub.split_once(':') {
Some((p, id)) => (p, id),
None => {
debug!("Malformed session token subject");
return ErrorResponse::unauthorized("invalid session token").into_response();
}
};
match user_repo::get_user_by_provider_id(&app_state.db, prov, prov_user_id).await {
Ok(Some(db_user)) => axum::Json(db_user).into_response(),
Ok(None) => {
debug!("User not found for session");
ErrorResponse::unauthorized("session not found").into_response()
}
Err(e) => {
warn!(error = %e, "Failed to fetch user for session");
ErrorResponse::with_status(
StatusCode::INTERNAL_SERVER_ERROR,
"database_error",
Some("could not fetch user".into()),
)
.into_response()
}
}
if let Some(user) = app_state.sessions.get(&token_str) {
trace!("Fetched user profile");
return axum::Json(&*user).into_response();
}
debug!("Session not found");
ErrorResponse::unauthorized("session not found").into_response()
}
pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
@@ -95,16 +133,18 @@ pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieMan
#[derive(Serialize)]
struct ProviderInfo {
provider: &'static str,
id: &'static str,
name: &'static str,
active: bool,
}
pub async fn list_providers_handler(State(app_state): State<AppState>) -> axum::response::Response {
let providers: Vec<ProviderInfo> = app_state
.auth
.iter()
.map(|(id, provider)| ProviderInfo {
provider: id,
.values()
.map(|provider| ProviderInfo {
id: provider.id(),
name: provider.label(),
active: provider.active(),
})
.collect();

View File

@@ -11,19 +11,19 @@ pub const JWT_TTL_SECS: u64 = 60 * 60; // 1 hour
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Claims {
pub sub: String,
pub sub: String, // format: "{provider}:{provider_user_id}"
pub name: Option<String>,
pub iat: usize,
pub exp: usize,
}
pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> String {
pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_secs() as usize;
let claims = Claims {
sub: user.username.clone(),
sub: format!("{}:{}", provider, user.id),
name: user.name.clone(),
iat: now,
exp: now + JWT_TTL_SECS as usize,
@@ -33,14 +33,14 @@ pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> Strin
token
}
pub fn verify_jwt(token: &str, decoding_key: &DecodingKey) -> bool {
pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option<Claims> {
let mut validation = Validation::new(Algorithm::HS256);
validation.leeway = 30;
match decode::<Claims>(token, decoding_key, &validation) {
Ok(_) => true,
Ok(data) => Some(data.claims),
Err(e) => {
warn!(error = %e, "Session JWT verification failed");
false
None
}
}
}