diff --git a/Cargo.lock b/Cargo.lock index afb7441..ae76ae0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -318,6 +318,7 @@ dependencies = [ "rand", "salvo", "serde", + "serde_json", "tokio", "tokio-stream", "tracing", @@ -958,6 +959,15 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.7.4" @@ -1354,8 +1364,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -1366,9 +1385,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -2276,10 +2301,14 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] diff --git a/Cargo.toml b/Cargo.toml index 31d38ed..350599f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,8 @@ futures-util = "0.3.31" rand = "0.8.5" salvo = { version = "0.74.3", features = ["affix-state", "catch-panic", "cors", "logging", "serve-static", "websocket"] } serde = { version = "1.0.216", features = ["derive"] } +serde_json = "1.0.134" tokio = { version = "1", features = ["macros"] } tokio-stream = "0.1.17" tracing = "0.1" -tracing-subscriber = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs index d849611..8c8211a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use std::env; use std::sync::LazyLock; use futures_util::{FutureExt, StreamExt}; +use models::IncomingMessage; use salvo::cors::Cors; use salvo::http::{HeaderValue, Method, StatusCode, StatusError}; use salvo::logging::Logger; @@ -14,6 +15,7 @@ use salvo::writing::Json; use salvo::Depot; use tokio::sync::{mpsc, Mutex}; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing_subscriber::EnvFilter; use crate::models::State; @@ -31,11 +33,13 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D Ok(session_id) => { let mut store = STORE.lock().await; if !store.sessions.contains_key(&session_id) { + tracing::debug!("Session provided in cookie, but does not exist"); let id = store.new_session(res).await; depot.insert("session_id", id); } } Err(_) => { + tracing::debug!("Session provided in cookie, but is not a valid number"); let mut store = STORE.lock().await; let id = store.new_session(res).await; @@ -44,6 +48,7 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D } } None => { + tracing::debug!("Session was not provided in cookie"); let mut store = STORE.lock().await; let id = store.new_session(res).await; @@ -80,10 +85,20 @@ async fn handle_socket(session_id: usize, ws: WebSocket) { // Handle incoming messages let fut = async move { let mut store = STORE.lock().await; - let session = store.sessions.get_mut(&session_id).unwrap(); - session.tx = Some(tx); + // let session = store + // .sessions + // .get_mut(&session_id) + // .expect("Unable to get session"); + // tx.send(Ok(Message::ping("1"))) + // .expect("Unable to send message"); + // session.tx = Some(tx); drop(store); + tracing::info!( + "WebSocket connection established for session_id: {}", + session_id + ); + while let Some(result) = user_ws_rx.next().await { let msg = match result { Ok(msg) => msg, @@ -93,19 +108,52 @@ async fn handle_socket(session_id: usize, ws: WebSocket) { } }; - println!("Received message: {:?}", msg); + if msg.is_close() { + tracing::info!("WebSocket closing for Session {}", session_id); + break; + } + + if msg.is_text() { + let text = msg.to_str().unwrap(); + + // Deserialize + match serde_json::from_str::(text) { + Ok(message) => { + tracing::info!("Received message: {:?}", message); + } + Err(e) => { + tracing::error!("Error deserializing message: {} {}", text, e); + } + } + } } }; tokio::task::spawn(fut); } #[handler] -pub async fn download(req: &mut Request, res: &mut Response) { - let download_id = req.param::("id").unwrap(); +pub async fn download(req: &mut Request, res: &mut Response, depot: &mut Depot) { + let download_id = req + .param::("id") + .expect("Download ID required to download file"); - let store = STORE.lock().await; - let executable = store.executables.get(&download_id as &str).unwrap(); - let data = executable.with_key(b"test"); + let session_id = + get_session_id(req, depot).expect("Session ID could not be found via request or depot"); + + let store = &mut *STORE.lock().await; + + let session = store + .sessions + .get_mut(&session_id) + .expect("Session not found"); + let executable = store + .executables + .get(&download_id as &str) + .expect("Executable not found"); + + // Create a download for the session + let session_download = session.add_download(executable); + let data = executable.with_key(session_id.to_string().as_bytes()); if let Err(e) = res.write_body(data) { eprintln!("Error writing body: {}", e); @@ -115,8 +163,10 @@ pub async fn download(req: &mut Request, res: &mut Response) { res.headers.insert( "Content-Disposition", - HeaderValue::from_str(format!("attachment; filename=\"{}\"", executable.filename).as_str()) - .unwrap(), + HeaderValue::from_str( + format!("attachment; filename=\"{}\"", session_download.filename).as_str(), + ) + .expect("Unable to create header"), ); res.headers.insert( "Content-Type", @@ -144,6 +194,7 @@ pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depo } } +// Acquires the session id from the request, preferring the request Cookie fn get_session_id(req: &Request, depot: &Depot) -> Option { match req.cookie("Session") { Some(cookie) => match cookie.value().parse::() { @@ -161,11 +212,19 @@ fn get_session_id(req: &Request, depot: &Depot) -> Option { async fn main() { let port = std::env::var("PORT").unwrap_or_else(|_| "5800".to_string()); let addr = format!("0.0.0.0:{}", port); - tracing_subscriber::fmt().init(); + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::new(format!( + "info,dynamic_preauth={}", + // Only log our message in debug mode + match cfg!(debug_assertions) { + true => "debug", + false => "info", + } + ))) + .init(); // Check if we are deployed on Railway let is_railway = env::var("RAILWAY_PROJECT_ID").is_ok(); - if is_railway { let build_logs = format!( "https://railway.com/project/{}/service/{}?environmentId={}&id={}#build", @@ -175,7 +234,7 @@ async fn main() { env::var("RAILWAY_DEPLOYMENT_ID").unwrap() ); - println!("Build logs available here: {}", build_logs); + tracing::info!("Build logs available here: {}", build_logs); } // Add the executables to the store @@ -201,6 +260,7 @@ async fn main() { .allow_origin(&origin) .allow_methods(vec![Method::GET]) .into_handler(); + tracing::debug!("CORS Origin: {}", &origin); let static_dir = StaticDir::new(["./public"]).defaults("index.html"); diff --git a/src/models.rs b/src/models.rs index ce2cfd7..30ae55d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,23 +1,67 @@ use rand::Rng; use salvo::{http::cookie::Cookie, websocket::Message, Response}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, path}; use tokio::sync::{mpsc::UnboundedSender, Mutex}; use crate::utility::search; -#[derive(Clone, Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct Session { - pub tokens: Vec, - pub last_seen: chrono::DateTime, + pub downloads: Vec, + pub first_seen: chrono::DateTime, + // The last time a request OR websocket message with this session was made + pub last_seen: chrono::DateTime, + // The last time a request was made with this session + pub last_request: chrono::DateTime, + + // The sender for the websocket connection #[serde(skip_serializing)] pub tx: Option>>, } -#[derive(Default, Clone, Debug)] +impl Session { + // Update the last seen time(s) for the session + pub fn seen(&mut self, socket: bool) { + self.last_seen = chrono::Utc::now(); + if !socket { + self.last_request = chrono::Utc::now(); + } + } + + // Add a download to the session + pub fn add_download(&mut self, exe: &Executable) -> &SessionDownload { + let mut rng = rand::thread_rng(); + let token: u64 = rng.gen(); + + let download = SessionDownload { + token, + filename: format!("{}-{:16x}{}", exe.name, token, exe.extension), + last_used: chrono::Utc::now(), + download_time: chrono::Utc::now(), + }; + + self.downloads.push(download); + return self.downloads.last().unwrap(); + } +} + +#[derive(Serialize, Debug, Clone)] +pub struct SessionDownload { + pub token: u64, + pub filename: String, + pub last_used: chrono::DateTime, + pub download_time: chrono::DateTime, +} + +impl SessionDownload {} + +#[derive(Clone, Debug)] pub struct State<'a> { + // A map of executables, keyed by their type/platform pub executables: HashMap<&'a str, Executable>, + // A map of sessions, keyed by their identifier (a random number) pub sessions: HashMap, } @@ -36,15 +80,17 @@ impl<'a> State<'a> { let key_start = search(&data, pattern.as_bytes(), 0).unwrap(); let key_end = key_start + pattern.len(); - let filename = path::Path::new(&exe_path) - .file_name() - .unwrap() - .to_string_lossy() - .into_owned(); + let filename = path::Path::new(&exe_path); + let name = filename.file_stem().unwrap().to_str().unwrap(); + let extension = match filename.extension() { + Some(s) => s.to_str().unwrap(), + None => "", + }; let exe = Executable { data, - filename, + name: name.to_string(), + extension: extension.to_string(), key_start: key_start, key_end: key_end, }; @@ -60,13 +106,16 @@ impl<'a> State<'a> { self.sessions.insert( id, Session { - tokens: vec![], + downloads: Vec::new(), last_seen: now, + last_request: now, first_seen: now, tx: None, }, ); + tracing::info!("New session created: {}", id); + res.add_cookie( Cookie::build(("Session", id.to_string())) .permanent() @@ -79,10 +128,11 @@ impl<'a> State<'a> { #[derive(Default, Clone, Debug)] pub struct Executable { - pub data: Vec, - pub filename: String, - pub key_start: usize, - pub key_end: usize, + pub data: Vec, // the raw data of the executable + pub name: String, // the name before the extension + pub extension: String, // may be empty string + pub key_start: usize, // the index of the byte where the key starts + pub key_end: usize, // the index of the byte where the key ends } impl Executable { @@ -104,3 +154,19 @@ impl Executable { return data; } } + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum IncomingMessage { + // A request from the client to delete a session token + DeleteSessionToken { id: u64 }, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +pub enum OutgoingMessage { + // An alert to the client that a session download has been used. + TokenAlert { token: u64 }, + // A message describing the current session state + State { session: Session }, +}