diff --git a/Cargo.lock b/Cargo.lock index eb5bfbf..f459ceb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,7 +253,6 @@ dependencies = [ "sqlx", "thiserror 2.0.16", "time", - "tl", "tokio", "tokio-util", "tower-http", @@ -3368,12 +3367,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tl" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b130bd8a58c163224b44e217b4239ca7b927d82bf6cc2fea1fc561d15056e3f7" - [[package]] name = "tokio" version = "1.47.1" diff --git a/Cargo.toml b/Cargo.toml index ebed4ca..42faba7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,6 @@ thiserror = "2.0.16" time = "0.3.43" tokio = { version = "1.47.1", features = ["full"] } tokio-util = "0.7" -tl = "0.7.8" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } url = "2.5" diff --git a/Justfile b/Justfile index 1d6443d..c13eec2 100644 --- a/Justfile +++ b/Justfile @@ -195,3 +195,6 @@ test-smoke port="18080": alias b := bun bun *ARGS: cd web && bun {{ ARGS }} + +sql *ARGS: + lazysql ${DATABASE_URL} diff --git a/src/app.rs b/src/app.rs index 5f18b3e..69f93b2 100644 --- a/src/app.rs +++ b/src/app.rs @@ -12,6 +12,7 @@ use sqlx::postgres::PgPoolOptions; use std::process::ExitCode; use std::sync::Arc; use std::time::Duration; +use anyhow::Context; use tracing::{error, info}; /// Main application struct containing all necessary components @@ -36,7 +37,7 @@ impl App { } })) .extract() - .expect("Failed to load config"); + .context("Failed to load config")?; // Check if the database URL is via private networking let is_private = config.database_url.contains("railway.internal"); @@ -52,7 +53,7 @@ impl App { .max_lifetime(Duration::from_secs(60 * 30)) .connect(&config.database_url) .await - .expect("Failed to create database pool"); + .context("Failed to create database pool")?; info!( is_private = is_private, @@ -65,7 +66,7 @@ impl App { sqlx::migrate!("./migrations") .run(&db_pool) .await - .expect("Failed to run database migrations"); + .context("Failed to run database migrations")?; info!("Database migrations completed successfully"); // Create BannerApi and AppState @@ -73,7 +74,7 @@ impl App { config.banner_base_url.clone(), config.rate_limiting.clone(), ) - .expect("Failed to create BannerApi"); + .context("Failed to create BannerApi")?; let banner_api_arc = Arc::new(banner_api); let app_state = AppState::new(banner_api_arc.clone(), db_pool.clone()); @@ -91,7 +92,7 @@ impl App { pub fn setup_services(&mut self, services: &[ServiceName]) -> Result<(), anyhow::Error> { // Register enabled services with the manager if services.contains(&ServiceName::Web) { - let web_service = Box::new(WebService::new(self.config.port)); + let web_service = Box::new(WebService::new(self.config.port, self.app_state.clone())); self.service_manager .register_service(ServiceName::Web.as_str(), web_service); } @@ -100,6 +101,7 @@ impl App { let scraper_service = Box::new(ScraperService::new( self.db_pool.clone(), self.banner_api.clone(), + self.app_state.service_statuses.clone(), )); self.service_manager .register_service(ServiceName::Scraper.as_str(), scraper_service); @@ -130,12 +132,13 @@ impl App { status_shutdown_rx, ) .await - .expect("Failed to create Discord client"); + .context("Failed to create Discord client")?; let bot_service = Box::new(BotService::new( client, status_task_handle, status_shutdown_tx, + self.app_state.service_statuses.clone(), )); self.service_manager diff --git a/src/banner/api.rs b/src/banner/api.rs index 2e30270..04ce5cf 100644 --- a/src/banner/api.rs +++ b/src/banner/api.rs @@ -21,9 +21,9 @@ pub struct BannerApi { base_url: String, } -#[allow(dead_code)] impl BannerApi { /// Creates a new Banner API client. + #[allow(dead_code)] pub fn new(base_url: String) -> Result { Self::new_with_config(base_url, RateLimitingConfig::default()) } @@ -231,30 +231,6 @@ impl BannerApi { .await } - /// Retrieves a list of instructors from the Banner API. - pub async fn get_instructors( - &self, - search: &str, - term: &str, - offset: i32, - max_results: i32, - ) -> Result> { - self.get_list_endpoint("get_instructor", search, term, offset, max_results) - .await - } - - /// Retrieves a list of campuses from the Banner API. - pub async fn get_campuses( - &self, - search: &str, - term: &str, - offset: i32, - max_results: i32, - ) -> Result> { - self.get_list_endpoint("get_campus", search, term, offset, max_results) - .await - } - /// Retrieves meeting time information for a course. pub async fn get_course_meeting_time( &self, diff --git a/src/banner/models/common.rs b/src/banner/models/common.rs index a3f3a69..029dc14 100644 --- a/src/banner/models/common.rs +++ b/src/banner/models/common.rs @@ -11,6 +11,7 @@ pub struct Pair { pub type BannerTerm = Pair; /// Represents an instructor in the Banner system +#[allow(dead_code)] pub type Instructor = Pair; impl BannerTerm { diff --git a/src/banner/models/meetings.rs b/src/banner/models/meetings.rs index 1a06998..677bf25 100644 --- a/src/banner/models/meetings.rs +++ b/src/banner/models/meetings.rs @@ -1,8 +1,8 @@ -use bitflags::{Flags, bitflags}; +use bitflags::{bitflags, Flags}; use chrono::{DateTime, NaiveDate, NaiveTime, Timelike, Utc, Weekday}; use extension_traits::extension; use serde::{Deserialize, Deserializer, Serialize}; -use std::{cmp::Ordering, fmt::Display, str::FromStr}; +use std::{cmp::Ordering, str::FromStr}; use super::terms::Term; @@ -394,26 +394,6 @@ impl MeetingLocation { } } -impl Display for MeetingLocation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MeetingLocation::Online => write!(f, "Online"), - MeetingLocation::InPerson { - campus, - building, - building_description, - room, - .. - } => write!( - f, - "{campus} | {building_name} | {building_code} {room}", - building_name = building_description, - building_code = building, - ), - } - } -} - /// Clean, parsed meeting schedule information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MeetingScheduleInfo { diff --git a/src/banner/query.rs b/src/banner/query.rs index 8900cb0..ea710bd 100644 --- a/src/banner/query.rs +++ b/src/banner/query.rs @@ -32,7 +32,6 @@ pub struct SearchQuery { course_number_range: Option, } -#[allow(dead_code)] impl SearchQuery { /// Creates a new SearchQuery with default values pub fn new() -> Self { @@ -68,6 +67,7 @@ impl SearchQuery { } /// Adds a keyword to the query + #[allow(dead_code)] pub fn keyword>(mut self, keyword: S) -> Self { match &mut self.keywords { Some(keywords) => keywords.push(keyword.into()), @@ -77,54 +77,63 @@ impl SearchQuery { } /// Sets whether to search for open courses only + #[allow(dead_code)] pub fn open_only(mut self, open_only: bool) -> Self { self.open_only = Some(open_only); self } /// Sets the term part for the query + #[allow(dead_code)] pub fn term_part(mut self, term_part: Vec) -> Self { self.term_part = Some(term_part); self } /// Sets the campuses for the query + #[allow(dead_code)] pub fn campus(mut self, campus: Vec) -> Self { self.campus = Some(campus); self } /// Sets the instructional methods for the query + #[allow(dead_code)] pub fn instructional_method(mut self, instructional_method: Vec) -> Self { self.instructional_method = Some(instructional_method); self } /// Sets the attributes for the query + #[allow(dead_code)] pub fn attributes(mut self, attributes: Vec) -> Self { self.attributes = Some(attributes); self } /// Sets the instructors for the query + #[allow(dead_code)] pub fn instructor(mut self, instructor: Vec) -> Self { self.instructor = Some(instructor); self } /// Sets the start time for the query + #[allow(dead_code)] pub fn start_time(mut self, start_time: Duration) -> Self { self.start_time = Some(start_time); self } /// Sets the end time for the query + #[allow(dead_code)] pub fn end_time(mut self, end_time: Duration) -> Self { self.end_time = Some(end_time); self } /// Sets the credit range for the query + #[allow(dead_code)] pub fn credits(mut self, low: i32, high: i32) -> Self { self.min_credits = Some(low); self.max_credits = Some(high); @@ -132,12 +141,14 @@ impl SearchQuery { } /// Sets the minimum credits for the query + #[allow(dead_code)] pub fn min_credits(mut self, value: i32) -> Self { self.min_credits = Some(value); self } /// Sets the maximum credits for the query + #[allow(dead_code)] pub fn max_credits(mut self, value: i32) -> Self { self.max_credits = Some(value); self @@ -150,6 +161,7 @@ impl SearchQuery { } /// Sets the offset for pagination + #[allow(dead_code)] pub fn offset(mut self, offset: i32) -> Self { self.offset = offset; self @@ -253,27 +265,25 @@ impl SearchQuery { } } -/// Formats a Duration into hour, minute, and meridiem strings for Banner API +/// Formats a Duration into hour, minute, and meridiem strings for Banner API. +/// +/// Uses 12-hour format: midnight = 12:00 AM, noon = 12:00 PM. fn format_time_parameter(duration: Duration) -> (String, String, String) { let total_minutes = duration.as_secs() / 60; let hours = total_minutes / 60; let minutes = total_minutes % 60; - let minute_str = minutes.to_string(); + let meridiem = if hours >= 12 { "PM" } else { "AM" }; + let hour_12 = match hours % 12 { + 0 => 12, + h => h, + }; - if hours >= 12 { - let meridiem = "PM".to_string(); - let hour_str = if hours >= 13 { - (hours - 12).to_string() - } else { - hours.to_string() - }; - (hour_str, minute_str, meridiem) - } else { - let meridiem = "AM".to_string(); - let hour_str = hours.to_string(); - (hour_str, minute_str, meridiem) - } + ( + hour_12.to_string(), + minutes.to_string(), + meridiem.to_string(), + ) } #[cfg(test)] @@ -394,7 +404,7 @@ mod tests { #[test] fn test_format_time_midnight() { let (h, m, mer) = format_time_parameter(Duration::from_secs(0)); - assert_eq!(h, "0"); + assert_eq!(h, "12"); assert_eq!(m, "0"); assert_eq!(mer, "AM"); } diff --git a/src/bot/commands/ics.rs b/src/bot/commands/ics.rs index 78e0727..ddf2b4f 100644 --- a/src/bot/commands/ics.rs +++ b/src/bot/commands/ics.rs @@ -2,116 +2,78 @@ use crate::banner::{Course, MeetingDays, MeetingScheduleInfo, WeekdayExt}; use crate::bot::{Context, Error, utils}; -use chrono::{Datelike, NaiveDate, Utc}; +use chrono::{Datelike, Duration, NaiveDate, Utc, Weekday}; use serenity::all::CreateAttachment; use tracing::info; -/// Represents a holiday or special day that should be excluded from class schedules -#[derive(Debug, Clone)] -enum Holiday { - /// A single-day holiday - Single { month: u32, day: u32 }, - /// A multi-day holiday range - Range { - month: u32, - start_day: u32, - end_day: u32, - }, +/// Find the nth occurrence of a weekday in a given month/year (1-based). +fn nth_weekday_of_month(year: i32, month: u32, weekday: Weekday, n: u32) -> Option { + let first = NaiveDate::from_ymd_opt(year, month, 1)?; + let days_ahead = (weekday.num_days_from_monday() as i64 + - first.weekday().num_days_from_monday() as i64) + .rem_euclid(7) as u32; + let day = 1 + days_ahead + 7 * (n - 1); + NaiveDate::from_ymd_opt(year, month, day) } -impl Holiday { - /// Check if a specific date falls within this holiday - fn contains_date(&self, date: NaiveDate) -> bool { - match self { - Holiday::Single { month, day, .. } => date.month() == *month && date.day() == *day, - Holiday::Range { - month, - start_day, - end_day, - .. - } => date.month() == *month && date.day() >= *start_day && date.day() <= *end_day, - } - } - - /// Get all dates in this holiday for a given year - fn get_dates_for_year(&self, year: i32) -> Vec { - match self { - Holiday::Single { month, day, .. } => { - if let Some(date) = NaiveDate::from_ymd_opt(year, *month, *day) { - vec![date] - } else { - Vec::new() - } - } - Holiday::Range { - month, - start_day, - end_day, - .. - } => { - let mut dates = Vec::new(); - for day in *start_day..=*end_day { - if let Some(date) = NaiveDate::from_ymd_opt(year, *month, day) { - dates.push(date); - } - } - dates - } - } - } +/// Compute a consecutive range of dates starting from `start` for `count` days. +fn date_range(start: NaiveDate, count: i64) -> Vec { + (0..count).filter_map(|i| start.checked_add_signed(Duration::days(i))).collect() } -/// University holidays excluded from class schedules. +/// Compute university holidays for a given year. /// -/// WARNING: These dates are specific to the UTSA 2024-2025 academic calendar and must be -/// updated each academic year. Many of these holidays fall on different dates annually -/// (e.g., Labor Day is the first Monday of September, Thanksgiving is the fourth Thursday -/// of November). Ideally these would be loaded from a configuration file or computed -/// dynamically from federal/university calendar rules. -// TODO: Load holiday dates from configuration or compute dynamically per academic year. -const UNIVERSITY_HOLIDAYS: &[(&str, Holiday)] = &[ - ("Labor Day", Holiday::Single { month: 9, day: 1 }), - ( - "Fall Break", - Holiday::Range { - month: 10, - start_day: 13, - end_day: 14, - }, - ), - ( - "Unspecified Holiday", - Holiday::Single { month: 11, day: 26 }, - ), - ( - "Thanksgiving", - Holiday::Range { - month: 11, - start_day: 28, - end_day: 29, - }, - ), - ("Student Study Day", Holiday::Single { month: 12, day: 5 }), - ( - "Winter Holiday", - Holiday::Range { - month: 12, - start_day: 23, - end_day: 31, - }, - ), - ("New Year's Day", Holiday::Single { month: 1, day: 1 }), - ("MLK Day", Holiday::Single { month: 1, day: 20 }), - ( - "Spring Break", - Holiday::Range { - month: 3, - start_day: 10, - end_day: 15, - }, - ), - ("Student Study Day", Holiday::Single { month: 5, day: 9 }), -]; +/// Federal holidays use weekday-of-month rules so they're correct for any year. +/// University-specific breaks (Fall Break, Spring Break, Winter Holiday) are derived +/// from anchoring federal holidays or using UTSA's typical scheduling patterns. +fn compute_holidays_for_year(year: i32) -> Vec<(&'static str, Vec)> { + let mut holidays = Vec::new(); + + // Labor Day: 1st Monday of September + if let Some(d) = nth_weekday_of_month(year, 9, Weekday::Mon, 1) { + holidays.push(("Labor Day", vec![d])); + } + + // Fall Break: Mon-Tue of Columbus Day week (2nd Monday of October + Tuesday) + if let Some(mon) = nth_weekday_of_month(year, 10, Weekday::Mon, 2) { + holidays.push(("Fall Break", date_range(mon, 2))); + } + + // Day before Thanksgiving: Wednesday before 4th Thursday of November + if let Some(thu) = nth_weekday_of_month(year, 11, Weekday::Thu, 4) + && let Some(wed) = thu.checked_sub_signed(Duration::days(1)) + { + holidays.push(("Day Before Thanksgiving", vec![wed])); + } + + // Thanksgiving: 4th Thursday of November + Friday + if let Some(thu) = nth_weekday_of_month(year, 11, Weekday::Thu, 4) { + holidays.push(("Thanksgiving", date_range(thu, 2))); + } + + // Winter Holiday: Dec 23-31 + if let Some(start) = NaiveDate::from_ymd_opt(year, 12, 23) { + holidays.push(("Winter Holiday", date_range(start, 9))); + } + + // New Year's Day: January 1 + if let Some(d) = NaiveDate::from_ymd_opt(year, 1, 1) { + holidays.push(("New Year's Day", vec![d])); + } + + // MLK Day: 3rd Monday of January + if let Some(d) = nth_weekday_of_month(year, 1, Weekday::Mon, 3) { + holidays.push(("MLK Day", vec![d])); + } + + // Spring Break: full week (Mon-Sat) starting the 2nd or 3rd Monday of March + // UTSA typically uses the 2nd full week of March + if let Some(mon) = nth_weekday_of_month(year, 3, Weekday::Mon, 2) { + holidays.push(("Spring Break", date_range(mon, 6))); + } + + holidays +} /// Generate an ICS file for a course #[poise::command(slash_command, prefix_command)] @@ -329,10 +291,16 @@ fn generate_event_content( } // Collect holiday names for reporting + let start_year = meeting_time.date_range.start.year(); + let end_year = meeting_time.date_range.end.year(); + let all_holidays: Vec<_> = (start_year..=end_year) + .flat_map(compute_holidays_for_year) + .collect(); + let mut holiday_names = Vec::new(); - for (holiday_name, holiday) in UNIVERSITY_HOLIDAYS { + for (holiday_name, holiday_dates) in &all_holidays { for &exception_date in &holiday_exceptions { - if holiday.contains_date(exception_date) { + if holiday_dates.contains(&exception_date) { holiday_names.push(format!( "{} ({})", holiday_name, @@ -344,6 +312,7 @@ fn generate_event_content( holiday_names.sort(); holiday_names.dedup(); + event_content.push_str("END:VEVENT\r\n"); return Ok((event_content, holiday_names)); } } @@ -362,32 +331,18 @@ fn class_meets_on_date(meeting_time: &MeetingScheduleInfo, date: NaiveDate) -> b /// Get holiday dates that fall within the course date range and would conflict with class meetings fn get_holiday_exceptions(meeting_time: &MeetingScheduleInfo) -> Vec { - let mut exceptions = Vec::new(); - - // Get the year range from the course date range let start_year = meeting_time.date_range.start.year(); let end_year = meeting_time.date_range.end.year(); - for (_, holiday) in UNIVERSITY_HOLIDAYS { - // Check for the holiday in each year of the course - for year in start_year..=end_year { - let holiday_dates = holiday.get_dates_for_year(year); - - for holiday_date in holiday_dates { - // Check if the holiday falls within the course date range - if holiday_date >= meeting_time.date_range.start - && holiday_date <= meeting_time.date_range.end - { - // Check if the class would actually meet on this day - if class_meets_on_date(meeting_time, holiday_date) { - exceptions.push(holiday_date); - } - } - } - } - } - - exceptions + (start_year..=end_year) + .flat_map(compute_holidays_for_year) + .flat_map(|(_, dates)| dates) + .filter(|&date| { + date >= meeting_time.date_range.start + && date <= meeting_time.date_range.end + && class_meets_on_date(meeting_time, date) + }) + .collect() } /// Generate EXDATE property for holiday exceptions diff --git a/src/bot/commands/search.rs b/src/bot/commands/search.rs index d536e3c..44c759c 100644 --- a/src/bot/commands/search.rs +++ b/src/bot/commands/search.rs @@ -24,8 +24,8 @@ pub async fn search( // Defer the response since this might take a while ctx.defer().await?; - // Build the search query - let mut query = SearchQuery::new().credits(3, 6); + // Build the search query — no default credit filter so all courses are visible + let mut query = SearchQuery::new(); if let Some(title) = title { query = query.title(title); diff --git a/src/data/scrape_jobs.rs b/src/data/scrape_jobs.rs index c0e0975..5f3a175 100644 --- a/src/data/scrape_jobs.rs +++ b/src/data/scrape_jobs.rs @@ -90,7 +90,7 @@ pub async fn unlock_and_increment_retry( "UPDATE scrape_jobs SET locked_at = NULL, retry_count = retry_count + 1 WHERE id = $1 - RETURNING CASE WHEN retry_count + 1 < $2 THEN retry_count + 1 ELSE NULL END", + RETURNING CASE WHEN retry_count < $2 THEN retry_count ELSE NULL END", ) .bind(job_id) .bind(max_retries) diff --git a/src/lib.rs b/src/lib.rs index 241783d..840f68d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,5 +11,6 @@ pub mod scraper; pub mod services; pub mod signals; pub mod state; +pub mod status; pub mod utils; pub mod web; diff --git a/src/main.rs b/src/main.rs index 39885a6..ab176e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,8 @@ mod scraper; mod services; mod signals; mod state; +#[allow(dead_code)] +mod status; mod web; #[tokio::main] @@ -31,17 +33,17 @@ async fn main() -> ExitCode { let enabled_services: Vec = determine_enabled_services(&args).expect("Failed to determine enabled services"); + // Create and initialize the application + let mut app = App::new().await.expect("Failed to initialize application"); + + // Setup logging — must happen before any info!() calls to avoid silently dropped logs + setup_logging(app.config(), args.tracing); + info!( enabled_services = ?enabled_services, "services configuration loaded" ); - // Create and initialize the application - let mut app = App::new().await.expect("Failed to initialize application"); - - // Setup logging - setup_logging(app.config(), args.tracing); - // Log application startup context info!( version = env!("CARGO_PKG_VERSION"), diff --git a/src/scraper/mod.rs b/src/scraper/mod.rs index a9e1bd5..5926e27 100644 --- a/src/scraper/mod.rs +++ b/src/scraper/mod.rs @@ -4,6 +4,7 @@ pub mod worker; use crate::banner::BannerApi; use crate::services::Service; +use crate::status::{ServiceStatus, ServiceStatusRegistry}; use sqlx::PgPool; use std::sync::Arc; use tokio::sync::broadcast; @@ -20,6 +21,7 @@ use self::worker::Worker; pub struct ScraperService { db_pool: PgPool, banner_api: Arc, + service_statuses: ServiceStatusRegistry, scheduler_handle: Option>, worker_handles: Vec>, shutdown_tx: Option>, @@ -27,10 +29,11 @@ pub struct ScraperService { impl ScraperService { /// Creates a new `ScraperService`. - pub fn new(db_pool: PgPool, banner_api: Arc) -> Self { + pub fn new(db_pool: PgPool, banner_api: Arc, service_statuses: ServiceStatusRegistry) -> Self { Self { db_pool, banner_api, + service_statuses, scheduler_handle: None, worker_handles: Vec::new(), shutdown_tx: None, @@ -66,6 +69,7 @@ impl ScraperService { worker_count = self.worker_handles.len(), "Spawned worker tasks" ); + self.service_statuses.set("scraper", ServiceStatus::Active); } } @@ -82,6 +86,7 @@ impl Service for ScraperService { } async fn shutdown(&mut self) -> Result<(), anyhow::Error> { + self.service_statuses.set("scraper", ServiceStatus::Disabled); info!("Shutting down scraper service"); // Send shutdown signal to all tasks diff --git a/src/services/bot.rs b/src/services/bot.rs index 4f950c1..79377b8 100644 --- a/src/services/bot.rs +++ b/src/services/bot.rs @@ -2,6 +2,7 @@ use super::Service; use crate::bot::{Data, get_commands}; use crate::config::Config; use crate::state::AppState; +use crate::status::{ServiceStatus, ServiceStatusRegistry}; use num_format::{Locale, ToFormattedString}; use serenity::Client; use serenity::all::{ActivityData, ClientBuilder, GatewayIntents}; @@ -17,6 +18,7 @@ pub struct BotService { shard_manager: Arc, status_task_handle: Arc>>>, status_shutdown_tx: Option>, + service_statuses: ServiceStatusRegistry, } impl BotService { @@ -98,6 +100,8 @@ impl BotService { ); *status_task_handle.lock().await = Some(handle); + app_state.service_statuses.set("bot", ServiceStatus::Active); + Ok(Data { app_state }) }) }) @@ -186,6 +190,7 @@ impl BotService { client: Client, status_task_handle: Arc>>>, status_shutdown_tx: broadcast::Sender<()>, + service_statuses: ServiceStatusRegistry, ) -> Self { let shard_manager = client.shard_manager.clone(); @@ -194,6 +199,7 @@ impl BotService { shard_manager, status_task_handle, status_shutdown_tx: Some(status_shutdown_tx), + service_statuses, } } } @@ -218,6 +224,7 @@ impl Service for BotService { } async fn shutdown(&mut self) -> Result<(), anyhow::Error> { + self.service_statuses.set("bot", ServiceStatus::Disabled); // Signal status update task to stop if let Some(status_shutdown_tx) = self.status_shutdown_tx.take() { let _ = status_shutdown_tx.send(()); diff --git a/src/services/web.rs b/src/services/web.rs index 1db2111..e53cc8c 100644 --- a/src/services/web.rs +++ b/src/services/web.rs @@ -1,4 +1,6 @@ use super::Service; +use crate::state::AppState; +use crate::status::ServiceStatus; use crate::web::create_router; use std::net::SocketAddr; use tokio::net::TcpListener; @@ -8,16 +10,47 @@ use tracing::{info, trace, warn}; /// Web server service implementation pub struct WebService { port: u16, + app_state: AppState, shutdown_tx: Option>, } impl WebService { - pub fn new(port: u16) -> Self { + pub fn new(port: u16, app_state: AppState) -> Self { Self { port, + app_state, shutdown_tx: None, } } + /// Periodically pings the database and updates the "database" service status. + async fn db_health_check_loop( + state: AppState, + mut shutdown_rx: broadcast::Receiver<()>, + ) { + use std::time::Duration; + let mut interval = tokio::time::interval(Duration::from_secs(30)); + + loop { + tokio::select! { + _ = interval.tick() => { + let status = match sqlx::query_scalar::<_, i32>("SELECT 1") + .fetch_one(&state.db_pool) + .await + { + Ok(_) => ServiceStatus::Connected, + Err(e) => { + warn!(error = %e, "DB health check failed"); + ServiceStatus::Error + } + }; + state.service_statuses.set("database", status); + } + _ = shutdown_rx.recv() => { + break; + } + } + } + } } #[async_trait::async_trait] @@ -28,11 +61,12 @@ impl Service for WebService { async fn run(&mut self) -> Result<(), anyhow::Error> { // Create the main router with Banner API routes - let app = create_router(); + let app = create_router(self.app_state.clone()); let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); let listener = TcpListener::bind(addr).await?; + self.app_state.service_statuses.set("web", ServiceStatus::Active); info!( service = "web", address = %addr, @@ -42,7 +76,14 @@ impl Service for WebService { // Create internal shutdown channel for axum graceful shutdown let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); - self.shutdown_tx = Some(shutdown_tx); + self.shutdown_tx = Some(shutdown_tx.clone()); + + // Spawn background DB health check + let health_state = self.app_state.clone(); + let health_shutdown_rx = shutdown_tx.subscribe(); + tokio::spawn(async move { + Self::db_health_check_loop(health_state, health_shutdown_rx).await; + }); // Use axum's graceful shutdown with the internal shutdown signal axum::serve(listener, app) diff --git a/src/state.rs b/src/state.rs index f0187d0..a5be5e1 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2,6 +2,7 @@ use crate::banner::BannerApi; use crate::banner::Course; +use crate::status::ServiceStatusRegistry; use anyhow::Result; use sqlx::PgPool; use std::sync::Arc; @@ -10,6 +11,7 @@ use std::sync::Arc; pub struct AppState { pub banner_api: Arc, pub db_pool: PgPool, + pub service_statuses: ServiceStatusRegistry, } impl AppState { @@ -17,6 +19,7 @@ impl AppState { Self { banner_api, db_pool, + service_statuses: ServiceStatusRegistry::new(), } } diff --git a/src/status.rs b/src/status.rs new file mode 100644 index 0000000..a6b191a --- /dev/null +++ b/src/status.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; +use std::time::Instant; + +use dashmap::DashMap; +use serde::Serialize; + +/// Health status of a service. +#[derive(Debug, Clone, Serialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ServiceStatus { + Starting, + Active, + Connected, + Disabled, + Error, +} + +/// A timestamped status entry for a service. +#[derive(Debug, Clone)] +pub struct StatusEntry { + pub status: ServiceStatus, + pub updated_at: Instant, +} + +/// Thread-safe registry for services to self-report their health status. +#[derive(Debug, Clone, Default)] +pub struct ServiceStatusRegistry { + inner: Arc>, +} + +impl ServiceStatusRegistry { + /// Creates a new empty registry. + pub fn new() -> Self { + Self::default() + } + + /// Inserts or updates the status for a named service. + pub fn set(&self, name: &str, status: ServiceStatus) { + self.inner.insert( + name.to_owned(), + StatusEntry { + status, + updated_at: Instant::now(), + }, + ); + } + + /// Returns the current status of a named service, if present. + pub fn get(&self, name: &str) -> Option { + self.inner.get(name).map(|entry| entry.status.clone()) + } + + /// Returns a snapshot of all service statuses. + pub fn all(&self) -> Vec<(String, ServiceStatus)> { + self.inner + .iter() + .map(|entry| (entry.key().clone(), entry.value().status.clone())) + .collect() + } +} diff --git a/src/web/routes.rs b/src/web/routes.rs index a48ad1c..b58c067 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -3,7 +3,7 @@ use axum::{ Router, body::Body, - extract::Request, + extract::{Request, State}, response::{Json, Response}, routing::get, }; @@ -17,6 +17,9 @@ use http::header; use serde::Serialize; use serde_json::{Value, json}; use std::{collections::BTreeMap, time::Duration}; + +use crate::state::AppState; +use crate::status::ServiceStatus; #[cfg(not(feature = "embed-assets"))] use tower_http::cors::{Any, CorsLayer}; use tower_http::{classify::ServerErrorsFailureClass, timeout::TimeoutLayer, trace::TraceLayer}; @@ -63,11 +66,12 @@ fn set_caching_headers(response: &mut Response, path: &str, etag: &str) { } /// Creates the web server router -pub fn create_router() -> Router { +pub fn create_router(app_state: AppState) -> Router { let api_router = Router::new() .route("/health", get(health)) .route("/status", get(status)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .with_state(app_state); let mut router = Router::new().nest("/api", api_router); @@ -155,7 +159,7 @@ async fn handle_spa_fallback_with_headers(uri: Uri, request_headers: HeaderMap) // Check if client has a matching ETag (conditional request) if let Some(etag) = request_headers.get(header::IF_NONE_MATCH) - && metadata.etag_matches(etag.to_str().unwrap()) + && etag.to_str().is_ok_and(|s| metadata.etag_matches(s)) { return StatusCode::NOT_MODIFIED.into_response(); } @@ -191,7 +195,7 @@ async fn handle_spa_fallback_with_headers(uri: Uri, request_headers: HeaderMap) // Check if client has a matching ETag for index.html if let Some(etag) = request_headers.get(header::IF_NONE_MATCH) - && metadata.etag_matches(etag.to_str().unwrap()) + && etag.to_str().is_ok_and(|s| metadata.etag_matches(s)) { return StatusCode::NOT_MODIFIED.into_response(); } @@ -217,70 +221,46 @@ async fn health() -> Json { })) } -#[derive(Serialize)] -enum Status { - Disabled, - Connected, - Active, - Healthy, - Error, -} - #[derive(Serialize)] struct ServiceInfo { name: String, - status: Status, + status: ServiceStatus, } #[derive(Serialize)] struct StatusResponse { - status: Status, + status: ServiceStatus, version: String, commit: String, services: BTreeMap, } /// Status endpoint showing bot and system status -async fn status() -> Json { +async fn status(State(state): State) -> Json { let mut services = BTreeMap::new(); - // Bot service status - hardcoded as disabled for now - services.insert( - "bot".to_string(), - ServiceInfo { - name: "Bot".to_string(), - status: Status::Disabled, - }, - ); + for (name, svc_status) in state.service_statuses.all() { + services.insert( + name.clone(), + ServiceInfo { + name, + status: svc_status, + }, + ); + } - // Banner API status - always connected for now - services.insert( - "banner".to_string(), - ServiceInfo { - name: "Banner".to_string(), - status: Status::Connected, - }, - ); - - // Discord status - hardcoded as disabled for now - services.insert( - "discord".to_string(), - ServiceInfo { - name: "Discord".to_string(), - status: Status::Disabled, - }, - ); - - let overall_status = if services.values().any(|s| matches!(s.status, Status::Error)) { - Status::Error - } else if services - .values() - .all(|s| matches!(s.status, Status::Active | Status::Connected)) + let overall_status = if services.values().any(|s| matches!(s.status, ServiceStatus::Error)) { + ServiceStatus::Error + } else if !services.is_empty() + && services + .values() + .all(|s| matches!(s.status, ServiceStatus::Active | ServiceStatus::Connected)) { - Status::Active + ServiceStatus::Active + } else if services.is_empty() { + ServiceStatus::Disabled } else { - // If we have any Disabled services but no errors, show as Healthy - Status::Healthy + ServiceStatus::Active }; Json(StatusResponse { diff --git a/web/src/lib/api.test.ts b/web/src/lib/api.test.ts index 3b4f437..ec098cc 100644 --- a/web/src/lib/api.test.ts +++ b/web/src/lib/api.test.ts @@ -31,11 +31,13 @@ describe("BannerApiClient", () => { it("should fetch status data", async () => { const mockStatus = { - status: "operational", - bot: { status: "running", uptime: "1h" }, - cache: { status: "connected", courses: "100", subjects: "50" }, - banner_api: { status: "connected" }, - timestamp: "2024-01-01T00:00:00Z", + status: "active", + version: "0.3.4", + commit: "abc1234", + services: { + web: { name: "web", status: "active" }, + database: { name: "database", status: "connected" }, + }, }; vi.mocked(fetch).mockResolvedValueOnce({ diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 5995d56..b67f3dc 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -6,7 +6,7 @@ export interface HealthResponse { timestamp: string; } -export type Status = "Disabled" | "Connected" | "Active" | "Healthy" | "Error"; +export type Status = "starting" | "active" | "connected" | "disabled" | "error"; export interface ServiceInfo { name: string; diff --git a/web/src/routes/index.tsx b/web/src/routes/index.tsx index bf5dffd..ed0371f 100644 --- a/web/src/routes/index.tsx +++ b/web/src/routes/index.tsx @@ -37,6 +37,9 @@ const SERVICE_ICONS: Record = { bot: Bot, banner: Globe, discord: MessageCircle, + database: Activity, + web: Globe, + scraper: Clock, }; interface ResponseTiming { @@ -80,11 +83,11 @@ const formatNumber = (num: number): string => { const getStatusIcon = (status: Status | "Unreachable"): StatusIcon => { const statusMap: Record = { - Active: { icon: CheckCircle, color: "green" }, - Connected: { icon: CheckCircle, color: "green" }, - Healthy: { icon: CheckCircle, color: "green" }, - Disabled: { icon: Circle, color: "gray" }, - Error: { icon: XCircle, color: "red" }, + active: { icon: CheckCircle, color: "green" }, + connected: { icon: CheckCircle, color: "green" }, + starting: { icon: Hourglass, color: "orange" }, + disabled: { icon: Circle, color: "gray" }, + error: { icon: XCircle, color: "red" }, Unreachable: { icon: WifiOff, color: "red" }, }; @@ -93,9 +96,9 @@ const getStatusIcon = (status: Status | "Unreachable"): StatusIcon => { const getOverallHealth = (state: StatusState): Status | "Unreachable" => { if (state.mode === "timeout") return "Unreachable"; - if (state.mode === "error") return "Error"; + if (state.mode === "error") return "error"; if (state.mode === "response") return state.status.status; - return "Error"; + return "error"; }; const getServices = (state: StatusState): Service[] => { @@ -116,8 +119,8 @@ const StatusDisplay = ({ status }: { status: Status | "Unreachable" }) => { {status}