diff --git a/Cargo.lock b/Cargo.lock index 690d162..a6ea202 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -173,6 +173,7 @@ dependencies = [ "diesel", "dotenvy", "figment", + "fundu", "governor", "poise", "redis", @@ -704,6 +705,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fundu" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ce12752fc64f35be3d53e0a57017cd30970f0cffd73f62c791837d8845badbd" +dependencies = [ + "fundu-core", +] + +[[package]] +name = "fundu-core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e463452e2d8b7600d38dcea1ed819773a57f0d710691bfc78db3961bd3f4c3ba" + [[package]] name = "futures" version = "0.3.31" diff --git a/Cargo.toml b/Cargo.toml index 72663d9..2c63ebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,5 +19,6 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } dotenvy = "0.15.7" poise = "0.6.1" async-trait = "0.1" +fundu = "2.0.1" anyhow = "1.0.99" -thiserror = "2.0.16" +thiserror = "2.0.16" \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index f49febd..e832543 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,11 +1,127 @@ -use serde::Deserialize; +//! Configuration module for the banner application. +//! +//! This module handles loading and parsing configuration from environment variables +//! using the figment crate. It supports flexible duration parsing that accepts both +//! numeric values (interpreted as seconds) and duration strings with units. +//! +//! All configuration is loaded from environment variables with the `APP_` prefix: +use fundu::{DurationParser, TimeUnit}; +use serde::{Deserialize, Deserializer}; +use std::time::Duration; + +/// Application configuration loaded from environment variables. #[derive(Deserialize)] pub struct Config { + /// Discord bot token for authentication pub bot_token: String, + /// Database connection URL pub database_url: String, + /// Redis connection URL pub redis_url: String, + /// Base URL for banner generation service pub banner_base_url: String, + /// Target Discord guild ID where the bot operates pub bot_target_guild: u64, + /// Discord application ID pub bot_app_id: u64, + /// Graceful shutdown timeout duration + /// + /// Accepts both numeric values (seconds) and duration strings. + /// Defaults to 8 seconds if not specified. + #[serde( + default = "default_shutdown_timeout", + deserialize_with = "deserialize_duration" + )] + pub shutdown_timeout: Duration, +} + +/// Default shutdown timeout of 8 seconds. +fn default_shutdown_timeout() -> Duration { + Duration::from_secs(8) +} + +/// Duration parser configured to handle various time units with seconds as default. +/// +/// Supports: +/// - Seconds (s) - default unit +/// - Milliseconds (ms) +/// - Minutes (m) +/// - Hours (h) +/// +/// Does not support fractions, exponents, or infinity values. +/// Allows for whitespace between the number and the time unit. +/// Allows for multiple time units to be specified (summed together, e.g. "10s 2m" = 120 + 10 = 130 seconds) +const DURATION_PARSER: DurationParser<'static> = DurationParser::builder() + .time_units(&[TimeUnit::Second, TimeUnit::MilliSecond, TimeUnit::Minute]) + .parse_multiple(None) + .allow_time_unit_delimiter() + .disable_infinity() + .disable_fraction() + .disable_exponent() + .default_unit(TimeUnit::Second) + .build(); + +/// Custom deserializer for duration fields that accepts both numeric and string values. +/// +/// This deserializer handles the flexible duration parsing by accepting: +/// - Unsigned integers (interpreted as seconds) +/// - Signed integers (interpreted as seconds, must be non-negative) +/// - Strings (parsed using the fundu duration parser) +/// +/// # Examples +/// +/// - `1` -> 1 second +/// - `"30s"` -> 30 seconds +/// - `"2 m"` -> 2 minutes +/// - `"1500ms"` -> 1.5 seconds +fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + use serde::de::Visitor; + + struct DurationVisitor; + + impl<'de> Visitor<'de> for DurationVisitor { + type Value = Duration; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a duration string or number") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + DURATION_PARSER.parse(value) + .map_err(|e| { + serde::de::Error::custom(format!( + "Invalid duration format '{}': {}. Examples: '5' (5 seconds), '3500ms', '30s', '2m', '1.5h'", + value, e + )) + })? + .try_into() + .map_err(|e| serde::de::Error::custom(format!("Duration conversion error: {}", e))) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Duration::from_secs(value)) + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + if value < 0 { + return Err(serde::de::Error::custom("Duration cannot be negative")); + } + Ok(Duration::from_secs(value as u64)) + } + } + + deserializer.deserialize_any(DurationVisitor) } diff --git a/src/main.rs b/src/main.rs index 952e015..9c03c60 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,17 @@ use serenity::all::{ClientBuilder, GatewayIntents}; -use std::time::Duration; -use tokio::{signal, task::JoinSet}; +use tokio::signal; use tracing::{error, info, warn}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use crate::bot::{Data, age}; use crate::config::Config; +use crate::services::manager::ServiceManager; use crate::services::{ServiceResult, bot::BotService, dummy::DummyService, run_service}; -use crate::shutdown::ShutdownCoordinator; use figment::{Figment, providers::Env}; mod bot; mod config; mod services; -mod shutdown; #[tokio::main] async fn main() { @@ -22,7 +20,18 @@ async fn main() { // Configure logging let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn,banner=debug")); - let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish(); + let subscriber = { + #[cfg(debug_assertions)] + { + FmtSubscriber::builder() + } + #[cfg(not(debug_assertions))] + { + FmtSubscriber::builder().json() + } + } + .with_env_filter(filter) + .finish(); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); let config: Config = Figment::new() @@ -51,118 +60,87 @@ async fn main() { .await .expect("Failed to build client"); - let shutdown_coordinator = ShutdownCoordinator::new(); + // Extract shutdown timeout before moving config + let shutdown_timeout = config.shutdown_timeout; - // Create services + // Create service manager + let mut service_manager = ServiceManager::new(); + + // Create and add services let bot_service = Box::new(BotService::new(client)); let dummy_service = Box::new(DummyService::new("background")); - // Start services using the unified runner - let bot_handle = { - let shutdown_rx = shutdown_coordinator.subscribe(); - tokio::spawn(run_service(bot_service, shutdown_rx)) + let bot_handle = tokio::spawn(run_service(bot_service, service_manager.subscribe())); + let dummy_handle = tokio::spawn(run_service(dummy_service, service_manager.subscribe())); + + service_manager.add_service("bot".to_string(), bot_handle); + service_manager.add_service("background".to_string(), dummy_handle); + + // Set up CTRL+C signal handling + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler"); + info!("Received CTRL+C, gracefully shutting down..."); }; - let dummy_handle = { - let shutdown_rx = shutdown_coordinator.subscribe(); - tokio::spawn(run_service(dummy_service, shutdown_rx)) - }; - - // Set up signal handling - let signal_handle = { - let shutdown_tx = shutdown_coordinator.shutdown_tx(); - tokio::spawn(async move { - signal::ctrl_c() - .await - .expect("Failed to install CTRL+C signal handler"); - info!("Received CTRL+C, initiating shutdown..."); - let _ = shutdown_tx.send(()); - ServiceResult::GracefulShutdown - }) - }; - - // Put all services in a JoinSet for unified handling - let mut services = JoinSet::new(); - services.spawn(bot_handle); - services.spawn(dummy_handle); - services.spawn(signal_handle); - - // Wait for any service to complete or signal + // Main application loop - wait for services or CTRL+C let mut exit_code = 0; - let first_completion = services.join_next().await; - match first_completion { - Some(Ok(Ok(service_result))) => { - // A service completed successfully - match &service_result { + let join = |strings: Vec| { + strings + .iter() + .map(|s| format!("\"{}\"", s)) + .collect::>() + .join(", ") + }; + + tokio::select! { + (service_name, result) = service_manager.run() => { + // A service completed unexpectedly + match result { ServiceResult::GracefulShutdown => { - // This means CTRL+C was pressed + info!(service = service_name, "Service completed gracefully"); } ServiceResult::NormalCompletion => { - warn!("A service completed unexpectedly"); + warn!(service = service_name, "Service completed unexpectedly"); exit_code = 1; } ServiceResult::Error(e) => { - error!("Service failure: {e}"); + error!(service = service_name, "Service failed: {e}"); exit_code = 1; } } - } - Some(Ok(Err(e))) => { - error!("Service task panicked: {e}"); - exit_code = 1; - } - Some(Err(e)) => { - error!("JoinSet error: {e}"); - exit_code = 1; - } - None => { - warn!("No services running"); - exit_code = 1; - } - }; - // Signal all services to shut down - shutdown_coordinator.shutdown(); - - // Wait for graceful shutdown with timeout - let remaining_count = services.len(); - if remaining_count > 0 { - info!("Waiting for {remaining_count} remaining services to shutdown (5s timeout)..."); - let shutdown_result = tokio::time::timeout(Duration::from_secs(5), async { - while let Some(result) = services.join_next().await { - match result { - Ok(Ok(ServiceResult::GracefulShutdown)) => { - // Service shutdown logged by the service itself - } - Ok(Ok(ServiceResult::NormalCompletion)) => { - warn!("Service completed normally during shutdown"); - } - Ok(Ok(ServiceResult::Error(e))) => { - error!("Service error during shutdown: {e}"); - } - Ok(Err(e)) => { - error!("Service panic during shutdown: {e}"); - } - Err(e) => { - error!("Service join error: {e}"); - } + // Shutdown remaining services + match service_manager.shutdown(shutdown_timeout).await { + Ok(()) => { + info!("Graceful shutdown complete"); + } + Err(pending_services) => { + warn!( + "Graceful shutdown elapsed - the following service(s) did not complete: {}", + join(pending_services) + ); + exit_code = if exit_code == 0 { 2 } else { exit_code }; } } - }) - .await; - - match shutdown_result { - Ok(()) => { - info!("All services shutdown completed"); - } - Err(_) => { - warn!("Shutdown timeout - some services may not have completed"); - exit_code = if exit_code == 0 { 2 } else { exit_code }; + } + _ = ctrl_c => { + // User requested shutdown + match service_manager.shutdown(shutdown_timeout).await { + Ok(()) => { + info!("Graceful shutdown complete"); + } + Err(pending_services) => { + warn!( + "Graceful shutdown elapsed - the following service(s) did not complete: {}", + join(pending_services) + ); + exit_code = 2; + } } } - } else { - info!("No remaining services to shutdown"); } info!("Application shutdown complete (exit code: {})", exit_code); diff --git a/src/services/bot.rs b/src/services/bot.rs index 2ad098a..38aa84c 100644 --- a/src/services/bot.rs +++ b/src/services/bot.rs @@ -1,4 +1,4 @@ -use super::{Service, ServiceResult}; +use super::Service; use serenity::Client; use std::sync::Arc; use tracing::{error, warn}; diff --git a/src/services/manager.rs b/src/services/manager.rs new file mode 100644 index 0000000..011bee1 --- /dev/null +++ b/src/services/manager.rs @@ -0,0 +1,155 @@ +use std::collections::HashMap; +use std::time::Duration; +use tokio::sync::broadcast; +use tokio::task::JoinHandle; +use tracing::{error, info, warn}; + +use crate::services::ServiceResult; + +/// Manages multiple services and their lifecycle +pub struct ServiceManager { + services: HashMap>, + shutdown_tx: broadcast::Sender<()>, +} + +impl ServiceManager { + pub fn new() -> Self { + let (shutdown_tx, _) = broadcast::channel(1); + Self { + services: HashMap::new(), + shutdown_tx, + } + } + + /// Add a service to be managed + pub fn add_service(&mut self, name: String, handle: JoinHandle) { + self.services.insert(name, handle); + } + + /// Get a shutdown receiver for services to subscribe to + pub fn subscribe(&self) -> broadcast::Receiver<()> { + self.shutdown_tx.subscribe() + } + + /// Run all services until one completes or fails + /// Returns the first service that completes and its result + pub async fn run(&mut self) -> (String, ServiceResult) { + if self.services.is_empty() { + return ( + "none".to_string(), + ServiceResult::Error(anyhow::anyhow!("No services to run")), + ); + } + + info!("ServiceManager running {} services", self.services.len()); + + // Wait for any service to complete + loop { + let mut completed_services = Vec::new(); + + for (name, handle) in &mut self.services { + if handle.is_finished() { + completed_services.push(name.clone()); + } + } + + if let Some(completed_name) = completed_services.first() { + let handle = self.services.remove(completed_name).unwrap(); + match handle.await { + Ok(result) => { + return (completed_name.clone(), result); + } + Err(e) => { + error!(service = completed_name, "Service task panicked: {e}"); + return ( + completed_name.clone(), + ServiceResult::Error(anyhow::anyhow!("Task panic: {e}")), + ); + } + } + } + + // Small delay to prevent busy-waiting + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + /// Shutdown all services gracefully with a timeout + /// Returns Ok(()) if all services shut down, or Err(Vec) with names of services that timed out + pub async fn shutdown(mut self, timeout: Duration) -> Result<(), Vec> { + if self.services.is_empty() { + info!("No services to shutdown"); + return Ok(()); + } + + info!( + "Shutting down {} services with {}s timeout", + self.services.len(), + timeout.as_secs() + ); + + // Signal all services to shutdown + let _ = self.shutdown_tx.send(()); + + // Wait for all services to complete with timeout + let shutdown_result = tokio::time::timeout(timeout, async { + let mut completed = Vec::new(); + let mut failed = Vec::new(); + + while !self.services.is_empty() { + let mut to_remove = Vec::new(); + + for (name, handle) in &mut self.services { + if handle.is_finished() { + to_remove.push(name.clone()); + } + } + + for name in to_remove { + let handle = self.services.remove(&name).unwrap(); + match handle.await { + Ok(ServiceResult::GracefulShutdown) => { + completed.push(name); + } + Ok(ServiceResult::NormalCompletion) => { + warn!(service = name, "Service completed normally during shutdown"); + completed.push(name); + } + Ok(ServiceResult::Error(e)) => { + error!(service = name, "Service error during shutdown: {e}"); + failed.push(name); + } + Err(e) => { + error!(service = name, "Service panic during shutdown: {e}"); + failed.push(name); + } + } + } + + if !self.services.is_empty() { + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + (completed, failed) + }) + .await; + + match shutdown_result { + Ok((completed, failed)) => { + if !completed.is_empty() { + info!("Services shutdown completed: {}", completed.join(", ")); + } + if !failed.is_empty() { + warn!("Services had errors during shutdown: {}", failed.join(", ")); + } + Ok(()) + } + Err(_) => { + // Timeout occurred - return names of services that didn't complete + let pending_services: Vec = self.services.keys().cloned().collect(); + Err(pending_services) + } + } + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index c845b7b..61826d6 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,9 +1,9 @@ -use std::time::Duration; use tokio::sync::broadcast; use tracing::{error, info, warn}; pub mod bot; pub mod dummy; +pub mod manager; #[derive(Debug)] pub enum ServiceResult { diff --git a/src/shutdown.rs b/src/shutdown.rs deleted file mode 100644 index a071997..0000000 --- a/src/shutdown.rs +++ /dev/null @@ -1,25 +0,0 @@ -use tokio::sync::broadcast; - -/// Shutdown coordinator for managing graceful shutdown of multiple services -pub struct ShutdownCoordinator { - shutdown_tx: broadcast::Sender<()>, -} - -impl ShutdownCoordinator { - pub fn new() -> Self { - let (shutdown_tx, _) = broadcast::channel(1); - Self { shutdown_tx } - } - - pub fn subscribe(&self) -> broadcast::Receiver<()> { - self.shutdown_tx.subscribe() - } - - pub fn shutdown(&self) { - let _ = self.shutdown_tx.send(()); - } - - pub fn shutdown_tx(&self) -> broadcast::Sender<()> { - self.shutdown_tx.clone() - } -}