From a37fbeb22426d426329c409e40846ec63c47b992 Mon Sep 17 00:00:00 2001 From: Xevion Date: Tue, 26 Aug 2025 00:19:43 -0500 Subject: [PATCH] fix: proper configuration handling across submodules --- cmd/banner/main.go | 184 ++++++++++++-------------------------- internal/api/api.go | 128 +++++++++++++------------- internal/api/scrape.go | 77 ++++++++-------- internal/api/session.go | 13 ++- internal/bot/bot.go | 40 +++++++++ internal/bot/commands.go | 69 ++++++++------ internal/bot/handlers.go | 90 +++++++++++++++++++ internal/config/config.go | 50 +++++------ internal/models/types.go | 19 ++-- internal/utils/helpers.go | 32 +++++-- internal/utils/meta.go | 16 ++-- internal/utils/term.go | 17 ++-- 12 files changed, 408 insertions(+), 327 deletions(-) create mode 100644 internal/bot/bot.go create mode 100644 internal/bot/handlers.go diff --git a/cmd/banner/main.go b/cmd/banner/main.go index 6c21942..7236f62 100644 --- a/cmd/banner/main.go +++ b/cmd/banner/main.go @@ -1,9 +1,7 @@ package main import ( - "context" "flag" - "fmt" "net/http" "net/http/cookiejar" _ "net/http/pprof" @@ -25,21 +23,13 @@ import ( "banner/internal/api" "banner/internal/bot" + "banner/internal/config" "banner/internal/utils" ) var ( - ctx context.Context - kv *redis.Client - Session *discordgo.Session - client http.Client - cookies http.CookieJar - isDevelopment bool - baseURL string // Base URL for all requests to the banner system - environment string - p *message.Printer = message.NewPrinter(message.MatchLanguage("en")) - CentralTimeLocation *time.Location - isClosing bool = false + Session *discordgo.Session + p *message.Printer = message.NewPrinter(message.MatchLanguage("en")) ) const ( @@ -54,43 +44,36 @@ func init() { log.Debug().Err(err).Msg("Error loading .env file") } - ctx = context.Background() - - var err error - CentralTimeLocation, err = time.LoadLocation(CentralTimezoneName) - if err != nil { - panic(err) - } - // Set zerolog's timestamp function to use the central timezone zerolog.TimestampFunc = func() time.Time { - return time.Now().In(CentralTimeLocation) + // TODO: Move this to config + loc, err := time.LoadLocation(CentralTimezoneName) + if err != nil { + panic(err) + } + return time.Now().In(loc) } zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack - // Try to grab the environment variable, or default to development - environment = utils.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT") - if environment == "" { - environment = "development" + // Use the custom console writer if we're in development + isDevelopment := utils.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT") + if isDevelopment == "" { + isDevelopment = "development" } - // Use the custom console writer if we're in development - isDevelopment = environment == "development" - if isDevelopment { + if isDevelopment == "development" { log.Logger = zerolog.New(utils.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger() } else { log.Logger = zerolog.New(utils.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger() } - log.Debug().Str("environment", environment).Msg("Loggers Setup") + log.Debug().Str("environment", isDevelopment).Msg("Loggers Setup") // Set discordgo's logger to use zerolog discordgo.Logger = utils.DiscordGoLogger - - baseURL = os.Getenv("BANNER_BASE_URL") } -func initRedis() { +func initRedis(cfg *config.Config) { // Setup redis redisUrl := utils.GetFirstEnv("REDIS_URL", "REDIS_PRIVATE_URL") if redisUrl == "" { @@ -102,14 +85,15 @@ func initRedis() { if err != nil { log.Fatal().Stack().Err(err).Msg("Cannot parse redis url") } - kv = redis.NewClient(options) + kv := redis.NewClient(options) + cfg.SetRedis(kv) var lastPingErr error pingCount := 0 // Nth ping being attempted totalPings := 5 // Total pings to attempt // Wait for private networking to kick in (production only) - if !isDevelopment { + if !cfg.IsDevelopment { time.Sleep(250 * time.Millisecond) } @@ -121,7 +105,7 @@ func initRedis() { } // Ping redis - pong, err := kv.Ping(ctx).Result() + pong, err := cfg.KV.Ping(cfg.Ctx).Result() // Failed; log error and wait 2 seconds if err != nil { @@ -140,7 +124,19 @@ func initRedis() { func main() { flag.Parse() - initRedis() + cfg, err := config.New() + if err != nil { + log.Fatal().Stack().Err(err).Msg("Cannot create config") + } + + // Try to grab the environment variable, or default to development + environment := utils.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT") + if environment == "" { + environment = "development" + } + cfg.SetEnvironment(environment) + + initRedis(cfg) if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") { // Start pprof server @@ -156,110 +152,40 @@ func main() { } // Create cookie jar - var err error - cookies, err = cookiejar.New(nil) + cookies, err := cookiejar.New(nil) if err != nil { log.Err(err).Msg("Cannot create cookie jar") } // Create client, setup session (acquire cookies) - client = http.Client{Jar: cookies} - api.Setup() + client := &http.Client{Jar: cookies} + cfg.SetClient(client) + + baseURL := os.Getenv("BANNER_BASE_URL") + cfg.SetBaseURL(baseURL) + + apiInstance := api.New(cfg) + apiInstance.Setup() // Create discord session - Session, err = discordgo.New("Bot " + os.Getenv("BOT_TOKEN")) + session, err := discordgo.New("Bot " + os.Getenv("BOT_TOKEN")) if err != nil { log.Err(err).Msg("Invalid bot parameters") } + botInstance := bot.New(session, apiInstance, cfg) + botInstance.RegisterHandlers() + // Open discord session - Session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) { + session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) { log.Info().Str("username", r.User.Username).Str("discriminator", r.User.Discriminator).Str("id", r.User.ID).Str("session", s.State.SessionID).Msg("Bot is logged in") }) - err = Session.Open() + err = session.Open() if err != nil { log.Fatal().Stack().Err(err).Msg("Cannot open the session") } // Setup command handlers - Session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) { - // Handle commands during restart (highly unlikely, but just in case) - if isClosing { - err := utils.RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil) - if err != nil { - log.Error().Err(err).Msg("Failed to respond with restart error feedback") - } - return - } - - name := interaction.ApplicationCommandData().Name - if handler, ok := bot.CommandHandlers[name]; ok { - // Build dict of options for the log - options := zerolog.Dict() - for _, option := range interaction.ApplicationCommandData().Options { - options.Str(option.Name, fmt.Sprintf("%v", option.Value)) - } - - event := log.Info().Str("name", name).Str("user", utils.GetUser(interaction).Username).Dict("options", options) - - // If the command was invoked in a guild, add guild & channel info to the log - if interaction.Member != nil { - guild := zerolog.Dict() - guild.Str("id", interaction.GuildID) - guild.Str("name", utils.GetGuildName(internalSession, interaction.GuildID)) - event.Dict("guild", guild) - - channel := zerolog.Dict() - channel.Str("id", interaction.ChannelID) - guild.Str("name", utils.GetChannelName(internalSession, interaction.ChannelID)) - event.Dict("channel", channel) - } else { - // If the command was invoked in a DM, add the user info to the log - user := zerolog.Dict() - user.Str("id", interaction.User.ID) - user.Str("name", interaction.User.Username) - event.Dict("user", user) - } - - // Log command invocation - event.Msg("Command Invoked") - - // Prepare to recover - defer func() { - if err := recover(); err != nil { - log.Error().Stack().Str("commandName", name).Interface("detail", err).Msg("Command Handler Panic") - - // Respond with error - err := utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil) - if err != nil { - log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with panic error feedback") - } - } - }() - - // Call handler - err := handler(internalSession, interaction) - - // Log & respond error - if err != nil { - // TODO: Find a way to merge the response with the handler's error - log.Error().Str("commandName", name).Err(err).Msg("Command Handler Error") - - // Respond with error - err = utils.RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil) - if err != nil { - log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with error feedback") - } - } - - } else { - log.Error().Stack().Str("commandName", name).Msg("Command Interaction Has No Handler") - - // Respond with error - utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil) - } - }) - // Register commands with discord arr := zerolog.Arr() lo.ForEach(bot.CommandDefinitions, func(cmd *discordgo.ApplicationCommand, _ int) { @@ -269,16 +195,16 @@ func main() { // In development, use test server, otherwise empty (global) for command registration guildTarget := "" - if isDevelopment { + if cfg.IsDevelopment { guildTarget = os.Getenv("BOT_TARGET_GUILD") } // Register commands - existingCommands, err := Session.ApplicationCommands(Session.State.User.ID, guildTarget) + existingCommands, err := session.ApplicationCommands(session.State.User.ID, guildTarget) if err != nil { log.Fatal().Stack().Err(err).Msg("Cannot get existing commands") } - newCommands, err := Session.ApplicationCommandBulkOverwrite(Session.State.User.ID, guildTarget, bot.CommandDefinitions) + newCommands, err := session.ApplicationCommandBulkOverwrite(session.State.User.ID, guildTarget, bot.CommandDefinitions) if err != nil { log.Fatal().Stack().Err(err).Msg("Cannot register commands") } @@ -304,7 +230,7 @@ func main() { } // Fetch terms on startup - err = api.TryReloadTerms() + err = apiInstance.TryReloadTerms() if err != nil { log.Fatal().Stack().Err(err).Msg("Cannot fetch terms on startup") } @@ -312,7 +238,7 @@ func main() { // Launch a goroutine to scrape the banner system periodically go func() { for { - err := api.Scrape() + err := apiInstance.Scrape() if err != nil { log.Err(err).Stack().Msg("Periodic Scrape Failed") } @@ -322,7 +248,7 @@ func main() { }() // Close session, ensure http client closes idle connections - defer Session.Close() + defer session.Close() defer client.CloseIdleConnections() // Setup signal handler channel @@ -332,7 +258,7 @@ func main() { // Wait for signal (indefinite) closingSignal := <-stop - isClosing = true // TODO: Switch to atomic lock with forced close after 10 seconds + botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds // Defers are called after this log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down") diff --git a/internal/api/api.go b/internal/api/api.go index 0438148..aec4beb 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -22,6 +22,14 @@ import ( "github.com/samber/lo" ) +type API struct { + config *config.Config +} + +func New(config *config.Config) *API { + return &API{config: config} +} + var ( latestSession string sessionTime time.Time @@ -44,7 +52,7 @@ func GenerateSession() string { } // DoRequest performs & logs the request, logging and returning the response -func DoRequest(req *http.Request) (*http.Response, error) { +func (a *API) DoRequest(req *http.Request) (*http.Response, error) { headerSize := 0 for key, values := range req.Header { for _, value := range values { @@ -68,7 +76,7 @@ func DoRequest(req *http.Request) (*http.Response, error) { Str("content-type", req.Header.Get("Content-Type")). Msg("Request") - res, err := config.Client.Do(req) + res, err := a.config.Client.Do(req) if err != nil { log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed") @@ -98,14 +106,14 @@ var terms []BannerTerm var lastTermUpdate time.Time // TryReloadTerms attempts to reload the terms if they are not loaded or the last update was more than 24 hours ago -func TryReloadTerms() error { +func (a *API) TryReloadTerms() error { if len(terms) > 0 && time.Since(lastTermUpdate) < 24*time.Hour { return nil } // Load the terms var err error - terms, err = GetTerms("", 1, 100) + terms, err = a.GetTerms("", 1, 100) if err != nil { return fmt.Errorf("failed to load terms: %w", err) } @@ -116,9 +124,9 @@ func TryReloadTerms() error { // IsTermArchived checks if the given term is archived // TODO: Add error, switch missing term logic to error -func IsTermArchived(term string) bool { +func (a *API) IsTermArchived(term string) bool { // Ensure the terms are loaded - err := TryReloadTerms() + err := a.TryReloadTerms() if err != nil { log.Err(err).Stack().Msg("Failed to reload terms") return true @@ -137,26 +145,12 @@ func IsTermArchived(term string) bool { return bannerTerm.Archived() } -// GetSession retrieves the current session ID if it's still valid. -// If the session ID is invalid or has expired, a new one is generated and returned. -// SessionIDs are valid for 30 minutes, but we'll be conservative and regenerate every 25 minutes. -func GetSession() string { - // Check if a reset is required +// EnsureSession ensures that a valid session is available, creating one if necessary. +func (a *API) EnsureSession() string { if latestSession == "" || time.Since(sessionTime) >= expiryTime { - // Generate a new session identifier latestSession = GenerateSession() - - // Select the current term - term := utils.Default(time.Now()).ToString() - log.Info().Str("term", term).Str("sessionID", latestSession).Msg("Setting selected term") - err := SelectTerm(term, latestSession) - if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed to select term while generating session ID") - } - sessionTime = time.Now() } - return latestSession } @@ -175,13 +169,13 @@ func (term BannerTerm) Archived() bool { // GetTerms retrieves and parses the term information for a given search term. // Page number must be at least 1. -func GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) { +func (a *API) GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) { // Ensure offset is valid if page <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/getTerms", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/getTerms", map[string]string{ "searchTerm": search, // Page vs Offset is not a mistake here, the API uses "offset" as the page number "offset": strconv.Itoa(page), @@ -193,7 +187,7 @@ func GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) { return nil, errors.New("Offset must be greater than 0") } - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get terms: %w", err) } @@ -225,7 +219,7 @@ func GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) { // SelectTerm selects the given term in the Banner system. // This function completes the initial term selection process, which is required before any other API calls can be made with the session ID. -func SelectTerm(term string, sessionID string) error { +func (a *API) SelectTerm(term string, sessionID string) error { form := url.Values{ "term": {term}, "studyPath": {""}, @@ -239,10 +233,10 @@ func SelectTerm(term string, sessionID string) error { "mode": "search", } - req := utils.BuildRequestWithBody("POST", "/term/search", params, bytes.NewBufferString(form.Encode())) + req := utils.BuildRequestWithBody(a.config, "POST", "/term/search", params, bytes.NewBufferString(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return fmt.Errorf("failed to select term: %w", err) } @@ -265,8 +259,8 @@ func SelectTerm(term string, sessionID string) error { json.Unmarshal(body, &redirectResponse) // Make a GET request to the fwdUrl - req = utils.BuildRequest("GET", redirectResponse.FwdURL, nil) - res, err = DoRequest(req) + req = utils.BuildRequest(a.config, "GET", redirectResponse.FwdURL, nil) + res, err = a.DoRequest(req) if err != nil { return fmt.Errorf("failed to follow redirect: %w", err) } @@ -281,22 +275,22 @@ func SelectTerm(term string, sessionID string) error { // GetPartOfTerms retrieves and parses the part of term information for a given term. // Ensure that the offset is greater than 0. -func GetPartOfTerms(search string, term int, offset int, maxResults int) ([]BannerTerm, error) { +func (a *API) GetPartOfTerms(search string, term int, offset int, maxResults int) ([]BannerTerm, error) { // Ensure offset is valid if offset <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/get_partOfTerm", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/get_partOfTerm", map[string]string{ "searchTerm": search, "term": strconv.Itoa(term), "offset": strconv.Itoa(offset), "max": strconv.Itoa(maxResults), - "uniqueSessionId": GetSession(), + "uniqueSessionId": a.EnsureSession(), "_": utils.Nonce(), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get part of terms: %w", err) } @@ -325,22 +319,22 @@ func GetPartOfTerms(search string, term int, offset int, maxResults int) ([]Bann // In my opinion, it is unclear what providing the term does, as the results should be the same regardless of the term. // This function is included for completeness, but probably isn't useful. // Ensure that the offset is greater than 0. -func GetInstructors(search string, term string, offset int, maxResults int) ([]Instructor, error) { +func (a *API) GetInstructors(search string, term string, offset int, maxResults int) ([]Instructor, error) { // Ensure offset is valid if offset <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/get_instructor", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/get_instructor", map[string]string{ "searchTerm": search, "term": term, "offset": strconv.Itoa(offset), "max": strconv.Itoa(maxResults), - "uniqueSessionId": GetSession(), + "uniqueSessionId": a.EnsureSession(), "_": utils.Nonce(), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get instructors: %w", err) } @@ -370,7 +364,7 @@ func GetInstructors(search string, term string, offset int, maxResults int) ([]I type ClassDetails struct { } -func GetCourseDetails(term int, crn int) *ClassDetails { +func (a *API) GetCourseDetails(term int, crn int) *ClassDetails { body, err := json.Marshal(map[string]string{ "term": strconv.Itoa(term), "courseReferenceNumber": strconv.Itoa(crn), @@ -379,9 +373,9 @@ func GetCourseDetails(term int, crn int) *ClassDetails { if err != nil { log.Fatal().Stack().Err(err).Msg("Failed to marshal body") } - req := utils.BuildRequestWithBody("GET", "/searchResults/getClassDetails", nil, bytes.NewBuffer(body)) + req := utils.BuildRequestWithBody(a.config, "GET", "/searchResults/getClassDetails", nil, bytes.NewBuffer(body)) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil } @@ -395,13 +389,13 @@ func GetCourseDetails(term int, crn int) *ClassDetails { } // Search invokes a search on the Banner system with the given query and returns the results. -func Search(query *Query, sort string, sortDescending bool) (*models.SearchResult, error) { - ResetDataForm() +func (a *API) Search(term string, query *Query, sort string, sortDescending bool) (*models.SearchResult, error) { + a.ResetDataForm() params := query.Paramify() - params["txt_term"] = "202510" // TODO: Make this automatic but dynamically specifiable - params["uniqueSessionId"] = GetSession() + params["txt_term"] = term + params["uniqueSessionId"] = a.EnsureSession() params["sortColumn"] = sort params["sortDirection"] = "asc" @@ -409,9 +403,9 @@ func Search(query *Query, sort string, sortDescending bool) (*models.SearchResul params["startDatepicker"] = "" params["endDatepicker"] = "" - req := utils.BuildRequest("GET", "/searchResults/searchResults", params) + req := utils.BuildRequest(a.config, "GET", "/searchResults/searchResults", params) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to search: %w", err) } @@ -445,22 +439,22 @@ func Search(query *Query, sort string, sortDescending bool) (*models.SearchResul // GetSubjects retrieves and parses the subject information for a given search term. // The results of this response shouldn't change much, but technically could as new majors are developed, or old ones are removed. // Ensure that the offset is greater than 0. -func GetSubjects(search string, term string, offset int, maxResults int) ([]Pair, error) { +func (a *API) GetSubjects(search string, term string, offset int, maxResults int) ([]Pair, error) { // Ensure offset is valid if offset <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/get_subject", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/get_subject", map[string]string{ "searchTerm": search, "term": term, "offset": strconv.Itoa(offset), "max": strconv.Itoa(maxResults), - "uniqueSessionId": GetSession(), + "uniqueSessionId": a.EnsureSession(), "_": utils.Nonce(), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get subjects: %w", err) } @@ -489,22 +483,22 @@ func GetSubjects(search string, term string, offset int, maxResults int) ([]Pair // In my opinion, it is unclear what providing the term does, as the results should be the same regardless of the term. // This function is included for completeness, but probably isn't useful. // Ensure that the offset is greater than 0. -func GetCampuses(search string, term int, offset int, maxResults int) ([]Pair, error) { +func (a *API) GetCampuses(search string, term int, offset int, maxResults int) ([]Pair, error) { // Ensure offset is valid if offset <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/get_campus", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/get_campus", map[string]string{ "searchTerm": search, "term": strconv.Itoa(term), "offset": strconv.Itoa(offset), "max": strconv.Itoa(maxResults), - "uniqueSessionId": GetSession(), + "uniqueSessionId": a.EnsureSession(), "_": utils.Nonce(), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get campuses: %w", err) } @@ -533,22 +527,22 @@ func GetCampuses(search string, term int, offset int, maxResults int) ([]Pair, e // In my opinion, it is unclear what providing the term does, as the results should be the same regardless of the term. // This function is included for completeness, but probably isn't useful. // Ensure that the offset is greater than 0. -func GetInstructionalMethods(search string, term string, offset int, maxResults int) ([]Pair, error) { +func (a *API) GetInstructionalMethods(search string, term string, offset int, maxResults int) ([]Pair, error) { // Ensure offset is valid if offset <= 0 { return nil, errors.New("offset must be greater than 0") } - req := utils.BuildRequest("GET", "/classSearch/get_instructionalMethod", map[string]string{ + req := utils.BuildRequest(a.config, "GET", "/classSearch/get_instructionalMethod", map[string]string{ "searchTerm": search, "term": term, "offset": strconv.Itoa(offset), "max": strconv.Itoa(maxResults), - "uniqueSessionId": GetSession(), + "uniqueSessionId": a.EnsureSession(), "_": utils.Nonce(), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get instructional methods: %w", err) } @@ -573,13 +567,13 @@ func GetInstructionalMethods(search string, term string, offset int, maxResults // GetCourseMeetingTime retrieves the meeting time information for a course based on the given term and course reference number (CRN). // It makes an HTTP GET request to the appropriate API endpoint and parses the response to extract the meeting time data. // The function returns a MeetingTimeResponse struct containing the extracted information. -func GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, error) { - req := utils.BuildRequest("GET", "/searchResults/getFacultyMeetingTimes", map[string]string{ +func (a *API) GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, error) { + req := utils.BuildRequest(a.config, "GET", "/searchResults/getFacultyMeetingTimes", map[string]string{ "term": strconv.Itoa(term), "courseReferenceNumber": strconv.Itoa(crn), }) - res, err := DoRequest(req) + res, err := a.DoRequest(req) if err != nil { return nil, fmt.Errorf("failed to get meeting time: %w", err) } @@ -609,9 +603,9 @@ func GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, erro } // ResetDataForm makes a POST request that needs to be made upon before new search requests can be made. -func ResetDataForm() { - req := utils.BuildRequest("POST", "/classSearch/resetDataForm", nil) - _, err := DoRequest(req) +func (a *API) ResetDataForm() { + req := utils.BuildRequest(a.config, "POST", "/classSearch/resetDataForm", nil) + _, err := a.DoRequest(req) if err != nil { log.Fatal().Stack().Err(err).Msg("Failed to reset data form") } @@ -619,9 +613,9 @@ func ResetDataForm() { // GetCourse retrieves the course information. // This course does not retrieve directly from the API, but rather uses scraped data stored in Redis. -func GetCourse(crn string) (*models.Course, error) { +func (a *API) GetCourse(crn string) (*models.Course, error) { // Retrieve raw data - result, err := config.KV.Get(config.Ctx, fmt.Sprintf("class:%s", crn)).Result() + result, err := a.config.KV.Get(a.config.Ctx, fmt.Sprintf("class:%s", crn)).Result() if err != nil { if err == redis.Nil { return nil, fmt.Errorf("course not found: %w", err) diff --git a/internal/api/scrape.go b/internal/api/scrape.go index d5dbbd9..e78f8f1 100644 --- a/internal/api/scrape.go +++ b/internal/api/scrape.go @@ -1,7 +1,6 @@ package api import ( - "banner/internal/config" "banner/internal/models" "banner/internal/utils" "fmt" @@ -25,40 +24,41 @@ var ( AllMajors []string ) -// Scrape is the general scraping invocation (best called within/as a goroutine) that should be called regularly to initiate scraping of the Banner system. -func Scrape() error { - // Populate AllMajors if it is empty - if len(AncillaryMajors) == 0 { - term := utils.Default(time.Now()).ToString() - subjects, err := GetSubjects("", term, 1, 99) - if err != nil { - return fmt.Errorf("failed to get subjects: %w", err) - } - - // Ensure subjects were found - if len(subjects) == 0 { - return fmt.Errorf("no subjects found") - } - - // Extract major code name - for _, subject := range subjects { - // Add to AncillaryMajors if not in PriorityMajors - if !lo.Contains(PriorityMajors, subject.Code) { - AncillaryMajors = append(AncillaryMajors, subject.Code) - } - } - - AllMajors = lo.Flatten([][]string{PriorityMajors, AncillaryMajors}) +// Scrape scrapes the API for all courses and stores them in Redis. +// This is a long-running process that should be run in a goroutine. +// TODO: Switch from hardcoded term to dynamic term +func (a *API) Scrape() error { + // For each subject, retrieve all courses + // For each course, get the details and store it in redis + // Make sure to handle pagination + subjects, err := a.GetSubjects("", "202510", 1, 100) + if err != nil { + return fmt.Errorf("failed to get subjects: %w", err) } - expiredSubjects, err := GetExpiredSubjects() + // Ensure subjects were found + if len(subjects) == 0 { + return fmt.Errorf("no subjects found") + } + + // Extract major code name + for _, subject := range subjects { + // Add to AncillaryMajors if not in PriorityMajors + if !lo.Contains(PriorityMajors, subject.Code) { + AncillaryMajors = append(AncillaryMajors, subject.Code) + } + } + + AllMajors = lo.Flatten([][]string{PriorityMajors, AncillaryMajors}) + + expiredSubjects, err := a.GetExpiredSubjects() if err != nil { return fmt.Errorf("failed to get scrapable majors: %w", err) } log.Info().Strs("majors", expiredSubjects).Msg("Scraping majors") for _, subject := range expiredSubjects { - err := ScrapeMajor(subject) + err := a.ScrapeMajor(subject) if err != nil { return fmt.Errorf("failed to scrape major %s: %w", subject, err) } @@ -68,12 +68,12 @@ func Scrape() error { } // GetExpiredSubjects returns a list of subjects that are expired and should be scraped. -func GetExpiredSubjects() ([]string, error) { +func (a *API) GetExpiredSubjects() ([]string, error) { term := utils.Default(time.Now()).ToString() subjects := make([]string, 0) // Get all subjects - values, err := config.KV.MGet(config.Ctx, lo.Map(AllMajors, func(major string, _ int) string { + values, err := a.config.KV.MGet(a.config.Ctx, lo.Map(AllMajors, func(major string, _ int) string { return fmt.Sprintf("scraped:%s:%s", major, term) })...).Result() if err != nil { @@ -97,14 +97,15 @@ func GetExpiredSubjects() ([]string, error) { // ScrapeMajor is the scraping invocation for a specific major. // This function does not check whether scraping is required at this time, it is assumed that the caller has already done so. -func ScrapeMajor(subject string) error { +func (a *API) ScrapeMajor(subject string) error { offset := 0 totalClassCount := 0 for { // Build & execute the query query := NewQuery().Offset(offset).MaxResults(MaxPageSize * 2).Subject(subject) - result, err := Search(query, "subjectDescription", false) + term := utils.Default(time.Now()).ToString() + result, err := a.Search(term, query, "subjectDescription", false) if err != nil { return fmt.Errorf("search failed: %w (%s)", err, query.String()) } @@ -121,7 +122,7 @@ func ScrapeMajor(subject string) error { // Process each class and store it in Redis for _, course := range result.Data { // Store class in Redis - err := IntakeCourse(course) + err := a.IntakeCourse(course) if err != nil { log.Error().Err(err).Msg("failed to store class in Redis") } @@ -153,14 +154,14 @@ func ScrapeMajor(subject string) error { if totalClassCount == 0 { scrapeExpiry = time.Hour * 12 } else { - scrapeExpiry = CalculateExpiry(term, totalClassCount, lo.Contains(PriorityMajors, subject)) + scrapeExpiry = a.CalculateExpiry(term, totalClassCount, lo.Contains(PriorityMajors, subject)) } // Mark the major as scraped if totalClassCount == 0 { totalClassCount = -1 } - err := config.KV.Set(config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err() + err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err() if err != nil { log.Error().Err(err).Msg("failed to mark major as scraped") } @@ -172,7 +173,7 @@ func ScrapeMajor(subject string) error { // term is the term for which the relevant course is occurring within. // count is the number of courses that were scraped. // priority is a boolean indicating whether the major is a priority major. -func CalculateExpiry(term string, count int, priority bool) time.Duration { +func (a *API) CalculateExpiry(term string, count int, priority bool) time.Duration { // An hour for every 100 classes baseExpiry := time.Hour * time.Duration(count/100) @@ -190,7 +191,7 @@ func CalculateExpiry(term string, count int, priority bool) time.Duration { // If the term is considered "view only" or "archived", then the expiry is multiplied by 5 var expiry = baseExpiry - if IsTermArchived(term) { + if a.IsTermArchived(term) { expiry *= 5 } @@ -212,8 +213,8 @@ func CalculateExpiry(term string, count int, priority bool) time.Duration { // IntakeCourse stores a course in Redis. // This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future. -func IntakeCourse(course models.Course) error { - err := config.KV.Set(config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err() +func (a *API) IntakeCourse(course models.Course) error { + err := a.config.KV.Set(a.config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err() if err != nil { return fmt.Errorf("failed to store class in Redis: %w", err) } diff --git a/internal/api/session.go b/internal/api/session.go index bc8be18..8b7586f 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -1,14 +1,13 @@ package api import ( - "banner/internal/config" "banner/internal/utils" "net/url" log "github.com/rs/zerolog/log" ) -func Setup() { +func (a *API) Setup() { // Makes the initial requests that sets up the session cookies for the rest of the application log.Info().Msg("Setting up session...") @@ -18,17 +17,17 @@ func Setup() { } for _, path := range requestQueue { - req := utils.BuildRequest("GET", path, nil) - DoRequest(req) + req := utils.BuildRequest(a.config, "GET", path, nil) + a.DoRequest(req) } // Validate that cookies were set - baseURLParsed, err := url.Parse(config.BaseURL) + baseURLParsed, err := url.Parse(a.config.BaseURL) if err != nil { - log.Fatal().Stack().Str("baseURL", config.BaseURL).Err(err).Msg("Failed to parse baseURL") + log.Fatal().Stack().Str("baseURL", a.config.BaseURL).Err(err).Msg("Failed to parse baseURL") } - currentCookies := config.Client.Jar.Cookies(baseURLParsed) + currentCookies := a.config.Client.Jar.Cookies(baseURLParsed) requiredCookies := map[string]bool{ "JSESSIONID": false, "SSB_COOKIE": false, diff --git a/internal/bot/bot.go b/internal/bot/bot.go new file mode 100644 index 0000000..abd6b5e --- /dev/null +++ b/internal/bot/bot.go @@ -0,0 +1,40 @@ +package bot + +import ( + "banner/internal/api" + "banner/internal/config" + "banner/internal/utils" + "fmt" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog/log" +) + +type Bot struct { + Session *discordgo.Session + API *api.API + Config *config.Config + isClosing bool +} + +func New(s *discordgo.Session, a *api.API, c *config.Config) *Bot { + return &Bot{Session: s, API: a, Config: c} +} + +func (b *Bot) SetClosing() { + b.isClosing = true +} + +func (b *Bot) GetSession() (string, error) { + sessionID := b.API.EnsureSession() + term := utils.Default(time.Now()).ToString() + + log.Info().Str("term", term).Str("sessionID", sessionID).Msg("Setting selected term") + err := b.API.SelectTerm(term, sessionID) + if err != nil { + return "", fmt.Errorf("failed to select term while generating session ID: %w", err) + } + + return sessionID, nil +} diff --git a/internal/bot/commands.go b/internal/bot/commands.go index 3a97177..458d5c2 100644 --- a/internal/bot/commands.go +++ b/internal/bot/commands.go @@ -2,7 +2,6 @@ package bot import ( "banner/internal/api" - "banner/internal/config" "banner/internal/models" "banner/internal/utils" "fmt" @@ -18,9 +17,16 @@ import ( "github.com/samber/lo" ) +const ( + ICalTimestampFormatUtc = "20060102T150405Z" + ICalTimestampFormatLocal = "20060102T150405" +) + +type CommandHandler func(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error + var ( CommandDefinitions = []*discordgo.ApplicationCommand{TermCommandDefinition, TimeCommandDefinition, SearchCommandDefinition, IcsCommandDefinition} - CommandHandlers = map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate) error{ + CommandHandlers = map[string]CommandHandler{ TimeCommandDefinition.Name: TimeCommandHandler, TermCommandDefinition.Name: TermCommandHandler, SearchCommandDefinition.Name: SearchCommandHandler, @@ -76,8 +82,8 @@ var SearchCommandDefinition = &discordgo.ApplicationCommand{ }, } -func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.InteractionCreate) error { - data := interaction.ApplicationCommandData() +func SearchCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error { + data := i.ApplicationCommandData() query := api.NewQuery().Credits(3, 6) for _, option := range data.Options { @@ -177,9 +183,14 @@ func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.Int } } - courses, err := api.Search(query, "", false) + term, err := b.GetSession() if err != nil { - session.InteractionRespond(interaction.Interaction, &discordgo.InteractionResponse{ + return err + } + + courses, err := b.API.Search(term, query, "", false) + if err != nil { + s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, Data: &discordgo.InteractionResponseData{ Content: "Error searching for courses", @@ -222,12 +233,12 @@ func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.Int color = 0xFF6500 } - err = session.InteractionRespond(interaction.Interaction, &discordgo.InteractionResponse{ + err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, Data: &discordgo.InteractionResponseData{ Embeds: []*discordgo.MessageEmbed{ { - Footer: utils.GetFetchedFooter(fetch_time), + Footer: utils.GetFetchedFooter(b.Config, fetch_time), Description: fmt.Sprintf("%d Class%s", courses.TotalCount, utils.Plural(courses.TotalCount)), Fields: fields[:min(25, len(fields))], Color: color, @@ -262,8 +273,8 @@ var TermCommandDefinition = &discordgo.ApplicationCommand{ }, } -func TermCommandHandler(session *discordgo.Session, interaction *discordgo.InteractionCreate) error { - data := interaction.ApplicationCommandData() +func TermCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error { + data := i.ApplicationCommandData() searchTerm := "" pageNumber := 1 @@ -279,10 +290,10 @@ func TermCommandHandler(session *discordgo.Session, interaction *discordgo.Inter } } - termResult, err := api.GetTerms(searchTerm, pageNumber, 25) + termResult, err := b.API.GetTerms(searchTerm, pageNumber, 25) if err != nil { - utils.RespondError(session, interaction.Interaction, "Error while fetching terms", err) + utils.RespondError(s, i.Interaction, "Error while fetching terms", err) return err } @@ -302,12 +313,12 @@ func TermCommandHandler(session *discordgo.Session, interaction *discordgo.Inter log.Warn().Int("count", len(fields)).Msg("Too many fields in term command (trimmed)") } - err = session.InteractionRespond(interaction.Interaction, &discordgo.InteractionResponse{ + err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, Data: &discordgo.InteractionResponseData{ Embeds: []*discordgo.MessageEmbed{ { - Footer: utils.GetFetchedFooter(fetch_time), + Footer: utils.GetFetchedFooter(b.Config, fetch_time), Description: fmt.Sprintf("%d term%s (page %d)", len(termResult), utils.Plural(len(termResult)), pageNumber), Fields: fields[:min(25, len(fields))], }, @@ -332,12 +343,12 @@ var TimeCommandDefinition = &discordgo.ApplicationCommand{ }, } -func TimeCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) error { +func TimeCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error { fetch_time := time.Now() crn := i.ApplicationCommandData().Options[0].IntValue() // Fix static term - meetingTimes, err := api.GetCourseMeetingTime(202510, int(crn)) + meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn)) if err != nil { s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -356,7 +367,7 @@ func TimeCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) er Data: &discordgo.InteractionResponseData{ Embeds: []*discordgo.MessageEmbed{ { - Footer: utils.GetFetchedFooter(fetch_time), + Footer: utils.GetFetchedFooter(b.Config, fetch_time), Description: "", Fields: []*discordgo.MessageEmbedField{ { @@ -397,16 +408,18 @@ var IcsCommandDefinition = &discordgo.ApplicationCommand{ }, } -func IcsCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) error { - crn := i.ApplicationCommandData().Options[0].IntValue() +func IcsCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error { + // Parse all options + options := utils.ParseOptions(i.ApplicationCommandData().Options) + crn := options.GetInt("crn") - course, err := api.GetCourse(strconv.Itoa(int(crn))) + course, err := b.API.GetCourse(strconv.Itoa(int(crn))) if err != nil { return fmt.Errorf("Error retrieving course data: %w", err) } // Fix static term - meetingTimes, err := api.GetCourseMeetingTime(202510, int(crn)) + meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn)) if err != nil { return fmt.Errorf("Error requesting meeting time: %w", err) } @@ -433,22 +446,24 @@ func IcsCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) err events := []string{} for _, meeting := range meetingTimes { - now := time.Now().In(config.CentralTimeLocation) + now := time.Now().In(b.Config.CentralTimeLocation) uid := fmt.Sprintf("%d-%s@ical.banner.xevion.dev", now.Unix(), meeting.CourseReferenceNumber) startDay := meeting.StartDay() startTime := meeting.StartTime() endTime := meeting.EndTime() - dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, config.CentralTimeLocation) - dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, config.CentralTimeLocation) + dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, b.Config.CentralTimeLocation) + dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, b.Config.CentralTimeLocation) - endDay := meeting.EndDay() - until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, config.CentralTimeLocation) + // endDay := meeting.EndDay() + // until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, b.Config.CentralTimeLocation) summary := fmt.Sprintf("%s %s %s", course.Subject, course.CourseNumber, course.CourseTitle) description := fmt.Sprintf("Instructor: %s\nSection: %s\nCRN: %s", course.Faculty[0].DisplayName, course.SequenceNumber, meeting.CourseReferenceNumber) location := meeting.PlaceString() + rrule := meeting.RRule() + event := fmt.Sprintf(`BEGIN:VEVENT DTSTAMP:%s UID:%s @@ -458,7 +473,7 @@ DTEND;TZID=America/Chicago:%s SUMMARY:%s DESCRIPTION:%s LOCATION:%s -END:VEVENT`, now.Format(config.ICalTimestampFormatLocal), uid, dtStart.Format(config.ICalTimestampFormatLocal), meeting.ByDay(), until.Format(config.ICalTimestampFormatLocal), dtEnd.Format(config.ICalTimestampFormatLocal), summary, strings.Replace(description, "\n", `\n`, -1), location) +END:VEVENT`, now.Format(ICalTimestampFormatLocal), uid, dtStart.Format(ICalTimestampFormatLocal), rrule.ByDay, rrule.Until, dtEnd.Format(ICalTimestampFormatLocal), summary, strings.Replace(description, "\n", `\n`, -1), location) events = append(events, event) } diff --git a/internal/bot/handlers.go b/internal/bot/handlers.go new file mode 100644 index 0000000..cc9b00d --- /dev/null +++ b/internal/bot/handlers.go @@ -0,0 +1,90 @@ +package bot + +import ( + "banner/internal/utils" + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func (b *Bot) RegisterHandlers() { + b.Session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) { + // Handle commands during restart (highly unlikely, but just in case) + if b.isClosing { + err := utils.RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil) + if err != nil { + log.Error().Err(err).Msg("Failed to respond with restart error feedback") + } + return + } + + name := interaction.ApplicationCommandData().Name + if handler, ok := CommandHandlers[name]; ok { + // Build dict of options for the log + options := zerolog.Dict() + for _, option := range interaction.ApplicationCommandData().Options { + options.Str(option.Name, fmt.Sprintf("%v", option.Value)) + } + + event := log.Info().Str("name", name).Str("user", utils.GetUser(interaction).Username).Dict("options", options) + + // If the command was invoked in a guild, add guild & channel info to the log + if interaction.Member != nil { + guild := zerolog.Dict() + guild.Str("id", interaction.GuildID) + guild.Str("name", utils.GetGuildName(b.Config, internalSession, interaction.GuildID)) + event.Dict("guild", guild) + + channel := zerolog.Dict() + channel.Str("id", interaction.ChannelID) + guild.Str("name", utils.GetChannelName(b.Config, internalSession, interaction.ChannelID)) + event.Dict("channel", channel) + } else { + // If the command was invoked in a DM, add the user info to the log + user := zerolog.Dict() + user.Str("id", interaction.User.ID) + user.Str("name", interaction.User.Username) + event.Dict("user", user) + } + + // Log command invocation + event.Msg("Command Invoked") + + // Prepare to recover + defer func() { + if err := recover(); err != nil { + log.Error().Stack().Str("commandName", name).Interface("detail", err).Msg("Command Handler Panic") + + // Respond with error + err := utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil) + if err != nil { + log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with panic error feedback") + } + } + }() + + // Call handler + err := handler(b, internalSession, interaction) + + // Log & respond error + if err != nil { + // TODO: Find a way to merge the response with the handler's error + log.Error().Str("commandName", name).Err(err).Msg("Command Handler Error") + + // Respond with error + err = utils.RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil) + if err != nil { + log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with error feedback") + } + } + + } else { + log.Error().Stack().Str("commandName", name).Msg("Command Interaction Has No Handler") + + // Respond with error + utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil) + } + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index d2d6622..545b629 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,57 +8,51 @@ import ( "github.com/redis/go-redis/v9" ) -var ( - // Global variables that need to be accessible across packages +type Config struct { Ctx context.Context KV *redis.Client - Client http.Client - Cookies http.CookieJar + Client *http.Client IsDevelopment bool BaseURL string Environment string CentralTimeLocation *time.Location - IsClosing bool = false -) +} const ( - ICalTimestampFormatUtc = "20060102T150405Z" - ICalTimestampFormatLocal = "20060102T150405" - CentralTimezoneName = "America/Chicago" + CentralTimezoneName = "America/Chicago" ) -func init() { - Ctx = context.Background() +func New() (*Config, error) { + ctx := context.Background() - var err error - CentralTimeLocation, err = time.LoadLocation(CentralTimezoneName) + loc, err := time.LoadLocation(CentralTimezoneName) if err != nil { - panic(err) + return nil, err } + + return &Config{ + Ctx: ctx, + CentralTimeLocation: loc, + }, nil } // SetBaseURL sets the base URL for API requests -func SetBaseURL(url string) { - BaseURL = url +func (c *Config) SetBaseURL(url string) { + c.BaseURL = url } // SetEnvironment sets the environment -func SetEnvironment(env string) { - Environment = env - IsDevelopment = env == "development" +func (c *Config) SetEnvironment(env string) { + c.Environment = env + c.IsDevelopment = env == "development" } // SetClient sets the HTTP client -func SetClient(c http.Client) { - Client = c -} - -// SetCookies sets the cookie jar -func SetCookies(cj http.CookieJar) { - Cookies = cj +func (c *Config) SetClient(client *http.Client) { + c.Client = client } // SetRedis sets the Redis client -func SetRedis(r *redis.Client) { - KV = r +func (c *Config) SetRedis(r *redis.Client) { + c.KV = r } diff --git a/internal/models/types.go b/internal/models/types.go index 579b263..d7c5017 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -1,7 +1,6 @@ package models import ( - "banner/internal/config" "banner/internal/utils" "encoding/json" "fmt" @@ -226,15 +225,17 @@ func (m *MeetingTimeResponse) EndTime() *utils.NaiveTime { return utils.ParseNaiveTime(value) } +type RRule struct { + Until string + ByDay string +} + // Converts the meeting time to a string that satisfies the iCalendar RRule format -func (m *MeetingTimeResponse) RRule() string { - sb := strings.Builder{} - - sb.WriteString("FREQ=WEEKLY;") - sb.WriteString(fmt.Sprintf("UNTIL=%s;", m.EndDay().UTC().Format(config.ICalTimestampFormatUtc))) - sb.WriteString(fmt.Sprintf("BYDAY=%s;", m.ByDay())) - - return sb.String() +func (m *MeetingTimeResponse) RRule() RRule { + return RRule{ + Until: m.EndDay().UTC().Format("20060102T150405Z"), + ByDay: m.ByDay(), + } } type SearchResult struct { diff --git a/internal/utils/helpers.go b/internal/utils/helpers.go index c8fb01b..aeb8916 100644 --- a/internal/utils/helpers.go +++ b/internal/utils/helpers.go @@ -20,10 +20,30 @@ import ( "banner/internal/config" ) +// Options is a map of options from a discord command. +type Options map[string]*discordgo.ApplicationCommandInteractionDataOption + +// GetInt returns the integer value of an option. +func (o Options) GetInt(key string) int64 { + if opt, ok := o[key]; ok { + return opt.IntValue() + } + return 0 +} + +// ParseOptions parses slash command options into a map. +func ParseOptions(options []*discordgo.ApplicationCommandInteractionDataOption) Options { + optionMap := make(Options) + for _, opt := range options { + optionMap[opt.Name] = opt + } + return optionMap +} + // BuildRequestWithBody builds a request with the given method, path, parameters, and body -func BuildRequestWithBody(method string, path string, params map[string]string, body io.Reader) *http.Request { +func BuildRequestWithBody(cfg *config.Config, method string, path string, params map[string]string, body io.Reader) *http.Request { // Builds a URL for the given path and parameters - requestUrl := config.BaseURL + path + requestUrl := cfg.BaseURL + path if params != nil { takenFirst := false @@ -44,8 +64,8 @@ func BuildRequestWithBody(method string, path string, params map[string]string, } // BuildRequest builds a request with the given method, path, and parameters and an empty body -func BuildRequest(method string, path string, params map[string]string) *http.Request { - return BuildRequestWithBody(method, path, params, nil) +func BuildRequest(cfg *config.Config, method string, path string, params map[string]string) *http.Request { + return BuildRequestWithBody(cfg, method, path, params, nil) } // AddUserAgent adds a false but consistent user agent to the request @@ -309,9 +329,9 @@ func RespondError(session *discordgo.Session, interaction *discordgo.Interaction }) } -func GetFetchedFooter(time time.Time) *discordgo.MessageEmbedFooter { +func GetFetchedFooter(cfg *config.Config, time time.Time) *discordgo.MessageEmbedFooter { return &discordgo.MessageEmbedFooter{ - Text: fmt.Sprintf("Fetched at %s", time.In(config.CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")), + Text: fmt.Sprintf("Fetched at %s", time.In(cfg.CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")), } } diff --git a/internal/utils/meta.go b/internal/utils/meta.go index bad899d..3e6457f 100644 --- a/internal/utils/meta.go +++ b/internal/utils/meta.go @@ -10,9 +10,9 @@ import ( ) // GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value -func GetGuildName(session *discordgo.Session, guildID string) string { +func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string { // Check Redis for the guild name - guildName, err := config.KV.Get(config.Ctx, "guild:"+guildID+":name").Result() + guildName, err := cfg.KV.Get(cfg.Ctx, "guild:"+guildID+":name").Result() if err != nil && err != redis.Nil { log.Error().Stack().Err(err).Msg("Error getting guild name from Redis") return "err" @@ -29,7 +29,7 @@ func GetGuildName(session *discordgo.Session, guildID string) string { log.Error().Stack().Err(err).Msg("Error getting guild name") // Store an invalid value in Redis so we don't keep trying to get the guild name - _, err := config.KV.Set(config.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result() + _, err := cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result() if err != nil { log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis") } @@ -38,15 +38,15 @@ func GetGuildName(session *discordgo.Session, guildID string) string { } // Cache the guild name in Redis - config.KV.Set(config.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3) + cfg.KV.Set(cfg.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3) return guild.Name } // GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value -func GetChannelName(session *discordgo.Session, channelID string) string { +func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string { // Check Redis for the channel name - channelName, err := config.KV.Get(config.Ctx, "channel:"+channelID+":name").Result() + channelName, err := cfg.KV.Get(cfg.Ctx, "channel:"+channelID+":name").Result() if err != nil && err != redis.Nil { log.Error().Stack().Err(err).Msg("Error getting channel name from Redis") return "err" @@ -63,7 +63,7 @@ func GetChannelName(session *discordgo.Session, channelID string) string { log.Error().Stack().Err(err).Msg("Error getting channel name") // Store an invalid value in Redis so we don't keep trying to get the channel name - _, err := config.KV.Set(config.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result() + _, err := cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result() if err != nil { log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis") } @@ -72,7 +72,7 @@ func GetChannelName(session *discordgo.Session, channelID string) string { } // Cache the channel name in Redis - config.KV.Set(config.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3) + cfg.KV.Set(cfg.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3) return channel.Name } diff --git a/internal/utils/term.go b/internal/utils/term.go index 1386ece..5fd79a8 100644 --- a/internal/utils/term.go +++ b/internal/utils/term.go @@ -30,7 +30,8 @@ var ( ) func init() { - SpringRange, SummerRange, FallRange = GetYearDayRange(uint16(time.Now().Year())) + loc, _ := time.LoadLocation(config.CentralTimezoneName) + SpringRange, SummerRange, FallRange = GetYearDayRange(loc, uint16(time.Now().Year())) currentTerm, nextTerm := GetCurrentTerm(time.Now()) log.Debug().Str("CurrentTerm", fmt.Sprintf("%+v", currentTerm)).Str("NextTerm", fmt.Sprintf("%+v", nextTerm)).Msg("GetCurrentTerm") @@ -46,13 +47,13 @@ type YearDayRange struct { // Spring: January 14th to May // Summer: May 25th - August 15th // Fall: August 18th - December 10th -func GetYearDayRange(year uint16) (YearDayRange, YearDayRange, YearDayRange) { - springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() - springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() - summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() - summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() - fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() - fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, config.CentralTimeLocation).YearDay() +func GetYearDayRange(loc *time.Location, year uint16) (YearDayRange, YearDayRange, YearDayRange) { + springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, loc).YearDay() + springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, loc).YearDay() + summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, loc).YearDay() + summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, loc).YearDay() + fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, loc).YearDay() + fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, loc).YearDay() return YearDayRange{ Start: uint16(springStart),