Compare commits

...

37 Commits

Author SHA1 Message Date
752c855dec chore: drop env prefixed config vars 2025-09-12 22:39:32 -05:00
14b02df8f4 feat: much better JSON logging, project-wide logging improvements, better use of debug/trace levels, field attributes 2025-09-12 22:01:14 -05:00
00cb209052 fix: disable poor error snippet 2025-09-12 21:40:07 -05:00
dfc05a2789 feat: setup rate limiter middleware & config 2025-09-12 21:12:06 -05:00
fe798e1867 fix: avoid COPY of non existent dir, add .dockerignore 2025-09-12 20:57:33 -05:00
39688f800f chore: update Dockerfile rust to 1.89.0 2025-09-12 20:53:24 -05:00
b2b4bb67f0 chore: rustfmt 2025-09-12 20:52:07 -05:00
e5d8cec2d6 refactor: reorganize banner api files, fix clippy lints, reformat 2025-09-12 20:50:47 -05:00
e9a0558535 feat: asynchronous, rate limited term session acquisition 2025-09-12 20:35:12 -05:00
353c36bcf2 feat: 'search' example binary 2025-09-12 20:12:41 -05:00
2f853a7de9 feat: middleware headers, fix concurrent session cookies issue, middleware headers, invalid session details 2025-09-12 20:12:12 -05:00
dd212c3239 chore: update dependencies, add sqlx 'macros', add futures, add 'http' (explicit) 2025-09-12 20:11:13 -05:00
8ff3a18c3e feat: Dockerfile 2025-09-01 00:47:26 -05:00
43647096e9 feat: scraper system 2025-09-01 00:46:38 -05:00
1bdbd1d6d6 chore: remove unused dependencies 2025-09-01 00:26:20 -05:00
23be6035ed feat: much better, smarter session acquisition 2025-08-31 15:34:49 -05:00
139e4aa635 feat: translate over to sqlx, remove diesel 2025-08-31 15:34:49 -05:00
677bb05b87 chore: update & sort dependencies, add sqlx, remove 'migrations' 2025-08-29 12:52:46 -05:00
f2bd02c970 chore: add bacon config 2025-08-29 12:10:57 -05:00
8cdf969a53 feat: command logging, explicit builtin command error handler 2025-08-29 12:10:57 -05:00
4764d48ac9 feat: move scraper into separate module, begin building data models 2025-08-29 11:07:46 -05:00
e734e40347 feat: setup diesel & schema, course with metrics/audit tables 2025-08-27 18:57:43 -05:00
c7117f14a3 feat: smart day string, terse refactor and use types properly, work on unimplemented commands lightly, util modules, 2025-08-27 13:46:41 -05:00
cb8a595326 chore: solve lints, improve formatting 2025-08-27 12:43:43 -05:00
ac70306c04 feat: improve logging, solve lints, improve implementations, remove unused code, standardize things 2025-08-27 12:43:43 -05:00
9972357cf6 feat: implement simple web service, improve ServiceManager encapsulation 2025-08-27 11:58:57 -05:00
2ec899cf25 feat: by CRN querying, redis caching, fixed deserialization, gcal integration 2025-08-27 11:12:08 -05:00
ede064be87 feat: add current term identification, term point state machine 2025-08-27 02:36:59 -05:00
a17bcf0247 fix: broken recurrence, enhanced handling, simpler/terse form 2025-08-27 02:36:59 -05:00
c529bf9727 feat: sort meeting times in gcal command 2025-08-27 00:23:38 -05:00
5ace08327d refactor: clean up MeetingScheduleInfo methods and enhance Term season handling 2025-08-27 00:12:15 -05:00
a01a30d047 feat: continue work on gcal, better meetings schedule types 2025-08-26 23:57:06 -05:00
31ab29c2f1 feat!: first pass re-implementation of banner, gcal command 2025-08-26 21:40:18 -05:00
5018ad0d31 chore: remove dummy service 2025-08-26 19:16:26 -05:00
87100a57d5 feat: service manager for coordination, configureable smart graceful shutdown timeout 2025-08-26 19:16:26 -05:00
cff672b30a feat: use anyhow, refactor services & coordinator out of main.rs 2025-08-26 19:16:26 -05:00
d4c55a3fd8 feat!: begin rust rewrite
service scheduling, configs, all dependencies, tracing, graceful
shutdown, concurrency
2025-08-26 19:16:26 -05:00
69 changed files with 9482 additions and 3682 deletions

73
.dockerignore Normal file
View File

@@ -0,0 +1,73 @@
# Build artifacts
target/
**/target/
# Development files
.env
.env.local
.env.*.local
# IDE and editor files
.vscode/
.idea/
*.swp
*.swo
*~
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Git
.git/
.gitignore
# Documentation
README.md
docs/
*.md
# Go files (since this is a Rust project)
go/
# Database migrations (if not needed at runtime)
migrations/
diesel_migrations/
diesel.toml
# Development configuration
bacon.toml
.cargo/config.toml
# Logs
*.log
logs/
# Temporary files
tmp/
temp/
*.tmp
# Test files
tests/
**/tests/
*_test.rs
*_tests.rs
# Coverage reports
coverage/
*.gcov
*.gcno
*.gcda
# Profiling
*.prof
# Backup files
*.bak
*.backup

12
.gitignore vendored
View File

@@ -1,10 +1,4 @@
.env
cover.cov
/banner
.*.go
dumps/
js/
.vscode/
*.prof
.task/
bin/
/target
/go/
.cargo/config.toml

4312
Cargo.lock generated Normal file
View File

File diff suppressed because it is too large Load Diff

47
Cargo.toml Normal file
View File

@@ -0,0 +1,47 @@
[package]
name = "banner"
version = "0.1.0"
edition = "2024"
default-run = "banner"
[dependencies]
anyhow = "1.0.99"
async-trait = "0.1"
axum = "0.8.4"
bitflags = { version = "2.9.4", features = ["serde"] }
chrono = { version = "0.4.42", features = ["serde"] }
compile-time = "0.2.0"
cookie = "0.18.1"
dashmap = "6.1.0"
dotenvy = "0.15.7"
figment = { version = "0.10.19", features = ["toml", "env"] }
fundu = "2.0.1"
futures = "0.3"
http = "1.3.1"
poise = "0.6.1"
rand = "0.9.2"
redis = { version = "0.32.5", features = ["tokio-comp", "r2d2"] }
regex = "1.10"
reqwest = { version = "0.12.23", features = ["json", "cookies"] }
reqwest-middleware = { version = "0.4.2", features = ["json"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.143"
serenity = { version = "0.12.4", features = ["rustls_backend"] }
sqlx = { version = "0.8.6", features = [
"runtime-tokio-rustls",
"postgres",
"chrono",
"json",
"macros",
] }
thiserror = "2.0.16"
time = "0.3.41"
tokio = { version = "1.47.1", features = ["full"] }
tl = "0.7.8"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
url = "2.5"
governor = "0.10.1"
once_cell = "1.21.3"
[dev-dependencies]

76
Dockerfile Normal file
View File

@@ -0,0 +1,76 @@
# Build Stage
ARG RUST_VERSION=1.89.0
FROM rust:${RUST_VERSION}-bookworm AS builder
# Install build dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /usr/src
RUN USER=root cargo new --bin banner
WORKDIR /usr/src/banner
# Copy dependency files for better layer caching
COPY ./Cargo.toml ./Cargo.lock* ./
# Build empty app with downloaded dependencies to produce a stable image layer for next build
RUN cargo build --release
# Build web app with own code
RUN rm src/*.rs
COPY ./src ./src
RUN rm ./target/release/deps/banner*
RUN cargo build --release
# Strip the binary to reduce size
RUN strip target/release/banner
# Runtime Stage - Debian slim for glibc compatibility
FROM debian:12-slim
ARG APP=/usr/src/app
ARG APP_USER=appuser
ARG UID=1000
ARG GID=1000
# Install runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tzdata \
wget \
&& rm -rf /var/lib/apt/lists/*
ARG TZ=Etc/UTC
ENV TZ=${TZ}
# Create user with specific UID/GID
RUN addgroup --gid $GID $APP_USER \
&& adduser --uid $UID --disabled-password --gecos "" --ingroup $APP_USER $APP_USER \
&& mkdir -p ${APP}
# Copy application files
COPY --from=builder --chown=$APP_USER:$APP_USER /usr/src/banner/target/release/banner ${APP}/banner
# Set proper permissions
RUN chmod +x ${APP}/banner
USER $APP_USER
WORKDIR ${APP}
# Build-time arg for PORT, default to 8000
ARG PORT=8000
# Runtime environment var for PORT, default to build-time arg
ENV PORT=${PORT}
EXPOSE ${PORT}
# Add health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:${PORT}/health || exit 1
# Can be explicitly overriden with different hosts & ports
ENV HOSTS=0.0.0.0,[::]
# Implicitly uses PORT environment variable
CMD ["sh", "-c", "exec ./banner --server ${HOSTS}"]

View File

@@ -1,46 +0,0 @@
version: "3"
tasks:
build:
desc: Build the application
cmds:
- go build -o bin/banner ./cmd/banner
sources:
- ./cmd/banner/**/*.go
- ./internal/**/*.go
generates:
- bin/banner
run:
desc: Run the application
cmds:
- go run ./cmd/banner
deps: [build]
test:
desc: Run tests
cmds:
- go test ./tests/...
env:
ENVIRONMENT: test
test-coverage:
desc: Run tests with coverage
cmds:
- go test -coverpkg=./internal/... -cover ./tests/...
env:
ENVIRONMENT: test
clean:
desc: Clean build artifacts
cmds:
- rm -rf bin/
- go clean -cache
- go clean -modcache
dev:
desc: Run in development mode
cmds:
- go run ./cmd/banner
env:
ENVIRONMENT: development

92
bacon.toml Normal file
View File

@@ -0,0 +1,92 @@
# This is a configuration file for the bacon tool
#
# Complete help on configuration: https://dystroy.org/bacon/config/
#
# You may check the current default at
# https://github.com/Canop/bacon/blob/main/defaults/default-bacon.toml
default_job = "check"
env.CARGO_TERM_COLOR = "always"
[jobs.check]
command = ["cargo", "check"]
need_stdout = false
[jobs.check-all]
command = ["cargo", "check", "--all-targets"]
need_stdout = false
# Run clippy on the default target
[jobs.clippy]
command = ["cargo", "clippy"]
need_stdout = false
# Run clippy on all targets
# To disable some lints, you may change the job this way:
# [jobs.clippy-all]
# command = [
# "cargo", "clippy",
# "--all-targets",
# "--",
# "-A", "clippy::bool_to_int_with_if",
# "-A", "clippy::collapsible_if",
# "-A", "clippy::derive_partial_eq_without_eq",
# ]
# need_stdout = false
[jobs.clippy-all]
command = ["cargo", "clippy", "--all-targets"]
need_stdout = false
# This job lets you run
# - all tests: bacon test
# - a specific test: bacon test -- config::test_default_files
# - the tests of a package: bacon test -- -- -p config
[jobs.test]
command = ["cargo", "test"]
need_stdout = true
[jobs.nextest]
command = [
"cargo", "nextest", "run",
"--hide-progress-bar", "--failure-output", "final"
]
need_stdout = true
analyzer = "nextest"
[jobs.doc]
command = ["cargo", "doc", "--no-deps"]
need_stdout = false
# If the doc compiles, then it opens in your browser and bacon switches
# to the previous job
[jobs.doc-open]
command = ["cargo", "doc", "--no-deps", "--open"]
need_stdout = false
on_success = "back" # so that we don't open the browser at each change
[jobs.run]
command = [
"cargo", "run",
]
need_stdout = true
allow_warnings = true
background = false
on_change_strategy = "kill_then_restart"
# kill = ["pkill", "-TERM", "-P"]'
# This parameterized job runs the example of your choice, as soon
# as the code compiles.
# Call it as
# bacon ex -- my-example
[jobs.ex]
command = ["cargo", "run", "--example"]
need_stdout = true
allow_warnings = true
# You may define here keybindings that would be specific to
# a project, for example a shortcut to launch a specific job.
# Shortcuts to internal functions (scrolling, toggling, etc.)
# should go in your personal global prefs.toml file instead.
[keybindings]
# alt-m = "job:my-job"
c = "job:clippy-all" # comment this to have 'c' run clippy on only the default target

View File

@@ -1,299 +0,0 @@
// Package main is the entry point for the banner application.
package main
import (
"context"
"flag"
"net/http"
"net/http/cookiejar"
_ "net/http/pprof"
"os"
"os/signal"
"strings"
"syscall"
"time"
_ "time/tzdata"
"github.com/bwmarrin/discordgo"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/rs/zerolog/pkgerrors"
"github.com/samber/lo"
"resty.dev/v3"
"banner/internal"
"banner/internal/api"
"banner/internal/bot"
"banner/internal/config"
)
var (
Session *discordgo.Session
)
const (
ICalTimestampFormatUtc = "20060102T150405Z"
ICalTimestampFormatLocal = "20060102T150405"
CentralTimezoneName = "America/Chicago"
)
func init() {
// Load environment variables
if err := godotenv.Load(); err != nil {
log.Debug().Err(err).Msg("Error loading .env file")
}
// Set zerolog's timestamp function to use the central timezone
zerolog.TimestampFunc = func() time.Time {
// TODO: Move this to config
loc, err := time.LoadLocation(CentralTimezoneName)
if err != nil {
panic(err)
}
return time.Now().In(loc)
}
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
// Use the custom console writer if we're in development
isDevelopment := internal.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
if isDevelopment == "" {
isDevelopment = "development"
}
if isDevelopment == "development" {
log.Logger = zerolog.New(config.NewConsoleWriter()).With().Timestamp().Logger()
} else {
log.Logger = zerolog.New(config.LogSplitter{Std: os.Stdout, Err: os.Stderr}).With().Timestamp().Logger()
}
log.Debug().Str("environment", isDevelopment).Msg("Loggers Setup")
// Set discordgo's logger to use zerolog
discordgo.Logger = internal.DiscordGoLogger
}
// initRedis initializes the Redis client and pings the server to ensure a connection.
func initRedis(cfg *config.Config) {
// Setup redis
redisUrl := internal.GetFirstEnv("REDIS_URL", "REDIS_PRIVATE_URL")
if redisUrl == "" {
log.Fatal().Stack().Msg("REDIS_URL/REDIS_PRIVATE_URL not set")
}
// Parse URL and create client
options, err := redis.ParseURL(redisUrl)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot parse redis url")
}
kv := redis.NewClient(options)
cfg.SetRedis(kv)
var lastPingErr error
pingCount := 0 // Nth ping being attempted
totalPings := 5 // Total pings to attempt
// Wait for private networking to kick in (production only)
if !cfg.IsDevelopment {
time.Sleep(250 * time.Millisecond)
}
// Test the redis instance, try to ping every 2 seconds 5 times, otherwise panic
for {
pingCount++
if pingCount > totalPings {
log.Fatal().Stack().Err(lastPingErr).Msg("Reached ping limit while trying to connect")
}
// Ping redis
pong, err := cfg.KV.Ping(cfg.Ctx).Result()
// Failed; log error and wait 2 seconds
if err != nil {
lastPingErr = err
log.Warn().Err(err).Int("pings", pingCount).Int("remaining", totalPings-pingCount).Msg("Cannot ping redis")
time.Sleep(2 * time.Second)
continue
}
log.Debug().Str("ping", pong).Msg("Redis connection successful")
break
}
}
func main() {
flag.Parse()
cfg, err := config.New()
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot create config")
}
// Try to grab the environment variable, or default to development
environment := internal.GetFirstEnv("ENVIRONMENT", "RAILWAY_ENVIRONMENT")
if environment == "" {
environment = "development"
}
cfg.SetEnvironment(environment)
initRedis(cfg)
if strings.EqualFold(os.Getenv("PPROF_ENABLE"), "true") {
// Start pprof server with graceful shutdown
go func() {
port := os.Getenv("PORT")
log.Info().Str("port", port).Msg("Starting pprof server")
server := &http.Server{
Addr: ":" + port,
}
// Start server in a separate goroutine
go func() {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal().Stack().Err(err).Msg("Cannot start pprof server")
}
}()
// Wait for context cancellation and then shutdown
<-cfg.Ctx.Done()
log.Info().Msg("Shutting down pprof server")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Pprof server forced to shutdown")
}
}()
}
// Create cookie jar
cookies, err := cookiejar.New(nil)
if err != nil {
log.Err(err).Msg("Cannot create cookie jar")
}
// Create Resty client with timeout and cookie jar
baseURL := os.Getenv("BANNER_BASE_URL")
client := resty.New().
SetBaseURL(baseURL).
SetTimeout(30*time.Second).
SetCookieJar(cookies).
SetHeader("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36").
AddResponseMiddleware(api.SessionMiddleware)
cfg.SetClient(client)
cfg.SetBaseURL(baseURL)
apiInstance := api.New(cfg)
apiInstance.Setup()
// Create discord session
session, err := discordgo.New("Bot " + os.Getenv("BOT_TOKEN"))
if err != nil {
log.Err(err).Msg("Invalid bot parameters")
}
botInstance := bot.New(session, apiInstance, cfg)
botInstance.RegisterHandlers()
// Open discord session
session.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) {
log.Info().Str("username", r.User.Username).Str("discriminator", r.User.Discriminator).Str("id", r.User.ID).Str("session", s.State.SessionID).Msg("Bot is logged in")
})
err = session.Open()
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot open the session")
}
// Setup command handlers
// Register commands with discord
arr := zerolog.Arr()
lo.ForEach(bot.CommandDefinitions, func(cmd *discordgo.ApplicationCommand, _ int) {
arr.Str(cmd.Name)
})
log.Info().Array("commands", arr).Msg("Registering commands")
// In development, use test server, otherwise empty (global) for command registration
guildTarget := ""
if cfg.IsDevelopment {
guildTarget = os.Getenv("BOT_TARGET_GUILD")
}
// Register commands
existingCommands, err := session.ApplicationCommands(session.State.User.ID, guildTarget)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot get existing commands")
}
newCommands, err := session.ApplicationCommandBulkOverwrite(session.State.User.ID, guildTarget, bot.CommandDefinitions)
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot register commands")
}
// Compare existing commands with new commands
for _, newCommand := range newCommands {
existingCommand, found := lo.Find(existingCommands, func(cmd *discordgo.ApplicationCommand) bool {
return cmd.Name == newCommand.Name
})
// New command
if !found {
log.Info().Str("commandName", newCommand.Name).Msg("Registered new command")
continue
}
// Compare versions
if newCommand.Version != existingCommand.Version {
log.Info().Str("commandName", newCommand.Name).
Str("oldVersion", existingCommand.Version).Str("newVersion", newCommand.Version).
Msg("Command Updated")
}
}
// Fetch terms on startup
err = apiInstance.TryReloadTerms()
if err != nil {
log.Fatal().Stack().Err(err).Msg("Cannot fetch terms on startup")
}
// Launch a goroutine to scrape the banner system periodically
go func() {
ticker := time.NewTicker(3 * time.Minute)
defer ticker.Stop()
for {
select {
case <-cfg.Ctx.Done():
log.Info().Msg("Periodic scraper stopped due to context cancellation")
return
case <-ticker.C:
err := apiInstance.Scrape()
if err != nil {
log.Err(err).Stack().Msg("Periodic Scrape Failed")
}
}
}
}()
// Close session, ensure Resty client closes
defer session.Close()
defer client.Close()
// Setup signal handler channel
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt) // Ctrl+C signal
signal.Notify(stop, syscall.SIGTERM) // Container stop signal
// Wait for signal (indefinite)
closingSignal := <-stop
botInstance.SetClosing() // TODO: Switch to atomic lock with forced close after 10 seconds
// Cancel the context to signal all operations to stop
cfg.CancelFunc()
// Defers are called after this
log.Warn().Str("signal", closingSignal.String()).Msg("Gracefully shutting down")
}

9
diesel.toml Normal file
View File

@@ -0,0 +1,9 @@
# For documentation on how to configure this file,
# see https://diesel.rs/guides/configuring-diesel-cli
[print_schema]
file = "src/data/schema.rs"
custom_type_derives = ["diesel::query_builder::QueryId", "Clone"]
[migrations_directory]
dir = "migrations"

27
go.mod
View File

@@ -1,27 +0,0 @@
module banner
go 1.24.0
toolchain go1.24.2
require (
github.com/bwmarrin/discordgo v0.29.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/redis/go-redis/v9 v9.12.1
github.com/rs/zerolog v1.34.0
github.com/samber/lo v1.51.0
resty.dev/v3 v3.0.0-beta.3
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
)

52
go.sum
View File

@@ -1,52 +0,0 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/samber/lo v1.51.0 h1:kysRYLbHy/MB7kQZf5DSN50JHmMsNEdeY24VzJFu7wI=
github.com/samber/lo v1.51.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
resty.dev/v3 v3.0.0-beta.3 h1:3kEwzEgCnnS6Ob4Emlk94t+I/gClyoah7SnNi67lt+E=
resty.dev/v3 v3.0.0-beta.3/go.mod h1:OgkqiPvTDtOuV4MGZuUDhwOpkY8enjOsjjMzeOHefy4=

View File

@@ -1,491 +0,0 @@
package api
import (
"banner/internal"
"banner/internal/config"
"banner/internal/models"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"resty.dev/v3"
)
// API provides a client for interacting with the Banner API.
type API struct {
config *config.Config
}
// New creates a new API client with the given configuration.
func New(config *config.Config) *API {
return &API{config: config}
}
var (
latestSession string
sessionTime time.Time
expiryTime = 25 * time.Minute
)
// SessionMiddleware creates a Resty middleware that resets the session timer on each successful Banner API call.
func SessionMiddleware(_ *resty.Client, r *resty.Response) error {
// log.Debug().Str("url", r.Request.RawRequest.URL.Path).Msg("Session middleware")
// Reset session timer on successful requests to Banner API endpoints
if r.IsSuccess() && strings.HasPrefix(r.Request.RawRequest.URL.Path, "StudentRegistrationSsb/ssb/classSearch/") {
// Only reset the session time if the session is still valid
if time.Since(sessionTime) <= expiryTime {
sessionTime = time.Now()
}
}
return nil
}
// GenerateSession generates a new session ID for use with the Banner API.
// This function should not be used directly; use EnsureSession instead.
func GenerateSession() string {
return internal.RandomString(5) + internal.Nonce()
}
// DefaultTerm returns the default term, which is the current term if it exists, otherwise the next term.
func (a *API) DefaultTerm(t time.Time) config.Term {
currentTerm, nextTerm := config.GetCurrentTerm(*a.config.SeasonRanges, t)
if currentTerm == nil {
return *nextTerm
}
return *currentTerm
}
var terms []BannerTerm
var lastTermUpdate time.Time
// TryReloadTerms attempts to reload the terms if they are not loaded or if the last update was more than 24 hours ago.
func (a *API) TryReloadTerms() error {
if len(terms) > 0 && time.Since(lastTermUpdate) < 24*time.Hour {
return nil
}
// Load the terms
var err error
terms, err = a.GetTerms("", 1, 100)
if err != nil {
return fmt.Errorf("failed to load terms: %w", err)
}
lastTermUpdate = time.Now()
return nil
}
// IsTermArchived checks if the given term is archived (view only).
//
// TODO: Add error handling for when a term does not exist.
func (a *API) IsTermArchived(term string) bool {
// Ensure the terms are loaded
err := a.TryReloadTerms()
if err != nil {
log.Err(err).Stack().Msg("Failed to reload terms")
return true
}
// Check if the term is in the list of terms
bannerTerm, exists := lo.Find(terms, func(t BannerTerm) bool {
return t.Code == term
})
if !exists {
log.Warn().Str("term", term).Msg("Term does not exist")
return true
}
return bannerTerm.Archived()
}
// EnsureSession ensures that a valid session is available, creating one if necessary.
func (a *API) EnsureSession() string {
if latestSession == "" || time.Since(sessionTime) >= expiryTime {
latestSession = GenerateSession()
sessionTime = time.Now()
}
return latestSession
}
// Pair represents a key-value pair from the Banner API.
type Pair struct {
Code string `json:"code"`
Description string `json:"description"`
}
// BannerTerm represents a term in the Banner system.
type BannerTerm Pair
// Instructor represents an instructor in the Banner system.
type Instructor Pair
// Archived returns true if the term is in an archival (view-only) state.
func (term BannerTerm) Archived() bool {
return strings.Contains(term.Description, "View Only")
}
// GetTerms retrieves a list of terms from the Banner API.
// The page number must be at least 1.
func (a *API) GetTerms(search string, page int, maxResults int) ([]BannerTerm, error) {
// Ensure offset is valid
if page <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("offset", strconv.Itoa(page)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]BannerTerm{})
res, err := req.Get("/classSearch/getTerms")
if err != nil {
return nil, fmt.Errorf("failed to get terms: %w", err)
}
terms, ok := res.Result().(*[]BannerTerm)
if !ok {
return nil, fmt.Errorf("terms parsing failed to cast: %v", res.Result())
}
return *terms, nil
}
// SelectTerm selects a term in the Banner system for the given session.
// This is required before other API calls can be made.
func (a *API) SelectTerm(term string, sessionID string) error {
form := url.Values{
"term": {term},
"studyPath": {""},
"studyPathText": {""},
"startDatepicker": {""},
"endDatepicker": {""},
"uniqueSessionId": {sessionID},
}
type RedirectResponse struct {
FwdURL string `json:"fwdUrl"`
}
req := a.config.Client.NewRequest().
SetResult(&RedirectResponse{}).
SetQueryParam("mode", "search").
SetBody(form.Encode()).
SetExpectResponseContentType("application/json").
SetHeader("Content-Type", "application/x-www-form-urlencoded")
res, err := req.Post("/term/search")
if err != nil {
return fmt.Errorf("failed to select term: %w", err)
}
redirectResponse := res.Result().(*RedirectResponse)
// TODO: Mild validation to ensure the redirect is appropriate
// Make a GET request to the fwdUrl
req = a.config.Client.NewRequest()
res, err = req.Get(redirectResponse.FwdURL)
// Assert that the response is OK (200)
if res.StatusCode() != 200 {
return fmt.Errorf("redirect response was not OK: %d", res.StatusCode())
}
return nil
}
// GetPartOfTerms retrieves a list of parts of a term from the Banner API.
// The page number must be at least 1.
func (a *API) GetPartOfTerms(search string, term int, offset int, maxResults int) ([]BannerTerm, error) {
// Ensure offset is valid
if offset <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("term", strconv.Itoa(term)).
SetQueryParam("offset", strconv.Itoa(offset)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("uniqueSessionId", a.EnsureSession()).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]BannerTerm{})
res, err := req.Get("/classSearch/get_partOfTerm")
if err != nil {
return nil, fmt.Errorf("failed to get part of terms: %w", err)
}
terms, ok := res.Result().(*[]BannerTerm)
if !ok {
return nil, fmt.Errorf("term parsing failed to cast: %v", res.Result())
}
return *terms, nil
}
// GetInstructors retrieves a list of instructors from the Banner API.
func (a *API) GetInstructors(search string, term string, offset int, maxResults int) ([]Instructor, error) {
// Ensure offset is valid
if offset <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("term", term).
SetQueryParam("offset", strconv.Itoa(offset)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("uniqueSessionId", a.EnsureSession()).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]Instructor{})
res, err := req.Get("/classSearch/get_instructor")
if err != nil {
return nil, fmt.Errorf("failed to get instructors: %w", err)
}
instructors, ok := res.Result().(*[]Instructor)
if !ok {
return nil, fmt.Errorf("instructor parsing failed to cast: %v", res.Result())
}
return *instructors, nil
}
// ClassDetails represents the detailed information for a class.
//
// TODO: Implement this struct and the associated GetCourseDetails function.
type ClassDetails struct {
}
// GetCourseDetails retrieves the details for a specific course.
func (a *API) GetCourseDetails(term int, crn int) (*ClassDetails, error) {
body, err := json.Marshal(map[string]string{
"term": strconv.Itoa(term),
"courseReferenceNumber": strconv.Itoa(crn),
"first": "first", // TODO: What is this?
})
if err != nil {
log.Fatal().Stack().Err(err).Msg("Failed to marshal body")
}
req := a.config.Client.NewRequest().
SetBody(body).
SetExpectResponseContentType("application/json").
SetResult(&ClassDetails{})
res, err := req.Get("/searchResults/getClassDetails")
if err != nil {
return nil, fmt.Errorf("failed to get course details: %w", err)
}
details, ok := res.Result().(*ClassDetails)
if !ok {
return nil, fmt.Errorf("course details parsing failed to cast: %v", res.Result())
}
return details, nil
}
// Search performs a search for courses with the given query and returns the results.
func (a *API) Search(term string, query *Query, sort string, sortDescending bool) (*models.SearchResult, error) {
a.ResetDataForm()
params := query.Paramify()
params["txt_term"] = term
params["uniqueSessionId"] = a.EnsureSession()
params["sortColumn"] = sort
params["sortDirection"] = "asc"
// These dates are not available for usage anywhere in the UI, but are included in every query
params["startDatepicker"] = ""
params["endDatepicker"] = ""
req := a.config.Client.NewRequest().
SetQueryParams(params).
SetExpectResponseContentType("application/json").
SetResult(&models.SearchResult{})
res, err := req.Get("/searchResults/searchResults")
if err != nil {
return nil, fmt.Errorf("failed to search: %w", err)
}
searchResult, ok := res.Result().(*models.SearchResult)
if !ok {
return nil, fmt.Errorf("search result parsing failed to cast: %v", res.Result())
}
return searchResult, nil
}
// GetSubjects retrieves a list of subjects from the Banner API.
// The page number must be at least 1.
func (a *API) GetSubjects(search string, term string, offset int, maxResults int) ([]Pair, error) {
// Ensure offset is valid
if offset <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("term", term).
SetQueryParam("offset", strconv.Itoa(offset)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("uniqueSessionId", a.EnsureSession()).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]Pair{})
res, err := req.Get("/classSearch/get_subject")
if err != nil {
return nil, fmt.Errorf("failed to get subjects: %w", err)
}
subjects, ok := res.Result().(*[]Pair)
if !ok {
return nil, fmt.Errorf("subjects parsing failed to cast: %v", res.Result())
}
return *subjects, nil
}
// GetCampuses retrieves a list of campuses from the Banner API.
// The page number must be at least 1.
func (a *API) GetCampuses(search string, term int, offset int, maxResults int) ([]Pair, error) {
// Ensure offset is valid
if offset <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("term", strconv.Itoa(term)).
SetQueryParam("offset", strconv.Itoa(offset)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("uniqueSessionId", a.EnsureSession()).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]Pair{})
res, err := req.Get("/classSearch/get_campus")
if err != nil {
return nil, fmt.Errorf("failed to get campuses: %w", err)
}
campuses, ok := res.Result().(*[]Pair)
if !ok {
return nil, fmt.Errorf("campuses parsing failed to cast: %v", res.Result())
}
return *campuses, nil
}
// GetInstructionalMethods retrieves a list of instructional methods from the Banner API.
// The page number must be at least 1.
func (a *API) GetInstructionalMethods(search string, term string, offset int, maxResults int) ([]Pair, error) {
// Ensure offset is valid
if offset <= 0 {
return nil, errors.New("offset must be greater than 0")
}
req := a.config.Client.NewRequest().
SetQueryParam("searchTerm", search).
SetQueryParam("term", term).
SetQueryParam("offset", strconv.Itoa(offset)).
SetQueryParam("max", strconv.Itoa(maxResults)).
SetQueryParam("uniqueSessionId", a.EnsureSession()).
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json").
SetResult(&[]Pair{})
res, err := req.Get("/classSearch/get_instructionalMethod")
if err != nil {
return nil, fmt.Errorf("failed to get instructional methods: %w", err)
}
methods, ok := res.Result().(*[]Pair)
if !ok {
return nil, fmt.Errorf("instructional methods parsing failed to cast: %v", res.Result())
}
return *methods, nil
}
// GetCourseMeetingTime retrieves the meeting time information for a course.
func (a *API) GetCourseMeetingTime(term int, crn int) ([]models.MeetingTimeResponse, error) {
type responseWrapper struct {
Fmt []models.MeetingTimeResponse `json:"fmt"`
}
req := a.config.Client.NewRequest().
SetQueryParam("term", strconv.Itoa(term)).
SetQueryParam("courseReferenceNumber", strconv.Itoa(crn)).
SetExpectResponseContentType("application/json").
SetResult(&responseWrapper{})
res, err := req.Get("/searchResults/getFacultyMeetingTimes")
if err != nil {
return nil, fmt.Errorf("failed to get meeting time: %w", err)
}
result, ok := res.Result().(*responseWrapper)
if !ok {
return nil, fmt.Errorf("meeting times parsing failed to cast: %v", res.Result())
}
return result.Fmt, nil
}
// ResetDataForm resets the search form in the Banner system.
// This must be called before a new search can be performed.
func (a *API) ResetDataForm() {
req := a.config.Client.NewRequest()
_, err := req.Post("/classSearch/resetDataForm")
if err != nil {
log.Fatal().Stack().Err(err).Msg("Failed to reset data form")
}
}
// GetCourse retrieves course information from the Redis cache.
func (a *API) GetCourse(crn string) (*models.Course, error) {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
// Retrieve raw data
result, err := a.config.KV.Get(ctx, fmt.Sprintf("class:%s", crn)).Result()
if err != nil {
if err == redis.Nil {
return nil, fmt.Errorf("course not found: %w", err)
}
return nil, fmt.Errorf("failed to get course: %w", err)
}
// Unmarshal the raw data
var course models.Course
err = json.Unmarshal([]byte(result), &course)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal course: %w", err)
}
return &course, nil
}

View File

@@ -1,240 +0,0 @@
// Package api provides the core functionality for interacting with the Banner API.
package api
import (
"banner/internal"
"banner/internal/models"
"context"
"fmt"
"math/rand"
"time"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
)
const (
// MaxPageSize is the maximum number of courses one can scrape per page.
MaxPageSize = 500
)
var (
// PriorityMajors is a list of majors that are considered to be high priority for scraping.
// This list is used to determine which majors to scrape first/most often.
PriorityMajors = []string{"CS", "CPE", "MAT", "EE", "IS"}
// AncillaryMajors is a list of majors that are considered to be low priority for scraping.
// This list will not contain any majors that are in PriorityMajors.
AncillaryMajors []string
// AllMajors is a list of all majors that are available in the Banner system.
AllMajors []string
)
// Scrape retrieves all courses from the Banner API and stores them in Redis.
// This is a long-running process that should be run in a goroutine.
//
// TODO: Switch from hardcoded term to dynamic term
func (a *API) Scrape() error {
// For each subject, retrieve all courses
// For each course, get the details and store it in redis
// Make sure to handle pagination
subjects, err := a.GetSubjects("", "202510", 1, 100)
if err != nil {
return fmt.Errorf("failed to get subjects: %w", err)
}
// Ensure subjects were found
if len(subjects) == 0 {
return fmt.Errorf("no subjects found")
}
// Extract major code name
for _, subject := range subjects {
// Add to AncillaryMajors if not in PriorityMajors
if !lo.Contains(PriorityMajors, subject.Code) {
AncillaryMajors = append(AncillaryMajors, subject.Code)
}
}
AllMajors = lo.Flatten([][]string{PriorityMajors, AncillaryMajors})
expiredSubjects, err := a.GetExpiredSubjects()
if err != nil {
return fmt.Errorf("failed to get scrapable majors: %w", err)
}
log.Info().Strs("majors", expiredSubjects).Msg("Scraping majors")
for _, subject := range expiredSubjects {
err := a.ScrapeMajor(subject)
if err != nil {
return fmt.Errorf("failed to scrape major %s: %w", subject, err)
}
}
return nil
}
// GetExpiredSubjects returns a list of subjects that have expired and should be scraped again.
// It checks Redis for the "scraped" status of each major for the current term.
func (a *API) GetExpiredSubjects() ([]string, error) {
term := a.DefaultTerm(time.Now()).ToString()
subjects := make([]string, 0)
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 10*time.Second)
defer cancel()
// Get all subjects
values, err := a.config.KV.MGet(ctx, lo.Map(AllMajors, func(major string, _ int) string {
return fmt.Sprintf("scraped:%s:%s", major, term)
})...).Result()
if err != nil {
return nil, fmt.Errorf("failed to get all subjects: %w", err)
}
// Extract expired subjects
for i, value := range values {
subject := AllMajors[i]
// If the value is nil or "0", then the subject is expired
if value == nil || value == "0" {
subjects = append(subjects, subject)
}
}
log.Debug().Strs("majors", subjects).Msg("Expired Subjects")
return subjects, nil
}
// ScrapeMajor scrapes all courses for a specific major.
// This function does not check whether scraping is required at this time; it is assumed that the caller has already done so.
func (a *API) ScrapeMajor(subject string) error {
offset := 0
totalClassCount := 0
for {
// Build & execute the query
query := NewQuery().Offset(offset).MaxResults(MaxPageSize * 2).Subject(subject)
term := a.DefaultTerm(time.Now()).ToString()
result, err := a.Search(term, query, "subjectDescription", false)
if err != nil {
return fmt.Errorf("search failed: %w (%s)", err, query.String())
}
// Isn't it bullshit that they decided not to leave an actual 'reason' field for the failure?
if !result.Success {
return fmt.Errorf("result marked unsuccessful when searching for classes (%s)", query.String())
}
classCount := len(result.Data)
totalClassCount += classCount
log.Debug().Str("subject", subject).Int("count", classCount).Int("offset", offset).Msg("Placing classes in Redis")
// Process each class and store it in Redis
for _, course := range result.Data {
// Store class in Redis
err := a.IntakeCourse(course)
if err != nil {
log.Error().Err(err).Msg("failed to store class in Redis")
}
}
// Increment and continue if the results are full
if classCount >= MaxPageSize {
// This is unlikely to happen, but log it just in case
if classCount > MaxPageSize {
log.Warn().Int("page", offset).Int("count", classCount).Msg("Results exceed MaxPageSize")
}
offset += MaxPageSize
// TODO: Replace sleep with smarter rate limiting
log.Debug().Str("subject", subject).Int("nextOffset", offset).Msg("Sleeping before next page")
time.Sleep(time.Second * 3)
continue
}
// Log the number of classes scraped
log.Info().Str("subject", subject).Int("total", totalClassCount).Msgf("Subject %s Scraped", subject)
break
}
term := a.DefaultTerm(time.Now()).ToString()
// Calculate the expiry time for the scrape (1 hour for every 200 classes, random +-15%) with a minimum of 1 hour
var scrapeExpiry time.Duration
if totalClassCount == 0 {
scrapeExpiry = time.Hour * 12
} else {
scrapeExpiry = a.CalculateExpiry(term, totalClassCount, lo.Contains(PriorityMajors, subject))
}
// Mark the major as scraped
if totalClassCount == 0 {
totalClassCount = -1
}
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
err := a.config.KV.Set(ctx, fmt.Sprintf("scraped:%s:%s", subject, term), totalClassCount, scrapeExpiry).Err()
if err != nil {
log.Error().Err(err).Msg("failed to mark major as scraped")
}
return nil
}
// CalculateExpiry calculates the expiry time until the next scrape for a major.
// The duration is based on the number of courses, whether the major is a priority, and if the term is archived.
func (a *API) CalculateExpiry(term string, count int, priority bool) time.Duration {
// An hour for every 100 classes
baseExpiry := time.Hour * time.Duration(count/100)
// Subjects with less than 50 classes have a reversed expiry (less classes, longer interval)
// 1 class => 12 hours, 49 classes => 1 hour
if count < 50 {
hours := internal.Slope(internal.Point{X: 1, Y: 12}, internal.Point{X: 49, Y: 1}, float64(count)).Y
baseExpiry = time.Duration(hours * float64(time.Hour))
}
// If the subject is a priority, then the expiry is halved without variance
if priority {
return baseExpiry / 3
}
// If the term is considered "view only" or "archived", then the expiry is multiplied by 5
var expiry = baseExpiry
if a.IsTermArchived(term) {
expiry *= 5
}
// Add minor variance to the expiry
expiryVariance := baseExpiry.Seconds() * (rand.Float64() * 0.15) // Between 0 and 15% of the total
if rand.Intn(2) == 0 {
expiry -= time.Duration(expiryVariance) * time.Second
} else {
expiry += time.Duration(expiryVariance) * time.Second
}
// Ensure the expiry is at least 1 hour with up to 15 extra minutes
if expiry < time.Hour {
baseExpiry = time.Hour + time.Duration(rand.Intn(60*15))*time.Second
}
return baseExpiry
}
// IntakeCourse stores a course in Redis.
// This function will be used to handle change identification, notifications, and SQLite upserts in the future.
func (a *API) IntakeCourse(course models.Course) error {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(a.config.Ctx, 5*time.Second)
defer cancel()
err := a.config.KV.Set(ctx, fmt.Sprintf("class:%s", course.CourseReferenceNumber), course, 0).Err()
if err != nil {
return fmt.Errorf("failed to store class in Redis: %w", err)
}
return nil
}

View File

@@ -1,350 +0,0 @@
package api
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/samber/lo"
)
const (
paramSubject = "txt_subject"
paramTitle = "txt_courseTitle"
paramKeywords = "txt_keywordlike"
paramOpenOnly = "chk_open_only"
paramTermPart = "txt_partOfTerm"
paramCampus = "txt_campus"
paramAttributes = "txt_attribute"
paramInstructor = "txt_instructor"
paramStartTimeHour = "select_start_hour"
paramStartTimeMinute = "select_start_min"
paramStartTimeMeridiem = "select_start_ampm"
paramEndTimeHour = "select_end_hour"
paramEndTimeMinute = "select_end_min"
paramEndTimeMeridiem = "select_end_ampm"
paramMinCredits = "txt_credithourlow"
paramMaxCredits = "txt_credithourhigh"
paramCourseNumberLow = "txt_course_number_range"
paramCourseNumberHigh = "txt_course_number_range_to"
paramOffset = "pageOffset"
paramMaxResults = "pageMaxSize"
)
// Query represents a search query for courses.
// It is a builder that allows for chaining methods to construct a query.
type Query struct {
subject *string
title *string
keywords *[]string
openOnly *bool
termPart *[]string // e.g. [1, B6, 8, J]
campus *[]string // e.g. [9, 1DT, 1LR]
instructionalMethod *[]string // e.g. [HB]
attributes *[]string // e.g. [060, 010]
instructor *[]uint64 // e.g. [27957, 27961]
startTime *time.Duration
endTime *time.Duration
minCredits *int
maxCredits *int
offset int
maxResults int
courseNumberRange *Range
}
// NewQuery creates a new Query with default values.
func NewQuery() *Query {
return &Query{maxResults: 8, offset: 0}
}
// Subject sets the subject for the query.
func (q *Query) Subject(subject string) *Query {
q.subject = &subject
return q
}
// Title sets the title for the query.
func (q *Query) Title(title string) *Query {
q.title = &title
return q
}
// Keywords sets the keywords for the query.
func (q *Query) Keywords(keywords []string) *Query {
q.keywords = &keywords
return q
}
// Keyword adds a keyword to the query.
func (q *Query) Keyword(keyword string) *Query {
if q.keywords == nil {
q.keywords = &[]string{keyword}
} else {
*q.keywords = append(*q.keywords, keyword)
}
return q
}
// OpenOnly sets whether to search for open courses only.
func (q *Query) OpenOnly(openOnly bool) *Query {
q.openOnly = &openOnly
return q
}
// TermPart sets the term part for the query.
func (q *Query) TermPart(termPart []string) *Query {
q.termPart = &termPart
return q
}
// Campus sets the campuses for the query.
func (q *Query) Campus(campus []string) *Query {
q.campus = &campus
return q
}
// InstructionalMethod sets the instructional methods for the query.
func (q *Query) InstructionalMethod(instructionalMethod []string) *Query {
q.instructionalMethod = &instructionalMethod
return q
}
// Attributes sets the attributes for the query.
func (q *Query) Attributes(attributes []string) *Query {
q.attributes = &attributes
return q
}
// Instructor sets the instructors for the query.
func (q *Query) Instructor(instructor []uint64) *Query {
q.instructor = &instructor
return q
}
// StartTime sets the start time for the query.
func (q *Query) StartTime(startTime time.Duration) *Query {
q.startTime = &startTime
return q
}
// EndTime sets the end time for the query.
func (q *Query) EndTime(endTime time.Duration) *Query {
q.endTime = &endTime
return q
}
// Credits sets the credit range for the query.
func (q *Query) Credits(low int, high int) *Query {
q.minCredits = &low
q.maxCredits = &high
return q
}
// MinCredits sets the minimum credits for the query.
func (q *Query) MinCredits(value int) *Query {
q.minCredits = &value
return q
}
// MaxCredits sets the maximum credits for the query.
func (q *Query) MaxCredits(value int) *Query {
q.maxCredits = &value
return q
}
// CourseNumbers sets the course number range for the query.
func (q *Query) CourseNumbers(low int, high int) *Query {
q.courseNumberRange = &Range{low, high}
return q
}
// Offset sets the offset for pagination.
func (q *Query) Offset(offset int) *Query {
q.offset = offset
return q
}
// MaxResults sets the maximum number of results to return.
func (q *Query) MaxResults(maxResults int) *Query {
q.maxResults = maxResults
return q
}
// Range represents a range of two integers.
type Range struct {
Low int
High int
}
// FormatTimeParameter formats a time.Duration into a tuple of strings for use in a POST request.
// It returns the hour, minute, and meridiem (AM/PM) as separate strings.
func FormatTimeParameter(d time.Duration) (string, string, string) {
hourParameter, minuteParameter, meridiemParameter := "", "", ""
hours := int64(d.Hours())
minutes := int64(d.Minutes()) % 60
minuteParameter = strconv.FormatInt(minutes, 10)
if hours >= 12 {
hourParameter = "PM"
// Exceptional case: 12PM = 12, 1PM = 1, 2PM = 2
if hours >= 13 {
hourParameter = strconv.FormatInt(hours-12, 10) // 13 - 12 = 1, 14 - 12 = 2
} else {
hourParameter = strconv.FormatInt(hours, 10)
}
} else {
meridiemParameter = "AM"
hourParameter = strconv.FormatInt(hours, 10)
}
return hourParameter, minuteParameter, meridiemParameter
}
// Paramify converts a Query into a map of parameters for a POST request.
// This function assumes each query key only appears once.
func (q *Query) Paramify() map[string]string {
params := map[string]string{}
if q.subject != nil {
params[paramSubject] = *q.subject
}
if q.title != nil {
// Whitespace can prevent valid queries from succeeding
params[paramTitle] = strings.TrimSpace(*q.title)
}
if q.keywords != nil {
params[paramKeywords] = strings.Join(*q.keywords, " ")
}
if q.openOnly != nil {
params[paramOpenOnly] = "true"
}
if q.termPart != nil {
params[paramTermPart] = strings.Join(*q.termPart, ",")
}
if q.campus != nil {
params[paramCampus] = strings.Join(*q.campus, ",")
}
if q.attributes != nil {
params[paramAttributes] = strings.Join(*q.attributes, ",")
}
if q.instructor != nil {
params[paramInstructor] = strings.Join(lo.Map(*q.instructor, func(i uint64, _ int) string {
return strconv.FormatUint(i, 10)
}), ",")
}
if q.startTime != nil {
hour, minute, meridiem := FormatTimeParameter(*q.startTime)
params[paramStartTimeHour] = hour
params[paramStartTimeMinute] = minute
params[paramStartTimeMeridiem] = meridiem
}
if q.endTime != nil {
hour, minute, meridiem := FormatTimeParameter(*q.endTime)
params[paramEndTimeHour] = hour
params[paramEndTimeMinute] = minute
params[paramEndTimeMeridiem] = meridiem
}
if q.minCredits != nil {
params[paramMinCredits] = strconv.Itoa(*q.minCredits)
}
if q.maxCredits != nil {
params[paramMaxCredits] = strconv.Itoa(*q.maxCredits)
}
if q.courseNumberRange != nil {
params[paramCourseNumberLow] = strconv.Itoa(q.courseNumberRange.Low)
params[paramCourseNumberHigh] = strconv.Itoa(q.courseNumberRange.High)
}
params[paramOffset] = strconv.Itoa(q.offset)
params[paramMaxResults] = strconv.Itoa(q.maxResults)
return params
}
// String returns a string representation of the query, ideal for debugging & logging.
func (q *Query) String() string {
var sb strings.Builder
if q.subject != nil {
fmt.Fprintf(&sb, "subject=%s, ", *q.subject)
}
if q.title != nil {
// Whitespace can prevent valid queries from succeeding
fmt.Fprintf(&sb, "title=%s, ", strings.TrimSpace(*q.title))
}
if q.keywords != nil {
fmt.Fprintf(&sb, "keywords=%s, ", strings.Join(*q.keywords, " "))
}
if q.openOnly != nil {
fmt.Fprintf(&sb, "openOnly=%t, ", *q.openOnly)
}
if q.termPart != nil {
fmt.Fprintf(&sb, "termPart=%s, ", strings.Join(*q.termPart, ","))
}
if q.campus != nil {
fmt.Fprintf(&sb, "campus=%s, ", strings.Join(*q.campus, ","))
}
if q.attributes != nil {
fmt.Fprintf(&sb, "attributes=%s, ", strings.Join(*q.attributes, ","))
}
if q.instructor != nil {
fmt.Fprintf(&sb, "instructor=%s, ", strings.Join(lo.Map(*q.instructor, func(i uint64, _ int) string {
return strconv.FormatUint(i, 10)
}), ","))
}
if q.startTime != nil {
hour, minute, meridiem := FormatTimeParameter(*q.startTime)
fmt.Fprintf(&sb, "startTime=%s:%s%s, ", hour, minute, meridiem)
}
if q.endTime != nil {
hour, minute, meridiem := FormatTimeParameter(*q.endTime)
fmt.Fprintf(&sb, "endTime=%s:%s%s, ", hour, minute, meridiem)
}
if q.minCredits != nil {
fmt.Fprintf(&sb, "minCredits=%d, ", *q.minCredits)
}
if q.maxCredits != nil {
fmt.Fprintf(&sb, "maxCredits=%d, ", *q.maxCredits)
}
if q.courseNumberRange != nil {
fmt.Fprintf(&sb, "courseNumberRange=%d-%d, ", q.courseNumberRange.Low, q.courseNumberRange.High)
}
fmt.Fprintf(&sb, "offset=%d, ", q.offset)
fmt.Fprintf(&sb, "maxResults=%d", q.maxResults)
return sb.String()
}
// Dict returns a map representation of the query, ideal for debugging & logging.
// This dict is represented with zerolog's Event type.
// func (q *Query) Dict() *zerolog.Event {
// }

View File

@@ -1,64 +0,0 @@
package api
import (
"banner/internal"
"net/url"
log "github.com/rs/zerolog/log"
)
// Setup makes the initial requests to set up the session cookies for the application.
func (a *API) Setup() {
// Makes the initial requests that sets up the session cookies for the rest of the application
log.Info().Msg("Setting up session...")
requestQueue := []string{
"/registration/registration",
"/selfServiceMenu/data",
}
for _, path := range requestQueue {
req := a.config.Client.NewRequest().
SetQueryParam("_", internal.Nonce()).
SetExpectResponseContentType("application/json")
res, err := req.Get(path)
if err != nil {
log.Fatal().Stack().Str("path", path).Err(err).Msg("Failed to make request")
}
if res.StatusCode() != 200 {
log.Fatal().Stack().Str("path", path).Int("status", res.StatusCode()).Msg("Failed to make request")
}
}
// Validate that cookies were set
baseURLParsed, err := url.Parse(a.config.BaseURL)
if err != nil {
log.Fatal().Stack().Str("baseURL", a.config.BaseURL).Err(err).Msg("Failed to parse baseURL")
}
currentCookies := a.config.Client.CookieJar().Cookies(baseURLParsed)
requiredCookies := map[string]bool{
"JSESSIONID": false,
"SSB_COOKIE": false,
}
for _, cookie := range currentCookies {
_, present := requiredCookies[cookie.Name]
// Check if this cookie is required
if present {
requiredCookies[cookie.Name] = true
}
}
// Check if all required cookies were set
for cookieName, cookieSet := range requiredCookies {
if !cookieSet {
log.Warn().Str("cookieName", cookieName).Msg("Required cookie not set")
}
}
log.Debug().Msg("All required cookies set, session setup complete")
// TODO: Validate that the session allows access to termSelection
}

View File

@@ -1,649 +0,0 @@
package bot
import (
"banner/internal"
"banner/internal/api"
"banner/internal/models"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
)
const (
// ICalTimestampLayoutUtc is the formatting layout for timestamps in the UTC timezone.
ICalTimestampLayoutUtc = "20060102T150405Z"
// ICalTimestampLayoutLocal is the formatting layout for timestamps in the local timezone.
ICalTimestampLayoutLocal = "20060102T150405"
)
// CommandHandler is a function that handles a slash command interaction.
type CommandHandler func(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error
var (
// CommandDefinitions is a list of all the bot's command definitions.
CommandDefinitions = []*discordgo.ApplicationCommand{TermCommandDefinition, TimeCommandDefinition, SearchCommandDefinition, IcsCommandDefinition, GCalCommandDefinition}
// CommandHandlers is a map of command names to their handlers.
CommandHandlers = map[string]CommandHandler{
TimeCommandDefinition.Name: TimeCommandHandler,
TermCommandDefinition.Name: TermCommandHandler,
SearchCommandDefinition.Name: SearchCommandHandler,
IcsCommandDefinition.Name: IcsCommandHandler,
GCalCommandDefinition.Name: GCalCommandHandler,
}
)
var SearchCommandDefinition = &discordgo.ApplicationCommand{
Name: "search",
Description: "Search for a course",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
MinLength: internal.GetIntPointer(0),
MaxLength: 48,
Name: "title",
Description: "Course Title (exact, use autocomplete)",
Required: false,
Autocomplete: true,
},
{
Type: discordgo.ApplicationCommandOptionString,
Name: "code",
MinLength: internal.GetIntPointer(4),
Description: "Course Code (e.g. 3743, 3000-3999, 3xxx, 3000-)",
Required: false,
},
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "max",
Description: "Maximum number of results",
Required: false,
},
{
Type: discordgo.ApplicationCommandOptionString,
Name: "keywords",
Description: "Keywords in Title or Description (space separated)",
},
{
Type: discordgo.ApplicationCommandOptionString,
Name: "instructor",
Description: "Instructor Name",
Required: false,
Autocomplete: true,
},
{
Type: discordgo.ApplicationCommandOptionString,
Name: "subject",
Description: "Subject (e.g. Computer Science/CS, Mathematics/MAT)",
Required: false,
Autocomplete: true,
},
},
}
// SearchCommandHandler handles the /search command, which allows users to search for courses.
func SearchCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
data := i.ApplicationCommandData()
query := api.NewQuery().Credits(3, 6)
for _, option := range data.Options {
switch option.Name {
case "title":
query.Title(option.StringValue())
case "code":
var (
low = -1
high = -1
)
var err error
valueRaw := strings.TrimSpace(option.StringValue())
// Partially/fully specified range
if strings.Contains(valueRaw, "-") {
match := regexp.MustCompile(`(\d{1,4})-(\d{1,4})?`).FindSubmatch([]byte(valueRaw))
if match == nil {
return fmt.Errorf("invalid range format: %s", valueRaw)
}
// If not 2 or 3 matches, it's invalid
if len(match) != 3 && len(match) != 4 {
return fmt.Errorf("invalid range format: %s", match[0])
}
low, err = strconv.Atoi(string(match[1]))
if err != nil {
return errors.Wrap(err, "error parsing course code (low)")
}
// If there's not a high value, set it to max (open ended)
if len(match) == 2 || len(match[2]) == 0 {
high = 9999
} else {
high, err = strconv.Atoi(string(match[2]))
if err != nil {
return errors.Wrap(err, "error parsing course code (high)")
}
}
}
// #xxx, ##xx, ###x format (34xx -> 3400-3499)
if strings.Contains(valueRaw, "x") {
if len(valueRaw) != 4 {
return fmt.Errorf("code range format invalid: must be 1 or more digits followed by x's (%s)", valueRaw)
}
match := regexp.MustCompile(`\d{1,}([xX]{1,3})`).Match([]byte(valueRaw))
if !match {
return fmt.Errorf("code range format invalid: must be 1 or more digits followed by x's (%s)", valueRaw)
}
// Replace x's with 0's
low, err = strconv.Atoi(strings.Replace(valueRaw, "x", "0", -1))
if err != nil {
return errors.Wrap(err, "error parsing implied course code (low)")
}
// Replace x's with 9's
high, err = strconv.Atoi(strings.Replace(valueRaw, "x", "9", -1))
if err != nil {
return errors.Wrap(err, "error parsing implied course code (high)")
}
} else if len(valueRaw) == 4 {
// 4 digit code
low, err = strconv.Atoi(valueRaw)
if err != nil {
return errors.Wrap(err, "error parsing course code")
}
high = low
}
if low == -1 || high == -1 {
return fmt.Errorf("course code range invalid (%s)", valueRaw)
}
if low > high {
return fmt.Errorf("course code range is invalid: low is greater than high (%d > %d)", low, high)
}
if low < 1000 || high < 1000 || low > 9999 || high > 9999 {
return fmt.Errorf("course code range is invalid: must be 1000-9999 (%d-%d)", low, high)
}
query.CourseNumbers(low, high)
case "keywords":
query.Keywords(
strings.Split(option.StringValue(), " "),
)
case "max":
query.MaxResults(
min(8, int(option.IntValue())),
)
}
}
term, err := b.GetSession()
if err != nil {
return err
}
courses, err := b.API.Search(term, query, "", false)
if err != nil {
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "Error searching for courses",
},
})
return err
}
fetchTime := time.Now()
fields := []*discordgo.MessageEmbedField{}
for _, course := range courses.Data {
// Safe instructor name handling
displayName := "TBA"
if len(course.Faculty) > 0 {
displayName = course.Faculty[0].DisplayName
}
categoryLink := fmt.Sprintf("[%s](https://catalog.utsa.edu/undergraduate/coursedescriptions/%s/)", course.Subject, strings.ToLower(course.Subject))
classLink := fmt.Sprintf("[%s-%s](https://catalog.utsa.edu/search/?P=%s%%20%s)", course.CourseNumber, course.SequenceNumber, course.Subject, course.CourseNumber)
professorLink := fmt.Sprintf("[%s](https://www.ratemyprofessors.com/search/professors/1516?q=%s)", displayName, url.QueryEscape(displayName))
identifierText := fmt.Sprintf("%s %s (CRN %s)\n%s", categoryLink, classLink, course.CourseReferenceNumber, professorLink)
// Safe meeting time handling
meetingTime := "No scheduled meetings"
if len(course.MeetingsFaculty) > 0 {
meetingTime = course.MeetingsFaculty[0].String()
}
fields = append(fields, &discordgo.MessageEmbedField{
Name: "Identifier",
Value: identifierText,
Inline: true,
}, &discordgo.MessageEmbedField{
Name: "Name",
Value: course.CourseTitle,
Inline: true,
}, &discordgo.MessageEmbedField{
Name: "Meeting Time",
Value: meetingTime,
Inline: true,
},
)
}
// Blue if there are results, orange if there are none
color := 0x0073FF
if courses.TotalCount == 0 {
color = 0xFF6500
}
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
Description: fmt.Sprintf("%d Class%s", courses.TotalCount, internal.Plural(courses.TotalCount)),
Fields: fields[:min(25, len(fields))],
Color: color,
},
},
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
return err
}
var TermCommandDefinition = &discordgo.ApplicationCommand{
Name: "terms",
Description: "Guess the current term, or search for a specific term",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
MinLength: internal.GetIntPointer(0),
MaxLength: 8,
Name: "search",
Description: "Term to search for",
Required: false,
},
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "page",
Description: "Page Number",
Required: false,
MinValue: internal.GetFloatPointer(1),
},
},
}
// TermCommandHandler handles the /terms command, which allows users to search for terms.
func TermCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
data := i.ApplicationCommandData()
searchTerm := ""
pageNumber := 1
for _, option := range data.Options {
switch option.Name {
case "search":
searchTerm = option.StringValue()
case "page":
pageNumber = int(option.IntValue())
default:
log.Warn().Str("option", option.Name).Msg("Unexpected option in term command")
}
}
termResult, err := b.API.GetTerms(searchTerm, pageNumber, 25)
if err != nil {
internal.RespondError(s, i.Interaction, "Error while fetching terms", err)
return err
}
fields := []*discordgo.MessageEmbedField{}
for _, t := range termResult {
fields = append(fields, &discordgo.MessageEmbedField{
Name: t.Description,
Value: t.Code,
Inline: true,
})
}
fetchTime := time.Now()
if len(fields) > 25 {
log.Warn().Int("count", len(fields)).Msg("Too many fields in term command (trimmed)")
}
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
Description: fmt.Sprintf("%d term%s (page %d)", len(termResult), internal.Plural(len(termResult)), pageNumber),
Fields: fields[:min(25, len(fields))],
},
},
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
return err
}
var TimeCommandDefinition = &discordgo.ApplicationCommand{
Name: "time",
Description: "Get Class Meeting Time",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "crn",
Description: "Course Reference Number",
Required: true,
},
},
}
// TimeCommandHandler handles the /time command, which allows users to get the meeting times for a course.
func TimeCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
fetchTime := time.Now()
crn := i.ApplicationCommandData().Options[0].IntValue()
// Fix static term
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
if err != nil {
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "Error getting meeting time",
},
})
return err
}
if len(meetingTimes) == 0 {
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "No meeting times found for this course",
},
})
return fmt.Errorf("no meeting times found for CRN %d", crn)
}
meetingTime := meetingTimes[0]
duration := meetingTime.EndTime().Sub(meetingTime.StartTime())
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: internal.GetFetchedFooter(b.Config, fetchTime),
Description: "",
Fields: []*discordgo.MessageEmbedField{
{
Name: "Start Date",
Value: meetingTime.StartDay().Format("Monday, January 2, 2006"),
},
{
Name: "End Date",
Value: meetingTime.EndDay().Format("Monday, January 2, 2006"),
},
{
Name: "Start/End Time",
Value: fmt.Sprintf("%s - %s (%d min)", meetingTime.StartTime().String(), meetingTime.EndTime().String(), int64(duration.Minutes())),
},
{
Name: "Days of Week",
Value: internal.WeekdaysToString(meetingTime.Days()),
},
},
},
},
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
return nil
}
var IcsCommandDefinition = &discordgo.ApplicationCommand{
Name: "ics",
Description: "Generate an ICS file for a course",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "crn",
Description: "Course Reference Number",
Required: true,
},
},
}
var GCalCommandDefinition = &discordgo.ApplicationCommand{
Name: "gcal",
Description: "Generate a link to create a Google Calendar event for a course",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "crn",
Description: "Course Reference Number",
Required: true,
},
},
}
// GCalCommandHandler handles the /gcal command, which allows users to generate a link to create a Google Calendar event for a course.
func GCalCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
// Parse all options
options := internal.ParseOptions(i.ApplicationCommandData().Options)
crn := options.GetInt("crn")
course, err := b.API.GetCourse(strconv.Itoa(int(crn)))
if err != nil {
return fmt.Errorf("Error retrieving course data: %w", err)
}
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
if err != nil {
return fmt.Errorf("Error requesting meeting time: %w", err)
}
if len(meetingTimes) == 0 {
return fmt.Errorf("unexpected - no meeting time data found for course")
}
// Check if the course has any meeting times
meetingTime, exists := lo.Find(meetingTimes, func(mt models.MeetingTimeResponse) bool {
switch mt.MeetingTime.MeetingType {
case "ID", "OA":
return false
default:
return true
}
})
if !exists {
internal.RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
return nil
}
startDay := meetingTime.StartDay()
startTime := meetingTime.StartTime()
endTime := meetingTime.EndTime()
// Create timestamps in UTC
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
// Format times in UTC for Google Calendar
startStr := dtStart.UTC().Format(ICalTimestampLayoutUtc)
endStr := dtEnd.UTC().Format(ICalTimestampLayoutUtc)
// Generate RRULE for recurrence
rrule := meetingTime.RRule()
recurRule := fmt.Sprintf("FREQ=WEEKLY;BYDAY=%s;UNTIL=%s", rrule.ByDay, rrule.Until)
// Build calendar URL
params := url.Values{}
params.Add("action", "TEMPLATE")
params.Add("text", fmt.Sprintf("%s %s - %s", course.Subject, course.CourseNumber, course.CourseTitle))
params.Add("dates", fmt.Sprintf("%s/%s", startStr, endStr))
params.Add("details", fmt.Sprintf("CRN: %s\nInstructor: %s\nDays: %s", course.CourseReferenceNumber, meetingTime.Faculty[0].DisplayName, internal.WeekdaysToString(meetingTime.Days())))
params.Add("location", meetingTime.PlaceString())
params.Add("trp", "true")
params.Add("ctz", b.Config.CentralTimeLocation.String())
params.Add("recur", "RRULE:"+recurRule)
calendarURL := "https://calendar.google.com/calendar/render?" + params.Encode()
err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: fmt.Sprintf("[Add to Google Calendar](<%s>)", calendarURL),
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
return err
}
// IcsCommandHandler handles the /ics command, which allows users to generate an ICS file for a course.
func IcsCommandHandler(b *Bot, s *discordgo.Session, i *discordgo.InteractionCreate) error {
// Parse all options
options := internal.ParseOptions(i.ApplicationCommandData().Options)
crn := options.GetInt("crn")
course, err := b.API.GetCourse(strconv.Itoa(int(crn)))
if err != nil {
return fmt.Errorf("Error retrieving course data: %w", err)
}
// Fix static term
meetingTimes, err := b.API.GetCourseMeetingTime(202510, int(crn))
if err != nil {
return fmt.Errorf("Error requesting meeting time: %w", err)
}
if len(meetingTimes) == 0 {
return fmt.Errorf("unexpected - no meeting time data found for course")
}
// Check if the course has any meeting times
_, exists := lo.Find(meetingTimes, func(mt models.MeetingTimeResponse) bool {
switch mt.MeetingTime.MeetingType {
case "ID", "OA":
return false
default:
return true
}
})
if !exists {
log.Warn().Str("crn", course.CourseReferenceNumber).Msg("Non-meeting course requested for ICS file")
internal.RespondError(s, i.Interaction, "The course requested does not meet at a defined moment in time.", nil)
return nil
}
events := []string{}
for _, meeting := range meetingTimes {
now := time.Now().In(b.Config.CentralTimeLocation)
uid := fmt.Sprintf("%d-%s@ical.banner.xevion.dev", now.Unix(), meeting.CourseReferenceNumber)
startDay := meeting.StartDay()
startTime := meeting.StartTime()
endTime := meeting.EndTime()
dtStart := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(startTime.Hours), int(startTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
dtEnd := time.Date(startDay.Year(), startDay.Month(), startDay.Day(), int(endTime.Hours), int(endTime.Minutes), 0, 0, b.Config.CentralTimeLocation)
// endDay := meeting.EndDay()
// until := time.Date(endDay.Year(), endDay.Month(), endDay.Day(), 23, 59, 59, 0, b.Config.CentralTimeLocation)
summary := fmt.Sprintf("%s %s %s", course.Subject, course.CourseNumber, course.CourseTitle)
// Safe instructor name handling
instructorName := "TBA"
if len(course.Faculty) > 0 {
instructorName = course.Faculty[0].DisplayName
}
description := fmt.Sprintf("Instructor: %s\nSection: %s\nCRN: %s", instructorName, course.SequenceNumber, meeting.CourseReferenceNumber)
location := meeting.PlaceString()
rrule := meeting.RRule()
event := fmt.Sprintf(`BEGIN:VEVENT
DTSTAMP:%s
UID:%s
DTSTART;TZID=America/Chicago:%s
RRULE:FREQ=WEEKLY;BYDAY=%s;UNTIL=%s
DTEND;TZID=America/Chicago:%s
SUMMARY:%s
DESCRIPTION:%s
LOCATION:%s
END:VEVENT`, now.Format(ICalTimestampLayoutLocal), uid, dtStart.Format(ICalTimestampLayoutLocal), rrule.ByDay, rrule.Until, dtEnd.Format(ICalTimestampLayoutLocal), summary, strings.Replace(description, "\n", `\n`, -1), location)
events = append(events, event)
}
// TODO: Make this dynamically requested, parsed & cached from tzurl.org
vTimezone := `BEGIN:VTIMEZONE
TZID:America/Chicago
LAST-MODIFIED:20231222T233358Z
TZURL:https://www.tzurl.org/zoneinfo-outlook/America/Chicago
X-LIC-LOCATION:America/Chicago
BEGIN:DAYLIGHT
TZNAME:CDT
TZOFFSETFROM:-0600
TZOFFSETTO:-0500
DTSTART:19700308T020000
RRULE:FREQ=YEARLY;BYMONTH=3;BYDAY=2SU
END:DAYLIGHT
BEGIN:STANDARD
TZNAME:CST
TZOFFSETFROM:-0500
TZOFFSETTO:-0600
DTSTART:19701101T020000
RRULE:FREQ=YEARLY;BYMONTH=11;BYDAY=1SU
END:STANDARD
END:VTIMEZONE`
ics := fmt.Sprintf(`BEGIN:VCALENDAR
VERSION:2.0
PRODID:-//xevion//Banner Discord Bot//EN
CALSCALE:GREGORIAN
%s
%s
END:VCALENDAR`, vTimezone, strings.Join(events, "\n"))
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Files: []*discordgo.File{
{
Name: fmt.Sprintf("%s-%s-%s_%s.ics", course.Subject, course.CourseNumber, course.SequenceNumber, course.CourseReferenceNumber),
ContentType: "text/calendar",
Reader: strings.NewReader(ics),
},
},
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
return nil
}

View File

@@ -1,91 +0,0 @@
package bot
import (
"banner/internal"
"fmt"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
// RegisterHandlers registers the bot's command handlers.
func (b *Bot) RegisterHandlers() {
b.Session.AddHandler(func(internalSession *discordgo.Session, interaction *discordgo.InteractionCreate) {
// Handle commands during restart (highly unlikely, but just in case)
if b.isClosing {
err := internal.RespondError(internalSession, interaction.Interaction, "Bot is currently restarting, try again later.", nil)
if err != nil {
log.Error().Err(err).Msg("Failed to respond with restart error feedback")
}
return
}
name := interaction.ApplicationCommandData().Name
if handler, ok := CommandHandlers[name]; ok {
// Build dict of options for the log
options := zerolog.Dict()
for _, option := range interaction.ApplicationCommandData().Options {
options.Str(option.Name, fmt.Sprintf("%v", option.Value))
}
event := log.Info().Str("name", name).Str("user", internal.GetUser(interaction).Username).Dict("options", options)
// If the command was invoked in a guild, add guild & channel info to the log
if interaction.Member != nil {
guild := zerolog.Dict()
guild.Str("id", interaction.GuildID)
guild.Str("name", internal.GetGuildName(b.Config, internalSession, interaction.GuildID))
event.Dict("guild", guild)
channel := zerolog.Dict()
channel.Str("id", interaction.ChannelID)
guild.Str("name", internal.GetChannelName(b.Config, internalSession, interaction.ChannelID))
event.Dict("channel", channel)
} else {
// If the command was invoked in a DM, add the user info to the log
user := zerolog.Dict()
user.Str("id", interaction.User.ID)
user.Str("name", interaction.User.Username)
event.Dict("user", user)
}
// Log command invocation
event.Msg("Command Invoked")
// Prepare to recover
defer func() {
if err := recover(); err != nil {
log.Error().Stack().Str("commandName", name).Interface("detail", err).Msg("Command Handler Panic")
// Respond with error
err := internal.RespondError(internalSession, interaction.Interaction, "Unexpected Error: command handler panic", nil)
if err != nil {
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with panic error feedback")
}
}
}()
// Call handler
err := handler(b, internalSession, interaction)
// Log & respond error
if err != nil {
// TODO: Find a way to merge the response with the handler's error
log.Error().Str("commandName", name).Err(err).Msg("Command Handler Error")
// Respond with error
err = internal.RespondError(internalSession, interaction.Interaction, fmt.Sprintf("Unexpected Error: %s", err.Error()), nil)
if err != nil {
log.Error().Stack().Str("commandName", name).Err(err).Msg("Failed to respond with error feedback")
}
}
} else {
log.Error().Stack().Str("commandName", name).Msg("Command Interaction Has No Handler")
// Respond with error
internal.RespondError(internalSession, interaction.Interaction, "Unexpected Error: interaction has no handler", nil)
}
})
}

View File

@@ -1,44 +0,0 @@
// Package bot provides the core functionality for the Discord bot.
package bot
import (
"banner/internal/api"
"banner/internal/config"
"fmt"
"time"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog/log"
)
// Bot represents the state of the Discord bot.
type Bot struct {
Session *discordgo.Session
API *api.API
Config *config.Config
isClosing bool
}
// New creates a new Bot instance.
func New(s *discordgo.Session, a *api.API, c *config.Config) *Bot {
return &Bot{Session: s, API: a, Config: c}
}
// SetClosing marks the bot as closing, preventing new commands from being processed.
func (b *Bot) SetClosing() {
b.isClosing = true
}
// GetSession ensures a valid session is available and selects the default term.
func (b *Bot) GetSession() (string, error) {
sessionID := b.API.EnsureSession()
term := b.API.DefaultTerm(time.Now()).ToString()
log.Info().Str("term", term).Str("sessionID", sessionID).Msg("Setting selected term")
err := b.API.SelectTerm(term, sessionID)
if err != nil {
return "", fmt.Errorf("failed to select term while generating session ID: %w", err)
}
return sessionID, nil
}

View File

@@ -1,72 +0,0 @@
package config
import (
"context"
"time"
"github.com/redis/go-redis/v9"
"resty.dev/v3"
)
// Config holds the application's configuration.
type Config struct {
// Ctx is the application's root context.
Ctx context.Context
// CancelFunc cancels the application's root context.
CancelFunc context.CancelFunc
// KV provides access to the Redis cache.
KV *redis.Client
// Client is the HTTP client for making API requests.
Client *resty.Client
// IsDevelopment is true if the application is running in a development environment.
IsDevelopment bool
// BaseURL is the base URL for the Banner API.
BaseURL string
// Environment is the application's running environment (e.g. "development").
Environment string
// CentralTimeLocation is the time.Location for US Central Time.
CentralTimeLocation *time.Location
// SeasonRanges is the time.Location for US Central Time.
SeasonRanges *SeasonRanges
}
// New creates a new Config instance with a cancellable context.
func New() (*Config, error) {
ctx, cancel := context.WithCancel(context.Background())
loc, err := time.LoadLocation("America/Chicago")
if err != nil {
cancel()
return nil, err
}
seasonRanges := GetYearDayRange(loc, uint16(time.Now().Year()))
return &Config{
Ctx: ctx,
CancelFunc: cancel,
CentralTimeLocation: loc,
SeasonRanges: &seasonRanges,
}, nil
}
// SetBaseURL sets the base URL for the Banner API.
func (c *Config) SetBaseURL(url string) {
c.BaseURL = url
}
// SetEnvironment sets the application's environment.
func (c *Config) SetEnvironment(env string) {
c.Environment = env
c.IsDevelopment = env == "development"
}
// SetClient sets the Resty client for making HTTP requests.
func (c *Config) SetClient(client *resty.Client) {
c.Client = client
}
// SetRedis sets the Redis client for caching.
func (c *Config) SetRedis(r *redis.Client) {
c.KV = r
}

View File

@@ -1,71 +0,0 @@
// Package config provides the configuration and logging setup for the application.
package config
import (
"io"
"os"
"github.com/rs/zerolog"
)
const timeFormat = "2006-01-02 15:04:05"
// NewConsoleWriter creates a new console writer that splits logs between stdout and stderr.
func NewConsoleWriter() zerolog.LevelWriter {
return &ConsoleLogSplitter{
stdConsole: zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: timeFormat,
NoColor: false,
PartsOrder: []string{zerolog.TimestampFieldName, zerolog.LevelFieldName, zerolog.MessageFieldName},
PartsExclude: []string{},
FieldsExclude: []string{},
},
errConsole: zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: timeFormat,
NoColor: false,
PartsOrder: []string{zerolog.TimestampFieldName, zerolog.LevelFieldName, zerolog.MessageFieldName},
PartsExclude: []string{},
FieldsExclude: []string{},
},
}
}
// ConsoleLogSplitter is a zerolog.LevelWriter that writes to stdout for info/debug logs and stderr for warn/error logs, with console-friendly formatting.
type ConsoleLogSplitter struct {
stdConsole zerolog.ConsoleWriter
errConsole zerolog.ConsoleWriter
}
// Write is a passthrough to the standard console writer and should not be called directly.
func (c *ConsoleLogSplitter) Write(p []byte) (n int, err error) {
return c.stdConsole.Write(p)
}
// WriteLevel writes to the appropriate output (stdout or stderr) with console formatting based on the log level.
func (c *ConsoleLogSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
if level <= zerolog.WarnLevel {
return c.stdConsole.Write(p)
}
return c.errConsole.Write(p)
}
// LogSplitter is a zerolog.LevelWriter that writes to stdout for info/debug logs and stderr for warn/error logs.
type LogSplitter struct {
Std io.Writer
Err io.Writer
}
// Write is a passthrough to the standard writer and should not be called directly.
func (l LogSplitter) Write(p []byte) (n int, err error) {
return l.Std.Write(p)
}
// WriteLevel writes to the appropriate output (stdout or stderr) based on the log level.
func (l LogSplitter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
if level <= zerolog.WarnLevel {
return l.Std.Write(p)
}
return l.Err.Write(p)
}

View File

@@ -1,140 +0,0 @@
package config
import (
"fmt"
"strconv"
"time"
)
// Term selection should yield smart results based on the current time, as well as the input provided.
// Fall 2024, "spring" => Spring 2025
// Fall 2024, "fall" => Fall 2025
// Summer 2024, "fall" => Fall 2024
const (
// Fall is the first term of the school year.
Fall = iota
// Spring is the second term of the school year.
Spring
// Summer is the third term of the school year.
Summer
)
// Term represents a school term, consisting of a year and a season.
type Term struct {
Year uint16
Season uint8
}
// SeasonRanges represents the start and end day of each term within a year.
type SeasonRanges struct {
Spring YearDayRange
Summer YearDayRange
Fall YearDayRange
}
// YearDayRange represents the start and end day of a term within a year.
type YearDayRange struct {
Start uint16
End uint16
}
// GetYearDayRange returns the start and end day of each term for the given year.
// The ranges are inclusive of the start day and exclusive of the end day.
func GetYearDayRange(loc *time.Location, year uint16) SeasonRanges {
springStart := time.Date(int(year), time.January, 14, 0, 0, 0, 0, loc).YearDay()
springEnd := time.Date(int(year), time.May, 1, 0, 0, 0, 0, loc).YearDay()
summerStart := time.Date(int(year), time.May, 25, 0, 0, 0, 0, loc).YearDay()
summerEnd := time.Date(int(year), time.August, 15, 0, 0, 0, 0, loc).YearDay()
fallStart := time.Date(int(year), time.August, 18, 0, 0, 0, 0, loc).YearDay()
fallEnd := time.Date(int(year), time.December, 10, 0, 0, 0, 0, loc).YearDay()
return SeasonRanges{
Spring: YearDayRange{
Start: uint16(springStart),
End: uint16(springEnd),
},
Summer: YearDayRange{
Start: uint16(summerStart),
End: uint16(summerEnd),
},
Fall: YearDayRange{
Start: uint16(fallStart),
End: uint16(fallEnd),
},
}
}
// GetCurrentTerm returns the current and next terms based on the provided time.
// The current term can be nil if the time falls between terms.
// The 'year' in the term corresponds to the academic year, which may differ from the calendar year.
func GetCurrentTerm(ranges SeasonRanges, now time.Time) (*Term, *Term) {
literalYear := uint16(now.Year())
dayOfYear := uint16(now.YearDay())
// If we're past the end of the summer term, we're 'in' the next school year.
var termYear uint16
if dayOfYear > ranges.Summer.End {
termYear = literalYear + 1
} else {
termYear = literalYear
}
if (dayOfYear < ranges.Spring.Start) || (dayOfYear >= ranges.Fall.End) {
// Fall over, Spring not yet begun
return nil, &Term{Year: termYear, Season: Spring}
} else if (dayOfYear >= ranges.Spring.Start) && (dayOfYear < ranges.Spring.End) {
// Spring
return &Term{Year: termYear, Season: Spring}, &Term{Year: termYear, Season: Summer}
} else if dayOfYear < ranges.Summer.Start {
// Spring over, Summer not yet begun
return nil, &Term{Year: termYear, Season: Summer}
} else if (dayOfYear >= ranges.Summer.Start) && (dayOfYear < ranges.Summer.End) {
// Summer
return &Term{Year: termYear, Season: Summer}, &Term{Year: termYear, Season: Fall}
} else if dayOfYear < ranges.Fall.Start {
// Summer over, Fall not yet begun
return nil, &Term{Year: termYear, Season: Fall}
} else if (dayOfYear >= ranges.Fall.Start) && (dayOfYear < ranges.Fall.End) {
// Fall
return &Term{Year: termYear, Season: Fall}, nil
}
panic(fmt.Sprintf("Impossible Code Reached (dayOfYear: %d)", dayOfYear))
}
// ParseTerm converts a Banner term code string to a Term struct.
func ParseTerm(code string) Term {
year, _ := strconv.ParseUint(code[0:4], 10, 16)
var season uint8
termCode := code[4:6]
switch termCode {
case "10":
season = Fall
case "20":
season = Spring
case "30":
season = Summer
}
return Term{
Year: uint16(year),
Season: season,
}
}
// ToString converts a Term struct to a Banner term code string.
func (term Term) ToString() string {
var season string
switch term.Season {
case Fall:
season = "10"
case Spring:
season = "20"
case Summer:
season = "30"
}
return fmt.Sprintf("%d%s", term.Year, season)
}

View File

@@ -1,13 +0,0 @@
package internal
import "fmt"
// UnexpectedContentTypeError is returned when the Content-Type header of a response does not match the expected value.
type UnexpectedContentTypeError struct {
Expected string
Actual string
}
func (e *UnexpectedContentTypeError) Error() string {
return fmt.Sprintf("Expected content type '%s', received '%s'", e.Expected, e.Actual)
}

View File

@@ -1,376 +0,0 @@
package internal
import (
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"os"
"runtime"
"sort"
"strconv"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
log "github.com/rs/zerolog/log"
"resty.dev/v3"
"banner/internal/config"
)
// Options is a map of options from a Discord command.
type Options map[string]*discordgo.ApplicationCommandInteractionDataOption
// GetInt returns the integer value of an option, or 0 if it doesn't exist.
func (o Options) GetInt(key string) int64 {
if opt, ok := o[key]; ok {
return opt.IntValue()
}
return 0
}
// ParseOptions parses slash command options into a map for easier access.
func ParseOptions(options []*discordgo.ApplicationCommandInteractionDataOption) Options {
optionMap := make(Options)
for _, opt := range options {
optionMap[opt.Name] = opt
}
return optionMap
}
// AddUserAgent adds a consistent user agent to the request to mimic a real browser.
func AddUserAgent(req *http.Request) {
req.Header.Add("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36")
}
// ContentTypeMatch checks if a Resty response has the given content type.
func ContentTypeMatch(res *resty.Response, expectedContentType string) bool {
contentType := res.Header().Get("Content-Type")
if contentType == "" {
return expectedContentType == "application/octect-stream"
}
return strings.HasPrefix(contentType, expectedContentType)
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
// RandomString returns a random string of length n.
// The character set is chosen to mimic Ellucian's Banner session ID generation.
func RandomString(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
// DiscordGoLogger is a helper function that implements discordgo's logging interface, directing all logs to zerolog.
func DiscordGoLogger(msgL, caller int, format string, a ...interface{}) {
pc, file, line, _ := runtime.Caller(caller)
files := strings.Split(file, "/")
file = files[len(files)-1]
name := runtime.FuncForPC(pc).Name()
fns := strings.Split(name, ".")
name = fns[len(fns)-1]
msg := fmt.Sprintf(format, a...)
var event *zerolog.Event
switch msgL {
case 0:
event = log.Debug()
case 1:
event = log.Info()
case 2:
event = log.Warn()
case 3:
event = log.Error()
default:
event = log.Info()
}
event.Str("file", file).Int("line", line).Str("function", name).Msg(msg)
}
// Nonce returns the current time in milliseconds since the Unix epoch as a string.
// This is typically used as a query parameter to prevent request caching.
func Nonce() string {
return strconv.Itoa(int(time.Now().UnixMilli()))
}
// Plural returns "s" if n is not 1.
func Plural(n int) string {
if n == 1 {
return ""
}
return "s"
}
// Plurale returns "es" if n is not 1.
func Plurale(n int) string {
if n == 1 {
return ""
}
return "es"
}
// WeekdaysToString converts a map of weekdays to a compact string representation (e.g., "MWF").
func WeekdaysToString(days map[time.Weekday]bool) string {
// If no days are present
numDays := len(days)
if numDays == 0 {
return "None"
}
// If all days are present
if numDays == 7 {
return "Everyday"
}
str := ""
if days[time.Monday] {
str += "M"
}
if days[time.Tuesday] {
str += "Tu"
}
if days[time.Wednesday] {
str += "W"
}
if days[time.Thursday] {
str += "Th"
}
if days[time.Friday] {
str += "F"
}
if days[time.Saturday] {
str += "Sa"
}
if days[time.Sunday] {
str += "Su"
}
return str
}
// NaiveTime represents a time of day without a date or timezone.
type NaiveTime struct {
Hours uint
Minutes uint
}
// Sub returns the duration between two NaiveTime instances.
func (nt *NaiveTime) Sub(other *NaiveTime) time.Duration {
return time.Hour*time.Duration(nt.Hours-other.Hours) + time.Minute*time.Duration(nt.Minutes-other.Minutes)
}
// ParseNaiveTime converts an integer representation of time (e.g., 1430) to a NaiveTime struct.
func ParseNaiveTime(integer uint64) *NaiveTime {
minutes := uint(integer % 100)
hours := uint(integer / 100)
return &NaiveTime{Hours: hours, Minutes: minutes}
}
// String returns a string representation of the NaiveTime in 12-hour format (e.g., "2:30PM").
func (nt NaiveTime) String() string {
meridiem := "AM"
hour := nt.Hours
if nt.Hours >= 12 {
meridiem = "PM"
if nt.Hours > 12 {
hour -= 12
}
}
return fmt.Sprintf("%d:%02d%s", hour, nt.Minutes, meridiem)
}
// GetFirstEnv returns the value of the first environment variable that is set.
func GetFirstEnv(key ...string) string {
for _, k := range key {
if v := os.Getenv(k); v != "" {
return v
}
}
return ""
}
// GetIntPointer returns a pointer to the given integer.
func GetIntPointer(value int) *int {
return &value
}
// GetFloatPointer returns a pointer to the given float.
func GetFloatPointer(value float64) *float64 {
return &value
}
var extensionMap = map[string]string{
"text/plain": "txt",
"application/json": "json",
"text/html": "html",
"text/css": "css",
"text/csv": "csv",
"text/calendar": "ics",
"text/markdown": "md",
"text/xml": "xml",
"text/yaml": "yaml",
"text/javascript": "js",
"text/vtt": "vtt",
"image/jpeg": "jpg",
"image/png": "png",
"image/gif": "gif",
"image/webp": "webp",
"image/tiff": "tiff",
"image/svg+xml": "svg",
"image/bmp": "bmp",
"image/vnd.microsoft.icon": "ico",
"image/x-icon": "ico",
"image/x-xbitmap": "xbm",
"image/x-xpixmap": "xpm",
"image/x-xwindowdump": "xwd",
"image/avif": "avif",
"image/apng": "apng",
"image/jxl": "jxl",
}
// GuessExtension guesses the file extension for a given content type.
func GuessExtension(contentType string) string {
ext, ok := extensionMap[strings.ToLower(contentType)]
if !ok {
return ""
}
return ext
}
// DumpResponse dumps the body of a Resty response to a file for debugging.
func DumpResponse(res *resty.Response) {
contentType := res.Header().Get("Content-Type")
ext := GuessExtension(contentType)
// Use current time as filename + /dumps/ prefix
filename := fmt.Sprintf("dumps/%d.%s", time.Now().Unix(), ext)
file, err := os.Create(filename)
if err != nil {
log.Err(err).Stack().Msg("Error creating file")
return
}
defer file.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
log.Err(err).Stack().Msg("Error reading response body")
return
}
_, err = file.Write(body)
if err != nil {
log.Err(err).Stack().Msg("Error writing response body")
return
}
log.Info().Str("filename", filename).Str("content-type", contentType).Msg("Dumped response body")
}
// RespondError responds to an interaction with a formatted error message.
func RespondError(session *discordgo.Session, interaction *discordgo.Interaction, message string, err error) error {
// Optional: log the error
if err != nil {
log.Err(err).Stack().Msg(message)
}
return session.InteractionRespond(interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Embeds: []*discordgo.MessageEmbed{
{
Footer: &discordgo.MessageEmbedFooter{
Text: fmt.Sprintf("Occurred at %s", time.Now().Format("Monday, January 2, 2006 at 3:04:05PM")),
},
Description: message,
Color: 0xff0000,
},
},
AllowedMentions: &discordgo.MessageAllowedMentions{},
},
})
}
// GetFetchedFooter returns a standard footer for embeds, indicating when the data was fetched.
func GetFetchedFooter(cfg *config.Config, time time.Time) *discordgo.MessageEmbedFooter {
return &discordgo.MessageEmbedFooter{
Text: fmt.Sprintf("Fetched at %s", time.In(cfg.CentralTimeLocation).Format("Monday, January 2, 2006 at 3:04:05PM")),
}
}
// GetUser returns the user from an interaction, regardless of whether it was in a guild or a DM.
func GetUser(interaction *discordgo.InteractionCreate) *discordgo.User {
// If the interaction is in a guild, the user is in the Member field
if interaction.Member != nil {
return interaction.Member.User
}
// If the interaction is in a DM, the user is in the User field
return interaction.User
}
// EncodeParams encodes a map of parameters into a URL-encoded string, sorted by key.
func EncodeParams(params map[string]*[]string) string {
// Escape hatch for nil
if params == nil {
return ""
}
// Sort the keys
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var buf strings.Builder
for _, k := range keys {
// Multiple values are allowed, so extract the slice & prepare the key
values := params[k]
keyEscaped := url.QueryEscape(k)
for _, v := range *values {
// If any parameters have been written, add the ampersand
if buf.Len() > 0 {
buf.WriteByte('&')
}
// Write the key and value
buf.WriteString(keyEscaped)
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(v))
}
}
return buf.String()
}
// Point represents a point in 2D space.
type Point struct {
X, Y float64
}
// Slope calculates the y-coordinate of a point on a line given two other points and an x-coordinate.
func Slope(p1 Point, p2 Point, x float64) Point {
slope := (p2.Y - p1.Y) / (p2.X - p1.X)
newY := slope*(x-p1.X) + p1.Y
return Point{X: x, Y: newY}
}

View File

@@ -1,96 +0,0 @@
// Package internal provides shared functionality for the banner application.
package internal
import (
"banner/internal/config"
"context"
"time"
"github.com/bwmarrin/discordgo"
"github.com/redis/go-redis/v9"
log "github.com/rs/zerolog/log"
)
// GetGuildName returns the name of a guild by its ID, using Redis for caching.
func GetGuildName(cfg *config.Config, session *discordgo.Session, guildID string) string {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel()
// Check Redis for the guild name
guildName, err := cfg.KV.Get(ctx, "guild:"+guildID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting guild name from Redis")
return "err"
}
// If the guild name is invalid (1 character long), then return "unknown"
if len(guildName) == 1 {
return "unknown"
}
// If the guild name isn't in Redis, get it from Discord and cache it
guild, err := session.Guild(guildID)
if err != nil {
log.Error().Stack().Err(err).Msg("Error getting guild name")
// Store an invalid value in Redis so we don't keep trying to get the guild name
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel2()
_, err := cfg.KV.Set(ctx2, "guild:"+guildID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false guild name in Redis")
}
return "unknown"
}
// Cache the guild name in Redis
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel3()
cfg.KV.Set(ctx3, "guild:"+guildID+":name", guild.Name, time.Hour*3)
return guild.Name
}
// GetChannelName returns the name of a channel by its ID, using Redis for caching.
func GetChannelName(cfg *config.Config, session *discordgo.Session, channelID string) string {
// Create a timeout context for Redis operations
ctx, cancel := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel()
// Check Redis for the channel name
channelName, err := cfg.KV.Get(ctx, "channel:"+channelID+":name").Result()
if err != nil && err != redis.Nil {
log.Error().Stack().Err(err).Msg("Error getting channel name from Redis")
return "err"
}
// If the channel name is invalid (1 character long), then return "unknown"
if len(channelName) == 1 {
return "unknown"
}
// If the channel name isn't in Redis, get it from Discord and cache it
channel, err := session.Channel(channelID)
if err != nil {
log.Error().Stack().Err(err).Msg("Error getting channel name")
// Store an invalid value in Redis so we don't keep trying to get the channel name
ctx2, cancel2 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel2()
_, err := cfg.KV.Set(ctx2, "channel:"+channelID+":name", "x", time.Minute*5).Result()
if err != nil {
log.Error().Stack().Err(err).Msg("Error setting false channel name in Redis")
}
return "unknown"
}
// Cache the channel name in Redis
ctx3, cancel3 := context.WithTimeout(cfg.Ctx, 5*time.Second)
defer cancel3()
cfg.KV.Set(ctx3, "channel:"+channelID+":name", channel.Name, time.Hour*3)
return channel.Name
}

View File

@@ -1,323 +0,0 @@
// Package models provides the data structures for the Banner API.
package models
import (
"banner/internal"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
log "github.com/rs/zerolog/log"
)
// FacultyItem represents a faculty member associated with a course.
type FacultyItem struct {
BannerID string `json:"bannerId"`
Category *string `json:"category"`
Class string `json:"class"`
CourseReferenceNumber string `json:"courseReferenceNumber"`
DisplayName string `json:"displayName"`
Email string `json:"emailAddress"`
Primary bool `json:"primaryIndicator"`
Term string `json:"term"`
}
// MeetingTimeResponse represents the meeting time information for a course.
type MeetingTimeResponse struct {
Category *string `json:"category"`
Class string `json:"class"`
CourseReferenceNumber string `json:"courseReferenceNumber"`
Faculty []FacultyItem
MeetingTime struct {
Category string `json:"category"`
// Some sort of metadata used internally by Banner (net.hedtech.banner.student.schedule.SectionSessionDecorator)
Class string `json:"class"`
// The start date of the meeting time in MM/DD/YYYY format (e.g. 01/16/2024)
StartDate string `json:"startDate"`
// The end date of the meeting time in MM/DD/YYYY format (e.g. 05/10/2024)
EndDate string `json:"endDate"`
// The start time of the meeting time in 24-hour format, hours & minutes, digits only (e.g. 1630)
BeginTime string `json:"beginTime"`
// The end time of the meeting time in 24-hour format, hours & minutes, digits only (e.g. 1745)
EndTime string `json:"endTime"`
// The room number within the building this course takes place at (e.g. 3.01.08, 200A)
Room string `json:"room"`
// The internal identifier for the term this course takes place in (e.g. 202420)
Term string `json:"term"`
// The internal identifier for the building this course takes place at (e.g. SP1)
Building string `json:"building"`
// The long name of the building this course takes place at (e.g. San Pedro I - Data Science)
BuildingDescription string `json:"buildingDescription"`
// The internal identifier for the campus this course takes place at (e.g. 1DT)
Campus string `json:"campus"`
// The long name of the campus this course takes place at (e.g. Main Campus, Downtown Campus)
CampusDescription string `json:"campusDescription"`
CourseReferenceNumber string `json:"courseReferenceNumber"`
// The number of credit hours this class is worth (assumably)
CreditHourSession float64 `json:"creditHourSession"`
// The number of hours per week this class meets (e.g. 2.5)
HoursWeek float64 `json:"hoursWeek"`
// Unknown meaning - e.g. AFF, AIN, AHB, FFF, AFF, EFF, DFF, IFF, EHB, JFF, KFF, BFF, BIN
MeetingScheduleType string `json:"meetingScheduleType"`
// The short identifier for the meeting type (e.g. FF, HB, OS, OA)
MeetingType string `json:"meetingType"`
// The long name of the meeting type (e.g. Traditional in-person)
MeetingTypeDescription string `json:"meetingTypeDescription"`
// A boolean indicating if the class will meet on each Monday of the term
Monday bool `json:"monday"`
// A boolean indicating if the class will meet on each Tuesday of the term
Tuesday bool `json:"tuesday"`
// A boolean indicating if the class will meet on each Wednesday of the term
Wednesday bool `json:"wednesday"`
// A boolean indicating if the class will meet on each Thursday of the term
Thursday bool `json:"thursday"`
// A boolean indicating if the class will meet on each Friday of the term
Friday bool `json:"friday"`
// A boolean indicating if the class will meet on each Saturday of the term
Saturday bool `json:"saturday"`
// A boolean indicating if the class will meet on each Sunday of the term
Sunday bool `json:"sunday"`
} `json:"meetingTime"`
Term string `json:"term"`
}
// String returns a formatted string representation of the meeting time.
func (m *MeetingTimeResponse) String() string {
switch m.MeetingTime.MeetingType {
case "HB":
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
case "H2":
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
case "H1":
return fmt.Sprintf("%s\nHybrid %s", m.TimeString(), m.PlaceString())
case "OS":
return fmt.Sprintf("%s\nOnline Only", m.TimeString())
case "OA":
return "No Time\nOnline Asynchronous"
case "OH":
return fmt.Sprintf("%s\nOnline Partial", m.TimeString())
case "ID":
return "To Be Arranged"
case "FF":
return fmt.Sprintf("%s\n%s", m.TimeString(), m.PlaceString())
}
// TODO: Add error log
return "Unknown"
}
// TimeString returns a formatted string of the meeting times (e.g., "MWF 1:00PM-2:15PM").
func (m *MeetingTimeResponse) TimeString() string {
startTime := m.StartTime()
endTime := m.EndTime()
if startTime == nil || endTime == nil {
return "???"
}
return fmt.Sprintf("%s %s-%s", internal.WeekdaysToString(m.Days()), m.StartTime().String(), m.EndTime().String())
}
// PlaceString returns a formatted string representing the location of the meeting.
func (m *MeetingTimeResponse) PlaceString() string {
mt := m.MeetingTime
// TODO: Add format case for partial online classes
if mt.Room == "" {
return "Online"
}
return fmt.Sprintf("%s | %s | %s %s", mt.CampusDescription, mt.BuildingDescription, mt.Building, mt.Room)
}
// Days returns a map of weekdays on which the course meets.
func (m *MeetingTimeResponse) Days() map[time.Weekday]bool {
days := map[time.Weekday]bool{}
days[time.Monday] = m.MeetingTime.Monday
days[time.Tuesday] = m.MeetingTime.Tuesday
days[time.Wednesday] = m.MeetingTime.Wednesday
days[time.Thursday] = m.MeetingTime.Thursday
days[time.Friday] = m.MeetingTime.Friday
days[time.Saturday] = m.MeetingTime.Saturday
return days
}
// ByDay returns a comma-separated string of two-letter day abbreviations for the iCalendar RRule.
func (m *MeetingTimeResponse) ByDay() string {
days := []string{}
if m.MeetingTime.Sunday {
days = append(days, "SU")
}
if m.MeetingTime.Monday {
days = append(days, "MO")
}
if m.MeetingTime.Tuesday {
days = append(days, "TU")
}
if m.MeetingTime.Wednesday {
days = append(days, "WE")
}
if m.MeetingTime.Thursday {
days = append(days, "TH")
}
if m.MeetingTime.Friday {
days = append(days, "FR")
}
if m.MeetingTime.Saturday {
days = append(days, "SA")
}
return strings.Join(days, ",")
}
const layout = "01/02/2006"
// StartDay returns the start date of the meeting as a time.Time object.
// This method is not cached and will panic if the date cannot be parsed.
func (m *MeetingTimeResponse) StartDay() time.Time {
t, err := time.Parse(layout, m.MeetingTime.StartDate)
if err != nil {
log.Panic().Stack().Err(err).Str("raw", m.MeetingTime.StartDate).Msg("Cannot parse start date")
}
return t
}
// EndDay returns the end date of the meeting as a time.Time object.
// This method is not cached and will panic if the date cannot be parsed.
func (m *MeetingTimeResponse) EndDay() time.Time {
t, err := time.Parse(layout, m.MeetingTime.EndDate)
if err != nil {
log.Panic().Stack().Err(err).Str("raw", m.MeetingTime.EndDate).Msg("Cannot parse end date")
}
return t
}
// StartTime returns the start time of the meeting as a NaiveTime object.
// This method is not cached and will panic if the time cannot be parsed.
func (m *MeetingTimeResponse) StartTime() *internal.NaiveTime {
raw := m.MeetingTime.BeginTime
if raw == "" {
log.Panic().Stack().Msg("Start time is empty")
}
value, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse start time integer")
}
return internal.ParseNaiveTime(value)
}
// EndTime returns the end time of the meeting as a NaiveTime object.
// This method is not cached and will panic if the time cannot be parsed.
func (m *MeetingTimeResponse) EndTime() *internal.NaiveTime {
raw := m.MeetingTime.EndTime
if raw == "" {
return nil
}
value, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
log.Panic().Stack().Err(err).Str("raw", raw).Msg("Cannot parse end time integer")
}
return internal.ParseNaiveTime(value)
}
// RRule represents a recurrence rule for an iCalendar event.
type RRule struct {
Until string
ByDay string
}
// RRule converts the meeting time to a struct that satisfies the iCalendar RRule format.
func (m *MeetingTimeResponse) RRule() RRule {
return RRule{
Until: m.EndDay().UTC().Format("20060102T150405Z"),
ByDay: m.ByDay(),
}
}
// SearchResult represents the result of a course search.
type SearchResult struct {
Success bool `json:"success"`
TotalCount int `json:"totalCount"`
PageOffset int `json:"pageOffset"`
PageMaxSize int `json:"pageMaxSize"`
PathMode string `json:"pathMode"`
SearchResultsConfig []struct {
Config string `json:"config"`
Display string `json:"display"`
} `json:"searchResultsConfig"`
Data []Course `json:"data"`
}
// Course represents a single course returned from a search.
type Course struct {
// ID is an internal identifier not used outside of the Banner system.
ID int `json:"id"`
// Term is the internal identifier for the term this class is in (e.g. 202420).
Term string `json:"term"`
// TermDesc is the human-readable name of the term this class is in (e.g. Fall 2021).
TermDesc string `json:"termDesc"`
// CourseReferenceNumber is the unique identifier for a course within a term.
CourseReferenceNumber string `json:"courseReferenceNumber"`
// PartOfTerm specifies which part of the term the course is in (e.g. B6, B5).
PartOfTerm string `json:"partOfTerm"`
// CourseNumber is the 4-digit code for the course (e.g. 3743).
CourseNumber string `json:"courseNumber"`
// Subject is the subject acronym (e.g. CS, AEPI).
Subject string `json:"subject"`
// SubjectDescription is the full name of the course subject.
SubjectDescription string `json:"subjectDescription"`
// SequenceNumber is the course section (e.g. 001, 002).
SequenceNumber string `json:"sequenceNumber"`
CampusDescription string `json:"campusDescription"`
// ScheduleTypeDescription is the type of schedule for the course (e.g. Lecture, Seminar).
ScheduleTypeDescription string `json:"scheduleTypeDescription"`
CourseTitle string `json:"courseTitle"`
CreditHours int `json:"creditHours"`
// MaximumEnrollment is the maximum number of students that can enroll.
MaximumEnrollment int `json:"maximumEnrollment"`
Enrollment int `json:"enrollment"`
SeatsAvailable int `json:"seatsAvailable"`
WaitCapacity int `json:"waitCapacity"`
WaitCount int `json:"waitCount"`
CrossList *string `json:"crossList"`
CrossListCapacity *int `json:"crossListCapacity"`
CrossListCount *int `json:"crossListCount"`
CrossListAvailable *int `json:"crossListAvailable"`
CreditHourHigh *int `json:"creditHourHigh"`
CreditHourLow *int `json:"creditHourLow"`
CreditHourIndicator *string `json:"creditHourIndicator"`
OpenSection bool `json:"openSection"`
LinkIdentifier *string `json:"linkIdentifier"`
IsSectionLinked bool `json:"isSectionLinked"`
// SubjectCourse is the combination of the subject and course number (e.g. CS3443).
SubjectCourse string `json:"subjectCourse"`
ReservedSeatSummary *string `json:"reservedSeatSummary"`
InstructionalMethod string `json:"instructionalMethod"`
InstructionalMethodDescription string `json:"instructionalMethodDescription"`
SectionAttributes []struct {
// Class is an internal API class identifier used by Banner.
Class string `json:"class"`
CourseReferenceNumber string `json:"courseReferenceNumber"`
// Code for the attribute (e.g., UPPR, ZIEP, AIS).
Code string `json:"code"`
Description string `json:"description"`
TermCode string `json:"termCode"`
IsZtcAttribute bool `json:"isZTCAttribute"`
} `json:"sectionAttributes"`
Faculty []FacultyItem `json:"faculty"`
MeetingsFaculty []MeetingTimeResponse `json:"meetingsFaculty"`
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (course Course) MarshalBinary() ([]byte, error) {
return json.Marshal(course)
}

View File

@@ -0,0 +1,56 @@
-- Drop all old tables
DROP TABLE IF EXISTS scrape_jobs;
DROP TABLE IF EXISTS course_metrics;
DROP TABLE IF EXISTS course_audits;
DROP TABLE IF EXISTS courses;
-- Enums for scrape_jobs
CREATE TYPE scrape_priority AS ENUM ('Low', 'Medium', 'High', 'Critical');
CREATE TYPE target_type AS ENUM ('Subject', 'CourseRange', 'CrnList', 'SingleCrn');
-- Main course data table
CREATE TABLE courses (
id SERIAL PRIMARY KEY,
crn VARCHAR NOT NULL,
subject VARCHAR NOT NULL,
course_number VARCHAR NOT NULL,
title VARCHAR NOT NULL,
term_code VARCHAR NOT NULL,
enrollment INTEGER NOT NULL,
max_enrollment INTEGER NOT NULL,
wait_count INTEGER NOT NULL,
wait_capacity INTEGER NOT NULL,
last_scraped_at TIMESTAMPTZ NOT NULL,
UNIQUE(crn, term_code)
);
-- Time-series data for course enrollment
CREATE TABLE course_metrics (
id SERIAL PRIMARY KEY,
course_id INTEGER NOT NULL REFERENCES courses(id) ON DELETE CASCADE,
timestamp TIMESTAMPTZ NOT NULL,
enrollment INTEGER NOT NULL,
wait_count INTEGER NOT NULL,
seats_available INTEGER NOT NULL
);
-- Audit trail for changes to course data
CREATE TABLE course_audits (
id SERIAL PRIMARY KEY,
course_id INTEGER NOT NULL REFERENCES courses(id) ON DELETE CASCADE,
timestamp TIMESTAMPTZ NOT NULL,
field_changed VARCHAR NOT NULL,
old_value TEXT NOT NULL,
new_value TEXT NOT NULL
);
-- Job queue for the scraper
CREATE TABLE scrape_jobs (
id SERIAL PRIMARY KEY,
target_type target_type NOT NULL,
target_payload JSONB NOT NULL,
priority scrape_priority NOT NULL,
execute_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
locked_at TIMESTAMPTZ
);

384
src/banner/api.rs Normal file
View File

@@ -0,0 +1,384 @@
//! Main Banner API client implementation.
use std::{
collections::{HashMap, VecDeque},
sync::{Arc, Mutex},
time::Instant,
};
use crate::banner::{
BannerSession, SessionPool, create_shared_rate_limiter_with_config,
errors::BannerApiError,
json::parse_json_with_context,
middleware::TransparentMiddleware,
models::*,
nonce,
query::SearchQuery,
rate_limit_middleware::RateLimitMiddleware,
rate_limiter::{RateLimitConfig, SharedRateLimiter, create_shared_rate_limiter},
util::user_agent,
};
use anyhow::{Context, Result, anyhow};
use cookie::Cookie;
use dashmap::DashMap;
use http::HeaderValue;
use reqwest::{Client, Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use serde_json;
use tl;
use tracing::{Level, Metadata, Span, debug, error, field::ValueSet, info, span, trace, warn};
/// Main Banner API client.
pub struct BannerApi {
pub sessions: SessionPool,
http: ClientWithMiddleware,
base_url: String,
}
impl BannerApi {
/// Creates a new Banner API client.
pub fn new(base_url: String) -> Result<Self> {
Self::new_with_config(base_url, RateLimitConfig::default())
}
/// Creates a new Banner API client with custom rate limiting configuration.
pub fn new_with_config(base_url: String, rate_limit_config: RateLimitConfig) -> Result<Self> {
let rate_limiter = create_shared_rate_limiter_with_config(rate_limit_config);
let http = ClientBuilder::new(
Client::builder()
.cookie_store(false)
.user_agent(user_agent())
.tcp_keepalive(Some(std::time::Duration::from_secs(60 * 5)))
.read_timeout(std::time::Duration::from_secs(10))
.connect_timeout(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build()
.context("Failed to create HTTP client")?,
)
.with(TransparentMiddleware)
.with(RateLimitMiddleware::new(rate_limiter.clone()))
.build();
Ok(Self {
sessions: SessionPool::new(http.clone(), base_url.clone()),
http,
base_url,
})
}
/// Validates offset parameter for search methods.
fn validate_offset(offset: i32) -> Result<()> {
if offset <= 0 {
Err(anyhow::anyhow!("Offset must be greater than 0"))
} else {
Ok(())
}
}
/// Builds common search parameters for list endpoints.
fn build_list_params(
&self,
search: &str,
term: &str,
offset: i32,
max_results: i32,
session_id: &str,
) -> Vec<(&str, String)> {
vec![
("searchTerm", search.to_string()),
("term", term.to_string()),
("offset", offset.to_string()),
("max", max_results.to_string()),
("uniqueSessionId", session_id.to_string()),
("_", nonce()),
]
}
/// Makes a GET request to a list endpoint and parses JSON response.
async fn get_list_endpoint<T>(
&self,
endpoint: &str,
search: &str,
term: &str,
offset: i32,
max_results: i32,
) -> Result<Vec<T>>
where
T: for<'de> serde::Deserialize<'de>,
{
Self::validate_offset(offset)?;
let session = self.sessions.acquire(term.parse()?).await?;
let url = format!("{}/classSearch/{}", self.base_url, endpoint);
let params = self.build_list_params(search, term, offset, max_results, &session.id());
let response = self
.http
.get(&url)
.query(&params)
.send()
.await
.with_context(|| format!("Failed to get {}", endpoint))?;
let data: Vec<T> = response
.json()
.await
.with_context(|| format!("Failed to parse {} response", endpoint))?;
Ok(data)
}
/// Builds search parameters for course search methods.
fn build_search_params(
&self,
query: &SearchQuery,
term: &str,
session_id: &str,
sort: &str,
sort_descending: bool,
) -> HashMap<String, String> {
let mut params = query.to_params();
params.insert("txt_term".to_string(), term.to_string());
params.insert("uniqueSessionId".to_string(), session_id.to_string());
params.insert("sortColumn".to_string(), sort.to_string());
params.insert(
"sortDirection".to_string(),
if sort_descending { "desc" } else { "asc" }.to_string(),
);
params.insert("startDatepicker".to_string(), String::new());
params.insert("endDatepicker".to_string(), String::new());
params
}
/// Performs a course search and handles common response processing.
async fn perform_search(
&self,
term: &str,
query: &SearchQuery,
sort: &str,
sort_descending: bool,
) -> Result<SearchResult, BannerApiError> {
let mut session = self.sessions.acquire(term.parse()?).await?;
if session.been_used() {
self.http
.post(format!("{}/classSearch/resetDataForm", self.base_url))
.header("Cookie", session.cookie())
.send()
.await
.map_err(|e| BannerApiError::RequestFailed(e.into()))?;
}
session.touch();
let params = self.build_search_params(query, term, &session.id(), sort, sort_descending);
debug!(
term = term,
subject = query.get_subject().map(|s| s.as_str()).unwrap_or("all"),
max_results = query.get_max_results(),
"Searching for courses"
);
let response = self
.http
.get(format!("{}/searchResults/searchResults", self.base_url))
.header("Cookie", session.cookie())
.query(&params)
.send()
.await
.context("Failed to search courses")?;
let status = response.status();
let url = response.url().clone();
let body = response
.text()
.await
.with_context(|| format!("Failed to read body (status={status})"))?;
let search_result: SearchResult = parse_json_with_context(&body).map_err(|e| {
BannerApiError::RequestFailed(anyhow!(
"Failed to parse search response (status={status}, url={url}): {e}"
))
})?;
// Check for signs of an invalid session
if search_result.path_mode.is_none() {
return Err(BannerApiError::InvalidSession(
"Search result path mode is none".to_string(),
));
} else if search_result.data.is_none() {
return Err(BannerApiError::InvalidSession(
"Search result data is none".to_string(),
));
}
if !search_result.success {
return Err(BannerApiError::RequestFailed(anyhow!(
"Search marked as unsuccessful by Banner API"
)));
}
Ok(search_result)
}
/// Retrieves a list of subjects from the Banner API.
pub async fn get_subjects(
&self,
search: &str,
term: &str,
offset: i32,
max_results: i32,
) -> Result<Vec<Pair>> {
self.get_list_endpoint("get_subject", search, term, offset, max_results)
.await
}
/// Retrieves a list of instructors from the Banner API.
pub async fn get_instructors(
&self,
search: &str,
term: &str,
offset: i32,
max_results: i32,
) -> Result<Vec<Instructor>> {
self.get_list_endpoint("get_instructor", search, term, offset, max_results)
.await
}
/// Retrieves a list of campuses from the Banner API.
pub async fn get_campuses(
&self,
search: &str,
term: &str,
offset: i32,
max_results: i32,
) -> Result<Vec<Pair>> {
self.get_list_endpoint("get_campus", search, term, offset, max_results)
.await
}
/// Retrieves meeting time information for a course.
pub async fn get_course_meeting_time(
&self,
term: &str,
crn: &str,
) -> Result<Vec<MeetingScheduleInfo>> {
let url = format!("{}/searchResults/getFacultyMeetingTimes", self.base_url);
let params = [("term", term), ("courseReferenceNumber", crn)];
let response = self
.http
.get(&url)
.query(&params)
.send()
.await
.context("Failed to get meeting times")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get meeting times: {}",
response.status()
));
} else if !response
.headers()
.get("Content-Type")
.unwrap_or(&HeaderValue::from_static(""))
.to_str()
.unwrap_or("")
.starts_with("application/json")
{
return Err(anyhow::anyhow!(
"Unexpected content type: {:?}",
response
.headers()
.get("Content-Type")
.unwrap_or(&HeaderValue::from_static("(empty)"))
.to_str()
.unwrap_or("(non-ascii)")
));
}
let response: MeetingTimesApiResponse =
response.json().await.context("Failed to parse response")?;
Ok(response
.fmt
.into_iter()
.map(|m| m.schedule_info())
.collect())
}
/// Performs a search for courses.
pub async fn search(
&self,
term: &str,
query: &SearchQuery,
sort: &str,
sort_descending: bool,
) -> Result<SearchResult, BannerApiError> {
debug!(
term = term,
subject = query.get_subject().map(|s| s.as_str()).unwrap_or("all"),
max_results = query.get_max_results(),
"Starting course search"
);
self.perform_search(term, query, sort, sort_descending)
.await
}
/// Retrieves a single course by CRN by issuing a minimal search
pub async fn get_course_by_crn(
&self,
term: &str,
crn: &str,
) -> Result<Option<Course>, BannerApiError> {
debug!(term = term, crn = crn, "Looking up course by CRN");
let query = SearchQuery::new()
.course_reference_number(crn)
.max_results(1);
let search_result = self
.perform_search(term, &query, "subjectDescription", false)
.await?;
// Additional validation for CRN search
if search_result.path_mode == Some("registration".to_string())
&& search_result.data.is_none()
{
return Err(BannerApiError::InvalidSession(
"Search result path mode is registration and data is none".to_string(),
));
}
Ok(search_result
.data
.and_then(|courses| courses.into_iter().next()))
}
/// Gets course details (placeholder - needs implementation).
pub async fn get_course_details(&self, term: &str, crn: &str) -> Result<ClassDetails> {
let body = serde_json::json!({
"term": term,
"courseReferenceNumber": crn,
"first": "first"
});
let url = format!("{}/searchResults/getClassDetails", self.base_url);
let response = self
.http
.post(&url)
.json(&body)
.send()
.await
.context("Failed to get course details")?;
let details: ClassDetails = response
.json()
.await
.context("Failed to parse course details response")?;
Ok(details)
}
}

11
src/banner/errors.rs Normal file
View File

@@ -0,0 +1,11 @@
//! Error types for the Banner API client.
use thiserror::Error;
#[derive(Debug, thiserror::Error)]
pub enum BannerApiError {
#[error("Banner session is invalid or expired: {0}")]
InvalidSession(String),
#[error(transparent)]
RequestFailed(#[from] anyhow::Error),
}

37
src/banner/json.rs Normal file
View File

@@ -0,0 +1,37 @@
//! JSON parsing utilities for the Banner API client.
use anyhow::Result;
/// Attempt to parse JSON and, on failure, include a contextual snippet of the
/// line where the error occurred. This prevents dumping huge JSON bodies to logs.
pub fn parse_json_with_context<T: serde::de::DeserializeOwned>(body: &str) -> Result<T> {
match serde_json::from_str::<T>(body) {
Ok(value) => Ok(value),
Err(err) => {
let (line, column) = (err.line(), err.column());
// let snippet = build_error_snippet(body, line, column, 80);
Err(anyhow::anyhow!("{err} at line {line}, column {column}",))
}
}
}
fn build_error_snippet(body: &str, line: usize, column: usize, context_len: usize) -> String {
let target_line = body.lines().nth(line.saturating_sub(1)).unwrap_or("");
if target_line.is_empty() {
return "(empty line)".to_string();
}
// column is 1-based, convert to 0-based for slicing
let error_idx = column.saturating_sub(1);
let half_len = context_len / 2;
let start = error_idx.saturating_sub(half_len);
let end = (error_idx + half_len).min(target_line.len());
let slice = &target_line[start..end];
let indicator_pos = error_idx - start;
let indicator = " ".repeat(indicator_pos) + "^";
format!("...{slice}...\n {indicator}")
}

49
src/banner/middleware.rs Normal file
View File

@@ -0,0 +1,49 @@
//! HTTP middleware for the Banner API client.
use http::Extensions;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next};
use tracing::{trace, warn};
pub struct TransparentMiddleware;
#[async_trait::async_trait]
impl Middleware for TransparentMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, reqwest_middleware::Error> {
trace!(
domain = req.url().domain(),
headers = ?req.headers(),
"{method} {path}",
method = req.method().to_string(),
path = req.url().path(),
);
let response_result = next.run(req, extensions).await;
match response_result {
Ok(response) => {
if response.status().is_success() {
trace!(
"{code} {reason} {path}",
code = response.status().as_u16(),
reason = response.status().canonical_reason().unwrap_or("??"),
path = response.url().path(),
);
Ok(response)
} else {
let e = response.error_for_status_ref().unwrap_err();
warn!(error = ?e, "Request failed (server)");
Ok(response)
}
}
Err(error) => {
warn!(error = ?error, "Request failed (middleware)");
Err(error)
}
}
}
}

27
src/banner/mod.rs Normal file
View File

@@ -0,0 +1,27 @@
#![allow(unused_imports)]
//! Banner API module for interacting with Ellucian Banner systems.
//!
//! This module provides functionality to:
//! - Search for courses and retrieve course information
//! - Manage Banner API sessions and authentication
//! - Scrape course data and cache it in Redis
//! - Generate ICS files and calendar links
pub mod api;
pub mod errors;
pub mod json;
pub mod middleware;
pub mod models;
pub mod query;
pub mod rate_limiter;
pub mod rate_limit_middleware;
pub mod session;
pub mod util;
pub use api::*;
pub use errors::*;
pub use models::*;
pub use query::*;
pub use rate_limiter::*;
pub use session::*;

View File

@@ -0,0 +1,21 @@
use serde::{Deserialize, Serialize};
/// Represents a key-value pair from the Banner API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pair {
pub code: String,
pub description: String,
}
/// Represents a term in the Banner system
pub type BannerTerm = Pair;
/// Represents an instructor in the Banner system
pub type Instructor = Pair;
impl BannerTerm {
/// Returns true if the term is in an archival (view-only) state
pub fn is_archived(&self) -> bool {
self.description.contains("View Only")
}
}

View File

@@ -0,0 +1,84 @@
use serde::{Deserialize, Serialize};
use super::meetings::FacultyItem;
use super::meetings::MeetingTimeResponse;
/// Course section attribute
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SectionAttribute {
pub class: String,
pub course_reference_number: String,
pub code: String,
pub description: String,
pub term_code: String,
#[serde(rename = "isZTCAttribute")]
pub is_ztc_attribute: bool,
}
/// Represents a single course returned from a search
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Course {
pub id: i32,
pub term: String,
pub term_desc: String,
pub course_reference_number: String,
pub part_of_term: String,
pub course_number: String,
pub subject: String,
pub subject_description: String,
pub sequence_number: String,
pub campus_description: String,
pub schedule_type_description: String,
pub course_title: String,
pub credit_hours: Option<i32>,
pub maximum_enrollment: i32,
pub enrollment: i32,
pub seats_available: i32,
pub wait_capacity: i32,
pub wait_count: i32,
pub cross_list: Option<String>,
pub cross_list_capacity: Option<i32>,
pub cross_list_count: Option<i32>,
pub cross_list_available: Option<i32>,
pub credit_hour_high: Option<i32>,
pub credit_hour_low: Option<i32>,
pub credit_hour_indicator: Option<String>,
pub open_section: bool,
pub link_identifier: Option<String>,
pub is_section_linked: bool,
pub subject_course: String,
pub reserved_seat_summary: Option<String>,
pub instructional_method: String,
pub instructional_method_description: String,
pub section_attributes: Vec<SectionAttribute>,
#[serde(default)]
pub faculty: Vec<FacultyItem>,
#[serde(default)]
pub meetings_faculty: Vec<MeetingTimeResponse>,
}
impl Course {
/// Returns the course title in the format "SUBJ #### - Course Title"
pub fn display_title(&self) -> String {
format!(
"{} {} - {}",
self.subject, self.course_number, self.course_title
)
}
/// Returns the name of the primary instructor, or "Unknown" if not available
pub fn primary_instructor_name(&self) -> &str {
self.faculty
.first()
.map(|f| f.display_name.as_str())
.unwrap_or("Unknown")
}
}
/// Class details (to be implemented)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassDetails {
// TODO: Implement based on Banner API response
}

View File

@@ -0,0 +1,570 @@
use bitflags::{Flags, bitflags};
use chrono::{DateTime, NaiveDate, NaiveTime, Timelike, Utc};
use serde::{Deserialize, Deserializer, Serialize};
use std::{cmp::Ordering, collections::HashSet, fmt::Display, str::FromStr};
use super::terms::Term;
/// Deserialize a string field into a u32
fn deserialize_string_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
s.parse::<u32>().map_err(serde::de::Error::custom)
}
/// Deserialize a string field into a Term
fn deserialize_string_to_term<'de, D>(deserializer: D) -> Result<Term, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
Term::from_str(&s).map_err(serde::de::Error::custom)
}
/// Represents a faculty member associated with a course
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FacultyItem {
pub banner_id: String, // e.g "@01647907" (can contain @ symbol)
pub category: Option<String>, // zero-padded digits
pub class: String, // internal class name
#[serde(deserialize_with = "deserialize_string_to_u32")]
pub course_reference_number: u32, // CRN, e.g 27294
pub display_name: String, // "LastName, FirstName"
pub email_address: String, // e.g. FirstName.LastName@utsaedu
pub primary_indicator: bool,
pub term: String, // e.g "202420"
}
/// Meeting time information for a course
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MeetingTime {
pub start_date: String, // MM/DD/YYYY, e.g 08/26/2025
pub end_date: String, // MM/DD/YYYY, e.g 08/26/2025
pub begin_time: Option<String>, // HHMM, e.g 1000
pub end_time: Option<String>, // HHMM, e.g 1100
pub category: String, // unknown meaning, e.g. 01, 02, etc
pub class: String, // internal class name, e.g. net.hedtech.banner.general.overallMeetingTimeDecorator
pub monday: bool, // true if the meeting time occurs on Monday
pub tuesday: bool, // true if the meeting time occurs on Tuesday
pub wednesday: bool, // true if the meeting time occurs on Wednesday
pub thursday: bool, // true if the meeting time occurs on Thursday
pub friday: bool, // true if the meeting time occurs on Friday
pub saturday: bool, // true if the meeting time occurs on Saturday
pub sunday: bool, // true if the meeting time occurs on Sunday
pub room: Option<String>, // e.g. 1.238
#[serde(deserialize_with = "deserialize_string_to_term")]
pub term: Term, // e.g 202510
pub building: Option<String>, // e.g NPB
pub building_description: Option<String>, // e.g North Paseo Building
pub campus: Option<String>, // campus code, e.g 11
pub campus_description: Option<String>, // name of campus, e.g Main Campus
pub course_reference_number: String, // CRN, e.g 27294
pub credit_hour_session: f64, // e.g. 30
pub hours_week: f64, // e.g. 30
pub meeting_schedule_type: String, // e.g AFF
pub meeting_type: String, // e.g HB, H2, H1, OS, OA, OH, ID, FF
pub meeting_type_description: String,
}
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MeetingDays: u8 {
const Monday = 1 << 0;
const Tuesday = 1 << 1;
const Wednesday = 1 << 2;
const Thursday = 1 << 3;
const Friday = 1 << 4;
const Saturday = 1 << 5;
const Sunday = 1 << 6;
}
}
impl MeetingDays {
/// Convert from the boolean flags in the raw API response
pub fn from_meeting_time(meeting_time: &MeetingTime) -> MeetingDays {
let mut days = MeetingDays::empty();
if meeting_time.monday {
days.insert(MeetingDays::Monday);
}
if meeting_time.tuesday {
days.insert(MeetingDays::Tuesday);
}
if meeting_time.wednesday {
days.insert(MeetingDays::Wednesday);
}
if meeting_time.thursday {
days.insert(MeetingDays::Thursday);
}
if meeting_time.friday {
days.insert(MeetingDays::Friday);
}
if meeting_time.saturday {
days.insert(MeetingDays::Saturday);
}
if meeting_time.sunday {
days.insert(MeetingDays::Sunday);
}
days
}
}
impl PartialOrd for MeetingDays {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.bits().cmp(&other.bits()))
}
}
impl From<DayOfWeek> for MeetingDays {
fn from(day: DayOfWeek) -> Self {
match day {
DayOfWeek::Monday => MeetingDays::Monday,
DayOfWeek::Tuesday => MeetingDays::Tuesday,
DayOfWeek::Wednesday => MeetingDays::Wednesday,
DayOfWeek::Thursday => MeetingDays::Thursday,
DayOfWeek::Friday => MeetingDays::Friday,
DayOfWeek::Saturday => MeetingDays::Saturday,
DayOfWeek::Sunday => MeetingDays::Sunday,
}
}
}
/// Days of the week for meeting schedules
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DayOfWeek {
Monday,
Tuesday,
Wednesday,
Thursday,
Friday,
Saturday,
Sunday,
}
impl DayOfWeek {
/// Convert to short string representation
pub fn to_short_string(self) -> &'static str {
match self {
DayOfWeek::Monday => "Mo",
DayOfWeek::Tuesday => "Tu",
DayOfWeek::Wednesday => "We",
DayOfWeek::Thursday => "Th",
DayOfWeek::Friday => "Fr",
DayOfWeek::Saturday => "Sa",
DayOfWeek::Sunday => "Su",
}
}
/// Convert to full string representation
pub fn to_full_string(self) -> &'static str {
match self {
DayOfWeek::Monday => "Monday",
DayOfWeek::Tuesday => "Tuesday",
DayOfWeek::Wednesday => "Wednesday",
DayOfWeek::Thursday => "Thursday",
DayOfWeek::Friday => "Friday",
DayOfWeek::Saturday => "Saturday",
DayOfWeek::Sunday => "Sunday",
}
}
}
impl TryFrom<MeetingDays> for DayOfWeek {
type Error = anyhow::Error;
fn try_from(days: MeetingDays) -> Result<Self, Self::Error> {
if days.contains_unknown_bits() {
return Err(anyhow::anyhow!("Unknown days: {:?}", days));
}
let count = days.into_iter().count();
if count == 1 {
return Ok(match days {
MeetingDays::Monday => DayOfWeek::Monday,
MeetingDays::Tuesday => DayOfWeek::Tuesday,
MeetingDays::Wednesday => DayOfWeek::Wednesday,
MeetingDays::Thursday => DayOfWeek::Thursday,
MeetingDays::Friday => DayOfWeek::Friday,
MeetingDays::Saturday => DayOfWeek::Saturday,
MeetingDays::Sunday => DayOfWeek::Sunday,
_ => unreachable!(),
});
}
Err(anyhow::anyhow!(
"Cannot convert multiple days to a single day: {days:?}"
))
}
}
/// Time range for meetings
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TimeRange {
pub start: NaiveTime,
pub end: NaiveTime,
}
impl TimeRange {
/// Parse time range from HHMM format strings
pub fn from_hhmm(start: &str, end: &str) -> Option<Self> {
let start_time = Self::parse_hhmm(start)?;
let end_time = Self::parse_hhmm(end)?;
Some(TimeRange {
start: start_time,
end: end_time,
})
}
/// Parse HHMM format string to NaiveTime
fn parse_hhmm(time_str: &str) -> Option<NaiveTime> {
if time_str.len() != 4 {
return None;
}
let hours = time_str[..2].parse::<u32>().ok()?;
let minutes = time_str[2..].parse::<u32>().ok()?;
if hours > 23 || minutes > 59 {
return None;
}
NaiveTime::from_hms_opt(hours, minutes, 0)
}
/// Format time in 12-hour format
pub fn format_12hr(&self) -> String {
format!(
"{}-{}",
Self::format_time_12hr(self.start),
Self::format_time_12hr(self.end)
)
}
/// Format a single time in 12-hour format
fn format_time_12hr(time: NaiveTime) -> String {
let hour = time.hour();
let minute = time.minute();
let meridiem = if hour < 12 { "AM" } else { "PM" };
format!("{hour}:{minute:02}{meridiem}")
}
/// Get duration in minutes
pub fn duration_minutes(&self) -> i64 {
let start_minutes = self.start.hour() as i64 * 60 + self.start.minute() as i64;
let end_minutes = self.end.hour() as i64 * 60 + self.end.minute() as i64;
end_minutes - start_minutes
}
}
impl PartialOrd for TimeRange {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.start.cmp(&other.start))
}
}
/// Date range for meetings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DateRange {
pub start: NaiveDate,
pub end: NaiveDate,
}
impl DateRange {
/// Parse date range from MM/DD/YYYY format strings
pub fn from_mm_dd_yyyy(start: &str, end: &str) -> Option<Self> {
let start_date = Self::parse_mm_dd_yyyy(start)?;
let end_date = Self::parse_mm_dd_yyyy(end)?;
Some(DateRange {
start: start_date,
end: end_date,
})
}
/// Parse MM/DD/YYYY format string to NaiveDate
fn parse_mm_dd_yyyy(date_str: &str) -> Option<NaiveDate> {
NaiveDate::parse_from_str(date_str, "%m/%d/%Y").ok()
}
/// Get the number of weeks between start and end dates
pub fn weeks_duration(&self) -> u32 {
let duration = self.end.signed_duration_since(self.start);
duration.num_weeks() as u32
}
/// Check if a specific date falls within this range
pub fn contains_date(&self, date: NaiveDate) -> bool {
date >= self.start && date <= self.end
}
}
/// Meeting schedule type enum
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MeetingType {
HybridBlended, // HB, H2, H1
OnlineSynchronous, // OS
OnlineAsynchronous, // OA
OnlineHybrid, // OH
IndependentStudy, // ID
FaceToFace, // FF
Unknown(String),
}
impl MeetingType {
/// Parse from the meeting type string
pub fn from_string(s: &str) -> Self {
match s {
"HB" | "H2" | "H1" => MeetingType::HybridBlended,
"OS" => MeetingType::OnlineSynchronous,
"OA" => MeetingType::OnlineAsynchronous,
"OH" => MeetingType::OnlineHybrid,
"ID" => MeetingType::IndependentStudy,
"FF" => MeetingType::FaceToFace,
other => MeetingType::Unknown(other.to_string()),
}
}
/// Get description for the meeting type
pub fn description(&self) -> &'static str {
match self {
MeetingType::HybridBlended => "Hybrid",
MeetingType::OnlineSynchronous => "Online Only",
MeetingType::OnlineAsynchronous => "Online Asynchronous",
MeetingType::OnlineHybrid => "Online Partial",
MeetingType::IndependentStudy => "To Be Arranged",
MeetingType::FaceToFace => "Face to Face",
MeetingType::Unknown(_) => "Unknown",
}
}
}
/// Meeting location information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MeetingLocation {
Online,
InPerson {
campus: String,
campus_description: String,
building: String,
building_description: String,
room: String,
},
}
impl MeetingLocation {
/// Create from raw MeetingTime data
pub fn from_meeting_time(meeting_time: &MeetingTime) -> Self {
if meeting_time.campus.is_none()
|| meeting_time.building.is_none()
|| meeting_time.building_description.is_none()
|| meeting_time.room.is_none()
|| meeting_time.campus_description.is_none()
|| meeting_time
.campus_description
.eq(&Some("Internet".to_string()))
{
return MeetingLocation::Online;
}
MeetingLocation::InPerson {
campus: meeting_time.campus.as_ref().unwrap().clone(),
campus_description: meeting_time.campus_description.as_ref().unwrap().clone(),
building: meeting_time.building.as_ref().unwrap().clone(),
building_description: meeting_time.building_description.as_ref().unwrap().clone(),
room: meeting_time.room.as_ref().unwrap().clone(),
}
}
}
impl Display for MeetingLocation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MeetingLocation::Online => write!(f, "Online"),
MeetingLocation::InPerson {
campus,
building,
building_description,
room,
..
} => write!(
f,
"{campus} | {building_name} | {building_code} {room}",
building_name = building_description,
building_code = building,
),
}
}
}
/// Clean, parsed meeting schedule information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeetingScheduleInfo {
pub days: MeetingDays,
pub time_range: Option<TimeRange>,
pub date_range: DateRange,
pub meeting_type: MeetingType,
pub location: MeetingLocation,
pub duration_weeks: u32,
}
impl MeetingScheduleInfo {
/// Create from raw MeetingTime data
pub fn from_meeting_time(meeting_time: &MeetingTime) -> Self {
let days = MeetingDays::from_meeting_time(meeting_time);
let time_range = match (&meeting_time.begin_time, &meeting_time.end_time) {
(Some(begin), Some(end)) => TimeRange::from_hhmm(begin, end),
_ => None,
};
let date_range =
DateRange::from_mm_dd_yyyy(&meeting_time.start_date, &meeting_time.end_date)
.unwrap_or_else(|| {
// Fallback to current date if parsing fails
let now = chrono::Utc::now().naive_utc().date();
DateRange {
start: now,
end: now,
}
});
let meeting_type = MeetingType::from_string(&meeting_time.meeting_type);
let location = MeetingLocation::from_meeting_time(meeting_time);
let duration_weeks = date_range.weeks_duration();
MeetingScheduleInfo {
days,
time_range,
date_range,
meeting_type,
location,
duration_weeks,
}
}
/// Convert the meeting days bitset to a enum vector
pub fn days_of_week(&self) -> Vec<DayOfWeek> {
self.days
.iter()
.map(|day| <MeetingDays as TryInto<DayOfWeek>>::try_into(day).unwrap())
.collect()
}
/// Get formatted days string
pub fn days_string(&self) -> Option<String> {
if self.days.is_empty() {
return None;
}
if self.days.is_all() {
return Some("Everyday".to_string());
}
let days_of_week = self.days_of_week();
if days_of_week.len() == 1 {
return Some(days_of_week[0].to_full_string().to_string());
}
// Mapper function to get the short string representation of the day of week
let mapper = {
let ambiguous = self.days.intersects(
MeetingDays::Tuesday
| MeetingDays::Thursday
| MeetingDays::Saturday
| MeetingDays::Sunday,
);
if ambiguous {
|day: &DayOfWeek| day.to_short_string().to_string()
} else {
|day: &DayOfWeek| day.to_short_string().chars().next().unwrap().to_string()
}
};
Some(days_of_week.iter().map(mapper).collect::<String>())
}
/// Returns a formatted string representing the location of the meeting
pub fn place_string(&self) -> String {
match &self.location {
MeetingLocation::Online => "Online".to_string(),
MeetingLocation::InPerson {
campus,
building,
building_description,
room,
..
} => format!(
"{} | {} | {} {}",
campus, building_description, building, room
),
}
}
/// Get the start and end date times for the meeting
///
/// Uses the start and end times of the meeting if available, otherwise defaults to midnight (00:00:00.000).
///
/// The returned times are in UTC.
pub fn datetime_range(&self) -> (DateTime<Utc>, DateTime<Utc>) {
let (start, end) = if let Some(time_range) = &self.time_range {
let start = self.date_range.start.and_time(time_range.start);
let end = self.date_range.end.and_time(time_range.end);
(start, end)
} else {
(
self.date_range.start.and_hms_opt(0, 0, 0).unwrap(),
self.date_range.end.and_hms_opt(0, 0, 0).unwrap(),
)
};
(start.and_utc(), end.and_utc())
}
}
impl PartialEq for MeetingScheduleInfo {
fn eq(&self, other: &Self) -> bool {
self.days == other.days && self.time_range == other.time_range
}
}
impl PartialOrd for MeetingScheduleInfo {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (&self.time_range, &other.time_range) {
(Some(self_time), Some(other_time)) => self_time.partial_cmp(other_time),
(None, None) => Some(self.days.partial_cmp(&other.days).unwrap()),
(Some(_), None) => Some(Ordering::Less),
(None, Some(_)) => Some(Ordering::Greater),
}
}
}
/// API response wrapper for meeting times
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeetingTimesApiResponse {
pub fmt: Vec<MeetingTimeResponse>,
}
/// Meeting time response wrapper
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MeetingTimeResponse {
pub category: Option<String>,
pub class: String,
pub course_reference_number: String,
#[serde(default)]
pub faculty: Vec<FacultyItem>,
pub meeting_time: MeetingTime,
pub term: String,
}
impl MeetingTimeResponse {
/// Get parsed meeting schedule information
pub fn schedule_info(&self) -> MeetingScheduleInfo {
MeetingScheduleInfo::from_meeting_time(&self.meeting_time)
}
}

14
src/banner/models/mod.rs Normal file
View File

@@ -0,0 +1,14 @@
//! Data models for the Banner API.
pub mod common;
pub mod courses;
pub mod meetings;
pub mod search;
pub mod terms;
// Re-export commonly used types
pub use common::*;
pub use courses::*;
pub use meetings::*;
pub use search::*;
pub use terms::*;

View File

@@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};
use super::courses::Course;
/// Search result wrapper
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchResult {
pub success: bool,
pub total_count: i32,
pub page_offset: i32,
pub page_max_size: i32,
pub path_mode: Option<String>,
pub search_results_config: Option<Vec<SearchResultConfig>>,
pub data: Option<Vec<Course>>,
}
/// Search result configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResultConfig {
pub config: String,
pub display: String,
}

247
src/banner/models/terms.rs Normal file
View File

@@ -0,0 +1,247 @@
use std::{ops::RangeInclusive, str::FromStr};
use anyhow::Context;
use chrono::{Datelike, Local, NaiveDate};
use serde::{Deserialize, Serialize};
/// The current year at the time of compilation
const CURRENT_YEAR: u32 = compile_time::date!().year() as u32;
/// The valid years for terms
/// We set a semi-static upper limit to avoid having to update this value while also keeping a tight bound
/// TODO: Recheck the lower bound, it's just a guess right now.
const VALID_YEARS: RangeInclusive<u32> = 2007..=(CURRENT_YEAR + 10);
/// Represents a term in the Banner system
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Term {
pub year: u32, // 2024, 2025, etc
pub season: Season,
}
/// Represents the term status at a specific point in time
#[derive(Debug, Clone)]
pub enum TermPoint {
/// Currently in a term
InTerm { current: Term },
/// Between terms, with the next term specified
BetweenTerms { next: Term },
}
/// Represents a season within a term
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Season {
Fall,
Spring,
Summer,
}
impl Term {
/// Returns the current term status - either currently in a term or between terms
pub fn get_current() -> TermPoint {
let now = Local::now().naive_local();
Self::get_status_for_date(now.date())
}
/// Returns the current term status for a specific date
pub fn get_status_for_date(date: NaiveDate) -> TermPoint {
let literal_year = date.year() as u32;
let day_of_year = date.ordinal();
let ranges = Self::get_season_ranges(literal_year);
// If we're past the end of the summer term, we're 'in' the next school year.
let term_year = if day_of_year > ranges.summer.end {
literal_year + 1
} else {
literal_year
};
if (day_of_year < ranges.spring.start) || (day_of_year >= ranges.fall.end) {
// Fall over, Spring not yet begun
TermPoint::BetweenTerms {
next: Term {
year: term_year,
season: Season::Spring,
},
}
} else if (day_of_year >= ranges.spring.start) && (day_of_year < ranges.spring.end) {
// Spring
TermPoint::InTerm {
current: Term {
year: term_year,
season: Season::Spring,
},
}
} else if day_of_year < ranges.summer.start {
// Spring over, Summer not yet begun
TermPoint::BetweenTerms {
next: Term {
year: term_year,
season: Season::Summer,
},
}
} else if (day_of_year >= ranges.summer.start) && (day_of_year < ranges.summer.end) {
// Summer
TermPoint::InTerm {
current: Term {
year: term_year,
season: Season::Summer,
},
}
} else if day_of_year < ranges.fall.start {
// Summer over, Fall not yet begun
TermPoint::BetweenTerms {
next: Term {
year: term_year,
season: Season::Fall,
},
}
} else if (day_of_year >= ranges.fall.start) && (day_of_year < ranges.fall.end) {
// Fall
TermPoint::InTerm {
current: Term {
year: term_year,
season: Season::Fall,
},
}
} else {
// This should never happen, but Rust requires exhaustive matching
panic!("Impossible code reached (dayOfYear: {})", day_of_year);
}
}
/// Returns the start and end day of each term for the given year.
/// The ranges are inclusive of the start day and exclusive of the end day.
fn get_season_ranges(year: u32) -> SeasonRanges {
let spring_start = NaiveDate::from_ymd_opt(year as i32, 1, 14)
.unwrap()
.ordinal();
let spring_end = NaiveDate::from_ymd_opt(year as i32, 5, 1)
.unwrap()
.ordinal();
let summer_start = NaiveDate::from_ymd_opt(year as i32, 5, 25)
.unwrap()
.ordinal();
let summer_end = NaiveDate::from_ymd_opt(year as i32, 8, 15)
.unwrap()
.ordinal();
let fall_start = NaiveDate::from_ymd_opt(year as i32, 8, 18)
.unwrap()
.ordinal();
let fall_end = NaiveDate::from_ymd_opt(year as i32, 12, 10)
.unwrap()
.ordinal();
SeasonRanges {
spring: YearDayRange {
start: spring_start,
end: spring_end,
},
summer: YearDayRange {
start: summer_start,
end: summer_end,
},
fall: YearDayRange {
start: fall_start,
end: fall_end,
},
}
}
/// Returns a long string representation of the term (e.g., "Fall 2025")
pub fn to_long_string(&self) -> String {
format!("{} {}", self.season, self.year)
}
}
impl TermPoint {
/// Returns the inner Term regardless of the status
pub fn inner(&self) -> &Term {
match self {
TermPoint::InTerm { current } => current,
TermPoint::BetweenTerms { next } => next,
}
}
}
/// Represents the start and end day of each term within a year
#[derive(Debug, Clone)]
struct SeasonRanges {
spring: YearDayRange,
summer: YearDayRange,
fall: YearDayRange,
}
/// Represents the start and end day of a term within a year
#[derive(Debug, Clone)]
struct YearDayRange {
start: u32,
end: u32,
}
impl std::fmt::Display for Term {
/// Returns the term in the format YYYYXX, where YYYY is the year and XX is the season code
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{year}{season}",
year = self.year,
season = self.season.to_str()
)
}
}
impl Season {
/// Returns the season code as a string
fn to_str(self) -> &'static str {
match self {
Season::Fall => "10",
Season::Spring => "20",
Season::Summer => "30",
}
}
}
impl std::fmt::Display for Season {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Season::Fall => write!(f, "Fall"),
Season::Spring => write!(f, "Spring"),
Season::Summer => write!(f, "Summer"),
}
}
}
impl FromStr for Season {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let season = match s {
"10" => Season::Fall,
"20" => Season::Spring,
"30" => Season::Summer,
_ => return Err(anyhow::anyhow!("Invalid season: {s}")),
};
Ok(season)
}
}
impl FromStr for Term {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() != 6 {
return Err(anyhow::anyhow!("Term string must be 6 characters"));
}
let year = s[0..4].parse::<u32>().context("Failed to parse year")?;
if !VALID_YEARS.contains(&year) {
return Err(anyhow::anyhow!("Year out of range"));
}
let season =
Season::from_str(&s[4..6]).map_err(|e| anyhow::anyhow!("Invalid season: {}", e))?;
Ok(Term { year, season })
}
}

334
src/banner/query.rs Normal file
View File

@@ -0,0 +1,334 @@
//! Query builder for Banner API course searches.
use std::collections::HashMap;
use std::time::Duration;
/// Range of two integers
#[derive(Debug, Clone)]
pub struct Range {
pub low: i32,
pub high: i32,
}
/// Builder for constructing Banner API search queries
#[derive(Debug, Clone, Default)]
pub struct SearchQuery {
subject: Option<String>,
title: Option<String>,
keywords: Option<Vec<String>>,
course_reference_number: Option<String>,
open_only: Option<bool>,
term_part: Option<Vec<String>>,
campus: Option<Vec<String>>,
instructional_method: Option<Vec<String>>,
attributes: Option<Vec<String>>,
instructor: Option<Vec<u64>>,
start_time: Option<Duration>,
end_time: Option<Duration>,
min_credits: Option<i32>,
max_credits: Option<i32>,
offset: i32,
max_results: i32,
course_number_range: Option<Range>,
}
impl SearchQuery {
/// Creates a new SearchQuery with default values
pub fn new() -> Self {
Self {
max_results: 8,
offset: 0,
..Default::default()
}
}
/// Sets the subject for the query
pub fn subject<S: Into<String>>(mut self, subject: S) -> Self {
self.subject = Some(subject.into());
self
}
/// Sets the title for the query
pub fn title<S: Into<String>>(mut self, title: S) -> Self {
self.title = Some(title.into());
self
}
/// Sets the course reference number (CRN) for the query
pub fn course_reference_number<S: Into<String>>(mut self, crn: S) -> Self {
self.course_reference_number = Some(crn.into());
self
}
/// Sets the keywords for the query
pub fn keywords(mut self, keywords: Vec<String>) -> Self {
self.keywords = Some(keywords);
self
}
/// Adds a keyword to the query
pub fn keyword<S: Into<String>>(mut self, keyword: S) -> Self {
match &mut self.keywords {
Some(keywords) => keywords.push(keyword.into()),
None => self.keywords = Some(vec![keyword.into()]),
}
self
}
/// Sets whether to search for open courses only
pub fn open_only(mut self, open_only: bool) -> Self {
self.open_only = Some(open_only);
self
}
/// Sets the term part for the query
pub fn term_part(mut self, term_part: Vec<String>) -> Self {
self.term_part = Some(term_part);
self
}
/// Sets the campuses for the query
pub fn campus(mut self, campus: Vec<String>) -> Self {
self.campus = Some(campus);
self
}
/// Sets the instructional methods for the query
pub fn instructional_method(mut self, instructional_method: Vec<String>) -> Self {
self.instructional_method = Some(instructional_method);
self
}
/// Sets the attributes for the query
pub fn attributes(mut self, attributes: Vec<String>) -> Self {
self.attributes = Some(attributes);
self
}
/// Sets the instructors for the query
pub fn instructor(mut self, instructor: Vec<u64>) -> Self {
self.instructor = Some(instructor);
self
}
/// Sets the start time for the query
pub fn start_time(mut self, start_time: Duration) -> Self {
self.start_time = Some(start_time);
self
}
/// Sets the end time for the query
pub fn end_time(mut self, end_time: Duration) -> Self {
self.end_time = Some(end_time);
self
}
/// Sets the credit range for the query
pub fn credits(mut self, low: i32, high: i32) -> Self {
self.min_credits = Some(low);
self.max_credits = Some(high);
self
}
/// Sets the minimum credits for the query
pub fn min_credits(mut self, value: i32) -> Self {
self.min_credits = Some(value);
self
}
/// Sets the maximum credits for the query
pub fn max_credits(mut self, value: i32) -> Self {
self.max_credits = Some(value);
self
}
/// Sets the course number range for the query
pub fn course_numbers(mut self, low: i32, high: i32) -> Self {
self.course_number_range = Some(Range { low, high });
self
}
/// Sets the offset for pagination
pub fn offset(mut self, offset: i32) -> Self {
self.offset = offset;
self
}
/// Sets the maximum number of results to return
pub fn max_results(mut self, max_results: i32) -> Self {
self.max_results = max_results;
self
}
/// Gets the subject field
pub fn get_subject(&self) -> Option<&String> {
self.subject.as_ref()
}
/// Gets the max_results field
pub fn get_max_results(&self) -> i32 {
self.max_results
}
/// Converts the query into URL parameters for the Banner API
pub fn to_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
if let Some(ref subject) = self.subject {
params.insert("txt_subject".to_string(), subject.clone());
}
if let Some(ref title) = self.title {
params.insert("txt_courseTitle".to_string(), title.trim().to_string());
}
if let Some(ref crn) = self.course_reference_number {
params.insert("txt_courseReferenceNumber".to_string(), crn.clone());
}
if let Some(ref keywords) = self.keywords {
params.insert("txt_keywordlike".to_string(), keywords.join(" "));
}
if self.open_only.is_some() {
params.insert("chk_open_only".to_string(), "true".to_string());
}
if let Some(ref term_part) = self.term_part {
params.insert("txt_partOfTerm".to_string(), term_part.join(","));
}
if let Some(ref campus) = self.campus {
params.insert("txt_campus".to_string(), campus.join(","));
}
if let Some(ref attributes) = self.attributes {
params.insert("txt_attribute".to_string(), attributes.join(","));
}
if let Some(ref instructor) = self.instructor {
let instructor_str = instructor
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(",");
params.insert("txt_instructor".to_string(), instructor_str);
}
if let Some(start_time) = self.start_time {
let (hour, minute, meridiem) = format_time_parameter(start_time);
params.insert("select_start_hour".to_string(), hour);
params.insert("select_start_min".to_string(), minute);
params.insert("select_start_ampm".to_string(), meridiem);
}
if let Some(end_time) = self.end_time {
let (hour, minute, meridiem) = format_time_parameter(end_time);
params.insert("select_end_hour".to_string(), hour);
params.insert("select_end_min".to_string(), minute);
params.insert("select_end_ampm".to_string(), meridiem);
}
if let Some(min_credits) = self.min_credits {
params.insert("txt_credithourlow".to_string(), min_credits.to_string());
}
if let Some(max_credits) = self.max_credits {
params.insert("txt_credithourhigh".to_string(), max_credits.to_string());
}
if let Some(ref range) = self.course_number_range {
params.insert("txt_course_number_range".to_string(), range.low.to_string());
params.insert(
"txt_course_number_range_to".to_string(),
range.high.to_string(),
);
}
params.insert("pageOffset".to_string(), self.offset.to_string());
params.insert("pageMaxSize".to_string(), self.max_results.to_string());
params
}
}
/// Formats a Duration into hour, minute, and meridiem strings for Banner API
fn format_time_parameter(duration: Duration) -> (String, String, String) {
let total_minutes = duration.as_secs() / 60;
let hours = total_minutes / 60;
let minutes = total_minutes % 60;
let minute_str = minutes.to_string();
if hours >= 12 {
let meridiem = "PM".to_string();
let hour_str = if hours >= 13 {
(hours - 12).to_string()
} else {
hours.to_string()
};
(hour_str, minute_str, meridiem)
} else {
let meridiem = "AM".to_string();
let hour_str = hours.to_string();
(hour_str, minute_str, meridiem)
}
}
impl std::fmt::Display for SearchQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut parts = Vec::new();
if let Some(ref subject) = self.subject {
parts.push(format!("subject={subject}"));
}
if let Some(ref title) = self.title {
parts.push(format!("title={}", title.trim()));
}
if let Some(ref keywords) = self.keywords {
parts.push(format!("keywords={}", keywords.join(" ")));
}
if self.open_only.is_some() {
parts.push("openOnly=true".to_string());
}
if let Some(ref term_part) = self.term_part {
parts.push(format!("termPart={}", term_part.join(",")));
}
if let Some(ref campus) = self.campus {
parts.push(format!("campus={}", campus.join(",")));
}
if let Some(ref attributes) = self.attributes {
parts.push(format!("attributes={}", attributes.join(",")));
}
if let Some(ref instructor) = self.instructor {
let instructor_str = instructor
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(",");
parts.push(format!("instructor={instructor_str}"));
}
if let Some(start_time) = self.start_time {
let (hour, minute, meridiem) = format_time_parameter(start_time);
parts.push(format!("startTime={hour}:{minute}:{meridiem}"));
}
if let Some(end_time) = self.end_time {
let (hour, minute, meridiem) = format_time_parameter(end_time);
parts.push(format!("endTime={hour}:{minute}:{meridiem}"));
}
if let Some(min_credits) = self.min_credits {
parts.push(format!("minCredits={min_credits}"));
}
if let Some(max_credits) = self.max_credits {
parts.push(format!("maxCredits={max_credits}"));
}
if let Some(ref range) = self.course_number_range {
parts.push(format!("courseNumberRange={}-{}", range.low, range.high));
}
parts.push(format!("offset={}", self.offset));
parts.push(format!("maxResults={}", self.max_results));
write!(f, "{}", parts.join(", "))
}
}

View File

@@ -0,0 +1,101 @@
//! HTTP middleware that enforces rate limiting for Banner API requests.
use crate::banner::rate_limiter::{RequestType, SharedRateLimiter};
use http::Extensions;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next};
use tracing::{debug, warn, trace};
use url::Url;
/// Middleware that enforces rate limiting based on request URL patterns
pub struct RateLimitMiddleware {
rate_limiter: SharedRateLimiter,
}
impl RateLimitMiddleware {
/// Creates a new rate limiting middleware
pub fn new(rate_limiter: SharedRateLimiter) -> Self {
Self { rate_limiter }
}
/// Determines the request type based on the URL path
fn get_request_type(url: &Url) -> RequestType {
let path = url.path();
if path.contains("/registration")
|| path.contains("/selfServiceMenu")
|| path.contains("/term/termSelection")
{
RequestType::Session
} else if path.contains("/searchResults") || path.contains("/classSearch") {
RequestType::Search
} else if path.contains("/getTerms")
|| path.contains("/getSubjects")
|| path.contains("/getCampuses")
{
RequestType::Metadata
} else if path.contains("/resetDataForm") {
RequestType::Reset
} else {
// Default to search for unknown endpoints
RequestType::Search
}
}
}
#[async_trait::async_trait]
impl Middleware for RateLimitMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, reqwest_middleware::Error> {
let request_type = Self::get_request_type(req.url());
trace!(
url = %req.url(),
request_type = ?request_type,
"Rate limiting request"
);
// Wait for permission to make the request
self.rate_limiter.wait_for_permission(request_type).await;
trace!(
url = %req.url(),
request_type = ?request_type,
"Rate limit permission granted, making request"
);
// Make the actual request
let response_result = next.run(req, extensions).await;
match response_result {
Ok(response) => {
if response.status().is_success() {
trace!(
url = %response.url(),
status = response.status().as_u16(),
"Request completed successfully"
);
} else {
warn!(
url = %response.url(),
status = response.status().as_u16(),
"Request completed with error status"
);
}
Ok(response)
}
Err(error) => {
warn!(
url = %error.url().unwrap_or(&Url::parse("unknown").unwrap()),
error = ?error,
"Request failed"
);
Err(error)
}
}
}
}

169
src/banner/rate_limiter.rs Normal file
View File

@@ -0,0 +1,169 @@
//! Rate limiting for Banner API requests to prevent overwhelming the server.
use governor::{
Quota, RateLimiter,
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, trace, warn};
/// Different types of Banner API requests with different rate limits
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RequestType {
/// Session creation and management (very conservative)
Session,
/// Course search requests (moderate)
Search,
/// Term and metadata requests (moderate)
Metadata,
/// Data form resets (low priority)
Reset,
}
/// Rate limiter configuration for different request types
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
/// Requests per minute for session operations
pub session_rpm: u32,
/// Requests per minute for search operations
pub search_rpm: u32,
/// Requests per minute for metadata operations
pub metadata_rpm: u32,
/// Requests per minute for reset operations
pub reset_rpm: u32,
/// Burst allowance (extra requests allowed in short bursts)
pub burst_allowance: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
// Very conservative for session creation
session_rpm: 6, // 1 every 10 seconds
// Moderate for search operations
search_rpm: 30, // 1 every 2 seconds
// Moderate for metadata
metadata_rpm: 20, // 1 every 3 seconds
// Low for resets
reset_rpm: 10, // 1 every 6 seconds
// Allow small bursts
burst_allowance: 3,
}
}
}
/// A rate limiter that manages different request types with different limits
pub struct BannerRateLimiter {
session_limiter: RateLimiter<NotKeyed, InMemoryState, DefaultClock>,
search_limiter: RateLimiter<NotKeyed, InMemoryState, DefaultClock>,
metadata_limiter: RateLimiter<NotKeyed, InMemoryState, DefaultClock>,
reset_limiter: RateLimiter<NotKeyed, InMemoryState, DefaultClock>,
config: RateLimitConfig,
}
impl BannerRateLimiter {
/// Creates a new rate limiter with the given configuration
pub fn new(config: RateLimitConfig) -> Self {
let session_quota = Quota::with_period(Duration::from_secs(60) / config.session_rpm)
.unwrap()
.allow_burst(NonZeroU32::new(config.burst_allowance).unwrap());
let search_quota = Quota::with_period(Duration::from_secs(60) / config.search_rpm)
.unwrap()
.allow_burst(NonZeroU32::new(config.burst_allowance).unwrap());
let metadata_quota = Quota::with_period(Duration::from_secs(60) / config.metadata_rpm)
.unwrap()
.allow_burst(NonZeroU32::new(config.burst_allowance).unwrap());
let reset_quota = Quota::with_period(Duration::from_secs(60) / config.reset_rpm)
.unwrap()
.allow_burst(NonZeroU32::new(config.burst_allowance).unwrap());
Self {
session_limiter: RateLimiter::direct(session_quota),
search_limiter: RateLimiter::direct(search_quota),
metadata_limiter: RateLimiter::direct(metadata_quota),
reset_limiter: RateLimiter::direct(reset_quota),
config,
}
}
/// Waits for permission to make a request of the given type
pub async fn wait_for_permission(&self, request_type: RequestType) {
let limiter = match request_type {
RequestType::Session => &self.session_limiter,
RequestType::Search => &self.search_limiter,
RequestType::Metadata => &self.metadata_limiter,
RequestType::Reset => &self.reset_limiter,
};
trace!(request_type = ?request_type, "Waiting for rate limit permission");
// Wait until we can make the request
limiter.until_ready().await;
trace!(request_type = ?request_type, "Rate limit permission granted");
}
/// Checks if a request of the given type would be allowed immediately
pub fn check_permission(&self, request_type: RequestType) -> bool {
let limiter = match request_type {
RequestType::Session => &self.session_limiter,
RequestType::Search => &self.search_limiter,
RequestType::Metadata => &self.metadata_limiter,
RequestType::Reset => &self.reset_limiter,
};
limiter.check().is_ok()
}
/// Gets the current configuration
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
/// Updates the rate limit configuration
pub fn update_config(&mut self, config: RateLimitConfig) {
self.config = config;
// Note: In a production system, you'd want to recreate the limiters
// with the new configuration, but for simplicity we'll just update
// the config field here.
warn!("Rate limit configuration updated - restart required for full effect");
}
}
impl Default for BannerRateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::default())
}
}
/// A shared rate limiter instance
pub type SharedRateLimiter = Arc<BannerRateLimiter>;
/// Creates a new shared rate limiter with default configuration
pub fn create_shared_rate_limiter() -> SharedRateLimiter {
Arc::new(BannerRateLimiter::default())
}
/// Creates a new shared rate limiter with custom configuration
pub fn create_shared_rate_limiter_with_config(config: RateLimitConfig) -> SharedRateLimiter {
Arc::new(BannerRateLimiter::new(config))
}
/// Conversion from config module's RateLimitingConfig to this module's RateLimitConfig
impl From<crate::config::RateLimitingConfig> for RateLimitConfig {
fn from(config: crate::config::RateLimitingConfig) -> Self {
Self {
session_rpm: config.session_rpm,
search_rpm: config.search_rpm,
metadata_rpm: config.metadata_rpm,
reset_rpm: config.reset_rpm,
burst_allowance: config.burst_allowance,
}
}
}

460
src/banner/session.rs Normal file
View File

@@ -0,0 +1,460 @@
//! Session management for Banner API.
use crate::banner::BannerTerm;
use crate::banner::models::Term;
use anyhow::{Context, Result};
use cookie::Cookie;
use dashmap::DashMap;
use governor::state::InMemoryState;
use governor::{Quota, RateLimiter};
use once_cell::sync::Lazy;
use rand::distr::{Alphanumeric, SampleString};
use reqwest_middleware::ClientWithMiddleware;
use std::collections::{HashMap, VecDeque};
use std::num::NonZeroU32;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, info, trace};
use url::Url;
const SESSION_EXPIRY: Duration = Duration::from_secs(25 * 60); // 25 minutes
// A global rate limiter to ensure we only try to create one new session every 10 seconds,
// preventing us from overwhelming the server with session creation requests.
static SESSION_CREATION_RATE_LIMITER: Lazy<
RateLimiter<governor::state::direct::NotKeyed, InMemoryState, governor::clock::DefaultClock>,
> = Lazy::new(|| RateLimiter::direct(Quota::with_period(Duration::from_secs(10)).unwrap()));
/// Represents an active anonymous session within the Banner API.
/// Identified by multiple persistent cookies, as well as a client-generated "unique session ID".
#[derive(Debug, Clone)]
pub struct BannerSession {
// Randomly generated
pub unique_session_id: String,
// Timestamp of creation
created_at: Instant,
// Timestamp of last activity
last_activity: Option<Instant>,
// Cookie values from initial registration page
jsessionid: String,
ssb_cookie: String,
}
/// Generates a new session ID mimicking Banner's format
fn generate_session_id() -> String {
let random_part = Alphanumeric.sample_string(&mut rand::rng(), 5);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
format!("{}{}", random_part, timestamp)
}
/// Generates a timestamp-based nonce
pub fn nonce() -> String {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
.to_string()
}
impl BannerSession {
/// Creates a new session
pub async fn new(unique_session_id: &str, jsessionid: &str, ssb_cookie: &str) -> Result<Self> {
let now = Instant::now();
Ok(Self {
created_at: now,
last_activity: None,
unique_session_id: unique_session_id.to_string(),
jsessionid: jsessionid.to_string(),
ssb_cookie: ssb_cookie.to_string(),
})
}
/// Returns the unique session ID
pub fn id(&self) -> String {
self.unique_session_id.clone()
}
/// Updates the last activity timestamp
pub fn touch(&mut self) {
trace!(id = self.unique_session_id, "Session was used");
self.last_activity = Some(Instant::now());
}
/// Returns true if the session is expired
pub fn is_expired(&self) -> bool {
self.last_activity.unwrap_or(self.created_at).elapsed() > SESSION_EXPIRY
}
/// Returns a string used to for the "Cookie" header
pub fn cookie(&self) -> String {
format!(
"JSESSIONID={}; SSB_COOKIE={}",
self.jsessionid, self.ssb_cookie
)
}
pub fn been_used(&self) -> bool {
self.last_activity.is_some()
}
}
/// A smart pointer that returns a BannerSession to the pool when dropped.
pub struct PooledSession {
session: Option<BannerSession>,
// This Arc points directly to the term-specific pool.
pool: Arc<TermPool>,
}
impl PooledSession {
pub fn been_used(&self) -> bool {
self.session.as_ref().unwrap().been_used()
}
}
impl Deref for PooledSession {
type Target = BannerSession;
fn deref(&self) -> &Self::Target {
// The option is only ever None after drop is called, so this is safe.
self.session.as_ref().unwrap()
}
}
impl DerefMut for PooledSession {
fn deref_mut(&mut self) -> &mut Self::Target {
self.session.as_mut().unwrap()
}
}
/// The magic happens here: when the guard goes out of scope, this is called.
impl Drop for PooledSession {
fn drop(&mut self) {
if let Some(session) = self.session.take() {
let pool = self.pool.clone();
// Since drop() cannot be async, we spawn a task to return the session.
tokio::spawn(async move {
pool.release(session).await;
});
}
}
}
pub struct TermPool {
sessions: Mutex<VecDeque<BannerSession>>,
notifier: Notify,
is_creating: Mutex<bool>,
}
impl TermPool {
fn new() -> Self {
Self {
sessions: Mutex::new(VecDeque::new()),
notifier: Notify::new(),
is_creating: Mutex::new(false),
}
}
async fn release(&self, session: BannerSession) {
let id = session.unique_session_id.clone();
if session.is_expired() {
trace!(id = id, "Session is now expired, dropping.");
// Wake up a waiter, as it might need to create a new session
// if this was the last one.
self.notifier.notify_one();
return;
}
let mut queue = self.sessions.lock().await;
queue.push_back(session);
let queue_size = queue.len();
drop(queue); // Release lock before notifying
trace!(id = id, queue_size, "Session returned to pool");
self.notifier.notify_one();
}
}
pub struct SessionPool {
sessions: DashMap<Term, Arc<TermPool>>,
http: ClientWithMiddleware,
base_url: String,
}
impl SessionPool {
pub fn new(http: ClientWithMiddleware, base_url: String) -> Self {
Self {
sessions: DashMap::new(),
http,
base_url,
}
}
/// Acquires a session from the pool.
/// If no sessions are available, a new one is created on demand,
/// respecting the global rate limit.
pub async fn acquire(&self, term: Term) -> Result<PooledSession> {
let term_pool = self
.sessions
.entry(term)
.or_insert_with(|| Arc::new(TermPool::new()))
.clone();
loop {
// Fast path: Try to get an existing, non-expired session.
{
let mut queue = term_pool.sessions.lock().await;
if let Some(session) = queue.pop_front() {
if !session.is_expired() {
trace!(id = session.unique_session_id, "Reusing session from pool");
return Ok(PooledSession {
session: Some(session),
pool: Arc::clone(&term_pool),
});
} else {
trace!(
id = session.unique_session_id,
"Popped an expired session, discarding."
);
}
}
} // MutexGuard is dropped, lock is released.
// Slow path: No sessions available. We must either wait or become the creator.
let mut is_creating_guard = term_pool.is_creating.lock().await;
if *is_creating_guard {
// Another task is already creating a session. Release the lock and wait.
drop(is_creating_guard);
trace!("Another task is creating a session, waiting for notification...");
term_pool.notifier.notified().await;
// Loop back to the top to try the fast path again.
continue;
}
// This task is now the designated creator.
*is_creating_guard = true;
drop(is_creating_guard);
// Race: wait for a session to be returned OR for the rate limiter to allow a new one.
trace!("Pool empty, racing notifier vs rate limiter...");
tokio::select! {
_ = term_pool.notifier.notified() => {
// A session was returned while we were waiting!
// We are no longer the creator. Reset the flag and loop to race for the new session.
trace!("Notified that a session was returned. Looping to retry.");
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
continue;
}
_ = SESSION_CREATION_RATE_LIMITER.until_ready() => {
// The rate limit has elapsed. It's our job to create the session.
trace!("Rate limiter ready. Proceeding to create a new session.");
let new_session_result = self.create_session(&term).await;
// After creation, we are no longer the creator. Reset the flag
// and notify all other waiting tasks.
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
term_pool.notifier.notify_waiters();
match new_session_result {
Ok(new_session) => {
debug!(id = new_session.unique_session_id, "Successfully created new session");
return Ok(PooledSession {
session: Some(new_session),
pool: term_pool,
});
}
Err(e) => {
// Propagate the error if session creation failed.
return Err(e.context("Failed to create new session in pool"));
}
}
}
}
}
}
/// Sets up initial session cookies by making required Banner API requests
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> {
info!(term = %term, "setting up banner session");
// The 'register' or 'search' registration page
let initial_registration = self
.http
.get(format!("{}/registration", self.base_url))
.send()
.await?;
// TODO: Validate success
let cookies = initial_registration
.headers()
.get_all("Set-Cookie")
.iter()
.filter_map(|header_value| {
if let Ok(cookie) = Cookie::parse(header_value.to_str().unwrap()) {
Some((cookie.name().to_string(), cookie.value().to_string()))
} else {
None
}
})
.collect::<HashMap<String, String>>();
if !cookies.contains_key("JSESSIONID") || !cookies.contains_key("SSB_COOKIE") {
return Err(anyhow::anyhow!("Failed to get cookies"));
}
let jsessionid = cookies.get("JSESSIONID").unwrap();
let ssb_cookie = cookies.get("SSB_COOKIE").unwrap();
let cookie_header = format!("JSESSIONID={}; SSB_COOKIE={}", jsessionid, ssb_cookie);
trace!(
jsessionid = jsessionid,
ssb_cookie = ssb_cookie,
"New session cookies acquired"
);
self.http
.get(format!("{}/selfServiceMenu/data", self.base_url))
.header("Cookie", &cookie_header)
.send()
.await?
.error_for_status()
.context("Failed to get data page")?;
self.http
.get(format!("{}/term/termSelection", self.base_url))
.header("Cookie", &cookie_header)
.query(&[("mode", "search")])
.send()
.await?
.error_for_status()
.context("Failed to get term selection page")?;
// TOOD: Validate success
let terms = self.get_terms("", 1, 10).await?;
if !terms.iter().any(|t| t.code == term.to_string()) {
return Err(anyhow::anyhow!("Failed to get term search response"));
}
let specific_term_search_response = self.get_terms(&term.to_string(), 1, 10).await?;
if !specific_term_search_response
.iter()
.any(|t| t.code == term.to_string())
{
return Err(anyhow::anyhow!("Failed to get term search response"));
}
let unique_session_id = generate_session_id();
self.select_term(&term.to_string(), &unique_session_id, &cookie_header)
.await?;
BannerSession::new(&unique_session_id, jsessionid, ssb_cookie).await
}
/// Retrieves a list of terms from the Banner API.
pub async fn get_terms(
&self,
search: &str,
page: i32,
max_results: i32,
) -> Result<Vec<BannerTerm>> {
if page <= 0 {
return Err(anyhow::anyhow!("Page must be greater than 0"));
}
let url = format!("{}/classSearch/getTerms", self.base_url);
let params = [
("searchTerm", search),
("offset", &page.to_string()),
("max", &max_results.to_string()),
("_", &nonce()),
];
let response = self
.http
.get(&url)
.query(&params)
.send()
.await
.with_context(|| "Failed to get terms".to_string())?;
let terms: Vec<BannerTerm> = response
.json()
.await
.context("Failed to parse terms response")?;
Ok(terms)
}
/// Selects a term for the current session
pub async fn select_term(
&self,
term: &str,
unique_session_id: &str,
cookie_header: &str,
) -> Result<()> {
let form_data = [
("term", term),
("studyPath", ""),
("studyPathText", ""),
("startDatepicker", ""),
("endDatepicker", ""),
("uniqueSessionId", unique_session_id),
];
let url = format!("{}/term/search", self.base_url);
let response = self
.http
.post(&url)
.header("Cookie", cookie_header)
.query(&[("mode", "search")])
.form(&form_data)
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to select term {}: {}",
term,
response.status()
));
}
#[derive(serde::Deserialize)]
struct RedirectResponse {
#[serde(rename = "fwdURL")]
fwd_url: String,
}
let redirect: RedirectResponse = response.json().await?;
let base_url_path = self.base_url.parse::<Url>().unwrap().path().to_string();
let non_overlap_redirect = redirect.fwd_url.strip_prefix(&base_url_path).unwrap();
// Follow the redirect
let redirect_url = format!("{}{}", self.base_url, non_overlap_redirect);
let redirect_response = self
.http
.get(&redirect_url)
.header("Cookie", cookie_header)
.send()
.await?;
if !redirect_response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to follow redirect: {}",
redirect_response.status()
));
}
trace!(term = term, "successfully selected term");
Ok(())
}
}

6
src/banner/util.rs Normal file
View File

@@ -0,0 +1,6 @@
//! Utility functions for the Banner module.
/// Returns a browser-like user agent string.
pub fn user_agent() -> &'static str {
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36"
}

132
src/bin/search.rs Normal file
View File

@@ -0,0 +1,132 @@
use banner::banner::{BannerApi, SearchQuery, Term};
use banner::config::Config;
use banner::error::Result;
use figment::{Figment, providers::Env};
use futures::future;
use tracing::{error, info};
use tracing_subscriber::{EnvFilter, FmtSubscriber};
#[tokio::main]
async fn main() -> Result<()> {
// Configure logging
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,banner=trace,reqwest=debug,hyper=info"));
let subscriber = FmtSubscriber::builder()
.with_env_filter(filter)
.with_target(true)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
info!("Starting Banner search test");
dotenvy::dotenv().ok();
// Load configuration
let config: Config = Figment::new()
.merge(Env::raw())
.extract()
.expect("Failed to load config");
info!(
banner_base_url = config.banner_base_url,
"Configuration loaded"
);
// Create Banner API client
let banner_api =
BannerApi::new_with_config(config.banner_base_url, config.rate_limiting.into())
.expect("Failed to create BannerApi");
// Get current term
let term = Term::get_current().inner().to_string();
info!(term = term, "Using current term");
// Define multiple search queries
let queries = vec![
(
"CS Courses",
SearchQuery::new().subject("CS").max_results(10),
),
(
"Math Courses",
SearchQuery::new().subject("MAT").max_results(10),
),
(
"3000-level CS",
SearchQuery::new()
.subject("CS")
.course_numbers(3000, 3999)
.max_results(8),
),
(
"High Credit Courses",
SearchQuery::new().credits(4, 6).max_results(8),
),
(
"Programming Courses",
SearchQuery::new().keyword("programming").max_results(6),
),
];
info!(query_count = queries.len(), "Executing concurrent searches");
// Execute all searches concurrently
let search_futures = queries.into_iter().map(|(label, query)| {
info!(label = %label, "Starting search");
let banner_api = &banner_api;
let term = &term;
async move {
let result = banner_api
.search(term, &query, "subjectDescription", false)
.await;
(label, result)
}
});
// Wait for all searches to complete
let search_results = future::join_all(search_futures)
.await
.into_iter()
.filter_map(|(label, result)| match result {
Ok(search_result) => {
info!(
label = label,
success = search_result.success,
total_count = search_result.total_count,
"Search completed successfully"
);
Some((label, search_result))
}
Err(e) => {
error!(label = label, error = ?e, "Search failed");
None
}
})
.collect::<Vec<_>>();
// Process and display results
for (label, search_result) in search_results {
println!("\n=== {} ===", label);
if let Some(courses) = &search_result.data {
if courses.is_empty() {
println!(" No courses found");
} else {
println!(" Found {} courses:", courses.len());
for course in courses {
println!(
" {} {} - {} (CRN: {})",
course.subject,
course.course_number,
course.course_title,
course.course_reference_number
);
}
}
} else {
println!(" No courses found");
}
}
info!("Search test completed");
Ok(())
}

151
src/bot/commands/gcal.rs Normal file
View File

@@ -0,0 +1,151 @@
//! Google Calendar command implementation.
use crate::banner::{Course, DayOfWeek, MeetingScheduleInfo};
use crate::bot::{Context, Error, utils};
use chrono::NaiveDate;
use std::collections::HashMap;
use tracing::info;
use url::Url;
/// Generate a link to create a Google Calendar event for a course
#[poise::command(slash_command)]
pub async fn gcal(
ctx: Context<'_>,
#[description = "Course Reference Number (CRN)"] crn: i32,
) -> Result<(), Error> {
let user = ctx.author();
info!(source = user.name, target = crn, "gcal command invoked");
ctx.defer().await?;
let course = utils::get_course_by_crn(&ctx, crn).await?;
let term = course.term.clone();
// Get meeting times
let meeting_times = ctx
.data()
.app_state
.banner_api
.get_course_meeting_time(&term, &crn.to_string())
.await?;
struct LinkDetail {
link: String,
detail: String,
}
let response: Vec<LinkDetail> = match meeting_times.len() {
0 => Err(anyhow::anyhow!("No meeting times found for this course.")),
1.. => {
// Sort meeting times by start time of their TimeRange
let mut sorted_meeting_times = meeting_times.to_vec();
sorted_meeting_times.sort_unstable_by(|a, b| {
// Primary sort: by start time
match (&a.time_range, &b.time_range) {
(Some(a_time), Some(b_time)) => a_time.start.cmp(&b_time.start),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => a.days.bits().cmp(&b.days.bits()),
}
});
let links = sorted_meeting_times
.iter()
.map(|m| {
let link = generate_gcal_url(&course, m)?;
let detail = match &m.time_range {
Some(range) => {
format!("{} {}", m.days_string().unwrap(), range.format_12hr())
}
None => m.days_string().unwrap(),
};
Ok(LinkDetail { link, detail })
})
.collect::<Result<Vec<LinkDetail>, anyhow::Error>>()?;
Ok(links)
}
}?;
ctx.say(
response
.iter()
.map(|LinkDetail { link, detail }| {
format!("[Add to Google Calendar](<{link}>) ({detail})")
})
.collect::<Vec<String>>()
.join("\n"),
)
.await?;
info!(crn = %crn, "gcal command completed");
Ok(())
}
/// Generate Google Calendar URL for a course
fn generate_gcal_url(
course: &Course,
meeting_time: &MeetingScheduleInfo,
) -> Result<String, anyhow::Error> {
let course_text = course.display_title();
let dates_text = {
let (start, end) = meeting_time.datetime_range();
format!(
"{}/{}",
start.format("%Y%m%dT%H%M%S"),
end.format("%Y%m%dT%H%M%S")
)
};
// Get instructor name
let instructor_name = course.primary_instructor_name();
// The event description
let details_text = format!(
"CRN: {}\nInstructor: {}\nDays: {}",
course.course_reference_number,
instructor_name,
meeting_time.days_string().unwrap()
);
// The event location
let location_text = meeting_time.place_string();
// The event recurrence rule
let recur_text = generate_rrule(meeting_time, meeting_time.date_range.end);
let mut params = HashMap::new();
params.insert("action", "TEMPLATE");
params.insert("text", &course_text);
params.insert("dates", &dates_text);
params.insert("details", &details_text);
params.insert("location", &location_text);
params.insert("trp", "true");
params.insert("ctz", "America/Chicago");
params.insert("recur", &recur_text);
Ok(Url::parse_with_params("https://calendar.google.com/calendar/render", &params)?.to_string())
}
/// Generate RRULE for recurrence
fn generate_rrule(meeting_time: &MeetingScheduleInfo, end_date: NaiveDate) -> String {
let days_of_week = meeting_time.days_of_week();
let by_day = days_of_week
.iter()
.map(|day| match day {
DayOfWeek::Monday => "MO",
DayOfWeek::Tuesday => "TU",
DayOfWeek::Wednesday => "WE",
DayOfWeek::Thursday => "TH",
DayOfWeek::Friday => "FR",
DayOfWeek::Saturday => "SA",
DayOfWeek::Sunday => "SU",
})
.collect::<Vec<&str>>()
.join(",");
// Format end date for RRULE (YYYYMMDD format)
let until = end_date.format("%Y%m%dT000000Z").to_string();
format!("RRULE:FREQ=WEEKLY;BYDAY={by_day};UNTIL={until}")
}

25
src/bot/commands/ics.rs Normal file
View File

@@ -0,0 +1,25 @@
//! ICS command implementation for generating calendar files.
use crate::bot::{Context, Error, utils};
use tracing::info;
/// Generate an ICS file for a course
#[poise::command(slash_command, prefix_command)]
pub async fn ics(
ctx: Context<'_>,
#[description = "Course Reference Number (CRN)"] crn: i32,
) -> Result<(), Error> {
ctx.defer().await?;
let course = utils::get_course_by_crn(&ctx, crn).await?;
// TODO: Implement actual ICS file generation
ctx.say(format!(
"ICS generation for '{}' is not yet implemented.",
course.display_title()
))
.await?;
info!(crn = %crn, "ics command completed");
Ok(())
}

13
src/bot/commands/mod.rs Normal file
View File

@@ -0,0 +1,13 @@
//! Bot commands module.
pub mod gcal;
pub mod ics;
pub mod search;
pub mod terms;
pub mod time;
pub use gcal::gcal;
pub use ics::ics;
pub use search::search;
pub use terms::terms;
pub use time::time;

140
src/bot/commands/search.rs Normal file
View File

@@ -0,0 +1,140 @@
//! Course search command implementation.
use crate::banner::{SearchQuery, Term};
use crate::bot::{Context, Error};
use anyhow::anyhow;
use regex::Regex;
use tracing::info;
/// Search for courses with various filters
#[poise::command(slash_command, prefix_command)]
pub async fn search(
ctx: Context<'_>,
#[description = "Course title (exact, use autocomplete)"] title: Option<String>,
#[description = "Course code (e.g. 3743, 3000-3999, 3xxx, 3000-)"] code: Option<String>,
#[description = "Maximum number of results"] max: Option<i32>,
#[description = "Keywords in title or description (space separated)"] keywords: Option<String>,
// #[description = "Instructor name"] instructor: Option<String>,
// #[description = "Subject (e.g Computer Science/CS, Mathematics/MAT)"] subject: Option<String>,
) -> Result<(), Error> {
// Defer the response since this might take a while
ctx.defer().await?;
// Build the search query
let mut query = SearchQuery::new().credits(3, 6);
if let Some(title) = title {
query = query.title(title);
}
if let Some(code) = code {
let (low, high) = parse_course_code(&code)?;
query = query.course_numbers(low, high);
}
if let Some(keywords) = keywords {
let keyword_list: Vec<String> =
keywords.split_whitespace().map(|s| s.to_string()).collect();
query = query.keywords(keyword_list);
}
if let Some(max_results) = max {
query = query.max_results(max_results.min(25)); // Cap at 25
}
let term = Term::get_current().inner().to_string();
let search_result = ctx
.data()
.app_state
.banner_api
.search(&term, &query, "subjectDescription", false)
.await?;
let response = if let Some(courses) = search_result.data {
if courses.is_empty() {
"No courses found with the specified criteria.".to_string()
} else {
courses
.iter()
.map(|course| {
format!(
"**{}**: {} ({})",
course.display_title(),
course.primary_instructor_name(),
course.course_reference_number
)
})
.collect::<Vec<_>>()
.join("\n")
}
} else {
"No courses found with the specified criteria.".to_string()
};
ctx.say(response).await?;
info!("search command completed");
Ok(())
}
/// Parse course code input (e.g, "3743", "3000-3999", "3xxx", "3000-")
fn parse_course_code(input: &str) -> Result<(i32, i32), Error> {
let input = input.trim();
// Handle range format (e.g, "3000-3999")
if input.contains('-') {
let re = Regex::new(r"(\d{1,4})-(\d{1,4})?").unwrap();
if let Some(captures) = re.captures(input) {
let low: i32 = captures[1].parse()?;
let high = if captures.get(2).is_some() {
captures[2].parse()?
} else {
9999 // Open-ended range
};
if low > high {
return Err(anyhow!("Invalid range: low value greater than high value"));
}
if low < 1000 || high > 9999 {
return Err(anyhow!("Course codes must be between 1000 and 9999"));
}
return Ok((low, high));
}
return Err(anyhow!("Invalid range format"));
}
// Handle wildcard format (e.g, "34xx")
if input.contains('x') {
if input.len() != 4 {
return Err(anyhow!("Wildcard format must be exactly 4 characters"));
}
let re = Regex::new(r"(\d+)(x+)").unwrap();
if let Some(captures) = re.captures(input) {
let prefix: i32 = captures[1].parse()?;
let x_count = captures[2].len();
let low = prefix * 10_i32.pow(x_count as u32);
let high = low + 10_i32.pow(x_count as u32) - 1;
if low < 1000 || high > 9999 {
return Err(anyhow!("Course codes must be between 1000 and 9999"));
}
return Ok((low, high));
}
return Err(anyhow!("Invalid wildcard format"));
}
// Handle single course code
if input.len() == 4 {
let code: i32 = input.parse()?;
if !(1000..=9999).contains(&code) {
return Err(anyhow!("Course codes must be between 1000 and 9999"));
}
return Ok((code, code));
}
Err(anyhow!("Invalid course code format"))
}

59
src/bot/commands/terms.rs Normal file
View File

@@ -0,0 +1,59 @@
//! Terms command implementation.
use crate::banner::{BannerTerm, Term};
use crate::bot::{Context, Error};
use tracing::info;
/// List available terms or search for a specific term
#[poise::command(slash_command, prefix_command)]
pub async fn terms(
ctx: Context<'_>,
#[description = "Term to search for"] search: Option<String>,
#[description = "Page number"] page: Option<i32>,
) -> Result<(), Error> {
ctx.defer().await?;
let search_term = search.unwrap_or_default();
let page_number = page.unwrap_or(1).max(1);
let max_results = 10;
let terms = ctx
.data()
.app_state
.banner_api
.sessions
.get_terms(&search_term, page_number, max_results)
.await?;
let response = if terms.is_empty() {
"No terms found.".to_string()
} else {
let current_term_code = Term::get_current().inner().to_string();
terms
.iter()
.map(|term| format_term(term, &current_term_code))
.collect::<Vec<_>>()
.join("\n")
};
ctx.say(response).await?;
info!("terms command completed");
Ok(())
}
fn format_term(term: &BannerTerm, current_term_code: &str) -> String {
let is_current = if term.code == current_term_code {
" (current)"
} else {
""
};
let is_archived = if term.is_archived() {
" (archived)"
} else {
""
};
format!(
"- `{}`: {}{}{}",
term.code, term.description, is_current, is_archived
)
}

25
src/bot/commands/time.rs Normal file
View File

@@ -0,0 +1,25 @@
//! Time command implementation for course meeting times.
use crate::bot::{Context, Error, utils};
use tracing::info;
/// Get meeting times for a specific course
#[poise::command(slash_command, prefix_command)]
pub async fn time(
ctx: Context<'_>,
#[description = "Course Reference Number (CRN)"] crn: i32,
) -> Result<(), Error> {
ctx.defer().await?;
let course = utils::get_course_by_crn(&ctx, crn).await?;
// TODO: Implement actual meeting time retrieval and display
ctx.say(format!(
"Meeting time display for '{}' is not yet implemented.",
course.display_title()
))
.await?;
info!(crn = %crn, "time command completed");
Ok(())
}

21
src/bot/mod.rs Normal file
View File

@@ -0,0 +1,21 @@
use crate::error::Error;
use crate::state::AppState;
pub mod commands;
pub mod utils;
pub struct Data {
pub app_state: AppState,
} // User data, which is stored and accessible in all command invocations
pub type Context<'a> = poise::Context<'a, Data, Error>;
/// Get all available commands
pub fn get_commands() -> Vec<poise::Command<Data, Error>> {
vec![
commands::search(),
commands::terms(),
commands::time(),
commands::ics(),
commands::gcal(),
]
}

24
src/bot/utils.rs Normal file
View File

@@ -0,0 +1,24 @@
//! Bot command utilities.
use crate::banner::{Course, Term};
use crate::bot::Context;
use crate::error::Result;
use tracing::error;
/// Gets a course by its CRN for the current term.
pub async fn get_course_by_crn(ctx: &Context<'_>, crn: i32) -> Result<Course> {
let app_state = &ctx.data().app_state;
// Get current term dynamically
let current_term_status = Term::get_current();
let term = current_term_status.inner();
// Fetch live course data from Redis cache via AppState
app_state
.get_course_or_fetch(&term.to_string(), &crn.to_string())
.await
.map_err(|e| {
error!(error = %e, crn = %crn, "failed to fetch course data");
e
})
}

205
src/config/mod.rs Normal file
View File

@@ -0,0 +1,205 @@
//! Configuration module for the banner application.
//!
//! This module handles loading and parsing configuration from environment variables
//! using the figment crate. It supports flexible duration parsing that accepts both
//! numeric values (interpreted as seconds) and duration strings with units.
use fundu::{DurationParser, TimeUnit};
use serde::{Deserialize, Deserializer};
use std::time::Duration;
/// Main application configuration containing all sub-configurations
#[derive(Deserialize)]
pub struct Config {
/// Log level for the application
///
/// This value is used to set the log level for this application's target specifically.
/// e.g. "debug" would be similar to "warn,banner=debug,..."
///
/// Valid values are: "trace", "debug", "info", "warn", "error"
/// Defaults to "info" if not specified
#[serde(default = "default_log_level")]
pub log_level: String,
/// Port for the web server
#[serde(default = "default_port")]
pub port: u16,
/// Database connection URL
pub database_url: String,
/// Redis connection URL
pub redis_url: String,
/// Graceful shutdown timeout duration
///
/// Accepts both numeric values (seconds) and duration strings
/// Defaults to 8 seconds if not specified
#[serde(
default = "default_shutdown_timeout",
deserialize_with = "deserialize_duration"
)]
pub shutdown_timeout: Duration,
/// Discord bot token for authentication
pub bot_token: String,
/// Target Discord guild ID where the bot operates
pub bot_target_guild: u64,
/// Base URL for banner generation service
pub banner_base_url: String,
/// Rate limiting configuration for Banner API requests
#[serde(default = "default_rate_limiting")]
pub rate_limiting: RateLimitingConfig,
}
/// Default log level of "info"
fn default_log_level() -> String {
"info".to_string()
}
/// Default port of 3000
fn default_port() -> u16 {
3000
}
/// Default shutdown timeout of 8 seconds
fn default_shutdown_timeout() -> Duration {
Duration::from_secs(8)
}
/// Rate limiting configuration for Banner API requests
#[derive(Deserialize, Clone, Debug)]
pub struct RateLimitingConfig {
/// Requests per minute for session operations (very conservative)
#[serde(default = "default_session_rpm")]
pub session_rpm: u32,
/// Requests per minute for search operations (moderate)
#[serde(default = "default_search_rpm")]
pub search_rpm: u32,
/// Requests per minute for metadata operations (moderate)
#[serde(default = "default_metadata_rpm")]
pub metadata_rpm: u32,
/// Requests per minute for reset operations (low priority)
#[serde(default = "default_reset_rpm")]
pub reset_rpm: u32,
/// Burst allowance (extra requests allowed in short bursts)
#[serde(default = "default_burst_allowance")]
pub burst_allowance: u32,
}
/// Default rate limiting configuration
fn default_rate_limiting() -> RateLimitingConfig {
RateLimitingConfig {
session_rpm: default_session_rpm(),
search_rpm: default_search_rpm(),
metadata_rpm: default_metadata_rpm(),
reset_rpm: default_reset_rpm(),
burst_allowance: default_burst_allowance(),
}
}
/// Default session requests per minute (6 = 1 every 10 seconds)
fn default_session_rpm() -> u32 {
6
}
/// Default search requests per minute (30 = 1 every 2 seconds)
fn default_search_rpm() -> u32 {
30
}
/// Default metadata requests per minute (20 = 1 every 3 seconds)
fn default_metadata_rpm() -> u32 {
20
}
/// Default reset requests per minute (10 = 1 every 6 seconds)
fn default_reset_rpm() -> u32 {
10
}
/// Default burst allowance (3 extra requests)
fn default_burst_allowance() -> u32 {
3
}
/// Duration parser configured to handle various time units with seconds as default
///
/// Supports:
/// - Seconds (s) - default unit
/// - Milliseconds (ms)
/// - Minutes (m)
/// - Hours (h)
///
/// Does not support fractions, exponents, or infinity values
/// Allows for whitespace between the number and the time unit
/// Allows for multiple time units to be specified (summed together, e.g "10s 2m" = 120 + 10 = 130 seconds)
const DURATION_PARSER: DurationParser<'static> = DurationParser::builder()
.time_units(&[TimeUnit::Second, TimeUnit::MilliSecond, TimeUnit::Minute])
.parse_multiple(None)
.allow_time_unit_delimiter()
.disable_infinity()
.disable_fraction()
.disable_exponent()
.default_unit(TimeUnit::Second)
.build();
/// Custom deserializer for duration fields that accepts both numeric and string values
///
/// This deserializer handles the flexible duration parsing by accepting:
/// - Unsigned integers (interpreted as seconds)
/// - Signed integers (interpreted as seconds, must be non-negative)
/// - Strings (parsed using the fundu duration parser)
///
/// # Examples
///
/// - `1` -> 1 second
/// - `"30s"` -> 30 seconds
/// - `"2 m"` -> 2 minutes
/// - `"1500ms"` -> 15 seconds
fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Visitor;
struct DurationVisitor;
impl<'de> Visitor<'de> for DurationVisitor {
type Value = Duration;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a duration string or number")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
DURATION_PARSER.parse(value)
.map_err(|e| {
serde::de::Error::custom(format!(
"Invalid duration format '{}': {}. Examples: '5' (5 seconds), '3500ms', '30s', '2m', '1.5h'",
value, e
))
})?
.try_into()
.map_err(|e| serde::de::Error::custom(format!("Duration conversion error: {}", e)))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Duration::from_secs(value))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if value < 0 {
return Err(serde::de::Error::custom("Duration cannot be negative"));
}
Ok(Duration::from_secs(value as u64))
}
}
deserializer.deserialize_any(DurationVisitor)
}

3
src/data/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
//! Database models and schema.
pub mod models;

71
src/data/models.rs Normal file
View File

@@ -0,0 +1,71 @@
//! `sqlx` models for the database schema.
use chrono::{DateTime, Utc};
use serde_json::Value;
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct Course {
pub id: i32,
pub crn: String,
pub subject: String,
pub course_number: String,
pub title: String,
pub term_code: String,
pub enrollment: i32,
pub max_enrollment: i32,
pub wait_count: i32,
pub wait_capacity: i32,
pub last_scraped_at: DateTime<Utc>,
}
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct CourseMetric {
pub id: i32,
pub course_id: i32,
pub timestamp: DateTime<Utc>,
pub enrollment: i32,
pub wait_count: i32,
pub seats_available: i32,
}
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct CourseAudit {
pub id: i32,
pub course_id: i32,
pub timestamp: DateTime<Utc>,
pub field_changed: String,
pub old_value: String,
pub new_value: String,
}
/// The priority level of a scrape job.
#[derive(sqlx::Type, Copy, Debug, Clone)]
#[sqlx(type_name = "scrape_priority", rename_all = "PascalCase")]
pub enum ScrapePriority {
Low,
Medium,
High,
Critical,
}
/// The type of target for a scrape job, determining how the payload is interpreted.
#[derive(sqlx::Type, Copy, Debug, Clone)]
#[sqlx(type_name = "target_type", rename_all = "PascalCase")]
pub enum TargetType {
Subject,
CourseRange,
CrnList,
SingleCrn,
}
/// Represents a queryable job from the database.
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct ScrapeJob {
pub id: i32,
pub target_type: TargetType,
pub target_payload: Value,
pub priority: ScrapePriority,
pub execute_at: DateTime<Utc>,
pub created_at: DateTime<Utc>,
pub locked_at: Option<DateTime<Utc>>,
}

4
src/error.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Application-specific error types.
pub type Error = anyhow::Error;
pub type Result<T, E = Error> = anyhow::Result<T, E>;

243
src/formatter.rs Normal file
View File

@@ -0,0 +1,243 @@
//! Custom tracing formatter
use serde::Serialize;
use serde_json::{Map, Value};
use std::fmt;
use time::macros::format_description;
use time::{OffsetDateTime, format_description::FormatItem};
use tracing::field::{Field, Visit};
use tracing::{Event, Level, Subscriber};
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields, FormattedFields};
use tracing_subscriber::registry::LookupSpan;
/// Cached format description for timestamps
/// Uses 3 subsecond digits on Emscripten, 5 otherwise for better performance
#[cfg(target_os = "emscripten")]
const TIMESTAMP_FORMAT: &[FormatItem<'static>] =
format_description!("[hour]:[minute]:[second].[subsecond digits:3]");
#[cfg(not(target_os = "emscripten"))]
const TIMESTAMP_FORMAT: &[FormatItem<'static>] =
format_description!("[hour]:[minute]:[second].[subsecond digits:5]");
/// A custom formatter with enhanced timestamp formatting
///
/// Re-implementation of the Full formatter with improved timestamp display.
pub struct CustomFormatter;
/// A custom JSON formatter that flattens fields to root level
///
/// Outputs logs in the format: { "message": "...", "level": "...", "customAttribute": "..." }
pub struct CustomJsonFormatter;
impl<S, N> FormatEvent<S, N> for CustomFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &Event<'_>,
) -> fmt::Result {
let meta = event.metadata();
// 1) Timestamp (dimmed when ANSI)
let now = OffsetDateTime::now_utc();
let formatted_time = now.format(&TIMESTAMP_FORMAT).map_err(|e| {
eprintln!("Failed to format timestamp: {}", e);
fmt::Error
})?;
write_dimmed(&mut writer, formatted_time)?;
writer.write_char(' ')?;
// 2) Colored 5-char level like Full
write_colored_level(&mut writer, meta.level())?;
writer.write_char(' ')?;
// 3) Span scope chain (bold names, fields in braces, dimmed ':')
if let Some(scope) = ctx.event_scope() {
let mut saw_any = false;
for span in scope.from_root() {
write_bold(&mut writer, span.metadata().name())?;
saw_any = true;
let ext = span.extensions();
if let Some(fields) = &ext.get::<FormattedFields<N>>() {
if !fields.is_empty() {
write_bold(&mut writer, "{")?;
write!(writer, "{}", fields)?;
write_bold(&mut writer, "}")?;
}
}
if writer.has_ansi_escapes() {
write!(writer, "\x1b[2m:\x1b[0m")?;
} else {
writer.write_char(':')?;
}
}
if saw_any {
writer.write_char(' ')?;
}
}
// 4) Target (dimmed), then a space
if writer.has_ansi_escapes() {
write!(writer, "\x1b[2m{}\x1b[0m\x1b[2m:\x1b[0m ", meta.target())?;
} else {
write!(writer, "{}: ", meta.target())?;
}
// 5) Event fields
ctx.format_fields(writer.by_ref(), event)?;
// 6) Newline
writeln!(writer)
}
}
impl<S, N> FormatEvent<S, N> for CustomJsonFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
_ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &Event<'_>,
) -> fmt::Result {
let meta = event.metadata();
#[derive(Serialize)]
struct EventFields {
message: String,
level: String,
target: String,
#[serde(flatten)]
fields: Map<String, Value>,
}
let (message, fields) = {
let mut message: Option<String> = None;
let mut fields: Map<String, Value> = Map::new();
struct FieldVisitor<'a> {
message: &'a mut Option<String>,
fields: &'a mut Map<String, Value>,
}
impl<'a> Visit for FieldVisitor<'a> {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
let key = field.name();
if key == "message" {
*self.message = Some(format!("{:?}", value));
} else {
// Use typed methods for better performance
self.fields
.insert(key.to_string(), Value::String(format!("{:?}", value)));
}
}
fn record_str(&mut self, field: &Field, value: &str) {
let key = field.name();
if key == "message" {
*self.message = Some(value.to_string());
} else {
self.fields
.insert(key.to_string(), Value::String(value.to_string()));
}
}
fn record_i64(&mut self, field: &Field, value: i64) {
let key = field.name();
if key != "message" {
self.fields.insert(
key.to_string(),
Value::Number(serde_json::Number::from(value)),
);
}
}
fn record_u64(&mut self, field: &Field, value: u64) {
let key = field.name();
if key != "message" {
self.fields.insert(
key.to_string(),
Value::Number(serde_json::Number::from(value)),
);
}
}
fn record_bool(&mut self, field: &Field, value: bool) {
let key = field.name();
if key != "message" {
self.fields.insert(key.to_string(), Value::Bool(value));
}
}
}
let mut visitor = FieldVisitor {
message: &mut message,
fields: &mut fields,
};
event.record(&mut visitor);
(message, fields)
};
let json = EventFields {
message: message.unwrap_or_default(),
level: meta.level().to_string(),
target: meta.target().to_string(),
fields,
};
writeln!(
writer,
"{}",
serde_json::to_string(&json).unwrap_or_else(|_| "{}".to_string())
)
}
}
/// Write the verbosity level with the same coloring/alignment as the Full formatter.
fn write_colored_level(writer: &mut Writer<'_>, level: &Level) -> fmt::Result {
if writer.has_ansi_escapes() {
// Basic ANSI color sequences; reset with \x1b[0m
let (color, text) = match *level {
Level::TRACE => ("\x1b[35m", "TRACE"), // purple
Level::DEBUG => ("\x1b[34m", "DEBUG"), // blue
Level::INFO => ("\x1b[32m", " INFO"), // green, note leading space
Level::WARN => ("\x1b[33m", " WARN"), // yellow, note leading space
Level::ERROR => ("\x1b[31m", "ERROR"), // red
};
write!(writer, "{}{}\x1b[0m", color, text)
} else {
// Right-pad to width 5 like Full's non-ANSI mode
match *level {
Level::TRACE => write!(writer, "{:>5}", "TRACE"),
Level::DEBUG => write!(writer, "{:>5}", "DEBUG"),
Level::INFO => write!(writer, "{:>5}", " INFO"),
Level::WARN => write!(writer, "{:>5}", " WARN"),
Level::ERROR => write!(writer, "{:>5}", "ERROR"),
}
}
}
fn write_dimmed(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "\x1b[2m{}\x1b[0m", s)
} else {
write!(writer, "{}", s)
}
}
fn write_bold(writer: &mut Writer<'_>, s: impl fmt::Display) -> fmt::Result {
if writer.has_ansi_escapes() {
write!(writer, "\x1b[1m{}\x1b[0m", s)
} else {
write!(writer, "{}", s)
}
}

9
src/lib.rs Normal file
View File

@@ -0,0 +1,9 @@
pub mod banner;
pub mod bot;
pub mod config;
pub mod data;
pub mod error;
pub mod scraper;
pub mod services;
pub mod state;
pub mod web;

274
src/main.rs Normal file
View File

@@ -0,0 +1,274 @@
use serenity::all::{ClientBuilder, GatewayIntents};
use tokio::signal;
use tracing::{error, info, warn};
use tracing_subscriber::{EnvFilter, FmtSubscriber};
use crate::banner::BannerApi;
use crate::bot::{Data, get_commands};
use crate::config::Config;
use crate::scraper::ScraperService;
use crate::services::manager::ServiceManager;
use crate::services::{ServiceResult, bot::BotService, web::WebService};
use crate::state::AppState;
use crate::web::routes::BannerState;
use figment::{Figment, providers::Env};
use sqlx::postgres::PgPoolOptions;
use std::sync::Arc;
mod banner;
mod bot;
mod config;
mod data;
mod error;
mod formatter;
mod scraper;
mod services;
mod state;
mod web;
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok();
// Load configuration first to get log level
let config: Config = Figment::new()
.merge(Env::raw())
.extract()
.expect("Failed to load config");
// Configure logging based on config
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
let base_level = &config.log_level;
EnvFilter::new(&format!(
"warn,banner={},banner::rate_limiter=warn,banner::session=warn,banner::rate_limit_middleware=warn",
base_level
))
});
let subscriber = {
#[cfg(debug_assertions)]
{
FmtSubscriber::builder()
.with_target(true)
.event_format(formatter::CustomFormatter)
}
#[cfg(not(debug_assertions))]
{
FmtSubscriber::builder()
.with_target(true)
.event_format(formatter::CustomJsonFormatter)
}
}
.with_env_filter(filter)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
// Log application startup context
info!(
version = env!("CARGO_PKG_VERSION"),
environment = if cfg!(debug_assertions) {
"development"
} else {
"production"
},
"starting banner"
);
// Create database connection pool
let db_pool = PgPoolOptions::new()
.max_connections(10)
.connect(&config.database_url)
.await
.expect("Failed to create database pool");
info!(
port = config.port,
shutdown_timeout = format!("{:.2?}", config.shutdown_timeout),
banner_base_url = config.banner_base_url,
"configuration loaded"
);
// Create BannerApi and AppState
let banner_api = BannerApi::new_with_config(
config.banner_base_url.clone(),
config.rate_limiting.clone().into(),
)
.expect("Failed to create BannerApi");
let banner_api_arc = Arc::new(banner_api);
let app_state = AppState::new(banner_api_arc.clone(), &config.redis_url)
.expect("Failed to create AppState");
// Create BannerState for web service
let banner_state = BannerState {
api: banner_api_arc.clone(),
};
// Configure the client with your Discord bot token in the environment
let intents = GatewayIntents::non_privileged();
let bot_target_guild = config.bot_target_guild;
let framework = poise::Framework::builder()
.options(poise::FrameworkOptions {
commands: get_commands(),
pre_command: |ctx| {
Box::pin(async move {
let content = match ctx {
poise::Context::Application(_) => ctx.invocation_string(),
poise::Context::Prefix(prefix) => prefix.msg.content.to_string(),
};
let channel_name = ctx
.channel_id()
.name(ctx.http())
.await
.unwrap_or("unknown".to_string());
let span = tracing::Span::current();
span.record("command_name", ctx.command().qualified_name.as_str());
span.record("invocation", ctx.invocation_string());
span.record("msg.content", content.as_str());
span.record("msg.author", ctx.author().tag().as_str());
span.record("msg.id", ctx.id());
span.record("msg.channel_id", ctx.channel_id().get());
span.record("msg.channel", &channel_name.as_str());
tracing::info!(
command_name = ctx.command().qualified_name.as_str(),
invocation = ctx.invocation_string(),
msg.content = %content,
msg.author = %ctx.author().tag(),
msg.author_id = %ctx.author().id,
msg.id = %ctx.id(),
msg.channel = %channel_name.as_str(),
msg.channel_id = %ctx.channel_id(),
"{} invoked by {}",
ctx.command().name,
ctx.author().tag()
);
})
},
on_error: |error| {
Box::pin(async move {
if let Err(e) = poise::builtins::on_error(error).await {
tracing::error!(error = %e, "Fatal error while sending error message");
}
// error!(error = ?error, "command error");
})
},
..Default::default()
})
.setup(move |ctx, _ready, framework| {
let app_state = app_state.clone();
Box::pin(async move {
poise::builtins::register_in_guild(
ctx,
&framework.options().commands,
bot_target_guild.into(),
)
.await?;
poise::builtins::register_globally(ctx, &framework.options().commands).await?;
Ok(Data { app_state })
})
})
.build();
let client = ClientBuilder::new(config.bot_token, intents)
.framework(framework)
.await
.expect("Failed to build client");
// Extract shutdown timeout before moving config
let shutdown_timeout = config.shutdown_timeout;
let port = config.port;
// Create service manager
let mut service_manager = ServiceManager::new();
// Register services with the manager
let bot_service = Box::new(BotService::new(client));
let web_service = Box::new(WebService::new(port, banner_state));
let scraper_service = Box::new(ScraperService::new(db_pool.clone(), banner_api_arc.clone()));
service_manager.register_service("bot", bot_service);
service_manager.register_service("web", web_service);
service_manager.register_service("scraper", scraper_service);
// Spawn all registered services
service_manager.spawn_all();
// Set up CTRL+C signal handling
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
info!("received ctrl+c, gracefully shutting down...");
};
// Main application loop - wait for services or CTRL+C
let mut exit_code = 0;
tokio::select! {
(service_name, result) = service_manager.run() => {
// A service completed unexpectedly
match result {
ServiceResult::GracefulShutdown => {
info!(service = service_name, "service completed gracefully");
}
ServiceResult::NormalCompletion => {
warn!(service = service_name, "service completed unexpectedly");
exit_code = 1;
}
ServiceResult::Error(e) => {
error!(service = service_name, error = ?e, "service failed");
exit_code = 1;
}
}
// Shutdown remaining services
match service_manager.shutdown(shutdown_timeout).await {
Ok(elapsed) => {
info!(
remaining = format!("{:.2?}", shutdown_timeout - elapsed),
"graceful shutdown complete"
);
}
Err(pending_services) => {
warn!(
pending_count = pending_services.len(),
pending_services = ?pending_services,
"graceful shutdown elapsed - {} service(s) did not complete",
pending_services.len()
);
// Non-zero exit code, default to 2 if not set
exit_code = if exit_code == 0 { 2 } else { exit_code };
}
}
}
_ = ctrl_c => {
// User requested shutdown
info!("user requested shutdown via ctrl+c");
match service_manager.shutdown(shutdown_timeout).await {
Ok(elapsed) => {
info!(
remaining = format!("{:.2?}", shutdown_timeout - elapsed),
"graceful shutdown complete"
);
info!("graceful shutdown complete");
}
Err(pending_services) => {
warn!(
pending_count = pending_services.len(),
pending_services = ?pending_services,
"graceful shutdown elapsed - {} service(s) did not complete",
pending_services.len()
);
exit_code = 2;
}
}
}
}
info!(exit_code, "application shutdown complete");
std::process::exit(exit_code);
}

90
src/scraper/mod.rs Normal file
View File

@@ -0,0 +1,90 @@
pub mod scheduler;
pub mod worker;
use crate::banner::BannerApi;
use sqlx::PgPool;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::info;
use self::scheduler::Scheduler;
use self::worker::Worker;
use crate::services::Service;
/// The main service that will be managed by the application's `ServiceManager`.
///
/// It holds the shared resources (database pool, API client) and manages the
/// lifecycle of the Scheduler and Worker tasks.
pub struct ScraperService {
db_pool: PgPool,
banner_api: Arc<BannerApi>,
scheduler_handle: Option<JoinHandle<()>>,
worker_handles: Vec<JoinHandle<()>>,
}
impl ScraperService {
/// Creates a new `ScraperService`.
pub fn new(db_pool: PgPool, banner_api: Arc<BannerApi>) -> Self {
Self {
db_pool,
banner_api,
scheduler_handle: None,
worker_handles: Vec::new(),
}
}
/// Starts the scheduler and a pool of workers.
pub fn start(&mut self) {
info!("ScraperService starting");
let scheduler = Scheduler::new(self.db_pool.clone(), self.banner_api.clone());
let scheduler_handle = tokio::spawn(async move {
scheduler.run().await;
});
self.scheduler_handle = Some(scheduler_handle);
info!("Scheduler task spawned");
let worker_count = 4; // This could be configurable
for i in 0..worker_count {
let worker = Worker::new(i, self.db_pool.clone(), self.banner_api.clone());
let worker_handle = tokio::spawn(async move {
worker.run().await;
});
self.worker_handles.push(worker_handle);
}
info!(
worker_count = self.worker_handles.len(),
"Spawned worker tasks"
);
}
/// Signals all child tasks to gracefully shut down.
pub async fn shutdown(&mut self) {
info!("Shutting down scraper service");
if let Some(handle) = self.scheduler_handle.take() {
handle.abort();
}
for handle in self.worker_handles.drain(..) {
handle.abort();
}
info!("Scraper service shutdown");
}
}
#[async_trait::async_trait]
impl Service for ScraperService {
fn name(&self) -> &'static str {
"scraper"
}
async fn run(&mut self) -> Result<(), anyhow::Error> {
self.start();
std::future::pending::<()>().await;
Ok(())
}
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
self.shutdown().await;
Ok(())
}
}

84
src/scraper/scheduler.rs Normal file
View File

@@ -0,0 +1,84 @@
use crate::banner::{BannerApi, Term};
use crate::data::models::{ScrapePriority, TargetType};
use crate::error::Result;
use serde_json::json;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{error, info, debug, trace};
/// Periodically analyzes data and enqueues prioritized scrape jobs.
pub struct Scheduler {
db_pool: PgPool,
banner_api: Arc<BannerApi>,
}
impl Scheduler {
pub fn new(db_pool: PgPool, banner_api: Arc<BannerApi>) -> Self {
Self {
db_pool,
banner_api,
}
}
/// Runs the scheduler's main loop.
pub async fn run(&self) {
info!("Scheduler service started");
let mut interval = time::interval(Duration::from_secs(60)); // Runs every minute
loop {
interval.tick().await;
// Scheduler analyzing data...
if let Err(e) = self.schedule_jobs().await {
error!(error = ?e, "Failed to schedule jobs");
}
}
}
/// The core logic for deciding what jobs to create.
async fn schedule_jobs(&self) -> Result<()> {
// For now, we will implement a simple baseline scheduling strategy:
// 1. Get a list of all subjects from the Banner API.
// 2. For each subject, check if an active (not locked, not completed) job already exists.
// 3. If no job exists, create a new, low-priority job to be executed in the near future.
let term = Term::get_current().inner().to_string();
debug!(term = term, "Enqueuing subject jobs");
let subjects = self.banner_api.get_subjects("", &term, 1, 500).await?;
debug!(subject_count = subjects.len(), "Retrieved subjects from API");
for subject in subjects {
let payload = json!({ "subject": subject.code });
let existing_job: Option<(i32,)> = sqlx::query_as(
"SELECT id FROM scrape_jobs WHERE target_type = $1 AND target_payload = $2 AND locked_at IS NULL"
)
.bind(TargetType::Subject)
.bind(&payload)
.fetch_optional(&self.db_pool)
.await?;
if existing_job.is_some() {
trace!(subject = subject.code, "Job already exists, skipping");
continue;
}
sqlx::query(
"INSERT INTO scrape_jobs (target_type, target_payload, priority, execute_at) VALUES ($1, $2, $3, $4)"
)
.bind(TargetType::Subject)
.bind(&payload)
.bind(ScrapePriority::Low)
.bind(chrono::Utc::now())
.execute(&self.db_pool)
.await?;
debug!(subject = subject.code, "New job enqueued for subject");
}
debug!("Job scheduling complete");
Ok(())
}
}

205
src/scraper/worker.rs Normal file
View File

@@ -0,0 +1,205 @@
use crate::banner::{BannerApi, BannerApiError, Course, SearchQuery, Term};
use crate::data::models::ScrapeJob;
use crate::error::Result;
use serde_json::Value;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{debug, error, info, trace, warn};
/// A single worker instance.
///
/// Each worker runs in its own asynchronous task and continuously polls the
/// database for scrape jobs to execute.
pub struct Worker {
id: usize, // For logging purposes
db_pool: PgPool,
banner_api: Arc<BannerApi>,
}
impl Worker {
pub fn new(id: usize, db_pool: PgPool, banner_api: Arc<BannerApi>) -> Self {
Self {
id,
db_pool,
banner_api,
}
}
/// Runs the worker's main loop.
pub async fn run(&self) {
info!(worker_id = self.id, "Worker started.");
loop {
match self.fetch_and_lock_job().await {
Ok(Some(job)) => {
let job_id = job.id;
debug!(worker_id = self.id, job_id = job.id, "Processing job");
if let Err(e) = self.process_job(job).await {
// Check if the error is due to an invalid session
if let Some(BannerApiError::InvalidSession(_)) =
e.downcast_ref::<BannerApiError>()
{
warn!(
worker_id = self.id,
job_id, "Invalid session detected. Forcing session refresh."
);
} else {
error!(worker_id = self.id, job_id, error = ?e, "Failed to process job");
}
// Unlock the job so it can be retried
if let Err(unlock_err) = self.unlock_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?unlock_err,
"Failed to unlock job"
);
}
} else {
debug!(worker_id = self.id, job_id, "Job completed");
// If successful, delete the job.
if let Err(delete_err) = self.delete_job(job_id).await {
error!(
worker_id = self.id,
job_id,
?delete_err,
"Failed to delete job"
);
}
}
}
Ok(None) => {
// No job found, wait for a bit before polling again.
trace!(worker_id = self.id, "No jobs available, waiting");
time::sleep(Duration::from_secs(5)).await;
}
Err(e) => {
warn!(worker_id = self.id, error = ?e, "Failed to fetch job");
// Wait before retrying to avoid spamming errors.
time::sleep(Duration::from_secs(10)).await;
}
}
}
}
/// Atomically fetches a job from the queue, locking it for processing.
///
/// This uses a `FOR UPDATE SKIP LOCKED` query to ensure that multiple
/// workers can poll the queue concurrently without conflicts.
async fn fetch_and_lock_job(&self) -> Result<Option<ScrapeJob>> {
let mut tx = self.db_pool.begin().await?;
let job = sqlx::query_as::<_, ScrapeJob>(
"SELECT * FROM scrape_jobs WHERE locked_at IS NULL AND execute_at <= NOW() ORDER BY priority DESC, execute_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED"
)
.fetch_optional(&mut *tx)
.await?;
if let Some(ref job) = job {
sqlx::query("UPDATE scrape_jobs SET locked_at = NOW() WHERE id = $1")
.bind(job.id)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(job)
}
async fn process_job(&self, job: ScrapeJob) -> Result<()> {
match job.target_type {
crate::data::models::TargetType::Subject => {
self.process_subject_job(&job.target_payload).await
}
_ => {
warn!(worker_id = self.id, job_id = job.id, "unhandled job type");
Ok(())
}
}
}
async fn process_subject_job(&self, payload: &Value) -> Result<()> {
let subject_code = payload["subject"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Invalid subject payload"))?;
info!(
worker_id = self.id,
subject = subject_code,
"Scraping subject"
);
let term = Term::get_current().inner().to_string();
let query = SearchQuery::new().subject(subject_code).max_results(500);
let search_result = self
.banner_api
.search(&term, &query, "subjectDescription", false)
.await?;
if let Some(courses_from_api) = search_result.data {
info!(
worker_id = self.id,
subject = subject_code,
count = courses_from_api.len(),
"Found courses"
);
for course in courses_from_api {
self.upsert_course(&course).await?;
}
}
Ok(())
}
async fn upsert_course(&self, course: &Course) -> Result<()> {
sqlx::query(
r#"
INSERT INTO courses (crn, subject, course_number, title, term_code, enrollment, max_enrollment, wait_count, wait_capacity, last_scraped_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (crn, term_code) DO UPDATE SET
subject = EXCLUDED.subject,
course_number = EXCLUDED.course_number,
title = EXCLUDED.title,
enrollment = EXCLUDED.enrollment,
max_enrollment = EXCLUDED.max_enrollment,
wait_count = EXCLUDED.wait_count,
wait_capacity = EXCLUDED.wait_capacity,
last_scraped_at = EXCLUDED.last_scraped_at
"#,
)
.bind(&course.course_reference_number)
.bind(&course.subject)
.bind(&course.course_number)
.bind(&course.course_title)
.bind(&course.term)
.bind(course.enrollment)
.bind(course.maximum_enrollment)
.bind(course.wait_count)
.bind(course.wait_capacity)
.bind(chrono::Utc::now())
.execute(&self.db_pool)
.await?;
Ok(())
}
async fn delete_job(&self, job_id: i32) -> Result<()> {
sqlx::query("DELETE FROM scrape_jobs WHERE id = $1")
.bind(job_id)
.execute(&self.db_pool)
.await?;
Ok(())
}
async fn unlock_job(&self, job_id: i32) -> Result<()> {
sqlx::query("UPDATE scrape_jobs SET locked_at = NULL WHERE id = $1")
.bind(job_id)
.execute(&self.db_pool)
.await?;
info!(worker_id = self.id, job_id, "Job unlocked for retry");
Ok(())
}
}

45
src/services/bot.rs Normal file
View File

@@ -0,0 +1,45 @@
use super::Service;
use serenity::Client;
use std::sync::Arc;
use tracing::{error, warn};
/// Discord bot service implementation
pub struct BotService {
client: Client,
shard_manager: Arc<serenity::gateway::ShardManager>,
}
impl BotService {
pub fn new(client: Client) -> Self {
let shard_manager = client.shard_manager.clone();
Self {
client,
shard_manager,
}
}
}
#[async_trait::async_trait]
impl Service for BotService {
fn name(&self) -> &'static str {
"bot"
}
async fn run(&mut self) -> Result<(), anyhow::Error> {
match self.client.start().await {
Ok(()) => {
warn!(service = "bot", "stopped early");
Err(anyhow::anyhow!("bot stopped early"))
}
Err(e) => {
error!(service = "bot", "error: {e:?}");
Err(e.into())
}
}
}
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
self.shard_manager.shutdown_all().await;
Ok(())
}
}

164
src/services/manager.rs Normal file
View File

@@ -0,0 +1,164 @@
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, trace, warn};
use crate::services::{Service, ServiceResult, run_service};
/// Manages multiple services and their lifecycle
pub struct ServiceManager {
registered_services: HashMap<String, Box<dyn Service>>,
running_services: HashMap<String, JoinHandle<ServiceResult>>,
shutdown_tx: broadcast::Sender<()>,
}
impl Default for ServiceManager {
fn default() -> Self {
Self::new()
}
}
impl ServiceManager {
pub fn new() -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
Self {
registered_services: HashMap::new(),
running_services: HashMap::new(),
shutdown_tx,
}
}
/// Register a service to be managed (not yet spawned)
pub fn register_service(&mut self, name: &str, service: Box<dyn Service>) {
self.registered_services.insert(name.to_string(), service);
}
/// Spawn all registered services
pub fn spawn_all(&mut self) {
let service_count = self.registered_services.len();
let service_names: Vec<_> = self.registered_services.keys().cloned().collect();
for (name, service) in self.registered_services.drain() {
let shutdown_rx = self.shutdown_tx.subscribe();
let handle = tokio::spawn(run_service(service, shutdown_rx));
debug!(service = name, id = ?handle.id(), "service spawned");
self.running_services.insert(name, handle);
}
info!(
service_count,
services = ?service_names,
"spawned {} services",
service_count
);
}
/// Run all services until one completes or fails
/// Returns the first service that completes and its result
pub async fn run(&mut self) -> (String, ServiceResult) {
if self.running_services.is_empty() {
return (
"none".to_string(),
ServiceResult::Error(anyhow::anyhow!("No services to run")),
);
}
info!(
"servicemanager running {} services",
self.running_services.len()
);
// Wait for any service to complete
loop {
let mut completed_services = Vec::new();
for (name, handle) in &mut self.running_services {
if handle.is_finished() {
completed_services.push(name.clone());
}
}
if let Some(completed_name) = completed_services.first() {
let handle = self.running_services.remove(completed_name).unwrap();
match handle.await {
Ok(result) => {
return (completed_name.clone(), result);
}
Err(e) => {
error!(service = completed_name, "service task panicked: {e}");
return (
completed_name.clone(),
ServiceResult::Error(anyhow::anyhow!("Task panic: {e}")),
);
}
}
}
// Small delay to prevent busy-waiting
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
/// Shutdown all services gracefully with a timeout.
///
/// If any service fails to shutdown, it will return an error containing the names of the services that failed to shutdown.
/// If all services shutdown successfully, the function will return the duration elapsed.
pub async fn shutdown(&mut self, timeout: Duration) -> Result<Duration, Vec<String>> {
let service_count = self.running_services.len();
let service_names: Vec<_> = self.running_services.keys().cloned().collect();
info!(
service_count,
services = ?service_names,
timeout = format!("{:.2?}", timeout),
"shutting down {} services with {:?} timeout",
service_count,
timeout
);
// Send shutdown signal to all services
let _ = self.shutdown_tx.send(());
// Wait for all services to complete
let start_time = std::time::Instant::now();
let mut pending_services = Vec::new();
for (name, handle) in self.running_services.drain() {
match tokio::time::timeout(timeout, handle).await {
Ok(Ok(_)) => {
trace!(service = name, "service shutdown completed");
}
Ok(Err(e)) => {
warn!(service = name, error = ?e, "service shutdown failed");
pending_services.push(name);
}
Err(_) => {
warn!(service = name, "service shutdown timed out");
pending_services.push(name);
}
}
}
let elapsed = start_time.elapsed();
if pending_services.is_empty() {
info!(
service_count,
elapsed = format!("{:.2?}", elapsed),
"services shutdown completed: {}",
service_names.join(", ")
);
Ok(elapsed)
} else {
warn!(
pending_count = pending_services.len(),
pending_services = ?pending_services,
elapsed = format!("{:.2?}", elapsed),
"services shutdown completed with {} pending: {}",
pending_services.len(),
pending_services.join(", ")
);
Err(pending_services)
}
}
}

71
src/services/mod.rs Normal file
View File

@@ -0,0 +1,71 @@
use tokio::sync::broadcast;
use tracing::{error, info, warn};
pub mod bot;
pub mod manager;
pub mod web;
#[derive(Debug)]
pub enum ServiceResult {
GracefulShutdown,
NormalCompletion,
Error(anyhow::Error),
}
/// Common trait for all services in the application
#[async_trait::async_trait]
pub trait Service: Send + Sync {
/// The name of the service for logging
fn name(&self) -> &'static str;
/// Run the service's main work loop
async fn run(&mut self) -> Result<(), anyhow::Error>;
/// Gracefully shutdown the service
///
/// An 'Ok' result does not mean the service has completed shutdown, it merely means that the service shutdown was initiated.
async fn shutdown(&mut self) -> Result<(), anyhow::Error>;
}
/// Generic service runner that handles the lifecycle
pub async fn run_service(
mut service: Box<dyn Service>,
mut shutdown_rx: broadcast::Receiver<()>,
) -> ServiceResult {
let name = service.name();
info!(service = name, "service started");
let work = async {
match service.run().await {
Ok(()) => {
warn!(service = name, "service completed unexpectedly");
ServiceResult::NormalCompletion
}
Err(e) => {
error!(service = name, "service failed: {e}");
ServiceResult::Error(e)
}
}
};
tokio::select! {
result = work => result,
_ = shutdown_rx.recv() => {
info!(service = name, "shutting down...");
let start_time = std::time::Instant::now();
match service.shutdown().await {
Ok(()) => {
let elapsed = start_time.elapsed();
info!(service = name, "shutdown completed in {elapsed:.2?}");
ServiceResult::GracefulShutdown
}
Err(e) => {
let elapsed = start_time.elapsed();
error!(service = name, "shutdown failed after {elapsed:.2?}: {e}");
ServiceResult::Error(e)
}
}
}
}
}

79
src/services/web.rs Normal file
View File

@@ -0,0 +1,79 @@
use super::Service;
use crate::web::{BannerState, create_router};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tracing::{info, warn, trace};
/// Web server service implementation
pub struct WebService {
port: u16,
banner_state: BannerState,
shutdown_tx: Option<broadcast::Sender<()>>,
}
impl WebService {
pub fn new(port: u16, banner_state: BannerState) -> Self {
Self {
port,
banner_state,
shutdown_tx: None,
}
}
}
#[async_trait::async_trait]
impl Service for WebService {
fn name(&self) -> &'static str {
"web"
}
async fn run(&mut self) -> Result<(), anyhow::Error> {
// Create the main router with Banner API routes
let app = create_router(self.banner_state.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
info!(
service = "web",
link = format!("http://localhost:{}", addr.port()),
"starting web server",
);
let listener = TcpListener::bind(addr).await?;
info!(
service = "web",
address = %addr,
"web server listening"
);
// Create internal shutdown channel for axum graceful shutdown
let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1);
self.shutdown_tx = Some(shutdown_tx);
// Use axum's graceful shutdown with the internal shutdown signal
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.recv().await;
trace!(
service = "web",
"received shutdown signal, starting graceful shutdown"
);
})
.await?;
info!(service = "web", "web server stopped");
Ok(())
}
async fn shutdown(&mut self) -> Result<(), anyhow::Error> {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
} else {
warn!(
service = "web",
"no shutdown channel found, cannot trigger graceful shutdown"
);
}
Ok(())
}
}

48
src/state.rs Normal file
View File

@@ -0,0 +1,48 @@
//! Application state shared across components (bot, web, scheduler).
use crate::banner::BannerApi;
use crate::banner::Course;
use anyhow::Result;
use redis::AsyncCommands;
use redis::Client;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub banner_api: Arc<BannerApi>,
pub redis: Arc<Client>,
}
impl AppState {
pub fn new(
banner_api: Arc<BannerApi>,
redis_url: &str,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let redis_client = Client::open(redis_url)?;
Ok(Self {
banner_api,
redis: Arc::new(redis_client),
})
}
/// Get a course by CRN with Redis cache fallback to Banner API
pub async fn get_course_or_fetch(&self, term: &str, crn: &str) -> Result<Course> {
let mut conn = self.redis.get_multiplexed_async_connection().await?;
let key = format!("class:{crn}");
if let Some(serialized) = conn.get::<_, Option<String>>(&key).await? {
let course: Course = serde_json::from_str(&serialized)?;
return Ok(course);
}
// Fallback: fetch from Banner API
if let Some(course) = self.banner_api.get_course_by_crn(term, crn).await? {
let serialized = serde_json::to_string(&course)?;
let _: () = conn.set(&key, serialized).await?;
return Ok(course);
}
Err(anyhow::anyhow!("Course not found for CRN {crn}"))
}
}

5
src/web/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
//! Web API module for the banner application.
pub mod routes;
pub use routes::*;

87
src/web/routes.rs Normal file
View File

@@ -0,0 +1,87 @@
//! Web API endpoints for Banner bot monitoring and metrics.
use axum::{Router, extract::State, response::Json, routing::get};
use serde_json::{Value, json};
use std::sync::Arc;
use tracing::info;
use crate::banner::BannerApi;
/// Shared application state for web server
#[derive(Clone)]
pub struct BannerState {
pub api: Arc<BannerApi>,
}
/// Creates the web server router
pub fn create_router(state: BannerState) -> Router {
Router::new()
.route("/", get(root))
.route("/health", get(health))
.route("/status", get(status))
.route("/metrics", get(metrics))
.with_state(state)
}
async fn root() -> Json<Value> {
Json(json!({
"message": "Banner Discord Bot API",
"version": "0.1.0",
"endpoints": {
"health": "/health",
"status": "/status",
"metrics": "/metrics"
}
}))
}
/// Health check endpoint
async fn health() -> Json<Value> {
info!("health check requested");
Json(json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
/// Status endpoint showing bot and system status
async fn status(State(_state): State<BannerState>) -> Json<Value> {
// For now, return basic status without accessing private fields
Json(json!({
"status": "operational",
"bot": {
"status": "running",
"uptime": "TODO: implement uptime tracking"
},
"cache": {
"status": "connected",
"courses": "TODO: implement course counting",
"subjects": "TODO: implement subject counting"
},
"banner_api": {
"status": "connected"
},
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
/// Metrics endpoint for monitoring
async fn metrics(State(_state): State<BannerState>) -> Json<Value> {
// For now, return basic metrics structure
Json(json!({
"redis": {
"status": "connected",
"connected_clients": "TODO: implement client counting",
"used_memory": "TODO: implement memory tracking"
},
"cache": {
"courses": {
"count": "TODO: implement course counting"
},
"subjects": {
"count": "TODO: implement subject counting"
}
},
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}

View File

@@ -1,229 +0,0 @@
package config_test
import (
"banner/internal/config"
"testing"
"time"
)
func TestGetCurrentTerm(t *testing.T) {
// Initialize location for testing
loc, _ := time.LoadLocation("America/Chicago")
// Use current year to avoid issues with global state
currentYear := uint16(time.Now().Year())
ranges := config.GetYearDayRange(loc, currentYear)
tests := []struct {
name string
date time.Time
expectedCurrent *config.Term
expectedNext *config.Term
}{
{
name: "Spring term",
date: time.Date(int(currentYear), 3, 15, 12, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear, Season: config.Spring},
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
},
{
name: "Summer term",
date: time.Date(int(currentYear), 6, 15, 12, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear, Season: config.Summer},
expectedNext: &config.Term{Year: currentYear, Season: config.Fall},
},
{
name: "Fall term",
date: time.Date(int(currentYear), 9, 15, 12, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear + 1, Season: config.Fall},
expectedNext: nil,
},
{
name: "Between Spring and Summer",
date: time.Date(int(currentYear), 5, 20, 12, 0, 0, 0, loc),
expectedCurrent: nil,
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
},
{
name: "Between Summer and Fall",
date: time.Date(int(currentYear), 8, 16, 12, 0, 0, 0, loc),
expectedCurrent: nil,
expectedNext: &config.Term{Year: currentYear + 1, Season: config.Fall},
},
{
name: "Between Fall and Spring",
date: time.Date(int(currentYear), 12, 15, 12, 0, 0, 0, loc),
expectedCurrent: nil,
expectedNext: &config.Term{Year: currentYear + 1, Season: config.Spring},
},
{
name: "Early January before Spring",
date: time.Date(int(currentYear), 1, 10, 12, 0, 0, 0, loc),
expectedCurrent: nil,
expectedNext: &config.Term{Year: currentYear, Season: config.Spring},
},
{
name: "Spring start date",
date: time.Date(int(currentYear), 1, 14, 0, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear, Season: config.Spring},
expectedNext: &config.Term{Year: currentYear, Season: config.Summer},
},
{
name: "Summer start date",
date: time.Date(int(currentYear), 5, 25, 0, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear, Season: config.Summer},
expectedNext: &config.Term{Year: currentYear, Season: config.Fall},
},
{
name: "Fall start date",
date: time.Date(int(currentYear), 8, 18, 0, 0, 0, 0, loc),
expectedCurrent: &config.Term{Year: currentYear + 1, Season: config.Fall},
expectedNext: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
current, next := config.GetCurrentTerm(ranges, tt.date)
if !termsEqual(current, tt.expectedCurrent) {
t.Errorf("GetCurrentTerm() current = %v, want %v", current, tt.expectedCurrent)
}
if !termsEqual(next, tt.expectedNext) {
t.Errorf("GetCurrentTerm() next = %v, want %v", next, tt.expectedNext)
}
})
}
}
func TestGetYearDayRange(t *testing.T) {
loc, _ := time.LoadLocation("America/Chicago")
ranges := config.GetYearDayRange(loc, 2024)
// Verify Spring range (Jan 14 to May 1)
expectedSpringStart := time.Date(2024, 1, 14, 0, 0, 0, 0, loc).YearDay()
expectedSpringEnd := time.Date(2024, 5, 1, 0, 0, 0, 0, loc).YearDay()
if ranges.Spring.Start != uint16(expectedSpringStart) {
t.Errorf("Spring start = %d, want %d", ranges.Spring.Start, expectedSpringStart)
}
if ranges.Spring.End != uint16(expectedSpringEnd) {
t.Errorf("Spring end = %d, want %d", ranges.Spring.End, expectedSpringEnd)
}
// Verify Summer range (May 25 to Aug 15)
expectedSummerStart := time.Date(2024, 5, 25, 0, 0, 0, 0, loc).YearDay()
expectedSummerEnd := time.Date(2024, 8, 15, 0, 0, 0, 0, loc).YearDay()
if ranges.Summer.Start != uint16(expectedSummerStart) {
t.Errorf("Summer start = %d, want %d", ranges.Summer.Start, expectedSummerStart)
}
if ranges.Summer.End != uint16(expectedSummerEnd) {
t.Errorf("Summer end = %d, want %d", ranges.Summer.End, expectedSummerEnd)
}
// Verify Fall range (Aug 18 to Dec 10)
expectedFallStart := time.Date(2024, 8, 18, 0, 0, 0, 0, loc).YearDay()
expectedFallEnd := time.Date(2024, 12, 10, 0, 0, 0, 0, loc).YearDay()
if ranges.Fall.Start != uint16(expectedFallStart) {
t.Errorf("Fall start = %d, want %d", ranges.Fall.Start, expectedFallStart)
}
if ranges.Fall.End != uint16(expectedFallEnd) {
t.Errorf("Fall end = %d, want %d", ranges.Fall.End, expectedFallEnd)
}
}
func TestParseTerm(t *testing.T) {
tests := []struct {
code string
expected config.Term
}{
{"202410", config.Term{Year: 2024, Season: config.Fall}},
{"202420", config.Term{Year: 2024, Season: config.Spring}},
{"202430", config.Term{Year: 2024, Season: config.Summer}},
{"202510", config.Term{Year: 2025, Season: config.Fall}},
}
for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
result := config.ParseTerm(tt.code)
if result != tt.expected {
t.Errorf("ParseTerm(%s) = %v, want %v", tt.code, result, tt.expected)
}
})
}
}
func TestTermToString(t *testing.T) {
tests := []struct {
term config.Term
expected string
}{
{config.Term{Year: 2024, Season: config.Fall}, "202410"},
{config.Term{Year: 2024, Season: config.Spring}, "202420"},
{config.Term{Year: 2024, Season: config.Summer}, "202430"},
{config.Term{Year: 2025, Season: config.Fall}, "202510"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := tt.term.ToString()
if result != tt.expected {
t.Errorf("Term{Year: %d, Season: %d}.ToString() = %s, want %s",
tt.term.Year, tt.term.Season, result, tt.expected)
}
})
}
}
func TestDefaultTerm(t *testing.T) {
loc, _ := time.LoadLocation("America/Chicago")
ranges := config.GetYearDayRange(loc, 2024)
tests := []struct {
name string
date time.Time
expected config.Term
}{
{
name: "During Spring term",
date: time.Date(2024, 3, 15, 12, 0, 0, 0, loc),
expected: config.Term{Year: 2024, Season: config.Spring},
},
{
name: "Between terms - returns next term",
date: time.Date(2024, 5, 20, 12, 0, 0, 0, loc),
expected: config.Term{Year: 2024, Season: config.Summer},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
currentTerm, nextTerm := config.GetCurrentTerm(ranges, tt.date)
var result config.Term
if currentTerm == nil {
result = *nextTerm
} else {
result = *currentTerm
}
if result != tt.expected {
t.Errorf("DefaultTerm() = %v, want %v", result, tt.expected)
}
})
}
}
// Helper function to compare terms, handling nil cases
func termsEqual(a, b *config.Term) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}