diff --git a/src/app.rs b/src/app.rs index df42579..045415e 100644 --- a/src/app.rs +++ b/src/app.rs @@ -115,10 +115,28 @@ impl App { /// Setup bot service if enabled pub async fn setup_bot_service(&mut self) -> Result<(), anyhow::Error> { - let client = BotService::create_client(&self.config, self.app_state.clone()) - .await - .expect("Failed to create Discord client"); - let bot_service = Box::new(BotService::new(client)); + use std::sync::Arc; + use tokio::sync::{broadcast, Mutex}; + + // Create shutdown channel for status update task + let (status_shutdown_tx, status_shutdown_rx) = broadcast::channel(1); + let status_task_handle = Arc::new(Mutex::new(None)); + + let client = BotService::create_client( + &self.config, + self.app_state.clone(), + status_task_handle.clone(), + status_shutdown_rx, + ) + .await + .expect("Failed to create Discord client"); + + let bot_service = Box::new(BotService::new( + client, + status_task_handle, + status_shutdown_tx, + )); + self.service_manager .register_service(ServiceName::Bot.as_str(), bot_service); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 2ed65f3..241783d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,4 +11,5 @@ pub mod scraper; pub mod services; pub mod signals; pub mod state; +pub mod utils; pub mod web; diff --git a/src/scraper/mod.rs b/src/scraper/mod.rs index d428430..b985083 100644 --- a/src/scraper/mod.rs +++ b/src/scraper/mod.rs @@ -3,16 +3,15 @@ pub mod scheduler; pub mod worker; use crate::banner::BannerApi; +use crate::services::Service; use sqlx::PgPool; use std::sync::Arc; -use std::time::Duration; use tokio::sync::broadcast; use tokio::task::JoinHandle; use tracing::{info, warn}; use self::scheduler::Scheduler; use self::worker::Worker; -use crate::services::Service; /// The main service that will be managed by the application's `ServiceManager`. /// @@ -91,6 +90,7 @@ impl Service for ScraperService { let _ = shutdown_tx.send(()); } else { warn!("No shutdown channel found for scraper service"); + return Err(anyhow::anyhow!("No shutdown channel available")); } // Collect all handles @@ -100,31 +100,15 @@ impl Service for ScraperService { } all_handles.append(&mut self.worker_handles); - // Wait for all tasks to complete with a timeout - let timeout_duration = Duration::from_secs(5); - - match tokio::time::timeout( - timeout_duration, - futures::future::join_all(all_handles), - ) - .await - { - Ok(results) => { - let failed = results.iter().filter(|r| r.is_err()).count(); - if failed > 0 { - warn!(failed_count = failed, "Some scraper tasks failed during shutdown"); - } else { - info!("All scraper tasks shutdown gracefully"); - } - } - Err(_) => { - warn!( - timeout = format!("{:.2?}", timeout_duration), - "Scraper service shutdown timed out" - ); - } + // Wait for all tasks to complete (no internal timeout - let ServiceManager handle it) + let results = futures::future::join_all(all_handles).await; + let failed = results.iter().filter(|r| r.is_err()).count(); + if failed > 0 { + warn!(failed_count = failed, "Some scraper tasks panicked during shutdown"); + return Err(anyhow::anyhow!("{} task(s) panicked", failed)); } + info!("All scraper tasks shutdown gracefully"); Ok(()) } } diff --git a/src/scraper/scheduler.rs b/src/scraper/scheduler.rs index 36f17cc..3f61ab0 100644 --- a/src/scraper/scheduler.rs +++ b/src/scraper/scheduler.rs @@ -25,7 +25,15 @@ impl Scheduler { } } - /// Runs the scheduler's main loop. + /// Runs the scheduler's main loop with graceful shutdown support. + /// + /// The scheduler wakes up every 60 seconds to analyze data and enqueue jobs. + /// When a shutdown signal is received: + /// 1. Any in-progress scheduling work is gracefully cancelled via CancellationToken + /// 2. The scheduler waits up to 5 seconds for work to complete + /// 3. If timeout occurs, the task is abandoned (it will be aborted when dropped) + /// + /// This ensures that shutdown is responsive even if scheduling work is blocked. pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) { info!("Scheduler service started"); @@ -35,19 +43,17 @@ impl Scheduler { loop { tokio::select! { - // Sleep until next scheduled run - instantly cancellable _ = time::sleep_until(next_run) => { - // Create cancellation token for graceful task cancellation let cancel_token = CancellationToken::new(); - // Spawn scheduling work in a separate task for cancellability + // Spawn work in separate task to allow graceful cancellation during shutdown. + // Without this, shutdown would have to wait for the full scheduling cycle. let work_handle = tokio::spawn({ let db_pool = self.db_pool.clone(); let banner_api = self.banner_api.clone(); let cancel_token = cancel_token.clone(); async move { - // Check for cancellation while running tokio::select! { result = Self::schedule_jobs_impl(&db_pool, &banner_api) => { if let Err(e) = result { @@ -67,19 +73,14 @@ impl Scheduler { _ = shutdown_rx.recv() => { info!("Scheduler received shutdown signal"); - // Gracefully cancel any in-progress work if let Some((handle, cancel_token)) = current_work.take() { - // Signal cancellation cancel_token.cancel(); - // Wait for graceful completion with timeout - match time::timeout(Duration::from_secs(5), handle).await { - Ok(_) => { - debug!("Scheduling work completed gracefully"); - } - Err(_) => { - warn!("Scheduling work did not complete within 5s timeout, may have been aborted"); - } + // Wait briefly for graceful completion + if tokio::time::timeout(Duration::from_secs(5), handle).await.is_err() { + warn!("Scheduling work did not complete within 5s, abandoning"); + } else { + debug!("Scheduling work completed gracefully"); } } @@ -90,7 +91,14 @@ impl Scheduler { } } - /// The core logic for deciding what jobs to create. + /// Core scheduling logic that analyzes data and creates scrape jobs. + /// + /// Strategy: + /// 1. Fetch all subjects for the current term from Banner API + /// 2. Query existing jobs in a single batch query + /// 3. Create jobs only for subjects that don't have pending jobs + /// + /// This is a static method (not &self) to allow it to be called from spawned tasks. async fn schedule_jobs_impl(db_pool: &PgPool, banner_api: &BannerApi) -> Result<()> { // For now, we will implement a simple baseline scheduling strategy: // 1. Get a list of all subjects from the Banner API. diff --git a/src/scraper/worker.rs b/src/scraper/worker.rs index e25c07a..f252bb4 100644 --- a/src/scraper/worker.rs +++ b/src/scraper/worker.rs @@ -30,28 +30,25 @@ impl Worker { /// Runs the worker's main loop. pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) { - info!(worker_id = self.id, "Worker started."); + info!(worker_id = self.id, "Worker started"); loop { // Fetch and lock a job, racing against shutdown signal let job = tokio::select! { _ = shutdown_rx.recv() => { - info!(worker_id = self.id, "Worker received shutdown signal"); - info!(worker_id = self.id, "Worker exiting gracefully"); + info!(worker_id = self.id, "Worker received shutdown signal, exiting gracefully"); break; } result = self.fetch_and_lock_job() => { match result { Ok(Some(job)) => job, Ok(None) => { - // No job found, wait for a bit before polling again trace!(worker_id = self.id, "No jobs available, waiting"); time::sleep(Duration::from_secs(5)).await; continue; } Err(e) => { - warn!(worker_id = self.id, error = ?e, "Failed to fetch job"); - // Wait before retrying to avoid spamming errors + warn!(worker_id = self.id, error = ?e, "Failed to fetch job, waiting"); time::sleep(Duration::from_secs(10)).await; continue; } @@ -65,63 +62,14 @@ impl Worker { // Process the job, racing against shutdown signal let process_result = tokio::select! { _ = shutdown_rx.recv() => { - info!(worker_id = self.id, job_id, "Shutdown received during job processing"); - - // Unlock the job so it can be retried - if let Err(e) = self.unlock_job(job_id).await { - warn!( - worker_id = self.id, - job_id, - error = ?e, - "Failed to unlock job during shutdown" - ); - } else { - debug!(worker_id = self.id, job_id, "Job unlocked during shutdown"); - } - - info!(worker_id = self.id, "Worker exiting gracefully"); + self.handle_shutdown_during_processing(job_id).await; break; } - result = self.process_job(job) => { - result - } + result = self.process_job(job) => result }; // Handle the job processing result - match process_result { - Ok(()) => { - debug!(worker_id = self.id, job_id, "Job completed"); - // If successful, delete the job - if let Err(delete_err) = self.delete_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?delete_err, - "Failed to delete job" - ); - } - } - Err(JobError::Recoverable(e)) => { - self.handle_recoverable_error(job_id, e).await; - } - Err(JobError::Unrecoverable(e)) => { - error!( - worker_id = self.id, - job_id, - error = ?e, - "Job corrupted, deleting" - ); - // Parse errors are unrecoverable - delete the job - if let Err(delete_err) = self.delete_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?delete_err, - "Failed to delete corrupted job" - ); - } - } - } + self.handle_job_result(job_id, process_result).await; } } @@ -191,24 +139,58 @@ impl Worker { Ok(()) } + /// Handle shutdown signal received during job processing + async fn handle_shutdown_during_processing(&self, job_id: i32) { + info!(worker_id = self.id, job_id, "Shutdown received during job processing"); + + if let Err(e) = self.unlock_job(job_id).await { + warn!( + worker_id = self.id, + job_id, + error = ?e, + "Failed to unlock job during shutdown" + ); + } else { + debug!(worker_id = self.id, job_id, "Job unlocked during shutdown"); + } + + info!(worker_id = self.id, "Worker exiting gracefully"); + } + + /// Handle the result of job processing + async fn handle_job_result(&self, job_id: i32, result: Result<(), JobError>) { + match result { + Ok(()) => { + debug!(worker_id = self.id, job_id, "Job completed successfully"); + if let Err(e) = self.delete_job(job_id).await { + error!(worker_id = self.id, job_id, error = ?e, "Failed to delete completed job"); + } + } + Err(JobError::Recoverable(e)) => { + self.handle_recoverable_error(job_id, e).await; + } + Err(JobError::Unrecoverable(e)) => { + error!(worker_id = self.id, job_id, error = ?e, "Job corrupted, deleting"); + if let Err(e) = self.delete_job(job_id).await { + error!(worker_id = self.id, job_id, error = ?e, "Failed to delete corrupted job"); + } + } + } + } + /// Handle recoverable errors by logging appropriately and unlocking the job async fn handle_recoverable_error(&self, job_id: i32, e: anyhow::Error) { if let Some(BannerApiError::InvalidSession(_)) = e.downcast_ref::() { warn!( worker_id = self.id, - job_id, "Invalid session detected. Forcing session refresh." + job_id, "Invalid session detected, forcing session refresh" ); } else { error!(worker_id = self.id, job_id, error = ?e, "Failed to process job"); } - if let Err(unlock_err) = self.unlock_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?unlock_err, - "Failed to unlock job" - ); + if let Err(e) = self.unlock_job(job_id).await { + error!(worker_id = self.id, job_id, error = ?e, "Failed to unlock job"); } } } diff --git a/src/services/bot.rs b/src/services/bot.rs index 2409e5e..8ead7b5 100644 --- a/src/services/bot.rs +++ b/src/services/bot.rs @@ -7,12 +7,16 @@ use serenity::Client; use serenity::all::{ActivityData, ClientBuilder, GatewayIntents}; use std::sync::Arc; use std::time::Duration; -use tracing::{debug, error, warn}; +use tokio::sync::{broadcast, Mutex}; +use tokio::task::JoinHandle; +use tracing::{debug, error, info, warn}; /// Discord bot service implementation pub struct BotService { client: Client, shard_manager: Arc, + status_task_handle: Arc>>>, + status_shutdown_tx: Option>, } impl BotService { @@ -20,6 +24,8 @@ impl BotService { pub async fn create_client( config: &Config, app_state: AppState, + status_task_handle: Arc>>>, + status_shutdown_rx: broadcast::Receiver<()>, ) -> Result { let intents = GatewayIntents::non_privileged(); let bot_target_guild = config.bot_target_guild; @@ -74,6 +80,7 @@ impl BotService { }) .setup(move |ctx, _ready, framework| { let app_state = app_state.clone(); + let status_task_handle = status_task_handle.clone(); Box::pin(async move { poise::builtins::register_in_guild( ctx, @@ -83,8 +90,9 @@ impl BotService { .await?; poise::builtins::register_globally(ctx, &framework.options().commands).await?; - // Start status update task - Self::start_status_update_task(ctx.clone(), app_state.clone()).await; + // Start status update task with shutdown support + let handle = Self::start_status_update_task(ctx.clone(), app_state.clone(), status_shutdown_rx); + *status_task_handle.lock().await = Some(handle); Ok(Data { app_state }) }) @@ -96,8 +104,12 @@ impl BotService { .await?) } - /// Start the status update task for the Discord bot - async fn start_status_update_task(ctx: serenity::client::Context, app_state: AppState) { + /// Start the status update task for the Discord bot with graceful shutdown support + fn start_status_update_task( + ctx: serenity::client::Context, + app_state: AppState, + mut shutdown_rx: broadcast::Receiver<()>, + ) -> JoinHandle<()> { tokio::spawn(async move { let max_interval = Duration::from_secs(300); // 5 minutes let base_interval = Duration::from_secs(30); @@ -106,59 +118,72 @@ impl BotService { // This runs once immediately on startup, then with adaptive intervals loop { - interval.tick().await; - - // Get the course count, update the activity if it has changed/hasn't been set this session - let course_count = app_state.get_course_count().await.unwrap(); - if previous_course_count.is_none() || previous_course_count != Some(course_count) { - ctx.set_activity(Some(ActivityData::playing(format!( - "Querying {:} classes", - course_count.to_formatted_string(&Locale::en) - )))); - } - - // Increase or reset the interval - interval = tokio::time::interval( - // Avoid logging the first 'change' - if course_count != previous_course_count.unwrap_or(0) { - if previous_course_count.is_some() { - debug!( - new_course_count = course_count, - last_interval = interval.period().as_secs(), - "Course count changed, resetting interval" - ); + tokio::select! { + _ = interval.tick() => { + // Get the course count, update the activity if it has changed/hasn't been set this session + let course_count = app_state.get_course_count().await.unwrap(); + if previous_course_count.is_none() || previous_course_count != Some(course_count) { + ctx.set_activity(Some(ActivityData::playing(format!( + "Querying {:} classes", + course_count.to_formatted_string(&Locale::en) + )))); } - // Record the new course count - previous_course_count = Some(course_count); + // Increase or reset the interval + interval = tokio::time::interval( + // Avoid logging the first 'change' + if course_count != previous_course_count.unwrap_or(0) { + if previous_course_count.is_some() { + debug!( + new_course_count = course_count, + last_interval = interval.period().as_secs(), + "Course count changed, resetting interval" + ); + } - // Reset to base interval - base_interval - } else { - // Increase interval by 10% (up to maximum) - let new_interval = interval.period().mul_f32(1.1).min(max_interval); - debug!( - current_course_count = course_count, - last_interval = interval.period().as_secs(), - new_interval = new_interval.as_secs(), - "Course count unchanged, increasing interval" + // Record the new course count + previous_course_count = Some(course_count); + + // Reset to base interval + base_interval + } else { + // Increase interval by 10% (up to maximum) + let new_interval = interval.period().mul_f32(1.1).min(max_interval); + debug!( + current_course_count = course_count, + last_interval = interval.period().as_secs(), + new_interval = new_interval.as_secs(), + "Course count unchanged, increasing interval" + ); + + new_interval + }, ); - new_interval - }, - ); - - // Reset the interval, otherwise it will tick again immediately - interval.reset(); + // Reset the interval, otherwise it will tick again immediately + interval.reset(); + } + _ = shutdown_rx.recv() => { + info!("Status update task received shutdown signal"); + break; + } + } } - }); + }) } - pub fn new(client: Client) -> Self { + pub fn new( + client: Client, + status_task_handle: Arc>>>, + status_shutdown_tx: broadcast::Sender<()>, + ) -> Self { let shard_manager = client.shard_manager.clone(); + Self { client, shard_manager, + status_task_handle, + status_shutdown_tx: Some(status_shutdown_tx), } } } @@ -183,6 +208,28 @@ impl Service for BotService { } async fn shutdown(&mut self) -> Result<(), anyhow::Error> { + // Signal status update task to stop + if let Some(status_shutdown_tx) = self.status_shutdown_tx.take() { + let _ = status_shutdown_tx.send(()); + } + + // Wait for status update task to complete (with timeout) + let handle = self.status_task_handle.lock().await.take(); + if let Some(handle) = handle { + match tokio::time::timeout(Duration::from_secs(2), handle).await { + Ok(Ok(())) => { + debug!("Status update task completed gracefully"); + } + Ok(Err(e)) => { + warn!(error = ?e, "Status update task panicked"); + } + Err(_) => { + warn!("Status update task did not complete within 2s timeout"); + } + } + } + + // Shutdown Discord shards self.shard_manager.shutdown_all().await; Ok(()) } diff --git a/src/services/manager.rs b/src/services/manager.rs index c9d3008..7cf9d61 100644 --- a/src/services/manager.rs +++ b/src/services/manager.rs @@ -112,11 +112,11 @@ impl ServiceManager { /// Shutdown all services gracefully with a timeout. /// - /// All services receive the shutdown signal simultaneously and must complete within the - /// specified timeout (combined, not per-service). If any service fails to shutdown within - /// the timeout, it will be aborted and included in the error result. + /// All services receive the shutdown signal simultaneously and shut down in parallel. + /// Each service gets the full timeout duration (they don't share/consume from a budget). + /// If any service fails to shutdown within the timeout, it will be aborted. /// - /// If all services shutdown successfully, the function will return the duration elapsed. + /// Returns the elapsed time if all succeed, or a list of failed service names. pub async fn shutdown(&mut self, timeout: Duration) -> Result> { let service_count = self.service_handles.len(); let service_names: Vec<_> = self.service_handles.keys().cloned().collect(); @@ -125,7 +125,7 @@ impl ServiceManager { service_count, services = ?service_names, timeout = format!("{:.2?}", timeout), - "shutting down {} services with {:?} total timeout", + "shutting down {} services in parallel with {:?} timeout each", service_count, timeout ); @@ -138,54 +138,59 @@ impl ServiceManager { let _ = self.shutdown_tx.send(()); let start_time = std::time::Instant::now(); - let mut completed = 0; - let mut failed_services = Vec::new(); - // Borrow the receiver mutably (don't take ownership to allow reuse) + // Collect results from all services with timeout let completion_rx = self .completion_rx .as_mut() .expect("completion_rx should be available"); - // Wait for all services to complete with timeout - while completed < service_count { - match tokio::time::timeout( - timeout.saturating_sub(start_time.elapsed()), - completion_rx.recv(), - ) - .await - { - Ok(Some((name, result))) => { - completed += 1; - self.service_handles.remove(&name); + // Collect all completion results with a single timeout + let collect_future = async { + let mut collected: Vec> = Vec::new(); + for _ in 0..service_count { + if let Some(result) = completion_rx.recv().await { + collected.push(Some(result)); + } else { + collected.push(None); + } + } + collected + }; - if matches!(result, ServiceResult::GracefulShutdown) { - trace!(service = name, "service shutdown completed"); - } else { - warn!(service = name, "service shutdown with non-graceful result"); - failed_services.push(name); - } - } - Ok(None) => { - // Channel closed - shouldn't happen but handle it - warn!("completion channel closed during shutdown"); - break; - } - Err(_) => { - // Timeout - abort all remaining services - warn!( - timeout = format!("{:.2?}", timeout), - elapsed = format!("{:.2?}", start_time.elapsed()), - remaining = service_count - completed, - "shutdown timeout - aborting remaining services" - ); + let results = match tokio::time::timeout(timeout, collect_future).await { + Ok(results) => results, + Err(_) => { + // Timeout exceeded - abort all remaining services + warn!( + timeout = format!("{:.2?}", timeout), + "shutdown timeout exceeded - aborting all remaining services" + ); - for (name, handle) in self.service_handles.drain() { - handle.abort(); - failed_services.push(name); - } - break; + let failed: Vec = self.service_handles.keys().cloned().collect(); + for handle in self.service_handles.values() { + handle.abort(); } + self.service_handles.clear(); + + return Err(failed); + } + }; + + // Process results and identify failures + let mut failed_services = Vec::new(); + for (name, service_result) in results.into_iter().flatten() { + self.service_handles.remove(&name); + + if matches!(service_result, ServiceResult::GracefulShutdown) { + trace!(service = name, "service shutdown completed"); + } else { + warn!( + service = name, + result = ?service_result, + "service shutdown with non-graceful result" + ); + failed_services.push(name); } } @@ -195,7 +200,7 @@ impl ServiceManager { info!( service_count, elapsed = format!("{:.2?}", elapsed), - "services shutdown completed: {}", + "all services shutdown successfully: {}", service_names.join(", ") ); Ok(elapsed) @@ -204,7 +209,7 @@ impl ServiceManager { failed_count = failed_services.len(), failed_services = ?failed_services, elapsed = format!("{:.2?}", elapsed), - "services shutdown completed with {} failed: {}", + "{} service(s) failed to shutdown gracefully: {}", failed_services.len(), failed_services.join(", ") ); diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..de8498c --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod shutdown; diff --git a/src/utils/shutdown.rs b/src/utils/shutdown.rs new file mode 100644 index 0000000..73a8740 --- /dev/null +++ b/src/utils/shutdown.rs @@ -0,0 +1,32 @@ +use tokio::task::JoinHandle; +use tracing::warn; + +/// Helper for joining multiple task handles with proper error handling. +/// +/// This function waits for all tasks to complete and reports any that panicked. +/// Returns an error if any task panicked, otherwise returns Ok. +pub async fn join_tasks(handles: Vec>) -> Result<(), anyhow::Error> { + let results = futures::future::join_all(handles).await; + + let failed = results.iter().filter(|r| r.is_err()).count(); + if failed > 0 { + warn!(failed_count = failed, "Some tasks panicked during shutdown"); + Err(anyhow::anyhow!("{} task(s) panicked", failed)) + } else { + Ok(()) + } +} + +/// Helper for joining multiple task handles with a timeout. +/// +/// Waits for all tasks to complete within the specified timeout. +/// If timeout occurs, remaining tasks are aborted. +pub async fn join_tasks_with_timeout( + handles: Vec>, + timeout: std::time::Duration, +) -> Result<(), anyhow::Error> { + match tokio::time::timeout(timeout, join_tasks(handles)).await { + Ok(result) => result, + Err(_) => Err(anyhow::anyhow!("Task join timed out after {:?}", timeout)), + } +}