feat: middleware headers, fix concurrent session cookies issue, middleware headers, invalid session details

This commit is contained in:
2025-09-12 20:12:12 -05:00
parent dd212c3239
commit 2f853a7de9
4 changed files with 131 additions and 46 deletions

View File

@@ -10,18 +10,19 @@ use crate::banner::{
BannerSession, SessionPool, models::*, nonce, query::SearchQuery, util::user_agent, BannerSession, SessionPool, models::*, nonce, query::SearchQuery, util::user_agent,
}; };
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use axum::http::{Extensions, HeaderValue};
use cookie::Cookie; use cookie::Cookie;
use dashmap::DashMap; use dashmap::DashMap;
use http::{Extensions, HeaderValue};
use reqwest::{Client, Request, Response}; use reqwest::{Client, Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use serde_json; use serde_json;
use tracing::{Level, Metadata, Span, debug, error, field::ValueSet, info, span}; use tl;
use tracing::{Level, Metadata, Span, debug, error, field::ValueSet, info, span, trace, warn};
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum BannerApiError { pub enum BannerApiError {
#[error("Banner session is invalid or expired")] #[error("Banner session is invalid or expired: {0}")]
InvalidSession, InvalidSession(String),
#[error(transparent)] #[error(transparent)]
RequestFailed(#[from] anyhow::Error), RequestFailed(#[from] anyhow::Error),
} }
@@ -43,29 +44,36 @@ impl Middleware for TransparentMiddleware {
extensions: &mut Extensions, extensions: &mut Extensions,
next: Next<'_>, next: Next<'_>,
) -> std::result::Result<Response, reqwest_middleware::Error> { ) -> std::result::Result<Response, reqwest_middleware::Error> {
debug!( trace!(
domain = req.url().domain(), domain = req.url().domain(),
headers = ?req.headers(),
"{method} {path}", "{method} {path}",
method = req.method().to_string(), method = req.method().to_string(),
path = req.url().path(), path = req.url().path(),
); );
let response = next.run(req, extensions).await; let response_result = next.run(req, extensions).await;
match &response { match response_result {
Ok(response) => { Ok(response) => {
debug!( if response.status().is_success() {
"{code} {reason} {path}", trace!(
code = response.status().as_u16(), "{code} {reason} {path}",
reason = response.status().canonical_reason().unwrap_or("??"), code = response.status().as_u16(),
path = response.url().path(), reason = response.status().canonical_reason().unwrap_or("??"),
); path = response.url().path(),
);
Ok(response)
} else {
let e = response.error_for_status_ref().unwrap_err();
warn!(error = ?e, "Request failed (server)");
Ok(response)
}
} }
Err(error) => { Err(error) => {
debug!("!!! {error}"); warn!(?error, "Request failed (middleware)");
Err(error)
} }
} }
response
} }
} }
@@ -74,7 +82,7 @@ impl BannerApi {
pub fn new(base_url: String) -> Result<Self> { pub fn new(base_url: String) -> Result<Self> {
let http = ClientBuilder::new( let http = ClientBuilder::new(
Client::builder() Client::builder()
.cookie_store(true) .cookie_store(false)
.user_agent(user_agent()) .user_agent(user_agent())
.tcp_keepalive(Some(std::time::Duration::from_secs(60 * 5))) .tcp_keepalive(Some(std::time::Duration::from_secs(60 * 5)))
.read_timeout(std::time::Duration::from_secs(10)) .read_timeout(std::time::Duration::from_secs(10))
@@ -285,10 +293,24 @@ impl BannerApi {
params.insert("startDatepicker".to_string(), String::new()); params.insert("startDatepicker".to_string(), String::new());
params.insert("endDatepicker".to_string(), String::new()); params.insert("endDatepicker".to_string(), String::new());
let url = format!("{}/searchResults/searchResults", self.base_url); if session.been_used() {
self.http
.post(&format!("{}/classSearch/resetDataForm", self.base_url))
.send()
.await
.map_err(|e| BannerApiError::RequestFailed(e.into()))?;
}
debug!(
term = term,
query = ?query,
sort = sort,
sort_descending = sort_descending,
"Searching for courses with params: {:?}", params);
let response = self let response = self
.http .http
.get(&url) .get(format!("{}/searchResults/searchResults", self.base_url))
.header("Cookie", session.cookie()) .header("Cookie", session.cookie())
.query(&params) .query(&params)
.send() .send()
@@ -309,8 +331,14 @@ impl BannerApi {
})?; })?;
// Check for signs of an invalid session, based on docs/Sessions.md // Check for signs of an invalid session, based on docs/Sessions.md
if search_result.path_mode.is_none() || search_result.data.is_none() { if search_result.path_mode.is_none() {
return Err(BannerApiError::InvalidSession); return Err(BannerApiError::InvalidSession(
"Search result path mode is none".to_string(),
));
} else if search_result.data.is_none() {
return Err(BannerApiError::InvalidSession(
"Search result data is none".to_string(),
));
} }
if !search_result.success { if !search_result.success {
@@ -373,7 +401,9 @@ impl BannerApi {
if search_result.path_mode == Some("registration".to_string()) if search_result.path_mode == Some("registration".to_string())
&& search_result.data.is_none() && search_result.data.is_none()
{ {
return Err(BannerApiError::InvalidSession); return Err(BannerApiError::InvalidSession(
"Search result path mode is registration and data is none".to_string(),
));
} }
if !search_result.success { if !search_result.success {

View File

@@ -22,7 +22,7 @@ const SESSION_EXPIRY: Duration = Duration::from_secs(25 * 60); // 25 minutes
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BannerSession { pub struct BannerSession {
// Randomly generated // Randomly generated
unique_session_id: String, pub unique_session_id: String,
// Timestamp of creation // Timestamp of creation
created_at: Instant, created_at: Instant,
// Timestamp of last activity // Timestamp of last activity
@@ -72,7 +72,7 @@ impl BannerSession {
/// Updates the last activity timestamp /// Updates the last activity timestamp
pub fn touch(&mut self) { pub fn touch(&mut self) {
debug!("Session {} is being used", self.unique_session_id); debug!(id = self.unique_session_id, "Session was used");
self.last_activity = Some(Instant::now()); self.last_activity = Some(Instant::now());
} }
@@ -88,6 +88,10 @@ impl BannerSession {
self.jsessionid, self.ssb_cookie self.jsessionid, self.ssb_cookie
) )
} }
pub fn been_used(&self) -> bool {
self.last_activity.is_some()
}
} }
/// A smart pointer that returns a BannerSession to the pool when dropped. /// A smart pointer that returns a BannerSession to the pool when dropped.
@@ -97,6 +101,12 @@ pub struct PooledSession {
pool: Arc<Mutex<VecDeque<BannerSession>>>, pool: Arc<Mutex<VecDeque<BannerSession>>>,
} }
impl PooledSession {
pub fn been_used(&self) -> bool {
self.session.as_ref().unwrap().been_used()
}
}
impl Deref for PooledSession { impl Deref for PooledSession {
type Target = BannerSession; type Target = BannerSession;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@@ -117,17 +127,24 @@ impl Drop for PooledSession {
if let Some(session) = self.session.take() { if let Some(session) = self.session.take() {
// Don't return expired sessions to the pool. // Don't return expired sessions to the pool.
if session.is_expired() { if session.is_expired() {
debug!("Session {} expired, dropping.", session.unique_session_id); debug!(
id = session.unique_session_id,
"Session is now expired, dropping."
);
return; return;
} }
// This is a synchronous lock, so it's allowed in drop(). // This is a synchronous lock, so it's allowed in drop().
// It blocks the current thread briefly to return the session. // It blocks the current thread briefly to return the session.
let mut queue = self.pool.lock().unwrap(); let mut queue = self.pool.lock().unwrap();
let id = session.unique_session_id.clone();
queue.push_back(session); queue.push_back(session);
debug!( debug!(
"Session returned to pool. Queue size is now {}.", id = id,
queue.len() "Session returned to pool. Queue size is now {queue_size}.",
queue_size = queue.len(),
); );
} }
} }
@@ -168,7 +185,7 @@ impl SessionPool {
if let Some(mut session) = session_option { if let Some(mut session) = session_option {
// We got a session, check if it's expired. // We got a session, check if it's expired.
if !session.is_expired() { if !session.is_expired() {
debug!("Reusing session {}", session.unique_session_id); debug!(id = session.unique_session_id, "Reusing session");
session.touch(); session.touch();
return Ok(PooledSession { return Ok(PooledSession {
@@ -177,8 +194,8 @@ impl SessionPool {
}); });
} else { } else {
debug!( debug!(
"Popped an expired session {}, discarding.", id = session.unique_session_id,
session.unique_session_id "Popped an expired session, discarding.",
); );
// The session is expired, so we loop again to try and get another one. // The session is expired, so we loop again to try and get another one.
} }
@@ -197,7 +214,7 @@ impl SessionPool {
/// Sets up initial session cookies by making required Banner API requests /// Sets up initial session cookies by making required Banner API requests
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> { pub async fn create_session(&self, term: &Term) -> Result<BannerSession> {
info!("setting up banner session..."); info!("setting up banner session for term {term}");
// The 'register' or 'search' registration page // The 'register' or 'search' registration page
let initial_registration = self let initial_registration = self
@@ -220,32 +237,53 @@ impl SessionPool {
}) })
.collect::<HashMap<String, String>>(); .collect::<HashMap<String, String>>();
if cookies.get("JSESSIONID").is_none() || cookies.get("SSB_COOKIE").is_none() {
return Err(anyhow::anyhow!("Failed to get cookies"));
}
let jsessionid = cookies.get("JSESSIONID").unwrap(); let jsessionid = cookies.get("JSESSIONID").unwrap();
let ssb_cookie = cookies.get("SSB_COOKIE").unwrap(); let ssb_cookie = cookies.get("SSB_COOKIE").unwrap();
let cookie_header = format!("JSESSIONID={}; SSB_COOKIE={}", jsessionid, ssb_cookie);
let data_page_response = self debug!(
.http jsessionid = jsessionid,
ssb_cookie = ssb_cookie,
"New session cookies acquired"
);
self.http
.get(format!("{}/selfServiceMenu/data", self.base_url)) .get(format!("{}/selfServiceMenu/data", self.base_url))
.header("Cookie", &cookie_header)
.send() .send()
.await?; .await?
// TODO: Validate success .error_for_status()
.context("Failed to get data page")?;
let term_selection_page_response = self self.http
.http
.get(format!("{}/term/termSelection", self.base_url)) .get(format!("{}/term/termSelection", self.base_url))
.header("Cookie", &cookie_header)
.query(&[("mode", "search")]) .query(&[("mode", "search")])
.send() .send()
.await?; .await?
.error_for_status()
.context("Failed to get term selection page")?;
// TOOD: Validate success // TOOD: Validate success
let term_search_response = self.get_terms("", 1, 10).await?; /*let terms = self.get_terms("", 1, 10).await?;
// TODO: Validate that the term search response contains the term we want if !terms.iter().any(|t| t.code == term.to_string()) {
return Err(anyhow::anyhow!("Failed to get term search response"));
}
let specific_term_search_response = self.get_terms(&term.to_string(), 1, 10).await?; let specific_term_search_response = self.get_terms(&term.to_string(), 1, 10).await?;
// TODO: Validate that the term response contains the term we want if !specific_term_search_response
.iter()
.any(|t| t.code == term.to_string())
{
return Err(anyhow::anyhow!("Failed to get term search response"));
}*/
let unique_session_id = generate_session_id(); let unique_session_id = generate_session_id();
self.select_term(&term.to_string(), &unique_session_id) self.select_term(&term.to_string(), &unique_session_id, &cookie_header)
.await?; .await?;
BannerSession::new(&unique_session_id, jsessionid, ssb_cookie).await BannerSession::new(&unique_session_id, jsessionid, ssb_cookie).await
@@ -287,7 +325,12 @@ impl SessionPool {
} }
/// Selects a term for the current session /// Selects a term for the current session
pub async fn select_term(&self, term: &str, unique_session_id: &str) -> Result<()> { pub async fn select_term(
&self,
term: &str,
unique_session_id: &str,
cookie_header: &str,
) -> Result<()> {
let form_data = [ let form_data = [
("term", term), ("term", term),
("studyPath", ""), ("studyPath", ""),
@@ -301,6 +344,7 @@ impl SessionPool {
let response = self let response = self
.http .http
.post(&url) .post(&url)
.header("Cookie", cookie_header)
.query(&[("mode", "search")]) .query(&[("mode", "search")])
.form(&form_data) .form(&form_data)
.send() .send()
@@ -327,7 +371,12 @@ impl SessionPool {
// Follow the redirect // Follow the redirect
let redirect_url = format!("{}{}", self.base_url, non_overlap_redirect); let redirect_url = format!("{}{}", self.base_url, non_overlap_redirect);
let redirect_response = self.http.get(&redirect_url).send().await?; let redirect_response = self
.http
.get(&redirect_url)
.header("Cookie", cookie_header)
.send()
.await?;
if !redirect_response.status().is_success() { if !redirect_response.status().is_success() {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
@@ -336,7 +385,7 @@ impl SessionPool {
)); ));
} }
debug!("successfully selected term: {}", term); debug!(term = term, "successfully selected term");
Ok(()) Ok(())
} }
} }

View File

@@ -18,6 +18,7 @@ pub struct Config {
/// ///
/// Valid values are: "trace", "debug", "info", "warn", "error" /// Valid values are: "trace", "debug", "info", "warn", "error"
/// Defaults to "info" if not specified /// Defaults to "info" if not specified
#[serde(default = "default_log_level")]
pub log_level: String, pub log_level: String,
/// Discord bot token for authentication /// Discord bot token for authentication
pub bot_token: String, pub bot_token: String,
@@ -43,6 +44,11 @@ pub struct Config {
pub shutdown_timeout: Duration, pub shutdown_timeout: Duration,
} }
/// Default log level of "info"
fn default_log_level() -> String {
"info".to_string()
}
/// Default port of 3000 /// Default port of 3000
fn default_port() -> u16 { fn default_port() -> u16 {
3000 3000

View File

@@ -37,7 +37,7 @@ impl Worker {
info!(worker_id = self.id, job_id = job.id, "Processing job"); info!(worker_id = self.id, job_id = job.id, "Processing job");
if let Err(e) = self.process_job(job).await { if let Err(e) = self.process_job(job).await {
// Check if the error is due to an invalid session // Check if the error is due to an invalid session
if let Some(BannerApiError::InvalidSession) = if let Some(BannerApiError::InvalidSession(_)) =
e.downcast_ref::<BannerApiError>() e.downcast_ref::<BannerApiError>()
{ {
warn!( warn!(