From 5d7d60cd9689af95ea83eab7bfad6cc7d7d99e24 Mon Sep 17 00:00:00 2001 From: Xevion Date: Fri, 30 Jan 2026 20:19:10 -0600 Subject: [PATCH] fix: prevent session pool deadlock on acquire cancellation Replace is_creating mutex with atomic flag and RAII guard to ensure proper cleanup when acquire() futures are cancelled mid-creation, preventing permanent deadlock for subsequent callers. --- src/banner/json.rs | 4 + src/banner/session.rs | 176 +++++++++++++++++++++++++-------------- tests/admin_rmp.rs | 1 + tests/db_batch_upsert.rs | 1 + tests/db_scrape_jobs.rs | 1 + 5 files changed, 119 insertions(+), 64 deletions(-) diff --git a/src/banner/json.rs b/src/banner/json.rs index f973436..ee94f4d 100644 --- a/src/banner/json.rs +++ b/src/banner/json.rs @@ -325,6 +325,7 @@ mod tests { fn test_parse_json_with_context_null_value() { #[derive(Debug, Deserialize)] struct TestStruct { + #[allow(dead_code)] name: String, } @@ -363,12 +364,14 @@ mod tests { #[allow(dead_code)] #[serde(rename = "courseTitle")] course_title: String, + #[allow(dead_code)] faculty: Vec, } #[derive(Debug, Deserialize)] struct Faculty { #[serde(rename = "displayName")] + #[allow(dead_code)] display_name: String, #[allow(dead_code)] email: String, @@ -376,6 +379,7 @@ mod tests { #[derive(Debug, Deserialize)] struct SearchResult { + #[allow(dead_code)] data: Vec, } diff --git a/src/banner/session.rs b/src/banner/session.rs index 7e32420..8c9adca 100644 --- a/src/banner/session.rs +++ b/src/banner/session.rs @@ -11,7 +11,9 @@ use rand::distr::{Alphanumeric, SampleString}; use reqwest_middleware::ClientWithMiddleware; use std::collections::{HashMap, VecDeque}; +use std::mem::ManuallyDrop; use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use tokio::sync::{Mutex, Notify}; @@ -121,6 +123,64 @@ impl BannerSession { #[cfg(test)] mod tests { use super::*; + use std::time::Duration; + + /// Verifies that cancelling `acquire()` mid-session-creation resets `is_creating`, + /// allowing subsequent callers to proceed rather than deadlocking. + #[tokio::test] + async fn test_acquire_not_deadlocked_after_cancellation() { + use tokio::sync::mpsc; + + let (tx, mut rx) = mpsc::channel::<()>(10); + + // Local server: /registration signals arrival via `tx`, then hangs forever. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route( + "/StudentRegistrationSsb/registration", + axum::routing::get(move || { + let tx = tx.clone(); + async move { + let _ = tx.send(()).await; + std::future::pending::<&str>().await + } + }), + ); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let base_url = format!("http://{}/StudentRegistrationSsb", addr); + let client = reqwest_middleware::ClientBuilder::new( + reqwest::Client::builder() + .timeout(Duration::from_secs(300)) + .build() + .unwrap(), + ) + .build(); + + let pool = SessionPool::new(client, base_url); + let term: Term = "202620".parse().unwrap(); + + // First acquire: cancel once the request reaches the server. + tokio::select! { + _ = pool.acquire(term) => panic!("server hangs — acquire should never complete"), + _ = rx.recv() => {} // Request arrived; dropping the future simulates timeout cancellation. + } + + // Second acquire: verify it reaches the server (i.e., is_creating was reset). + // The global rate limiter has a 10s period, so allow 15s for the second attempt. + tokio::select! { + _ = pool.acquire(term) => {} + result = tokio::time::timeout(Duration::from_secs(15), rx.recv()) => { + assert!( + result.is_ok(), + "acquire() deadlocked — is_creating was not reset after cancellation" + ); + } + } + } #[test] fn test_new_session_creates_session() { @@ -200,50 +260,53 @@ mod tests { } } -/// A smart pointer that returns a BannerSession to the pool when dropped. +/// A smart pointer that returns a `BannerSession` to the pool when dropped. pub struct PooledSession { - session: Option, - // This Arc points directly to the term-specific pool. + session: ManuallyDrop, pool: Arc, } -impl PooledSession { - pub fn been_used(&self) -> bool { - self.session.as_ref().unwrap().been_used() - } -} - impl Deref for PooledSession { type Target = BannerSession; fn deref(&self) -> &Self::Target { - // The option is only ever None after drop is called, so this is safe. - self.session.as_ref().unwrap() + &self.session } } impl DerefMut for PooledSession { fn deref_mut(&mut self) -> &mut Self::Target { - self.session.as_mut().unwrap() + &mut self.session } } -/// The magic happens here: when the guard goes out of scope, this is called. impl Drop for PooledSession { fn drop(&mut self) { - if let Some(session) = self.session.take() { - let pool = self.pool.clone(); - // Since drop() cannot be async, we spawn a task to return the session. - tokio::spawn(async move { - pool.release(session).await; - }); - } + // SAFETY: `drop` is called exactly once by Rust's drop semantics, + // so `ManuallyDrop::take` is guaranteed to see a valid value. + let session = unsafe { ManuallyDrop::take(&mut self.session) }; + let pool = self.pool.clone(); + tokio::spawn(async move { + pool.release(session).await; + }); } } pub struct TermPool { sessions: Mutex>, notifier: Notify, - is_creating: Mutex, + is_creating: AtomicBool, +} + +/// RAII guard ensuring `is_creating` is reset on drop for cancellation safety. +/// Without this, a cancelled `acquire()` future would leave the flag set permanently, +/// deadlocking all subsequent callers. +struct CreatingGuard(Arc); + +impl Drop for CreatingGuard { + fn drop(&mut self) { + self.0.is_creating.store(false, Ordering::Release); + self.0.notifier.notify_waiters(); + } } impl TermPool { @@ -251,7 +314,7 @@ impl TermPool { Self { sessions: Mutex::new(VecDeque::new()), notifier: Notify::new(), - is_creating: Mutex::new(false), + is_creating: AtomicBool::new(false), } } @@ -308,7 +371,7 @@ impl SessionPool { if let Some(session) = queue.pop_front() { if !session.is_expired() { return Ok(PooledSession { - session: Some(session), + session: ManuallyDrop::new(session), pool: Arc::clone(&term_pool), }); } else { @@ -317,45 +380,38 @@ impl SessionPool { } } // MutexGuard is dropped, lock is released. - // Slow path: No sessions available. We must either wait or become the creator. - let mut is_creating_guard = term_pool.is_creating.lock().await; - if *is_creating_guard { - // Another task is already creating a session. Release the lock and wait. - drop(is_creating_guard); + // Slow path: wait for an in-progress creation, or become the creator. + if term_pool.is_creating.load(Ordering::Acquire) { if !waited_for_creation { trace!("Waiting for another task to create session"); waited_for_creation = true; } term_pool.notifier.notified().await; - // Loop back to the top to try the fast path again. continue; } - // This task is now the designated creator. - *is_creating_guard = true; - drop(is_creating_guard); + // CAS to become the designated creator. + if term_pool + .is_creating + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + continue; // Lost the race — loop back and wait. + } + + // Guard resets is_creating on drop (including cancellation). + let creating_guard = CreatingGuard(Arc::clone(&term_pool)); - // Race: wait for a session to be returned OR for the rate limiter to allow a new one. trace!("Pool empty, creating new session"); tokio::select! { _ = term_pool.notifier.notified() => { - // A session was returned while we were waiting! - // We are no longer the creator. Reset the flag and loop to race for the new session. - let mut guard = term_pool.is_creating.lock().await; - *guard = false; - drop(guard); + // A session was returned — release creator role and race for it. + drop(creating_guard); continue; } _ = SESSION_CREATION_RATE_LIMITER.until_ready() => { - // The rate limit has elapsed. It's our job to create the session. let new_session_result = self.create_session(&term).await; - - // After creation, we are no longer the creator. Reset the flag - // and notify all other waiting tasks. - let mut guard = term_pool.is_creating.lock().await; - *guard = false; - drop(guard); - term_pool.notifier.notify_waiters(); + drop(creating_guard); match new_session_result { Ok(new_session) => { @@ -366,12 +422,11 @@ impl SessionPool { "Created new session" ); return Ok(PooledSession { - session: Some(new_session), + session: ManuallyDrop::new(new_session), pool: term_pool, }); } Err(e) => { - // Propagate the error if session creation failed. return Err(e.context("Failed to create new session in pool")); } } @@ -380,8 +435,8 @@ impl SessionPool { } } - /// Sets up initial session cookies by making required Banner API requests - pub async fn create_session(&self, term: &Term) -> Result { + /// Sets up initial session cookies by making required Banner API requests. + async fn create_session(&self, term: &Term) -> Result { info!(term = %term, "setting up banner session"); // The 'register' or 'search' registration page @@ -392,22 +447,15 @@ impl SessionPool { .await?; // TODO: Validate success - let cookies = initial_registration + let cookies: HashMap = initial_registration .headers() .get_all("Set-Cookie") .iter() - .filter_map(|header_value| { - if let Ok(cookie_str) = header_value.to_str() { - if let Ok(cookie) = Cookie::parse(cookie_str) { - Some((cookie.name().to_string(), cookie.value().to_string())) - } else { - None - } - } else { - None - } + .filter_map(|v| { + let c = Cookie::parse(v.to_str().ok()?).ok()?; + Some((c.name().to_string(), c.value().to_string())) }) - .collect::>(); + .collect(); let jsessionid = cookies .get("JSESSIONID") @@ -494,8 +542,8 @@ impl SessionPool { Ok(terms) } - /// Selects a term for the current session - pub async fn select_term( + /// Selects a term for the current session. + async fn select_term( &self, term: &str, unique_session_id: &str, diff --git a/tests/admin_rmp.rs b/tests/admin_rmp.rs index 5d7631a..54a7929 100644 --- a/tests/admin_rmp.rs +++ b/tests/admin_rmp.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod helpers; use banner::data::rmp::unmatch_instructor; diff --git a/tests/db_batch_upsert.rs b/tests/db_batch_upsert.rs index 1c85b19..7c71ac7 100644 --- a/tests/db_batch_upsert.rs +++ b/tests/db_batch_upsert.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod helpers; use banner::data::batch::batch_upsert_courses; diff --git a/tests/db_scrape_jobs.rs b/tests/db_scrape_jobs.rs index ae69b7e..53e7aa2 100644 --- a/tests/db_scrape_jobs.rs +++ b/tests/db_scrape_jobs.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] mod helpers; use banner::data::models::{ScrapePriority, TargetType};