diff --git a/.gitignore b/.gitignore index 34487fa..24554e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ .env cover.cov -./banner +/banner .*.go dumps/ js/ diff --git a/cmd/banner/main.go b/cmd/banner/main.go index 1daa3c9..ae6bdfa 100644 --- a/cmd/banner/main.go +++ b/cmd/banner/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "net/http" "net/http/cookiejar" @@ -139,14 +140,31 @@ func main() { initRedis(cfg) if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") { - // Start pprof server + // Start pprof server with graceful shutdown go func() { port := os.Getenv("PORT") log.Info().Str("port", port).Msg("Starting pprof server") - err := http.ListenAndServe(":"+port, nil) - if err != nil { - log.Fatal().Stack().Err(err).Msg("Cannot start pprof server") + server := &http.Server{ + Addr: ":" + port, + } + + // Start server in a separate goroutine + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatal().Stack().Err(err).Msg("Cannot start pprof server") + } + }() + + // Wait for context cancellation and then shutdown + <-cfg.Ctx.Done() + log.Info().Msg("Shutting down pprof server") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Error().Err(err).Msg("Pprof server forced to shutdown") } }() } @@ -157,8 +175,11 @@ func main() { log.Err(err).Msg("Cannot create cookie jar") } - // Create client, setup session (acquire cookies) - client := &http.Client{Jar: cookies} + // Create client with timeout, setup session (acquire cookies) + client := &http.Client{ + Jar: cookies, + Timeout: 30 * time.Second, + } cfg.SetClient(client) baseURL := os.Getenv("BANNER_BASE_URL") @@ -237,13 +258,20 @@ func main() { // Launch a goroutine to scrape the banner system periodically go func() { - for { - err := apiInstance.Scrape() - if err != nil { - log.Err(err).Stack().Msg("Periodic Scrape Failed") - } + ticker := time.NewTicker(3 * time.Minute) + defer ticker.Stop() - time.Sleep(3 * time.Minute) + for { + select { + case <-cfg.Ctx.Done(): + log.Info().Msg("Periodic scraper stopped due to context cancellation") + return + case <-ticker.C: + err := apiInstance.Scrape() + if err != nil { + log.Err(err).Stack().Msg("Periodic Scrape Failed") + } + } } }() @@ -260,6 +288,9 @@ func main() { closingSignal := <-stop botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds + // Cancel the context to signal all operations to stop + cfg.CancelFunc() + // Defers are called after this log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down") } diff --git a/internal/api/api.go b/internal/api/api.go index aec4beb..2676586 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -5,6 +5,7 @@ import ( "banner/internal/models" "banner/internal/utils" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -76,7 +77,14 @@ func (a *API) DoRequest(req *http.Request) (*http.Response, error) { Str("content-type", req.Header.Get("Content-Type")). Msg("Request") - res, err := a.config.Client.Do(req) + // Create a timeout context for this specific request + ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second) + defer cancel() + + // Clone the request with the timeout context + reqWithTimeout := req.Clone(ctx) + + res, err := a.config.Client.Do(reqWithTimeout) if err != nil { log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed") @@ -614,8 +622,12 @@ func (a *API) ResetDataForm() { // GetCourse retrieves the course information. // This course does not retrieve directly from the API, but rather uses scraped data stored in Redis. func (a *API) GetCourse(crn string) (*models.Course, error) { + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second) + defer cancel() + // Retrieve raw data - result, err := a.config.KV.Get(a.config.Ctx, fmt.Sprintf("class:%s", crn)).Result() + result, err := a.config.KV.Get(ctx, fmt.Sprintf("class:%s", crn)).Result() if err != nil { if err == redis.Nil { return nil, fmt.Errorf("course not found: %w", err) diff --git a/internal/api/scrape.go b/internal/api/scrape.go index e78f8f1..be4810b 100644 --- a/internal/api/scrape.go +++ b/internal/api/scrape.go @@ -3,6 +3,7 @@ package api import ( "banner/internal/models" "banner/internal/utils" + "context" "fmt" "math/rand" "time" @@ -72,8 +73,12 @@ func (a *API) GetExpiredSubjects() ([]string, error) { term := utils.Default(time.Now()).ToString() subjects := make([]string, 0) + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(a.config.Ctx, 10*time.Second) + defer cancel() + // Get all subjects - values, err := a.config.KV.MGet(a.config.Ctx, lo.Map(AllMajors, func(major string, _ int) string { + values, err := a.config.KV.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string { return fmt.Sprintf("scraped:%s:%s", major, term) })...).Result() if err != nil { @@ -161,7 +166,12 @@ func (a *API) ScrapeMajor(subject string) error { if totalClassCount == 0 { totalClassCount = -1 } - err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err() + + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second) + defer cancel() + + err := a.config.KV.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err() if err != nil { log.Error().Err(err).Msg("failed to mark major as scraped") } @@ -214,7 +224,11 @@ func (a *API) CalculateExpiry(term string, count int, priority bool) time.Durati // IntakeCourse stores a course in Redis. // This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future. func (a *API) IntakeCourse(course models.Course) error { - err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err() + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second) + defer cancel() + + err := a.config.KV.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err() if err != nil { return fmt.Errorf("failed to store class in Redis: %w", err) } diff --git a/internal/config/config.go b/internal/config/config.go index 545b629..084b0fa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,6 +10,7 @@ import ( type Config struct { Ctx context.Context + CancelFunc context.CancelFunc KV *redis.Client Client *http.Client IsDevelopment bool @@ -23,15 +24,17 @@ const ( ) func New() (*Config, error) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) loc, err := time.LoadLocation(CentralTimezoneName) if err != nil { + cancel() return nil, err } return &Config{ Ctx: ctx, + CancelFunc: cancel, CentralTimeLocation: loc, }, nil } diff --git a/internal/utils/helpers.go b/internal/utils/helpers.go index aeb8916..1db077f 100644 --- a/internal/utils/helpers.go +++ b/internal/utils/helpers.go @@ -58,7 +58,7 @@ func BuildRequestWithBody(cfg *config.Config, method string, path string, params } } - request, _ := http.NewRequest(method, requestUrl, body) + request, _ := http.NewRequestWithContext(cfg.Ctx, method, requestUrl, body) AddUserAgent(request) return request } diff --git a/internal/utils/meta.go b/internal/utils/meta.go index 3e6457f..c6713f1 100644 --- a/internal/utils/meta.go +++ b/internal/utils/meta.go @@ -2,6 +2,7 @@ package utils import ( "banner/internal/config" + "context" "time" "github.com/bwmarrin/discordgo" @@ -11,8 +12,12 @@ import ( // GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string { + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel() + // Check Redis for the guild name - guildName, err := cfg.KV.Get(cfg.Ctx, "guild:"+guildID+":name").Result() + guildName, err := cfg.KV.Get(ctx, "guild:"+guildID+":name").Result() if err != nil && err != redis.Nil { log.Error().Stack().Err(err).Msg("Error getting guild name from Redis") return "err" @@ -29,7 +34,9 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string log.Error().Stack().Err(err).Msg("Error getting guild name") // Store an invalid value in Redis so we don't keep trying to get the guild name - _, err := cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result() + ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel2() + _, err := cfg.KV.Set(ctx2, "guild:"+guildID+":name", "x", time.Minute*5).Result() if err != nil { log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis") } @@ -38,15 +45,21 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string } // Cache the guild name in Redis - cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3) + ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel3() + cfg.KV.Set(ctx3, "guild:"+guildID+":name", guild.Name, time.Hour*3) return guild.Name } // GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string { + // Create a timeout context for Redis operations + ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel() + // Check Redis for the channel name - channelName, err := cfg.KV.Get(cfg.Ctx, "channel:"+channelID+":name").Result() + channelName, err := cfg.KV.Get(ctx, "channel:"+channelID+":name").Result() if err != nil && err != redis.Nil { log.Error().Stack().Err(err).Msg("Error getting channel name from Redis") return "err" @@ -63,7 +76,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st log.Error().Stack().Err(err).Msg("Error getting channel name") // Store an invalid value in Redis so we don't keep trying to get the channel name - _, err := cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result() + ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel2() + _, err := cfg.KV.Set(ctx2, "channel:"+channelID+":name", "x", time.Minute*5).Result() if err != nil { log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis") } @@ -72,7 +87,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st } // Cache the channel name in Redis - cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3) + ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second) + defer cancel3() + cfg.KV.Set(ctx3, "channel:"+channelID+":name", channel.Name, time.Hour*3) return channel.Name }