fix: prevent session pool deadlock on acquire cancellation

Replace is_creating mutex with atomic flag and RAII guard to ensure
proper cleanup when acquire() futures are cancelled mid-creation,
preventing permanent deadlock for subsequent callers.
This commit is contained in:
2026-01-30 20:19:10 -06:00
parent 1954166db6
commit 5d7d60cd96
5 changed files with 119 additions and 64 deletions
+4
View File
@@ -325,6 +325,7 @@ mod tests {
fn test_parse_json_with_context_null_value() { fn test_parse_json_with_context_null_value() {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct TestStruct { struct TestStruct {
#[allow(dead_code)]
name: String, name: String,
} }
@@ -363,12 +364,14 @@ mod tests {
#[allow(dead_code)] #[allow(dead_code)]
#[serde(rename = "courseTitle")] #[serde(rename = "courseTitle")]
course_title: String, course_title: String,
#[allow(dead_code)]
faculty: Vec<Faculty>, faculty: Vec<Faculty>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Faculty { struct Faculty {
#[serde(rename = "displayName")] #[serde(rename = "displayName")]
#[allow(dead_code)]
display_name: String, display_name: String,
#[allow(dead_code)] #[allow(dead_code)]
email: String, email: String,
@@ -376,6 +379,7 @@ mod tests {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct SearchResult { struct SearchResult {
#[allow(dead_code)]
data: Vec<Course>, data: Vec<Course>,
} }
+112 -64
View File
@@ -11,7 +11,9 @@ use rand::distr::{Alphanumeric, SampleString};
use reqwest_middleware::ClientWithMiddleware; use reqwest_middleware::ClientWithMiddleware;
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, LazyLock}; use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify}; use tokio::sync::{Mutex, Notify};
@@ -121,6 +123,64 @@ impl BannerSession {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::time::Duration;
/// Verifies that cancelling `acquire()` mid-session-creation resets `is_creating`,
/// allowing subsequent callers to proceed rather than deadlocking.
#[tokio::test]
async fn test_acquire_not_deadlocked_after_cancellation() {
use tokio::sync::mpsc;
let (tx, mut rx) = mpsc::channel::<()>(10);
// Local server: /registration signals arrival via `tx`, then hangs forever.
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = axum::Router::new().route(
"/StudentRegistrationSsb/registration",
axum::routing::get(move || {
let tx = tx.clone();
async move {
let _ = tx.send(()).await;
std::future::pending::<&str>().await
}
}),
);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let base_url = format!("http://{}/StudentRegistrationSsb", addr);
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.build()
.unwrap(),
)
.build();
let pool = SessionPool::new(client, base_url);
let term: Term = "202620".parse().unwrap();
// First acquire: cancel once the request reaches the server.
tokio::select! {
_ = pool.acquire(term) => panic!("server hangs — acquire should never complete"),
_ = rx.recv() => {} // Request arrived; dropping the future simulates timeout cancellation.
}
// Second acquire: verify it reaches the server (i.e., is_creating was reset).
// The global rate limiter has a 10s period, so allow 15s for the second attempt.
tokio::select! {
_ = pool.acquire(term) => {}
result = tokio::time::timeout(Duration::from_secs(15), rx.recv()) => {
assert!(
result.is_ok(),
"acquire() deadlocked — is_creating was not reset after cancellation"
);
}
}
}
#[test] #[test]
fn test_new_session_creates_session() { fn test_new_session_creates_session() {
@@ -200,50 +260,53 @@ mod tests {
} }
} }
/// A smart pointer that returns a BannerSession to the pool when dropped. /// A smart pointer that returns a `BannerSession` to the pool when dropped.
pub struct PooledSession { pub struct PooledSession {
session: Option<BannerSession>, session: ManuallyDrop<BannerSession>,
// This Arc points directly to the term-specific pool.
pool: Arc<TermPool>, pool: Arc<TermPool>,
} }
impl PooledSession {
pub fn been_used(&self) -> bool {
self.session.as_ref().unwrap().been_used()
}
}
impl Deref for PooledSession { impl Deref for PooledSession {
type Target = BannerSession; type Target = BannerSession;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
// The option is only ever None after drop is called, so this is safe. &self.session
self.session.as_ref().unwrap()
} }
} }
impl DerefMut for PooledSession { impl DerefMut for PooledSession {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.session.as_mut().unwrap() &mut self.session
} }
} }
/// The magic happens here: when the guard goes out of scope, this is called.
impl Drop for PooledSession { impl Drop for PooledSession {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(session) = self.session.take() { // SAFETY: `drop` is called exactly once by Rust's drop semantics,
let pool = self.pool.clone(); // so `ManuallyDrop::take` is guaranteed to see a valid value.
// Since drop() cannot be async, we spawn a task to return the session. let session = unsafe { ManuallyDrop::take(&mut self.session) };
tokio::spawn(async move { let pool = self.pool.clone();
pool.release(session).await; tokio::spawn(async move {
}); pool.release(session).await;
} });
} }
} }
pub struct TermPool { pub struct TermPool {
sessions: Mutex<VecDeque<BannerSession>>, sessions: Mutex<VecDeque<BannerSession>>,
notifier: Notify, notifier: Notify,
is_creating: Mutex<bool>, is_creating: AtomicBool,
}
/// RAII guard ensuring `is_creating` is reset on drop for cancellation safety.
/// Without this, a cancelled `acquire()` future would leave the flag set permanently,
/// deadlocking all subsequent callers.
struct CreatingGuard(Arc<TermPool>);
impl Drop for CreatingGuard {
fn drop(&mut self) {
self.0.is_creating.store(false, Ordering::Release);
self.0.notifier.notify_waiters();
}
} }
impl TermPool { impl TermPool {
@@ -251,7 +314,7 @@ impl TermPool {
Self { Self {
sessions: Mutex::new(VecDeque::new()), sessions: Mutex::new(VecDeque::new()),
notifier: Notify::new(), notifier: Notify::new(),
is_creating: Mutex::new(false), is_creating: AtomicBool::new(false),
} }
} }
@@ -308,7 +371,7 @@ impl SessionPool {
if let Some(session) = queue.pop_front() { if let Some(session) = queue.pop_front() {
if !session.is_expired() { if !session.is_expired() {
return Ok(PooledSession { return Ok(PooledSession {
session: Some(session), session: ManuallyDrop::new(session),
pool: Arc::clone(&term_pool), pool: Arc::clone(&term_pool),
}); });
} else { } else {
@@ -317,45 +380,38 @@ impl SessionPool {
} }
} // MutexGuard is dropped, lock is released. } // MutexGuard is dropped, lock is released.
// Slow path: No sessions available. We must either wait or become the creator. // Slow path: wait for an in-progress creation, or become the creator.
let mut is_creating_guard = term_pool.is_creating.lock().await; if term_pool.is_creating.load(Ordering::Acquire) {
if *is_creating_guard {
// Another task is already creating a session. Release the lock and wait.
drop(is_creating_guard);
if !waited_for_creation { if !waited_for_creation {
trace!("Waiting for another task to create session"); trace!("Waiting for another task to create session");
waited_for_creation = true; waited_for_creation = true;
} }
term_pool.notifier.notified().await; term_pool.notifier.notified().await;
// Loop back to the top to try the fast path again.
continue; continue;
} }
// This task is now the designated creator. // CAS to become the designated creator.
*is_creating_guard = true; if term_pool
drop(is_creating_guard); .is_creating
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
continue; // Lost the race — loop back and wait.
}
// Guard resets is_creating on drop (including cancellation).
let creating_guard = CreatingGuard(Arc::clone(&term_pool));
// Race: wait for a session to be returned OR for the rate limiter to allow a new one.
trace!("Pool empty, creating new session"); trace!("Pool empty, creating new session");
tokio::select! { tokio::select! {
_ = term_pool.notifier.notified() => { _ = term_pool.notifier.notified() => {
// A session was returned while we were waiting! // A session was returned — release creator role and race for it.
// We are no longer the creator. Reset the flag and loop to race for the new session. drop(creating_guard);
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
continue; continue;
} }
_ = SESSION_CREATION_RATE_LIMITER.until_ready() => { _ = SESSION_CREATION_RATE_LIMITER.until_ready() => {
// The rate limit has elapsed. It's our job to create the session.
let new_session_result = self.create_session(&term).await; let new_session_result = self.create_session(&term).await;
drop(creating_guard);
// After creation, we are no longer the creator. Reset the flag
// and notify all other waiting tasks.
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
term_pool.notifier.notify_waiters();
match new_session_result { match new_session_result {
Ok(new_session) => { Ok(new_session) => {
@@ -366,12 +422,11 @@ impl SessionPool {
"Created new session" "Created new session"
); );
return Ok(PooledSession { return Ok(PooledSession {
session: Some(new_session), session: ManuallyDrop::new(new_session),
pool: term_pool, pool: term_pool,
}); });
} }
Err(e) => { Err(e) => {
// Propagate the error if session creation failed.
return Err(e.context("Failed to create new session in pool")); return Err(e.context("Failed to create new session in pool"));
} }
} }
@@ -380,8 +435,8 @@ impl SessionPool {
} }
} }
/// Sets up initial session cookies by making required Banner API requests /// Sets up initial session cookies by making required Banner API requests.
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> { async fn create_session(&self, term: &Term) -> Result<BannerSession> {
info!(term = %term, "setting up banner session"); info!(term = %term, "setting up banner session");
// The 'register' or 'search' registration page // The 'register' or 'search' registration page
@@ -392,22 +447,15 @@ impl SessionPool {
.await?; .await?;
// TODO: Validate success // TODO: Validate success
let cookies = initial_registration let cookies: HashMap<String, String> = initial_registration
.headers() .headers()
.get_all("Set-Cookie") .get_all("Set-Cookie")
.iter() .iter()
.filter_map(|header_value| { .filter_map(|v| {
if let Ok(cookie_str) = header_value.to_str() { let c = Cookie::parse(v.to_str().ok()?).ok()?;
if let Ok(cookie) = Cookie::parse(cookie_str) { Some((c.name().to_string(), c.value().to_string()))
Some((cookie.name().to_string(), cookie.value().to_string()))
} else {
None
}
} else {
None
}
}) })
.collect::<HashMap<String, String>>(); .collect();
let jsessionid = cookies let jsessionid = cookies
.get("JSESSIONID") .get("JSESSIONID")
@@ -494,8 +542,8 @@ impl SessionPool {
Ok(terms) Ok(terms)
} }
/// Selects a term for the current session /// Selects a term for the current session.
pub async fn select_term( async fn select_term(
&self, &self,
term: &str, term: &str,
unique_session_id: &str, unique_session_id: &str,
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers; mod helpers;
use banner::data::rmp::unmatch_instructor; use banner::data::rmp::unmatch_instructor;
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers; mod helpers;
use banner::data::batch::batch_upsert_courses; use banner::data::batch::batch_upsert_courses;
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers; mod helpers;
use banner::data::models::{ScrapePriority, TargetType}; use banner::data::models::{ScrapePriority, TargetType};