feat: proper context handeling, graceful cancellation & shutdown

This commit is contained in:
2025-08-26 00:29:37 -05:00
parent 65fe4f101f
commit c01a112ec6
7 changed files with 103 additions and 26 deletions

View File

@@ -5,6 +5,7 @@ import (
"banner/internal/models"
"banner/internal/utils"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -76,7 +77,14 @@ func (a *API) DoRequest(req *http.Request) (*http.Response, error) {
Str("content-type", req.Header.Get("Content-Type")).
Msg("Request")
res, err := a.config.Client.Do(req)
// Create a timeout context for this specific request
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
defer cancel()
// Clone the request with the timeout context
reqWithTimeout := req.Clone(ctx)
res, err := a.config.Client.Do(reqWithTimeout)
if err != nil {
log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed")
@@ -614,8 +622,12 @@ func (a *API) ResetDataForm() {
// GetCourse retrieves the course information.
// This course does not retrieve directly from the API, but rather uses scraped data stored in Redis.
func (a *API) GetCourse(crn string) (*models.Course, error) {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
// Retrieve raw data
result, err := a.config.KV.Get(a.config.Ctx, fmt.Sprintf("class:%s", crn)).Result()
result, err := a.config.KV.Get(ctx, fmt.Sprintf("class:%s", crn)).Result()
if err != nil {
if err == redis.Nil {
return nil, fmt.Errorf("course not found: %w", err)

View File

@@ -3,6 +3,7 @@ package api
import (
"banner/internal/models"
"banner/internal/utils"
"context"
"fmt"
"math/rand"
"time"
@@ -72,8 +73,12 @@ func (a *API) GetExpiredSubjects() ([]string, error) {
term := utils.Default(time.Now()).ToString()
subjects := make([]string, 0)
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 10*time.Second)
defer cancel()
// Get all subjects
values, err := a.config.KV.MGet(a.config.Ctx, lo.Map(AllMajors, func(major string, _ int) string {
values, err := a.config.KV.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string {
return fmt.Sprintf("scraped:%s:%s", major, term)
})...).Result()
if err != nil {
@@ -161,7 +166,12 @@ func (a *API) ScrapeMajor(subject string) error {
if totalClassCount == 0 {
totalClassCount = -1
}
err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
err := a.config.KV.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
if err != nil {
log.Error().Err(err).Msg("failed to mark major as scraped")
}
@@ -214,7 +224,11 @@ func (a *API) CalculateExpiry(term string, count int, priority bool) time.Durati
// IntakeCourse stores a course in Redis.
// This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future.
func (a *API) IntakeCourse(course models.Course) error {
err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
err := a.config.KV.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
if err != nil {
return fmt.Errorf("failed to store class in Redis: %w", err)
}

View File

@@ -10,6 +10,7 @@ import (
type Config struct {
Ctx context.Context
CancelFunc context.CancelFunc
KV *redis.Client
Client *http.Client
IsDevelopment bool
@@ -23,15 +24,17 @@ const (
)
func New() (*Config, error) {
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
loc, err := time.LoadLocation(CentralTimezoneName)
if err != nil {
cancel()
return nil, err
}
return &Config{
Ctx: ctx,
CancelFunc: cancel,
CentralTimeLocation: loc,
}, nil
}

View File

@@ -58,7 +58,7 @@ func BuildRequestWithBody(cfg *config.Config, method string, path string, params
}
}
request, _ := http.NewRequest(method, requestUrl, body)
request, _ := http.NewRequestWithContext(cfg.Ctx, method, requestUrl, body)
AddUserAgent(request)
return request
}

View File

@@ -2,6 +2,7 @@ package utils
import (
"banner/internal/config"
"context"
"time"
"github.com/bwmarrin/discordgo"
@@ -11,8 +12,12 @@ import (
// GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value
func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel()
// Check Redis for the guild name
guildName, err := cfg.KV.Get(cfg.Ctx, "guild:"+guildID+":name").Result()
guildName, err := cfg.KV.Get(ctx, "guild:"+guildID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
return "err"
@@ -29,7 +34,9 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string
log.Error().Stack().Err(err).Msg("Error getting guild name")
// Store an invalid value in Redis so we don't keep trying to get the guild name
_, err := cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result()
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel2()
_, err := cfg.KV.Set(ctx2, "guild:"+guildID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
}
@@ -38,15 +45,21 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string
}
// Cache the guild name in Redis
cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3)
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel3()
cfg.KV.Set(ctx3, "guild:"+guildID+":name", guild.Name, time.Hour*3)
return guild.Name
}
// GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value
func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel()
// Check Redis for the channel name
channelName, err := cfg.KV.Get(cfg.Ctx, "channel:"+channelID+":name").Result()
channelName, err := cfg.KV.Get(ctx, "channel:"+channelID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
return "err"
@@ -63,7 +76,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st
log.Error().Stack().Err(err).Msg("Error getting channel name")
// Store an invalid value in Redis so we don't keep trying to get the channel name
_, err := cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result()
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel2()
_, err := cfg.KV.Set(ctx2, "channel:"+channelID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
}
@@ -72,7 +87,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st
}
// Cache the channel name in Redis
cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3)
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel3()
cfg.KV.Set(ctx3, "channel:"+channelID+":name", channel.Name, time.Hour*3)
return channel.Name
}