diff --git a/Cargo.lock b/Cargo.lock index 9fbe288..b874132 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,7 +182,9 @@ dependencies = [ "figment", "fundu", "futures", + "governor", "http 1.3.1", + "once_cell", "poise", "rand 0.9.2", "redis", @@ -878,6 +880,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -933,9 +941,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasi 0.14.3+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -950,6 +960,29 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "governor" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "444405bbb1a762387aa22dd569429533b54a1d8759d35d3b64cb39b0293eaa19" +dependencies = [ + "cfg-if", + "dashmap 6.1.0", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.3", + "hashbrown 0.15.5", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.2", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.3.27" @@ -1617,6 +1650,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -1883,6 +1922,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + [[package]] name = "potential_utf" version = "0.1.3" @@ -1956,6 +2001,21 @@ dependencies = [ "unicase", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.1+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.40" @@ -2041,6 +2101,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.9.4", +] + [[package]] name = "redis" version = "0.32.5" @@ -2682,6 +2751,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3719,6 +3797,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.25.4" @@ -3753,6 +3841,22 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.10" @@ -3762,6 +3866,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.61.2" diff --git a/Cargo.toml b/Cargo.toml index 8c89d18..accc2f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,5 +41,7 @@ tl = "0.7.8" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } url = "2.5" +governor = "0.10.1" +once_cell = "1.21.3" [dev-dependencies] diff --git a/src/banner/api.rs b/src/banner/api.rs index 1ed826c..8bb3f58 100644 --- a/src/banner/api.rs +++ b/src/banner/api.rs @@ -279,7 +279,19 @@ impl BannerApi { ) -> Result { // self.sessions.reset_data_form().await?; - let session = self.sessions.acquire(term.parse()?).await?; + let mut session = self.sessions.acquire(term.parse()?).await?; + + if session.been_used() { + self.http + .post(&format!("{}/classSearch/resetDataForm", self.base_url)) + .header("Cookie", session.cookie()) + .send() + .await + .map_err(|e| BannerApiError::RequestFailed(e.into()))?; + } + + session.touch(); + let mut params = query.to_params(); // Add additional parameters @@ -293,14 +305,6 @@ impl BannerApi { params.insert("startDatepicker".to_string(), String::new()); params.insert("endDatepicker".to_string(), String::new()); - if session.been_used() { - self.http - .post(&format!("{}/classSearch/resetDataForm", self.base_url)) - .send() - .await - .map_err(|e| BannerApiError::RequestFailed(e.into()))?; - } - debug!( term = term, query = ?query, diff --git a/src/banner/session.rs b/src/banner/session.rs index a9119be..cd18c9c 100644 --- a/src/banner/session.rs +++ b/src/banner/session.rs @@ -5,18 +5,28 @@ use crate::banner::models::Term; use anyhow::{Context, Result}; use cookie::Cookie; use dashmap::DashMap; +use governor::state::InMemoryState; +use governor::{Quota, RateLimiter}; +use once_cell::sync::Lazy; use rand::distr::{Alphanumeric, SampleString}; -use reqwest::Client; use reqwest_middleware::ClientWithMiddleware; use std::collections::{HashMap, VecDeque}; +use std::num::NonZeroU32; use std::ops::{Deref, DerefMut}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, Notify}; use tracing::{debug, info}; use url::Url; const SESSION_EXPIRY: Duration = Duration::from_secs(25 * 60); // 25 minutes +// A global rate limiter to ensure we only try to create one new session every 10 seconds, +// preventing us from overwhelming the server with session creation requests. +static SESSION_CREATION_RATE_LIMITER: Lazy< + RateLimiter, +> = Lazy::new(|| RateLimiter::direct(Quota::with_period(Duration::from_secs(10)).unwrap())); + /// Represents an active anonymous session within the Banner API. /// Identified by multiple persistent cookies, as well as a client-generated "unique session ID". #[derive(Debug, Clone)] @@ -97,8 +107,8 @@ impl BannerSession { /// A smart pointer that returns a BannerSession to the pool when dropped. pub struct PooledSession { session: Option, - // This Arc points directly to the queue the session belongs to. - pool: Arc>>, + // This Arc points directly to the term-specific pool. + pool: Arc, } impl PooledSession { @@ -125,33 +135,55 @@ impl DerefMut for PooledSession { impl Drop for PooledSession { fn drop(&mut self) { if let Some(session) = self.session.take() { - // Don't return expired sessions to the pool. - if session.is_expired() { - debug!( - id = session.unique_session_id, - "Session is now expired, dropping." - ); - return; - } - - // This is a synchronous lock, so it's allowed in drop(). - // It blocks the current thread briefly to return the session. - let mut queue = self.pool.lock().unwrap(); - - let id = session.unique_session_id.clone(); - queue.push_back(session); - - debug!( - id = id, - "Session returned to pool. Queue size is now {queue_size}.", - queue_size = queue.len(), - ); + let pool = self.pool.clone(); + // Since drop() cannot be async, we spawn a task to return the session. + tokio::spawn(async move { + pool.release(session).await; + }); } } } +pub struct TermPool { + sessions: Mutex>, + notifier: Notify, + is_creating: Mutex, +} + +impl TermPool { + fn new() -> Self { + Self { + sessions: Mutex::new(VecDeque::new()), + notifier: Notify::new(), + is_creating: Mutex::new(false), + } + } + + async fn release(&self, session: BannerSession) { + let id = session.unique_session_id.clone(); + if session.is_expired() { + debug!(id = id, "Session is now expired, dropping."); + // Wake up a waiter, as it might need to create a new session + // if this was the last one. + self.notifier.notify_one(); + return; + } + + let mut queue = self.sessions.lock().await; + queue.push_back(session); + let queue_size = queue.len(); + drop(queue); // Release lock before notifying + + debug!( + id = id, + "Session returned to pool. Queue size is now {queue_size}." + ); + self.notifier.notify_one(); + } +} + pub struct SessionPool { - sessions: DashMap>>>, + sessions: DashMap>, http: ClientWithMiddleware, base_url: String, } @@ -166,48 +198,88 @@ impl SessionPool { } /// Acquires a session from the pool. - /// If no sessions are available, a new one is created on demand. + /// If no sessions are available, a new one is created on demand, + /// respecting the global rate limit. pub async fn acquire(&self, term: Term) -> Result { - // Get the queue for the given term, or insert a new empty one. - let pool_entry = self + let term_pool = self .sessions .entry(term.clone()) - .or_insert_with(|| Arc::new(Mutex::new(VecDeque::new()))) + .or_insert_with(|| Arc::new(TermPool::new())) .clone(); loop { - // Lock the specific queue for this term - let session_option = { - let mut queue = pool_entry.lock().unwrap(); - queue.pop_front() // Try to get a session - }; - - if let Some(mut session) = session_option { - // We got a session, check if it's expired. - if !session.is_expired() { - debug!(id = session.unique_session_id, "Reusing session"); - - session.touch(); - return Ok(PooledSession { - session: Some(session), - pool: pool_entry, - }); - } else { - debug!( - id = session.unique_session_id, - "Popped an expired session, discarding.", - ); - // The session is expired, so we loop again to try and get another one. + // Fast path: Try to get an existing, non-expired session. + { + let mut queue = term_pool.sessions.lock().await; + if let Some(session) = queue.pop_front() { + if !session.is_expired() { + debug!(id = session.unique_session_id, "Reusing session from pool"); + return Ok(PooledSession { + session: Some(session), + pool: Arc::clone(&term_pool), + }); + } else { + debug!( + id = session.unique_session_id, + "Popped an expired session, discarding." + ); + } } - } else { - // Queue was empty, so we create a new session. - let mut new_session = self.create_session(&term).await?; - new_session.touch(); + } // MutexGuard is dropped, lock is released. - return Ok(PooledSession { - session: Some(new_session), - pool: pool_entry, - }); + // Slow path: No sessions available. We must either wait or become the creator. + let mut is_creating_guard = term_pool.is_creating.lock().await; + if *is_creating_guard { + // Another task is already creating a session. Release the lock and wait. + drop(is_creating_guard); + debug!("Another task is creating a session, waiting for notification..."); + term_pool.notifier.notified().await; + // Loop back to the top to try the fast path again. + continue; + } + + // This task is now the designated creator. + *is_creating_guard = true; + drop(is_creating_guard); + + // Race: wait for a session to be returned OR for the rate limiter to allow a new one. + debug!("Pool empty, racing notifier vs rate limiter..."); + tokio::select! { + _ = term_pool.notifier.notified() => { + // A session was returned while we were waiting! + // We are no longer the creator. Reset the flag and loop to race for the new session. + debug!("Notified that a session was returned. Looping to retry."); + let mut guard = term_pool.is_creating.lock().await; + *guard = false; + drop(guard); + continue; + } + _ = SESSION_CREATION_RATE_LIMITER.until_ready() => { + // The rate limit has elapsed. It's our job to create the session. + debug!("Rate limiter ready. Proceeding to create a new session."); + let new_session_result = self.create_session(&term).await; + + // After creation, we are no longer the creator. Reset the flag + // and notify all other waiting tasks. + let mut guard = term_pool.is_creating.lock().await; + *guard = false; + drop(guard); + term_pool.notifier.notify_waiters(); + + match new_session_result { + Ok(new_session) => { + debug!(id = new_session.unique_session_id, "Successfully created new session"); + return Ok(PooledSession { + session: Some(new_session), + pool: term_pool, + }); + } + Err(e) => { + // Propagate the error if session creation failed. + return Err(e.context("Failed to create new session in pool")); + } + } + } } } } @@ -269,7 +341,7 @@ impl SessionPool { .context("Failed to get term selection page")?; // TOOD: Validate success - /*let terms = self.get_terms("", 1, 10).await?; + let terms = self.get_terms("", 1, 10).await?; if !terms.iter().any(|t| t.code == term.to_string()) { return Err(anyhow::anyhow!("Failed to get term search response")); } @@ -280,7 +352,7 @@ impl SessionPool { .any(|t| t.code == term.to_string()) { return Err(anyhow::anyhow!("Failed to get term search response")); - }*/ + } let unique_session_id = generate_session_id(); self.select_term(&term.to_string(), &unique_session_id, &cookie_header)