From c0e99b5f942b9f12321b17654521dc67093e378c Mon Sep 17 00:00:00 2001 From: Xevion Date: Mon, 23 Dec 2024 12:43:40 -0600 Subject: [PATCH] log railway build logs url, add CORS & CatchPanic middleware, tx session property, ws upgrade handler --- Cargo.lock | 25 +++++++++ Cargo.toml | 4 +- src/main.rs | 140 +++++++++++++++++++++++++++++++++++++++++--------- src/models.rs | 9 ++-- 4 files changed, 151 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fec9aac..afb7441 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,10 +314,12 @@ name = "dynamic-preauth" version = "0.1.0" dependencies = [ "chrono", + "futures-util", "rand", "salvo", "serde", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", ] @@ -1551,6 +1553,7 @@ version = "0.74.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf71b51a4d651ddf3d660db7ae483baab3f1bd81f1f82f177731bf43f6052c34" dependencies = [ + "salvo-cors", "salvo-jwt-auth", "salvo-proxy", "salvo-serve-static", @@ -1558,6 +1561,17 @@ dependencies = [ "salvo_extra", ] +[[package]] +name = "salvo-cors" +version = "0.74.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5751058f858759ef8b29b669ca9a67aa8a9800eb1be43a286ab293a00121d5ec" +dependencies = [ + "bytes", + "salvo_core", + "tracing", +] + [[package]] name = "salvo-jwt-auth" version = "0.74.3" @@ -2131,6 +2145,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.24.0" diff --git a/Cargo.toml b/Cargo.toml index 12b6321..31d38ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,9 +5,11 @@ edition = "2021" [dependencies] chrono = { version = "0.4.39", features = ["serde"] } +futures-util = "0.3.31" rand = "0.8.5" -salvo = { version = "0.74.3", features = ["affix-state", "logging", "serve-static"] } +salvo = { version = "0.74.3", features = ["affix-state", "catch-panic", "cors", "logging", "serve-static", "websocket"] } serde = { version = "1.0.216", features = ["derive"] } tokio = { version = "1", features = ["macros"] } +tokio-stream = "0.1.17" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src/main.rs b/src/main.rs index 26858f5..abf9247 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,19 @@ +use std::env; use std::sync::LazyLock; -use salvo::http::{HeaderValue, StatusCode}; +use futures_util::{FutureExt, StreamExt}; +use salvo::cors::Cors; +use salvo::http::{HeaderValue, Method, StatusCode, StatusError}; use salvo::logging::Logger; use salvo::prelude::{ - handler, Listener, Request, Response, Router, Server, Service, StaticDir, TcpListener, + handler, CatchPanic, Listener, Request, Response, Router, Server, Service, StaticDir, + TcpListener, WebSocketUpgrade, }; +use salvo::websocket::WebSocket; use salvo::writing::Json; use salvo::Depot; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; +use tokio_stream::wrappers::UnboundedReceiverStream; use crate::models::State; @@ -46,18 +52,67 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D } } +#[handler] +async fn connect(req: &mut Request, res: &mut Response, depot: &Depot) -> Result<(), StatusError> { + let session_id = get_session_id(req, depot).unwrap(); + WebSocketUpgrade::new() + .upgrade(req, res, move |ws| async move { + handle_socket(session_id, ws).await; + }) + .await +} + +async fn handle_socket(session_id: usize, ws: WebSocket) { + // Split the socket into a sender and receive of messages. + let (user_ws_tx, mut user_ws_rx) = ws.split(); + + // Use an unbounded channel to handle buffering and flushing of messages + // to the websocket... + let (tx, rx) = mpsc::unbounded_channel(); + let rx = UnboundedReceiverStream::new(rx); + let fut = rx.forward(user_ws_tx).map(|result| { + if let Err(e) = result { + tracing::error!(error = ?e, "websocket send error"); + } + }); + tokio::task::spawn(fut); + + // Handle incoming messages + let fut = async move { + let mut store = STORE.lock().await; + let session = store.sessions.get_mut(&session_id).unwrap(); + session.tx = Some(tx); + drop(store); + + while let Some(result) = user_ws_rx.next().await { + let msg = match result { + Ok(msg) => msg, + Err(_) => { + // eprintln!("websocket error(uid={}): {}", my_id, e); + break; + } + }; + + println!("Received message: {:?}", msg); + } + }; + tokio::task::spawn(fut); +} + #[handler] pub async fn download(req: &mut Request, res: &mut Response) { - let article_id = req.param::("id").unwrap(); + let download_id = req.param::("id").unwrap(); let store = STORE.lock().await; - let executable = store.executables.get(&article_id as &str).unwrap(); + let executable = store.executables.get(&download_id as &str).unwrap(); let data = executable.with_key(b"test"); if let Err(e) = res.write_body(data) { eprintln!("Error writing body: {}", e); } + // TODO: Send the notify message via websocket + res.headers.insert( "Content-Disposition", HeaderValue::from_str(format!("attachment; filename=\"{}\"", executable.filename).as_str()) @@ -73,24 +128,13 @@ pub async fn download(req: &mut Request, res: &mut Response) { pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depot) { let store = STORE.lock().await; - let session_id = match req.cookie("Session") { - Some(cookie) => match cookie.value().parse::() { - Ok(id) => id, - _ => { - res.status_code(StatusCode::BAD_REQUEST); - return; - } - }, - None => match depot.get::("session_id") { - Ok(id) => *id, - _ => { - res.status_code(StatusCode::BAD_REQUEST); - return; - } - }, - }; + let session_id = get_session_id(req, depot); + if session_id.is_none() { + res.status_code(StatusCode::BAD_REQUEST); + return; + } - match store.sessions.get(&session_id) { + match store.sessions.get(&session_id.unwrap()) { Some(session) => { res.render(Json(&session)); } @@ -100,23 +144,73 @@ pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depo } } +fn get_session_id(req: &Request, depot: &Depot) -> Option { + match req.cookie("Session") { + Some(cookie) => match cookie.value().parse::() { + Ok(id) => Some(id), + _ => None, + }, + None => match depot.get::("session_id") { + Ok(id) => Some(*id), + _ => None, + }, + } +} + #[tokio::main] async fn main() { let port = std::env::var("PORT").unwrap_or_else(|_| "5800".to_string()); let addr = format!("0.0.0.0:{}", port); tracing_subscriber::fmt().init(); + // Check if we are deployed on Railway + let is_railway = env::var("RAILWAY_PROJECT_ID").is_ok(); + + if is_railway { + let build_logs = format!( + "https://railway.com/project/{}/service/{}?environmentId={}&id={}#build", + env::var("RAILWAY_PROJECT_ID").unwrap(), + env::var("RAILWAY_SERVICE_ID").unwrap(), + env::var("RAILWAY_ENVIRONMENT_ID").unwrap(), + env::var("RAILWAY_DEPLOYMENT_ID").unwrap() + ); + + println!("Build logs available here: {}", build_logs); + } + + // Add the executables to the store let mut store = STORE.lock().await; store.add_executable("windows", "./demo-windows.exe"); store.add_executable("linux", "./demo-linux"); - drop(store); + drop(store); // critical: Drop the lock to avoid deadlock, otherwise the server will hang + + // Allow all origins if: debug mode or RAILWAY_PUBLIC_DOMAIN is not set + let origin = if cfg!(debug_assertions) | env::var_os("RAILWAY_PUBLIC_DOMAIN").is_none() { + "*".to_string() + } else { + format!( + "https://{}", + env::var_os("RAILWAY_PUBLIC_DOMAIN") + .unwrap() + .to_str() + .unwrap() + ) + }; + + let cors = Cors::new() + .allow_origin(&origin) + .allow_methods(vec![Method::GET]) + .into_handler(); let static_dir = StaticDir::new(["./public"]).defaults("index.html"); let router = Router::new() + .hoop(CatchPanic::new()) + .hoop(cors) .hoop(session_middleware) .push(Router::with_path("download/").get(download)) .push(Router::with_path("session").get(get_session)) + .push(Router::with_path("ws").goal(connect)) .push(Router::with_path("<**path>").get(static_dir)); let service = Service::new(router).hoop(Logger::new()); diff --git a/src/models.rs b/src/models.rs index 84006f5..ce2cfd7 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,8 @@ -use rand::{distributions::Alphanumeric, Rng}; -use salvo::{http::cookie::Cookie, Response}; +use rand::Rng; +use salvo::{http::cookie::Cookie, websocket::Message, Response}; use serde::Serialize; use std::{collections::HashMap, path}; -use tokio::sync::Mutex; +use tokio::sync::{mpsc::UnboundedSender, Mutex}; use crate::utility::search; @@ -11,6 +11,8 @@ pub struct Session { pub tokens: Vec, pub last_seen: chrono::DateTime, pub first_seen: chrono::DateTime, + #[serde(skip_serializing)] + pub tx: Option>>, } #[derive(Default, Clone, Debug)] @@ -61,6 +63,7 @@ impl<'a> State<'a> { tokens: vec![], last_seen: now, first_seen: now, + tx: None, }, );