mirror of
https://github.com/Xevion/Pac-Man.git
synced 2025-12-10 00:07:52 -06:00
feat: users table with sqlx, migrations, data persistence
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1814,6 +1814,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum-cookie",
|
||||
"chrono",
|
||||
"dashmap",
|
||||
"dotenvy",
|
||||
"figment",
|
||||
|
||||
@@ -26,6 +26,7 @@ sqlx = { version = "0.8", features = [
|
||||
"postgres",
|
||||
"chrono",
|
||||
] }
|
||||
chrono = { version = "0.4", features = ["serde", "clock"] }
|
||||
figment = { version = "0.10", features = ["env"] }
|
||||
dotenvy = "0.15"
|
||||
dashmap = "6.1"
|
||||
|
||||
15
pacman-server/migrations/20240917120000_init_users.sql
Normal file
15
pacman-server/migrations/20240917120000_init_users.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
-- users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
provider TEXT NOT NULL,
|
||||
provider_user_id TEXT NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
display_name TEXT NULL,
|
||||
email TEXT NULL,
|
||||
avatar_url TEXT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (provider, provider_user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_users_provider ON users (provider, provider_user_id);
|
||||
@@ -2,6 +2,7 @@ use dashmap::DashMap;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data::pool::PgPool;
|
||||
use crate::{auth::AuthRegistry, config::Config};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -11,10 +12,11 @@ pub struct AppState {
|
||||
pub sessions: Arc<DashMap<String, crate::auth::provider::AuthUser>>,
|
||||
pub jwt_encoding_key: Arc<EncodingKey>,
|
||||
pub jwt_decoding_key: Arc<DecodingKey>,
|
||||
pub db: Arc<PgPool>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(config: Config, auth: AuthRegistry) -> Self {
|
||||
pub fn new(config: Config, auth: AuthRegistry, db: PgPool) -> Self {
|
||||
let jwt_secret = config.jwt_secret.clone();
|
||||
|
||||
Self {
|
||||
@@ -23,6 +25,7 @@ impl AppState {
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
jwt_encoding_key: Arc::new(EncodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
jwt_decoding_key: Arc::new(DecodingKey::from_secret(jwt_secret.as_bytes())),
|
||||
db: Arc::new(db),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ impl AuthRegistry {
|
||||
self.providers.get(id)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&'static str, &Arc<dyn provider::OAuthProvider>)> {
|
||||
self.providers.iter().map(|(k, v)| (*k, v))
|
||||
pub fn values(&self) -> impl Iterator<Item = &Arc<dyn provider::OAuthProvider>> {
|
||||
self.providers.values()
|
||||
}
|
||||
}
|
||||
|
||||
2
pacman-server/src/data/mod.rs
Normal file
2
pacman-server/src/data/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod pool;
|
||||
pub mod user;
|
||||
16
pacman-server/src/data/pool.rs
Normal file
16
pacman-server/src/data/pool.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub type PgPool = Pool<Postgres>;
|
||||
|
||||
pub async fn create_pool(database_url: &str, max_connections: u32) -> PgPool {
|
||||
info!("Connecting to PostgreSQL");
|
||||
PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(database_url)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
warn!(error = %e, "Failed to connect to PostgreSQL");
|
||||
panic!("database connect failed: {}", e);
|
||||
})
|
||||
}
|
||||
69
pacman-server/src/data/user.rs
Normal file
69
pacman-server/src/data/user.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use serde::Serialize;
|
||||
use sqlx::FromRow;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, FromRow)]
|
||||
pub struct User {
|
||||
pub id: i64,
|
||||
pub provider: String,
|
||||
pub provider_user_id: String,
|
||||
pub username: String,
|
||||
pub display_name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
pub async fn upsert_user(
|
||||
pool: &sqlx::PgPool,
|
||||
provider: &str,
|
||||
provider_user_id: &str,
|
||||
username: &str,
|
||||
display_name: Option<&str>,
|
||||
email: Option<&str>,
|
||||
avatar_url: Option<&str>,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
let rec = sqlx::query_as::<_, User>(
|
||||
r#"
|
||||
INSERT INTO users (provider, provider_user_id, username, display_name, email, avatar_url)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (provider, provider_user_id)
|
||||
DO UPDATE SET
|
||||
username = EXCLUDED.username,
|
||||
display_name = EXCLUDED.display_name,
|
||||
email = EXCLUDED.email,
|
||||
avatar_url = EXCLUDED.avatar_url,
|
||||
updated_at = NOW()
|
||||
RETURNING id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(provider)
|
||||
.bind(provider_user_id)
|
||||
.bind(username)
|
||||
.bind(display_name)
|
||||
.bind(email)
|
||||
.bind(avatar_url)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(rec)
|
||||
}
|
||||
|
||||
pub async fn get_user_by_provider_id(
|
||||
pool: &sqlx::PgPool,
|
||||
provider: &str,
|
||||
provider_user_id: &str,
|
||||
) -> Result<Option<User>, sqlx::Error> {
|
||||
let rec = sqlx::query_as::<_, User>(
|
||||
r#"
|
||||
SELECT id, provider, provider_user_id, username, display_name, email, avatar_url, created_at, updated_at
|
||||
FROM users
|
||||
WHERE provider = $1 AND provider_user_id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(provider)
|
||||
.bind(provider_user_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(rec)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ mod routes;
|
||||
mod app;
|
||||
mod auth;
|
||||
mod config;
|
||||
mod data;
|
||||
mod errors;
|
||||
mod session;
|
||||
use std::sync::Arc;
|
||||
@@ -36,6 +37,11 @@ async fn main() {
|
||||
let addr = std::net::SocketAddr::new(config.host, config.port);
|
||||
let shutdown_timeout = std::time::Duration::from_secs(config.shutdown_timeout_seconds as u64);
|
||||
let auth = AuthRegistry::new(&config).expect("auth initializer");
|
||||
let db = data::pool::create_pool(&config.database_url, 10).await;
|
||||
// Run database migrations at startup
|
||||
if let Err(e) = sqlx::migrate!("./migrations").run(&db).await {
|
||||
panic!("failed to run database migrations: {}", e);
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World! Visit /auth/github to start OAuth flow." }))
|
||||
@@ -44,7 +50,7 @@ async fn main() {
|
||||
.route("/auth/{provider}/callback", get(routes::oauth_callback_handler))
|
||||
.route("/logout", get(routes::logout_handler))
|
||||
.route("/profile", get(routes::profile_handler))
|
||||
.with_state(AppState::new(config, auth))
|
||||
.with_state(AppState::new(config, auth, db))
|
||||
.layer(CookieLayer::default());
|
||||
|
||||
info!(%addr, "Starting HTTP server bind");
|
||||
@@ -90,8 +96,8 @@ async fn main() {
|
||||
if let Some(signaled_at) = *rx_signal.borrow() {
|
||||
let elapsed = now.duration_since(signaled_at);
|
||||
if elapsed < shutdown_timeout {
|
||||
let remaining = shutdown_timeout - elapsed;
|
||||
info!(remaining = ?remaining, "Graceful shutdown complete");
|
||||
let remaining = format!("{:.2?}", shutdown_timeout - elapsed);
|
||||
info!(remaining = remaining, "Graceful shutdown complete");
|
||||
}
|
||||
}
|
||||
res.unwrap();
|
||||
|
||||
@@ -5,8 +5,9 @@ use axum::{
|
||||
};
|
||||
use axum_cookie::CookieManager;
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
|
||||
use crate::data::user as user_repo;
|
||||
use crate::{app::AppState, errors::ErrorResponse, session};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
@@ -17,6 +18,7 @@ pub struct AuthQuery {
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(provider = %provider))]
|
||||
pub async fn oauth_authorize_handler(
|
||||
State(app_state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
@@ -25,12 +27,13 @@ pub async fn oauth_authorize_handler(
|
||||
warn!(%provider, "Unknown OAuth provider");
|
||||
return ErrorResponse::bad_request("invalid_provider", Some(provider)).into_response();
|
||||
};
|
||||
trace!(provider = %provider, "Starting OAuth authorization");
|
||||
trace!("Starting OAuth authorization");
|
||||
let resp = prov.authorize().await;
|
||||
trace!(provider = %provider, "Redirecting to provider authorization page");
|
||||
trace!("Redirecting to provider authorization page");
|
||||
resp
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(provider = %provider))]
|
||||
pub async fn oauth_callback_handler(
|
||||
State(app_state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
@@ -59,28 +62,63 @@ pub async fn oauth_callback_handler(
|
||||
return e.into_response();
|
||||
}
|
||||
};
|
||||
let session_token = session::create_jwt_for_user(&user, &app_state.jwt_encoding_key);
|
||||
app_state.sessions.insert(session_token.clone(), user);
|
||||
// Persist or update in database
|
||||
match user_repo::upsert_user(
|
||||
&app_state.db,
|
||||
&provider,
|
||||
&user.id,
|
||||
&user.username,
|
||||
user.name.as_deref(),
|
||||
user.email.as_deref(),
|
||||
user.avatar_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_db_user) => {}
|
||||
Err(e) => {
|
||||
warn!(error = %e, provider = %provider, "Failed to upsert user in database");
|
||||
}
|
||||
}
|
||||
let session_token = session::create_jwt_for_user(&provider, &user, &app_state.jwt_encoding_key);
|
||||
session::set_session_cookie(&cookie, &session_token);
|
||||
info!(%provider, "Signed in successfully");
|
||||
(StatusCode::FOUND, Redirect::to("/profile")).into_response()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub async fn profile_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
|
||||
let Some(token_str) = session::get_session_token(&cookie) else {
|
||||
debug!("Missing session cookie");
|
||||
return ErrorResponse::unauthorized("missing session cookie").into_response();
|
||||
};
|
||||
if !session::verify_jwt(&token_str, &app_state.jwt_decoding_key) {
|
||||
let Some(claims) = session::decode_jwt(&token_str, &app_state.jwt_decoding_key) else {
|
||||
debug!("Invalid session token");
|
||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
||||
};
|
||||
// sub format: provider:provider_user_id
|
||||
let (prov, prov_user_id) = match claims.sub.split_once(':') {
|
||||
Some((p, id)) => (p, id),
|
||||
None => {
|
||||
debug!("Malformed session token subject");
|
||||
return ErrorResponse::unauthorized("invalid session token").into_response();
|
||||
}
|
||||
};
|
||||
match user_repo::get_user_by_provider_id(&app_state.db, prov, prov_user_id).await {
|
||||
Ok(Some(db_user)) => axum::Json(db_user).into_response(),
|
||||
Ok(None) => {
|
||||
debug!("User not found for session");
|
||||
ErrorResponse::unauthorized("session not found").into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to fetch user for session");
|
||||
ErrorResponse::with_status(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"database_error",
|
||||
Some("could not fetch user".into()),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
if let Some(user) = app_state.sessions.get(&token_str) {
|
||||
trace!("Fetched user profile");
|
||||
return axum::Json(&*user).into_response();
|
||||
}
|
||||
debug!("Session not found");
|
||||
ErrorResponse::unauthorized("session not found").into_response()
|
||||
}
|
||||
|
||||
pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieManager) -> axum::response::Response {
|
||||
@@ -95,16 +133,18 @@ pub async fn logout_handler(State(app_state): State<AppState>, cookie: CookieMan
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ProviderInfo {
|
||||
provider: &'static str,
|
||||
id: &'static str,
|
||||
name: &'static str,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
pub async fn list_providers_handler(State(app_state): State<AppState>) -> axum::response::Response {
|
||||
let providers: Vec<ProviderInfo> = app_state
|
||||
.auth
|
||||
.iter()
|
||||
.map(|(id, provider)| ProviderInfo {
|
||||
provider: id,
|
||||
.values()
|
||||
.map(|provider| ProviderInfo {
|
||||
id: provider.id(),
|
||||
name: provider.label(),
|
||||
active: provider.active(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -11,19 +11,19 @@ pub const JWT_TTL_SECS: u64 = 60 * 60; // 1 hour
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub sub: String, // format: "{provider}:{provider_user_id}"
|
||||
pub name: Option<String>,
|
||||
pub iat: usize,
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> String {
|
||||
pub fn create_jwt_for_user(provider: &str, user: &AuthUser, encoding_key: &EncodingKey) -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time went backwards")
|
||||
.as_secs() as usize;
|
||||
let claims = Claims {
|
||||
sub: user.username.clone(),
|
||||
sub: format!("{}:{}", provider, user.id),
|
||||
name: user.name.clone(),
|
||||
iat: now,
|
||||
exp: now + JWT_TTL_SECS as usize,
|
||||
@@ -33,14 +33,14 @@ pub fn create_jwt_for_user(user: &AuthUser, encoding_key: &EncodingKey) -> Strin
|
||||
token
|
||||
}
|
||||
|
||||
pub fn verify_jwt(token: &str, decoding_key: &DecodingKey) -> bool {
|
||||
pub fn decode_jwt(token: &str, decoding_key: &DecodingKey) -> Option<Claims> {
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.leeway = 30;
|
||||
match decode::<Claims>(token, decoding_key, &validation) {
|
||||
Ok(_) => true,
|
||||
Ok(data) => Some(data.claims),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Session JWT verification failed");
|
||||
false
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user