mirror of
https://github.com/Xevion/banner.git
synced 2025-12-08 08:06:30 -06:00
feat!: begin rust rewrite
service scheduling, configs, all dependencies, tracing, graceful shutdown, concurrency
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,10 +1,3 @@
|
||||
.env
|
||||
cover.cov
|
||||
/banner
|
||||
.*.go
|
||||
dumps/
|
||||
js/
|
||||
.vscode/
|
||||
*.prof
|
||||
.task/
|
||||
bin/
|
||||
/target
|
||||
/go/
|
||||
3663
Cargo.lock
generated
Normal file
3663
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
Cargo.toml
Normal file
21
Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "banner"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.47.1", features = ["full"] }
|
||||
axum = "0.8.4"
|
||||
serenity = { version = "0.12.4", features = ["rustls_backend"] }
|
||||
reqwest = { version = "0.12.23", features = ["json", "cookies"] }
|
||||
diesel = { version = "2.2.12", features = ["chrono", "postgres", "uuid"] }
|
||||
redis = { version = "0.32.5", features = ["tokio-comp"] }
|
||||
figment = { version = "0.10.19", features = ["toml", "env"] }
|
||||
serde_json = "1.0.143"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
governor = "0.10.1"
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
dotenvy = "0.15.7"
|
||||
poise = "0.6.1"
|
||||
async-trait = "0.1"
|
||||
46
Taskfile.yml
46
Taskfile.yml
@@ -1,46 +0,0 @@
|
||||
version: "3"
|
||||
|
||||
tasks:
|
||||
build:
|
||||
desc: Build the application
|
||||
cmds:
|
||||
- go build -o bin/banner ./cmd/banner
|
||||
sources:
|
||||
- ./cmd/banner/**/*.go
|
||||
- ./internal/**/*.go
|
||||
generates:
|
||||
- bin/banner
|
||||
|
||||
run:
|
||||
desc: Run the application
|
||||
cmds:
|
||||
- go run ./cmd/banner
|
||||
deps: [build]
|
||||
|
||||
test:
|
||||
desc: Run tests
|
||||
cmds:
|
||||
- go test ./tests/...
|
||||
env:
|
||||
ENVIRONMENT: test
|
||||
|
||||
test-coverage:
|
||||
desc: Run tests with coverage
|
||||
cmds:
|
||||
- go test -coverpkg=./internal/... -cover ./tests/...
|
||||
env:
|
||||
ENVIRONMENT: test
|
||||
|
||||
clean:
|
||||
desc: Clean build artifacts
|
||||
cmds:
|
||||
- rm -rf bin/
|
||||
- go clean -cache
|
||||
- go clean -modcache
|
||||
|
||||
dev:
|
||||
desc: Run in development mode
|
||||
cmds:
|
||||
- go run ./cmd/banner
|
||||
env:
|
||||
ENVIRONMENT: development
|
||||
@@ -1,299 +0,0 @@
|
||||
// Package main is the entry point for the banner application.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
_ "time/tzdata"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/rs/zerolog/pkgerrors"
|
||||
"github.com/samber/lo"
|
||||
"resty.dev/v3"
|
||||
|
||||
"banner/internal"
|
||||
"banner/internal/api"
|
||||
"banner/internal/bot"
|
||||
"banner/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
Session *discordgo.Session
|
||||
)
|
||||
|
||||
const (
|
||||
ICalTimestampFormatUtc = "20060102T150405Z"
|
||||
ICalTimestampFormatLocal = "20060102T150405"
|
||||
CentralTimezoneName = "America/Chicago"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Load environment variables
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Debug().Err(err).Msg("Error loading .env file")
|
||||
}
|
||||
|
||||
// Set zerolog's timestamp function to use the central timezone
|
||||
zerolog.TimestampFunc = func() time.Time {
|
||||
// TODO: Move this to config
|
||||
loc, err := time.LoadLocation(CentralTimezoneName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return time.Now().In(loc)
|
||||
}
|
||||
|
||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
||||
|
||||
// Use the custom console writer if we're in development
|
||||
isDevelopment := internal.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
|
||||
if isDevelopment == "" {
|
||||
isDevelopment = "development"
|
||||
}
|
||||
|
||||
if isDevelopment == "development" {
|
||||
log.Logger = zerolog.New(config.NewConsoleWriter()).With().Timestamp().Logger()
|
||||
} else {
|
||||
log.Logger = zerolog.New(config.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger()
|
||||
}
|
||||
log.Debug().Str("environment", isDevelopment).Msg("Loggers Setup")
|
||||
|
||||
// Set discordgo's logger to use zerolog
|
||||
discordgo.Logger = internal.DiscordGoLogger
|
||||
}
|
||||
|
||||
// initRedis initializes the Redis client and pings the server to ensure a connection.
|
||||
func initRedis(cfg *config.Config) {
|
||||
// Setup redis
|
||||
redisUrl := internal.GetFirstEnv("REDIS_URL", "REDIS_PRIVATE_URL")
|
||||
if redisUrl == "" {
|
||||
log.Fatal().Stack().Msg("REDIS_URL/REDIS_PRIVATE_URL not set")
|
||||
}
|
||||
|
||||
// Parse URL and create client
|
||||
options, err := redis.ParseURL(redisUrl)
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot parse redis url")
|
||||
}
|
||||
kv := redis.NewClient(options)
|
||||
cfg.SetRedis(kv)
|
||||
|
||||
var lastPingErr error
|
||||
pingCount := 0 // Nth ping being attempted
|
||||
totalPings := 5 // Total pings to attempt
|
||||
|
||||
// Wait for private networking to kick in (production only)
|
||||
if !cfg.IsDevelopment {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Test the redis instance, try to ping every 2 seconds 5 times, otherwise panic
|
||||
for {
|
||||
pingCount++
|
||||
if pingCount > totalPings {
|
||||
log.Fatal().Stack().Err(lastPingErr).Msg("Reached ping limit while trying to connect")
|
||||
}
|
||||
|
||||
// Ping redis
|
||||
pong, err := cfg.KV.Ping(cfg.Ctx).Result()
|
||||
|
||||
// Failed; log error and wait 2 seconds
|
||||
if err != nil {
|
||||
lastPingErr = err
|
||||
log.Warn().Err(err).Int("pings", pingCount).Int("remaining", totalPings-pingCount).Msg("Cannot ping redis")
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Str("ping", pong).Msg("Redis connection successful")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
cfg, err := config.New()
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot create config")
|
||||
}
|
||||
|
||||
// Try to grab the environment variable, or default to development
|
||||
environment := internal.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
|
||||
if environment == "" {
|
||||
environment = "development"
|
||||
}
|
||||
cfg.SetEnvironment(environment)
|
||||
|
||||
initRedis(cfg)
|
||||
|
||||
if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") {
|
||||
// Start pprof server with graceful shutdown
|
||||
go func() {
|
||||
port := os.Getenv("PORT")
|
||||
log.Info().Str("port", port).Msg("Starting pprof server")
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + port,
|
||||
}
|
||||
|
||||
// Start server in a separate goroutine
|
||||
go func() {
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot start pprof server")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for context cancellation and then shutdown
|
||||
<-cfg.Ctx.Done()
|
||||
log.Info().Msg("Shutting down pprof server")
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Pprof server forced to shutdown")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Create cookie jar
|
||||
cookies, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Cannot create cookie jar")
|
||||
}
|
||||
|
||||
// Create Resty client with timeout and cookie jar
|
||||
baseURL := os.Getenv("BANNER_BASE_URL")
|
||||
client := resty.New().
|
||||
SetBaseURL(baseURL).
|
||||
SetTimeout(30*time.Second).
|
||||
SetCookieJar(cookies).
|
||||
SetHeader("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36").
|
||||
AddResponseMiddleware(api.SessionMiddleware)
|
||||
|
||||
cfg.SetClient(client)
|
||||
cfg.SetBaseURL(baseURL)
|
||||
|
||||
apiInstance := api.New(cfg)
|
||||
apiInstance.Setup()
|
||||
|
||||
// Create discord session
|
||||
session, err := discordgo.New("Bot " + os.Getenv("BOT_TOKEN"))
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Invalid bot parameters")
|
||||
}
|
||||
|
||||
botInstance := bot.New(session, apiInstance, cfg)
|
||||
botInstance.RegisterHandlers()
|
||||
|
||||
// Open discord session
|
||||
session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) {
|
||||
log.Info().Str("username", r.User.Username).Str("discriminator", r.User.Discriminator).Str("id", r.User.ID).Str("session", s.State.SessionID).Msg("Bot is logged in")
|
||||
})
|
||||
err = session.Open()
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot open the session")
|
||||
}
|
||||
|
||||
// Setup command handlers
|
||||
// Register commands with discord
|
||||
arr := zerolog.Arr()
|
||||
lo.ForEach(bot.CommandDefinitions, func(cmd *discordgo.ApplicationCommand, _ int) {
|
||||
arr.Str(cmd.Name)
|
||||
})
|
||||
log.Info().Array("commands", arr).Msg("Registering commands")
|
||||
|
||||
// In development, use test server, otherwise empty (global) for command registration
|
||||
guildTarget := ""
|
||||
if cfg.IsDevelopment {
|
||||
guildTarget = os.Getenv("BOT_TARGET_GUILD")
|
||||
}
|
||||
|
||||
// Register commands
|
||||
existingCommands, err := session.ApplicationCommands(session.State.User.ID, guildTarget)
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot get existing commands")
|
||||
}
|
||||
newCommands, err := session.ApplicationCommandBulkOverwrite(session.State.User.ID, guildTarget, bot.CommandDefinitions)
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot register commands")
|
||||
}
|
||||
|
||||
// Compare existing commands with new commands
|
||||
for _, newCommand := range newCommands {
|
||||
existingCommand, found := lo.Find(existingCommands, func(cmd *discordgo.ApplicationCommand) bool {
|
||||
return cmd.Name == newCommand.Name
|
||||
})
|
||||
|
||||
// New command
|
||||
if !found {
|
||||
log.Info().Str("commandName", newCommand.Name).Msg("Registered new command")
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if newCommand.Version != existingCommand.Version {
|
||||
log.Info().Str("commandName", newCommand.Name).
|
||||
Str("oldVersion", existingCommand.Version).Str("newVersion", newCommand.Version).
|
||||
Msg("Command Updated")
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch terms on startup
|
||||
err = apiInstance.TryReloadTerms()
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Cannot fetch terms on startup")
|
||||
}
|
||||
|
||||
// Launch a goroutine to scrape the banner system periodically
|
||||
go func() {
|
||||
ticker := time.NewTicker(3 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Ctx.Done():
|
||||
log.Info().Msg("Periodic scraper stopped due to context cancellation")
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := apiInstance.Scrape()
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg("Periodic Scrape Failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Close session, ensure Resty client closes
|
||||
defer session.Close()
|
||||
defer client.Close()
|
||||
|
||||
// Setup signal handler channel
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt) // Ctrl+C signal
|
||||
signal.Notify(stop, syscall.SIGTERM) // Container stop signal
|
||||
|
||||
// Wait for signal (indefinite)
|
||||
closingSignal := <-stop
|
||||
botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds
|
||||
|
||||
// Cancel the context to signal all operations to stop
|
||||
cfg.CancelFunc()
|
||||
|
||||
// Defers are called after this
|
||||
log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down")
|
||||
}
|
||||
27
go.mod
27
go.mod
@@ -1,27 +0,0 @@
|
||||
module banner
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/samber/lo v1.51.0
|
||||
resty.dev/v3 v3.0.0-beta.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
)
|
||||
52
go.sum
52
go.sum
@@ -1,52 +0,0 @@
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
|
||||
github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/samber/lo v1.51.0 h1:kysRYLbHy/MB7kQZf5DSN50JHmMsNEdeY24VzJFu7wI=
|
||||
github.com/samber/lo v1.51.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
resty.dev/v3 v3.0.0-beta.3 h1:3kEwzEgCnnS6Ob4Emlk94t+I/gClyoah7SnNi67lt+E=
|
||||
resty.dev/v3 v3.0.0-beta.3/go.mod h1:OgkqiPvTDtOuV4MGZuUDhwOpkY8enjOsjjMzeOHefy4=
|
||||
@@ -1,491 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"banner/internal/config"
|
||||
"banner/internal/models"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"resty.dev/v3"
|
||||
)
|
||||
|
||||
// API provides a client for interacting with the Banner API.
|
||||
type API struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// New creates a new API client with the given configuration.
|
||||
func New(config *config.Config) *API {
|
||||
return &API{config: config}
|
||||
}
|
||||
|
||||
var (
|
||||
latestSession string
|
||||
sessionTime time.Time
|
||||
expiryTime = 25 * time.Minute
|
||||
)
|
||||
|
||||
// SessionMiddleware creates a Resty middleware that resets the session timer on each successful Banner API call.
|
||||
func SessionMiddleware(_ *resty.Client, r *resty.Response) error {
|
||||
// log.Debug().Str("url", r.Request.RawRequest.URL.Path).Msg("Session middleware")
|
||||
|
||||
// Reset session timer on successful requests to Banner API endpoints
|
||||
if r.IsSuccess() && strings.HasPrefix(r.Request.RawRequest.URL.Path, "StudentRegistrationSsb/ssb/classSearch/") {
|
||||
// Only reset the session time if the session is still valid
|
||||
if time.Since(sessionTime) <= expiryTime {
|
||||
sessionTime = time.Now()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSession generates a new session ID for use with the Banner API.
|
||||
// This function should not be used directly; use EnsureSession instead.
|
||||
func GenerateSession() string {
|
||||
return internal.RandomString(5) + internal.Nonce()
|
||||
}
|
||||
|
||||
// DefaultTerm returns the default term, which is the current term if it exists, otherwise the next term.
|
||||
func (a *API) DefaultTerm(t time.Time) config.Term {
|
||||
currentTerm, nextTerm := config.GetCurrentTerm(*a.config.SeasonRanges, t)
|
||||
if currentTerm == nil {
|
||||
return *nextTerm
|
||||
}
|
||||
return *currentTerm
|
||||
}
|
||||
|
||||
var terms []BannerTerm
|
||||
var lastTermUpdate time.Time
|
||||
|
||||
// TryReloadTerms attempts to reload the terms if they are not loaded or if the last update was more than 24 hours ago.
|
||||
func (a *API) TryReloadTerms() error {
|
||||
if len(terms) > 0 && time.Since(lastTermUpdate) < 24*time.Hour {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load the terms
|
||||
var err error
|
||||
terms, err = a.GetTerms("", 1, 100)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load terms: %w", err)
|
||||
}
|
||||
|
||||
lastTermUpdate = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsTermArchived checks if the given term is archived (view only).
|
||||
//
|
||||
// TODO: Add error handling for when a term does not exist.
|
||||
func (a *API) IsTermArchived(term string) bool {
|
||||
// Ensure the terms are loaded
|
||||
err := a.TryReloadTerms()
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg("Failed to reload terms")
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if the term is in the list of terms
|
||||
bannerTerm, exists := lo.Find(terms, func(t BannerTerm) bool {
|
||||
return t.Code == term
|
||||
})
|
||||
|
||||
if !exists {
|
||||
log.Warn().Str("term", term).Msg("Term does not exist")
|
||||
return true
|
||||
}
|
||||
|
||||
return bannerTerm.Archived()
|
||||
}
|
||||
|
||||
// EnsureSession ensures that a valid session is available, creating one if necessary.
|
||||
func (a *API) EnsureSession() string {
|
||||
if latestSession == "" || time.Since(sessionTime) >= expiryTime {
|
||||
latestSession = GenerateSession()
|
||||
sessionTime = time.Now()
|
||||
}
|
||||
return latestSession
|
||||
}
|
||||
|
||||
// Pair represents a key-value pair from the Banner API.
|
||||
type Pair struct {
|
||||
Code string `json:"code"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// BannerTerm represents a term in the Banner system.
|
||||
type BannerTerm Pair
|
||||
|
||||
// Instructor represents an instructor in the Banner system.
|
||||
type Instructor Pair
|
||||
|
||||
// Archived returns true if the term is in an archival (view-only) state.
|
||||
func (term BannerTerm) Archived() bool {
|
||||
return strings.Contains(term.Description, "View Only")
|
||||
}
|
||||
|
||||
// GetTerms retrieves a list of terms from the Banner API.
|
||||
// The page number must be at least 1.
|
||||
func (a *API) GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) {
|
||||
// Ensure offset is valid
|
||||
if page <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("offset", strconv.Itoa(page)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]BannerTerm{})
|
||||
|
||||
res, err := req.Get("/classSearch/getTerms")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get terms: %w", err)
|
||||
}
|
||||
|
||||
terms, ok := res.Result().(*[]BannerTerm)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("terms parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return *terms, nil
|
||||
}
|
||||
|
||||
// SelectTerm selects a term in the Banner system for the given session.
|
||||
// This is required before other API calls can be made.
|
||||
func (a *API) SelectTerm(term string, sessionID string) error {
|
||||
form := url.Values{
|
||||
"term": {term},
|
||||
"studyPath": {""},
|
||||
"studyPathText": {""},
|
||||
"startDatepicker": {""},
|
||||
"endDatepicker": {""},
|
||||
"uniqueSessionId": {sessionID},
|
||||
}
|
||||
|
||||
type RedirectResponse struct {
|
||||
FwdURL string `json:"fwdUrl"`
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetResult(&RedirectResponse{}).
|
||||
SetQueryParam("mode", "search").
|
||||
SetBody(form.Encode()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := req.Post("/term/search")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select term: %w", err)
|
||||
}
|
||||
|
||||
redirectResponse := res.Result().(*RedirectResponse)
|
||||
|
||||
// TODO: Mild validation to ensure the redirect is appropriate
|
||||
|
||||
// Make a GET request to the fwdUrl
|
||||
req = a.config.Client.NewRequest()
|
||||
res, err = req.Get(redirectResponse.FwdURL)
|
||||
|
||||
// Assert that the response is OK (200)
|
||||
if res.StatusCode() != 200 {
|
||||
return fmt.Errorf("redirect response was not OK: %d", res.StatusCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPartOfTerms retrieves a list of parts of a term from the Banner API.
|
||||
// The page number must be at least 1.
|
||||
func (a *API) GetPartOfTerms(search string, term int, offset int, maxResults int) ([]BannerTerm, error) {
|
||||
// Ensure offset is valid
|
||||
if offset <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("term", strconv.Itoa(term)).
|
||||
SetQueryParam("offset", strconv.Itoa(offset)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("uniqueSessionId", a.EnsureSession()).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]BannerTerm{})
|
||||
|
||||
res, err := req.Get("/classSearch/get_partOfTerm")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get part of terms: %w", err)
|
||||
}
|
||||
|
||||
terms, ok := res.Result().(*[]BannerTerm)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("term parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return *terms, nil
|
||||
}
|
||||
|
||||
// GetInstructors retrieves a list of instructors from the Banner API.
|
||||
func (a *API) GetInstructors(search string, term string, offset int, maxResults int) ([]Instructor, error) {
|
||||
// Ensure offset is valid
|
||||
if offset <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("term", term).
|
||||
SetQueryParam("offset", strconv.Itoa(offset)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("uniqueSessionId", a.EnsureSession()).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]Instructor{})
|
||||
|
||||
res, err := req.Get("/classSearch/get_instructor")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get instructors: %w", err)
|
||||
}
|
||||
|
||||
instructors, ok := res.Result().(*[]Instructor)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("instructor parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return *instructors, nil
|
||||
}
|
||||
|
||||
// ClassDetails represents the detailed information for a class.
|
||||
//
|
||||
// TODO: Implement this struct and the associated GetCourseDetails function.
|
||||
type ClassDetails struct {
|
||||
}
|
||||
|
||||
// GetCourseDetails retrieves the details for a specific course.
|
||||
func (a *API) GetCourseDetails(term int, crn int) (*ClassDetails, error) {
|
||||
body, err := json.Marshal(map[string]string{
|
||||
"term": strconv.Itoa(term),
|
||||
"courseReferenceNumber": strconv.Itoa(crn),
|
||||
"first": "first", // TODO: What is this?
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Failed to marshal body")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetBody(body).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&ClassDetails{})
|
||||
|
||||
res, err := req.Get("/searchResults/getClassDetails")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get course details: %w", err)
|
||||
}
|
||||
|
||||
details, ok := res.Result().(*ClassDetails)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("course details parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// Search performs a search for courses with the given query and returns the results.
|
||||
func (a *API) Search(term string, query *Query, sort string, sortDescending bool) (*models.SearchResult, error) {
|
||||
a.ResetDataForm()
|
||||
|
||||
params := query.Paramify()
|
||||
|
||||
params["txt_term"] = term
|
||||
params["uniqueSessionId"] = a.EnsureSession()
|
||||
params["sortColumn"] = sort
|
||||
params["sortDirection"] = "asc"
|
||||
|
||||
// These dates are not available for usage anywhere in the UI, but are included in every query
|
||||
params["startDatepicker"] = ""
|
||||
params["endDatepicker"] = ""
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParams(params).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&models.SearchResult{})
|
||||
|
||||
res, err := req.Get("/searchResults/searchResults")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search: %w", err)
|
||||
}
|
||||
|
||||
searchResult, ok := res.Result().(*models.SearchResult)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("search result parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return searchResult, nil
|
||||
}
|
||||
|
||||
// GetSubjects retrieves a list of subjects from the Banner API.
|
||||
// The page number must be at least 1.
|
||||
func (a *API) GetSubjects(search string, term string, offset int, maxResults int) ([]Pair, error) {
|
||||
// Ensure offset is valid
|
||||
if offset <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("term", term).
|
||||
SetQueryParam("offset", strconv.Itoa(offset)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("uniqueSessionId", a.EnsureSession()).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]Pair{})
|
||||
|
||||
res, err := req.Get("/classSearch/get_subject")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get subjects: %w", err)
|
||||
}
|
||||
|
||||
subjects, ok := res.Result().(*[]Pair)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("subjects parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return *subjects, nil
|
||||
}
|
||||
|
||||
// GetCampuses retrieves a list of campuses from the Banner API.
|
||||
// The page number must be at least 1.
|
||||
func (a *API) GetCampuses(search string, term int, offset int, maxResults int) ([]Pair, error) {
|
||||
// Ensure offset is valid
|
||||
if offset <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("term", strconv.Itoa(term)).
|
||||
SetQueryParam("offset", strconv.Itoa(offset)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("uniqueSessionId", a.EnsureSession()).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]Pair{})
|
||||
|
||||
res, err := req.Get("/classSearch/get_campus")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get campuses: %w", err)
|
||||
}
|
||||
|
||||
campuses, ok := res.Result().(*[]Pair)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("campuses parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return *campuses, nil
|
||||
}
|
||||
|
||||
// GetInstructionalMethods retrieves a list of instructional methods from the Banner API.
|
||||
// The page number must be at least 1.
|
||||
func (a *API) GetInstructionalMethods(search string, term string, offset int, maxResults int) ([]Pair, error) {
|
||||
// Ensure offset is valid
|
||||
if offset <= 0 {
|
||||
return nil, errors.New("offset must be greater than 0")
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("searchTerm", search).
|
||||
SetQueryParam("term", term).
|
||||
SetQueryParam("offset", strconv.Itoa(offset)).
|
||||
SetQueryParam("max", strconv.Itoa(maxResults)).
|
||||
SetQueryParam("uniqueSessionId", a.EnsureSession()).
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&[]Pair{})
|
||||
|
||||
res, err := req.Get("/classSearch/get_instructionalMethod")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get instructional methods: %w", err)
|
||||
}
|
||||
|
||||
methods, ok := res.Result().(*[]Pair)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("instructional methods parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
return *methods, nil
|
||||
}
|
||||
|
||||
// GetCourseMeetingTime retrieves the meeting time information for a course.
|
||||
func (a *API) GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, error) {
|
||||
type responseWrapper struct {
|
||||
Fmt []models.MeetingTimeResponse `json:"fmt"`
|
||||
}
|
||||
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("term", strconv.Itoa(term)).
|
||||
SetQueryParam("courseReferenceNumber", strconv.Itoa(crn)).
|
||||
SetExpectResponseContentType("application/json").
|
||||
SetResult(&responseWrapper{})
|
||||
|
||||
res, err := req.Get("/searchResults/getFacultyMeetingTimes")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get meeting time: %w", err)
|
||||
}
|
||||
|
||||
result, ok := res.Result().(*responseWrapper)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("meeting times parsing failed to cast: %v", res.Result())
|
||||
}
|
||||
|
||||
return result.Fmt, nil
|
||||
}
|
||||
|
||||
// ResetDataForm resets the search form in the Banner system.
|
||||
// This must be called before a new search can be performed.
|
||||
func (a *API) ResetDataForm() {
|
||||
req := a.config.Client.NewRequest()
|
||||
|
||||
_, err := req.Post("/classSearch/resetDataForm")
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Err(err).Msg("Failed to reset data form")
|
||||
}
|
||||
}
|
||||
|
||||
// GetCourse retrieves course information from the Redis cache.
|
||||
func (a *API) GetCourse(crn string) (*models.Course, error) {
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Retrieve raw data
|
||||
result, err := a.config.KV.Get(ctx, fmt.Sprintf("class:%s", crn)).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("course not found: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get course: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal the raw data
|
||||
var course models.Course
|
||||
err = json.Unmarshal([]byte(result), &course)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal course: %w", err)
|
||||
}
|
||||
|
||||
return &course, nil
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
// Package api provides the core functionality for interacting with the Banner API.
|
||||
package api
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"banner/internal/models"
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPageSize is the maximum number of courses one can scrape per page.
|
||||
MaxPageSize = 500
|
||||
)
|
||||
|
||||
var (
|
||||
// PriorityMajors is a list of majors that are considered to be high priority for scraping.
|
||||
// This list is used to determine which majors to scrape first/most often.
|
||||
PriorityMajors = []string{"CS", "CPE", "MAT", "EE", "IS"}
|
||||
// AncillaryMajors is a list of majors that are considered to be low priority for scraping.
|
||||
// This list will not contain any majors that are in PriorityMajors.
|
||||
AncillaryMajors []string
|
||||
// AllMajors is a list of all majors that are available in the Banner system.
|
||||
AllMajors []string
|
||||
)
|
||||
|
||||
// Scrape retrieves all courses from the Banner API and stores them in Redis.
|
||||
// This is a long-running process that should be run in a goroutine.
|
||||
//
|
||||
// TODO: Switch from hardcoded term to dynamic term
|
||||
func (a *API) Scrape() error {
|
||||
// For each subject, retrieve all courses
|
||||
// For each course, get the details and store it in redis
|
||||
// Make sure to handle pagination
|
||||
subjects, err := a.GetSubjects("", "202510", 1, 100)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get subjects: %w", err)
|
||||
}
|
||||
|
||||
// Ensure subjects were found
|
||||
if len(subjects) == 0 {
|
||||
return fmt.Errorf("no subjects found")
|
||||
}
|
||||
|
||||
// Extract major code name
|
||||
for _, subject := range subjects {
|
||||
// Add to AncillaryMajors if not in PriorityMajors
|
||||
if !lo.Contains(PriorityMajors, subject.Code) {
|
||||
AncillaryMajors = append(AncillaryMajors, subject.Code)
|
||||
}
|
||||
}
|
||||
|
||||
AllMajors = lo.Flatten([][]string{PriorityMajors, AncillaryMajors})
|
||||
|
||||
expiredSubjects, err := a.GetExpiredSubjects()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get scrapable majors: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Strs("majors", expiredSubjects).Msg("Scraping majors")
|
||||
for _, subject := range expiredSubjects {
|
||||
err := a.ScrapeMajor(subject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scrape major %s: %w", subject, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExpiredSubjects returns a list of subjects that have expired and should be scraped again.
|
||||
// It checks Redis for the "scraped" status of each major for the current term.
|
||||
func (a *API) GetExpiredSubjects() ([]string, error) {
|
||||
term := a.DefaultTerm(time.Now()).ToString()
|
||||
subjects := make([]string, 0)
|
||||
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(a.config.Ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Get all subjects
|
||||
values, err := a.config.KV.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string {
|
||||
return fmt.Sprintf("scraped:%s:%s", major, term)
|
||||
})...).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get all subjects: %w", err)
|
||||
}
|
||||
|
||||
// Extract expired subjects
|
||||
for i, value := range values {
|
||||
subject := AllMajors[i]
|
||||
|
||||
// If the value is nil or "0", then the subject is expired
|
||||
if value == nil || value == "0" {
|
||||
subjects = append(subjects, subject)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Strs("majors", subjects).Msg("Expired Subjects")
|
||||
|
||||
return subjects, nil
|
||||
}
|
||||
|
||||
// ScrapeMajor scrapes all courses for a specific major.
|
||||
// This function does not check whether scraping is required at this time; it is assumed that the caller has already done so.
|
||||
func (a *API) ScrapeMajor(subject string) error {
|
||||
offset := 0
|
||||
totalClassCount := 0
|
||||
|
||||
for {
|
||||
// Build & execute the query
|
||||
query := NewQuery().Offset(offset).MaxResults(MaxPageSize * 2).Subject(subject)
|
||||
term := a.DefaultTerm(time.Now()).ToString()
|
||||
result, err := a.Search(term, query, "subjectDescription", false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("search failed: %w (%s)", err, query.String())
|
||||
}
|
||||
|
||||
// Isn't it bullshit that they decided not to leave an actual 'reason' field for the failure?
|
||||
if !result.Success {
|
||||
return fmt.Errorf("result marked unsuccessful when searching for classes (%s)", query.String())
|
||||
}
|
||||
|
||||
classCount := len(result.Data)
|
||||
totalClassCount += classCount
|
||||
log.Debug().Str("subject", subject).Int("count", classCount).Int("offset", offset).Msg("Placing classes in Redis")
|
||||
|
||||
// Process each class and store it in Redis
|
||||
for _, course := range result.Data {
|
||||
// Store class in Redis
|
||||
err := a.IntakeCourse(course)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to store class in Redis")
|
||||
}
|
||||
}
|
||||
|
||||
// Increment and continue if the results are full
|
||||
if classCount >= MaxPageSize {
|
||||
// This is unlikely to happen, but log it just in case
|
||||
if classCount > MaxPageSize {
|
||||
log.Warn().Int("page", offset).Int("count", classCount).Msg("Results exceed MaxPageSize")
|
||||
}
|
||||
|
||||
offset += MaxPageSize
|
||||
|
||||
// TODO: Replace sleep with smarter rate limiting
|
||||
log.Debug().Str("subject", subject).Int("nextOffset", offset).Msg("Sleeping before next page")
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
// Log the number of classes scraped
|
||||
log.Info().Str("subject", subject).Int("total", totalClassCount).Msgf("Subject %s Scraped", subject)
|
||||
break
|
||||
}
|
||||
|
||||
term := a.DefaultTerm(time.Now()).ToString()
|
||||
|
||||
// Calculate the expiry time for the scrape (1 hour for every 200 classes, random +-15%) with a minimum of 1 hour
|
||||
var scrapeExpiry time.Duration
|
||||
if totalClassCount == 0 {
|
||||
scrapeExpiry = time.Hour * 12
|
||||
} else {
|
||||
scrapeExpiry = a.CalculateExpiry(term, totalClassCount, lo.Contains(PriorityMajors, subject))
|
||||
}
|
||||
|
||||
// Mark the major as scraped
|
||||
if totalClassCount == 0 {
|
||||
totalClassCount = -1
|
||||
}
|
||||
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := a.config.KV.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to mark major as scraped")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateExpiry calculates the expiry time until the next scrape for a major.
|
||||
// The duration is based on the number of courses, whether the major is a priority, and if the term is archived.
|
||||
func (a *API) CalculateExpiry(term string, count int, priority bool) time.Duration {
|
||||
// An hour for every 100 classes
|
||||
baseExpiry := time.Hour * time.Duration(count/100)
|
||||
|
||||
// Subjects with less than 50 classes have a reversed expiry (less classes, longer interval)
|
||||
// 1 class => 12 hours, 49 classes => 1 hour
|
||||
if count < 50 {
|
||||
hours := internal.Slope(internal.Point{X: 1, Y: 12}, internal.Point{X: 49, Y: 1}, float64(count)).Y
|
||||
baseExpiry = time.Duration(hours * float64(time.Hour))
|
||||
}
|
||||
|
||||
// If the subject is a priority, then the expiry is halved without variance
|
||||
if priority {
|
||||
return baseExpiry / 3
|
||||
}
|
||||
|
||||
// If the term is considered "view only" or "archived", then the expiry is multiplied by 5
|
||||
var expiry = baseExpiry
|
||||
if a.IsTermArchived(term) {
|
||||
expiry *= 5
|
||||
}
|
||||
|
||||
// Add minor variance to the expiry
|
||||
expiryVariance := baseExpiry.Seconds() * (rand.Float64() * 0.15) // Between 0 and 15% of the total
|
||||
if rand.Intn(2) == 0 {
|
||||
expiry -= time.Duration(expiryVariance) * time.Second
|
||||
} else {
|
||||
expiry += time.Duration(expiryVariance) * time.Second
|
||||
}
|
||||
|
||||
// Ensure the expiry is at least 1 hour with up to 15 extra minutes
|
||||
if expiry < time.Hour {
|
||||
baseExpiry = time.Hour + time.Duration(rand.Intn(60*15))*time.Second
|
||||
}
|
||||
|
||||
return baseExpiry
|
||||
}
|
||||
|
||||
// IntakeCourse stores a course in Redis.
|
||||
// This function will be used to handle change identification, notifications, and SQLite upserts in the future.
|
||||
func (a *API) IntakeCourse(course models.Course) error {
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := a.config.KV.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store class in Redis: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
paramSubject = "txt_subject"
|
||||
paramTitle = "txt_courseTitle"
|
||||
paramKeywords = "txt_keywordlike"
|
||||
paramOpenOnly = "chk_open_only"
|
||||
paramTermPart = "txt_partOfTerm"
|
||||
paramCampus = "txt_campus"
|
||||
paramAttributes = "txt_attribute"
|
||||
paramInstructor = "txt_instructor"
|
||||
paramStartTimeHour = "select_start_hour"
|
||||
paramStartTimeMinute = "select_start_min"
|
||||
paramStartTimeMeridiem = "select_start_ampm"
|
||||
paramEndTimeHour = "select_end_hour"
|
||||
paramEndTimeMinute = "select_end_min"
|
||||
paramEndTimeMeridiem = "select_end_ampm"
|
||||
paramMinCredits = "txt_credithourlow"
|
||||
paramMaxCredits = "txt_credithourhigh"
|
||||
paramCourseNumberLow = "txt_course_number_range"
|
||||
paramCourseNumberHigh = "txt_course_number_range_to"
|
||||
paramOffset = "pageOffset"
|
||||
paramMaxResults = "pageMaxSize"
|
||||
)
|
||||
|
||||
// Query represents a search query for courses.
|
||||
// It is a builder that allows for chaining methods to construct a query.
|
||||
type Query struct {
|
||||
subject *string
|
||||
title *string
|
||||
keywords *[]string
|
||||
openOnly *bool
|
||||
termPart *[]string // e.g. [1, B6, 8, J]
|
||||
campus *[]string // e.g. [9, 1DT, 1LR]
|
||||
instructionalMethod *[]string // e.g. [HB]
|
||||
attributes *[]string // e.g. [060, 010]
|
||||
instructor *[]uint64 // e.g. [27957, 27961]
|
||||
startTime *time.Duration
|
||||
endTime *time.Duration
|
||||
minCredits *int
|
||||
maxCredits *int
|
||||
offset int
|
||||
maxResults int
|
||||
courseNumberRange *Range
|
||||
}
|
||||
|
||||
// NewQuery creates a new Query with default values.
|
||||
func NewQuery() *Query {
|
||||
return &Query{maxResults: 8, offset: 0}
|
||||
}
|
||||
|
||||
// Subject sets the subject for the query.
|
||||
func (q *Query) Subject(subject string) *Query {
|
||||
q.subject = &subject
|
||||
return q
|
||||
}
|
||||
|
||||
// Title sets the title for the query.
|
||||
func (q *Query) Title(title string) *Query {
|
||||
q.title = &title
|
||||
return q
|
||||
}
|
||||
|
||||
// Keywords sets the keywords for the query.
|
||||
func (q *Query) Keywords(keywords []string) *Query {
|
||||
q.keywords = &keywords
|
||||
return q
|
||||
}
|
||||
|
||||
// Keyword adds a keyword to the query.
|
||||
func (q *Query) Keyword(keyword string) *Query {
|
||||
if q.keywords == nil {
|
||||
q.keywords = &[]string{keyword}
|
||||
} else {
|
||||
*q.keywords = append(*q.keywords, keyword)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// OpenOnly sets whether to search for open courses only.
|
||||
func (q *Query) OpenOnly(openOnly bool) *Query {
|
||||
q.openOnly = &openOnly
|
||||
return q
|
||||
}
|
||||
|
||||
// TermPart sets the term part for the query.
|
||||
func (q *Query) TermPart(termPart []string) *Query {
|
||||
q.termPart = &termPart
|
||||
return q
|
||||
}
|
||||
|
||||
// Campus sets the campuses for the query.
|
||||
func (q *Query) Campus(campus []string) *Query {
|
||||
q.campus = &campus
|
||||
return q
|
||||
}
|
||||
|
||||
// InstructionalMethod sets the instructional methods for the query.
|
||||
func (q *Query) InstructionalMethod(instructionalMethod []string) *Query {
|
||||
q.instructionalMethod = &instructionalMethod
|
||||
return q
|
||||
}
|
||||
|
||||
// Attributes sets the attributes for the query.
|
||||
func (q *Query) Attributes(attributes []string) *Query {
|
||||
q.attributes = &attributes
|
||||
return q
|
||||
}
|
||||
|
||||
// Instructor sets the instructors for the query.
|
||||
func (q *Query) Instructor(instructor []uint64) *Query {
|
||||
q.instructor = &instructor
|
||||
return q
|
||||
}
|
||||
|
||||
// StartTime sets the start time for the query.
|
||||
func (q *Query) StartTime(startTime time.Duration) *Query {
|
||||
q.startTime = &startTime
|
||||
return q
|
||||
}
|
||||
|
||||
// EndTime sets the end time for the query.
|
||||
func (q *Query) EndTime(endTime time.Duration) *Query {
|
||||
q.endTime = &endTime
|
||||
return q
|
||||
}
|
||||
|
||||
// Credits sets the credit range for the query.
|
||||
func (q *Query) Credits(low int, high int) *Query {
|
||||
q.minCredits = &low
|
||||
q.maxCredits = &high
|
||||
return q
|
||||
}
|
||||
|
||||
// MinCredits sets the minimum credits for the query.
|
||||
func (q *Query) MinCredits(value int) *Query {
|
||||
q.minCredits = &value
|
||||
return q
|
||||
}
|
||||
|
||||
// MaxCredits sets the maximum credits for the query.
|
||||
func (q *Query) MaxCredits(value int) *Query {
|
||||
q.maxCredits = &value
|
||||
return q
|
||||
}
|
||||
|
||||
// CourseNumbers sets the course number range for the query.
|
||||
func (q *Query) CourseNumbers(low int, high int) *Query {
|
||||
q.courseNumberRange = &Range{low, high}
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset sets the offset for pagination.
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
q.offset = offset
|
||||
return q
|
||||
}
|
||||
|
||||
// MaxResults sets the maximum number of results to return.
|
||||
func (q *Query) MaxResults(maxResults int) *Query {
|
||||
q.maxResults = maxResults
|
||||
return q
|
||||
}
|
||||
|
||||
// Range represents a range of two integers.
|
||||
type Range struct {
|
||||
Low int
|
||||
High int
|
||||
}
|
||||
|
||||
// FormatTimeParameter formats a time.Duration into a tuple of strings for use in a POST request.
|
||||
// It returns the hour, minute, and meridiem (AM/PM) as separate strings.
|
||||
func FormatTimeParameter(d time.Duration) (string, string, string) {
|
||||
hourParameter, minuteParameter, meridiemParameter := "", "", ""
|
||||
|
||||
hours := int64(d.Hours())
|
||||
minutes := int64(d.Minutes()) % 60
|
||||
|
||||
minuteParameter = strconv.FormatInt(minutes, 10)
|
||||
|
||||
if hours >= 12 {
|
||||
hourParameter = "PM"
|
||||
|
||||
// Exceptional case: 12PM = 12, 1PM = 1, 2PM = 2
|
||||
if hours >= 13 {
|
||||
hourParameter = strconv.FormatInt(hours-12, 10) // 13 - 12 = 1, 14 - 12 = 2
|
||||
} else {
|
||||
hourParameter = strconv.FormatInt(hours, 10)
|
||||
}
|
||||
} else {
|
||||
meridiemParameter = "AM"
|
||||
hourParameter = strconv.FormatInt(hours, 10)
|
||||
}
|
||||
|
||||
return hourParameter, minuteParameter, meridiemParameter
|
||||
}
|
||||
|
||||
// Paramify converts a Query into a map of parameters for a POST request.
|
||||
// This function assumes each query key only appears once.
|
||||
func (q *Query) Paramify() map[string]string {
|
||||
params := map[string]string{}
|
||||
|
||||
if q.subject != nil {
|
||||
params[paramSubject] = *q.subject
|
||||
}
|
||||
|
||||
if q.title != nil {
|
||||
// Whitespace can prevent valid queries from succeeding
|
||||
params[paramTitle] = strings.TrimSpace(*q.title)
|
||||
}
|
||||
|
||||
if q.keywords != nil {
|
||||
params[paramKeywords] = strings.Join(*q.keywords, " ")
|
||||
}
|
||||
|
||||
if q.openOnly != nil {
|
||||
params[paramOpenOnly] = "true"
|
||||
}
|
||||
|
||||
if q.termPart != nil {
|
||||
params[paramTermPart] = strings.Join(*q.termPart, ",")
|
||||
}
|
||||
|
||||
if q.campus != nil {
|
||||
params[paramCampus] = strings.Join(*q.campus, ",")
|
||||
}
|
||||
|
||||
if q.attributes != nil {
|
||||
params[paramAttributes] = strings.Join(*q.attributes, ",")
|
||||
}
|
||||
|
||||
if q.instructor != nil {
|
||||
params[paramInstructor] = strings.Join(lo.Map(*q.instructor, func(i uint64, _ int) string {
|
||||
return strconv.FormatUint(i, 10)
|
||||
}), ",")
|
||||
}
|
||||
|
||||
if q.startTime != nil {
|
||||
hour, minute, meridiem := FormatTimeParameter(*q.startTime)
|
||||
params[paramStartTimeHour] = hour
|
||||
params[paramStartTimeMinute] = minute
|
||||
params[paramStartTimeMeridiem] = meridiem
|
||||
}
|
||||
|
||||
if q.endTime != nil {
|
||||
hour, minute, meridiem := FormatTimeParameter(*q.endTime)
|
||||
params[paramEndTimeHour] = hour
|
||||
params[paramEndTimeMinute] = minute
|
||||
params[paramEndTimeMeridiem] = meridiem
|
||||
}
|
||||
|
||||
if q.minCredits != nil {
|
||||
params[paramMinCredits] = strconv.Itoa(*q.minCredits)
|
||||
}
|
||||
|
||||
if q.maxCredits != nil {
|
||||
params[paramMaxCredits] = strconv.Itoa(*q.maxCredits)
|
||||
}
|
||||
|
||||
if q.courseNumberRange != nil {
|
||||
params[paramCourseNumberLow] = strconv.Itoa(q.courseNumberRange.Low)
|
||||
params[paramCourseNumberHigh] = strconv.Itoa(q.courseNumberRange.High)
|
||||
}
|
||||
|
||||
params[paramOffset] = strconv.Itoa(q.offset)
|
||||
params[paramMaxResults] = strconv.Itoa(q.maxResults)
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// String returns a string representation of the query, ideal for debugging & logging.
|
||||
func (q *Query) String() string {
|
||||
var sb strings.Builder
|
||||
|
||||
if q.subject != nil {
|
||||
fmt.Fprintf(&sb, "subject=%s, ", *q.subject)
|
||||
}
|
||||
|
||||
if q.title != nil {
|
||||
// Whitespace can prevent valid queries from succeeding
|
||||
fmt.Fprintf(&sb, "title=%s, ", strings.TrimSpace(*q.title))
|
||||
}
|
||||
|
||||
if q.keywords != nil {
|
||||
fmt.Fprintf(&sb, "keywords=%s, ", strings.Join(*q.keywords, " "))
|
||||
}
|
||||
|
||||
if q.openOnly != nil {
|
||||
fmt.Fprintf(&sb, "openOnly=%t, ", *q.openOnly)
|
||||
}
|
||||
|
||||
if q.termPart != nil {
|
||||
fmt.Fprintf(&sb, "termPart=%s, ", strings.Join(*q.termPart, ","))
|
||||
}
|
||||
|
||||
if q.campus != nil {
|
||||
fmt.Fprintf(&sb, "campus=%s, ", strings.Join(*q.campus, ","))
|
||||
}
|
||||
|
||||
if q.attributes != nil {
|
||||
fmt.Fprintf(&sb, "attributes=%s, ", strings.Join(*q.attributes, ","))
|
||||
}
|
||||
|
||||
if q.instructor != nil {
|
||||
fmt.Fprintf(&sb, "instructor=%s, ", strings.Join(lo.Map(*q.instructor, func(i uint64, _ int) string {
|
||||
return strconv.FormatUint(i, 10)
|
||||
}), ","))
|
||||
}
|
||||
|
||||
if q.startTime != nil {
|
||||
hour, minute, meridiem := FormatTimeParameter(*q.startTime)
|
||||
fmt.Fprintf(&sb, "startTime=%s:%s%s, ", hour, minute, meridiem)
|
||||
}
|
||||
|
||||
if q.endTime != nil {
|
||||
hour, minute, meridiem := FormatTimeParameter(*q.endTime)
|
||||
fmt.Fprintf(&sb, "endTime=%s:%s%s, ", hour, minute, meridiem)
|
||||
}
|
||||
|
||||
if q.minCredits != nil {
|
||||
fmt.Fprintf(&sb, "minCredits=%d, ", *q.minCredits)
|
||||
}
|
||||
|
||||
if q.maxCredits != nil {
|
||||
fmt.Fprintf(&sb, "maxCredits=%d, ", *q.maxCredits)
|
||||
}
|
||||
|
||||
if q.courseNumberRange != nil {
|
||||
fmt.Fprintf(&sb, "courseNumberRange=%d-%d, ", q.courseNumberRange.Low, q.courseNumberRange.High)
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "offset=%d, ", q.offset)
|
||||
fmt.Fprintf(&sb, "maxResults=%d", q.maxResults)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Dict returns a map representation of the query, ideal for debugging & logging.
|
||||
// This dict is represented with zerolog's Event type.
|
||||
// func (q *Query) Dict() *zerolog.Event {
|
||||
// }
|
||||
@@ -1,64 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"net/url"
|
||||
|
||||
log "github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Setup makes the initial requests to set up the session cookies for the application.
|
||||
func (a *API) Setup() {
|
||||
// Makes the initial requests that sets up the session cookies for the rest of the application
|
||||
log.Info().Msg("Setting up session...")
|
||||
|
||||
requestQueue := []string{
|
||||
"/registration/registration",
|
||||
"/selfServiceMenu/data",
|
||||
}
|
||||
|
||||
for _, path := range requestQueue {
|
||||
req := a.config.Client.NewRequest().
|
||||
SetQueryParam("_", internal.Nonce()).
|
||||
SetExpectResponseContentType("application/json")
|
||||
|
||||
res, err := req.Get(path)
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Str("path", path).Err(err).Msg("Failed to make request")
|
||||
}
|
||||
|
||||
if res.StatusCode() != 200 {
|
||||
log.Fatal().Stack().Str("path", path).Int("status", res.StatusCode()).Msg("Failed to make request")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that cookies were set
|
||||
baseURLParsed, err := url.Parse(a.config.BaseURL)
|
||||
if err != nil {
|
||||
log.Fatal().Stack().Str("baseURL", a.config.BaseURL).Err(err).Msg("Failed to parse baseURL")
|
||||
}
|
||||
|
||||
currentCookies := a.config.Client.CookieJar().Cookies(baseURLParsed)
|
||||
requiredCookies := map[string]bool{
|
||||
"JSESSIONID": false,
|
||||
"SSB_COOKIE": false,
|
||||
}
|
||||
|
||||
for _, cookie := range currentCookies {
|
||||
_, present := requiredCookies[cookie.Name]
|
||||
// Check if this cookie is required
|
||||
if present {
|
||||
requiredCookies[cookie.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all required cookies were set
|
||||
for cookieName, cookieSet := range requiredCookies {
|
||||
if !cookieSet {
|
||||
log.Warn().Str("cookieName", cookieName).Msg("Required cookie not set")
|
||||
}
|
||||
}
|
||||
log.Debug().Msg("All required cookies set, session setup complete")
|
||||
|
||||
// TODO: Validate that the session allows access to termSelection
|
||||
}
|
||||
@@ -1,649 +0,0 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"banner/internal/api"
|
||||
"banner/internal/models"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
// ICalTimestampLayoutUtc is the formatting layout for timestamps in the UTC timezone.
|
||||
ICalTimestampLayoutUtc = "20060102T150405Z"
|
||||
// ICalTimestampLayoutLocal is the formatting layout for timestamps in the local timezone.
|
||||
ICalTimestampLayoutLocal = "20060102T150405"
|
||||
)
|
||||
|
||||
// CommandHandler is a function that handles a slash command interaction.
|
||||
type CommandHandler func(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error
|
||||
|
||||
var (
|
||||
// CommandDefinitions is a list of all the bot's command definitions.
|
||||
CommandDefinitions = []*discordgo.ApplicationCommand{TermCommandDefinition, TimeCommandDefinition, SearchCommandDefinition, IcsCommandDefinition, GCalCommandDefinition}
|
||||
// CommandHandlers is a map of command names to their handlers.
|
||||
CommandHandlers = map[string]CommandHandler{
|
||||
TimeCommandDefinition.Name: TimeCommandHandler,
|
||||
TermCommandDefinition.Name: TermCommandHandler,
|
||||
SearchCommandDefinition.Name: SearchCommandHandler,
|
||||
IcsCommandDefinition.Name: IcsCommandHandler,
|
||||
GCalCommandDefinition.Name: GCalCommandHandler,
|
||||
}
|
||||
)
|
||||
|
||||
var SearchCommandDefinition = &discordgo.ApplicationCommand{
|
||||
Name: "search",
|
||||
Description: "Search for a course",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
MinLength: internal.GetIntPointer(0),
|
||||
MaxLength: 48,
|
||||
Name: "title",
|
||||
Description: "Course Title (exact, use autocomplete)",
|
||||
Required: false,
|
||||
Autocomplete: true,
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
Name: "code",
|
||||
MinLength: internal.GetIntPointer(4),
|
||||
Description: "Course Code (e.g. 3743, 3000-3999, 3xxx, 3000-)",
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionInteger,
|
||||
Name: "max",
|
||||
Description: "Maximum number of results",
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
Name: "keywords",
|
||||
Description: "Keywords in Title or Description (space separated)",
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
Name: "instructor",
|
||||
Description: "Instructor Name",
|
||||
Required: false,
|
||||
Autocomplete: true,
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
Name: "subject",
|
||||
Description: "Subject (e.g. Computer Science/CS, Mathematics/MAT)",
|
||||
Required: false,
|
||||
Autocomplete: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// SearchCommandHandler handles the /search command, which allows users to search for courses.
|
||||
func SearchCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
|
||||
data := i.ApplicationCommandData()
|
||||
query := api.NewQuery().Credits(3, 6)
|
||||
|
||||
for _, option := range data.Options {
|
||||
switch option.Name {
|
||||
case "title":
|
||||
query.Title(option.StringValue())
|
||||
case "code":
|
||||
var (
|
||||
low = -1
|
||||
high = -1
|
||||
)
|
||||
var err error
|
||||
valueRaw := strings.TrimSpace(option.StringValue())
|
||||
|
||||
// Partially/fully specified range
|
||||
if strings.Contains(valueRaw, "-") {
|
||||
match := regexp.MustCompile(`(\d{1,4})-(\d{1,4})?`).FindSubmatch([]byte(valueRaw))
|
||||
|
||||
if match == nil {
|
||||
return fmt.Errorf("invalid range format: %s", valueRaw)
|
||||
}
|
||||
|
||||
// If not 2 or 3 matches, it's invalid
|
||||
if len(match) != 3 && len(match) != 4 {
|
||||
return fmt.Errorf("invalid range format: %s", match[0])
|
||||
}
|
||||
|
||||
low, err = strconv.Atoi(string(match[1]))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing course code (low)")
|
||||
}
|
||||
|
||||
// If there's not a high value, set it to max (open ended)
|
||||
if len(match) == 2 || len(match[2]) == 0 {
|
||||
high = 9999
|
||||
} else {
|
||||
high, err = strconv.Atoi(string(match[2]))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing course code (high)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #xxx, ##xx, ###x format (34xx -> 3400-3499)
|
||||
if strings.Contains(valueRaw, "x") {
|
||||
if len(valueRaw) != 4 {
|
||||
return fmt.Errorf("code range format invalid: must be 1 or more digits followed by x's (%s)", valueRaw)
|
||||
}
|
||||
|
||||
match := regexp.MustCompile(`\d{1,}([xX]{1,3})`).Match([]byte(valueRaw))
|
||||
if !match {
|
||||
return fmt.Errorf("code range format invalid: must be 1 or more digits followed by x's (%s)", valueRaw)
|
||||
}
|
||||
|
||||
// Replace x's with 0's
|
||||
low, err = strconv.Atoi(strings.Replace(valueRaw, "x", "0", -1))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing implied course code (low)")
|
||||
}
|
||||
|
||||
// Replace x's with 9's
|
||||
high, err = strconv.Atoi(strings.Replace(valueRaw, "x", "9", -1))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing implied course code (high)")
|
||||
}
|
||||
} else if len(valueRaw) == 4 {
|
||||
// 4 digit code
|
||||
low, err = strconv.Atoi(valueRaw)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing course code")
|
||||
}
|
||||
|
||||
high = low
|
||||
}
|
||||
|
||||
if low == -1 || high == -1 {
|
||||
return fmt.Errorf("course code range invalid (%s)", valueRaw)
|
||||
}
|
||||
|
||||
if low > high {
|
||||
return fmt.Errorf("course code range is invalid: low is greater than high (%d > %d)", low, high)
|
||||
}
|
||||
|
||||
if low < 1000 || high < 1000 || low > 9999 || high > 9999 {
|
||||
return fmt.Errorf("course code range is invalid: must be 1000-9999 (%d-%d)", low, high)
|
||||
}
|
||||
|
||||
query.CourseNumbers(low, high)
|
||||
case "keywords":
|
||||
query.Keywords(
|
||||
strings.Split(option.StringValue(), " "),
|
||||
)
|
||||
case "max":
|
||||
query.MaxResults(
|
||||
min(8, int(option.IntValue())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
term, err := b.GetSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
courses, err := b.API.Search(term, query, "", false)
|
||||
if err != nil {
|
||||
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Content: "Error searching for courses",
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
fetchTime := time.Now()
|
||||
fields := []*discordgo.MessageEmbedField{}
|
||||
|
||||
for _, course := range courses.Data {
|
||||
// Safe instructor name handling
|
||||
displayName := "TBA"
|
||||
if len(course.Faculty) > 0 {
|
||||
displayName = course.Faculty[0].DisplayName
|
||||
}
|
||||
|
||||
categoryLink := fmt.Sprintf("[%s](https://catalog.utsa.edu/undergraduate/coursedescriptions/%s/)", course.Subject, strings.ToLower(course.Subject))
|
||||
classLink := fmt.Sprintf("[%s-%s](https://catalog.utsa.edu/search/?P=%s%%20%s)", course.CourseNumber, course.SequenceNumber, course.Subject, course.CourseNumber)
|
||||
professorLink := fmt.Sprintf("[%s](https://www.ratemyprofessors.com/search/professors/1516?q=%s)", displayName, url.QueryEscape(displayName))
|
||||
|
||||
identifierText := fmt.Sprintf("%s %s (CRN %s)\n%s", categoryLink, classLink, course.CourseReferenceNumber, professorLink)
|
||||
|
||||
// Safe meeting time handling
|
||||
meetingTime := "No scheduled meetings"
|
||||
if len(course.MeetingsFaculty) > 0 {
|
||||
meetingTime = course.MeetingsFaculty[0].String()
|
||||
}
|
||||
|
||||
fields = append(fields, &discordgo.MessageEmbedField{
|
||||
Name: "Identifier",
|
||||
Value: identifierText,
|
||||
Inline: true,
|
||||
}, &discordgo.MessageEmbedField{
|
||||
Name: "Name",
|
||||
Value: course.CourseTitle,
|
||||
Inline: true,
|
||||
}, &discordgo.MessageEmbedField{
|
||||
Name: "Meeting Time",
|
||||
Value: meetingTime,
|
||||
Inline: true,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Blue if there are results, orange if there are none
|
||||
color := 0x0073FF
|
||||
if courses.TotalCount == 0 {
|
||||
color = 0xFF6500
|
||||
}
|
||||
|
||||
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Embeds: []*discordgo.MessageEmbed{
|
||||
{
|
||||
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
|
||||
Description: fmt.Sprintf("%d Class%s", courses.TotalCount, internal.Plural(courses.TotalCount)),
|
||||
Fields: fields[:min(25, len(fields))],
|
||||
Color: color,
|
||||
},
|
||||
},
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var TermCommandDefinition = &discordgo.ApplicationCommand{
|
||||
Name: "terms",
|
||||
Description: "Guess the current term, or search for a specific term",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
MinLength: internal.GetIntPointer(0),
|
||||
MaxLength: 8,
|
||||
Name: "search",
|
||||
Description: "Term to search for",
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionInteger,
|
||||
Name: "page",
|
||||
Description: "Page Number",
|
||||
Required: false,
|
||||
MinValue: internal.GetFloatPointer(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TermCommandHandler handles the /terms command, which allows users to search for terms.
|
||||
func TermCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
|
||||
data := i.ApplicationCommandData()
|
||||
|
||||
searchTerm := ""
|
||||
pageNumber := 1
|
||||
|
||||
for _, option := range data.Options {
|
||||
switch option.Name {
|
||||
case "search":
|
||||
searchTerm = option.StringValue()
|
||||
case "page":
|
||||
pageNumber = int(option.IntValue())
|
||||
default:
|
||||
log.Warn().Str("option", option.Name).Msg("Unexpected option in term command")
|
||||
}
|
||||
}
|
||||
|
||||
termResult, err := b.API.GetTerms(searchTerm, pageNumber, 25)
|
||||
|
||||
if err != nil {
|
||||
internal.RespondError(s, i.Interaction, "Error while fetching terms", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fields := []*discordgo.MessageEmbedField{}
|
||||
|
||||
for _, t := range termResult {
|
||||
fields = append(fields, &discordgo.MessageEmbedField{
|
||||
Name: t.Description,
|
||||
Value: t.Code,
|
||||
Inline: true,
|
||||
})
|
||||
}
|
||||
|
||||
fetchTime := time.Now()
|
||||
|
||||
if len(fields) > 25 {
|
||||
log.Warn().Int("count", len(fields)).Msg("Too many fields in term command (trimmed)")
|
||||
}
|
||||
|
||||
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Embeds: []*discordgo.MessageEmbed{
|
||||
{
|
||||
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
|
||||
Description: fmt.Sprintf("%d term%s (page %d)", len(termResult), internal.Plural(len(termResult)), pageNumber),
|
||||
Fields: fields[:min(25, len(fields))],
|
||||
},
|
||||
},
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var TimeCommandDefinition = &discordgo.ApplicationCommand{
|
||||
Name: "time",
|
||||
Description: "Get Class Meeting Time",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionInteger,
|
||||
Name: "crn",
|
||||
Description: "Course Reference Number",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TimeCommandHandler handles the /time command, which allows users to get the meeting times for a course.
|
||||
func TimeCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
|
||||
fetchTime := time.Now()
|
||||
crn := i.ApplicationCommandData().Options[0].IntValue()
|
||||
|
||||
// Fix static term
|
||||
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
|
||||
if err != nil {
|
||||
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Content: "Error getting meeting time",
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
if len(meetingTimes) == 0 {
|
||||
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Content: "No meeting times found for this course",
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("no meeting times found for CRN %d", crn)
|
||||
}
|
||||
|
||||
meetingTime := meetingTimes[0]
|
||||
duration := meetingTime.EndTime().Sub(meetingTime.StartTime())
|
||||
|
||||
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Embeds: []*discordgo.MessageEmbed{
|
||||
{
|
||||
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
|
||||
Description: "",
|
||||
Fields: []*discordgo.MessageEmbedField{
|
||||
{
|
||||
Name: "Start Date",
|
||||
Value: meetingTime.StartDay().Format("Monday, January 2, 2006"),
|
||||
},
|
||||
{
|
||||
Name: "End Date",
|
||||
Value: meetingTime.EndDay().Format("Monday, January 2, 2006"),
|
||||
},
|
||||
{
|
||||
Name: "Start/End Time",
|
||||
Value: fmt.Sprintf("%s - %s (%d min)", meetingTime.StartTime().String(), meetingTime.EndTime().String(), int64(duration.Minutes())),
|
||||
},
|
||||
{
|
||||
Name: "Days of Week",
|
||||
Value: internal.WeekdaysToString(meetingTime.Days()),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
var IcsCommandDefinition = &discordgo.ApplicationCommand{
|
||||
Name: "ics",
|
||||
Description: "Generate an ICS file for a course",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionInteger,
|
||||
Name: "crn",
|
||||
Description: "Course Reference Number",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var GCalCommandDefinition = &discordgo.ApplicationCommand{
|
||||
Name: "gcal",
|
||||
Description: "Generate a link to create a Google Calendar event for a course",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionInteger,
|
||||
Name: "crn",
|
||||
Description: "Course Reference Number",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// GCalCommandHandler handles the /gcal command, which allows users to generate a link to create a Google Calendar event for a course.
|
||||
func GCalCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
|
||||
// Parse all options
|
||||
options := internal.ParseOptions(i.ApplicationCommandData().Options)
|
||||
crn := options.GetInt("crn")
|
||||
|
||||
course, err := b.API.GetCourse(strconv.Itoa(int(crn)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error retrieving course data: %w", err)
|
||||
}
|
||||
|
||||
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error requesting meeting time: %w", err)
|
||||
}
|
||||
|
||||
if len(meetingTimes) == 0 {
|
||||
return fmt.Errorf("unexpected - no meeting time data found for course")
|
||||
}
|
||||
|
||||
// Check if the course has any meeting times
|
||||
meetingTime, exists := lo.Find(meetingTimes, func(mt models.MeetingTimeResponse) bool {
|
||||
switch mt.MeetingTime.MeetingType {
|
||||
case "ID", "OA":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
|
||||
if !exists {
|
||||
internal.RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
startDay := meetingTime.StartDay()
|
||||
startTime := meetingTime.StartTime()
|
||||
endTime := meetingTime.EndTime()
|
||||
|
||||
// Create timestamps in UTC
|
||||
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
|
||||
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
|
||||
|
||||
// Format times in UTC for Google Calendar
|
||||
startStr := dtStart.UTC().Format(ICalTimestampLayoutUtc)
|
||||
endStr := dtEnd.UTC().Format(ICalTimestampLayoutUtc)
|
||||
|
||||
// Generate RRULE for recurrence
|
||||
rrule := meetingTime.RRule()
|
||||
recurRule := fmt.Sprintf("FREQ=WEEKLY;BYDAY=%s;UNTIL=%s", rrule.ByDay, rrule.Until)
|
||||
|
||||
// Build calendar URL
|
||||
params := url.Values{}
|
||||
params.Add("action", "TEMPLATE")
|
||||
params.Add("text", fmt.Sprintf("%s %s - %s", course.Subject, course.CourseNumber, course.CourseTitle))
|
||||
params.Add("dates", fmt.Sprintf("%s/%s", startStr, endStr))
|
||||
params.Add("details", fmt.Sprintf("CRN: %s\nInstructor: %s\nDays: %s", course.CourseReferenceNumber, meetingTime.Faculty[0].DisplayName, internal.WeekdaysToString(meetingTime.Days())))
|
||||
params.Add("location", meetingTime.PlaceString())
|
||||
params.Add("trp", "true")
|
||||
params.Add("ctz", b.Config.CentralTimeLocation.String())
|
||||
params.Add("recur", "RRULE:"+recurRule)
|
||||
|
||||
calendarURL := "https://calendar.google.com/calendar/render?" + params.Encode()
|
||||
|
||||
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Content: fmt.Sprintf("[Add to Google Calendar](<%s>)", calendarURL),
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// IcsCommandHandler handles the /ics command, which allows users to generate an ICS file for a course.
|
||||
func IcsCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
|
||||
// Parse all options
|
||||
options := internal.ParseOptions(i.ApplicationCommandData().Options)
|
||||
crn := options.GetInt("crn")
|
||||
|
||||
course, err := b.API.GetCourse(strconv.Itoa(int(crn)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error retrieving course data: %w", err)
|
||||
}
|
||||
|
||||
// Fix static term
|
||||
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error requesting meeting time: %w", err)
|
||||
}
|
||||
|
||||
if len(meetingTimes) == 0 {
|
||||
return fmt.Errorf("unexpected - no meeting time data found for course")
|
||||
}
|
||||
|
||||
// Check if the course has any meeting times
|
||||
_, exists := lo.Find(meetingTimes, func(mt models.MeetingTimeResponse) bool {
|
||||
switch mt.MeetingTime.MeetingType {
|
||||
case "ID", "OA":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
|
||||
if !exists {
|
||||
log.Warn().Str("crn", course.CourseReferenceNumber).Msg("Non-meeting course requested for ICS file")
|
||||
internal.RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
events := []string{}
|
||||
for _, meeting := range meetingTimes {
|
||||
now := time.Now().In(b.Config.CentralTimeLocation)
|
||||
uid := fmt.Sprintf("%d-%s@ical.banner.xevion.dev", now.Unix(), meeting.CourseReferenceNumber)
|
||||
|
||||
startDay := meeting.StartDay()
|
||||
startTime := meeting.StartTime()
|
||||
endTime := meeting.EndTime()
|
||||
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
|
||||
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
|
||||
|
||||
// endDay := meeting.EndDay()
|
||||
// until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, b.Config.CentralTimeLocation)
|
||||
|
||||
summary := fmt.Sprintf("%s %s %s", course.Subject, course.CourseNumber, course.CourseTitle)
|
||||
|
||||
// Safe instructor name handling
|
||||
instructorName := "TBA"
|
||||
if len(course.Faculty) > 0 {
|
||||
instructorName = course.Faculty[0].DisplayName
|
||||
}
|
||||
|
||||
description := fmt.Sprintf("Instructor: %s\nSection: %s\nCRN: %s", instructorName, course.SequenceNumber, meeting.CourseReferenceNumber)
|
||||
location := meeting.PlaceString()
|
||||
|
||||
rrule := meeting.RRule()
|
||||
|
||||
event := fmt.Sprintf(`BEGIN:VEVENT
|
||||
DTSTAMP:%s
|
||||
UID:%s
|
||||
DTSTART;TZID=America/Chicago:%s
|
||||
RRULE:FREQ=WEEKLY;BYDAY=%s;UNTIL=%s
|
||||
DTEND;TZID=America/Chicago:%s
|
||||
SUMMARY:%s
|
||||
DESCRIPTION:%s
|
||||
LOCATION:%s
|
||||
END:VEVENT`, now.Format(ICalTimestampLayoutLocal), uid, dtStart.Format(ICalTimestampLayoutLocal), rrule.ByDay, rrule.Until, dtEnd.Format(ICalTimestampLayoutLocal), summary, strings.Replace(description, "\n", `\n`, -1), location)
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
// TODO: Make this dynamically requested, parsed & cached from tzurl.org
|
||||
vTimezone := `BEGIN:VTIMEZONE
|
||||
TZID:America/Chicago
|
||||
LAST-MODIFIED:20231222T233358Z
|
||||
TZURL:https://www.tzurl.org/zoneinfo-outlook/America/Chicago
|
||||
X-LIC-LOCATION:America/Chicago
|
||||
BEGIN:DAYLIGHT
|
||||
TZNAME:CDT
|
||||
TZOFFSETFROM:-0600
|
||||
TZOFFSETTO:-0500
|
||||
DTSTART:19700308T020000
|
||||
RRULE:FREQ=YEARLY;BYMONTH=3;BYDAY=2SU
|
||||
END:DAYLIGHT
|
||||
BEGIN:STANDARD
|
||||
TZNAME:CST
|
||||
TZOFFSETFROM:-0500
|
||||
TZOFFSETTO:-0600
|
||||
DTSTART:19701101T020000
|
||||
RRULE:FREQ=YEARLY;BYMONTH=11;BYDAY=1SU
|
||||
END:STANDARD
|
||||
END:VTIMEZONE`
|
||||
|
||||
ics := fmt.Sprintf(`BEGIN:VCALENDAR
|
||||
VERSION:2.0
|
||||
PRODID:-//xevion//Banner Discord Bot//EN
|
||||
CALSCALE:GREGORIAN
|
||||
%s
|
||||
%s
|
||||
END:VCALENDAR`, vTimezone, strings.Join(events, "\n"))
|
||||
|
||||
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Files: []*discordgo.File{
|
||||
{
|
||||
Name: fmt.Sprintf("%s-%s-%s_%s.ics", course.Subject, course.CourseNumber, course.SequenceNumber, course.CourseReferenceNumber),
|
||||
ContentType: "text/calendar",
|
||||
Reader: strings.NewReader(ics),
|
||||
},
|
||||
},
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"fmt"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// RegisterHandlers registers the bot's command handlers.
|
||||
func (b *Bot) RegisterHandlers() {
|
||||
b.Session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) {
|
||||
// Handle commands during restart (highly unlikely, but just in case)
|
||||
if b.isClosing {
|
||||
err := internal.RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to respond with restart error feedback")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
name := interaction.ApplicationCommandData().Name
|
||||
if handler, ok := CommandHandlers[name]; ok {
|
||||
// Build dict of options for the log
|
||||
options := zerolog.Dict()
|
||||
for _, option := range interaction.ApplicationCommandData().Options {
|
||||
options.Str(option.Name, fmt.Sprintf("%v", option.Value))
|
||||
}
|
||||
|
||||
event := log.Info().Str("name", name).Str("user", internal.GetUser(interaction).Username).Dict("options", options)
|
||||
|
||||
// If the command was invoked in a guild, add guild & channel info to the log
|
||||
if interaction.Member != nil {
|
||||
guild := zerolog.Dict()
|
||||
guild.Str("id", interaction.GuildID)
|
||||
guild.Str("name", internal.GetGuildName(b.Config, internalSession, interaction.GuildID))
|
||||
event.Dict("guild", guild)
|
||||
|
||||
channel := zerolog.Dict()
|
||||
channel.Str("id", interaction.ChannelID)
|
||||
guild.Str("name", internal.GetChannelName(b.Config, internalSession, interaction.ChannelID))
|
||||
event.Dict("channel", channel)
|
||||
} else {
|
||||
// If the command was invoked in a DM, add the user info to the log
|
||||
user := zerolog.Dict()
|
||||
user.Str("id", interaction.User.ID)
|
||||
user.Str("name", interaction.User.Username)
|
||||
event.Dict("user", user)
|
||||
}
|
||||
|
||||
// Log command invocation
|
||||
event.Msg("Command Invoked")
|
||||
|
||||
// Prepare to recover
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Error().Stack().Str("commandName", name).Interface("detail", err).Msg("Command Handler Panic")
|
||||
|
||||
// Respond with error
|
||||
err := internal.RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil)
|
||||
if err != nil {
|
||||
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with panic error feedback")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Call handler
|
||||
err := handler(b, internalSession, interaction)
|
||||
|
||||
// Log & respond error
|
||||
if err != nil {
|
||||
// TODO: Find a way to merge the response with the handler's error
|
||||
log.Error().Str("commandName", name).Err(err).Msg("Command Handler Error")
|
||||
|
||||
// Respond with error
|
||||
err = internal.RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil)
|
||||
if err != nil {
|
||||
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with error feedback")
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
log.Error().Stack().Str("commandName", name).Msg("Command Interaction Has No Handler")
|
||||
|
||||
// Respond with error
|
||||
internal.RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
// Package bot provides the core functionality for the Discord bot.
|
||||
package bot
|
||||
|
||||
import (
|
||||
"banner/internal/api"
|
||||
"banner/internal/config"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Bot represents the state of the Discord bot.
|
||||
type Bot struct {
|
||||
Session *discordgo.Session
|
||||
API *api.API
|
||||
Config *config.Config
|
||||
isClosing bool
|
||||
}
|
||||
|
||||
// New creates a new Bot instance.
|
||||
func New(s *discordgo.Session, a *api.API, c *config.Config) *Bot {
|
||||
return &Bot{Session: s, API: a, Config: c}
|
||||
}
|
||||
|
||||
// SetClosing marks the bot as closing, preventing new commands from being processed.
|
||||
func (b *Bot) SetClosing() {
|
||||
b.isClosing = true
|
||||
}
|
||||
|
||||
// GetSession ensures a valid session is available and selects the default term.
|
||||
func (b *Bot) GetSession() (string, error) {
|
||||
sessionID := b.API.EnsureSession()
|
||||
term := b.API.DefaultTerm(time.Now()).ToString()
|
||||
|
||||
log.Info().Str("term", term).Str("sessionID", sessionID).Msg("Setting selected term")
|
||||
err := b.API.SelectTerm(term, sessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to select term while generating session ID: %w", err)
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"resty.dev/v3"
|
||||
)
|
||||
|
||||
// Config holds the application's configuration.
|
||||
type Config struct {
|
||||
// Ctx is the application's root context.
|
||||
Ctx context.Context
|
||||
// CancelFunc cancels the application's root context.
|
||||
CancelFunc context.CancelFunc
|
||||
// KV provides access to the Redis cache.
|
||||
KV *redis.Client
|
||||
// Client is the HTTP client for making API requests.
|
||||
Client *resty.Client
|
||||
// IsDevelopment is true if the application is running in a development environment.
|
||||
IsDevelopment bool
|
||||
// BaseURL is the base URL for the Banner API.
|
||||
BaseURL string
|
||||
// Environment is the application's running environment (e.g. "development").
|
||||
Environment string
|
||||
// CentralTimeLocation is the time.Location for US Central Time.
|
||||
CentralTimeLocation *time.Location
|
||||
// SeasonRanges is the time.Location for US Central Time.
|
||||
SeasonRanges *SeasonRanges
|
||||
}
|
||||
|
||||
// New creates a new Config instance with a cancellable context.
|
||||
func New() (*Config, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
loc, err := time.LoadLocation("America/Chicago")
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seasonRanges := GetYearDayRange(loc, uint16(time.Now().Year()))
|
||||
|
||||
return &Config{
|
||||
Ctx: ctx,
|
||||
CancelFunc: cancel,
|
||||
CentralTimeLocation: loc,
|
||||
SeasonRanges: &seasonRanges,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetBaseURL sets the base URL for the Banner API.
|
||||
func (c *Config) SetBaseURL(url string) {
|
||||
c.BaseURL = url
|
||||
}
|
||||
|
||||
// SetEnvironment sets the application's environment.
|
||||
func (c *Config) SetEnvironment(env string) {
|
||||
c.Environment = env
|
||||
c.IsDevelopment = env == "development"
|
||||
}
|
||||
|
||||
// SetClient sets the Resty client for making HTTP requests.
|
||||
func (c *Config) SetClient(client *resty.Client) {
|
||||
c.Client = client
|
||||
}
|
||||
|
||||
// SetRedis sets the Redis client for caching.
|
||||
func (c *Config) SetRedis(r *redis.Client) {
|
||||
c.KV = r
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
// Package config provides the configuration and logging setup for the application.
|
||||
package config
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const timeFormat = "2006-01-02 15:04:05"
|
||||
|
||||
// NewConsoleWriter creates a new console writer that splits logs between stdout and stderr.
|
||||
func NewConsoleWriter() zerolog.LevelWriter {
|
||||
return &ConsoleLogSplitter{
|
||||
stdConsole: zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: timeFormat,
|
||||
NoColor: false,
|
||||
PartsOrder: []string{zerolog.TimestampFieldName, zerolog.LevelFieldName, zerolog.MessageFieldName},
|
||||
PartsExclude: []string{},
|
||||
FieldsExclude: []string{},
|
||||
},
|
||||
errConsole: zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: timeFormat,
|
||||
NoColor: false,
|
||||
PartsOrder: []string{zerolog.TimestampFieldName, zerolog.LevelFieldName, zerolog.MessageFieldName},
|
||||
PartsExclude: []string{},
|
||||
FieldsExclude: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ConsoleLogSplitter is a zerolog.LevelWriter that writes to stdout for info/debug logs and stderr for warn/error logs, with console-friendly formatting.
|
||||
type ConsoleLogSplitter struct {
|
||||
stdConsole zerolog.ConsoleWriter
|
||||
errConsole zerolog.ConsoleWriter
|
||||
}
|
||||
|
||||
// Write is a passthrough to the standard console writer and should not be called directly.
|
||||
func (c *ConsoleLogSplitter) Write(p []byte) (n int, err error) {
|
||||
return c.stdConsole.Write(p)
|
||||
}
|
||||
|
||||
// WriteLevel writes to the appropriate output (stdout or stderr) with console formatting based on the log level.
|
||||
func (c *ConsoleLogSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
|
||||
if level <= zerolog.WarnLevel {
|
||||
return c.stdConsole.Write(p)
|
||||
}
|
||||
return c.errConsole.Write(p)
|
||||
}
|
||||
|
||||
// LogSplitter is a zerolog.LevelWriter that writes to stdout for info/debug logs and stderr for warn/error logs.
|
||||
type LogSplitter struct {
|
||||
Std io.Writer
|
||||
Err io.Writer
|
||||
}
|
||||
|
||||
// Write is a passthrough to the standard writer and should not be called directly.
|
||||
func (l LogSplitter) Write(p []byte) (n int, err error) {
|
||||
return l.Std.Write(p)
|
||||
}
|
||||
|
||||
// WriteLevel writes to the appropriate output (stdout or stderr) based on the log level.
|
||||
func (l LogSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
|
||||
if level <= zerolog.WarnLevel {
|
||||
return l.Std.Write(p)
|
||||
}
|
||||
return l.Err.Write(p)
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Term selection should yield smart results based on the current time, as well as the input provided.
|
||||
// Fall 2024, "spring" => Spring 2025
|
||||
// Fall 2024, "fall" => Fall 2025
|
||||
// Summer 2024, "fall" => Fall 2024
|
||||
|
||||
const (
|
||||
// Fall is the first term of the school year.
|
||||
Fall = iota
|
||||
// Spring is the second term of the school year.
|
||||
Spring
|
||||
// Summer is the third term of the school year.
|
||||
Summer
|
||||
)
|
||||
|
||||
// Term represents a school term, consisting of a year and a season.
|
||||
type Term struct {
|
||||
Year uint16
|
||||
Season uint8
|
||||
}
|
||||
|
||||
// SeasonRanges represents the start and end day of each term within a year.
|
||||
type SeasonRanges struct {
|
||||
Spring YearDayRange
|
||||
Summer YearDayRange
|
||||
Fall YearDayRange
|
||||
}
|
||||
|
||||
// YearDayRange represents the start and end day of a term within a year.
|
||||
type YearDayRange struct {
|
||||
Start uint16
|
||||
End uint16
|
||||
}
|
||||
|
||||
// GetYearDayRange returns the start and end day of each term for the given year.
|
||||
// The ranges are inclusive of the start day and exclusive of the end day.
|
||||
func GetYearDayRange(loc *time.Location, year uint16) SeasonRanges {
|
||||
springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, loc).YearDay()
|
||||
springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, loc).YearDay()
|
||||
summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, loc).YearDay()
|
||||
summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, loc).YearDay()
|
||||
fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, loc).YearDay()
|
||||
fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, loc).YearDay()
|
||||
|
||||
return SeasonRanges{
|
||||
Spring: YearDayRange{
|
||||
Start: uint16(springStart),
|
||||
End: uint16(springEnd),
|
||||
},
|
||||
Summer: YearDayRange{
|
||||
Start: uint16(summerStart),
|
||||
End: uint16(summerEnd),
|
||||
},
|
||||
Fall: YearDayRange{
|
||||
Start: uint16(fallStart),
|
||||
End: uint16(fallEnd),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentTerm returns the current and next terms based on the provided time.
|
||||
// The current term can be nil if the time falls between terms.
|
||||
// The 'year' in the term corresponds to the academic year, which may differ from the calendar year.
|
||||
func GetCurrentTerm(ranges SeasonRanges, now time.Time) (*Term, *Term) {
|
||||
literalYear := uint16(now.Year())
|
||||
dayOfYear := uint16(now.YearDay())
|
||||
|
||||
// If we're past the end of the summer term, we're 'in' the next school year.
|
||||
var termYear uint16
|
||||
if dayOfYear > ranges.Summer.End {
|
||||
termYear = literalYear + 1
|
||||
} else {
|
||||
termYear = literalYear
|
||||
}
|
||||
|
||||
if (dayOfYear < ranges.Spring.Start) || (dayOfYear >= ranges.Fall.End) {
|
||||
// Fall over, Spring not yet begun
|
||||
return nil, &Term{Year: termYear, Season: Spring}
|
||||
} else if (dayOfYear >= ranges.Spring.Start) && (dayOfYear < ranges.Spring.End) {
|
||||
// Spring
|
||||
return &Term{Year: termYear, Season: Spring}, &Term{Year: termYear, Season: Summer}
|
||||
} else if dayOfYear < ranges.Summer.Start {
|
||||
// Spring over, Summer not yet begun
|
||||
return nil, &Term{Year: termYear, Season: Summer}
|
||||
} else if (dayOfYear >= ranges.Summer.Start) && (dayOfYear < ranges.Summer.End) {
|
||||
// Summer
|
||||
return &Term{Year: termYear, Season: Summer}, &Term{Year: termYear, Season: Fall}
|
||||
} else if dayOfYear < ranges.Fall.Start {
|
||||
// Summer over, Fall not yet begun
|
||||
return nil, &Term{Year: termYear, Season: Fall}
|
||||
} else if (dayOfYear >= ranges.Fall.Start) && (dayOfYear < ranges.Fall.End) {
|
||||
// Fall
|
||||
return &Term{Year: termYear, Season: Fall}, nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("Impossible Code Reached (dayOfYear: %d)", dayOfYear))
|
||||
}
|
||||
|
||||
// ParseTerm converts a Banner term code string to a Term struct.
|
||||
func ParseTerm(code string) Term {
|
||||
year, _ := strconv.ParseUint(code[0:4], 10, 16)
|
||||
|
||||
var season uint8
|
||||
termCode := code[4:6]
|
||||
switch termCode {
|
||||
case "10":
|
||||
season = Fall
|
||||
case "20":
|
||||
season = Spring
|
||||
case "30":
|
||||
season = Summer
|
||||
}
|
||||
|
||||
return Term{
|
||||
Year: uint16(year),
|
||||
Season: season,
|
||||
}
|
||||
}
|
||||
|
||||
// ToString converts a Term struct to a Banner term code string.
|
||||
func (term Term) ToString() string {
|
||||
var season string
|
||||
switch term.Season {
|
||||
case Fall:
|
||||
season = "10"
|
||||
case Spring:
|
||||
season = "20"
|
||||
case Summer:
|
||||
season = "30"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d%s", term.Year, season)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package internal
|
||||
|
||||
import "fmt"
|
||||
|
||||
// UnexpectedContentTypeError is returned when the Content-Type header of a response does not match the expected value.
|
||||
type UnexpectedContentTypeError struct {
|
||||
Expected string
|
||||
Actual string
|
||||
}
|
||||
|
||||
func (e *UnexpectedContentTypeError) Error() string {
|
||||
return fmt.Sprintf("Expected content type '%s', received '%s'", e.Expected, e.Actual)
|
||||
}
|
||||
@@ -1,376 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/rs/zerolog"
|
||||
log "github.com/rs/zerolog/log"
|
||||
"resty.dev/v3"
|
||||
|
||||
"banner/internal/config"
|
||||
)
|
||||
|
||||
// Options is a map of options from a Discord command.
|
||||
type Options map[string]*discordgo.ApplicationCommandInteractionDataOption
|
||||
|
||||
// GetInt returns the integer value of an option, or 0 if it doesn't exist.
|
||||
func (o Options) GetInt(key string) int64 {
|
||||
if opt, ok := o[key]; ok {
|
||||
return opt.IntValue()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ParseOptions parses slash command options into a map for easier access.
|
||||
func ParseOptions(options []*discordgo.ApplicationCommandInteractionDataOption) Options {
|
||||
optionMap := make(Options)
|
||||
for _, opt := range options {
|
||||
optionMap[opt.Name] = opt
|
||||
}
|
||||
return optionMap
|
||||
}
|
||||
|
||||
// AddUserAgent adds a consistent user agent to the request to mimic a real browser.
|
||||
func AddUserAgent(req *http.Request) {
|
||||
req.Header.Add("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36")
|
||||
}
|
||||
|
||||
// ContentTypeMatch checks if a Resty response has the given content type.
|
||||
func ContentTypeMatch(res *resty.Response, expectedContentType string) bool {
|
||||
contentType := res.Header().Get("Content-Type")
|
||||
if contentType == "" {
|
||||
return expectedContentType == "application/octect-stream"
|
||||
}
|
||||
return strings.HasPrefix(contentType, expectedContentType)
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
// RandomString returns a random string of length n.
|
||||
// The character set is chosen to mimic Ellucian's Banner session ID generation.
|
||||
func RandomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// DiscordGoLogger is a helper function that implements discordgo's logging interface, directing all logs to zerolog.
|
||||
func DiscordGoLogger(msgL, caller int, format string, a ...interface{}) {
|
||||
pc, file, line, _ := runtime.Caller(caller)
|
||||
|
||||
files := strings.Split(file, "/")
|
||||
file = files[len(files)-1]
|
||||
|
||||
name := runtime.FuncForPC(pc).Name()
|
||||
fns := strings.Split(name, ".")
|
||||
name = fns[len(fns)-1]
|
||||
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
|
||||
var event *zerolog.Event
|
||||
switch msgL {
|
||||
case 0:
|
||||
event = log.Debug()
|
||||
case 1:
|
||||
event = log.Info()
|
||||
case 2:
|
||||
event = log.Warn()
|
||||
case 3:
|
||||
event = log.Error()
|
||||
default:
|
||||
event = log.Info()
|
||||
}
|
||||
|
||||
event.Str("file", file).Int("line", line).Str("function", name).Msg(msg)
|
||||
}
|
||||
|
||||
// Nonce returns the current time in milliseconds since the Unix epoch as a string.
|
||||
// This is typically used as a query parameter to prevent request caching.
|
||||
func Nonce() string {
|
||||
return strconv.Itoa(int(time.Now().UnixMilli()))
|
||||
}
|
||||
|
||||
// Plural returns "s" if n is not 1.
|
||||
func Plural(n int) string {
|
||||
if n == 1 {
|
||||
return ""
|
||||
}
|
||||
return "s"
|
||||
}
|
||||
|
||||
// Plurale returns "es" if n is not 1.
|
||||
func Plurale(n int) string {
|
||||
if n == 1 {
|
||||
return ""
|
||||
}
|
||||
return "es"
|
||||
}
|
||||
|
||||
// WeekdaysToString converts a map of weekdays to a compact string representation (e.g., "MWF").
|
||||
func WeekdaysToString(days map[time.Weekday]bool) string {
|
||||
// If no days are present
|
||||
numDays := len(days)
|
||||
if numDays == 0 {
|
||||
return "None"
|
||||
}
|
||||
|
||||
// If all days are present
|
||||
if numDays == 7 {
|
||||
return "Everyday"
|
||||
}
|
||||
|
||||
str := ""
|
||||
|
||||
if days[time.Monday] {
|
||||
str += "M"
|
||||
}
|
||||
|
||||
if days[time.Tuesday] {
|
||||
str += "Tu"
|
||||
}
|
||||
|
||||
if days[time.Wednesday] {
|
||||
str += "W"
|
||||
}
|
||||
|
||||
if days[time.Thursday] {
|
||||
str += "Th"
|
||||
}
|
||||
|
||||
if days[time.Friday] {
|
||||
str += "F"
|
||||
}
|
||||
|
||||
if days[time.Saturday] {
|
||||
str += "Sa"
|
||||
}
|
||||
|
||||
if days[time.Sunday] {
|
||||
str += "Su"
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// NaiveTime represents a time of day without a date or timezone.
|
||||
type NaiveTime struct {
|
||||
Hours uint
|
||||
Minutes uint
|
||||
}
|
||||
|
||||
// Sub returns the duration between two NaiveTime instances.
|
||||
func (nt *NaiveTime) Sub(other *NaiveTime) time.Duration {
|
||||
return time.Hour*time.Duration(nt.Hours-other.Hours) + time.Minute*time.Duration(nt.Minutes-other.Minutes)
|
||||
}
|
||||
|
||||
// ParseNaiveTime converts an integer representation of time (e.g., 1430) to a NaiveTime struct.
|
||||
func ParseNaiveTime(integer uint64) *NaiveTime {
|
||||
minutes := uint(integer % 100)
|
||||
hours := uint(integer / 100)
|
||||
|
||||
return &NaiveTime{Hours: hours, Minutes: minutes}
|
||||
}
|
||||
|
||||
// String returns a string representation of the NaiveTime in 12-hour format (e.g., "2:30PM").
|
||||
func (nt NaiveTime) String() string {
|
||||
meridiem := "AM"
|
||||
hour := nt.Hours
|
||||
if nt.Hours >= 12 {
|
||||
meridiem = "PM"
|
||||
if nt.Hours > 12 {
|
||||
hour -= 12
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%d:%02d%s", hour, nt.Minutes, meridiem)
|
||||
}
|
||||
|
||||
// GetFirstEnv returns the value of the first environment variable that is set.
|
||||
func GetFirstEnv(key ...string) string {
|
||||
for _, k := range key {
|
||||
if v := os.Getenv(k); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetIntPointer returns a pointer to the given integer.
|
||||
func GetIntPointer(value int) *int {
|
||||
return &value
|
||||
}
|
||||
|
||||
// GetFloatPointer returns a pointer to the given float.
|
||||
func GetFloatPointer(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
var extensionMap = map[string]string{
|
||||
"text/plain": "txt",
|
||||
"application/json": "json",
|
||||
"text/html": "html",
|
||||
"text/css": "css",
|
||||
"text/csv": "csv",
|
||||
"text/calendar": "ics",
|
||||
"text/markdown": "md",
|
||||
"text/xml": "xml",
|
||||
"text/yaml": "yaml",
|
||||
"text/javascript": "js",
|
||||
"text/vtt": "vtt",
|
||||
"image/jpeg": "jpg",
|
||||
"image/png": "png",
|
||||
"image/gif": "gif",
|
||||
"image/webp": "webp",
|
||||
"image/tiff": "tiff",
|
||||
"image/svg+xml": "svg",
|
||||
"image/bmp": "bmp",
|
||||
"image/vnd.microsoft.icon": "ico",
|
||||
"image/x-icon": "ico",
|
||||
"image/x-xbitmap": "xbm",
|
||||
"image/x-xpixmap": "xpm",
|
||||
"image/x-xwindowdump": "xwd",
|
||||
"image/avif": "avif",
|
||||
"image/apng": "apng",
|
||||
"image/jxl": "jxl",
|
||||
}
|
||||
|
||||
// GuessExtension guesses the file extension for a given content type.
|
||||
func GuessExtension(contentType string) string {
|
||||
ext, ok := extensionMap[strings.ToLower(contentType)]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ext
|
||||
}
|
||||
|
||||
// DumpResponse dumps the body of a Resty response to a file for debugging.
|
||||
func DumpResponse(res *resty.Response) {
|
||||
contentType := res.Header().Get("Content-Type")
|
||||
ext := GuessExtension(contentType)
|
||||
|
||||
// Use current time as filename + /dumps/ prefix
|
||||
filename := fmt.Sprintf("dumps/%d.%s", time.Now().Unix(), ext)
|
||||
file, err := os.Create(filename)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg("Error creating file")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg("Error reading response body")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = file.Write(body)
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg("Error writing response body")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Str("filename", filename).Str("content-type", contentType).Msg("Dumped response body")
|
||||
}
|
||||
|
||||
// RespondError responds to an interaction with a formatted error message.
|
||||
func RespondError(session *discordgo.Session, interaction *discordgo.Interaction, message string, err error) error {
|
||||
// Optional: log the error
|
||||
if err != nil {
|
||||
log.Err(err).Stack().Msg(message)
|
||||
}
|
||||
|
||||
return session.InteractionRespond(interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
Data: &discordgo.InteractionResponseData{
|
||||
Embeds: []*discordgo.MessageEmbed{
|
||||
{
|
||||
Footer: &discordgo.MessageEmbedFooter{
|
||||
Text: fmt.Sprintf("Occurred at %s", time.Now().Format("Monday, January 2, 2006 at 3:04:05PM")),
|
||||
},
|
||||
Description: message,
|
||||
Color: 0xff0000,
|
||||
},
|
||||
},
|
||||
AllowedMentions: &discordgo.MessageAllowedMentions{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetFetchedFooter returns a standard footer for embeds, indicating when the data was fetched.
|
||||
func GetFetchedFooter(cfg *config.Config, time time.Time) *discordgo.MessageEmbedFooter {
|
||||
return &discordgo.MessageEmbedFooter{
|
||||
Text: fmt.Sprintf("Fetched at %s", time.In(cfg.CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")),
|
||||
}
|
||||
}
|
||||
|
||||
// GetUser returns the user from an interaction, regardless of whether it was in a guild or a DM.
|
||||
func GetUser(interaction *discordgo.InteractionCreate) *discordgo.User {
|
||||
// If the interaction is in a guild, the user is in the Member field
|
||||
if interaction.Member != nil {
|
||||
return interaction.Member.User
|
||||
}
|
||||
|
||||
// If the interaction is in a DM, the user is in the User field
|
||||
return interaction.User
|
||||
}
|
||||
|
||||
// EncodeParams encodes a map of parameters into a URL-encoded string, sorted by key.
|
||||
func EncodeParams(params map[string]*[]string) string {
|
||||
// Escape hatch for nil
|
||||
if params == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Sort the keys
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var buf strings.Builder
|
||||
for _, k := range keys {
|
||||
// Multiple values are allowed, so extract the slice & prepare the key
|
||||
values := params[k]
|
||||
keyEscaped := url.QueryEscape(k)
|
||||
|
||||
for _, v := range *values {
|
||||
// If any parameters have been written, add the ampersand
|
||||
if buf.Len() > 0 {
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
|
||||
// Write the key and value
|
||||
buf.WriteString(keyEscaped)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(url.QueryEscape(v))
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Point represents a point in 2D space.
|
||||
type Point struct {
|
||||
X, Y float64
|
||||
}
|
||||
|
||||
// Slope calculates the y-coordinate of a point on a line given two other points and an x-coordinate.
|
||||
func Slope(p1 Point, p2 Point, x float64) Point {
|
||||
slope := (p2.Y - p1.Y) / (p2.X - p1.X)
|
||||
newY := slope*(x-p1.X) + p1.Y
|
||||
return Point{X: x, Y: newY}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
// Package internal provides shared functionality for the banner application.
|
||||
package internal
|
||||
|
||||
import (
|
||||
"banner/internal/config"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/redis/go-redis/v9"
|
||||
log "github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GetGuildName returns the name of a guild by its ID, using Redis for caching.
|
||||
func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string {
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check Redis for the guild name
|
||||
guildName, err := cfg.KV.Get(ctx, "guild:"+guildID+":name").Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
|
||||
return "err"
|
||||
}
|
||||
|
||||
// If the guild name is invalid (1 character long), then return "unknown"
|
||||
if len(guildName) == 1 {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// If the guild name isn't in Redis, get it from Discord and cache it
|
||||
guild, err := session.Guild(guildID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("Error getting guild name")
|
||||
|
||||
// Store an invalid value in Redis so we don't keep trying to get the guild name
|
||||
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel2()
|
||||
_, err := cfg.KV.Set(ctx2, "guild:"+guildID+":name", "x", time.Minute*5).Result()
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Cache the guild name in Redis
|
||||
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel3()
|
||||
cfg.KV.Set(ctx3, "guild:"+guildID+":name", guild.Name, time.Hour*3)
|
||||
|
||||
return guild.Name
|
||||
}
|
||||
|
||||
// GetChannelName returns the name of a channel by its ID, using Redis for caching.
|
||||
func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string {
|
||||
// Create a timeout context for Redis operations
|
||||
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check Redis for the channel name
|
||||
channelName, err := cfg.KV.Get(ctx, "channel:"+channelID+":name").Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
|
||||
return "err"
|
||||
}
|
||||
|
||||
// If the channel name is invalid (1 character long), then return "unknown"
|
||||
if len(channelName) == 1 {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// If the channel name isn't in Redis, get it from Discord and cache it
|
||||
channel, err := session.Channel(channelID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("Error getting channel name")
|
||||
|
||||
// Store an invalid value in Redis so we don't keep trying to get the channel name
|
||||
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel2()
|
||||
_, err := cfg.KV.Set(ctx2, "channel:"+channelID+":name", "x", time.Minute*5).Result()
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Cache the channel name in Redis
|
||||
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
|
||||
defer cancel3()
|
||||
cfg.KV.Set(ctx3, "channel:"+channelID+":name", channel.Name, time.Hour*3)
|
||||
|
||||
return channel.Name
|
||||
}
|
||||
@@ -1,323 +0,0 @@
|
||||
// Package models provides the data structures for the Banner API.
|
||||
package models
|
||||
|
||||
import (
|
||||
"banner/internal"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FacultyItem represents a faculty member associated with a course.
|
||||
type FacultyItem struct {
|
||||
BannerID string `json:"bannerId"`
|
||||
Category *string `json:"category"`
|
||||
Class string `json:"class"`
|
||||
CourseReferenceNumber string `json:"courseReferenceNumber"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"emailAddress"`
|
||||
Primary bool `json:"primaryIndicator"`
|
||||
Term string `json:"term"`
|
||||
}
|
||||
|
||||
// MeetingTimeResponse represents the meeting time information for a course.
|
||||
type MeetingTimeResponse struct {
|
||||
Category *string `json:"category"`
|
||||
Class string `json:"class"`
|
||||
CourseReferenceNumber string `json:"courseReferenceNumber"`
|
||||
Faculty []FacultyItem
|
||||
MeetingTime struct {
|
||||
Category string `json:"category"`
|
||||
// Some sort of metadata used internally by Banner (net.hedtech.banner.student.schedule.SectionSessionDecorator)
|
||||
Class string `json:"class"`
|
||||
// The start date of the meeting time in MM/DD/YYYY format (e.g. 01/16/2024)
|
||||
StartDate string `json:"startDate"`
|
||||
// The end date of the meeting time in MM/DD/YYYY format (e.g. 05/10/2024)
|
||||
EndDate string `json:"endDate"`
|
||||
// The start time of the meeting time in 24-hour format, hours & minutes, digits only (e.g. 1630)
|
||||
BeginTime string `json:"beginTime"`
|
||||
// The end time of the meeting time in 24-hour format, hours & minutes, digits only (e.g. 1745)
|
||||
EndTime string `json:"endTime"`
|
||||
// The room number within the building this course takes place at (e.g. 3.01.08, 200A)
|
||||
Room string `json:"room"`
|
||||
// The internal identifier for the term this course takes place in (e.g. 202420)
|
||||
Term string `json:"term"`
|
||||
// The internal identifier for the building this course takes place at (e.g. SP1)
|
||||
Building string `json:"building"`
|
||||
// The long name of the building this course takes place at (e.g. San Pedro I - Data Science)
|
||||
BuildingDescription string `json:"buildingDescription"`
|
||||
// The internal identifier for the campus this course takes place at (e.g. 1DT)
|
||||
Campus string `json:"campus"`
|
||||
// The long name of the campus this course takes place at (e.g. Main Campus, Downtown Campus)
|
||||
CampusDescription string `json:"campusDescription"`
|
||||
CourseReferenceNumber string `json:"courseReferenceNumber"`
|
||||
// The number of credit hours this class is worth (assumably)
|
||||
CreditHourSession float64 `json:"creditHourSession"`
|
||||
// The number of hours per week this class meets (e.g. 2.5)
|
||||
HoursWeek float64 `json:"hoursWeek"`
|
||||
// Unknown meaning - e.g. AFF, AIN, AHB, FFF, AFF, EFF, DFF, IFF, EHB, JFF, KFF, BFF, BIN
|
||||
MeetingScheduleType string `json:"meetingScheduleType"`
|
||||
// The short identifier for the meeting type (e.g. FF, HB, OS, OA)
|
||||
MeetingType string `json:"meetingType"`
|
||||
// The long name of the meeting type (e.g. Traditional in-person)
|
||||
MeetingTypeDescription string `json:"meetingTypeDescription"`
|
||||
// A boolean indicating if the class will meet on each Monday of the term
|
||||
Monday bool `json:"monday"`
|
||||
// A boolean indicating if the class will meet on each Tuesday of the term
|
||||
Tuesday bool `json:"tuesday"`
|
||||
// A boolean indicating if the class will meet on each Wednesday of the term
|
||||
Wednesday bool `json:"wednesday"`
|
||||
// A boolean indicating if the class will meet on each Thursday of the term
|
||||
Thursday bool `json:"thursday"`
|
||||
// A boolean indicating if the class will meet on each Friday of the term
|
||||
Friday bool `json:"friday"`
|
||||
// A boolean indicating if the class will meet on each Saturday of the term
|
||||
Saturday bool `json:"saturday"`
|
||||
// A boolean indicating if the class will meet on each Sunday of the term
|
||||
Sunday bool `json:"sunday"`
|
||||
} `json:"meetingTime"`
|
||||
Term string `json:"term"`
|
||||
}
|
||||
|
||||
// String returns a formatted string representation of the meeting time.
|
||||
func (m *MeetingTimeResponse) String() string {
|
||||
switch m.MeetingTime.MeetingType {
|
||||
case "HB":
|
||||
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
|
||||
case "H2":
|
||||
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
|
||||
case "H1":
|
||||
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
|
||||
case "OS":
|
||||
return fmt.Sprintf("%s\nOnline Only", m.TimeString())
|
||||
case "OA":
|
||||
return "No Time\nOnline Asynchronous"
|
||||
case "OH":
|
||||
return fmt.Sprintf("%s\nOnline Partial", m.TimeString())
|
||||
case "ID":
|
||||
return "To Be Arranged"
|
||||
case "FF":
|
||||
return fmt.Sprintf("%s\n%s", m.TimeString(), m.PlaceString())
|
||||
}
|
||||
|
||||
// TODO: Add error log
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
// TimeString returns a formatted string of the meeting times (e.g., "MWF 1:00PM-2:15PM").
|
||||
func (m *MeetingTimeResponse) TimeString() string {
|
||||
startTime := m.StartTime()
|
||||
endTime := m.EndTime()
|
||||
|
||||
if startTime == nil || endTime == nil {
|
||||
return "???"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s-%s", internal.WeekdaysToString(m.Days()), m.StartTime().String(), m.EndTime().String())
|
||||
}
|
||||
|
||||
// PlaceString returns a formatted string representing the location of the meeting.
|
||||
func (m *MeetingTimeResponse) PlaceString() string {
|
||||
mt := m.MeetingTime
|
||||
|
||||
// TODO: Add format case for partial online classes
|
||||
if mt.Room == "" {
|
||||
return "Online"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s | %s | %s %s", mt.CampusDescription, mt.BuildingDescription, mt.Building, mt.Room)
|
||||
}
|
||||
|
||||
// Days returns a map of weekdays on which the course meets.
|
||||
func (m *MeetingTimeResponse) Days() map[time.Weekday]bool {
|
||||
days := map[time.Weekday]bool{}
|
||||
|
||||
days[time.Monday] = m.MeetingTime.Monday
|
||||
days[time.Tuesday] = m.MeetingTime.Tuesday
|
||||
days[time.Wednesday] = m.MeetingTime.Wednesday
|
||||
days[time.Thursday] = m.MeetingTime.Thursday
|
||||
days[time.Friday] = m.MeetingTime.Friday
|
||||
days[time.Saturday] = m.MeetingTime.Saturday
|
||||
|
||||
return days
|
||||
}
|
||||
|
||||
// ByDay returns a comma-separated string of two-letter day abbreviations for the iCalendar RRule.
|
||||
func (m *MeetingTimeResponse) ByDay() string {
|
||||
days := []string{}
|
||||
|
||||
if m.MeetingTime.Sunday {
|
||||
days = append(days, "SU")
|
||||
}
|
||||
if m.MeetingTime.Monday {
|
||||
days = append(days, "MO")
|
||||
}
|
||||
if m.MeetingTime.Tuesday {
|
||||
days = append(days, "TU")
|
||||
}
|
||||
if m.MeetingTime.Wednesday {
|
||||
days = append(days, "WE")
|
||||
}
|
||||
if m.MeetingTime.Thursday {
|
||||
days = append(days, "TH")
|
||||
}
|
||||
if m.MeetingTime.Friday {
|
||||
days = append(days, "FR")
|
||||
}
|
||||
if m.MeetingTime.Saturday {
|
||||
days = append(days, "SA")
|
||||
}
|
||||
|
||||
return strings.Join(days, ",")
|
||||
}
|
||||
|
||||
const layout = "01/02/2006"
|
||||
|
||||
// StartDay returns the start date of the meeting as a time.Time object.
|
||||
// This method is not cached and will panic if the date cannot be parsed.
|
||||
func (m *MeetingTimeResponse) StartDay() time.Time {
|
||||
t, err := time.Parse(layout, m.MeetingTime.StartDate)
|
||||
if err != nil {
|
||||
log.Panic().Stack().Err(err).Str("raw", m.MeetingTime.StartDate).Msg("Cannot parse start date")
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// EndDay returns the end date of the meeting as a time.Time object.
|
||||
// This method is not cached and will panic if the date cannot be parsed.
|
||||
func (m *MeetingTimeResponse) EndDay() time.Time {
|
||||
t, err := time.Parse(layout, m.MeetingTime.EndDate)
|
||||
if err != nil {
|
||||
log.Panic().Stack().Err(err).Str("raw", m.MeetingTime.EndDate).Msg("Cannot parse end date")
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// StartTime returns the start time of the meeting as a NaiveTime object.
|
||||
// This method is not cached and will panic if the time cannot be parsed.
|
||||
func (m *MeetingTimeResponse) StartTime() *internal.NaiveTime {
|
||||
raw := m.MeetingTime.BeginTime
|
||||
if raw == "" {
|
||||
log.Panic().Stack().Msg("Start time is empty")
|
||||
}
|
||||
|
||||
value, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse start time integer")
|
||||
}
|
||||
|
||||
return internal.ParseNaiveTime(value)
|
||||
}
|
||||
|
||||
// EndTime returns the end time of the meeting as a NaiveTime object.
|
||||
// This method is not cached and will panic if the time cannot be parsed.
|
||||
func (m *MeetingTimeResponse) EndTime() *internal.NaiveTime {
|
||||
raw := m.MeetingTime.EndTime
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
value, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse end time integer")
|
||||
}
|
||||
|
||||
return internal.ParseNaiveTime(value)
|
||||
}
|
||||
|
||||
// RRule represents a recurrence rule for an iCalendar event.
|
||||
type RRule struct {
|
||||
Until string
|
||||
ByDay string
|
||||
}
|
||||
|
||||
// RRule converts the meeting time to a struct that satisfies the iCalendar RRule format.
|
||||
func (m *MeetingTimeResponse) RRule() RRule {
|
||||
return RRule{
|
||||
Until: m.EndDay().UTC().Format("20060102T150405Z"),
|
||||
ByDay: m.ByDay(),
|
||||
}
|
||||
}
|
||||
|
||||
// SearchResult represents the result of a course search.
|
||||
type SearchResult struct {
|
||||
Success bool `json:"success"`
|
||||
TotalCount int `json:"totalCount"`
|
||||
PageOffset int `json:"pageOffset"`
|
||||
PageMaxSize int `json:"pageMaxSize"`
|
||||
PathMode string `json:"pathMode"`
|
||||
SearchResultsConfig []struct {
|
||||
Config string `json:"config"`
|
||||
Display string `json:"display"`
|
||||
} `json:"searchResultsConfig"`
|
||||
Data []Course `json:"data"`
|
||||
}
|
||||
|
||||
// Course represents a single course returned from a search.
|
||||
type Course struct {
|
||||
// ID is an internal identifier not used outside of the Banner system.
|
||||
ID int `json:"id"`
|
||||
// Term is the internal identifier for the term this class is in (e.g. 202420).
|
||||
Term string `json:"term"`
|
||||
// TermDesc is the human-readable name of the term this class is in (e.g. Fall 2021).
|
||||
TermDesc string `json:"termDesc"`
|
||||
// CourseReferenceNumber is the unique identifier for a course within a term.
|
||||
CourseReferenceNumber string `json:"courseReferenceNumber"`
|
||||
// PartOfTerm specifies which part of the term the course is in (e.g. B6, B5).
|
||||
PartOfTerm string `json:"partOfTerm"`
|
||||
// CourseNumber is the 4-digit code for the course (e.g. 3743).
|
||||
CourseNumber string `json:"courseNumber"`
|
||||
// Subject is the subject acronym (e.g. CS, AEPI).
|
||||
Subject string `json:"subject"`
|
||||
// SubjectDescription is the full name of the course subject.
|
||||
SubjectDescription string `json:"subjectDescription"`
|
||||
// SequenceNumber is the course section (e.g. 001, 002).
|
||||
SequenceNumber string `json:"sequenceNumber"`
|
||||
CampusDescription string `json:"campusDescription"`
|
||||
// ScheduleTypeDescription is the type of schedule for the course (e.g. Lecture, Seminar).
|
||||
ScheduleTypeDescription string `json:"scheduleTypeDescription"`
|
||||
CourseTitle string `json:"courseTitle"`
|
||||
CreditHours int `json:"creditHours"`
|
||||
// MaximumEnrollment is the maximum number of students that can enroll.
|
||||
MaximumEnrollment int `json:"maximumEnrollment"`
|
||||
Enrollment int `json:"enrollment"`
|
||||
SeatsAvailable int `json:"seatsAvailable"`
|
||||
WaitCapacity int `json:"waitCapacity"`
|
||||
WaitCount int `json:"waitCount"`
|
||||
CrossList *string `json:"crossList"`
|
||||
CrossListCapacity *int `json:"crossListCapacity"`
|
||||
CrossListCount *int `json:"crossListCount"`
|
||||
CrossListAvailable *int `json:"crossListAvailable"`
|
||||
CreditHourHigh *int `json:"creditHourHigh"`
|
||||
CreditHourLow *int `json:"creditHourLow"`
|
||||
CreditHourIndicator *string `json:"creditHourIndicator"`
|
||||
OpenSection bool `json:"openSection"`
|
||||
LinkIdentifier *string `json:"linkIdentifier"`
|
||||
IsSectionLinked bool `json:"isSectionLinked"`
|
||||
// SubjectCourse is the combination of the subject and course number (e.g. CS3443).
|
||||
SubjectCourse string `json:"subjectCourse"`
|
||||
ReservedSeatSummary *string `json:"reservedSeatSummary"`
|
||||
InstructionalMethod string `json:"instructionalMethod"`
|
||||
InstructionalMethodDescription string `json:"instructionalMethodDescription"`
|
||||
SectionAttributes []struct {
|
||||
// Class is an internal API class identifier used by Banner.
|
||||
Class string `json:"class"`
|
||||
CourseReferenceNumber string `json:"courseReferenceNumber"`
|
||||
// Code for the attribute (e.g., UPPR, ZIEP, AIS).
|
||||
Code string `json:"code"`
|
||||
Description string `json:"description"`
|
||||
TermCode string `json:"termCode"`
|
||||
IsZtcAttribute bool `json:"isZTCAttribute"`
|
||||
} `json:"sectionAttributes"`
|
||||
Faculty []FacultyItem `json:"faculty"`
|
||||
MeetingsFaculty []MeetingTimeResponse `json:"meetingsFaculty"`
|
||||
}
|
||||
|
||||
// MarshalBinary implements the encoding.BinaryMarshaler interface.
|
||||
func (course Course) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(course)
|
||||
}
|
||||
16
src/bot/mod.rs
Normal file
16
src/bot/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use poise::serenity_prelude as serenity;
|
||||
pub struct Data {} // User data, which is stored and accessible in all command invocations
|
||||
pub type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
pub type Context<'a> = poise::Context<'a, Data, Error>;
|
||||
|
||||
/// Displays your or another user's account creation date
|
||||
#[poise::command(slash_command, prefix_command)]
|
||||
pub async fn age(
|
||||
ctx: Context<'_>,
|
||||
#[description = "Selected user"] user: Option<serenity::User>,
|
||||
) -> Result<(), Error> {
|
||||
let u = user.as_ref().unwrap_or_else(|| ctx.author());
|
||||
let response = format!("{}'s account was created at {}", u.name, u.created_at());
|
||||
ctx.say(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
0
src/config/mod.rs
Normal file
0
src/config/mod.rs
Normal file
0
src/error.rs
Normal file
0
src/error.rs
Normal file
1
src/lib.rs
Normal file
1
src/lib.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod bot;
|
||||
342
src/main.rs
Normal file
342
src/main.rs
Normal file
@@ -0,0 +1,342 @@
|
||||
use serde::Deserialize;
|
||||
use serenity::all::{ClientBuilder, GatewayIntents};
|
||||
use std::time::Duration;
|
||||
use tokio::{signal, sync::broadcast, task::JoinSet};
|
||||
use tracing::{error, info, warn};
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
use crate::bot::{Data, age};
|
||||
use figment::{Figment, providers::Env};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
bot_token: String,
|
||||
database_url: String,
|
||||
redis_url: String,
|
||||
banner_base_url: String,
|
||||
bot_target_guild: u64,
|
||||
bot_app_id: u64,
|
||||
}
|
||||
|
||||
mod bot;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ServiceResult {
|
||||
GracefulShutdown,
|
||||
NormalCompletion,
|
||||
Error(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
/// Common trait for all services in the application
|
||||
#[async_trait::async_trait]
|
||||
trait Service: Send + Sync {
|
||||
/// The name of the service for logging
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Run the service's main work loop
|
||||
async fn run(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Gracefully shutdown the service
|
||||
async fn shutdown(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
}
|
||||
|
||||
/// Generic service runner that handles the lifecycle
|
||||
async fn run_service(
|
||||
mut service: Box<dyn Service>,
|
||||
mut shutdown_rx: broadcast::Receiver<()>,
|
||||
) -> ServiceResult {
|
||||
let name = service.name();
|
||||
info!(service = name, "Service started");
|
||||
|
||||
let work = async {
|
||||
match service.run().await {
|
||||
Ok(()) => {
|
||||
warn!(service = name, "Service completed unexpectedly");
|
||||
ServiceResult::NormalCompletion
|
||||
}
|
||||
Err(e) => {
|
||||
error!(service = name, "Service failed: {e}");
|
||||
ServiceResult::Error(e)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
result = work => result,
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!(service = name, "Shutting down...");
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match service.shutdown().await {
|
||||
Ok(()) => {
|
||||
let elapsed = start_time.elapsed();
|
||||
info!(service = name, "Shutdown completed in {elapsed:.2?}");
|
||||
ServiceResult::GracefulShutdown
|
||||
}
|
||||
Err(e) => {
|
||||
let elapsed = start_time.elapsed();
|
||||
error!(service = name, "Shutdown failed after {elapsed:.2?}: {e}");
|
||||
ServiceResult::Error(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown coordinator for managing graceful shutdown of multiple services
|
||||
struct ShutdownCoordinator {
|
||||
shutdown_tx: broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
impl ShutdownCoordinator {
|
||||
fn new() -> Self {
|
||||
let (shutdown_tx, _) = broadcast::channel(1);
|
||||
Self { shutdown_tx }
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.subscribe()
|
||||
}
|
||||
|
||||
fn shutdown(&self) {
|
||||
let _ = self.shutdown_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Discord bot service implementation
|
||||
struct BotService {
|
||||
client: serenity::Client,
|
||||
shard_manager: std::sync::Arc<serenity::gateway::ShardManager>,
|
||||
}
|
||||
|
||||
impl BotService {
|
||||
fn new(client: serenity::Client) -> Self {
|
||||
let shard_manager = client.shard_manager.clone();
|
||||
Self {
|
||||
client,
|
||||
shard_manager,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Service for BotService {
|
||||
fn name(&self) -> &'static str {
|
||||
"bot"
|
||||
}
|
||||
|
||||
async fn run(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self.client.start().await {
|
||||
Ok(()) => {
|
||||
warn!(service = "bot", "Stopped early.");
|
||||
Err("bot stopped early".into())
|
||||
}
|
||||
Err(e) => {
|
||||
error!(service = "bot", "Error: {e:?}");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.shard_manager.shutdown_all().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Dummy service implementation for demonstration
|
||||
struct DummyService {
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
impl DummyService {
|
||||
fn new(name: &'static str) -> Self {
|
||||
Self { name }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Service for DummyService {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
async fn run(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut counter = 0;
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
counter += 1;
|
||||
info!(service = self.name, "Service heartbeat ({counter})");
|
||||
|
||||
// Simulate service failure after 60 seconds for demo
|
||||
if counter >= 6 {
|
||||
error!(service = self.name, "Service encountered an error");
|
||||
return Err("Service error".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Simulate cleanup work
|
||||
tokio::time::sleep(Duration::from_millis(3500)).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// Configure logging
|
||||
let filter =
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn,banner=debug"));
|
||||
let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
||||
let config: Config = Figment::new()
|
||||
.merge(Env::prefixed("APP_"))
|
||||
.extract()
|
||||
.expect("Failed to load config");
|
||||
|
||||
// Configure the client with your Discord bot token in the environment.
|
||||
let intents = GatewayIntents::non_privileged();
|
||||
|
||||
let framework = poise::Framework::builder()
|
||||
.options(poise::FrameworkOptions {
|
||||
commands: vec![age()],
|
||||
..Default::default()
|
||||
})
|
||||
.setup(|ctx, _ready, framework| {
|
||||
Box::pin(async move {
|
||||
poise::builtins::register_globally(ctx, &framework.options().commands).await?;
|
||||
Ok(Data {})
|
||||
})
|
||||
})
|
||||
.build();
|
||||
|
||||
let client = ClientBuilder::new(config.bot_token, intents)
|
||||
.framework(framework)
|
||||
.await
|
||||
.expect("Failed to build client");
|
||||
|
||||
let shutdown_coordinator = ShutdownCoordinator::new();
|
||||
|
||||
// Create services
|
||||
let bot_service = Box::new(BotService::new(client));
|
||||
let dummy_service = Box::new(DummyService::new("background"));
|
||||
|
||||
// Start services using the unified runner
|
||||
let bot_handle = {
|
||||
let shutdown_rx = shutdown_coordinator.subscribe();
|
||||
tokio::spawn(run_service(bot_service, shutdown_rx))
|
||||
};
|
||||
|
||||
let dummy_handle = {
|
||||
let shutdown_rx = shutdown_coordinator.subscribe();
|
||||
tokio::spawn(run_service(dummy_service, shutdown_rx))
|
||||
};
|
||||
|
||||
// Set up signal handling
|
||||
let signal_handle = {
|
||||
let coordinator = shutdown_coordinator.shutdown_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install CTRL+C signal handler");
|
||||
info!("Received CTRL+C, initiating shutdown...");
|
||||
let _ = coordinator.send(());
|
||||
ServiceResult::GracefulShutdown
|
||||
})
|
||||
};
|
||||
|
||||
// Put all services in a JoinSet for unified handling
|
||||
let mut services = JoinSet::new();
|
||||
services.spawn(bot_handle);
|
||||
services.spawn(dummy_handle);
|
||||
services.spawn(signal_handle);
|
||||
|
||||
// Wait for any service to complete or signal
|
||||
let mut exit_code = 0;
|
||||
let first_completion = services.join_next().await;
|
||||
|
||||
let service_result = match first_completion {
|
||||
Some(Ok(Ok(service_result))) => {
|
||||
// A service completed successfully
|
||||
match &service_result {
|
||||
ServiceResult::GracefulShutdown => {
|
||||
// This means CTRL+C was pressed
|
||||
}
|
||||
ServiceResult::NormalCompletion => {
|
||||
warn!("A service completed unexpectedly");
|
||||
exit_code = 1;
|
||||
}
|
||||
ServiceResult::Error(e) => {
|
||||
error!("Service failure: {e}");
|
||||
exit_code = 1;
|
||||
}
|
||||
}
|
||||
service_result
|
||||
}
|
||||
Some(Ok(Err(e))) => {
|
||||
error!("Service task panicked: {e}");
|
||||
exit_code = 1;
|
||||
ServiceResult::Error("Task panic".into())
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("JoinSet error: {e}");
|
||||
exit_code = 1;
|
||||
ServiceResult::Error("JoinSet error".into())
|
||||
}
|
||||
None => {
|
||||
warn!("No services running");
|
||||
exit_code = 1;
|
||||
ServiceResult::Error("No services".into())
|
||||
}
|
||||
};
|
||||
|
||||
// Signal all services to shut down
|
||||
shutdown_coordinator.shutdown();
|
||||
|
||||
// Wait for graceful shutdown with timeout
|
||||
let remaining_count = services.len();
|
||||
if remaining_count > 0 {
|
||||
info!("Waiting for {remaining_count} remaining services to shutdown (5s timeout)...");
|
||||
let shutdown_result = tokio::time::timeout(Duration::from_secs(5), async {
|
||||
while let Some(result) = services.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(ServiceResult::GracefulShutdown)) => {
|
||||
// Service shutdown logged by the service itself
|
||||
}
|
||||
Ok(Ok(ServiceResult::NormalCompletion)) => {
|
||||
warn!("Service completed normally during shutdown");
|
||||
}
|
||||
Ok(Ok(ServiceResult::Error(e))) => {
|
||||
error!("Service error during shutdown: {e}");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Service panic during shutdown: {e}");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Service join error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match shutdown_result {
|
||||
Ok(()) => {
|
||||
info!("All services shutdown completed");
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Shutdown timeout - some services may not have completed");
|
||||
exit_code = if exit_code == 0 { 2 } else { exit_code };
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("No remaining services to shutdown");
|
||||
}
|
||||
|
||||
info!("Application shutdown complete (exit code: {})", exit_code);
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"banner/internal/config"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetCurrentTerm(t *testing.T) {
|
||||
// Initialize location for testing
|
||||
loc, _ := time.LoadLocation("America/Chicago")
|
||||
|
||||
// Use current year to avoid issues with global state
|
||||
currentYear := uint16(time.Now().Year())
|
||||
ranges := config.GetYearDayRange(loc, currentYear)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
date time.Time
|
||||
expectedCurrent *config.Term
|
||||
expectedNext *config.Term
|
||||
}{
|
||||
{
|
||||
name: "Spring term",
|
||||
date: time.Date(int(currentYear), 3, 15, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear, Season: config.Spring},
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
|
||||
},
|
||||
{
|
||||
name: "Summer term",
|
||||
date: time.Date(int(currentYear), 6, 15, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear, Season: config.Summer},
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Fall},
|
||||
},
|
||||
{
|
||||
name: "Fall term",
|
||||
date: time.Date(int(currentYear), 9, 15, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear + 1, Season: config.Fall},
|
||||
expectedNext: nil,
|
||||
},
|
||||
{
|
||||
name: "Between Spring and Summer",
|
||||
date: time.Date(int(currentYear), 5, 20, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: nil,
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
|
||||
},
|
||||
{
|
||||
name: "Between Summer and Fall",
|
||||
date: time.Date(int(currentYear), 8, 16, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: nil,
|
||||
expectedNext: &config.Term{Year: currentYear + 1, Season: config.Fall},
|
||||
},
|
||||
{
|
||||
name: "Between Fall and Spring",
|
||||
date: time.Date(int(currentYear), 12, 15, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: nil,
|
||||
expectedNext: &config.Term{Year: currentYear + 1, Season: config.Spring},
|
||||
},
|
||||
{
|
||||
name: "Early January before Spring",
|
||||
date: time.Date(int(currentYear), 1, 10, 12, 0, 0, 0, loc),
|
||||
expectedCurrent: nil,
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Spring},
|
||||
},
|
||||
{
|
||||
name: "Spring start date",
|
||||
date: time.Date(int(currentYear), 1, 14, 0, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear, Season: config.Spring},
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
|
||||
},
|
||||
{
|
||||
name: "Summer start date",
|
||||
date: time.Date(int(currentYear), 5, 25, 0, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear, Season: config.Summer},
|
||||
expectedNext: &config.Term{Year: currentYear, Season: config.Fall},
|
||||
},
|
||||
{
|
||||
name: "Fall start date",
|
||||
date: time.Date(int(currentYear), 8, 18, 0, 0, 0, 0, loc),
|
||||
expectedCurrent: &config.Term{Year: currentYear + 1, Season: config.Fall},
|
||||
expectedNext: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
current, next := config.GetCurrentTerm(ranges, tt.date)
|
||||
|
||||
if !termsEqual(current, tt.expectedCurrent) {
|
||||
t.Errorf("GetCurrentTerm() current = %v, want %v", current, tt.expectedCurrent)
|
||||
}
|
||||
|
||||
if !termsEqual(next, tt.expectedNext) {
|
||||
t.Errorf("GetCurrentTerm() next = %v, want %v", next, tt.expectedNext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetYearDayRange(t *testing.T) {
|
||||
loc, _ := time.LoadLocation("America/Chicago")
|
||||
|
||||
ranges := config.GetYearDayRange(loc, 2024)
|
||||
|
||||
// Verify Spring range (Jan 14 to May 1)
|
||||
expectedSpringStart := time.Date(2024, 1, 14, 0, 0, 0, 0, loc).YearDay()
|
||||
expectedSpringEnd := time.Date(2024, 5, 1, 0, 0, 0, 0, loc).YearDay()
|
||||
|
||||
if ranges.Spring.Start != uint16(expectedSpringStart) {
|
||||
t.Errorf("Spring start = %d, want %d", ranges.Spring.Start, expectedSpringStart)
|
||||
}
|
||||
if ranges.Spring.End != uint16(expectedSpringEnd) {
|
||||
t.Errorf("Spring end = %d, want %d", ranges.Spring.End, expectedSpringEnd)
|
||||
}
|
||||
|
||||
// Verify Summer range (May 25 to Aug 15)
|
||||
expectedSummerStart := time.Date(2024, 5, 25, 0, 0, 0, 0, loc).YearDay()
|
||||
expectedSummerEnd := time.Date(2024, 8, 15, 0, 0, 0, 0, loc).YearDay()
|
||||
|
||||
if ranges.Summer.Start != uint16(expectedSummerStart) {
|
||||
t.Errorf("Summer start = %d, want %d", ranges.Summer.Start, expectedSummerStart)
|
||||
}
|
||||
if ranges.Summer.End != uint16(expectedSummerEnd) {
|
||||
t.Errorf("Summer end = %d, want %d", ranges.Summer.End, expectedSummerEnd)
|
||||
}
|
||||
|
||||
// Verify Fall range (Aug 18 to Dec 10)
|
||||
expectedFallStart := time.Date(2024, 8, 18, 0, 0, 0, 0, loc).YearDay()
|
||||
expectedFallEnd := time.Date(2024, 12, 10, 0, 0, 0, 0, loc).YearDay()
|
||||
|
||||
if ranges.Fall.Start != uint16(expectedFallStart) {
|
||||
t.Errorf("Fall start = %d, want %d", ranges.Fall.Start, expectedFallStart)
|
||||
}
|
||||
if ranges.Fall.End != uint16(expectedFallEnd) {
|
||||
t.Errorf("Fall end = %d, want %d", ranges.Fall.End, expectedFallEnd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTerm(t *testing.T) {
|
||||
tests := []struct {
|
||||
code string
|
||||
expected config.Term
|
||||
}{
|
||||
{"202410", config.Term{Year: 2024, Season: config.Fall}},
|
||||
{"202420", config.Term{Year: 2024, Season: config.Spring}},
|
||||
{"202430", config.Term{Year: 2024, Season: config.Summer}},
|
||||
{"202510", config.Term{Year: 2025, Season: config.Fall}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.code, func(t *testing.T) {
|
||||
result := config.ParseTerm(tt.code)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ParseTerm(%s) = %v, want %v", tt.code, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTermToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
term config.Term
|
||||
expected string
|
||||
}{
|
||||
{config.Term{Year: 2024, Season: config.Fall}, "202410"},
|
||||
{config.Term{Year: 2024, Season: config.Spring}, "202420"},
|
||||
{config.Term{Year: 2024, Season: config.Summer}, "202430"},
|
||||
{config.Term{Year: 2025, Season: config.Fall}, "202510"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := tt.term.ToString()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Term{Year: %d, Season: %d}.ToString() = %s, want %s",
|
||||
tt.term.Year, tt.term.Season, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTerm(t *testing.T) {
|
||||
loc, _ := time.LoadLocation("America/Chicago")
|
||||
ranges := config.GetYearDayRange(loc, 2024)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
date time.Time
|
||||
expected config.Term
|
||||
}{
|
||||
{
|
||||
name: "During Spring term",
|
||||
date: time.Date(2024, 3, 15, 12, 0, 0, 0, loc),
|
||||
expected: config.Term{Year: 2024, Season: config.Spring},
|
||||
},
|
||||
{
|
||||
name: "Between terms - returns next term",
|
||||
date: time.Date(2024, 5, 20, 12, 0, 0, 0, loc),
|
||||
expected: config.Term{Year: 2024, Season: config.Summer},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
currentTerm, nextTerm := config.GetCurrentTerm(ranges, tt.date)
|
||||
var result config.Term
|
||||
if currentTerm == nil {
|
||||
result = *nextTerm
|
||||
} else {
|
||||
result = *currentTerm
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("DefaultTerm() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare terms, handling nil cases
|
||||
func termsEqual(a, b *config.Term) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
Reference in New Issue
Block a user