fix: prevent session pool deadlock on acquire cancellation

Replace is_creating mutex with atomic flag and RAII guard to ensure
proper cleanup when acquire() futures are cancelled mid-creation,
preventing permanent deadlock for subsequent callers.
This commit is contained in:
2026-01-30 20:19:10 -06:00
parent 1954166db6
commit 5d7d60cd96
5 changed files with 119 additions and 64 deletions
+4
View File
@@ -325,6 +325,7 @@ mod tests {
fn test_parse_json_with_context_null_value() {
#[derive(Debug, Deserialize)]
struct TestStruct {
#[allow(dead_code)]
name: String,
}
@@ -363,12 +364,14 @@ mod tests {
#[allow(dead_code)]
#[serde(rename = "courseTitle")]
course_title: String,
#[allow(dead_code)]
faculty: Vec<Faculty>,
}
#[derive(Debug, Deserialize)]
struct Faculty {
#[serde(rename = "displayName")]
#[allow(dead_code)]
display_name: String,
#[allow(dead_code)]
email: String,
@@ -376,6 +379,7 @@ mod tests {
#[derive(Debug, Deserialize)]
struct SearchResult {
#[allow(dead_code)]
data: Vec<Course>,
}
+112 -64
View File
@@ -11,7 +11,9 @@ use rand::distr::{Alphanumeric, SampleString};
use reqwest_middleware::ClientWithMiddleware;
use std::collections::{HashMap, VecDeque};
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
@@ -121,6 +123,64 @@ impl BannerSession {
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
/// Verifies that cancelling `acquire()` mid-session-creation resets `is_creating`,
/// allowing subsequent callers to proceed rather than deadlocking.
#[tokio::test]
async fn test_acquire_not_deadlocked_after_cancellation() {
use tokio::sync::mpsc;
let (tx, mut rx) = mpsc::channel::<()>(10);
// Local server: /registration signals arrival via `tx`, then hangs forever.
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = axum::Router::new().route(
"/StudentRegistrationSsb/registration",
axum::routing::get(move || {
let tx = tx.clone();
async move {
let _ = tx.send(()).await;
std::future::pending::<&str>().await
}
}),
);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let base_url = format!("http://{}/StudentRegistrationSsb", addr);
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.build()
.unwrap(),
)
.build();
let pool = SessionPool::new(client, base_url);
let term: Term = "202620".parse().unwrap();
// First acquire: cancel once the request reaches the server.
tokio::select! {
_ = pool.acquire(term) => panic!("server hangs — acquire should never complete"),
_ = rx.recv() => {} // Request arrived; dropping the future simulates timeout cancellation.
}
// Second acquire: verify it reaches the server (i.e., is_creating was reset).
// The global rate limiter has a 10s period, so allow 15s for the second attempt.
tokio::select! {
_ = pool.acquire(term) => {}
result = tokio::time::timeout(Duration::from_secs(15), rx.recv()) => {
assert!(
result.is_ok(),
"acquire() deadlocked — is_creating was not reset after cancellation"
);
}
}
}
#[test]
fn test_new_session_creates_session() {
@@ -200,50 +260,53 @@ mod tests {
}
}
/// A smart pointer that returns a BannerSession to the pool when dropped.
/// A smart pointer that returns a `BannerSession` to the pool when dropped.
pub struct PooledSession {
session: Option<BannerSession>,
// This Arc points directly to the term-specific pool.
session: ManuallyDrop<BannerSession>,
pool: Arc<TermPool>,
}
impl PooledSession {
pub fn been_used(&self) -> bool {
self.session.as_ref().unwrap().been_used()
}
}
impl Deref for PooledSession {
type Target = BannerSession;
fn deref(&self) -> &Self::Target {
// The option is only ever None after drop is called, so this is safe.
self.session.as_ref().unwrap()
&self.session
}
}
impl DerefMut for PooledSession {
fn deref_mut(&mut self) -> &mut Self::Target {
self.session.as_mut().unwrap()
&mut self.session
}
}
/// The magic happens here: when the guard goes out of scope, this is called.
impl Drop for PooledSession {
fn drop(&mut self) {
if let Some(session) = self.session.take() {
let pool = self.pool.clone();
// Since drop() cannot be async, we spawn a task to return the session.
tokio::spawn(async move {
pool.release(session).await;
});
}
// SAFETY: `drop` is called exactly once by Rust's drop semantics,
// so `ManuallyDrop::take` is guaranteed to see a valid value.
let session = unsafe { ManuallyDrop::take(&mut self.session) };
let pool = self.pool.clone();
tokio::spawn(async move {
pool.release(session).await;
});
}
}
pub struct TermPool {
sessions: Mutex<VecDeque<BannerSession>>,
notifier: Notify,
is_creating: Mutex<bool>,
is_creating: AtomicBool,
}
/// RAII guard ensuring `is_creating` is reset on drop for cancellation safety.
/// Without this, a cancelled `acquire()` future would leave the flag set permanently,
/// deadlocking all subsequent callers.
struct CreatingGuard(Arc<TermPool>);
impl Drop for CreatingGuard {
fn drop(&mut self) {
self.0.is_creating.store(false, Ordering::Release);
self.0.notifier.notify_waiters();
}
}
impl TermPool {
@@ -251,7 +314,7 @@ impl TermPool {
Self {
sessions: Mutex::new(VecDeque::new()),
notifier: Notify::new(),
is_creating: Mutex::new(false),
is_creating: AtomicBool::new(false),
}
}
@@ -308,7 +371,7 @@ impl SessionPool {
if let Some(session) = queue.pop_front() {
if !session.is_expired() {
return Ok(PooledSession {
session: Some(session),
session: ManuallyDrop::new(session),
pool: Arc::clone(&term_pool),
});
} else {
@@ -317,45 +380,38 @@ impl SessionPool {
}
} // MutexGuard is dropped, lock is released.
// Slow path: No sessions available. We must either wait or become the creator.
let mut is_creating_guard = term_pool.is_creating.lock().await;
if *is_creating_guard {
// Another task is already creating a session. Release the lock and wait.
drop(is_creating_guard);
// Slow path: wait for an in-progress creation, or become the creator.
if term_pool.is_creating.load(Ordering::Acquire) {
if !waited_for_creation {
trace!("Waiting for another task to create session");
waited_for_creation = true;
}
term_pool.notifier.notified().await;
// Loop back to the top to try the fast path again.
continue;
}
// This task is now the designated creator.
*is_creating_guard = true;
drop(is_creating_guard);
// CAS to become the designated creator.
if term_pool
.is_creating
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
continue; // Lost the race — loop back and wait.
}
// Guard resets is_creating on drop (including cancellation).
let creating_guard = CreatingGuard(Arc::clone(&term_pool));
// Race: wait for a session to be returned OR for the rate limiter to allow a new one.
trace!("Pool empty, creating new session");
tokio::select! {
_ = term_pool.notifier.notified() => {
// A session was returned while we were waiting!
// We are no longer the creator. Reset the flag and loop to race for the new session.
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
// A session was returned — release creator role and race for it.
drop(creating_guard);
continue;
}
_ = SESSION_CREATION_RATE_LIMITER.until_ready() => {
// The rate limit has elapsed. It's our job to create the session.
let new_session_result = self.create_session(&term).await;
// After creation, we are no longer the creator. Reset the flag
// and notify all other waiting tasks.
let mut guard = term_pool.is_creating.lock().await;
*guard = false;
drop(guard);
term_pool.notifier.notify_waiters();
drop(creating_guard);
match new_session_result {
Ok(new_session) => {
@@ -366,12 +422,11 @@ impl SessionPool {
"Created new session"
);
return Ok(PooledSession {
session: Some(new_session),
session: ManuallyDrop::new(new_session),
pool: term_pool,
});
}
Err(e) => {
// Propagate the error if session creation failed.
return Err(e.context("Failed to create new session in pool"));
}
}
@@ -380,8 +435,8 @@ impl SessionPool {
}
}
/// Sets up initial session cookies by making required Banner API requests
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> {
/// Sets up initial session cookies by making required Banner API requests.
async fn create_session(&self, term: &Term) -> Result<BannerSession> {
info!(term = %term, "setting up banner session");
// The 'register' or 'search' registration page
@@ -392,22 +447,15 @@ impl SessionPool {
.await?;
// TODO: Validate success
let cookies = initial_registration
let cookies: HashMap<String, String> = initial_registration
.headers()
.get_all("Set-Cookie")
.iter()
.filter_map(|header_value| {
if let Ok(cookie_str) = header_value.to_str() {
if let Ok(cookie) = Cookie::parse(cookie_str) {
Some((cookie.name().to_string(), cookie.value().to_string()))
} else {
None
}
} else {
None
}
.filter_map(|v| {
let c = Cookie::parse(v.to_str().ok()?).ok()?;
Some((c.name().to_string(), c.value().to_string()))
})
.collect::<HashMap<String, String>>();
.collect();
let jsessionid = cookies
.get("JSESSIONID")
@@ -494,8 +542,8 @@ impl SessionPool {
Ok(terms)
}
/// Selects a term for the current session
pub async fn select_term(
/// Selects a term for the current session.
async fn select_term(
&self,
term: &str,
unique_session_id: &str,
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers;
use banner::data::rmp::unmatch_instructor;
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers;
use banner::data::batch::batch_upsert_courses;
+1
View File
@@ -1,3 +1,4 @@
#[allow(dead_code)]
mod helpers;
use banner::data::models::{ScrapePriority, TargetType};