Compare commits

...

2 Commits

Author SHA1 Message Date
Ryan Walters
47c23459f1 refactor: implement comprehensive graceful shutdown across all services
Implements graceful shutdown with broadcast channels and proper timeout handling
for scraper workers, scheduler, bot service, and status update tasks. Introduces
centralized shutdown utilities and improves service manager to handle parallel
shutdown with per-service timeouts instead of shared timeout budgets.

Key changes:
- Add utils module with shutdown helper functions
- Update ScraperService to return errors on shutdown failures
- Refactor scheduler with cancellable work tasks and 5s grace period
- Extract worker shutdown logic into helper methods for clarity
- Add broadcast channel shutdown support to BotService and status task
- Improve ServiceManager to shutdown services in parallel with individual timeouts
2025-11-03 02:10:01 -06:00
Ryan Walters
8af9b0a1a2 refactor(scraper): implement graceful shutdown with broadcast channels
Replace task abortion with broadcast-based graceful shutdown for scheduler and workers. Implement cancellation tokens for in-progress work with 5s timeout. Add tokio-util dependency for CancellationToken support. Update ServiceManager to use completion channels and abort handles for better service lifecycle control.
2025-11-03 01:22:12 -06:00
13 changed files with 481 additions and 232 deletions

1
Cargo.lock generated
View File

@@ -254,6 +254,7 @@ dependencies = [
"time", "time",
"tl", "tl",
"tokio", "tokio",
"tokio-util",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",

View File

@@ -36,6 +36,7 @@ sqlx = { version = "0.8.6", features = [
thiserror = "2.0.16" thiserror = "2.0.16"
time = "0.3.43" time = "0.3.43"
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
tokio-util = "0.7"
tl = "0.7.8" tl = "0.7.8"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }

View File

@@ -42,11 +42,7 @@ impl App {
// Check if the database URL is via private networking // Check if the database URL is via private networking
let is_private = config.database_url.contains("railway.internal"); let is_private = config.database_url.contains("railway.internal");
let slow_threshold = if is_private { let slow_threshold = Duration::from_millis(if is_private { 200 } else { 500 });
Duration::from_millis(200)
} else {
Duration::from_millis(500)
};
// Create database connection pool // Create database connection pool
let db_pool = PgPoolOptions::new() let db_pool = PgPoolOptions::new()
@@ -108,10 +104,6 @@ impl App {
.register_service(ServiceName::Scraper.as_str(), scraper_service); .register_service(ServiceName::Scraper.as_str(), scraper_service);
} }
if services.contains(&ServiceName::Bot) {
// Bot service will be set up separately in run() method since it's async
}
// Check if any services are enabled // Check if any services are enabled
if !self.service_manager.has_services() && !services.contains(&ServiceName::Bot) { if !self.service_manager.has_services() && !services.contains(&ServiceName::Bot) {
error!("No services enabled. Cannot start application."); error!("No services enabled. Cannot start application.");
@@ -123,10 +115,28 @@ impl App {
/// Setup bot service if enabled /// Setup bot service if enabled
pub async fn setup_bot_service(&mut self) -> Result<(), anyhow::Error> { pub async fn setup_bot_service(&mut self) -> Result<(), anyhow::Error> {
let client = BotService::create_client(&self.config, self.app_state.clone()) use std::sync::Arc;
.await use tokio::sync::{broadcast, Mutex};
.expect("Failed to create Discord client");
let bot_service = Box::new(BotService::new(client)); // Create shutdown channel for status update task
let (status_shutdown_tx, status_shutdown_rx) = broadcast::channel(1);
let status_task_handle = Arc::new(Mutex::new(None));
let client = BotService::create_client(
&self.config,
self.app_state.clone(),
status_task_handle.clone(),
status_shutdown_rx,
)
.await
.expect("Failed to create Discord client");
let bot_service = Box::new(BotService::new(
client,
status_task_handle,
status_shutdown_tx,
));
self.service_manager self.service_manager
.register_service(ServiceName::Bot.as_str(), bot_service); .register_service(ServiceName::Bot.as_str(), bot_service);
Ok(()) Ok(())
@@ -147,9 +157,4 @@ impl App {
pub fn config(&self) -> &Config { pub fn config(&self) -> &Config {
&self.config &self.config
} }
/// Get a reference to the app state
pub fn app_state(&self) -> &AppState {
&self.app_state
}
} }

View File

@@ -11,4 +11,5 @@ pub mod scraper;
pub mod services; pub mod services;
pub mod signals; pub mod signals;
pub mod state; pub mod state;
pub mod utils;
pub mod web; pub mod web;

View File

@@ -3,14 +3,15 @@ pub mod scheduler;
pub mod worker; pub mod worker;
use crate::banner::BannerApi; use crate::banner::BannerApi;
use crate::services::Service;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::info; use tracing::{info, warn};
use self::scheduler::Scheduler; use self::scheduler::Scheduler;
use self::worker::Worker; use self::worker::Worker;
use crate::services::Service;
/// The main service that will be managed by the application's `ServiceManager`. /// The main service that will be managed by the application's `ServiceManager`.
/// ///
@@ -21,6 +22,7 @@ pub struct ScraperService {
banner_api: Arc<BannerApi>, banner_api: Arc<BannerApi>,
scheduler_handle: Option<JoinHandle<()>>, scheduler_handle: Option<JoinHandle<()>>,
worker_handles: Vec<JoinHandle<()>>, worker_handles: Vec<JoinHandle<()>>,
shutdown_tx: Option<broadcast::Sender<()>>,
} }
impl ScraperService { impl ScraperService {
@@ -31,6 +33,7 @@ impl ScraperService {
banner_api, banner_api,
scheduler_handle: None, scheduler_handle: None,
worker_handles: Vec::new(), worker_handles: Vec::new(),
shutdown_tx: None,
} }
} }
@@ -38,9 +41,14 @@ impl ScraperService {
pub fn start(&mut self) { pub fn start(&mut self) {
info!("ScraperService starting"); info!("ScraperService starting");
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel(1);
self.shutdown_tx = Some(shutdown_tx.clone());
let scheduler = Scheduler::new(self.db_pool.clone(), self.banner_api.clone()); let scheduler = Scheduler::new(self.db_pool.clone(), self.banner_api.clone());
let shutdown_rx = shutdown_tx.subscribe();
let scheduler_handle = tokio::spawn(async move { let scheduler_handle = tokio::spawn(async move {
scheduler.run().await; scheduler.run(shutdown_rx).await;
}); });
self.scheduler_handle = Some(scheduler_handle); self.scheduler_handle = Some(scheduler_handle);
info!("Scheduler task spawned"); info!("Scheduler task spawned");
@@ -48,8 +56,9 @@ impl ScraperService {
let worker_count = 4; // This could be configurable let worker_count = 4; // This could be configurable
for i in 0..worker_count { for i in 0..worker_count {
let worker = Worker::new(i, self.db_pool.clone(), self.banner_api.clone()); let worker = Worker::new(i, self.db_pool.clone(), self.banner_api.clone());
let shutdown_rx = shutdown_tx.subscribe();
let worker_handle = tokio::spawn(async move { let worker_handle = tokio::spawn(async move {
worker.run().await; worker.run(shutdown_rx).await;
}); });
self.worker_handles.push(worker_handle); self.worker_handles.push(worker_handle);
} }
@@ -59,17 +68,6 @@ impl ScraperService {
); );
} }
/// Signals all child tasks to gracefully shut down.
pub async fn shutdown(&mut self) {
info!("Shutting down scraper service");
if let Some(handle) = self.scheduler_handle.take() {
handle.abort();
}
for handle in self.worker_handles.drain(..) {
handle.abort();
}
info!("Scraper service shutdown");
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -85,7 +83,32 @@ impl Service for ScraperService {
} }
async fn shutdown(&mut self) -> Result<(), anyhow::Error> { async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
self.shutdown().await; info!("Shutting down scraper service");
// Send shutdown signal to all tasks
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
} else {
warn!("No shutdown channel found for scraper service");
return Err(anyhow::anyhow!("No shutdown channel available"));
}
// Collect all handles
let mut all_handles = Vec::new();
if let Some(handle) = self.scheduler_handle.take() {
all_handles.push(handle);
}
all_handles.append(&mut self.worker_handles);
// Wait for all tasks to complete (no internal timeout - let ServiceManager handle it)
let results = futures::future::join_all(all_handles).await;
let failed = results.iter().filter(|r| r.is_err()).count();
if failed > 0 {
warn!(failed_count = failed, "Some scraper tasks panicked during shutdown");
return Err(anyhow::anyhow!("{} task(s) panicked", failed));
}
info!("All scraper tasks shutdown gracefully");
Ok(()) Ok(())
} }
} }

View File

@@ -6,8 +6,10 @@ use serde_json::json;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::broadcast;
use tokio::time; use tokio::time;
use tracing::{debug, error, info, trace}; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn};
/// Periodically analyzes data and enqueues prioritized scrape jobs. /// Periodically analyzes data and enqueues prioritized scrape jobs.
pub struct Scheduler { pub struct Scheduler {
@@ -23,22 +25,81 @@ impl Scheduler {
} }
} }
/// Runs the scheduler's main loop. /// Runs the scheduler's main loop with graceful shutdown support.
pub async fn run(&self) { ///
/// The scheduler wakes up every 60 seconds to analyze data and enqueue jobs.
/// When a shutdown signal is received:
/// 1. Any in-progress scheduling work is gracefully cancelled via CancellationToken
/// 2. The scheduler waits up to 5 seconds for work to complete
/// 3. If timeout occurs, the task is abandoned (it will be aborted when dropped)
///
/// This ensures that shutdown is responsive even if scheduling work is blocked.
pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) {
info!("Scheduler service started"); info!("Scheduler service started");
let mut interval = time::interval(Duration::from_secs(60)); // Runs every minute
let work_interval = Duration::from_secs(60);
let mut next_run = time::Instant::now();
let mut current_work: Option<(tokio::task::JoinHandle<()>, CancellationToken)> = None;
loop { loop {
interval.tick().await; tokio::select! {
// Scheduler analyzing data... _ = time::sleep_until(next_run) => {
if let Err(e) = self.schedule_jobs().await { let cancel_token = CancellationToken::new();
error!(error = ?e, "Failed to schedule jobs");
// Spawn work in separate task to allow graceful cancellation during shutdown.
// Without this, shutdown would have to wait for the full scheduling cycle.
let work_handle = tokio::spawn({
let db_pool = self.db_pool.clone();
let banner_api = self.banner_api.clone();
let cancel_token = cancel_token.clone();
async move {
tokio::select! {
result = Self::schedule_jobs_impl(&db_pool, &banner_api) => {
if let Err(e) = result {
error!(error = ?e, "Failed to schedule jobs");
}
}
_ = cancel_token.cancelled() => {
debug!("Scheduling work cancelled gracefully");
}
}
}
});
current_work = Some((work_handle, cancel_token));
next_run = time::Instant::now() + work_interval;
}
_ = shutdown_rx.recv() => {
info!("Scheduler received shutdown signal");
if let Some((handle, cancel_token)) = current_work.take() {
cancel_token.cancel();
// Wait briefly for graceful completion
if tokio::time::timeout(Duration::from_secs(5), handle).await.is_err() {
warn!("Scheduling work did not complete within 5s, abandoning");
} else {
debug!("Scheduling work completed gracefully");
}
}
info!("Scheduler exiting gracefully");
break;
}
} }
} }
} }
/// The core logic for deciding what jobs to create. /// Core scheduling logic that analyzes data and creates scrape jobs.
async fn schedule_jobs(&self) -> Result<()> { ///
/// Strategy:
/// 1. Fetch all subjects for the current term from Banner API
/// 2. Query existing jobs in a single batch query
/// 3. Create jobs only for subjects that don't have pending jobs
///
/// This is a static method (not &self) to allow it to be called from spawned tasks.
async fn schedule_jobs_impl(db_pool: &PgPool, banner_api: &BannerApi) -> Result<()> {
// For now, we will implement a simple baseline scheduling strategy: // For now, we will implement a simple baseline scheduling strategy:
// 1. Get a list of all subjects from the Banner API. // 1. Get a list of all subjects from the Banner API.
// 2. Query existing jobs for all subjects in a single query. // 2. Query existing jobs for all subjects in a single query.
@@ -47,7 +108,7 @@ impl Scheduler {
debug!(term = term, "Enqueuing subject jobs"); debug!(term = term, "Enqueuing subject jobs");
let subjects = self.banner_api.get_subjects("", &term, 1, 500).await?; let subjects = banner_api.get_subjects("", &term, 1, 500).await?;
debug!( debug!(
subject_count = subjects.len(), subject_count = subjects.len(),
"Retrieved subjects from API" "Retrieved subjects from API"
@@ -61,12 +122,12 @@ impl Scheduler {
// Query existing jobs for all subjects in a single query // Query existing jobs for all subjects in a single query
let existing_jobs: Vec<(serde_json::Value,)> = sqlx::query_as( let existing_jobs: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT target_payload FROM scrape_jobs "SELECT target_payload FROM scrape_jobs
WHERE target_type = $1 AND target_payload = ANY($2) AND locked_at IS NULL", WHERE target_type = $1 AND target_payload = ANY($2) AND locked_at IS NULL",
) )
.bind(TargetType::Subject) .bind(TargetType::Subject)
.bind(&subject_payloads) .bind(&subject_payloads)
.fetch_all(&self.db_pool) .fetch_all(db_pool)
.await?; .await?;
// Convert to a HashSet for efficient lookup // Convert to a HashSet for efficient lookup
@@ -95,7 +156,7 @@ impl Scheduler {
// Insert all new jobs in a single batch // Insert all new jobs in a single batch
if !new_jobs.is_empty() { if !new_jobs.is_empty() {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let mut tx = self.db_pool.begin().await?; let mut tx = db_pool.begin().await?;
for (payload, subject_code) in new_jobs { for (payload, subject_code) in new_jobs {
sqlx::query( sqlx::query(

View File

@@ -5,6 +5,7 @@ use crate::scraper::jobs::{JobError, JobType};
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::broadcast;
use tokio::time; use tokio::time;
use tracing::{debug, error, info, trace, warn}; use tracing::{debug, error, info, trace, warn};
@@ -28,79 +29,47 @@ impl Worker {
} }
/// Runs the worker's main loop. /// Runs the worker's main loop.
pub async fn run(&self) { pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) {
info!(worker_id = self.id, "Worker started."); info!(worker_id = self.id, "Worker started");
loop {
match self.fetch_and_lock_job().await {
Ok(Some(job)) => {
let job_id = job.id;
debug!(worker_id = self.id, job_id = job.id, "Processing job");
match self.process_job(job).await {
Ok(()) => {
debug!(worker_id = self.id, job_id, "Job completed");
// If successful, delete the job.
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete job"
);
}
}
Err(JobError::Recoverable(e)) => {
// Check if the error is due to an invalid session
if let Some(BannerApiError::InvalidSession(_)) =
e.downcast_ref::<BannerApiError>()
{
warn!(
worker_id = self.id,
job_id, "Invalid session detected. Forcing session refresh."
);
} else {
error!(worker_id = self.id, job_id, error = ?e, "Failed to process job");
}
// Unlock the job so it can be retried loop {
if let Err(unlock_err) = self.unlock_job(job_id).await { // Fetch and lock a job, racing against shutdown signal
error!( let job = tokio::select! {
worker_id = self.id, _ = shutdown_rx.recv() => {
job_id, info!(worker_id = self.id, "Worker received shutdown signal, exiting gracefully");
?unlock_err, break;
"Failed to unlock job" }
); result = self.fetch_and_lock_job() => {
} match result {
Ok(Some(job)) => job,
Ok(None) => {
trace!(worker_id = self.id, "No jobs available, waiting");
time::sleep(Duration::from_secs(5)).await;
continue;
} }
Err(JobError::Unrecoverable(e)) => { Err(e) => {
error!( warn!(worker_id = self.id, error = ?e, "Failed to fetch job, waiting");
worker_id = self.id, time::sleep(Duration::from_secs(10)).await;
job_id, continue;
error = ?e,
"Job corrupted, deleting"
);
// Parse errors are unrecoverable - delete the job
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete corrupted job"
);
}
} }
} }
} }
Ok(None) => { };
// No job found, wait for a bit before polling again.
trace!(worker_id = self.id, "No jobs available, waiting"); let job_id = job.id;
time::sleep(Duration::from_secs(5)).await; debug!(worker_id = self.id, job_id, "Processing job");
// Process the job, racing against shutdown signal
let process_result = tokio::select! {
_ = shutdown_rx.recv() => {
self.handle_shutdown_during_processing(job_id).await;
break;
} }
Err(e) => { result = self.process_job(job) => result
warn!(worker_id = self.id, error = ?e, "Failed to fetch job"); };
// Wait before retrying to avoid spamming errors.
time::sleep(Duration::from_secs(10)).await; // Handle the job processing result
} self.handle_job_result(job_id, process_result).await;
}
} }
} }
@@ -169,4 +138,59 @@ impl Worker {
info!(worker_id = self.id, job_id, "Job unlocked for retry"); info!(worker_id = self.id, job_id, "Job unlocked for retry");
Ok(()) Ok(())
} }
/// Handle shutdown signal received during job processing
async fn handle_shutdown_during_processing(&self, job_id: i32) {
info!(worker_id = self.id, job_id, "Shutdown received during job processing");
if let Err(e) = self.unlock_job(job_id).await {
warn!(
worker_id = self.id,
job_id,
error = ?e,
"Failed to unlock job during shutdown"
);
} else {
debug!(worker_id = self.id, job_id, "Job unlocked during shutdown");
}
info!(worker_id = self.id, "Worker exiting gracefully");
}
/// Handle the result of job processing
async fn handle_job_result(&self, job_id: i32, result: Result<(), JobError>) {
match result {
Ok(()) => {
debug!(worker_id = self.id, job_id, "Job completed successfully");
if let Err(e) = self.delete_job(job_id).await {
error!(worker_id = self.id, job_id, error = ?e, "Failed to delete completed job");
}
}
Err(JobError::Recoverable(e)) => {
self.handle_recoverable_error(job_id, e).await;
}
Err(JobError::Unrecoverable(e)) => {
error!(worker_id = self.id, job_id, error = ?e, "Job corrupted, deleting");
if let Err(e) = self.delete_job(job_id).await {
error!(worker_id = self.id, job_id, error = ?e, "Failed to delete corrupted job");
}
}
}
}
/// Handle recoverable errors by logging appropriately and unlocking the job
async fn handle_recoverable_error(&self, job_id: i32, e: anyhow::Error) {
if let Some(BannerApiError::InvalidSession(_)) = e.downcast_ref::<BannerApiError>() {
warn!(
worker_id = self.id,
job_id, "Invalid session detected, forcing session refresh"
);
} else {
error!(worker_id = self.id, job_id, error = ?e, "Failed to process job");
}
if let Err(e) = self.unlock_job(job_id).await {
error!(worker_id = self.id, job_id, error = ?e, "Failed to unlock job");
}
}
} }

View File

@@ -7,12 +7,16 @@ use serenity::Client;
use serenity::all::{ActivityData, ClientBuilder, GatewayIntents}; use serenity::all::{ActivityData, ClientBuilder, GatewayIntents};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{debug, error, warn}; use tokio::sync::{broadcast, Mutex};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
/// Discord bot service implementation /// Discord bot service implementation
pub struct BotService { pub struct BotService {
client: Client, client: Client,
shard_manager: Arc<serenity::gateway::ShardManager>, shard_manager: Arc<serenity::gateway::ShardManager>,
status_task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
status_shutdown_tx: Option<broadcast::Sender<()>>,
} }
impl BotService { impl BotService {
@@ -20,6 +24,8 @@ impl BotService {
pub async fn create_client( pub async fn create_client(
config: &Config, config: &Config,
app_state: AppState, app_state: AppState,
status_task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
status_shutdown_rx: broadcast::Receiver<()>,
) -> Result<Client, anyhow::Error> { ) -> Result<Client, anyhow::Error> {
let intents = GatewayIntents::non_privileged(); let intents = GatewayIntents::non_privileged();
let bot_target_guild = config.bot_target_guild; let bot_target_guild = config.bot_target_guild;
@@ -74,6 +80,7 @@ impl BotService {
}) })
.setup(move |ctx, _ready, framework| { .setup(move |ctx, _ready, framework| {
let app_state = app_state.clone(); let app_state = app_state.clone();
let status_task_handle = status_task_handle.clone();
Box::pin(async move { Box::pin(async move {
poise::builtins::register_in_guild( poise::builtins::register_in_guild(
ctx, ctx,
@@ -83,8 +90,9 @@ impl BotService {
.await?; .await?;
poise::builtins::register_globally(ctx, &framework.options().commands).await?; poise::builtins::register_globally(ctx, &framework.options().commands).await?;
// Start status update task // Start status update task with shutdown support
Self::start_status_update_task(ctx.clone(), app_state.clone()).await; let handle = Self::start_status_update_task(ctx.clone(), app_state.clone(), status_shutdown_rx);
*status_task_handle.lock().await = Some(handle);
Ok(Data { app_state }) Ok(Data { app_state })
}) })
@@ -96,8 +104,12 @@ impl BotService {
.await?) .await?)
} }
/// Start the status update task for the Discord bot /// Start the status update task for the Discord bot with graceful shutdown support
async fn start_status_update_task(ctx: serenity::client::Context, app_state: AppState) { fn start_status_update_task(
ctx: serenity::client::Context,
app_state: AppState,
mut shutdown_rx: broadcast::Receiver<()>,
) -> JoinHandle<()> {
tokio::spawn(async move { tokio::spawn(async move {
let max_interval = Duration::from_secs(300); // 5 minutes let max_interval = Duration::from_secs(300); // 5 minutes
let base_interval = Duration::from_secs(30); let base_interval = Duration::from_secs(30);
@@ -106,59 +118,72 @@ impl BotService {
// This runs once immediately on startup, then with adaptive intervals // This runs once immediately on startup, then with adaptive intervals
loop { loop {
interval.tick().await; tokio::select! {
_ = interval.tick() => {
// Get the course count, update the activity if it has changed/hasn't been set this session // Get the course count, update the activity if it has changed/hasn't been set this session
let course_count = app_state.get_course_count().await.unwrap(); let course_count = app_state.get_course_count().await.unwrap();
if previous_course_count.is_none() || previous_course_count != Some(course_count) { if previous_course_count.is_none() || previous_course_count != Some(course_count) {
ctx.set_activity(Some(ActivityData::playing(format!( ctx.set_activity(Some(ActivityData::playing(format!(
"Querying {:} classes", "Querying {:} classes",
course_count.to_formatted_string(&Locale::en) course_count.to_formatted_string(&Locale::en)
)))); ))));
}
// Increase or reset the interval
interval = tokio::time::interval(
// Avoid logging the first 'change'
if course_count != previous_course_count.unwrap_or(0) {
if previous_course_count.is_some() {
debug!(
new_course_count = course_count,
last_interval = interval.period().as_secs(),
"Course count changed, resetting interval"
);
} }
// Record the new course count // Increase or reset the interval
previous_course_count = Some(course_count); interval = tokio::time::interval(
// Avoid logging the first 'change'
if course_count != previous_course_count.unwrap_or(0) {
if previous_course_count.is_some() {
debug!(
new_course_count = course_count,
last_interval = interval.period().as_secs(),
"Course count changed, resetting interval"
);
}
// Reset to base interval // Record the new course count
base_interval previous_course_count = Some(course_count);
} else {
// Increase interval by 10% (up to maximum) // Reset to base interval
let new_interval = interval.period().mul_f32(1.1).min(max_interval); base_interval
debug!( } else {
current_course_count = course_count, // Increase interval by 10% (up to maximum)
last_interval = interval.period().as_secs(), let new_interval = interval.period().mul_f32(1.1).min(max_interval);
new_interval = new_interval.as_secs(), debug!(
"Course count unchanged, increasing interval" current_course_count = course_count,
last_interval = interval.period().as_secs(),
new_interval = new_interval.as_secs(),
"Course count unchanged, increasing interval"
);
new_interval
},
); );
new_interval // Reset the interval, otherwise it will tick again immediately
}, interval.reset();
); }
_ = shutdown_rx.recv() => {
// Reset the interval, otherwise it will tick again immediately info!("Status update task received shutdown signal");
interval.reset(); break;
}
}
} }
}); })
} }
pub fn new(client: Client) -> Self { pub fn new(
client: Client,
status_task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
status_shutdown_tx: broadcast::Sender<()>,
) -> Self {
let shard_manager = client.shard_manager.clone(); let shard_manager = client.shard_manager.clone();
Self { Self {
client, client,
shard_manager, shard_manager,
status_task_handle,
status_shutdown_tx: Some(status_shutdown_tx),
} }
} }
} }
@@ -183,6 +208,28 @@ impl Service for BotService {
} }
async fn shutdown(&mut self) -> Result<(), anyhow::Error> { async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
// Signal status update task to stop
if let Some(status_shutdown_tx) = self.status_shutdown_tx.take() {
let _ = status_shutdown_tx.send(());
}
// Wait for status update task to complete (with timeout)
let handle = self.status_task_handle.lock().await.take();
if let Some(handle) = handle {
match tokio::time::timeout(Duration::from_secs(2), handle).await {
Ok(Ok(())) => {
debug!("Status update task completed gracefully");
}
Ok(Err(e)) => {
warn!(error = ?e, "Status update task panicked");
}
Err(_) => {
warn!("Status update task did not complete within 2s timeout");
}
}
}
// Shutdown Discord shards
self.shard_manager.shutdown_all().await; self.shard_manager.shutdown_all().await;
Ok(()) Ok(())
} }

View File

@@ -1,15 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tokio::sync::broadcast; use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle; use tracing::{debug, info, trace, warn};
use tracing::{debug, error, info, trace, warn};
use crate::services::{Service, ServiceResult, run_service}; use crate::services::{Service, ServiceResult, run_service};
/// Manages multiple services and their lifecycle /// Manages multiple services and their lifecycle
pub struct ServiceManager { pub struct ServiceManager {
registered_services: HashMap<String, Box<dyn Service>>, registered_services: HashMap<String, Box<dyn Service>>,
running_services: HashMap<String, JoinHandle<ServiceResult>>, service_handles: HashMap<String, tokio::task::AbortHandle>,
completion_rx: Option<mpsc::UnboundedReceiver<(String, ServiceResult)>>,
completion_tx: mpsc::UnboundedSender<(String, ServiceResult)>,
shutdown_tx: broadcast::Sender<()>, shutdown_tx: broadcast::Sender<()>,
} }
@@ -22,9 +23,13 @@ impl Default for ServiceManager {
impl ServiceManager { impl ServiceManager {
pub fn new() -> Self { pub fn new() -> Self {
let (shutdown_tx, _) = broadcast::channel(1); let (shutdown_tx, _) = broadcast::channel(1);
let (completion_tx, completion_rx) = mpsc::unbounded_channel();
Self { Self {
registered_services: HashMap::new(), registered_services: HashMap::new(),
running_services: HashMap::new(), service_handles: HashMap::new(),
completion_rx: Some(completion_rx),
completion_tx,
shutdown_tx, shutdown_tx,
} }
} }
@@ -46,9 +51,19 @@ impl ServiceManager {
for (name, service) in self.registered_services.drain() { for (name, service) in self.registered_services.drain() {
let shutdown_rx = self.shutdown_tx.subscribe(); let shutdown_rx = self.shutdown_tx.subscribe();
let handle = tokio::spawn(run_service(service, shutdown_rx)); let completion_tx = self.completion_tx.clone();
let name_clone = name.clone();
// Spawn service task
let handle = tokio::spawn(async move {
let result = run_service(service, shutdown_rx).await;
// Send completion notification
let _ = completion_tx.send((name_clone, result));
});
// Store abort handle for shutdown control
self.service_handles.insert(name.clone(), handle.abort_handle());
debug!(service = name, id = ?handle.id(), "service spawned"); debug!(service = name, id = ?handle.id(), "service spawned");
self.running_services.insert(name, handle);
} }
info!( info!(
@@ -62,7 +77,7 @@ impl ServiceManager {
/// Run all services until one completes or fails /// Run all services until one completes or fails
/// Returns the first service that completes and its result /// Returns the first service that completes and its result
pub async fn run(&mut self) -> (String, ServiceResult) { pub async fn run(&mut self) -> (String, ServiceResult) {
if self.running_services.is_empty() { if self.service_handles.is_empty() {
return ( return (
"none".to_string(), "none".to_string(),
ServiceResult::Error(anyhow::anyhow!("No services to run")), ServiceResult::Error(anyhow::anyhow!("No services to run")),
@@ -71,99 +86,134 @@ impl ServiceManager {
info!( info!(
"servicemanager running {} services", "servicemanager running {} services",
self.running_services.len() self.service_handles.len()
); );
// Wait for any service to complete // Wait for any service to complete via the channel
loop { let completion_rx = self
let mut completed_services = Vec::new(); .completion_rx
.as_mut()
.expect("completion_rx should be available");
for (name, handle) in &mut self.running_services { completion_rx
if handle.is_finished() { .recv()
completed_services.push(name.clone()); .await
} .map(|(name, result)| {
} self.service_handles.remove(&name);
(name, result)
if let Some(completed_name) = completed_services.first() { })
let handle = self.running_services.remove(completed_name).unwrap(); .unwrap_or_else(|| {
match handle.await { (
Ok(result) => { "channel_closed".to_string(),
return (completed_name.clone(), result); ServiceResult::Error(anyhow::anyhow!("Completion channel closed")),
} )
Err(e) => { })
error!(service = completed_name, "service task panicked: {e}");
return (
completed_name.clone(),
ServiceResult::Error(anyhow::anyhow!("Task panic: {e}")),
);
}
}
}
// Small delay to prevent busy-waiting
tokio::time::sleep(Duration::from_millis(10)).await;
}
} }
/// Shutdown all services gracefully with a timeout. /// Shutdown all services gracefully with a timeout.
/// ///
/// If any service fails to shutdown, it will return an error containing the names of the services that failed to shutdown. /// All services receive the shutdown signal simultaneously and shut down in parallel.
/// If all services shutdown successfully, the function will return the duration elapsed. /// Each service gets the full timeout duration (they don't share/consume from a budget).
/// If any service fails to shutdown within the timeout, it will be aborted.
///
/// Returns the elapsed time if all succeed, or a list of failed service names.
pub async fn shutdown(&mut self, timeout: Duration) -> Result<Duration, Vec<String>> { pub async fn shutdown(&mut self, timeout: Duration) -> Result<Duration, Vec<String>> {
let service_count = self.running_services.len(); let service_count = self.service_handles.len();
let service_names: Vec<_> = self.running_services.keys().cloned().collect(); let service_names: Vec<_> = self.service_handles.keys().cloned().collect();
info!( info!(
service_count, service_count,
services = ?service_names, services = ?service_names,
timeout = format!("{:.2?}", timeout), timeout = format!("{:.2?}", timeout),
"shutting down {} services with {:?} timeout", "shutting down {} services in parallel with {:?} timeout each",
service_count, service_count,
timeout timeout
); );
// Send shutdown signal to all services if service_count == 0 {
return Ok(Duration::ZERO);
}
// Send shutdown signal to all services simultaneously
let _ = self.shutdown_tx.send(()); let _ = self.shutdown_tx.send(());
// Wait for all services to complete
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let mut pending_services = Vec::new();
for (name, handle) in self.running_services.drain() { // Collect results from all services with timeout
match tokio::time::timeout(timeout, handle).await { let completion_rx = self
Ok(Ok(_)) => { .completion_rx
trace!(service = name, "service shutdown completed"); .as_mut()
.expect("completion_rx should be available");
// Collect all completion results with a single timeout
let collect_future = async {
let mut collected: Vec<Option<(String, ServiceResult)>> = Vec::new();
for _ in 0..service_count {
if let Some(result) = completion_rx.recv().await {
collected.push(Some(result));
} else {
collected.push(None);
} }
Ok(Err(e)) => { }
warn!(service = name, error = ?e, "service shutdown failed"); collected
pending_services.push(name); };
}
Err(_) => { let results = match tokio::time::timeout(timeout, collect_future).await {
warn!(service = name, "service shutdown timed out"); Ok(results) => results,
pending_services.push(name); Err(_) => {
// Timeout exceeded - abort all remaining services
warn!(
timeout = format!("{:.2?}", timeout),
"shutdown timeout exceeded - aborting all remaining services"
);
let failed: Vec<String> = self.service_handles.keys().cloned().collect();
for handle in self.service_handles.values() {
handle.abort();
} }
self.service_handles.clear();
return Err(failed);
}
};
// Process results and identify failures
let mut failed_services = Vec::new();
for (name, service_result) in results.into_iter().flatten() {
self.service_handles.remove(&name);
if matches!(service_result, ServiceResult::GracefulShutdown) {
trace!(service = name, "service shutdown completed");
} else {
warn!(
service = name,
result = ?service_result,
"service shutdown with non-graceful result"
);
failed_services.push(name);
} }
} }
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
if pending_services.is_empty() {
if failed_services.is_empty() {
info!( info!(
service_count, service_count,
elapsed = format!("{:.2?}", elapsed), elapsed = format!("{:.2?}", elapsed),
"services shutdown completed: {}", "all services shutdown successfully: {}",
service_names.join(", ") service_names.join(", ")
); );
Ok(elapsed) Ok(elapsed)
} else { } else {
warn!( warn!(
pending_count = pending_services.len(), failed_count = failed_services.len(),
pending_services = ?pending_services, failed_services = ?failed_services,
elapsed = format!("{:.2?}", elapsed), elapsed = format!("{:.2?}", elapsed),
"services shutdown completed with {} pending: {}", "{} service(s) failed to shutdown gracefully: {}",
pending_services.len(), failed_services.len(),
pending_services.join(", ") failed_services.join(", ")
); );
Err(pending_services) Err(failed_services)
} }
} }
} }

View File

@@ -23,7 +23,11 @@ pub trait Service: Send + Sync {
/// Gracefully shutdown the service /// Gracefully shutdown the service
/// ///
/// An 'Ok' result does not mean the service has completed shutdown, it merely means that the service shutdown was initiated. /// Implementations should initiate shutdown and MAY wait for completion.
/// Services are expected to respond to this call and begin cleanup promptly.
/// When managed by ServiceManager, the configured timeout (default 8s) applies to
/// ALL services combined, not per-service. Services should complete shutdown as
/// quickly as possible to avoid timeout.
async fn shutdown(&mut self) -> Result<(), anyhow::Error>; async fn shutdown(&mut self) -> Result<(), anyhow::Error>;
} }

View File

@@ -3,7 +3,7 @@ use crate::web::{BannerState, create_router};
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tracing::{info, warn, trace}; use tracing::{info, trace, warn};
/// Web server service implementation /// Web server service implementation
pub struct WebService { pub struct WebService {
@@ -33,16 +33,12 @@ impl Service for WebService {
let app = create_router(self.banner_state.clone()); let app = create_router(self.banner_state.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
info!(
service = "web",
link = format!("http://localhost:{}", addr.port()),
"starting web server",
);
let listener = TcpListener::bind(addr).await?; let listener = TcpListener::bind(addr).await?;
info!( info!(
service = "web", service = "web",
address = %addr, address = %addr,
link = format!("http://localhost:{}", addr.port()),
"web server listening" "web server listening"
); );
@@ -61,13 +57,16 @@ impl Service for WebService {
}) })
.await?; .await?;
trace!(service = "web", "graceful shutdown completed");
info!(service = "web", "web server stopped"); info!(service = "web", "web server stopped");
Ok(()) Ok(())
} }
async fn shutdown(&mut self) -> Result<(), anyhow::Error> { async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
if let Some(shutdown_tx) = self.shutdown_tx.take() { if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(()); let _ = shutdown_tx.send(());
trace!(service = "web", "sent shutdown signal to axum");
} else { } else {
warn!( warn!(
service = "web", service = "web",

1
src/utils/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod shutdown;

32
src/utils/shutdown.rs Normal file
View File

@@ -0,0 +1,32 @@
use tokio::task::JoinHandle;
use tracing::warn;
/// Helper for joining multiple task handles with proper error handling.
///
/// This function waits for all tasks to complete and reports any that panicked.
/// Returns an error if any task panicked, otherwise returns Ok.
pub async fn join_tasks(handles: Vec<JoinHandle<()>>) -> Result<(), anyhow::Error> {
let results = futures::future::join_all(handles).await;
let failed = results.iter().filter(|r| r.is_err()).count();
if failed > 0 {
warn!(failed_count = failed, "Some tasks panicked during shutdown");
Err(anyhow::anyhow!("{} task(s) panicked", failed))
} else {
Ok(())
}
}
/// Helper for joining multiple task handles with a timeout.
///
/// Waits for all tasks to complete within the specified timeout.
/// If timeout occurs, remaining tasks are aborted.
pub async fn join_tasks_with_timeout(
handles: Vec<JoinHandle<()>>,
timeout: std::time::Duration,
) -> Result<(), anyhow::Error> {
match tokio::time::timeout(timeout, join_tasks(handles)).await {
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!("Task join timed out after {:?}", timeout)),
}
}