refactor(scraper): implement graceful shutdown with broadcast channels

Replace task abortion with broadcast-based graceful shutdown for scheduler and workers. Implement cancellation tokens for in-progress work with 5s timeout. Add tokio-util dependency for CancellationToken support. Update ServiceManager to use completion channels and abort handles for better service lifecycle control.
This commit is contained in:
Ryan Walters
2025-11-03 01:22:12 -06:00
parent 020a00254f
commit 8af9b0a1a2
9 changed files with 343 additions and 172 deletions

1
Cargo.lock generated
View File

@@ -254,6 +254,7 @@ dependencies = [
"time",
"tl",
"tokio",
"tokio-util",
"tower-http",
"tracing",
"tracing-subscriber",

View File

@@ -36,6 +36,7 @@ sqlx = { version = "0.8.6", features = [
thiserror = "2.0.16"
time = "0.3.43"
tokio = { version = "1.47.1", features = ["full"] }
tokio-util = "0.7"
tl = "0.7.8"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }

View File

@@ -42,11 +42,7 @@ impl App {
// Check if the database URL is via private networking
let is_private = config.database_url.contains("railway.internal");
let slow_threshold = if is_private {
Duration::from_millis(200)
} else {
Duration::from_millis(500)
};
let slow_threshold = Duration::from_millis(if is_private { 200 } else { 500 });
// Create database connection pool
let db_pool = PgPoolOptions::new()
@@ -108,10 +104,6 @@ impl App {
.register_service(ServiceName::Scraper.as_str(), scraper_service);
}
if services.contains(&ServiceName::Bot) {
// Bot service will be set up separately in run() method since it's async
}
// Check if any services are enabled
if !self.service_manager.has_services() && !services.contains(&ServiceName::Bot) {
error!("No services enabled. Cannot start application.");
@@ -147,9 +139,4 @@ impl App {
pub fn config(&self) -> &Config {
&self.config
}
/// Get a reference to the app state
pub fn app_state(&self) -> &AppState {
&self.app_state
}
}

View File

@@ -5,8 +5,10 @@ pub mod worker;
use crate::banner::BannerApi;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tracing::info;
use tracing::{info, warn};
use self::scheduler::Scheduler;
use self::worker::Worker;
@@ -21,6 +23,7 @@ pub struct ScraperService {
banner_api: Arc<BannerApi>,
scheduler_handle: Option<JoinHandle<()>>,
worker_handles: Vec<JoinHandle<()>>,
shutdown_tx: Option<broadcast::Sender<()>>,
}
impl ScraperService {
@@ -31,6 +34,7 @@ impl ScraperService {
banner_api,
scheduler_handle: None,
worker_handles: Vec::new(),
shutdown_tx: None,
}
}
@@ -38,9 +42,14 @@ impl ScraperService {
pub fn start(&mut self) {
info!("ScraperService starting");
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel(1);
self.shutdown_tx = Some(shutdown_tx.clone());
let scheduler = Scheduler::new(self.db_pool.clone(), self.banner_api.clone());
let shutdown_rx = shutdown_tx.subscribe();
let scheduler_handle = tokio::spawn(async move {
scheduler.run().await;
scheduler.run(shutdown_rx).await;
});
self.scheduler_handle = Some(scheduler_handle);
info!("Scheduler task spawned");
@@ -48,8 +57,9 @@ impl ScraperService {
let worker_count = 4; // This could be configurable
for i in 0..worker_count {
let worker = Worker::new(i, self.db_pool.clone(), self.banner_api.clone());
let shutdown_rx = shutdown_tx.subscribe();
let worker_handle = tokio::spawn(async move {
worker.run().await;
worker.run(shutdown_rx).await;
});
self.worker_handles.push(worker_handle);
}
@@ -59,17 +69,6 @@ impl ScraperService {
);
}
/// Signals all child tasks to gracefully shut down.
pub async fn shutdown(&mut self) {
info!("Shutting down scraper service");
if let Some(handle) = self.scheduler_handle.take() {
handle.abort();
}
for handle in self.worker_handles.drain(..) {
handle.abort();
}
info!("Scraper service shutdown");
}
}
#[async_trait::async_trait]
@@ -85,7 +84,47 @@ impl Service for ScraperService {
}
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
self.shutdown().await;
info!("Shutting down scraper service");
// Send shutdown signal to all tasks
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
} else {
warn!("No shutdown channel found for scraper service");
}
// Collect all handles
let mut all_handles = Vec::new();
if let Some(handle) = self.scheduler_handle.take() {
all_handles.push(handle);
}
all_handles.append(&mut self.worker_handles);
// Wait for all tasks to complete with a timeout
let timeout_duration = Duration::from_secs(5);
match tokio::time::timeout(
timeout_duration,
futures::future::join_all(all_handles),
)
.await
{
Ok(results) => {
let failed = results.iter().filter(|r| r.is_err()).count();
if failed > 0 {
warn!(failed_count = failed, "Some scraper tasks failed during shutdown");
} else {
info!("All scraper tasks shutdown gracefully");
}
}
Err(_) => {
warn!(
timeout = format!("{:.2?}", timeout_duration),
"Scraper service shutdown timed out"
);
}
}
Ok(())
}
}

View File

@@ -6,8 +6,10 @@ use serde_json::json;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::time;
use tracing::{debug, error, info, trace};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn};
/// Periodically analyzes data and enqueues prioritized scrape jobs.
pub struct Scheduler {
@@ -24,21 +26,72 @@ impl Scheduler {
}
/// Runs the scheduler's main loop.
pub async fn run(&self) {
pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) {
info!("Scheduler service started");
let mut interval = time::interval(Duration::from_secs(60)); // Runs every minute
let work_interval = Duration::from_secs(60);
let mut next_run = time::Instant::now();
let mut current_work: Option<(tokio::task::JoinHandle<()>, CancellationToken)> = None;
loop {
interval.tick().await;
// Scheduler analyzing data...
if let Err(e) = self.schedule_jobs().await {
error!(error = ?e, "Failed to schedule jobs");
tokio::select! {
// Sleep until next scheduled run - instantly cancellable
_ = time::sleep_until(next_run) => {
// Create cancellation token for graceful task cancellation
let cancel_token = CancellationToken::new();
// Spawn scheduling work in a separate task for cancellability
let work_handle = tokio::spawn({
let db_pool = self.db_pool.clone();
let banner_api = self.banner_api.clone();
let cancel_token = cancel_token.clone();
async move {
// Check for cancellation while running
tokio::select! {
result = Self::schedule_jobs_impl(&db_pool, &banner_api) => {
if let Err(e) = result {
error!(error = ?e, "Failed to schedule jobs");
}
}
_ = cancel_token.cancelled() => {
debug!("Scheduling work cancelled gracefully");
}
}
}
});
current_work = Some((work_handle, cancel_token));
next_run = time::Instant::now() + work_interval;
}
_ = shutdown_rx.recv() => {
info!("Scheduler received shutdown signal");
// Gracefully cancel any in-progress work
if let Some((handle, cancel_token)) = current_work.take() {
// Signal cancellation
cancel_token.cancel();
// Wait for graceful completion with timeout
match time::timeout(Duration::from_secs(5), handle).await {
Ok(_) => {
debug!("Scheduling work completed gracefully");
}
Err(_) => {
warn!("Scheduling work did not complete within 5s timeout, may have been aborted");
}
}
}
info!("Scheduler exiting gracefully");
break;
}
}
}
}
/// The core logic for deciding what jobs to create.
async fn schedule_jobs(&self) -> Result<()> {
async fn schedule_jobs_impl(db_pool: &PgPool, banner_api: &BannerApi) -> Result<()> {
// For now, we will implement a simple baseline scheduling strategy:
// 1. Get a list of all subjects from the Banner API.
// 2. Query existing jobs for all subjects in a single query.
@@ -47,7 +100,7 @@ impl Scheduler {
debug!(term = term, "Enqueuing subject jobs");
let subjects = self.banner_api.get_subjects("", &term, 1, 500).await?;
let subjects = banner_api.get_subjects("", &term, 1, 500).await?;
debug!(
subject_count = subjects.len(),
"Retrieved subjects from API"
@@ -66,7 +119,7 @@ impl Scheduler {
)
.bind(TargetType::Subject)
.bind(&subject_payloads)
.fetch_all(&self.db_pool)
.fetch_all(db_pool)
.await?;
// Convert to a HashSet for efficient lookup
@@ -95,7 +148,7 @@ impl Scheduler {
// Insert all new jobs in a single batch
if !new_jobs.is_empty() {
let now = chrono::Utc::now();
let mut tx = self.db_pool.begin().await?;
let mut tx = db_pool.begin().await?;
for (payload, subject_code) in new_jobs {
sqlx::query(

View File

@@ -5,6 +5,7 @@ use crate::scraper::jobs::{JobError, JobType};
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::time;
use tracing::{debug, error, info, trace, warn};
@@ -28,77 +29,97 @@ impl Worker {
}
/// Runs the worker's main loop.
pub async fn run(&self) {
pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) {
info!(worker_id = self.id, "Worker started.");
loop {
match self.fetch_and_lock_job().await {
Ok(Some(job)) => {
let job_id = job.id;
debug!(worker_id = self.id, job_id = job.id, "Processing job");
match self.process_job(job).await {
Ok(()) => {
debug!(worker_id = self.id, job_id, "Job completed");
// If successful, delete the job.
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete job"
);
}
}
Err(JobError::Recoverable(e)) => {
// Check if the error is due to an invalid session
if let Some(BannerApiError::InvalidSession(_)) =
e.downcast_ref::<BannerApiError>()
{
warn!(
worker_id = self.id,
job_id, "Invalid session detected. Forcing session refresh."
);
} else {
error!(worker_id = self.id, job_id, error = ?e, "Failed to process job");
}
// Unlock the job so it can be retried
if let Err(unlock_err) = self.unlock_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?unlock_err,
"Failed to unlock job"
);
}
loop {
// Fetch and lock a job, racing against shutdown signal
let job = tokio::select! {
_ = shutdown_rx.recv() => {
info!(worker_id = self.id, "Worker received shutdown signal");
info!(worker_id = self.id, "Worker exiting gracefully");
break;
}
result = self.fetch_and_lock_job() => {
match result {
Ok(Some(job)) => job,
Ok(None) => {
// No job found, wait for a bit before polling again
trace!(worker_id = self.id, "No jobs available, waiting");
time::sleep(Duration::from_secs(5)).await;
continue;
}
Err(JobError::Unrecoverable(e)) => {
error!(
worker_id = self.id,
job_id,
error = ?e,
"Job corrupted, deleting"
);
// Parse errors are unrecoverable - delete the job
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete corrupted job"
);
}
Err(e) => {
warn!(worker_id = self.id, error = ?e, "Failed to fetch job");
// Wait before retrying to avoid spamming errors
time::sleep(Duration::from_secs(10)).await;
continue;
}
}
}
Ok(None) => {
// No job found, wait for a bit before polling again.
trace!(worker_id = self.id, "No jobs available, waiting");
time::sleep(Duration::from_secs(5)).await;
};
let job_id = job.id;
debug!(worker_id = self.id, job_id, "Processing job");
// Process the job, racing against shutdown signal
let process_result = tokio::select! {
_ = shutdown_rx.recv() => {
info!(worker_id = self.id, job_id, "Shutdown received during job processing");
// Unlock the job so it can be retried
if let Err(e) = self.unlock_job(job_id).await {
warn!(
worker_id = self.id,
job_id,
error = ?e,
"Failed to unlock job during shutdown"
);
} else {
debug!(worker_id = self.id, job_id, "Job unlocked during shutdown");
}
info!(worker_id = self.id, "Worker exiting gracefully");
break;
}
Err(e) => {
warn!(worker_id = self.id, error = ?e, "Failed to fetch job");
// Wait before retrying to avoid spamming errors.
time::sleep(Duration::from_secs(10)).await;
result = self.process_job(job) => {
result
}
};
// Handle the job processing result
match process_result {
Ok(()) => {
debug!(worker_id = self.id, job_id, "Job completed");
// If successful, delete the job
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete job"
);
}
}
Err(JobError::Recoverable(e)) => {
self.handle_recoverable_error(job_id, e).await;
}
Err(JobError::Unrecoverable(e)) => {
error!(
worker_id = self.id,
job_id,
error = ?e,
"Job corrupted, deleting"
);
// Parse errors are unrecoverable - delete the job
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete corrupted job"
);
}
}
}
}
@@ -169,4 +190,25 @@ impl Worker {
info!(worker_id = self.id, job_id, "Job unlocked for retry");
Ok(())
}
/// Handle recoverable errors by logging appropriately and unlocking the job
async fn handle_recoverable_error(&self, job_id: i32, e: anyhow::Error) {
if let Some(BannerApiError::InvalidSession(_)) = e.downcast_ref::<BannerApiError>() {
warn!(
worker_id = self.id,
job_id, "Invalid session detected. Forcing session refresh."
);
} else {
error!(worker_id = self.id, job_id, error = ?e, "Failed to process job");
}
if let Err(unlock_err) = self.unlock_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?unlock_err,
"Failed to unlock job"
);
}
}
}

View File

@@ -1,15 +1,16 @@
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, trace, warn};
use tokio::sync::{broadcast, mpsc};
use tracing::{debug, info, trace, warn};
use crate::services::{Service, ServiceResult, run_service};
/// Manages multiple services and their lifecycle
pub struct ServiceManager {
registered_services: HashMap<String, Box<dyn Service>>,
running_services: HashMap<String, JoinHandle<ServiceResult>>,
service_handles: HashMap<String, tokio::task::AbortHandle>,
completion_rx: Option<mpsc::UnboundedReceiver<(String, ServiceResult)>>,
completion_tx: mpsc::UnboundedSender<(String, ServiceResult)>,
shutdown_tx: broadcast::Sender<()>,
}
@@ -22,9 +23,13 @@ impl Default for ServiceManager {
impl ServiceManager {
pub fn new() -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
let (completion_tx, completion_rx) = mpsc::unbounded_channel();
Self {
registered_services: HashMap::new(),
running_services: HashMap::new(),
service_handles: HashMap::new(),
completion_rx: Some(completion_rx),
completion_tx,
shutdown_tx,
}
}
@@ -46,9 +51,19 @@ impl ServiceManager {
for (name, service) in self.registered_services.drain() {
let shutdown_rx = self.shutdown_tx.subscribe();
let handle = tokio::spawn(run_service(service, shutdown_rx));
let completion_tx = self.completion_tx.clone();
let name_clone = name.clone();
// Spawn service task
let handle = tokio::spawn(async move {
let result = run_service(service, shutdown_rx).await;
// Send completion notification
let _ = completion_tx.send((name_clone, result));
});
// Store abort handle for shutdown control
self.service_handles.insert(name.clone(), handle.abort_handle());
debug!(service = name, id = ?handle.id(), "service spawned");
self.running_services.insert(name, handle);
}
info!(
@@ -62,7 +77,7 @@ impl ServiceManager {
/// Run all services until one completes or fails
/// Returns the first service that completes and its result
pub async fn run(&mut self) -> (String, ServiceResult) {
if self.running_services.is_empty() {
if self.service_handles.is_empty() {
return (
"none".to_string(),
ServiceResult::Error(anyhow::anyhow!("No services to run")),
@@ -71,82 +86,112 @@ impl ServiceManager {
info!(
"servicemanager running {} services",
self.running_services.len()
self.service_handles.len()
);
// Wait for any service to complete
loop {
let mut completed_services = Vec::new();
// Wait for any service to complete via the channel
let completion_rx = self
.completion_rx
.as_mut()
.expect("completion_rx should be available");
for (name, handle) in &mut self.running_services {
if handle.is_finished() {
completed_services.push(name.clone());
}
}
if let Some(completed_name) = completed_services.first() {
let handle = self.running_services.remove(completed_name).unwrap();
match handle.await {
Ok(result) => {
return (completed_name.clone(), result);
}
Err(e) => {
error!(service = completed_name, "service task panicked: {e}");
return (
completed_name.clone(),
ServiceResult::Error(anyhow::anyhow!("Task panic: {e}")),
);
}
}
}
// Small delay to prevent busy-waiting
tokio::time::sleep(Duration::from_millis(10)).await;
}
completion_rx
.recv()
.await
.map(|(name, result)| {
self.service_handles.remove(&name);
(name, result)
})
.unwrap_or_else(|| {
(
"channel_closed".to_string(),
ServiceResult::Error(anyhow::anyhow!("Completion channel closed")),
)
})
}
/// Shutdown all services gracefully with a timeout.
///
/// If any service fails to shutdown, it will return an error containing the names of the services that failed to shutdown.
/// All services receive the shutdown signal simultaneously and must complete within the
/// specified timeout (combined, not per-service). If any service fails to shutdown within
/// the timeout, it will be aborted and included in the error result.
///
/// If all services shutdown successfully, the function will return the duration elapsed.
pub async fn shutdown(&mut self, timeout: Duration) -> Result<Duration, Vec<String>> {
let service_count = self.running_services.len();
let service_names: Vec<_> = self.running_services.keys().cloned().collect();
let service_count = self.service_handles.len();
let service_names: Vec<_> = self.service_handles.keys().cloned().collect();
info!(
service_count,
services = ?service_names,
timeout = format!("{:.2?}", timeout),
"shutting down {} services with {:?} timeout",
"shutting down {} services with {:?} total timeout",
service_count,
timeout
);
// Send shutdown signal to all services
if service_count == 0 {
return Ok(Duration::ZERO);
}
// Send shutdown signal to all services simultaneously
let _ = self.shutdown_tx.send(());
// Wait for all services to complete
let start_time = std::time::Instant::now();
let mut pending_services = Vec::new();
let mut completed = 0;
let mut failed_services = Vec::new();
for (name, handle) in self.running_services.drain() {
match tokio::time::timeout(timeout, handle).await {
Ok(Ok(_)) => {
trace!(service = name, "service shutdown completed");
// Borrow the receiver mutably (don't take ownership to allow reuse)
let completion_rx = self
.completion_rx
.as_mut()
.expect("completion_rx should be available");
// Wait for all services to complete with timeout
while completed < service_count {
match tokio::time::timeout(
timeout.saturating_sub(start_time.elapsed()),
completion_rx.recv(),
)
.await
{
Ok(Some((name, result))) => {
completed += 1;
self.service_handles.remove(&name);
if matches!(result, ServiceResult::GracefulShutdown) {
trace!(service = name, "service shutdown completed");
} else {
warn!(service = name, "service shutdown with non-graceful result");
failed_services.push(name);
}
}
Ok(Err(e)) => {
warn!(service = name, error = ?e, "service shutdown failed");
pending_services.push(name);
Ok(None) => {
// Channel closed - shouldn't happen but handle it
warn!("completion channel closed during shutdown");
break;
}
Err(_) => {
warn!(service = name, "service shutdown timed out");
pending_services.push(name);
// Timeout - abort all remaining services
warn!(
timeout = format!("{:.2?}", timeout),
elapsed = format!("{:.2?}", start_time.elapsed()),
remaining = service_count - completed,
"shutdown timeout - aborting remaining services"
);
for (name, handle) in self.service_handles.drain() {
handle.abort();
failed_services.push(name);
}
break;
}
}
}
let elapsed = start_time.elapsed();
if pending_services.is_empty() {
if failed_services.is_empty() {
info!(
service_count,
elapsed = format!("{:.2?}", elapsed),
@@ -156,14 +201,14 @@ impl ServiceManager {
Ok(elapsed)
} else {
warn!(
pending_count = pending_services.len(),
pending_services = ?pending_services,
failed_count = failed_services.len(),
failed_services = ?failed_services,
elapsed = format!("{:.2?}", elapsed),
"services shutdown completed with {} pending: {}",
pending_services.len(),
pending_services.join(", ")
"services shutdown completed with {} failed: {}",
failed_services.len(),
failed_services.join(", ")
);
Err(pending_services)
Err(failed_services)
}
}
}

View File

@@ -23,7 +23,11 @@ pub trait Service: Send + Sync {
/// Gracefully shutdown the service
///
/// An 'Ok' result does not mean the service has completed shutdown, it merely means that the service shutdown was initiated.
/// Implementations should initiate shutdown and MAY wait for completion.
/// Services are expected to respond to this call and begin cleanup promptly.
/// When managed by ServiceManager, the configured timeout (default 8s) applies to
/// ALL services combined, not per-service. Services should complete shutdown as
/// quickly as possible to avoid timeout.
async fn shutdown(&mut self) -> Result<(), anyhow::Error>;
}

View File

@@ -3,7 +3,7 @@ use crate::web::{BannerState, create_router};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tracing::{info, warn, trace};
use tracing::{info, trace, warn};
/// Web server service implementation
pub struct WebService {
@@ -33,16 +33,12 @@ impl Service for WebService {
let app = create_router(self.banner_state.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
info!(
service = "web",
link = format!("http://localhost:{}", addr.port()),
"starting web server",
);
let listener = TcpListener::bind(addr).await?;
info!(
service = "web",
address = %addr,
link = format!("http://localhost:{}", addr.port()),
"web server listening"
);
@@ -61,13 +57,16 @@ impl Service for WebService {
})
.await?;
trace!(service = "web", "graceful shutdown completed");
info!(service = "web", "web server stopped");
Ok(())
}
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
trace!(service = "web", "sent shutdown signal to axum");
} else {
warn!(
service = "web",