diff --git a/Cargo.lock b/Cargo.lock index e2e7948..6735235 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -254,6 +254,7 @@ dependencies = [ "time", "tl", "tokio", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index cb3b349..fe4db01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ sqlx = { version = "0.8.6", features = [ thiserror = "2.0.16" time = "0.3.43" tokio = { version = "1.47.1", features = ["full"] } +tokio-util = "0.7" tl = "0.7.8" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } diff --git a/src/app.rs b/src/app.rs index 589f8ed..df42579 100644 --- a/src/app.rs +++ b/src/app.rs @@ -42,11 +42,7 @@ impl App { // Check if the database URL is via private networking let is_private = config.database_url.contains("railway.internal"); - let slow_threshold = if is_private { - Duration::from_millis(200) - } else { - Duration::from_millis(500) - }; + let slow_threshold = Duration::from_millis(if is_private { 200 } else { 500 }); // Create database connection pool let db_pool = PgPoolOptions::new() @@ -108,10 +104,6 @@ impl App { .register_service(ServiceName::Scraper.as_str(), scraper_service); } - if services.contains(&ServiceName::Bot) { - // Bot service will be set up separately in run() method since it's async - } - // Check if any services are enabled if !self.service_manager.has_services() && !services.contains(&ServiceName::Bot) { error!("No services enabled. Cannot start application."); @@ -147,9 +139,4 @@ impl App { pub fn config(&self) -> &Config { &self.config } - - /// Get a reference to the app state - pub fn app_state(&self) -> &AppState { - &self.app_state - } } diff --git a/src/scraper/mod.rs b/src/scraper/mod.rs index 6d05867..d428430 100644 --- a/src/scraper/mod.rs +++ b/src/scraper/mod.rs @@ -5,8 +5,10 @@ pub mod worker; use crate::banner::BannerApi; use sqlx::PgPool; use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast; use tokio::task::JoinHandle; -use tracing::info; +use tracing::{info, warn}; use self::scheduler::Scheduler; use self::worker::Worker; @@ -21,6 +23,7 @@ pub struct ScraperService { banner_api: Arc, scheduler_handle: Option>, worker_handles: Vec>, + shutdown_tx: Option>, } impl ScraperService { @@ -31,6 +34,7 @@ impl ScraperService { banner_api, scheduler_handle: None, worker_handles: Vec::new(), + shutdown_tx: None, } } @@ -38,9 +42,14 @@ impl ScraperService { pub fn start(&mut self) { info!("ScraperService starting"); + // Create shutdown channel + let (shutdown_tx, _) = broadcast::channel(1); + self.shutdown_tx = Some(shutdown_tx.clone()); + let scheduler = Scheduler::new(self.db_pool.clone(), self.banner_api.clone()); + let shutdown_rx = shutdown_tx.subscribe(); let scheduler_handle = tokio::spawn(async move { - scheduler.run().await; + scheduler.run(shutdown_rx).await; }); self.scheduler_handle = Some(scheduler_handle); info!("Scheduler task spawned"); @@ -48,8 +57,9 @@ impl ScraperService { let worker_count = 4; // This could be configurable for i in 0..worker_count { let worker = Worker::new(i, self.db_pool.clone(), self.banner_api.clone()); + let shutdown_rx = shutdown_tx.subscribe(); let worker_handle = tokio::spawn(async move { - worker.run().await; + worker.run(shutdown_rx).await; }); self.worker_handles.push(worker_handle); } @@ -59,17 +69,6 @@ impl ScraperService { ); } - /// Signals all child tasks to gracefully shut down. - pub async fn shutdown(&mut self) { - info!("Shutting down scraper service"); - if let Some(handle) = self.scheduler_handle.take() { - handle.abort(); - } - for handle in self.worker_handles.drain(..) { - handle.abort(); - } - info!("Scraper service shutdown"); - } } #[async_trait::async_trait] @@ -85,7 +84,47 @@ impl Service for ScraperService { } async fn shutdown(&mut self) -> Result<(), anyhow::Error> { - self.shutdown().await; + info!("Shutting down scraper service"); + + // Send shutdown signal to all tasks + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } else { + warn!("No shutdown channel found for scraper service"); + } + + // Collect all handles + let mut all_handles = Vec::new(); + if let Some(handle) = self.scheduler_handle.take() { + all_handles.push(handle); + } + all_handles.append(&mut self.worker_handles); + + // Wait for all tasks to complete with a timeout + let timeout_duration = Duration::from_secs(5); + + match tokio::time::timeout( + timeout_duration, + futures::future::join_all(all_handles), + ) + .await + { + Ok(results) => { + let failed = results.iter().filter(|r| r.is_err()).count(); + if failed > 0 { + warn!(failed_count = failed, "Some scraper tasks failed during shutdown"); + } else { + info!("All scraper tasks shutdown gracefully"); + } + } + Err(_) => { + warn!( + timeout = format!("{:.2?}", timeout_duration), + "Scraper service shutdown timed out" + ); + } + } + Ok(()) } } diff --git a/src/scraper/scheduler.rs b/src/scraper/scheduler.rs index f15637b..36f17cc 100644 --- a/src/scraper/scheduler.rs +++ b/src/scraper/scheduler.rs @@ -6,8 +6,10 @@ use serde_json::json; use sqlx::PgPool; use std::sync::Arc; use std::time::Duration; +use tokio::sync::broadcast; use tokio::time; -use tracing::{debug, error, info, trace}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, trace, warn}; /// Periodically analyzes data and enqueues prioritized scrape jobs. pub struct Scheduler { @@ -24,21 +26,72 @@ impl Scheduler { } /// Runs the scheduler's main loop. - pub async fn run(&self) { + pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) { info!("Scheduler service started"); - let mut interval = time::interval(Duration::from_secs(60)); // Runs every minute + + let work_interval = Duration::from_secs(60); + let mut next_run = time::Instant::now(); + let mut current_work: Option<(tokio::task::JoinHandle<()>, CancellationToken)> = None; loop { - interval.tick().await; - // Scheduler analyzing data... - if let Err(e) = self.schedule_jobs().await { - error!(error = ?e, "Failed to schedule jobs"); + tokio::select! { + // Sleep until next scheduled run - instantly cancellable + _ = time::sleep_until(next_run) => { + // Create cancellation token for graceful task cancellation + let cancel_token = CancellationToken::new(); + + // Spawn scheduling work in a separate task for cancellability + let work_handle = tokio::spawn({ + let db_pool = self.db_pool.clone(); + let banner_api = self.banner_api.clone(); + let cancel_token = cancel_token.clone(); + + async move { + // Check for cancellation while running + tokio::select! { + result = Self::schedule_jobs_impl(&db_pool, &banner_api) => { + if let Err(e) = result { + error!(error = ?e, "Failed to schedule jobs"); + } + } + _ = cancel_token.cancelled() => { + debug!("Scheduling work cancelled gracefully"); + } + } + } + }); + + current_work = Some((work_handle, cancel_token)); + next_run = time::Instant::now() + work_interval; + } + _ = shutdown_rx.recv() => { + info!("Scheduler received shutdown signal"); + + // Gracefully cancel any in-progress work + if let Some((handle, cancel_token)) = current_work.take() { + // Signal cancellation + cancel_token.cancel(); + + // Wait for graceful completion with timeout + match time::timeout(Duration::from_secs(5), handle).await { + Ok(_) => { + debug!("Scheduling work completed gracefully"); + } + Err(_) => { + warn!("Scheduling work did not complete within 5s timeout, may have been aborted"); + } + } + } + + info!("Scheduler exiting gracefully"); + break; + } } } } /// The core logic for deciding what jobs to create. - async fn schedule_jobs(&self) -> Result<()> { + async fn schedule_jobs_impl(db_pool: &PgPool, banner_api: &BannerApi) -> Result<()> { // For now, we will implement a simple baseline scheduling strategy: // 1. Get a list of all subjects from the Banner API. // 2. Query existing jobs for all subjects in a single query. @@ -47,7 +100,7 @@ impl Scheduler { debug!(term = term, "Enqueuing subject jobs"); - let subjects = self.banner_api.get_subjects("", &term, 1, 500).await?; + let subjects = banner_api.get_subjects("", &term, 1, 500).await?; debug!( subject_count = subjects.len(), "Retrieved subjects from API" @@ -61,12 +114,12 @@ impl Scheduler { // Query existing jobs for all subjects in a single query let existing_jobs: Vec<(serde_json::Value,)> = sqlx::query_as( - "SELECT target_payload FROM scrape_jobs + "SELECT target_payload FROM scrape_jobs WHERE target_type = $1 AND target_payload = ANY($2) AND locked_at IS NULL", ) .bind(TargetType::Subject) .bind(&subject_payloads) - .fetch_all(&self.db_pool) + .fetch_all(db_pool) .await?; // Convert to a HashSet for efficient lookup @@ -95,7 +148,7 @@ impl Scheduler { // Insert all new jobs in a single batch if !new_jobs.is_empty() { let now = chrono::Utc::now(); - let mut tx = self.db_pool.begin().await?; + let mut tx = db_pool.begin().await?; for (payload, subject_code) in new_jobs { sqlx::query( diff --git a/src/scraper/worker.rs b/src/scraper/worker.rs index 9d7c987..e25c07a 100644 --- a/src/scraper/worker.rs +++ b/src/scraper/worker.rs @@ -5,6 +5,7 @@ use crate::scraper::jobs::{JobError, JobType}; use sqlx::PgPool; use std::sync::Arc; use std::time::Duration; +use tokio::sync::broadcast; use tokio::time; use tracing::{debug, error, info, trace, warn}; @@ -28,77 +29,97 @@ impl Worker { } /// Runs the worker's main loop. - pub async fn run(&self) { + pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) { info!(worker_id = self.id, "Worker started."); - loop { - match self.fetch_and_lock_job().await { - Ok(Some(job)) => { - let job_id = job.id; - debug!(worker_id = self.id, job_id = job.id, "Processing job"); - match self.process_job(job).await { - Ok(()) => { - debug!(worker_id = self.id, job_id, "Job completed"); - // If successful, delete the job. - if let Err(delete_err) = self.delete_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?delete_err, - "Failed to delete job" - ); - } - } - Err(JobError::Recoverable(e)) => { - // Check if the error is due to an invalid session - if let Some(BannerApiError::InvalidSession(_)) = - e.downcast_ref::() - { - warn!( - worker_id = self.id, - job_id, "Invalid session detected. Forcing session refresh." - ); - } else { - error!(worker_id = self.id, job_id, error = ?e, "Failed to process job"); - } - // Unlock the job so it can be retried - if let Err(unlock_err) = self.unlock_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?unlock_err, - "Failed to unlock job" - ); - } + loop { + // Fetch and lock a job, racing against shutdown signal + let job = tokio::select! { + _ = shutdown_rx.recv() => { + info!(worker_id = self.id, "Worker received shutdown signal"); + info!(worker_id = self.id, "Worker exiting gracefully"); + break; + } + result = self.fetch_and_lock_job() => { + match result { + Ok(Some(job)) => job, + Ok(None) => { + // No job found, wait for a bit before polling again + trace!(worker_id = self.id, "No jobs available, waiting"); + time::sleep(Duration::from_secs(5)).await; + continue; } - Err(JobError::Unrecoverable(e)) => { - error!( - worker_id = self.id, - job_id, - error = ?e, - "Job corrupted, deleting" - ); - // Parse errors are unrecoverable - delete the job - if let Err(delete_err) = self.delete_job(job_id).await { - error!( - worker_id = self.id, - job_id, - ?delete_err, - "Failed to delete corrupted job" - ); - } + Err(e) => { + warn!(worker_id = self.id, error = ?e, "Failed to fetch job"); + // Wait before retrying to avoid spamming errors + time::sleep(Duration::from_secs(10)).await; + continue; } } } - Ok(None) => { - // No job found, wait for a bit before polling again. - trace!(worker_id = self.id, "No jobs available, waiting"); - time::sleep(Duration::from_secs(5)).await; + }; + + let job_id = job.id; + debug!(worker_id = self.id, job_id, "Processing job"); + + // Process the job, racing against shutdown signal + let process_result = tokio::select! { + _ = shutdown_rx.recv() => { + info!(worker_id = self.id, job_id, "Shutdown received during job processing"); + + // Unlock the job so it can be retried + if let Err(e) = self.unlock_job(job_id).await { + warn!( + worker_id = self.id, + job_id, + error = ?e, + "Failed to unlock job during shutdown" + ); + } else { + debug!(worker_id = self.id, job_id, "Job unlocked during shutdown"); + } + + info!(worker_id = self.id, "Worker exiting gracefully"); + break; } - Err(e) => { - warn!(worker_id = self.id, error = ?e, "Failed to fetch job"); - // Wait before retrying to avoid spamming errors. - time::sleep(Duration::from_secs(10)).await; + result = self.process_job(job) => { + result + } + }; + + // Handle the job processing result + match process_result { + Ok(()) => { + debug!(worker_id = self.id, job_id, "Job completed"); + // If successful, delete the job + if let Err(delete_err) = self.delete_job(job_id).await { + error!( + worker_id = self.id, + job_id, + ?delete_err, + "Failed to delete job" + ); + } + } + Err(JobError::Recoverable(e)) => { + self.handle_recoverable_error(job_id, e).await; + } + Err(JobError::Unrecoverable(e)) => { + error!( + worker_id = self.id, + job_id, + error = ?e, + "Job corrupted, deleting" + ); + // Parse errors are unrecoverable - delete the job + if let Err(delete_err) = self.delete_job(job_id).await { + error!( + worker_id = self.id, + job_id, + ?delete_err, + "Failed to delete corrupted job" + ); + } } } } @@ -169,4 +190,25 @@ impl Worker { info!(worker_id = self.id, job_id, "Job unlocked for retry"); Ok(()) } + + /// Handle recoverable errors by logging appropriately and unlocking the job + async fn handle_recoverable_error(&self, job_id: i32, e: anyhow::Error) { + if let Some(BannerApiError::InvalidSession(_)) = e.downcast_ref::() { + warn!( + worker_id = self.id, + job_id, "Invalid session detected. Forcing session refresh." + ); + } else { + error!(worker_id = self.id, job_id, error = ?e, "Failed to process job"); + } + + if let Err(unlock_err) = self.unlock_job(job_id).await { + error!( + worker_id = self.id, + job_id, + ?unlock_err, + "Failed to unlock job" + ); + } + } } diff --git a/src/services/manager.rs b/src/services/manager.rs index 89779c4..c9d3008 100644 --- a/src/services/manager.rs +++ b/src/services/manager.rs @@ -1,15 +1,16 @@ use std::collections::HashMap; use std::time::Duration; -use tokio::sync::broadcast; -use tokio::task::JoinHandle; -use tracing::{debug, error, info, trace, warn}; +use tokio::sync::{broadcast, mpsc}; +use tracing::{debug, info, trace, warn}; use crate::services::{Service, ServiceResult, run_service}; /// Manages multiple services and their lifecycle pub struct ServiceManager { registered_services: HashMap>, - running_services: HashMap>, + service_handles: HashMap, + completion_rx: Option>, + completion_tx: mpsc::UnboundedSender<(String, ServiceResult)>, shutdown_tx: broadcast::Sender<()>, } @@ -22,9 +23,13 @@ impl Default for ServiceManager { impl ServiceManager { pub fn new() -> Self { let (shutdown_tx, _) = broadcast::channel(1); + let (completion_tx, completion_rx) = mpsc::unbounded_channel(); + Self { registered_services: HashMap::new(), - running_services: HashMap::new(), + service_handles: HashMap::new(), + completion_rx: Some(completion_rx), + completion_tx, shutdown_tx, } } @@ -46,9 +51,19 @@ impl ServiceManager { for (name, service) in self.registered_services.drain() { let shutdown_rx = self.shutdown_tx.subscribe(); - let handle = tokio::spawn(run_service(service, shutdown_rx)); + let completion_tx = self.completion_tx.clone(); + let name_clone = name.clone(); + + // Spawn service task + let handle = tokio::spawn(async move { + let result = run_service(service, shutdown_rx).await; + // Send completion notification + let _ = completion_tx.send((name_clone, result)); + }); + + // Store abort handle for shutdown control + self.service_handles.insert(name.clone(), handle.abort_handle()); debug!(service = name, id = ?handle.id(), "service spawned"); - self.running_services.insert(name, handle); } info!( @@ -62,7 +77,7 @@ impl ServiceManager { /// Run all services until one completes or fails /// Returns the first service that completes and its result pub async fn run(&mut self) -> (String, ServiceResult) { - if self.running_services.is_empty() { + if self.service_handles.is_empty() { return ( "none".to_string(), ServiceResult::Error(anyhow::anyhow!("No services to run")), @@ -71,82 +86,112 @@ impl ServiceManager { info!( "servicemanager running {} services", - self.running_services.len() + self.service_handles.len() ); - // Wait for any service to complete - loop { - let mut completed_services = Vec::new(); + // Wait for any service to complete via the channel + let completion_rx = self + .completion_rx + .as_mut() + .expect("completion_rx should be available"); - for (name, handle) in &mut self.running_services { - if handle.is_finished() { - completed_services.push(name.clone()); - } - } - - if let Some(completed_name) = completed_services.first() { - let handle = self.running_services.remove(completed_name).unwrap(); - match handle.await { - Ok(result) => { - return (completed_name.clone(), result); - } - Err(e) => { - error!(service = completed_name, "service task panicked: {e}"); - return ( - completed_name.clone(), - ServiceResult::Error(anyhow::anyhow!("Task panic: {e}")), - ); - } - } - } - - // Small delay to prevent busy-waiting - tokio::time::sleep(Duration::from_millis(10)).await; - } + completion_rx + .recv() + .await + .map(|(name, result)| { + self.service_handles.remove(&name); + (name, result) + }) + .unwrap_or_else(|| { + ( + "channel_closed".to_string(), + ServiceResult::Error(anyhow::anyhow!("Completion channel closed")), + ) + }) } /// Shutdown all services gracefully with a timeout. /// - /// If any service fails to shutdown, it will return an error containing the names of the services that failed to shutdown. + /// All services receive the shutdown signal simultaneously and must complete within the + /// specified timeout (combined, not per-service). If any service fails to shutdown within + /// the timeout, it will be aborted and included in the error result. + /// /// If all services shutdown successfully, the function will return the duration elapsed. pub async fn shutdown(&mut self, timeout: Duration) -> Result> { - let service_count = self.running_services.len(); - let service_names: Vec<_> = self.running_services.keys().cloned().collect(); + let service_count = self.service_handles.len(); + let service_names: Vec<_> = self.service_handles.keys().cloned().collect(); info!( service_count, services = ?service_names, timeout = format!("{:.2?}", timeout), - "shutting down {} services with {:?} timeout", + "shutting down {} services with {:?} total timeout", service_count, timeout ); - // Send shutdown signal to all services + if service_count == 0 { + return Ok(Duration::ZERO); + } + + // Send shutdown signal to all services simultaneously let _ = self.shutdown_tx.send(()); - // Wait for all services to complete let start_time = std::time::Instant::now(); - let mut pending_services = Vec::new(); + let mut completed = 0; + let mut failed_services = Vec::new(); - for (name, handle) in self.running_services.drain() { - match tokio::time::timeout(timeout, handle).await { - Ok(Ok(_)) => { - trace!(service = name, "service shutdown completed"); + // Borrow the receiver mutably (don't take ownership to allow reuse) + let completion_rx = self + .completion_rx + .as_mut() + .expect("completion_rx should be available"); + + // Wait for all services to complete with timeout + while completed < service_count { + match tokio::time::timeout( + timeout.saturating_sub(start_time.elapsed()), + completion_rx.recv(), + ) + .await + { + Ok(Some((name, result))) => { + completed += 1; + self.service_handles.remove(&name); + + if matches!(result, ServiceResult::GracefulShutdown) { + trace!(service = name, "service shutdown completed"); + } else { + warn!(service = name, "service shutdown with non-graceful result"); + failed_services.push(name); + } } - Ok(Err(e)) => { - warn!(service = name, error = ?e, "service shutdown failed"); - pending_services.push(name); + Ok(None) => { + // Channel closed - shouldn't happen but handle it + warn!("completion channel closed during shutdown"); + break; } Err(_) => { - warn!(service = name, "service shutdown timed out"); - pending_services.push(name); + // Timeout - abort all remaining services + warn!( + timeout = format!("{:.2?}", timeout), + elapsed = format!("{:.2?}", start_time.elapsed()), + remaining = service_count - completed, + "shutdown timeout - aborting remaining services" + ); + + for (name, handle) in self.service_handles.drain() { + handle.abort(); + failed_services.push(name); + } + break; } } } let elapsed = start_time.elapsed(); - if pending_services.is_empty() { + + if failed_services.is_empty() { info!( service_count, elapsed = format!("{:.2?}", elapsed), @@ -156,14 +201,14 @@ impl ServiceManager { Ok(elapsed) } else { warn!( - pending_count = pending_services.len(), - pending_services = ?pending_services, + failed_count = failed_services.len(), + failed_services = ?failed_services, elapsed = format!("{:.2?}", elapsed), - "services shutdown completed with {} pending: {}", - pending_services.len(), - pending_services.join(", ") + "services shutdown completed with {} failed: {}", + failed_services.len(), + failed_services.join(", ") ); - Err(pending_services) + Err(failed_services) } } } diff --git a/src/services/mod.rs b/src/services/mod.rs index 1c5f84a..7d063a7 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -23,7 +23,11 @@ pub trait Service: Send + Sync { /// Gracefully shutdown the service /// - /// An 'Ok' result does not mean the service has completed shutdown, it merely means that the service shutdown was initiated. + /// Implementations should initiate shutdown and MAY wait for completion. + /// Services are expected to respond to this call and begin cleanup promptly. + /// When managed by ServiceManager, the configured timeout (default 8s) applies to + /// ALL services combined, not per-service. Services should complete shutdown as + /// quickly as possible to avoid timeout. async fn shutdown(&mut self) -> Result<(), anyhow::Error>; } diff --git a/src/services/web.rs b/src/services/web.rs index f23cde6..692f0d2 100644 --- a/src/services/web.rs +++ b/src/services/web.rs @@ -3,7 +3,7 @@ use crate::web::{BannerState, create_router}; use std::net::SocketAddr; use tokio::net::TcpListener; use tokio::sync::broadcast; -use tracing::{info, warn, trace}; +use tracing::{info, trace, warn}; /// Web server service implementation pub struct WebService { @@ -33,16 +33,12 @@ impl Service for WebService { let app = create_router(self.banner_state.clone()); let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); - info!( - service = "web", - link = format!("http://localhost:{}", addr.port()), - "starting web server", - ); let listener = TcpListener::bind(addr).await?; info!( service = "web", address = %addr, + link = format!("http://localhost:{}", addr.port()), "web server listening" ); @@ -61,13 +57,16 @@ impl Service for WebService { }) .await?; + trace!(service = "web", "graceful shutdown completed"); info!(service = "web", "web server stopped"); + Ok(()) } async fn shutdown(&mut self) -> Result<(), anyhow::Error> { if let Some(shutdown_tx) = self.shutdown_tx.take() { let _ = shutdown_tx.send(()); + trace!(service = "web", "sent shutdown signal to axum"); } else { warn!( service = "web",