diff --git a/src/main.rs b/src/main.rs index 7e419bb..cd1fca6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,12 +35,35 @@ const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Pretty; const DEFAULT_TRACING_FORMAT: TracingFormat = TracingFormat::Json; /// Banner Discord Bot - Course availability monitoring +/// +/// This application runs multiple services that can be controlled via CLI arguments: +/// - bot: Discord bot for course monitoring commands +/// - web: HTTP server for web interface and API +/// - scraper: Background service for scraping course data +/// +/// Use --services to specify which services to run, or --disable-services to exclude specific services. #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Log formatter to use #[arg(long, value_enum, default_value_t = DEFAULT_TRACING_FORMAT)] tracing: TracingFormat, + + /// Services to run (comma-separated). Default: all services + /// + /// Examples: + /// --services bot,web # Run only bot and web services + /// --services scraper # Run only the scraper service + #[arg(long, value_delimiter = ',', conflicts_with = "disable_services")] + services: Option>, + + /// Services to disable (comma-separated) + /// + /// Examples: + /// --disable-services bot # Run web and scraper only + /// --disable-services bot,web # Run only the scraper service + #[arg(long, value_delimiter = ',', conflicts_with = "services")] + disable_services: Option>, } #[derive(clap::ValueEnum, Clone, Debug)] @@ -51,6 +74,58 @@ enum TracingFormat { Json, } +#[derive(clap::ValueEnum, Clone, Debug, PartialEq)] +enum ServiceName { + /// Discord bot for course monitoring commands + Bot, + /// HTTP server for web interface and API + Web, + /// Background service for scraping course data + Scraper, +} + +impl ServiceName { + /// Get all available services + fn all() -> Vec { + vec![ServiceName::Bot, ServiceName::Web, ServiceName::Scraper] + } + + /// Convert to string for service registration + fn as_str(&self) -> &'static str { + match self { + ServiceName::Bot => "bot", + ServiceName::Web => "web", + ServiceName::Scraper => "scraper", + } + } +} + +/// Determine which services should be enabled based on CLI arguments +fn determine_enabled_services(args: &Args) -> Result, anyhow::Error> { + match (&args.services, &args.disable_services) { + (Some(services), None) => { + // User specified which services to run + Ok(services.clone()) + } + (None, Some(disabled)) => { + // User specified which services to disable + let enabled: Vec = ServiceName::all() + .into_iter() + .filter(|s| !disabled.contains(s)) + .collect(); + Ok(enabled) + } + (None, None) => { + // Default: run all services + Ok(ServiceName::all()) + } + (Some(_), Some(_)) => { + // This should be prevented by clap's conflicts_with, but just in case + Err(anyhow::anyhow!("Cannot specify both --services and --disable-services")) + } + } +} + async fn update_bot_status(ctx: &Context, app_state: &AppState) -> Result<(), anyhow::Error> { let course_count = app_state.get_course_count().await?; @@ -70,6 +145,15 @@ async fn main() { // Parse CLI arguments let args = Args::parse(); + // Determine which services should be enabled + let enabled_services: Vec = + determine_enabled_services(&args).expect("Failed to determine enabled services"); + + info!( + enabled_services = ?enabled_services, + "services configuration loaded" + ); + // Load configuration first to get log level let config: Config = Figment::new() .merge(Env::raw().map(|k| { @@ -85,7 +169,7 @@ async fn main() { // Configure logging based on config let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { let base_level = &config.log_level; - EnvFilter::new(&format!( + EnvFilter::new(format!( "warn,banner={},banner::rate_limiter=warn,banner::session=warn,banner::rate_limit_middleware=warn", base_level )) @@ -183,7 +267,7 @@ async fn main() { span.record("msg.author", ctx.author().tag().as_str()); span.record("msg.id", ctx.id()); span.record("msg.channel_id", ctx.channel_id().get()); - span.record("msg.channel", &channel_name.as_str()); + span.record("msg.channel", channel_name.as_str()); tracing::info!( command_name = ctx.command().qualified_name.as_str(), @@ -258,14 +342,28 @@ async fn main() { // Create service manager let mut service_manager = ServiceManager::new(); - // Register services with the manager - let bot_service = Box::new(BotService::new(client)); - let web_service = Box::new(WebService::new(port, banner_state)); - let scraper_service = Box::new(ScraperService::new(db_pool.clone(), banner_api_arc.clone())); + // Register enabled services with the manager + if enabled_services.contains(&ServiceName::Bot) { + let bot_service = Box::new(BotService::new(client)); + service_manager.register_service(ServiceName::Bot.as_str(), bot_service); + } - service_manager.register_service("bot", bot_service); - service_manager.register_service("web", web_service); - service_manager.register_service("scraper", scraper_service); + if enabled_services.contains(&ServiceName::Web) { + let web_service = Box::new(WebService::new(port, banner_state)); + service_manager.register_service(ServiceName::Web.as_str(), web_service); + } + + if enabled_services.contains(&ServiceName::Scraper) { + let scraper_service = + Box::new(ScraperService::new(db_pool.clone(), banner_api_arc.clone())); + service_manager.register_service(ServiceName::Scraper.as_str(), scraper_service); + } + + // Check if any services are enabled + if !service_manager.has_services() { + error!("No services enabled. Cannot start application."); + std::process::exit(1); + } // Spawn all registered services service_manager.spawn_all(); diff --git a/src/services/manager.rs b/src/services/manager.rs index 5e47a8e..89779c4 100644 --- a/src/services/manager.rs +++ b/src/services/manager.rs @@ -34,6 +34,11 @@ impl ServiceManager { self.registered_services.insert(name.to_string(), service); } + /// Check if there are any registered services + pub fn has_services(&self) -> bool { + !self.registered_services.is_empty() + } + /// Spawn all registered services pub fn spawn_all(&mut self) { let service_count = self.registered_services.len();