mirror of
https://github.com/Xevion/banner.git
synced 2026-01-31 04:23:34 -06:00
fix: prevent session pool deadlock on acquire cancellation
Replace is_creating mutex with atomic flag and RAII guard to ensure proper cleanup when acquire() futures are cancelled mid-creation, preventing permanent deadlock for subsequent callers.
This commit is contained in:
@@ -325,6 +325,7 @@ mod tests {
|
|||||||
fn test_parse_json_with_context_null_value() {
|
fn test_parse_json_with_context_null_value() {
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct TestStruct {
|
struct TestStruct {
|
||||||
|
#[allow(dead_code)]
|
||||||
name: String,
|
name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,12 +364,14 @@ mod tests {
|
|||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[serde(rename = "courseTitle")]
|
#[serde(rename = "courseTitle")]
|
||||||
course_title: String,
|
course_title: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
faculty: Vec<Faculty>,
|
faculty: Vec<Faculty>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Faculty {
|
struct Faculty {
|
||||||
#[serde(rename = "displayName")]
|
#[serde(rename = "displayName")]
|
||||||
|
#[allow(dead_code)]
|
||||||
display_name: String,
|
display_name: String,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
email: String,
|
email: String,
|
||||||
@@ -376,6 +379,7 @@ mod tests {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct SearchResult {
|
struct SearchResult {
|
||||||
|
#[allow(dead_code)]
|
||||||
data: Vec<Course>,
|
data: Vec<Course>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+112
-64
@@ -11,7 +11,9 @@ use rand::distr::{Alphanumeric, SampleString};
|
|||||||
use reqwest_middleware::ClientWithMiddleware;
|
use reqwest_middleware::ClientWithMiddleware;
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
|
||||||
|
use std::mem::ManuallyDrop;
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::ops::{Deref, DerefMut};
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::{Arc, LazyLock};
|
use std::sync::{Arc, LazyLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::{Mutex, Notify};
|
use tokio::sync::{Mutex, Notify};
|
||||||
@@ -121,6 +123,64 @@ impl BannerSession {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Verifies that cancelling `acquire()` mid-session-creation resets `is_creating`,
|
||||||
|
/// allowing subsequent callers to proceed rather than deadlocking.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_acquire_not_deadlocked_after_cancellation() {
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
let (tx, mut rx) = mpsc::channel::<()>(10);
|
||||||
|
|
||||||
|
// Local server: /registration signals arrival via `tx`, then hangs forever.
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let app = axum::Router::new().route(
|
||||||
|
"/StudentRegistrationSsb/registration",
|
||||||
|
axum::routing::get(move || {
|
||||||
|
let tx = tx.clone();
|
||||||
|
async move {
|
||||||
|
let _ = tx.send(()).await;
|
||||||
|
std::future::pending::<&str>().await
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let base_url = format!("http://{}/StudentRegistrationSsb", addr);
|
||||||
|
let client = reqwest_middleware::ClientBuilder::new(
|
||||||
|
reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(300))
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let pool = SessionPool::new(client, base_url);
|
||||||
|
let term: Term = "202620".parse().unwrap();
|
||||||
|
|
||||||
|
// First acquire: cancel once the request reaches the server.
|
||||||
|
tokio::select! {
|
||||||
|
_ = pool.acquire(term) => panic!("server hangs — acquire should never complete"),
|
||||||
|
_ = rx.recv() => {} // Request arrived; dropping the future simulates timeout cancellation.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second acquire: verify it reaches the server (i.e., is_creating was reset).
|
||||||
|
// The global rate limiter has a 10s period, so allow 15s for the second attempt.
|
||||||
|
tokio::select! {
|
||||||
|
_ = pool.acquire(term) => {}
|
||||||
|
result = tokio::time::timeout(Duration::from_secs(15), rx.recv()) => {
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"acquire() deadlocked — is_creating was not reset after cancellation"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_new_session_creates_session() {
|
fn test_new_session_creates_session() {
|
||||||
@@ -200,50 +260,53 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A smart pointer that returns a BannerSession to the pool when dropped.
|
/// A smart pointer that returns a `BannerSession` to the pool when dropped.
|
||||||
pub struct PooledSession {
|
pub struct PooledSession {
|
||||||
session: Option<BannerSession>,
|
session: ManuallyDrop<BannerSession>,
|
||||||
// This Arc points directly to the term-specific pool.
|
|
||||||
pool: Arc<TermPool>,
|
pool: Arc<TermPool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PooledSession {
|
|
||||||
pub fn been_used(&self) -> bool {
|
|
||||||
self.session.as_ref().unwrap().been_used()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for PooledSession {
|
impl Deref for PooledSession {
|
||||||
type Target = BannerSession;
|
type Target = BannerSession;
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
// The option is only ever None after drop is called, so this is safe.
|
&self.session
|
||||||
self.session.as_ref().unwrap()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DerefMut for PooledSession {
|
impl DerefMut for PooledSession {
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
self.session.as_mut().unwrap()
|
&mut self.session
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The magic happens here: when the guard goes out of scope, this is called.
|
|
||||||
impl Drop for PooledSession {
|
impl Drop for PooledSession {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if let Some(session) = self.session.take() {
|
// SAFETY: `drop` is called exactly once by Rust's drop semantics,
|
||||||
let pool = self.pool.clone();
|
// so `ManuallyDrop::take` is guaranteed to see a valid value.
|
||||||
// Since drop() cannot be async, we spawn a task to return the session.
|
let session = unsafe { ManuallyDrop::take(&mut self.session) };
|
||||||
tokio::spawn(async move {
|
let pool = self.pool.clone();
|
||||||
pool.release(session).await;
|
tokio::spawn(async move {
|
||||||
});
|
pool.release(session).await;
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TermPool {
|
pub struct TermPool {
|
||||||
sessions: Mutex<VecDeque<BannerSession>>,
|
sessions: Mutex<VecDeque<BannerSession>>,
|
||||||
notifier: Notify,
|
notifier: Notify,
|
||||||
is_creating: Mutex<bool>,
|
is_creating: AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RAII guard ensuring `is_creating` is reset on drop for cancellation safety.
|
||||||
|
/// Without this, a cancelled `acquire()` future would leave the flag set permanently,
|
||||||
|
/// deadlocking all subsequent callers.
|
||||||
|
struct CreatingGuard(Arc<TermPool>);
|
||||||
|
|
||||||
|
impl Drop for CreatingGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.0.is_creating.store(false, Ordering::Release);
|
||||||
|
self.0.notifier.notify_waiters();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TermPool {
|
impl TermPool {
|
||||||
@@ -251,7 +314,7 @@ impl TermPool {
|
|||||||
Self {
|
Self {
|
||||||
sessions: Mutex::new(VecDeque::new()),
|
sessions: Mutex::new(VecDeque::new()),
|
||||||
notifier: Notify::new(),
|
notifier: Notify::new(),
|
||||||
is_creating: Mutex::new(false),
|
is_creating: AtomicBool::new(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,7 +371,7 @@ impl SessionPool {
|
|||||||
if let Some(session) = queue.pop_front() {
|
if let Some(session) = queue.pop_front() {
|
||||||
if !session.is_expired() {
|
if !session.is_expired() {
|
||||||
return Ok(PooledSession {
|
return Ok(PooledSession {
|
||||||
session: Some(session),
|
session: ManuallyDrop::new(session),
|
||||||
pool: Arc::clone(&term_pool),
|
pool: Arc::clone(&term_pool),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
@@ -317,45 +380,38 @@ impl SessionPool {
|
|||||||
}
|
}
|
||||||
} // MutexGuard is dropped, lock is released.
|
} // MutexGuard is dropped, lock is released.
|
||||||
|
|
||||||
// Slow path: No sessions available. We must either wait or become the creator.
|
// Slow path: wait for an in-progress creation, or become the creator.
|
||||||
let mut is_creating_guard = term_pool.is_creating.lock().await;
|
if term_pool.is_creating.load(Ordering::Acquire) {
|
||||||
if *is_creating_guard {
|
|
||||||
// Another task is already creating a session. Release the lock and wait.
|
|
||||||
drop(is_creating_guard);
|
|
||||||
if !waited_for_creation {
|
if !waited_for_creation {
|
||||||
trace!("Waiting for another task to create session");
|
trace!("Waiting for another task to create session");
|
||||||
waited_for_creation = true;
|
waited_for_creation = true;
|
||||||
}
|
}
|
||||||
term_pool.notifier.notified().await;
|
term_pool.notifier.notified().await;
|
||||||
// Loop back to the top to try the fast path again.
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This task is now the designated creator.
|
// CAS to become the designated creator.
|
||||||
*is_creating_guard = true;
|
if term_pool
|
||||||
drop(is_creating_guard);
|
.is_creating
|
||||||
|
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
continue; // Lost the race — loop back and wait.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Guard resets is_creating on drop (including cancellation).
|
||||||
|
let creating_guard = CreatingGuard(Arc::clone(&term_pool));
|
||||||
|
|
||||||
// Race: wait for a session to be returned OR for the rate limiter to allow a new one.
|
|
||||||
trace!("Pool empty, creating new session");
|
trace!("Pool empty, creating new session");
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = term_pool.notifier.notified() => {
|
_ = term_pool.notifier.notified() => {
|
||||||
// A session was returned while we were waiting!
|
// A session was returned — release creator role and race for it.
|
||||||
// We are no longer the creator. Reset the flag and loop to race for the new session.
|
drop(creating_guard);
|
||||||
let mut guard = term_pool.is_creating.lock().await;
|
|
||||||
*guard = false;
|
|
||||||
drop(guard);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
_ = SESSION_CREATION_RATE_LIMITER.until_ready() => {
|
_ = SESSION_CREATION_RATE_LIMITER.until_ready() => {
|
||||||
// The rate limit has elapsed. It's our job to create the session.
|
|
||||||
let new_session_result = self.create_session(&term).await;
|
let new_session_result = self.create_session(&term).await;
|
||||||
|
drop(creating_guard);
|
||||||
// After creation, we are no longer the creator. Reset the flag
|
|
||||||
// and notify all other waiting tasks.
|
|
||||||
let mut guard = term_pool.is_creating.lock().await;
|
|
||||||
*guard = false;
|
|
||||||
drop(guard);
|
|
||||||
term_pool.notifier.notify_waiters();
|
|
||||||
|
|
||||||
match new_session_result {
|
match new_session_result {
|
||||||
Ok(new_session) => {
|
Ok(new_session) => {
|
||||||
@@ -366,12 +422,11 @@ impl SessionPool {
|
|||||||
"Created new session"
|
"Created new session"
|
||||||
);
|
);
|
||||||
return Ok(PooledSession {
|
return Ok(PooledSession {
|
||||||
session: Some(new_session),
|
session: ManuallyDrop::new(new_session),
|
||||||
pool: term_pool,
|
pool: term_pool,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Propagate the error if session creation failed.
|
|
||||||
return Err(e.context("Failed to create new session in pool"));
|
return Err(e.context("Failed to create new session in pool"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -380,8 +435,8 @@ impl SessionPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets up initial session cookies by making required Banner API requests
|
/// Sets up initial session cookies by making required Banner API requests.
|
||||||
pub async fn create_session(&self, term: &Term) -> Result<BannerSession> {
|
async fn create_session(&self, term: &Term) -> Result<BannerSession> {
|
||||||
info!(term = %term, "setting up banner session");
|
info!(term = %term, "setting up banner session");
|
||||||
|
|
||||||
// The 'register' or 'search' registration page
|
// The 'register' or 'search' registration page
|
||||||
@@ -392,22 +447,15 @@ impl SessionPool {
|
|||||||
.await?;
|
.await?;
|
||||||
// TODO: Validate success
|
// TODO: Validate success
|
||||||
|
|
||||||
let cookies = initial_registration
|
let cookies: HashMap<String, String> = initial_registration
|
||||||
.headers()
|
.headers()
|
||||||
.get_all("Set-Cookie")
|
.get_all("Set-Cookie")
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|header_value| {
|
.filter_map(|v| {
|
||||||
if let Ok(cookie_str) = header_value.to_str() {
|
let c = Cookie::parse(v.to_str().ok()?).ok()?;
|
||||||
if let Ok(cookie) = Cookie::parse(cookie_str) {
|
Some((c.name().to_string(), c.value().to_string()))
|
||||||
Some((cookie.name().to_string(), cookie.value().to_string()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect::<HashMap<String, String>>();
|
.collect();
|
||||||
|
|
||||||
let jsessionid = cookies
|
let jsessionid = cookies
|
||||||
.get("JSESSIONID")
|
.get("JSESSIONID")
|
||||||
@@ -494,8 +542,8 @@ impl SessionPool {
|
|||||||
Ok(terms)
|
Ok(terms)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Selects a term for the current session
|
/// Selects a term for the current session.
|
||||||
pub async fn select_term(
|
async fn select_term(
|
||||||
&self,
|
&self,
|
||||||
term: &str,
|
term: &str,
|
||||||
unique_session_id: &str,
|
unique_session_id: &str,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#[allow(dead_code)]
|
||||||
mod helpers;
|
mod helpers;
|
||||||
|
|
||||||
use banner::data::rmp::unmatch_instructor;
|
use banner::data::rmp::unmatch_instructor;
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#[allow(dead_code)]
|
||||||
mod helpers;
|
mod helpers;
|
||||||
|
|
||||||
use banner::data::batch::batch_upsert_courses;
|
use banner::data::batch::batch_upsert_courses;
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#[allow(dead_code)]
|
||||||
mod helpers;
|
mod helpers;
|
||||||
|
|
||||||
use banner::data::models::{ScrapePriority, TargetType};
|
use banner::data::models::{ScrapePriority, TargetType};
|
||||||
|
|||||||
Reference in New Issue
Block a user