refactor: complete refactor into cmd/ & internal/ submodules

This commit is contained in:
2025-08-25 22:57:05 -05:00
parent 2d25bb8921
commit b16c2d51bc
14 changed files with 333 additions and 243 deletions

6
.gitignore vendored
View File

@@ -1,8 +1,10 @@
.env
cover.cov
banner
./banner
.*.go
dumps/
js/
.vscode/
*.prof
*.prof
.task/
bin/

View File

@@ -22,12 +22,16 @@ import (
"github.com/rs/zerolog/pkgerrors"
"github.com/samber/lo"
"golang.org/x/text/message"
"banner/internal/api"
"banner/internal/bot"
"banner/internal/utils"
)
var (
ctx context.Context
kv *redis.Client
session *discordgo.Session
Session *discordgo.Session
client http.Client
cookies http.CookieJar
isDevelopment bool
@@ -66,7 +70,7 @@ func init() {
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
// Try to grab the environment variable, or default to development
environment = GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
environment = utils.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
if environment == "" {
environment = "development"
}
@@ -74,21 +78,21 @@ func init() {
// Use the custom console writer if we're in development
isDevelopment = environment == "development"
if isDevelopment {
log.Logger = zerolog.New(logSplitter{std: stdConsole, err: errConsole}).With().Timestamp().Logger()
log.Logger = zerolog.New(utils.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger()
} else {
log.Logger = zerolog.New(logSplitter{std: os.Stdout, err: os.Stderr}).With().Timestamp().Logger()
log.Logger = zerolog.New(utils.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger()
}
log.Debug().Str("environment", environment).Msg("Loggers Setup")
// Set discordgo's logger to use zerolog
discordgo.Logger = DiscordGoLogger
discordgo.Logger = utils.DiscordGoLogger
baseURL = os.Getenv("BANNER_BASE_URL")
}
func initRedis() {
// Setup redis
redisUrl := GetFirstEnv("REDIS_URL", "REDIS_PRIVATE_URL")
redisUrl := utils.GetFirstEnv("REDIS_URL", "REDIS_PRIVATE_URL")
if redisUrl == "" {
log.Fatal().Stack().Msg("REDIS_URL/REDIS_PRIVATE_URL not set")
}
@@ -160,28 +164,28 @@ func main() {
// Create client, setup session (acquire cookies)
client = http.Client{Jar: cookies}
setup()
api.Setup()
// Create discord session
session, err = discordgo.New("Bot " + os.Getenv("BOT_TOKEN"))
Session, err = discordgo.New("Bot " + os.Getenv("BOT_TOKEN"))
if err != nil {
log.Err(err).Msg("Invalid bot parameters")
}
// Open discord session
session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) {
Session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) {
log.Info().Str("username", r.User.Username).Str("discriminator", r.User.Discriminator).Str("id", r.User.ID).Str("session", s.State.SessionID).Msg("Bot is logged in")
})
err = session.Open()
err = Session.Open()
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot open the session")
}
// Setup command handlers
session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) {
Session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) {
// Handle commands during restart (highly unlikely, but just in case)
if isClosing {
err := RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil)
err := utils.RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil)
if err != nil {
log.Error().Err(err).Msg("Failed to respond with restart error feedback")
}
@@ -189,25 +193,25 @@ func main() {
}
name := interaction.ApplicationCommandData().Name
if handler, ok := commandHandlers[name]; ok {
if handler, ok := bot.CommandHandlers[name]; ok {
// Build dict of options for the log
options := zerolog.Dict()
for _, option := range interaction.ApplicationCommandData().Options {
options.Str(option.Name, fmt.Sprintf("%v", option.Value))
}
event := log.Info().Str("name", name).Str("user", GetUser(interaction).Username).Dict("options", options)
event := log.Info().Str("name", name).Str("user", utils.GetUser(interaction).Username).Dict("options", options)
// If the command was invoked in a guild, add guild & channel info to the log
if interaction.Member != nil {
guild := zerolog.Dict()
guild.Str("id", interaction.GuildID)
guild.Str("name", GetGuildName(interaction.GuildID))
guild.Str("name", utils.GetGuildName(internalSession, interaction.GuildID))
event.Dict("guild", guild)
channel := zerolog.Dict()
channel.Str("id", interaction.ChannelID)
guild.Str("name", GetChannelName(interaction.ChannelID))
guild.Str("name", utils.GetChannelName(internalSession, interaction.ChannelID))
event.Dict("channel", channel)
} else {
// If the command was invoked in a DM, add the user info to the log
@@ -226,7 +230,7 @@ func main() {
log.Error().Stack().Str("commandName", name).Interface("detail", err).Msg("Command Handler Panic")
// Respond with error
err := RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil)
err := utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil)
if err != nil {
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with panic error feedback")
}
@@ -242,7 +246,7 @@ func main() {
log.Error().Str("commandName", name).Err(err).Msg("Command Handler Error")
// Respond with error
err = RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil)
err = utils.RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil)
if err != nil {
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with error feedback")
}
@@ -252,13 +256,13 @@ func main() {
log.Error().Stack().Str("commandName", name).Msg("Command Interaction Has No Handler")
// Respond with error
RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil)
utils.RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil)
}
})
// Register commands with discord
arr := zerolog.Arr()
lo.ForEach(commandDefinitions, func(cmd *discordgo.ApplicationCommand, _ int) {
lo.ForEach(bot.CommandDefinitions, func(cmd *discordgo.ApplicationCommand, _ int) {
arr.Str(cmd.Name)
})
log.Info().Array("commands", arr).Msg("Registering commands")
@@ -270,11 +274,11 @@ func main() {
}
// Register commands
existingCommands, err := session.ApplicationCommands(session.State.User.ID, guildTarget)
existingCommands, err := Session.ApplicationCommands(Session.State.User.ID, guildTarget)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot get existing commands")
}
newCommands, err := session.ApplicationCommandBulkOverwrite(session.State.User.ID, guildTarget, commandDefinitions)
newCommands, err := Session.ApplicationCommandBulkOverwrite(Session.State.User.ID, guildTarget, bot.CommandDefinitions)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot register commands")
}
@@ -300,7 +304,7 @@ func main() {
}
// Fetch terms on startup
err = TryReloadTerms()
err = api.TryReloadTerms()
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot fetch terms on startup")
}
@@ -308,7 +312,7 @@ func main() {
// Launch a goroutine to scrape the banner system periodically
go func() {
for {
err := Scrape()
err := api.Scrape()
if err != nil {
log.Err(err).Stack().Msg("Periodic Scrape Failed")
}
@@ -318,7 +322,7 @@ func main() {
}()
// Close session, ensure http client closes idle connections
defer session.Close()
defer Session.Close()
defer client.CloseIdleConnections()
// Setup signal handler channel

View File

@@ -1,11 +1,15 @@
package main
package api
import (
"banner/internal/config"
"banner/internal/models"
"banner/internal/utils"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
@@ -13,7 +17,9 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
)
var (
@@ -34,7 +40,101 @@ func ResetSessionTimer() {
// GenerateSession generates a new session ID (nonce) for use with the Banner API.
// Don't use this function directly, use GetSession instead.
func GenerateSession() string {
return RandomString(5) + Nonce()
return utils.RandomString(5) + utils.Nonce()
}
// DoRequest performs & logs the request, logging and returning the response
func DoRequest(req *http.Request) (*http.Response, error) {
headerSize := 0
for key, values := range req.Header {
for _, value := range values {
headerSize += len(key)
headerSize += len(value)
}
}
bodySize := int64(0)
if req.Body != nil {
bodySize, _ = io.Copy(io.Discard, req.Body)
}
size := zerolog.Dict().Int64("body", bodySize).Int("header", headerSize).Int("url", len(req.URL.String()))
log.Debug().
Dict("size", size).
Str("method", strings.TrimRight(req.Method, " ")).
Str("url", req.URL.String()).
Str("query", req.URL.RawQuery).
Str("content-type", req.Header.Get("Content-Type")).
Msg("Request")
res, err := config.Client.Do(req)
if err != nil {
log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed")
} else {
contentLengthHeader := res.Header.Get("Content-Length")
contentLength := int64(-1)
// If this request was a Banner API request, reset the session timer
if strings.HasPrefix(req.URL.Path, "StudentRegistrationSsb/ssb/classSearch/") {
ResetSessionTimer()
}
// Get the content length
if contentLengthHeader != "" {
contentLength, err = strconv.ParseInt(contentLengthHeader, 10, 64)
if err != nil {
contentLength = -1
}
}
log.Debug().Int("status", res.StatusCode).Int64("content-length", contentLength).Strs("content-type", res.Header["Content-Type"]).Msg("Response")
}
return res, err
}
var terms []BannerTerm
var lastTermUpdate time.Time
// TryReloadTerms attempts to reload the terms if they are not loaded or the last update was more than 24 hours ago
func TryReloadTerms() error {
if len(terms) > 0 && time.Since(lastTermUpdate) < 24*time.Hour {
return nil
}
// Load the terms
var err error
terms, err = GetTerms("", 1, 100)
if err != nil {
return fmt.Errorf("failed to load terms: %w", err)
}
lastTermUpdate = time.Now()
return nil
}
// IsTermArchived checks if the given term is archived
// TODO: Add error, switch missing term logic to error
func IsTermArchived(term string) bool {
// Ensure the terms are loaded
err := TryReloadTerms()
if err != nil {
log.Err(err).Stack().Msg("Failed to reload terms")
return true
}
// Check if the term is in the list of terms
bannerTerm, exists := lo.Find(terms, func(t BannerTerm) bool {
return t.Code == term
})
if !exists {
log.Warn().Str("term", term).Msg("Term does not exist")
return true
}
return bannerTerm.Archived()
}
// GetSession retrieves the current session ID if it's still valid.
@@ -47,7 +147,7 @@ func GetSession() string {
latestSession = GenerateSession()
// Select the current term
term := Default(time.Now()).ToString()
term := utils.Default(time.Now()).ToString()
log.Info().Str("term", term).Str("sessionID", latestSession).Msg("Setting selected term")
err := SelectTerm(term, latestSession)
if err != nil {
@@ -81,12 +181,12 @@ func GetTerms(search string, page int, max int) ([]BannerTerm, error) {
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/getTerms", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/getTerms", map[string]string{
"searchTerm": search,
// Page vs Offset is not a mistake here, the API uses "offset" as the page number
"offset": strconv.Itoa(page),
"max": strconv.Itoa(max),
"_": Nonce(),
"_": utils.Nonce(),
})
if page <= 0 {
@@ -99,9 +199,9 @@ func GetTerms(search string, page int, max int) ([]BannerTerm, error) {
}
// Assert that the response is JSON
if contentType := res.Header.Get("Content-Type"); !strings.Contains(contentType, JsonContentType) {
return nil, &UnexpectedContentTypeError{
Expected: JsonContentType,
if contentType := res.Header.Get("Content-Type"); !strings.Contains(contentType, models.JsonContentType) {
return nil, &utils.UnexpectedContentTypeError{
Expected: models.JsonContentType,
Actual: contentType,
}
}
@@ -139,7 +239,7 @@ func SelectTerm(term string, sessionId string) error {
"mode": "search",
}
req := BuildRequestWithBody("POST", "/term/search", params, bytes.NewBufferString(form.Encode()))
req := utils.BuildRequestWithBody("POST", "/term/search", params, bytes.NewBufferString(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
res, err := DoRequest(req)
@@ -148,7 +248,7 @@ func SelectTerm(term string, sessionId string) error {
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
return fmt.Errorf("response was not JSON: %w", res.Header.Get("Content-Type"))
}
@@ -165,7 +265,7 @@ func SelectTerm(term string, sessionId string) error {
json.Unmarshal(body, &redirectResponse)
// Make a GET request to the fwdUrl
req = BuildRequest("GET", redirectResponse.FwdUrl, nil)
req = utils.BuildRequest("GET", redirectResponse.FwdUrl, nil)
res, err = DoRequest(req)
if err != nil {
return fmt.Errorf("failed to follow redirect: %w", err)
@@ -187,13 +287,13 @@ func GetPartOfTerms(search string, term int, offset int, max int) ([]BannerTerm,
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/get_partOfTerm", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/get_partOfTerm", map[string]string{
"searchTerm": search,
"term": strconv.Itoa(term),
"offset": strconv.Itoa(offset),
"max": strconv.Itoa(max),
"uniqueSessionId": GetSession(),
"_": Nonce(),
"_": utils.Nonce(),
})
res, err := DoRequest(req)
@@ -202,7 +302,7 @@ func GetPartOfTerms(search string, term int, offset int, max int) ([]BannerTerm,
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Panic().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -231,13 +331,13 @@ func GetInstructors(search string, term string, offset int, max int) ([]Instruct
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/get_instructor", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/get_instructor", map[string]string{
"searchTerm": search,
"term": term,
"offset": strconv.Itoa(offset),
"max": strconv.Itoa(max),
"uniqueSessionId": GetSession(),
"_": Nonce(),
"_": utils.Nonce(),
})
res, err := DoRequest(req)
@@ -246,7 +346,7 @@ func GetInstructors(search string, term string, offset int, max int) ([]Instruct
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -279,7 +379,7 @@ func GetCourseDetails(term int, crn int) *ClassDetails {
if err != nil {
log.Fatal().Stack().Err(err).Msg("Failed to marshal body")
}
req := BuildRequestWithBody("GET", "/searchResults/getClassDetails", nil, bytes.NewBuffer(body))
req := utils.BuildRequestWithBody("GET", "/searchResults/getClassDetails", nil, bytes.NewBuffer(body))
res, err := DoRequest(req)
if err != nil {
@@ -287,7 +387,7 @@ func GetCourseDetails(term int, crn int) *ClassDetails {
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -295,7 +395,7 @@ func GetCourseDetails(term int, crn int) *ClassDetails {
}
// Search invokes a search on the Banner system with the given query and returns the results.
func Search(query *Query, sort string, sortDescending bool) (*SearchResult, error) {
func Search(query *Query, sort string, sortDescending bool) (*models.SearchResult, error) {
ResetDataForm()
params := query.Paramify()
@@ -309,7 +409,7 @@ func Search(query *Query, sort string, sortDescending bool) (*SearchResult, erro
params["startDatepicker"] = ""
params["endDatepicker"] = ""
req := BuildRequest("GET", "/searchResults/searchResults", params)
req := utils.BuildRequest("GET", "/searchResults/searchResults", params)
res, err := DoRequest(req)
if err != nil {
@@ -321,7 +421,7 @@ func Search(query *Query, sort string, sortDescending bool) (*SearchResult, erro
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
// for server 500 errors, parse for the error with '#dialog-message > div.message'
log.Error().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -332,7 +432,7 @@ func Search(query *Query, sort string, sortDescending bool) (*SearchResult, erro
return nil, fmt.Errorf("failed to read response body: %w", err)
}
var result SearchResult
var result models.SearchResult
err = json.Unmarshal(body, &result)
if err != nil {
@@ -351,13 +451,13 @@ func GetSubjects(search string, term string, offset int, max int) ([]Pair, error
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/get_subject", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/get_subject", map[string]string{
"searchTerm": search,
"term": term,
"offset": strconv.Itoa(offset),
"max": strconv.Itoa(max),
"uniqueSessionId": GetSession(),
"_": Nonce(),
"_": utils.Nonce(),
})
res, err := DoRequest(req)
@@ -366,7 +466,7 @@ func GetSubjects(search string, term string, offset int, max int) ([]Pair, error
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -395,13 +495,13 @@ func GetCampuses(search string, term int, offset int, max int) ([]Pair, error) {
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/get_campus", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/get_campus", map[string]string{
"searchTerm": search,
"term": strconv.Itoa(term),
"offset": strconv.Itoa(offset),
"max": strconv.Itoa(max),
"uniqueSessionId": GetSession(),
"_": Nonce(),
"_": utils.Nonce(),
})
res, err := DoRequest(req)
@@ -410,7 +510,7 @@ func GetCampuses(search string, term int, offset int, max int) ([]Pair, error) {
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -439,13 +539,13 @@ func GetInstructionalMethods(search string, term string, offset int, max int) ([
return nil, errors.New("offset must be greater than 0")
}
req := BuildRequest("GET", "/classSearch/get_instructionalMethod", map[string]string{
req := utils.BuildRequest("GET", "/classSearch/get_instructionalMethod", map[string]string{
"searchTerm": search,
"term": term,
"offset": strconv.Itoa(offset),
"max": strconv.Itoa(max),
"uniqueSessionId": GetSession(),
"_": Nonce(),
"_": utils.Nonce(),
})
res, err := DoRequest(req)
@@ -454,7 +554,7 @@ func GetInstructionalMethods(search string, term string, offset int, max int) ([
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -473,8 +573,8 @@ func GetInstructionalMethods(search string, term string, offset int, max int) ([
// GetCourseMeetingTime retrieves the meeting time information for a course based on the given term and course reference number (CRN).
// It makes an HTTP GET request to the appropriate API endpoint and parses the response to extract the meeting time data.
// The function returns a MeetingTimeResponse struct containing the extracted information.
func GetCourseMeetingTime(term int, crn int) ([]MeetingTimeResponse, error) {
req := BuildRequest("GET", "/searchResults/getFacultyMeetingTimes", map[string]string{
func GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, error) {
req := utils.BuildRequest("GET", "/searchResults/getFacultyMeetingTimes", map[string]string{
"term": strconv.Itoa(term),
"courseReferenceNumber": strconv.Itoa(crn),
})
@@ -485,7 +585,7 @@ func GetCourseMeetingTime(term int, crn int) ([]MeetingTimeResponse, error) {
}
// Assert that the response is JSON
if !ContentTypeMatch(res, "application/json") {
if !utils.ContentTypeMatch(res, "application/json") {
log.Fatal().Stack().Str("content-type", res.Header.Get("Content-Type")).Msg("Response was not JSON")
}
@@ -498,7 +598,7 @@ func GetCourseMeetingTime(term int, crn int) ([]MeetingTimeResponse, error) {
// Parse the JSON into a MeetingTimeResponse struct
var meetingTime struct {
Inner []MeetingTimeResponse `json:"fmt"`
Inner []models.MeetingTimeResponse `json:"fmt"`
}
err = json.Unmarshal(body, &meetingTime)
if err != nil {
@@ -510,7 +610,7 @@ func GetCourseMeetingTime(term int, crn int) ([]MeetingTimeResponse, error) {
// ResetDataForm makes a POST request that needs to be made upon before new search requests can be made.
func ResetDataForm() {
req := BuildRequest("POST", "/classSearch/resetDataForm", nil)
req := utils.BuildRequest("POST", "/classSearch/resetDataForm", nil)
_, err := DoRequest(req)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Failed to reset data form")
@@ -519,9 +619,9 @@ func ResetDataForm() {
// GetCourse retrieves the course information.
// This course does not retrieve directly from the API, but rather uses scraped data stored in Redis.
func GetCourse(crn string) (*Course, error) {
func GetCourse(crn string) (*models.Course, error) {
// Retrieve raw data
result, err := kv.Get(ctx, fmt.Sprintf("class:%s", crn)).Result()
result, err := config.KV.Get(config.Ctx, fmt.Sprintf("class:%s", crn)).Result()
if err != nil {
if err == redis.Nil {
return nil, fmt.Errorf("course not found: %w", err)
@@ -530,7 +630,7 @@ func GetCourse(crn string) (*Course, error) {
}
// Unmarshal the raw data
var course Course
var course models.Course
err = json.Unmarshal([]byte(result), &course)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal course: %w", err)

View File

@@ -1,6 +1,9 @@
package main
package api
import (
"banner/internal/config"
"banner/internal/models"
"banner/internal/utils"
"fmt"
"math/rand"
"time"
@@ -26,7 +29,7 @@ var (
func Scrape() error {
// Populate AllMajors if it is empty
if len(AncillaryMajors) == 0 {
term := Default(time.Now()).ToString()
term := utils.Default(time.Now()).ToString()
subjects, err := GetSubjects("", term, 1, 99)
if err != nil {
return fmt.Errorf("failed to get subjects: %w", err)
@@ -66,11 +69,11 @@ func Scrape() error {
// GetExpiredSubjects returns a list of subjects that are expired and should be scraped.
func GetExpiredSubjects() ([]string, error) {
term := Default(time.Now()).ToString()
term := utils.Default(time.Now()).ToString()
subjects := make([]string, 0)
// Get all subjects
values, err := kv.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string {
values, err := config.KV.MGet(config.Ctx, lo.Map(AllMajors, func(major string, _ int) string {
return fmt.Sprintf("scraped:%s:%s", major, term)
})...).Result()
if err != nil {
@@ -144,7 +147,7 @@ func ScrapeMajor(subject string) error {
}
}
term := Default(time.Now()).ToString()
term := utils.Default(time.Now()).ToString()
// Calculate the expiry time for the scrape (1 hour for every 200 classes, random +-15%) with a minimum of 1 hour
var scrapeExpiry time.Duration
@@ -158,7 +161,7 @@ func ScrapeMajor(subject string) error {
if totalClassCount == 0 {
totalClassCount = -1
}
err := kv.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
err := config.KV.Set(config.Ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
if err != nil {
log.Error().Err(err).Msg("failed to mark major as scraped")
}
@@ -177,7 +180,7 @@ func CalculateExpiry(term string, count int, priority bool) time.Duration {
// Subjects with less than 50 classes have a reversed expiry (less classes, longer interval)
// 1 class => 12 hours, 49 classes => 1 hour
if count < 50 {
hours := Slope(Point{1, 12}, Point{49, 1}, float64(count)).Y
hours := utils.Slope(utils.Point{1, 12}, utils.Point{49, 1}, float64(count)).Y
baseExpiry = time.Duration(hours * float64(time.Hour))
}
@@ -210,8 +213,8 @@ func CalculateExpiry(term string, count int, priority bool) time.Duration {
// IntakeCourse stores a course in Redis.
// This function is mostly a stub for now, but will be used to handle change identification, notifications, and SQLite upserts in the future.
func IntakeCourse(course Course) error {
err := kv.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
func IntakeCourse(course models.Course) error {
err := config.KV.Set(config.Ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
if err != nil {
return fmt.Errorf("failed to store class in Redis: %w", err)
}

View File

@@ -1,4 +1,4 @@
package main
package api
import (
"fmt"

View File

@@ -1,12 +1,14 @@
package main
package api
import (
"banner/internal/config"
"banner/internal/utils"
"net/url"
log "github.com/rs/zerolog/log"
)
func setup() {
func Setup() {
// Makes the initial requests that sets up the session cookies for the rest of the application
log.Info().Msg("Setting up session...")
@@ -16,17 +18,17 @@ func setup() {
}
for _, path := range request_queue {
req := BuildRequest("GET", path, nil)
req := utils.BuildRequest("GET", path, nil)
DoRequest(req)
}
// Validate that cookies were set
baseUrlParsed, err := url.Parse(baseURL)
baseUrlParsed, err := url.Parse(config.BaseURL)
if err != nil {
log.Fatal().Stack().Str("baseURL", baseURL).Err(err).Msg("Failed to parse baseURL")
log.Fatal().Stack().Str("baseURL", config.BaseURL).Err(err).Msg("Failed to parse baseURL")
}
current_cookies := client.Jar.Cookies(baseUrlParsed)
current_cookies := config.Client.Jar.Cookies(baseUrlParsed)
required_cookies := map[string]bool{
"JSESSIONID": false,
"SSB_COOKIE": false,

View File

@@ -1,6 +1,10 @@
package main
package bot
import (
"banner/internal/api"
"banner/internal/config"
"banner/internal/models"
"banner/internal/utils"
"fmt"
"net/url"
"regexp"
@@ -15,8 +19,8 @@ import (
)
var (
commandDefinitions = []*discordgo.ApplicationCommand{TermCommandDefinition, TimeCommandDefinition, SearchCommandDefinition, IcsCommandDefinition}
commandHandlers = map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate) error{
CommandDefinitions = []*discordgo.ApplicationCommand{TermCommandDefinition, TimeCommandDefinition, SearchCommandDefinition, IcsCommandDefinition}
CommandHandlers = map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate) error{
TimeCommandDefinition.Name: TimeCommandHandler,
TermCommandDefinition.Name: TermCommandHandler,
SearchCommandDefinition.Name: SearchCommandHandler,
@@ -30,7 +34,7 @@ var SearchCommandDefinition = &discordgo.ApplicationCommand{
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
MinLength: GetIntPointer(0),
MinLength: utils.GetIntPointer(0),
MaxLength: 48,
Name: "title",
Description: "Course Title (exact, use autocomplete)",
@@ -40,7 +44,7 @@ var SearchCommandDefinition = &discordgo.ApplicationCommand{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "code",
MinLength: GetIntPointer(4),
MinLength: utils.GetIntPointer(4),
Description: "Course Code (e.g. 3743, 3000-3999, 3xxx, 3000-)",
Required: false,
},
@@ -74,7 +78,7 @@ var SearchCommandDefinition = &discordgo.ApplicationCommand{
func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.InteractionCreate) error {
data := interaction.ApplicationCommandData()
query := NewQuery().Credits(3, 6)
query := api.NewQuery().Credits(3, 6)
for _, option := range data.Options {
switch option.Name {
@@ -173,7 +177,7 @@ func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.Int
}
}
courses, err := Search(query, "", false)
courses, err := api.Search(query, "", false)
if err != nil {
session.InteractionRespond(interaction.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@@ -223,8 +227,8 @@ func SearchCommandHandler(session *discordgo.Session, interaction *discordgo.Int
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: GetFetchedFooter(fetch_time),
Description: p.Sprintf("%d Class%s", courses.TotalCount, Plurale(courses.TotalCount)),
Footer: utils.GetFetchedFooter(fetch_time),
Description: fmt.Sprintf("%d Class%s", courses.TotalCount, utils.Plural(courses.TotalCount)),
Fields: fields[:min(25, len(fields))],
Color: color,
},
@@ -242,7 +246,7 @@ var TermCommandDefinition = &discordgo.ApplicationCommand{
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
MinLength: GetIntPointer(0),
MinLength: utils.GetIntPointer(0),
MaxLength: 8,
Name: "search",
Description: "Term to search for",
@@ -253,7 +257,7 @@ var TermCommandDefinition = &discordgo.ApplicationCommand{
Name: "page",
Description: "Page Number",
Required: false,
MinValue: GetFloatPointer(1),
MinValue: utils.GetFloatPointer(1),
},
},
}
@@ -275,10 +279,10 @@ func TermCommandHandler(session *discordgo.Session, interaction *discordgo.Inter
}
}
termResult, err := GetTerms(searchTerm, pageNumber, 25)
termResult, err := api.GetTerms(searchTerm, pageNumber, 25)
if err != nil {
RespondError(session, interaction.Interaction, "Error while fetching terms", err)
utils.RespondError(session, interaction.Interaction, "Error while fetching terms", err)
return err
}
@@ -303,8 +307,8 @@ func TermCommandHandler(session *discordgo.Session, interaction *discordgo.Inter
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: GetFetchedFooter(fetch_time),
Description: p.Sprintf("%d of %d term%s (page %d)", len(termResult), len(terms), Plural(len(terms)), pageNumber),
Footer: utils.GetFetchedFooter(fetch_time),
Description: fmt.Sprintf("%d term%s (page %d)", len(termResult), utils.Plural(len(termResult)), pageNumber),
Fields: fields[:min(25, len(fields))],
},
},
@@ -333,7 +337,7 @@ func TimeCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) er
crn := i.ApplicationCommandData().Options[0].IntValue()
// Fix static term
meetingTimes, err := GetCourseMeetingTime(202510, int(crn))
meetingTimes, err := api.GetCourseMeetingTime(202510, int(crn))
if err != nil {
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@@ -352,7 +356,7 @@ func TimeCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) er
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: GetFetchedFooter(fetch_time),
Footer: utils.GetFetchedFooter(fetch_time),
Description: "",
Fields: []*discordgo.MessageEmbedField{
{
@@ -369,7 +373,7 @@ func TimeCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) er
},
{
Name: "Days of Week",
Value: WeekdaysToString(meetingTime.Days()),
Value: utils.WeekdaysToString(meetingTime.Days()),
},
},
},
@@ -396,13 +400,13 @@ var IcsCommandDefinition = &discordgo.ApplicationCommand{
func IcsCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) error {
crn := i.ApplicationCommandData().Options[0].IntValue()
course, err := GetCourse(strconv.Itoa(int(crn)))
course, err := api.GetCourse(strconv.Itoa(int(crn)))
if err != nil {
return fmt.Errorf("Error retrieving course data: %w", err)
}
// Fix static term
meetingTimes, err := GetCourseMeetingTime(202510, int(crn))
meetingTimes, err := api.GetCourseMeetingTime(202510, int(crn))
if err != nil {
return fmt.Errorf("Error requesting meeting time: %w", err)
}
@@ -412,7 +416,7 @@ func IcsCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) err
}
// Check if the course has any meeting times
_, exists := lo.Find(meetingTimes, func(mt MeetingTimeResponse) bool {
_, exists := lo.Find(meetingTimes, func(mt models.MeetingTimeResponse) bool {
switch mt.MeetingTime.MeetingType {
case "ID", "OA":
return false
@@ -423,23 +427,23 @@ func IcsCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) err
if !exists {
log.Warn().Str("crn", course.CourseReferenceNumber).Msg("Non-meeting course requested for ICS file")
RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
utils.RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
return nil
}
events := []string{}
for _, meeting := range meetingTimes {
now := time.Now().In(CentralTimeLocation)
now := time.Now().In(config.CentralTimeLocation)
uid := fmt.Sprintf("%d-%s@ical.banner.xevion.dev", now.Unix(), meeting.CourseReferenceNumber)
startDay := meeting.StartDay()
startTime := meeting.StartTime()
endTime := meeting.EndTime()
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, CentralTimeLocation)
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, CentralTimeLocation)
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, config.CentralTimeLocation)
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, config.CentralTimeLocation)
endDay := meeting.EndDay()
until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, CentralTimeLocation)
until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, config.CentralTimeLocation)
summary := fmt.Sprintf("%s %s %s", course.Subject, course.CourseNumber, course.CourseTitle)
description := fmt.Sprintf("Instructor: %s\nSection: %s\nCRN: %s", course.Faculty[0].DisplayName, course.SequenceNumber, meeting.CourseReferenceNumber)
@@ -454,7 +458,7 @@ DTEND;TZID=America/Chicago:%s
SUMMARY:%s
DESCRIPTION:%s
LOCATION:%s
END:VEVENT`, now.Format(ICalTimestampFormatLocal), uid, dtStart.Format(ICalTimestampFormatLocal), meeting.ByDay(), until.Format(ICalTimestampFormatLocal), dtEnd.Format(ICalTimestampFormatLocal), summary, strings.Replace(description, "\n", `\n`, -1), location)
END:VEVENT`, now.Format(config.ICalTimestampFormatLocal), uid, dtStart.Format(config.ICalTimestampFormatLocal), meeting.ByDay(), until.Format(config.ICalTimestampFormatLocal), dtEnd.Format(config.ICalTimestampFormatLocal), summary, strings.Replace(description, "\n", `\n`, -1), location)
events = append(events, event)
}
@@ -489,7 +493,7 @@ CALSCALE:GREGORIAN
%s
END:VCALENDAR`, vTimezone, strings.Join(events, "\n"))
session.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Files: []*discordgo.File{

64
internal/config/config.go Normal file
View File

@@ -0,0 +1,64 @@
package config
import (
"context"
"net/http"
"time"
"github.com/redis/go-redis/v9"
)
var (
// Global variables that need to be accessible across packages
Ctx context.Context
KV *redis.Client
Client http.Client
Cookies http.CookieJar
IsDevelopment bool
BaseURL string
Environment string
CentralTimeLocation *time.Location
IsClosing bool = false
)
const (
ICalTimestampFormatUtc = "20060102T150405Z"
ICalTimestampFormatLocal = "20060102T150405"
CentralTimezoneName = "America/Chicago"
)
func init() {
Ctx = context.Background()
var err error
CentralTimeLocation, err = time.LoadLocation(CentralTimezoneName)
if err != nil {
panic(err)
}
}
// SetBaseURL sets the base URL for API requests
func SetBaseURL(url string) {
BaseURL = url
}
// SetEnvironment sets the environment
func SetEnvironment(env string) {
Environment = env
IsDevelopment = env == "development"
}
// SetClient sets the HTTP client
func SetClient(c http.Client) {
Client = c
}
// SetCookies sets the cookie jar
func SetCookies(cj http.CookieJar) {
Cookies = cj
}
// SetRedis sets the Redis client
func SetRedis(r *redis.Client) {
KV = r
}

View File

@@ -1,6 +1,8 @@
package main
package models
import (
"banner/internal/config"
"banner/internal/utils"
"encoding/json"
"fmt"
"strconv"
@@ -113,7 +115,7 @@ func (m *MeetingTimeResponse) TimeString() string {
return "???"
}
return fmt.Sprintf("%s %s-%s", WeekdaysToString(m.Days()), m.StartTime().String(), m.EndTime().String())
return fmt.Sprintf("%s %s-%s", utils.WeekdaysToString(m.Days()), m.StartTime().String(), m.EndTime().String())
}
// PlaceString returns a formatted string best representing the place of the meeting time
@@ -194,7 +196,7 @@ func (m *MeetingTimeResponse) EndDay() time.Time {
// StartTime returns the start time of the meeting time as a NaiveTime object
// This is not cached and is parsed on each invocation. It may also panic without handling.
func (m *MeetingTimeResponse) StartTime() *NaiveTime {
func (m *MeetingTimeResponse) StartTime() *utils.NaiveTime {
raw := m.MeetingTime.BeginTime
if raw == "" {
log.Panic().Stack().Msg("Start time is empty")
@@ -205,12 +207,12 @@ func (m *MeetingTimeResponse) StartTime() *NaiveTime {
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse start time integer")
}
return ParseNaiveTime(value)
return utils.ParseNaiveTime(value)
}
// EndTime returns the end time of the meeting time as a NaiveTime object
// This is not cached and is parsed on each invocation. It may also panic without handling.
func (m *MeetingTimeResponse) EndTime() *NaiveTime {
func (m *MeetingTimeResponse) EndTime() *utils.NaiveTime {
raw := m.MeetingTime.EndTime
if raw == "" {
return nil
@@ -221,7 +223,7 @@ func (m *MeetingTimeResponse) EndTime() *NaiveTime {
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse end time integer")
}
return ParseNaiveTime(value)
return utils.ParseNaiveTime(value)
}
// Converts the meeting time to a string that satisfies the iCalendar RRule format
@@ -229,7 +231,7 @@ func (m *MeetingTimeResponse) RRule() string {
sb := strings.Builder{}
sb.WriteString("FREQ=WEEKLY;")
sb.WriteString(fmt.Sprintf("UNTIL=%s;", m.EndDay().UTC().Format(ICalTimestampFormatUtc)))
sb.WriteString(fmt.Sprintf("UNTIL=%s;", m.EndDay().UTC().Format(config.ICalTimestampFormatUtc)))
sb.WriteString(fmt.Sprintf("BYDAY=%s;", m.ByDay()))
return sb.String()

View File

@@ -1,4 +1,4 @@
package main
package utils
import "fmt"

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"fmt"
@@ -14,16 +14,16 @@ import (
"time"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
"github.com/rs/zerolog"
log "github.com/rs/zerolog/log"
"github.com/samber/lo"
"banner/internal/config"
)
// BuildRequestWithBody builds a request with the given method, path, parameters, and body
func BuildRequestWithBody(method string, path string, params map[string]string, body io.Reader) *http.Request {
// Builds a URL for the given path and parameters
requestUrl := baseURL + path
requestUrl := config.BaseURL + path
if params != nil {
takenFirst := false
@@ -112,57 +112,6 @@ func Nonce() string {
return strconv.Itoa(int(time.Now().UnixMilli()))
}
// DoRequest performs & logs the request, logging and returning the response
func DoRequest(req *http.Request) (*http.Response, error) {
headerSize := 0
for key, values := range req.Header {
for _, value := range values {
headerSize += len(key)
headerSize += len(value)
}
}
bodySize := int64(0)
if req.Body != nil {
bodySize, _ = io.Copy(io.Discard, req.Body)
}
size := zerolog.Dict().Int64("body", bodySize).Int("header", headerSize).Int("url", len(req.URL.String()))
log.Debug().
Dict("size", size).
Str("method", strings.TrimRight(req.Method, " ")).
Str("url", req.URL.String()).
Str("query", req.URL.RawQuery).
Str("content-type", req.Header.Get("Content-Type")).
Msg("Request")
res, err := client.Do(req)
if err != nil {
log.Err(err).Stack().Str("method", req.Method).Msg("Request Failed")
} else {
contentLengthHeader := res.Header.Get("Content-Length")
contentLength := int64(-1)
// If this request was a Banner API request, reset the session timer
if strings.HasPrefix(req.URL.Path, "StudentRegistrationSsb/ssb/classSearch/") {
ResetSessionTimer()
}
// Get the content length
if contentLengthHeader != "" {
contentLength, err = strconv.ParseInt(contentLengthHeader, 10, 64)
if err != nil {
contentLength = -1
}
}
log.Debug().Int("status", res.StatusCode).Int64("content-length", contentLength).Strs("content-type", res.Header["Content-Type"]).Msg("Response")
}
return res, err
}
// Plural is a simple helper function that returns an empty string if n is 1, and "s" otherwise.
func Plural(n int) string {
if n == 1 {
@@ -362,7 +311,7 @@ func RespondError(session *discordgo.Session, interaction *discordgo.Interaction
func GetFetchedFooter(time time.Time) *discordgo.MessageEmbedFooter {
return &discordgo.MessageEmbedFooter{
Text: fmt.Sprintf("Fetched at %s", time.In(CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")),
Text: fmt.Sprintf("Fetched at %s", time.In(config.CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")),
}
}
@@ -415,49 +364,6 @@ func EncodeParams(params map[string]*[]string) string {
return buf.String()
}
var terms []BannerTerm
var lastTermUpdate time.Time
// TryReloadTerms attempts to reload the terms if they are not loaded or the last update was more than 24 hours ago
func TryReloadTerms() error {
if len(terms) > 0 && time.Since(lastTermUpdate) < 24*time.Hour {
return nil
}
// Load the terms
var err error
terms, err = GetTerms("", 1, 100)
if err != nil {
return errors.Wrap(err, "failed to load terms")
}
lastTermUpdate = time.Now()
return nil
}
// IsTermArchived checks if the given term is archived
// TODO: Add error, switch missing term logic to error
func IsTermArchived(term string) bool {
// Ensure the terms are loaded
err := TryReloadTerms()
if err != nil {
log.Err(err).Stack().Msg("Failed to reload terms")
return true
}
// Check if the term is in the list of terms
bannerTerm, exists := lo.Find(terms, func(t BannerTerm) bool {
return t.Code == term
})
if !exists {
log.Warn().Str("term", term).Msg("Term does not exist")
return true
}
return bannerTerm.Archived()
}
// Point represents a point in 2D space
type Point struct {
X, Y float64

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"io"
@@ -15,21 +15,21 @@ var (
)
// logSplitter implements zerolog.LevelWriter
type logSplitter struct {
std io.Writer
err io.Writer
type LogSplitter struct {
Std io.Writer
Err io.Writer
}
// Write should not be called
func (l logSplitter) Write(p []byte) (n int, err error) {
return l.std.Write(p)
func (l LogSplitter) Write(p []byte) (n int, err error) {
return l.Std.Write(p)
}
// WriteLevel write to the appropriate output
func (l logSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
func (l LogSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
if level <= zerolog.WarnLevel {
return l.std.Write(p)
return l.Std.Write(p)
} else {
return l.err.Write(p)
return l.Err.Write(p)
}
}

View File

@@ -1,16 +1,18 @@
package main
package utils
import (
"banner/internal/config"
"time"
"github.com/bwmarrin/discordgo"
"github.com/redis/go-redis/v9"
log "github.com/rs/zerolog/log"
)
// GetGuildName returns the name of the guild with the given ID, utilizing Redis to cache the value
func GetGuildName(guildID string) string {
func GetGuildName(session *discordgo.Session, guildID string) string {
// Check Redis for the guild name
guildName, err := kv.Get(ctx, "guild:"+guildID+":name").Result()
guildName, err := config.KV.Get(config.Ctx, "guild:"+guildID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
return "err"
@@ -27,7 +29,7 @@ func GetGuildName(guildID string) string {
log.Error().Stack().Err(err).Msg("Error getting guild name")
// Store an invalid value in Redis so we don't keep trying to get the guild name
_, err := kv.Set(ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result()
_, err := config.KV.Set(config.Ctx, "guild:"+guildID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
}
@@ -36,15 +38,15 @@ func GetGuildName(guildID string) string {
}
// Cache the guild name in Redis
kv.Set(ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3)
config.KV.Set(config.Ctx, "guild:"+guildID+":name", guild.Name, time.Hour*3)
return guild.Name
}
// GetChannelName returns the name of the channel with the given ID, utilizing Redis to cache the value
func GetChannelName(channelID string) string {
func GetChannelName(session *discordgo.Session, channelID string) string {
// Check Redis for the channel name
channelName, err := kv.Get(ctx, "channel:"+channelID+":name").Result()
channelName, err := config.KV.Get(config.Ctx, "channel:"+channelID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
return "err"
@@ -61,7 +63,7 @@ func GetChannelName(channelID string) string {
log.Error().Stack().Err(err).Msg("Error getting channel name")
// Store an invalid value in Redis so we don't keep trying to get the channel name
_, err := kv.Set(ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result()
_, err := config.KV.Set(config.Ctx, "channel:"+channelID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
}
@@ -70,7 +72,7 @@ func GetChannelName(channelID string) string {
}
// Cache the channel name in Redis
kv.Set(ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3)
config.KV.Set(config.Ctx, "channel:"+channelID+":name", channel.Name, time.Hour*3)
return channel.Name
}

View File

@@ -1,6 +1,7 @@
package main
package utils
import (
"banner/internal/config"
"fmt"
"strconv"
"time"
@@ -46,12 +47,12 @@ type YearDayRange struct {
// Summer: May 25th - August 15th
// Fall: August 18th - December 10th
func GetYearDayRange(year uint16) (YearDayRange, YearDayRange, YearDayRange) {
springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, CentralTimeLocation).YearDay()
springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, CentralTimeLocation).YearDay()
summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, CentralTimeLocation).YearDay()
summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, CentralTimeLocation).YearDay()
fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, CentralTimeLocation).YearDay()
fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, CentralTimeLocation).YearDay()
springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, config.CentralTimeLocation).YearDay()
return YearDayRange{
Start: uint16(springStart),