diff --git a/src/app.rs b/src/app.rs new file mode 100644 index 0000000..ad0747a --- /dev/null +++ b/src/app.rs @@ -0,0 +1,145 @@ +use crate::banner::BannerApi; +use crate::cli::ServiceName; +use crate::config::Config; +use crate::scraper::ScraperService; +use crate::services::bot::BotService; +use crate::services::manager::ServiceManager; +use crate::services::web::WebService; +use crate::state::AppState; +use crate::web::routes::BannerState; +use figment::value::UncasedStr; +use figment::{Figment, providers::Env}; +use sqlx::postgres::PgPoolOptions; +use std::process::ExitCode; +use std::sync::Arc; +use tracing::{error, info}; + +/// Main application struct containing all necessary components +pub struct App { + config: Config, + db_pool: sqlx::PgPool, + banner_api: Arc, + app_state: AppState, + banner_state: BannerState, + service_manager: ServiceManager, +} + +impl App { + /// Create a new App instance with all necessary components initialized + pub async fn new() -> Result { + // Load configuration + let config: Config = Figment::new() + .merge(Env::raw().map(|k| { + if k == UncasedStr::new("RAILWAY_DEPLOYMENT_DRAINING_SECONDS") { + "SHUTDOWN_TIMEOUT".into() + } else { + k.into() + } + })) + .extract() + .expect("Failed to load config"); + + // Create database connection pool + let db_pool = PgPoolOptions::new() + .max_connections(10) + .connect(&config.database_url) + .await + .expect("Failed to create database pool"); + + info!( + port = config.port, + shutdown_timeout = format!("{:.2?}", config.shutdown_timeout), + banner_base_url = config.banner_base_url, + "configuration loaded" + ); + + // Create BannerApi and AppState + let banner_api = BannerApi::new_with_config( + config.banner_base_url.clone(), + config.rate_limiting.clone().into(), + ) + .expect("Failed to create BannerApi"); + + let banner_api_arc = Arc::new(banner_api); + let app_state = AppState::new(banner_api_arc.clone(), db_pool.clone()); + + // Create BannerState for web service + let banner_state = BannerState {}; + + Ok(App { + config, + db_pool, + banner_api: banner_api_arc, + app_state, + banner_state, + service_manager: ServiceManager::new(), + }) + } + + /// Setup and register services based on enabled service list + pub fn setup_services( + &mut self, + enabled_services: &[ServiceName], + ) -> Result<(), anyhow::Error> { + // Register enabled services with the manager + if enabled_services.contains(&ServiceName::Web) { + let web_service = + Box::new(WebService::new(self.config.port, self.banner_state.clone())); + self.service_manager + .register_service(ServiceName::Web.as_str(), web_service); + } + + if enabled_services.contains(&ServiceName::Scraper) { + let scraper_service = Box::new(ScraperService::new( + self.db_pool.clone(), + self.banner_api.clone(), + )); + self.service_manager + .register_service(ServiceName::Scraper.as_str(), scraper_service); + } + + if enabled_services.contains(&ServiceName::Bot) { + // Bot service will be set up separately in run() method since it's async + } + + // Check if any services are enabled + if !self.service_manager.has_services() && !enabled_services.contains(&ServiceName::Bot) { + error!("No services enabled. Cannot start application."); + return Err(anyhow::anyhow!("No services enabled")); + } + + Ok(()) + } + + /// Setup bot service if enabled + pub async fn setup_bot_service(&mut self) -> Result<(), anyhow::Error> { + let client = BotService::create_client(&self.config, self.app_state.clone()) + .await + .expect("Failed to create Discord client"); + let bot_service = Box::new(BotService::new(client)); + self.service_manager + .register_service(ServiceName::Bot.as_str(), bot_service); + Ok(()) + } + + /// Start all registered services + pub fn start_services(&mut self) { + self.service_manager.spawn_all(); + } + + /// Run the application and handle shutdown signals + pub async fn run(self) -> ExitCode { + use crate::signals::handle_shutdown_signals; + handle_shutdown_signals(self.service_manager, self.config.shutdown_timeout).await + } + + /// Get a reference to the configuration + pub fn config(&self) -> &Config { + &self.config + } + + /// Get a reference to the app state + pub fn app_state(&self) -> &AppState { + &self.app_state + } +} diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..f711999 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,104 @@ +use clap::Parser; + +/// Banner Discord Bot - Course availability monitoring +/// +/// This application runs multiple services that can be controlled via CLI arguments: +/// - bot: Discord bot for course monitoring commands +/// - web: HTTP server for web interface and API +/// - scraper: Background service for scraping course data +/// +/// Use --services to specify which services to run, or --disable-services to exclude specific services. +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Log formatter to use + #[arg(long, value_enum, default_value_t = default_tracing_format())] + pub tracing: TracingFormat, + + /// Services to run (comma-separated). Default: all services + /// + /// Examples: + /// --services bot,web # Run only bot and web services + /// --services scraper # Run only the scraper service + #[arg(long, value_delimiter = ',', conflicts_with = "disable_services")] + pub services: Option>, + + /// Services to disable (comma-separated) + /// + /// Examples: + /// --disable-services bot # Run web and scraper only + /// --disable-services bot,web # Run only the scraper service + #[arg(long, value_delimiter = ',', conflicts_with = "services")] + pub disable_services: Option>, +} + +#[derive(clap::ValueEnum, Clone, Debug)] +pub enum TracingFormat { + /// Use pretty formatter (default in debug mode) + Pretty, + /// Use JSON formatter (default in release mode) + Json, +} + +#[derive(clap::ValueEnum, Clone, Debug, PartialEq)] +pub enum ServiceName { + /// Discord bot for course monitoring commands + Bot, + /// HTTP server for web interface and API + Web, + /// Background service for scraping course data + Scraper, +} + +impl ServiceName { + /// Get all available services + pub fn all() -> Vec { + vec![ServiceName::Bot, ServiceName::Web, ServiceName::Scraper] + } + + /// Convert to string for service registration + pub fn as_str(&self) -> &'static str { + match self { + ServiceName::Bot => "bot", + ServiceName::Web => "web", + ServiceName::Scraper => "scraper", + } + } +} + +/// Determine which services should be enabled based on CLI arguments +pub fn determine_enabled_services(args: &Args) -> Result, anyhow::Error> { + match (&args.services, &args.disable_services) { + (Some(services), None) => { + // User specified which services to run + Ok(services.clone()) + } + (None, Some(disabled)) => { + // User specified which services to disable + let enabled: Vec = ServiceName::all() + .into_iter() + .filter(|s| !disabled.contains(s)) + .collect(); + Ok(enabled) + } + (None, None) => { + // Default: run all services + Ok(ServiceName::all()) + } + (Some(_), Some(_)) => { + // This should be prevented by clap's conflicts_with, but just in case + Err(anyhow::anyhow!( + "Cannot specify both --services and --disable-services" + )) + } + } +} + +#[cfg(debug_assertions)] +const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Pretty; +#[cfg(not(debug_assertions))] +const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Json; + +fn default_tracing_format() -> TracingFormat { + DEFAULT_TRACING_FORMAT +} diff --git a/src/lib.rs b/src/lib.rs index e26d55e..2ed65f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,14 @@ +pub mod app; pub mod banner; pub mod bot; +pub mod cli; pub mod config; pub mod data; pub mod error; +pub mod formatter; +pub mod logging; pub mod scraper; pub mod services; +pub mod signals; pub mod state; pub mod web; diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 0000000..6056e8b --- /dev/null +++ b/src/logging.rs @@ -0,0 +1,44 @@ +use crate::cli::TracingFormat; +use crate::config::Config; +use crate::formatter; +use tracing_subscriber::fmt::format::JsonFields; +use tracing_subscriber::{EnvFilter, FmtSubscriber}; + +/// Configure and initialize logging for the application +pub fn setup_logging(config: &Config, tracing_format: TracingFormat) { + // Configure logging based on config + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { + let base_level = &config.log_level; + EnvFilter::new(format!( + "warn,banner={},banner::rate_limiter=warn,banner::session=warn,banner::rate_limit_middleware=warn", + base_level + )) + }); + + // Select formatter based on CLI args + let use_pretty = match tracing_format { + TracingFormat::Pretty => true, + TracingFormat::Json => false, + }; + + let subscriber: Box = if use_pretty { + Box::new( + FmtSubscriber::builder() + .with_target(true) + .event_format(formatter::CustomPrettyFormatter) + .with_env_filter(filter) + .finish(), + ) + } else { + Box::new( + FmtSubscriber::builder() + .with_target(true) + .event_format(formatter::CustomJsonFormatter) + .fmt_fields(JsonFields::new()) + .with_env_filter(filter) + .finish(), + ) + }; + + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); +} diff --git a/src/main.rs b/src/main.rs index d01c8ca..39885a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,137 +1,27 @@ +use crate::app::App; +use crate::cli::{Args, ServiceName, determine_enabled_services}; +use crate::logging::setup_logging; use clap::Parser; -use figment::value::UncasedStr; -use num_format::{Locale, ToFormattedString}; -use serenity::all::{ActivityData, ClientBuilder, GatewayIntents}; -use tokio::signal; -use tracing::{debug, error, info, warn}; -use tracing_subscriber::fmt::format::JsonFields; -use tracing_subscriber::{EnvFilter, FmtSubscriber}; - -use crate::banner::BannerApi; -use crate::bot::{Data, get_commands}; -use crate::config::Config; -use crate::scraper::ScraperService; -use crate::services::manager::ServiceManager; -use crate::services::{ServiceResult, bot::BotService, web::WebService}; -use crate::state::AppState; -use crate::web::routes::BannerState; -use figment::{Figment, providers::Env}; -use sqlx::postgres::PgPoolOptions; -use std::sync::Arc; -use std::time::Duration; +use std::process::ExitCode; +use tracing::info; +mod app; mod banner; mod bot; +mod cli; mod config; mod data; mod error; mod formatter; +mod logging; mod scraper; mod services; +mod signals; mod state; mod web; -#[cfg(debug_assertions)] -const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Pretty; -#[cfg(not(debug_assertions))] -const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Json; - -/// Banner Discord Bot - Course availability monitoring -/// -/// This application runs multiple services that can be controlled via CLI arguments: -/// - bot: Discord bot for course monitoring commands -/// - web: HTTP server for web interface and API -/// - scraper: Background service for scraping course data -/// -/// Use --services to specify which services to run, or --disable-services to exclude specific services. -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Log formatter to use - #[arg(long, value_enum, default_value_t = DEFAULT_TRACING_FORMAT)] - tracing: TracingFormat, - - /// Services to run (comma-separated). Default: all services - /// - /// Examples: - /// --services bot,web # Run only bot and web services - /// --services scraper # Run only the scraper service - #[arg(long, value_delimiter = ',', conflicts_with = "disable_services")] - services: Option>, - - /// Services to disable (comma-separated) - /// - /// Examples: - /// --disable-services bot # Run web and scraper only - /// --disable-services bot,web # Run only the scraper service - #[arg(long, value_delimiter = ',', conflicts_with = "services")] - disable_services: Option>, -} - -#[derive(clap::ValueEnum, Clone, Debug)] -enum TracingFormat { - /// Use pretty formatter (default in debug mode) - Pretty, - /// Use JSON formatter (default in release mode) - Json, -} - -#[derive(clap::ValueEnum, Clone, Debug, PartialEq)] -enum ServiceName { - /// Discord bot for course monitoring commands - Bot, - /// HTTP server for web interface and API - Web, - /// Background service for scraping course data - Scraper, -} - -impl ServiceName { - /// Get all available services - fn all() -> Vec { - vec![ServiceName::Bot, ServiceName::Web, ServiceName::Scraper] - } - - /// Convert to string for service registration - fn as_str(&self) -> &'static str { - match self { - ServiceName::Bot => "bot", - ServiceName::Web => "web", - ServiceName::Scraper => "scraper", - } - } -} - -/// Determine which services should be enabled based on CLI arguments -fn determine_enabled_services(args: &Args) -> Result, anyhow::Error> { - match (&args.services, &args.disable_services) { - (Some(services), None) => { - // User specified which services to run - Ok(services.clone()) - } - (None, Some(disabled)) => { - // User specified which services to disable - let enabled: Vec = ServiceName::all() - .into_iter() - .filter(|s| !disabled.contains(s)) - .collect(); - Ok(enabled) - } - (None, None) => { - // Default: run all services - Ok(ServiceName::all()) - } - (Some(_), Some(_)) => { - // This should be prevented by clap's conflicts_with, but just in case - Err(anyhow::anyhow!( - "Cannot specify both --services and --disable-services" - )) - } - } -} - #[tokio::main] -async fn main() { +async fn main() -> ExitCode { dotenvy::dotenv().ok(); // Parse CLI arguments @@ -146,52 +36,11 @@ async fn main() { "services configuration loaded" ); - // Load configuration first to get log level - let config: Config = Figment::new() - .merge(Env::raw().map(|k| { - if k == UncasedStr::new("RAILWAY_DEPLOYMENT_DRAINING_SECONDS") { - "SHUTDOWN_TIMEOUT".into() - } else { - k.into() - } - })) - .extract() - .expect("Failed to load config"); + // Create and initialize the application + let mut app = App::new().await.expect("Failed to initialize application"); - // Configure logging based on config - let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { - let base_level = &config.log_level; - EnvFilter::new(format!( - "warn,banner={},banner::rate_limiter=warn,banner::session=warn,banner::rate_limit_middleware=warn", - base_level - )) - }); - - // Select formatter based on CLI args - let use_pretty = match args.tracing { - TracingFormat::Pretty => true, - TracingFormat::Json => false, - }; - - let subscriber: Box = if use_pretty { - Box::new( - FmtSubscriber::builder() - .with_target(true) - .event_format(formatter::CustomPrettyFormatter) - .with_env_filter(filter) - .finish(), - ) - } else { - Box::new( - FmtSubscriber::builder() - .with_target(true) - .event_format(formatter::CustomJsonFormatter) - .fmt_fields(JsonFields::new()) - .with_env_filter(filter) - .finish(), - ) - }; - tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + // Setup logging + setup_logging(app.config(), args.tracing); // Log application startup context info!( @@ -204,312 +53,18 @@ async fn main() { "starting banner" ); - // Create database connection pool - let db_pool = PgPoolOptions::new() - .max_connections(10) - .connect(&config.database_url) - .await - .expect("Failed to create database pool"); + // Setup services (web, scraper) + app.setup_services(&enabled_services) + .expect("Failed to setup services"); - info!( - port = config.port, - shutdown_timeout = format!("{:.2?}", config.shutdown_timeout), - banner_base_url = config.banner_base_url, - "configuration loaded" - ); - - // Create BannerApi and AppState - let banner_api = BannerApi::new_with_config( - config.banner_base_url.clone(), - config.rate_limiting.clone().into(), - ) - .expect("Failed to create BannerApi"); - - let banner_api_arc = Arc::new(banner_api); - let app_state = AppState::new(banner_api_arc.clone(), db_pool.clone()); - - // Create BannerState for web service - let banner_state = BannerState {}; - - // Configure the client with your Discord bot token in the environment - let intents = GatewayIntents::non_privileged(); - - let bot_target_guild = config.bot_target_guild; - - let framework = poise::Framework::builder() - .options(poise::FrameworkOptions { - commands: get_commands(), - pre_command: |ctx| { - Box::pin(async move { - let content = match ctx { - poise::Context::Application(_) => ctx.invocation_string(), - poise::Context::Prefix(prefix) => prefix.msg.content.to_string(), - }; - let channel_name = ctx - .channel_id() - .name(ctx.http()) - .await - .unwrap_or("unknown".to_string()); - - let span = tracing::Span::current(); - span.record("command_name", ctx.command().qualified_name.as_str()); - span.record("invocation", ctx.invocation_string()); - span.record("msg.content", content.as_str()); - span.record("msg.author", ctx.author().tag().as_str()); - span.record("msg.id", ctx.id()); - span.record("msg.channel_id", ctx.channel_id().get()); - span.record("msg.channel", channel_name.as_str()); - - tracing::info!( - command_name = ctx.command().qualified_name.as_str(), - invocation = ctx.invocation_string(), - msg.content = %content, - msg.author = %ctx.author().tag(), - msg.author_id = %ctx.author().id, - msg.id = %ctx.id(), - msg.channel = %channel_name.as_str(), - msg.channel_id = %ctx.channel_id(), - "{} invoked by {}", - ctx.command().name, - ctx.author().tag() - ); - }) - }, - on_error: |error| { - Box::pin(async move { - if let Err(e) = poise::builtins::on_error(error).await { - tracing::error!(error = %e, "Fatal error while sending error message"); - } - // error!(error = ?error, "command error"); - }) - }, - ..Default::default() - }) - .setup(move |ctx, _ready, framework| { - let app_state = app_state.clone(); - Box::pin(async move { - poise::builtins::register_in_guild( - ctx, - &framework.options().commands, - bot_target_guild.into(), - ) - .await?; - poise::builtins::register_globally(ctx, &framework.options().commands).await?; - - // Start status update task - let status_app_state = app_state.clone(); - let status_ctx = ctx.clone(); - tokio::spawn(async move { - let max_interval = Duration::from_secs(300); // 5 minutes - let base_interval = Duration::from_secs(30); - let mut interval = tokio::time::interval(base_interval); - let mut previous_course_count: Option = None; - - // This runs once immediately on startup, then with adaptive intervals - loop { - interval.tick().await; - - // Get the course count, update the activity if it has changed/hasn't been set this session - let course_count = status_app_state.get_course_count().await.unwrap(); - if previous_course_count.is_none() - || previous_course_count != Some(course_count) - { - status_ctx.set_activity(Some(ActivityData::playing(format!( - "Querying {:} classes", - course_count.to_formatted_string(&Locale::en) - )))); - } - - // Increase or reset the interval - interval = tokio::time::interval( - // Avoid logging the first 'change' - if course_count != previous_course_count.unwrap_or(0) { - if previous_course_count.is_some() { - debug!( - new_course_count = course_count, - last_interval = interval.period().as_secs(), - "Course count changed, resetting interval" - ); - } - - // Record the new course count - previous_course_count = Some(course_count); - - // Reset to base interval - base_interval - } else { - // Increase interval by 10% (up to maximum) - let new_interval = interval.period().mul_f32(1.1).min(max_interval); - debug!( - current_course_count = course_count, - last_interval = interval.period().as_secs(), - new_interval = new_interval.as_secs(), - "Course count unchanged, increasing interval" - ); - - new_interval - }, - ); - - // Reset the interval, otherwise it will tick again immediately - interval.reset(); - } - }); - - Ok(Data { app_state }) - }) - }) - .build(); - - let client = ClientBuilder::new(config.bot_token, intents) - .framework(framework) - .await - .expect("Failed to build client"); - - // Extract shutdown timeout before moving config - let shutdown_timeout = config.shutdown_timeout; - let port = config.port; - - // Create service manager - let mut service_manager = ServiceManager::new(); - - // Register enabled services with the manager + // Setup bot service if enabled if enabled_services.contains(&ServiceName::Bot) { - let bot_service = Box::new(BotService::new(client)); - service_manager.register_service(ServiceName::Bot.as_str(), bot_service); - } - - if enabled_services.contains(&ServiceName::Web) { - let web_service = Box::new(WebService::new(port, banner_state)); - service_manager.register_service(ServiceName::Web.as_str(), web_service); - } - - if enabled_services.contains(&ServiceName::Scraper) { - let scraper_service = - Box::new(ScraperService::new(db_pool.clone(), banner_api_arc.clone())); - service_manager.register_service(ServiceName::Scraper.as_str(), scraper_service); - } - - // Check if any services are enabled - if !service_manager.has_services() { - error!("No services enabled. Cannot start application."); - std::process::exit(1); - } - - // Spawn all registered services - service_manager.spawn_all(); - - // Set up signal handling for both SIGINT (Ctrl+C) and SIGTERM - let ctrl_c = async { - signal::ctrl_c() + app.setup_bot_service() .await - .expect("Failed to install CTRL+C signal handler"); - info!("received ctrl+c, gracefully shutting down..."); - }; - - #[cfg(unix)] - let sigterm = async { - use tokio::signal::unix::{SignalKind, signal}; - let mut sigterm_stream = - signal(SignalKind::terminate()).expect("Failed to install SIGTERM signal handler"); - sigterm_stream.recv().await; - info!("received SIGTERM, gracefully shutting down..."); - }; - - #[cfg(not(unix))] - let sigterm = async { - // On non-Unix systems, create a future that never completes - // This ensures the select! macro works correctly - std::future::pending::<()>().await; - }; - - // Main application loop - wait for services or signals - let mut exit_code = 0; - - tokio::select! { - (service_name, result) = service_manager.run() => { - // A service completed unexpectedly - match result { - ServiceResult::GracefulShutdown => { - info!(service = service_name, "service completed gracefully"); - } - ServiceResult::NormalCompletion => { - warn!(service = service_name, "service completed unexpectedly"); - exit_code = 1; - } - ServiceResult::Error(e) => { - error!(service = service_name, error = ?e, "service failed"); - exit_code = 1; - } - } - - // Shutdown remaining services - match service_manager.shutdown(shutdown_timeout).await { - Ok(elapsed) => { - info!( - remaining = format!("{:.2?}", shutdown_timeout - elapsed), - "graceful shutdown complete" - ); - } - Err(pending_services) => { - warn!( - pending_count = pending_services.len(), - pending_services = ?pending_services, - "graceful shutdown elapsed - {} service(s) did not complete", - pending_services.len() - ); - - // Non-zero exit code, default to 2 if not set - exit_code = if exit_code == 0 { 2 } else { exit_code }; - } - } - } - _ = ctrl_c => { - // User requested shutdown via Ctrl+C - info!("user requested shutdown via ctrl+c"); - match service_manager.shutdown(shutdown_timeout).await { - Ok(elapsed) => { - info!( - remaining = format!("{:.2?}", shutdown_timeout - elapsed), - "graceful shutdown complete" - ); - info!("graceful shutdown complete"); - } - Err(pending_services) => { - warn!( - pending_count = pending_services.len(), - pending_services = ?pending_services, - "graceful shutdown elapsed - {} service(s) did not complete", - pending_services.len() - ); - exit_code = 2; - } - } - } - _ = sigterm => { - // System requested shutdown via SIGTERM - info!("system requested shutdown via SIGTERM"); - match service_manager.shutdown(shutdown_timeout).await { - Ok(elapsed) => { - info!( - remaining = format!("{:.2?}", shutdown_timeout - elapsed), - "graceful shutdown complete" - ); - info!("graceful shutdown complete"); - } - Err(pending_services) => { - warn!( - pending_count = pending_services.len(), - pending_services = ?pending_services, - "graceful shutdown elapsed - {} service(s) did not complete", - pending_services.len() - ); - exit_code = 2; - } - } - } + .expect("Failed to setup bot service"); } - info!(exit_code, "application shutdown complete"); - std::process::exit(exit_code); + // Start all services and run the application + app.start_services(); + app.run().await } diff --git a/src/services/bot.rs b/src/services/bot.rs index cb16524..2409e5e 100644 --- a/src/services/bot.rs +++ b/src/services/bot.rs @@ -1,7 +1,13 @@ use super::Service; +use crate::bot::{Data, get_commands}; +use crate::config::Config; +use crate::state::AppState; +use num_format::{Locale, ToFormattedString}; use serenity::Client; +use serenity::all::{ActivityData, ClientBuilder, GatewayIntents}; use std::sync::Arc; -use tracing::{error, warn}; +use std::time::Duration; +use tracing::{debug, error, warn}; /// Discord bot service implementation pub struct BotService { @@ -10,6 +16,144 @@ pub struct BotService { } impl BotService { + /// Create a new Discord bot client with full configuration + pub async fn create_client( + config: &Config, + app_state: AppState, + ) -> Result { + let intents = GatewayIntents::non_privileged(); + let bot_target_guild = config.bot_target_guild; + + let framework = poise::Framework::builder() + .options(poise::FrameworkOptions { + commands: get_commands(), + pre_command: |ctx| { + Box::pin(async move { + let content = match ctx { + poise::Context::Application(_) => ctx.invocation_string(), + poise::Context::Prefix(prefix) => prefix.msg.content.to_string(), + }; + let channel_name = ctx + .channel_id() + .name(ctx.http()) + .await + .unwrap_or("unknown".to_string()); + + let span = tracing::Span::current(); + span.record("command_name", ctx.command().qualified_name.as_str()); + span.record("invocation", ctx.invocation_string()); + span.record("msg.content", content.as_str()); + span.record("msg.author", ctx.author().tag().as_str()); + span.record("msg.id", ctx.id()); + span.record("msg.channel_id", ctx.channel_id().get()); + span.record("msg.channel", channel_name.as_str()); + + tracing::info!( + command_name = ctx.command().qualified_name.as_str(), + invocation = ctx.invocation_string(), + msg.content = %content, + msg.author = %ctx.author().tag(), + msg.author_id = %ctx.author().id, + msg.id = %ctx.id(), + msg.channel = %channel_name.as_str(), + msg.channel_id = %ctx.channel_id(), + "{} invoked by {}", + ctx.command().name, + ctx.author().tag() + ); + }) + }, + on_error: |error| { + Box::pin(async move { + if let Err(e) = poise::builtins::on_error(error).await { + tracing::error!(error = %e, "Fatal error while sending error message"); + } + }) + }, + ..Default::default() + }) + .setup(move |ctx, _ready, framework| { + let app_state = app_state.clone(); + Box::pin(async move { + poise::builtins::register_in_guild( + ctx, + &framework.options().commands, + bot_target_guild.into(), + ) + .await?; + poise::builtins::register_globally(ctx, &framework.options().commands).await?; + + // Start status update task + Self::start_status_update_task(ctx.clone(), app_state.clone()).await; + + Ok(Data { app_state }) + }) + }) + .build(); + + Ok(ClientBuilder::new(config.bot_token.clone(), intents) + .framework(framework) + .await?) + } + + /// Start the status update task for the Discord bot + async fn start_status_update_task(ctx: serenity::client::Context, app_state: AppState) { + tokio::spawn(async move { + let max_interval = Duration::from_secs(300); // 5 minutes + let base_interval = Duration::from_secs(30); + let mut interval = tokio::time::interval(base_interval); + let mut previous_course_count: Option = None; + + // This runs once immediately on startup, then with adaptive intervals + loop { + interval.tick().await; + + // Get the course count, update the activity if it has changed/hasn't been set this session + let course_count = app_state.get_course_count().await.unwrap(); + if previous_course_count.is_none() || previous_course_count != Some(course_count) { + ctx.set_activity(Some(ActivityData::playing(format!( + "Querying {:} classes", + course_count.to_formatted_string(&Locale::en) + )))); + } + + // Increase or reset the interval + interval = tokio::time::interval( + // Avoid logging the first 'change' + if course_count != previous_course_count.unwrap_or(0) { + if previous_course_count.is_some() { + debug!( + new_course_count = course_count, + last_interval = interval.period().as_secs(), + "Course count changed, resetting interval" + ); + } + + // Record the new course count + previous_course_count = Some(course_count); + + // Reset to base interval + base_interval + } else { + // Increase interval by 10% (up to maximum) + let new_interval = interval.period().mul_f32(1.1).min(max_interval); + debug!( + current_course_count = course_count, + last_interval = interval.period().as_secs(), + new_interval = new_interval.as_secs(), + "Course count unchanged, increasing interval" + ); + + new_interval + }, + ); + + // Reset the interval, otherwise it will tick again immediately + interval.reset(); + } + }); + } + pub fn new(client: Client) -> Self { let shard_manager = client.shard_manager.clone(); Self { diff --git a/src/signals.rs b/src/signals.rs new file mode 100644 index 0000000..9a719a5 --- /dev/null +++ b/src/signals.rs @@ -0,0 +1,106 @@ +use crate::services::ServiceResult; +use crate::services::manager::ServiceManager; +use std::process::ExitCode; +use std::time::Duration; +use tokio::signal; +use tracing::{error, info, warn}; + +/// Handle application shutdown signals and graceful shutdown +pub async fn handle_shutdown_signals( + mut service_manager: ServiceManager, + shutdown_timeout: Duration, +) -> ExitCode { + // Set up signal handling for both SIGINT (Ctrl+C) and SIGTERM + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler"); + info!("received ctrl+c, gracefully shutting down..."); + }; + + #[cfg(unix)] + let sigterm = async { + use tokio::signal::unix::{SignalKind, signal}; + let mut sigterm_stream = + signal(SignalKind::terminate()).expect("Failed to install SIGTERM signal handler"); + sigterm_stream.recv().await; + info!("received SIGTERM, gracefully shutting down..."); + }; + + #[cfg(not(unix))] + let sigterm = async { + // On non-Unix systems, create a future that never completes + // This ensures the select! macro works correctly + std::future::pending::<()>().await; + }; + + // Main application loop - wait for services or signals + let mut exit_code = ExitCode::SUCCESS; + + tokio::select! { + (service_name, result) = service_manager.run() => { + // A service completed unexpectedly + match result { + ServiceResult::GracefulShutdown => { + info!(service = service_name, "service completed gracefully"); + } + ServiceResult::NormalCompletion => { + warn!(service = service_name, "service completed unexpectedly"); + exit_code = ExitCode::FAILURE; + } + ServiceResult::Error(e) => { + error!(service = service_name, error = ?e, "service failed"); + exit_code = ExitCode::FAILURE; + } + } + + // Shutdown remaining services + exit_code = handle_graceful_shutdown(service_manager, shutdown_timeout, exit_code).await; + } + _ = ctrl_c => { + // User requested shutdown via Ctrl+C + info!("user requested shutdown via ctrl+c"); + exit_code = handle_graceful_shutdown(service_manager, shutdown_timeout, ExitCode::SUCCESS).await; + } + _ = sigterm => { + // System requested shutdown via SIGTERM + info!("system requested shutdown via SIGTERM"); + exit_code = handle_graceful_shutdown(service_manager, shutdown_timeout, ExitCode::SUCCESS).await; + } + } + + info!(exit_code = ?exit_code, "application shutdown complete"); + exit_code +} + +/// Handle graceful shutdown of remaining services +async fn handle_graceful_shutdown( + mut service_manager: ServiceManager, + shutdown_timeout: Duration, + current_exit_code: ExitCode, +) -> ExitCode { + match service_manager.shutdown(shutdown_timeout).await { + Ok(elapsed) => { + info!( + remaining = format!("{:.2?}", shutdown_timeout - elapsed), + "graceful shutdown complete" + ); + current_exit_code + } + Err(pending_services) => { + warn!( + pending_count = pending_services.len(), + pending_services = ?pending_services, + "graceful shutdown elapsed - {} service(s) did not complete", + pending_services.len() + ); + + // Non-zero exit code, default to FAILURE if not set + if current_exit_code == ExitCode::SUCCESS { + ExitCode::FAILURE + } else { + current_exit_code + } + } + } +}