mirror of
https://github.com/Xevion/banner.git
synced 2025-12-10 18:06:35 -06:00
feat: proper context handeling, graceful cancellation & shutdown
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,6 +1,6 @@
|
|||||||
.env
|
.env
|
||||||
cover.cov
|
cover.cov
|
||||||
./banner
|
/banner
|
||||||
.*.go
|
.*.go
|
||||||
dumps/
|
dumps/
|
||||||
js/
|
js/
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
@@ -139,14 +140,31 @@ func main() {
|
|||||||
initRedis(cfg)
|
initRedis(cfg)
|
||||||
|
|
||||||
if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") {
|
if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") {
|
||||||
// Start pprof server
|
// Start pprof server with graceful shutdown
|
||||||
go func() {
|
go func() {
|
||||||
port := os.Getenv("PORT")
|
port := os.Getenv("PORT")
|
||||||
log.Info().Str("port", port).Msg("Starting pprof server")
|
log.Info().Str("port", port).Msg("Starting pprof server")
|
||||||
err := http.ListenAndServe(":"+port, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
server := &http.Server{
|
||||||
log.Fatal().Stack().Err(err).Msg("Cannot start pprof server")
|
Addr: ":" + port,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server in a separate goroutine
|
||||||
|
go func() {
|
||||||
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatal().Stack().Err(err).Msg("Cannot start pprof server")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for context cancellation and then shutdown
|
||||||
|
<-cfg.Ctx.Done()
|
||||||
|
log.Info().Msg("Shutting down pprof server")
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Pprof server forced to shutdown")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -157,8 +175,11 @@ func main() {
|
|||||||
log.Err(err).Msg("Cannot create cookie jar")
|
log.Err(err).Msg("Cannot create cookie jar")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create client, setup session (acquire cookies)
|
// Create client with timeout, setup session (acquire cookies)
|
||||||
client := &http.Client{Jar: cookies}
|
client := &http.Client{
|
||||||
|
Jar: cookies,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
cfg.SetClient(client)
|
cfg.SetClient(client)
|
||||||
|
|
||||||
baseURL := os.Getenv("BANNER_BASE_URL")
|
baseURL := os.Getenv("BANNER_BASE_URL")
|
||||||
@@ -237,13 +258,20 @@ func main() {
|
|||||||
|
|
||||||
// Launch a goroutine to scrape the banner system periodically
|
// Launch a goroutine to scrape the banner system periodically
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
ticker := time.NewTicker(3 * time.Minute)
|
||||||
err := apiInstance.Scrape()
|
defer ticker.Stop()
|
||||||
if err != nil {
|
|
||||||
log.Err(err).Stack().Msg("Periodic Scrape Failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(3 * time.Minute)
|
for {
|
||||||
|
select {
|
||||||
|
case <-cfg.Ctx.Done():
|
||||||
|
log.Info().Msg("Periodic scraper stopped due to context cancellation")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
err := apiInstance.Scrape()
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Stack().Msg("Periodic Scrape Failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -260,6 +288,9 @@ func main() {
|
|||||||
closingSignal := <-stop
|
closingSignal := <-stop
|
||||||
botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds
|
botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds
|
||||||
|
|
||||||
|
// Cancel the context to signal all operations to stop
|
||||||
|
cfg.CancelFunc()
|
||||||
|
|
||||||
// Defers are called after this
|
// Defers are called after this
|
||||||
log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down")
|
log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"banner/internal/models"
|
"banner/internal/models"
|
||||||
"banner/internal/utils"
|
"banner/internal/utils"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -76,7 +77,14 @@ func (a *API) DoRequest(req *http.Request) (*http.Response, error) {
|
|||||||
Str("content-type", req.Header.Get("Content-Type")).
|
Str("content-type", req.Header.Get("Content-Type")).
|
||||||
Msg("Request")
|
Msg("Request")
|
||||||
|
|
||||||
res, err := a.config.Client.Do(req)
|
// Create a timeout context for this specific request
|
||||||
|
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Clone the request with the timeout context
|
||||||
|
reqWithTimeout := req.Clone(ctx)
|
||||||
|
|
||||||
|
res, err := a.config.Client.Do(reqWithTimeout)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed")
|
log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed")
|
||||||
@@ -614,8 +622,12 @@ func (a *API) ResetDataForm() {
|
|||||||
// GetCourse retrieves the course information.
|
// GetCourse retrieves the course information.
|
||||||
// This course does not retrieve directly from the API, but rather uses scraped data stored in Redis.
|
// This course does not retrieve directly from the API, but rather uses scraped data stored in Redis.
|
||||||
func (a *API) GetCourse(crn string) (*models.Course, error) {
|
func (a *API) GetCourse(crn string) (*models.Course, error) {
|
||||||
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Retrieve raw data
|
// Retrieve raw data
|
||||||
result, err := a.config.KV.Get(a.config.Ctx, fmt.Sprintf("class:%s", crn)).Result()
|
result, err := a.config.KV.Get(ctx, fmt.Sprintf("class:%s", crn)).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, fmt.Errorf("course not found: %w", err)
|
return nil, fmt.Errorf("course not found: %w", err)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"banner/internal/models"
|
"banner/internal/models"
|
||||||
"banner/internal/utils"
|
"banner/internal/utils"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
@@ -72,8 +73,12 @@ func (a *API) GetExpiredSubjects() ([]string, error) {
|
|||||||
term := utils.Default(time.Now()).ToString()
|
term := utils.Default(time.Now()).ToString()
|
||||||
subjects := make([]string, 0)
|
subjects := make([]string, 0)
|
||||||
|
|
||||||
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(a.config.Ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Get all subjects
|
// Get all subjects
|
||||||
values, err := a.config.KV.MGet(a.config.Ctx, lo.Map(AllMajors, func(major string, _ int) string {
|
values, err := a.config.KV.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string {
|
||||||
return fmt.Sprintf("scraped:%s:%s", major, term)
|
return fmt.Sprintf("scraped:%s:%s", major, term)
|
||||||
})...).Result()
|
})...).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,7 +166,12 @@ func (a *API) ScrapeMajor(subject string) error {
|
|||||||
if totalClassCount == 0 {
|
if totalClassCount == 0 {
|
||||||
totalClassCount = -1
|
totalClassCount = -1
|
||||||
}
|
}
|
||||||
err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
|
|
||||||
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := a.config.KV.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to mark major as scraped")
|
log.Error().Err(err).Msg("failed to mark major as scraped")
|
||||||
}
|
}
|
||||||
@@ -214,7 +224,11 @@ func (a *API) CalculateExpiry(term string, count int, priority bool) time.Durati
|
|||||||
// IntakeCourse stores a course in Redis.
|
// IntakeCourse stores a course in Redis.
|
||||||
// This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future.
|
// This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future.
|
||||||
func (a *API) IntakeCourse(course models.Course) error {
|
func (a *API) IntakeCourse(course models.Course) error {
|
||||||
err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := a.config.KV.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to store class in Redis: %w", err)
|
return fmt.Errorf("failed to store class in Redis: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
|
CancelFunc context.CancelFunc
|
||||||
KV *redis.Client
|
KV *redis.Client
|
||||||
Client *http.Client
|
Client *http.Client
|
||||||
IsDevelopment bool
|
IsDevelopment bool
|
||||||
@@ -23,15 +24,17 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func New() (*Config, error) {
|
func New() (*Config, error) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
loc, err := time.LoadLocation(CentralTimezoneName)
|
loc, err := time.LoadLocation(CentralTimezoneName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cancel()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
|
CancelFunc: cancel,
|
||||||
CentralTimeLocation: loc,
|
CentralTimeLocation: loc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func BuildRequestWithBody(cfg *config.Config, method string, path string, params
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request, _ := http.NewRequest(method, requestUrl, body)
|
request, _ := http.NewRequestWithContext(cfg.Ctx, method, requestUrl, body)
|
||||||
AddUserAgent(request)
|
AddUserAgent(request)
|
||||||
return request
|
return request
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"banner/internal/config"
|
"banner/internal/config"
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
@@ -11,8 +12,12 @@ import (
|
|||||||
|
|
||||||
// GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value
|
// GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value
|
||||||
func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string {
|
func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string {
|
||||||
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Check Redis for the guild name
|
// Check Redis for the guild name
|
||||||
guildName, err := cfg.KV.Get(cfg.Ctx, "guild:"+guildID+":name").Result()
|
guildName, err := cfg.KV.Get(ctx, "guild:"+guildID+":name").Result()
|
||||||
if err != nil && err != redis.Nil {
|
if err != nil && err != redis.Nil {
|
||||||
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
|
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
|
||||||
return "err"
|
return "err"
|
||||||
@@ -29,7 +34,9 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string
|
|||||||
log.Error().Stack().Err(err).Msg("Error getting guild name")
|
log.Error().Stack().Err(err).Msg("Error getting guild name")
|
||||||
|
|
||||||
// Store an invalid value in Redis so we don't keep trying to get the guild name
|
// Store an invalid value in Redis so we don't keep trying to get the guild name
|
||||||
_, err := cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result()
|
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
_, err := cfg.KV.Set(ctx2, "guild:"+guildID+":name", "x", time.Minute*5).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
|
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
|
||||||
}
|
}
|
||||||
@@ -38,15 +45,21 @@ func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cache the guild name in Redis
|
// Cache the guild name in Redis
|
||||||
cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3)
|
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel3()
|
||||||
|
cfg.KV.Set(ctx3, "guild:"+guildID+":name", guild.Name, time.Hour*3)
|
||||||
|
|
||||||
return guild.Name
|
return guild.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value
|
// GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value
|
||||||
func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string {
|
func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string {
|
||||||
|
// Create a timeout context for Redis operations
|
||||||
|
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Check Redis for the channel name
|
// Check Redis for the channel name
|
||||||
channelName, err := cfg.KV.Get(cfg.Ctx, "channel:"+channelID+":name").Result()
|
channelName, err := cfg.KV.Get(ctx, "channel:"+channelID+":name").Result()
|
||||||
if err != nil && err != redis.Nil {
|
if err != nil && err != redis.Nil {
|
||||||
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
|
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
|
||||||
return "err"
|
return "err"
|
||||||
@@ -63,7 +76,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st
|
|||||||
log.Error().Stack().Err(err).Msg("Error getting channel name")
|
log.Error().Stack().Err(err).Msg("Error getting channel name")
|
||||||
|
|
||||||
// Store an invalid value in Redis so we don't keep trying to get the channel name
|
// Store an invalid value in Redis so we don't keep trying to get the channel name
|
||||||
_, err := cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result()
|
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
_, err := cfg.KV.Set(ctx2, "channel:"+channelID+":name", "x", time.Minute*5).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
|
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
|
||||||
}
|
}
|
||||||
@@ -72,7 +87,9 @@ func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cache the channel name in Redis
|
// Cache the channel name in Redis
|
||||||
cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3)
|
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||||
|
defer cancel3()
|
||||||
|
cfg.KV.Set(ctx3, "channel:"+channelID+":name", channel.Name, time.Hour*3)
|
||||||
|
|
||||||
return channel.Name
|
return channel.Name
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user