feat: proper shutdown timeout handling

This commit is contained in:
Ryan Walters
2025-09-17 03:41:13 -05:00
parent 750b47b609
commit 8d9c0621c9

View File

@@ -9,8 +9,11 @@ mod auth;
mod config; mod config;
mod errors; mod errors;
mod session; mod session;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[cfg(unix)] #[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind}; use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::{watch, Notify};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@@ -24,6 +27,7 @@ async fn main() {
let config: Config = config::load_config(); let config: Config = config::load_config();
let addr = std::net::SocketAddr::new(config.host, config.port); let addr = std::net::SocketAddr::new(config.host, config.port);
let shutdown_timeout = std::time::Duration::from_secs(config.shutdown_timeout_seconds as u64);
let auth = AuthRegistry::new(&config).expect("auth initializer"); let auth = AuthRegistry::new(&config).expect("auth initializer");
let app = Router::new() let app = Router::new()
@@ -36,13 +40,57 @@ async fn main() {
.layer(CookieLayer::default()); .layer(CookieLayer::default());
let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) // coordinated graceful shutdown with timeout
.await let notify = Arc::new(Notify::new());
.unwrap(); let (tx_signal, rx_signal) = watch::channel::<Option<Instant>>(None);
{
let notify = notify.clone();
let tx = tx_signal.clone();
tokio::spawn(async move {
let signaled_at = shutdown_signal().await;
let _ = tx.send(Some(signaled_at));
notify.notify_waiters();
});
}
let mut rx_for_timeout = rx_signal.clone();
let timeout_task = async move {
// wait until first signal observed
while rx_for_timeout.borrow().is_none() {
if rx_for_timeout.changed().await.is_err() {
return; // channel closed
}
}
tokio::time::sleep(shutdown_timeout).await;
eprintln!("shutdown timeout elapsed (>{:.2?}) - forcing exit", shutdown_timeout);
std::process::exit(1);
};
let notify_for_server = notify.clone();
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
notify_for_server.notified().await;
});
tokio::select! {
res = server => {
// server finished; if we had a signal, print remaining time
let now = Instant::now();
if let Some(signaled_at) = *rx_signal.borrow() {
let elapsed = now.duration_since(signaled_at);
if elapsed < shutdown_timeout {
let remaining = shutdown_timeout - elapsed;
eprintln!("graceful shutdown complete, remaining time: {:.2?}", remaining);
}
}
res.unwrap();
}
_ = timeout_task => {}
}
} }
async fn shutdown_signal() { async fn shutdown_signal() -> Instant {
let ctrl_c = async { let ctrl_c = async {
tokio::signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); tokio::signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
eprintln!("received Ctrl+C, shutting down"); eprintln!("received Ctrl+C, shutting down");
@@ -59,7 +107,7 @@ async fn shutdown_signal() {
let sigterm = std::future::pending::<()>(); let sigterm = std::future::pending::<()>();
tokio::select! { tokio::select! {
_ = ctrl_c => {} _ = ctrl_c => { Instant::now() }
_ = sigterm => {} _ = sigterm => { Instant::now() }
} }
} }