diff --git a/src/main.rs b/src/main.rs index 7e43c3c..037ce08 100644 --- a/src/main.rs +++ b/src/main.rs @@ -178,7 +178,7 @@ pub async fn download(req: &mut Request, res: &mut Response, depot: &mut Depot) // Create a download for the session let session_download = session.add_download(executable); tracing::info!(session_id, type = download_id, dl_token = session_download.token, "Download created"); - let data = executable.with_key(session_id.to_string().as_bytes()); + let data = executable.with_key(session_download.token.to_string().as_bytes()); if let Err(e) = res.write_body(data) { tracing::error!("Error writing body: {}", e); @@ -210,39 +210,54 @@ pub async fn download(req: &mut Request, res: &mut Response, depot: &mut Depot) pub async fn notify(req: &mut Request, res: &mut Response) { let key = req.query::("key"); - match key { - Some(key) => { - // Parse key into u32 - let key = match key.parse::() { - Ok(k) => k, - Err(e) => { - tracing::error!("Error parsing key: {}", e); - res.status_code(StatusCode::BAD_REQUEST); - return; - } - }; + if key.is_none() { + res.status_code(StatusCode::BAD_REQUEST); + return; + } - let store = &mut *STORE.lock().await; - let session = store.sessions.get_mut(&key); + let key = key.unwrap(); - match session { - Some(session) => { - let message = OutgoingMessage::TokenAlert { token: key }; - session - .send_message(message) - .expect("Failed to buffer token alert message"); + if !key.starts_with("0x") { + res.status_code(StatusCode::BAD_REQUEST); + return; + } - res.render("Notification sent"); - } - None => { - tracing::warn!("Session not found for key while attempting notify: {}", key); - res.status_code(StatusCode::UNAUTHORIZED); - return; - } + // Parse key into u32 + let key = match u32::from_str_radix(key.trim_start_matches("0x"), 16) { + Ok(k) => k, + Err(e) => { + tracing::error!("Error parsing key: {}", e); + res.status_code(StatusCode::BAD_REQUEST); + return; + } + }; + + let store = &mut *STORE.lock().await; + + let target_session = store + .sessions + .iter_mut() + .find(|(_, session)| session.downloads.iter().find(|d| d.token == key).is_some()); + + match target_session { + Some((_, session)) => { + let message = OutgoingMessage::TokenAlert { token: key }; + + if let Err(e) = session.send_message(message) { + tracing::warn!( + error = e.to_string(), + "Session did not have a receiving WebSocket available, notify ignored.", + ); + res.status_code(StatusCode::NOT_MODIFIED); + return; } + + res.render("Notification sent"); } None => { - res.status_code(StatusCode::BAD_REQUEST); + tracing::warn!("Session not found for key while attempting notify: {}", key); + res.status_code(StatusCode::UNAUTHORIZED); + return; } } }