feat: middleware headers, fix concurrent session cookies issue, middleware headers, invalid session details

This commit is contained in:
2025-09-12 20:12:12 -05:00
parent dd212c3239
commit 2f853a7de9
4 changed files with 131 additions and 46 deletions

View File

@@ -10,18 +10,19 @@ use crate::banner::{
BannerSession, SessionPool, models::*, nonce, query::SearchQuery, util::user_agent,
};
use anyhow::{Context, Result, anyhow};
use axum::http::{Extensions, HeaderValue};
use cookie::Cookie;
use dashmap::DashMap;
use http::{Extensions, HeaderValue};
use reqwest::{Client, Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use serde_json;
use tracing::{Level, Metadata, Span, debug, error, field::ValueSet, info, span};
use tl;
use tracing::{Level, Metadata, Span, debug, error, field::ValueSet, info, span, trace, warn};
#[derive(Debug, thiserror::Error)]
pub enum BannerApiError {
#[error("Banner session is invalid or expired")]
InvalidSession,
#[error("Banner session is invalid or expired: {0}")]
InvalidSession(String),
#[error(transparent)]
RequestFailed(#[from] anyhow::Error),
}
@@ -43,29 +44,36 @@ impl Middleware for TransparentMiddleware {
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, reqwest_middleware::Error> {
debug!(
trace!(
domain = req.url().domain(),
headers = ?req.headers(),
"{method} {path}",
method = req.method().to_string(),
path = req.url().path(),
);
let response = next.run(req, extensions).await;
let response_result = next.run(req, extensions).await;
match &response {
match response_result {
Ok(response) => {
debug!(
"{code} {reason} {path}",
code = response.status().as_u16(),
reason = response.status().canonical_reason().unwrap_or("??"),
path = response.url().path(),
);
if response.status().is_success() {
trace!(
"{code} {reason} {path}",
code = response.status().as_u16(),
reason = response.status().canonical_reason().unwrap_or("??"),
path = response.url().path(),
);
Ok(response)
} else {
let e = response.error_for_status_ref().unwrap_err();
warn!(error = ?e, "Request failed (server)");
Ok(response)
}
}
Err(error) => {
debug!("!!! {error}");
warn!(?error, "Request failed (middleware)");
Err(error)
}
}
response
}
}
@@ -74,7 +82,7 @@ impl BannerApi {
pub fn new(base_url: String) -> Result<Self> {
let http = ClientBuilder::new(
Client::builder()
.cookie_store(true)
.cookie_store(false)
.user_agent(user_agent())
.tcp_keepalive(Some(std::time::Duration::from_secs(60 * 5)))
.read_timeout(std::time::Duration::from_secs(10))
@@ -285,10 +293,24 @@ impl BannerApi {
params.insert("startDatepicker".to_string(), String::new());
params.insert("endDatepicker".to_string(), String::new());
let url = format!("{}/searchResults/searchResults", self.base_url);
if session.been_used() {
self.http
.post(&format!("{}/classSearch/resetDataForm", self.base_url))
.send()
.await
.map_err(|e| BannerApiError::RequestFailed(e.into()))?;
}
debug!(
term = term,
query = ?query,
sort = sort,
sort_descending = sort_descending,
"Searching for courses with params: {:?}", params);
let response = self
.http
.get(&url)
.get(format!("{}/searchResults/searchResults", self.base_url))
.header("Cookie", session.cookie())
.query(&params)
.send()
@@ -309,8 +331,14 @@ impl BannerApi {
})?;
// Check for signs of an invalid session, based on docs/Sessions.md
if search_result.path_mode.is_none() || search_result.data.is_none() {
return Err(BannerApiError::InvalidSession);
if search_result.path_mode.is_none() {
return Err(BannerApiError::InvalidSession(
"Search result path mode is none".to_string(),
));
} else if search_result.data.is_none() {
return Err(BannerApiError::InvalidSession(
"Search result data is none".to_string(),
));
}
if !search_result.success {
@@ -373,7 +401,9 @@ impl BannerApi {
if search_result.path_mode == Some("registration".to_string())
&& search_result.data.is_none()
{
return Err(BannerApiError::InvalidSession);
return Err(BannerApiError::InvalidSession(
"Search result path mode is registration and data is none".to_string(),
));
}
if !search_result.success {

View File

@@ -22,7 +22,7 @@ const SESSION_EXPIRY: Duration = Duration::from_secs(25 * 60); // 25 minutes
#[derive(Debug, Clone)]
pub struct BannerSession {
// Randomly generated
unique_session_id: String,
pub unique_session_id: String,
// Timestamp of creation
created_at: Instant,
// Timestamp of last activity
@@ -72,7 +72,7 @@ impl BannerSession {
/// Updates the last activity timestamp
pub fn touch(&mut self) {
debug!("Session {} is being used", self.unique_session_id);
debug!(id = self.unique_session_id, "Session was used");
self.last_activity = Some(Instant::now());
}
@@ -88,6 +88,10 @@ impl BannerSession {
self.jsessionid, self.ssb_cookie
)
}
pub fn been_used(&self) -> bool {
self.last_activity.is_some()
}
}
/// A smart pointer that returns a BannerSession to the pool when dropped.
@@ -97,6 +101,12 @@ pub struct PooledSession {
pool: Arc<Mutex<VecDeque<BannerSession>>>,
}
impl PooledSession {
pub fn been_used(&self) -> bool {
self.session.as_ref().unwrap().been_used()
}
}
impl Deref for PooledSession {
type Target = BannerSession;
fn deref(&self) -> &Self::Target {
@@ -117,17 +127,24 @@ impl Drop for PooledSession {
if let Some(session) = self.session.take() {
// Don't return expired sessions to the pool.
if session.is_expired() {
debug!("Session {} expired, dropping.", session.unique_session_id);
debug!(
id = session.unique_session_id,
"Session is now expired, dropping."
);
return;
}
// This is a synchronous lock, so it's allowed in drop().
// It blocks the current thread briefly to return the session.
let mut queue = self.pool.lock().unwrap();
let id = session.unique_session_id.clone();
queue.push_back(session);
debug!(
"Session returned to pool. Queue size is now {}.",
queue.len()
id = id,
"Session returned to pool. Queue size is now {queue_size}.",
queue_size = queue.len(),
);
}
}
@@ -168,7 +185,7 @@ impl SessionPool {
if let Some(mut session) = session_option {
// We got a session, check if it's expired.
if !session.is_expired() {
debug!("Reusing session {}", session.unique_session_id);
debug!(id = session.unique_session_id, "Reusing session");
session.touch();
return Ok(PooledSession {
@@ -177,8 +194,8 @@ impl SessionPool {
});
} else {
debug!(
"Popped an expired session {}, discarding.",
session.unique_session_id
id = session.unique_session_id,
"Popped an expired session, discarding.",
);
// The session is expired, so we loop again to try and get another one.
}
@@ -197,7 +214,7 @@ impl SessionPool {
/// Sets up initial session cookies by making required Banner API requests
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> {
info!("setting up banner session...");
info!("setting up banner session for term {term}");
// The 'register' or 'search' registration page
let initial_registration = self
@@ -220,32 +237,53 @@ impl SessionPool {
})
.collect::<HashMap<String, String>>();
if cookies.get("JSESSIONID").is_none() || cookies.get("SSB_COOKIE").is_none() {
return Err(anyhow::anyhow!("Failed to get cookies"));
}
let jsessionid = cookies.get("JSESSIONID").unwrap();
let ssb_cookie = cookies.get("SSB_COOKIE").unwrap();
let cookie_header = format!("JSESSIONID={}; SSB_COOKIE={}", jsessionid, ssb_cookie);
let data_page_response = self
.http
debug!(
jsessionid = jsessionid,
ssb_cookie = ssb_cookie,
"New session cookies acquired"
);
self.http
.get(format!("{}/selfServiceMenu/data", self.base_url))
.header("Cookie", &cookie_header)
.send()
.await?;
// TODO: Validate success
.await?
.error_for_status()
.context("Failed to get data page")?;
let term_selection_page_response = self
.http
self.http
.get(format!("{}/term/termSelection", self.base_url))
.header("Cookie", &cookie_header)
.query(&[("mode", "search")])
.send()
.await?;
.await?
.error_for_status()
.context("Failed to get term selection page")?;
// TOOD: Validate success
let term_search_response = self.get_terms("", 1, 10).await?;
// TODO: Validate that the term search response contains the term we want
/*let terms = self.get_terms("", 1, 10).await?;
if !terms.iter().any(|t| t.code == term.to_string()) {
return Err(anyhow::anyhow!("Failed to get term search response"));
}
let specific_term_search_response = self.get_terms(&term.to_string(), 1, 10).await?;
// TODO: Validate that the term response contains the term we want
if !specific_term_search_response
.iter()
.any(|t| t.code == term.to_string())
{
return Err(anyhow::anyhow!("Failed to get term search response"));
}*/
let unique_session_id = generate_session_id();
self.select_term(&term.to_string(), &unique_session_id)
self.select_term(&term.to_string(), &unique_session_id, &cookie_header)
.await?;
BannerSession::new(&unique_session_id, jsessionid, ssb_cookie).await
@@ -287,7 +325,12 @@ impl SessionPool {
}
/// Selects a term for the current session
pub async fn select_term(&self, term: &str, unique_session_id: &str) -> Result<()> {
pub async fn select_term(
&self,
term: &str,
unique_session_id: &str,
cookie_header: &str,
) -> Result<()> {
let form_data = [
("term", term),
("studyPath", ""),
@@ -301,6 +344,7 @@ impl SessionPool {
let response = self
.http
.post(&url)
.header("Cookie", cookie_header)
.query(&[("mode", "search")])
.form(&form_data)
.send()
@@ -327,7 +371,12 @@ impl SessionPool {
// Follow the redirect
let redirect_url = format!("{}{}", self.base_url, non_overlap_redirect);
let redirect_response = self.http.get(&redirect_url).send().await?;
let redirect_response = self
.http
.get(&redirect_url)
.header("Cookie", cookie_header)
.send()
.await?;
if !redirect_response.status().is_success() {
return Err(anyhow::anyhow!(
@@ -336,7 +385,7 @@ impl SessionPool {
));
}
debug!("successfully selected term: {}", term);
debug!(term = term, "successfully selected term");
Ok(())
}
}

View File

@@ -18,6 +18,7 @@ pub struct Config {
///
/// Valid values are: "trace", "debug", "info", "warn", "error"
/// Defaults to "info" if not specified
#[serde(default = "default_log_level")]
pub log_level: String,
/// Discord bot token for authentication
pub bot_token: String,
@@ -43,6 +44,11 @@ pub struct Config {
pub shutdown_timeout: Duration,
}
/// Default log level of "info"
fn default_log_level() -> String {
"info".to_string()
}
/// Default port of 3000
fn default_port() -> u16 {
3000

View File

@@ -37,7 +37,7 @@ impl Worker {
info!(worker_id = self.id, job_id = job.id, "Processing job");
if let Err(e) = self.process_job(job).await {
// Check if the error is due to an invalid session
if let Some(BannerApiError::InvalidSession) =
if let Some(BannerApiError::InvalidSession(_)) =
e.downcast_ref::<BannerApiError>()
{
warn!(