feat: extract database operations module and add extensive test suite

This commit is contained in:
2026-01-28 17:32:27 -06:00
parent 992263205c
commit 1733ee5f86
14 changed files with 1539 additions and 80 deletions
+206
View File
@@ -240,3 +240,209 @@ impl FromStr for Term {
Ok(Term { year, season })
}
}
#[cfg(test)]
mod tests {
use super::*;
// --- Season::from_str ---
#[test]
fn test_season_from_str_fall() {
assert_eq!(Season::from_str("10").unwrap(), Season::Fall);
}
#[test]
fn test_season_from_str_spring() {
assert_eq!(Season::from_str("20").unwrap(), Season::Spring);
}
#[test]
fn test_season_from_str_summer() {
assert_eq!(Season::from_str("30").unwrap(), Season::Summer);
}
#[test]
fn test_season_from_str_invalid() {
for input in ["00", "40", "1", ""] {
assert!(
Season::from_str(input).is_err(),
"expected Err for {input:?}"
);
}
}
// --- Season Display ---
#[test]
fn test_season_display() {
assert_eq!(Season::Fall.to_string(), "Fall");
assert_eq!(Season::Spring.to_string(), "Spring");
assert_eq!(Season::Summer.to_string(), "Summer");
}
#[test]
fn test_season_to_str_roundtrip() {
for season in [Season::Fall, Season::Spring, Season::Summer] {
assert_eq!(Season::from_str(season.to_str()).unwrap(), season);
}
}
// --- Term::from_str ---
#[test]
fn test_term_from_str_valid_fall() {
let term = Term::from_str("202510").unwrap();
assert_eq!(term.year, 2025);
assert_eq!(term.season, Season::Fall);
}
#[test]
fn test_term_from_str_valid_spring() {
let term = Term::from_str("202520").unwrap();
assert_eq!(term.year, 2025);
assert_eq!(term.season, Season::Spring);
}
#[test]
fn test_term_from_str_valid_summer() {
let term = Term::from_str("202530").unwrap();
assert_eq!(term.year, 2025);
assert_eq!(term.season, Season::Summer);
}
#[test]
fn test_term_from_str_too_short() {
assert!(Term::from_str("20251").is_err());
}
#[test]
fn test_term_from_str_too_long() {
assert!(Term::from_str("2025100").is_err());
}
#[test]
fn test_term_from_str_empty() {
assert!(Term::from_str("").is_err());
}
#[test]
fn test_term_from_str_invalid_year_chars() {
assert!(Term::from_str("abcd10").is_err());
}
#[test]
fn test_term_from_str_invalid_season() {
assert!(Term::from_str("202540").is_err());
}
#[test]
fn test_term_from_str_year_below_range() {
assert!(Term::from_str("200010").is_err());
}
#[test]
fn test_term_display_roundtrip() {
for code in ["202510", "202520", "202530"] {
let term = Term::from_str(code).unwrap();
assert_eq!(term.to_string(), code);
}
}
// --- Term::get_status_for_date ---
#[test]
fn test_status_mid_spring() {
let date = NaiveDate::from_ymd_opt(2025, 2, 15).unwrap();
let status = Term::get_status_for_date(date);
assert!(
matches!(status, TermPoint::InTerm { current } if current.season == Season::Spring)
);
}
#[test]
fn test_status_mid_summer() {
let date = NaiveDate::from_ymd_opt(2025, 7, 1).unwrap();
let status = Term::get_status_for_date(date);
assert!(
matches!(status, TermPoint::InTerm { current } if current.season == Season::Summer)
);
}
#[test]
fn test_status_mid_fall() {
let date = NaiveDate::from_ymd_opt(2025, 10, 15).unwrap();
let status = Term::get_status_for_date(date);
assert!(matches!(status, TermPoint::InTerm { current } if current.season == Season::Fall));
}
#[test]
fn test_status_between_fall_and_spring() {
let date = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let status = Term::get_status_for_date(date);
assert!(
matches!(status, TermPoint::BetweenTerms { next } if next.season == Season::Spring)
);
}
#[test]
fn test_status_between_spring_and_summer() {
let date = NaiveDate::from_ymd_opt(2025, 5, 15).unwrap();
let status = Term::get_status_for_date(date);
assert!(
matches!(status, TermPoint::BetweenTerms { next } if next.season == Season::Summer)
);
}
#[test]
fn test_status_between_summer_and_fall() {
let date = NaiveDate::from_ymd_opt(2025, 8, 16).unwrap();
let status = Term::get_status_for_date(date);
assert!(matches!(status, TermPoint::BetweenTerms { next } if next.season == Season::Fall));
}
#[test]
fn test_status_after_fall_end() {
let date = NaiveDate::from_ymd_opt(2025, 12, 15).unwrap();
let status = Term::get_status_for_date(date);
assert!(
matches!(status, TermPoint::BetweenTerms { next } if next.season == Season::Spring)
);
// Year should roll over: fall 2025 ends → next spring is 2026
let next_term = status.inner();
assert_eq!(next_term.year, 2026);
}
// --- TermPoint::inner ---
#[test]
fn test_term_point_inner() {
let in_term = TermPoint::InTerm {
current: Term {
year: 2025,
season: Season::Fall,
},
};
assert_eq!(
in_term.inner(),
&Term {
year: 2025,
season: Season::Fall
}
);
let between = TermPoint::BetweenTerms {
next: Term {
year: 2026,
season: Season::Spring,
},
};
assert_eq!(
between.inner(),
&Term {
year: 2026,
season: Season::Spring
}
);
}
}
+123
View File
@@ -85,3 +85,126 @@ pub type SharedRateLimiter = Arc<BannerRateLimiter>;
pub fn create_shared_rate_limiter(config: Option<RateLimitingConfig>) -> SharedRateLimiter {
Arc::new(BannerRateLimiter::new(config.unwrap_or_default()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_with_default_config() {
let _limiter = BannerRateLimiter::new(RateLimitingConfig::default());
}
#[test]
fn test_new_with_custom_config() {
let config = RateLimitingConfig {
session_rpm: 10,
search_rpm: 30,
metadata_rpm: 20,
reset_rpm: 15,
burst_allowance: 5,
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
fn test_new_with_minimum_valid_values() {
let config = RateLimitingConfig {
session_rpm: 1,
search_rpm: 1,
metadata_rpm: 1,
reset_rpm: 1,
burst_allowance: 1,
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
fn test_new_with_high_rpm_values() {
let config = RateLimitingConfig {
session_rpm: 10000,
search_rpm: 10000,
metadata_rpm: 10000,
reset_rpm: 10000,
burst_allowance: 1,
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
fn test_default_impl() {
let _limiter = BannerRateLimiter::default();
}
#[test]
#[should_panic]
fn test_new_panics_on_zero_session_rpm() {
let config = RateLimitingConfig {
session_rpm: 0,
..RateLimitingConfig::default()
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
#[should_panic]
fn test_new_panics_on_zero_search_rpm() {
let config = RateLimitingConfig {
search_rpm: 0,
..RateLimitingConfig::default()
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
#[should_panic]
fn test_new_panics_on_zero_metadata_rpm() {
let config = RateLimitingConfig {
metadata_rpm: 0,
..RateLimitingConfig::default()
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
#[should_panic]
fn test_new_panics_on_zero_reset_rpm() {
let config = RateLimitingConfig {
reset_rpm: 0,
..RateLimitingConfig::default()
};
let _limiter = BannerRateLimiter::new(config);
}
#[test]
#[should_panic]
fn test_new_panics_on_zero_burst_allowance() {
let config = RateLimitingConfig {
burst_allowance: 0,
..RateLimitingConfig::default()
};
let _limiter = BannerRateLimiter::new(config);
}
#[tokio::test]
async fn test_wait_for_permission_completes() {
let limiter = BannerRateLimiter::default();
let timeout_duration = std::time::Duration::from_secs(1);
for request_type in [
RequestType::Session,
RequestType::Search,
RequestType::Metadata,
RequestType::Reset,
] {
let result =
tokio::time::timeout(timeout_duration, limiter.wait_for_permission(request_type))
.await;
assert!(
result.is_ok(),
"wait_for_permission timed out for {:?}",
request_type
);
}
}
}
+99
View File
@@ -101,6 +101,105 @@ impl BannerSession {
pub fn been_used(&self) -> bool {
self.last_activity.is_some()
}
#[cfg(test)]
pub(crate) fn new_with_created_at(
unique_session_id: &str,
jsessionid: &str,
ssb_cookie: &str,
created_at: Instant,
) -> Self {
Self {
unique_session_id: unique_session_id.to_string(),
created_at,
last_activity: None,
jsessionid: jsessionid.to_string(),
ssb_cookie: ssb_cookie.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_session_returns_ok() {
let session = BannerSession::new("sess-1", "JSID123", "SSB456");
assert!(session.is_ok());
assert_eq!(session.unwrap().id(), "sess-1");
}
#[test]
fn test_fresh_session_not_expired() {
let session = BannerSession::new("sess-1", "JSID123", "SSB456").unwrap();
assert!(!session.is_expired());
}
#[test]
fn test_fresh_session_not_been_used() {
let session = BannerSession::new("sess-1", "JSID123", "SSB456").unwrap();
assert!(!session.been_used());
}
#[test]
fn test_touch_marks_used() {
let mut session = BannerSession::new("sess-1", "JSID123", "SSB456").unwrap();
session.touch();
assert!(session.been_used());
}
#[test]
fn test_touched_session_not_expired() {
let mut session = BannerSession::new("sess-1", "JSID123", "SSB456").unwrap();
session.touch();
assert!(!session.is_expired());
}
#[test]
fn test_cookie_format() {
let session = BannerSession::new("sess-1", "JSID123", "SSB456").unwrap();
assert_eq!(session.cookie(), "JSESSIONID=JSID123; SSB_COOKIE=SSB456");
}
#[test]
fn test_id_returns_unique_session_id() {
let session = BannerSession::new("my-unique-id", "JSID123", "SSB456").unwrap();
assert_eq!(session.id(), "my-unique-id");
}
#[test]
fn test_expired_session() {
let session = BannerSession::new_with_created_at(
"sess-old",
"JSID123",
"SSB456",
Instant::now() - Duration::from_secs(26 * 60),
);
assert!(session.is_expired());
}
#[test]
fn test_not_quite_expired_session() {
let session = BannerSession::new_with_created_at(
"sess-recent",
"JSID123",
"SSB456",
Instant::now() - Duration::from_secs(24 * 60),
);
assert!(!session.is_expired());
}
#[test]
fn test_session_at_expiry_boundary() {
let session = BannerSession::new_with_created_at(
"sess-boundary",
"JSID123",
"SSB456",
Instant::now() - Duration::from_secs(25 * 60 + 1),
);
assert!(session.is_expired());
}
}
/// A smart pointer that returns a BannerSession to the pool when dropped.
+107
View File
@@ -140,3 +140,110 @@ fn parse_course_code(input: &str) -> Result<(i32, i32), Error> {
Err(anyhow!("Invalid course code format"))
}
#[cfg(test)]
mod tests {
use super::*;
// --- Single codes ---
#[test]
fn test_parse_single_code() {
assert_eq!(parse_course_code("3743").unwrap(), (3743, 3743));
}
#[test]
fn test_parse_single_code_boundaries() {
assert_eq!(parse_course_code("1000").unwrap(), (1000, 1000));
assert_eq!(parse_course_code("9999").unwrap(), (9999, 9999));
}
#[test]
fn test_parse_single_code_below_range() {
assert!(parse_course_code("0999").is_err());
}
#[test]
fn test_parse_single_code_wrong_length() {
assert!(parse_course_code("123").is_err());
}
#[test]
fn test_parse_single_code_non_numeric() {
assert!(parse_course_code("abcd").is_err());
}
#[test]
fn test_parse_single_code_trimmed() {
assert_eq!(parse_course_code(" 3743 ").unwrap(), (3743, 3743));
}
// --- Ranges ---
#[test]
fn test_parse_range_full() {
assert_eq!(parse_course_code("3000-3999").unwrap(), (3000, 3999));
}
#[test]
fn test_parse_range_same() {
assert_eq!(parse_course_code("3000-3000").unwrap(), (3000, 3000));
}
#[test]
fn test_parse_range_open() {
assert_eq!(parse_course_code("3000-").unwrap(), (3000, 9999));
}
#[test]
fn test_parse_range_inverted() {
assert!(parse_course_code("5000-3000").is_err());
}
#[test]
fn test_parse_range_below_1000() {
assert!(parse_course_code("500-999").is_err());
}
#[test]
fn test_parse_range_above_9999() {
assert!(parse_course_code("9000-10000").is_err());
}
#[test]
fn test_parse_range_full_valid() {
assert_eq!(parse_course_code("1000-9999").unwrap(), (1000, 9999));
}
// --- Wildcards ---
#[test]
fn test_parse_wildcard_one_x() {
assert_eq!(parse_course_code("300x").unwrap(), (3000, 3009));
}
#[test]
fn test_parse_wildcard_two_x() {
assert_eq!(parse_course_code("30xx").unwrap(), (3000, 3099));
}
#[test]
fn test_parse_wildcard_three_x() {
assert_eq!(parse_course_code("3xxx").unwrap(), (3000, 3999));
}
#[test]
fn test_parse_wildcard_9xxx() {
assert_eq!(parse_course_code("9xxx").unwrap(), (9000, 9999));
}
#[test]
fn test_parse_wildcard_wrong_length() {
assert!(parse_course_code("3xxxx").is_err());
}
#[test]
fn test_parse_wildcard_0xxx() {
assert!(parse_course_code("0xxx").is_err());
}
}
-13
View File
@@ -120,16 +120,3 @@ pub async fn batch_upsert_courses(courses: &[Course], db_pool: &PgPool) -> Resul
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_batch_returns_ok() {
// This is a basic compile-time test
// Runtime tests would require sqlx::test macro and a test database
let courses: Vec<Course> = vec![];
assert_eq!(courses.len(), 0);
}
}
+1
View File
@@ -2,3 +2,4 @@
pub mod batch;
pub mod models;
pub mod scrape_jobs;
+170
View File
@@ -0,0 +1,170 @@
//! Database operations for scrape job queue management.
use crate::data::models::{ScrapeJob, ScrapePriority, TargetType};
use crate::error::Result;
use sqlx::PgPool;
use std::collections::HashSet;
/// Atomically fetch and lock the next available scrape job.
///
/// Uses `FOR UPDATE SKIP LOCKED` to allow multiple workers to poll the queue
/// concurrently without conflicts. Only jobs that are unlocked and ready to
/// execute (based on `execute_at`) are considered.
///
/// # Arguments
/// * `db_pool` - PostgreSQL connection pool
///
/// # Returns
/// * `Ok(Some(job))` if a job was successfully fetched and locked
/// * `Ok(None)` if no jobs are available
pub async fn fetch_and_lock_job(db_pool: &PgPool) -> Result<Option<ScrapeJob>> {
let mut tx = db_pool.begin().await?;
let job = sqlx::query_as::<_, ScrapeJob>(
"SELECT * FROM scrape_jobs WHERE locked_at IS NULL AND execute_at <= NOW() ORDER BY priority DESC, execute_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED"
)
.fetch_optional(&mut *tx)
.await?;
if let Some(ref job) = job {
sqlx::query("UPDATE scrape_jobs SET locked_at = NOW() WHERE id = $1")
.bind(job.id)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(job)
}
/// Delete a scrape job by ID.
///
/// Typically called after a job has been successfully processed or permanently failed.
///
/// # Arguments
/// * `job_id` - The database ID of the job to delete
/// * `db_pool` - PostgreSQL connection pool
pub async fn delete_job(job_id: i32, db_pool: &PgPool) -> Result<()> {
sqlx::query("DELETE FROM scrape_jobs WHERE id = $1")
.bind(job_id)
.execute(db_pool)
.await?;
Ok(())
}
/// Unlock a scrape job by clearing its `locked_at` timestamp.
///
/// Used to release a job back to the queue, e.g. during graceful shutdown.
///
/// # Arguments
/// * `job_id` - The database ID of the job to unlock
/// * `db_pool` - PostgreSQL connection pool
pub async fn unlock_job(job_id: i32, db_pool: &PgPool) -> Result<()> {
sqlx::query("UPDATE scrape_jobs SET locked_at = NULL WHERE id = $1")
.bind(job_id)
.execute(db_pool)
.await?;
Ok(())
}
/// Atomically unlock a job and increment its retry count.
///
/// Returns whether the job still has retries remaining. This is determined
/// atomically in the database to avoid race conditions between workers.
///
/// # Arguments
/// * `job_id` - The database ID of the job
/// * `max_retries` - Maximum number of retries allowed for this job
/// * `db_pool` - PostgreSQL connection pool
///
/// # Returns
/// * `Ok(true)` if the job was unlocked and retries remain
/// * `Ok(false)` if the job has exhausted its retries
pub async fn unlock_and_increment_retry(
job_id: i32,
max_retries: i32,
db_pool: &PgPool,
) -> Result<bool> {
let result = sqlx::query_scalar::<_, Option<i32>>(
"UPDATE scrape_jobs
SET locked_at = NULL, retry_count = retry_count + 1
WHERE id = $1
RETURNING CASE WHEN retry_count + 1 < $2 THEN retry_count + 1 ELSE NULL END",
)
.bind(job_id)
.bind(max_retries)
.fetch_one(db_pool)
.await?;
Ok(result.is_some())
}
/// Find existing unlocked job payloads matching the given target type and candidates.
///
/// Returns a set of stringified JSON payloads that already exist in the queue,
/// used for deduplication when scheduling new jobs.
///
/// # Arguments
/// * `target_type` - The target type to filter by
/// * `candidate_payloads` - Candidate payloads to check against existing jobs
/// * `db_pool` - PostgreSQL connection pool
///
/// # Returns
/// A `HashSet` of stringified JSON payloads that already have pending jobs
pub async fn find_existing_job_payloads(
target_type: TargetType,
candidate_payloads: &[serde_json::Value],
db_pool: &PgPool,
) -> Result<HashSet<String>> {
let existing_jobs: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT target_payload FROM scrape_jobs
WHERE target_type = $1 AND target_payload = ANY($2) AND locked_at IS NULL",
)
.bind(target_type)
.bind(candidate_payloads)
.fetch_all(db_pool)
.await?;
let existing_payloads = existing_jobs
.into_iter()
.map(|(payload,)| payload.to_string())
.collect();
Ok(existing_payloads)
}
/// Batch insert scrape jobs in a single transaction.
///
/// All jobs are inserted with `execute_at` set to the current time.
///
/// # Arguments
/// * `jobs` - Slice of `(payload, target_type, priority)` tuples to insert
/// * `db_pool` - PostgreSQL connection pool
pub async fn batch_insert_jobs(
jobs: &[(serde_json::Value, TargetType, ScrapePriority)],
db_pool: &PgPool,
) -> Result<()> {
if jobs.is_empty() {
return Ok(());
}
let now = chrono::Utc::now();
let mut tx = db_pool.begin().await?;
for (payload, target_type, priority) in jobs {
sqlx::query(
"INSERT INTO scrape_jobs (target_type, target_payload, priority, execute_at) VALUES ($1, $2, $3, $4)"
)
.bind(target_type)
.bind(payload)
.bind(priority)
.bind(now)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
+80
View File
@@ -102,3 +102,83 @@ impl JobType {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// --- Valid dispatch ---
#[test]
fn test_from_target_subject_valid() {
let result =
JobType::from_target_type_and_payload(TargetType::Subject, json!({"subject": "CS"}));
assert!(matches!(result, Ok(JobType::Subject(_))));
}
#[test]
fn test_from_target_subject_empty_string() {
let result =
JobType::from_target_type_and_payload(TargetType::Subject, json!({"subject": ""}));
assert!(matches!(result, Ok(JobType::Subject(_))));
}
// --- Invalid JSON ---
#[test]
fn test_from_target_subject_missing_field() {
let result = JobType::from_target_type_and_payload(TargetType::Subject, json!({}));
assert!(matches!(result, Err(JobParseError::InvalidJson(_))));
}
#[test]
fn test_from_target_subject_wrong_type() {
let result =
JobType::from_target_type_and_payload(TargetType::Subject, json!({"subject": 123}));
assert!(matches!(result, Err(JobParseError::InvalidJson(_))));
}
#[test]
fn test_from_target_subject_null_payload() {
let result = JobType::from_target_type_and_payload(TargetType::Subject, json!(null));
assert!(matches!(result, Err(JobParseError::InvalidJson(_))));
}
// --- Unsupported target types ---
#[test]
fn test_from_target_unsupported_variants() {
let unsupported = [
TargetType::CourseRange,
TargetType::CrnList,
TargetType::SingleCrn,
];
for target_type in unsupported {
let result =
JobType::from_target_type_and_payload(target_type, json!({"subject": "CS"}));
assert!(
matches!(result, Err(JobParseError::UnsupportedTargetType(_))),
"expected UnsupportedTargetType for {target_type:?}"
);
}
}
// --- Error Display ---
#[test]
fn test_job_parse_error_display() {
let invalid_json_err =
JobType::from_target_type_and_payload(TargetType::Subject, json!(null)).unwrap_err();
let display = invalid_json_err.to_string();
assert!(display.contains("Invalid JSON"), "got: {display}");
let unsupported_err =
JobType::from_target_type_and_payload(TargetType::CrnList, json!({})).unwrap_err();
let display = unsupported_err.to_string();
assert!(
display.contains("Unsupported target type"),
"got: {display}"
);
}
}
+12 -27
View File
@@ -1,5 +1,6 @@
use crate::banner::{BannerApi, Term};
use crate::data::models::{ScrapePriority, TargetType};
use crate::data::scrape_jobs;
use crate::error::Result;
use crate::scraper::jobs::subject::SubjectJob;
use serde_json::json;
@@ -123,21 +124,13 @@ impl Scheduler {
.collect();
// Query existing jobs for all subjects in a single query
let existing_jobs: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT target_payload FROM scrape_jobs
WHERE target_type = $1 AND target_payload = ANY($2) AND locked_at IS NULL",
let existing_payloads = scrape_jobs::find_existing_job_payloads(
TargetType::Subject,
&subject_payloads,
db_pool,
)
.bind(TargetType::Subject)
.bind(&subject_payloads)
.fetch_all(db_pool)
.await?;
// Convert to a HashSet for efficient lookup
let existing_payloads: std::collections::HashSet<String> = existing_jobs
.into_iter()
.map(|(payload,)| payload.to_string())
.collect();
// Filter out subjects that already have jobs and prepare new jobs
let mut skipped_count = 0;
let new_jobs: Vec<_> = subjects
@@ -162,24 +155,16 @@ impl Scheduler {
// Insert all new jobs in a single batch
if !new_jobs.is_empty() {
let now = chrono::Utc::now();
let mut tx = db_pool.begin().await?;
for (payload, subject_code) in new_jobs {
sqlx::query(
"INSERT INTO scrape_jobs (target_type, target_payload, priority, execute_at) VALUES ($1, $2, $3, $4)"
)
.bind(TargetType::Subject)
.bind(&payload)
.bind(ScrapePriority::Low)
.bind(now)
.execute(&mut *tx)
.await?;
for (_, subject_code) in &new_jobs {
debug!(subject = subject_code, "New job enqueued for subject");
}
tx.commit().await?;
let jobs: Vec<_> = new_jobs
.into_iter()
.map(|(payload, _)| (payload, TargetType::Subject, ScrapePriority::Low))
.collect();
scrape_jobs::batch_insert_jobs(&jobs, db_pool).await?;
}
debug!("Job scheduling complete");
+5 -40
View File
@@ -1,5 +1,6 @@
use crate::banner::{BannerApi, BannerApiError};
use crate::data::models::ScrapeJob;
use crate::data::scrape_jobs;
use crate::error::Result;
use crate::scraper::jobs::{JobError, JobType};
use sqlx::PgPool;
@@ -83,24 +84,7 @@ impl Worker {
/// This uses a `FOR UPDATE SKIP LOCKED` query to ensure that multiple
/// workers can poll the queue concurrently without conflicts.
async fn fetch_and_lock_job(&self) -> Result<Option<ScrapeJob>> {
let mut tx = self.db_pool.begin().await?;
let job = sqlx::query_as::<_, ScrapeJob>(
"SELECT * FROM scrape_jobs WHERE locked_at IS NULL AND execute_at <= NOW() ORDER BY priority DESC, execute_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED"
)
.fetch_optional(&mut *tx)
.await?;
if let Some(ref job) = job {
sqlx::query("UPDATE scrape_jobs SET locked_at = NOW() WHERE id = $1")
.bind(job.id)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(job)
scrape_jobs::fetch_and_lock_job(&self.db_pool).await
}
async fn process_job(&self, job: ScrapeJob) -> Result<(), JobError> {
@@ -139,34 +123,15 @@ impl Worker {
}
async fn delete_job(&self, job_id: i32) -> Result<()> {
sqlx::query("DELETE FROM scrape_jobs WHERE id = $1")
.bind(job_id)
.execute(&self.db_pool)
.await?;
Ok(())
scrape_jobs::delete_job(job_id, &self.db_pool).await
}
async fn unlock_job(&self, job_id: i32) -> Result<()> {
sqlx::query("UPDATE scrape_jobs SET locked_at = NULL WHERE id = $1")
.bind(job_id)
.execute(&self.db_pool)
.await?;
Ok(())
scrape_jobs::unlock_job(job_id, &self.db_pool).await
}
async fn unlock_and_increment_retry(&self, job_id: i32, max_retries: i32) -> Result<bool> {
let result = sqlx::query_scalar::<_, Option<i32>>(
"UPDATE scrape_jobs
SET locked_at = NULL, retry_count = retry_count + 1
WHERE id = $1
RETURNING CASE WHEN retry_count + 1 < $2 THEN retry_count + 1 ELSE NULL END",
)
.bind(job_id)
.bind(max_retries)
.fetch_one(&self.db_pool)
.await?;
Ok(result.is_some())
scrape_jobs::unlock_and_increment_retry(job_id, max_retries, &self.db_pool).await
}
/// Handle shutdown signal received during job processing