log railway build logs url, add CORS & CatchPanic middleware, tx session property, ws upgrade handler

This commit is contained in:
2024-12-23 12:43:40 -06:00
parent 54ddf4496c
commit c0e99b5f94
4 changed files with 151 additions and 27 deletions

25
Cargo.lock generated
View File

@@ -314,10 +314,12 @@ name = "dynamic-preauth"
version = "0.1.0"
dependencies = [
"chrono",
"futures-util",
"rand",
"salvo",
"serde",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
]
@@ -1551,6 +1553,7 @@ version = "0.74.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf71b51a4d651ddf3d660db7ae483baab3f1bd81f1f82f177731bf43f6052c34"
dependencies = [
"salvo-cors",
"salvo-jwt-auth",
"salvo-proxy",
"salvo-serve-static",
@@ -1558,6 +1561,17 @@ dependencies = [
"salvo_extra",
]
[[package]]
name = "salvo-cors"
version = "0.74.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5751058f858759ef8b29b669ca9a67aa8a9800eb1be43a286ab293a00121d5ec"
dependencies = [
"bytes",
"salvo_core",
"tracing",
]
[[package]]
name = "salvo-jwt-auth"
version = "0.74.3"
@@ -2131,6 +2145,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.24.0"

View File

@@ -5,9 +5,11 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4.39", features = ["serde"] }
futures-util = "0.3.31"
rand = "0.8.5"
salvo = { version = "0.74.3", features = ["affix-state", "logging", "serve-static"] }
salvo = { version = "0.74.3", features = ["affix-state", "catch-panic", "cors", "logging", "serve-static", "websocket"] }
serde = { version = "1.0.216", features = ["derive"] }
tokio = { version = "1", features = ["macros"] }
tokio-stream = "0.1.17"
tracing = "0.1"
tracing-subscriber = "0.3"

View File

@@ -1,13 +1,19 @@
use std::env;
use std::sync::LazyLock;
use salvo::http::{HeaderValue, StatusCode};
use futures_util::{FutureExt, StreamExt};
use salvo::cors::Cors;
use salvo::http::{HeaderValue, Method, StatusCode, StatusError};
use salvo::logging::Logger;
use salvo::prelude::{
handler, Listener, Request, Response, Router, Server, Service, StaticDir, TcpListener,
handler, CatchPanic, Listener, Request, Response, Router, Server, Service, StaticDir,
TcpListener, WebSocketUpgrade,
};
use salvo::websocket::WebSocket;
use salvo::writing::Json;
use salvo::Depot;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::models::State;
@@ -46,18 +52,67 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D
}
}
#[handler]
async fn connect(req: &mut Request, res: &mut Response, depot: &Depot) -> Result<(), StatusError> {
let session_id = get_session_id(req, depot).unwrap();
WebSocketUpgrade::new()
.upgrade(req, res, move |ws| async move {
handle_socket(session_id, ws).await;
})
.await
}
async fn handle_socket(session_id: usize, ws: WebSocket) {
// Split the socket into a sender and receive of messages.
let (user_ws_tx, mut user_ws_rx) = ws.split();
// Use an unbounded channel to handle buffering and flushing of messages
// to the websocket...
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);
let fut = rx.forward(user_ws_tx).map(|result| {
if let Err(e) = result {
tracing::error!(error = ?e, "websocket send error");
}
});
tokio::task::spawn(fut);
// Handle incoming messages
let fut = async move {
let mut store = STORE.lock().await;
let session = store.sessions.get_mut(&session_id).unwrap();
session.tx = Some(tx);
drop(store);
while let Some(result) = user_ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(_) => {
// eprintln!("websocket error(uid={}): {}", my_id, e);
break;
}
};
println!("Received message: {:?}", msg);
}
};
tokio::task::spawn(fut);
}
#[handler]
pub async fn download(req: &mut Request, res: &mut Response) {
let article_id = req.param::<String>("id").unwrap();
let download_id = req.param::<String>("id").unwrap();
let store = STORE.lock().await;
let executable = store.executables.get(&article_id as &str).unwrap();
let executable = store.executables.get(&download_id as &str).unwrap();
let data = executable.with_key(b"test");
if let Err(e) = res.write_body(data) {
eprintln!("Error writing body: {}", e);
}
// TODO: Send the notify message via websocket
res.headers.insert(
"Content-Disposition",
HeaderValue::from_str(format!("attachment; filename=\"{}\"", executable.filename).as_str())
@@ -73,24 +128,13 @@ pub async fn download(req: &mut Request, res: &mut Response) {
pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depot) {
let store = STORE.lock().await;
let session_id = match req.cookie("Session") {
Some(cookie) => match cookie.value().parse::<usize>() {
Ok(id) => id,
_ => {
res.status_code(StatusCode::BAD_REQUEST);
return;
}
},
None => match depot.get::<usize>("session_id") {
Ok(id) => *id,
_ => {
res.status_code(StatusCode::BAD_REQUEST);
return;
}
},
};
let session_id = get_session_id(req, depot);
if session_id.is_none() {
res.status_code(StatusCode::BAD_REQUEST);
return;
}
match store.sessions.get(&session_id) {
match store.sessions.get(&session_id.unwrap()) {
Some(session) => {
res.render(Json(&session));
}
@@ -100,23 +144,73 @@ pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depo
}
}
fn get_session_id(req: &Request, depot: &Depot) -> Option<usize> {
match req.cookie("Session") {
Some(cookie) => match cookie.value().parse::<usize>() {
Ok(id) => Some(id),
_ => None,
},
None => match depot.get::<usize>("session_id") {
Ok(id) => Some(*id),
_ => None,
},
}
}
#[tokio::main]
async fn main() {
let port = std::env::var("PORT").unwrap_or_else(|_| "5800".to_string());
let addr = format!("0.0.0.0:{}", port);
tracing_subscriber::fmt().init();
// Check if we are deployed on Railway
let is_railway = env::var("RAILWAY_PROJECT_ID").is_ok();
if is_railway {
let build_logs = format!(
"https://railway.com/project/{}/service/{}?environmentId={}&id={}#build",
env::var("RAILWAY_PROJECT_ID").unwrap(),
env::var("RAILWAY_SERVICE_ID").unwrap(),
env::var("RAILWAY_ENVIRONMENT_ID").unwrap(),
env::var("RAILWAY_DEPLOYMENT_ID").unwrap()
);
println!("Build logs available here: {}", build_logs);
}
// Add the executables to the store
let mut store = STORE.lock().await;
store.add_executable("windows", "./demo-windows.exe");
store.add_executable("linux", "./demo-linux");
drop(store);
drop(store); // critical: Drop the lock to avoid deadlock, otherwise the server will hang
// Allow all origins if: debug mode or RAILWAY_PUBLIC_DOMAIN is not set
let origin = if cfg!(debug_assertions) | env::var_os("RAILWAY_PUBLIC_DOMAIN").is_none() {
"*".to_string()
} else {
format!(
"https://{}",
env::var_os("RAILWAY_PUBLIC_DOMAIN")
.unwrap()
.to_str()
.unwrap()
)
};
let cors = Cors::new()
.allow_origin(&origin)
.allow_methods(vec![Method::GET])
.into_handler();
let static_dir = StaticDir::new(["./public"]).defaults("index.html");
let router = Router::new()
.hoop(CatchPanic::new())
.hoop(cors)
.hoop(session_middleware)
.push(Router::with_path("download/<id>").get(download))
.push(Router::with_path("session").get(get_session))
.push(Router::with_path("ws").goal(connect))
.push(Router::with_path("<**path>").get(static_dir));
let service = Service::new(router).hoop(Logger::new());

View File

@@ -1,8 +1,8 @@
use rand::{distributions::Alphanumeric, Rng};
use salvo::{http::cookie::Cookie, Response};
use rand::Rng;
use salvo::{http::cookie::Cookie, websocket::Message, Response};
use serde::Serialize;
use std::{collections::HashMap, path};
use tokio::sync::Mutex;
use tokio::sync::{mpsc::UnboundedSender, Mutex};
use crate::utility::search;
@@ -11,6 +11,8 @@ pub struct Session {
pub tokens: Vec<String>,
pub last_seen: chrono::DateTime<chrono::Utc>,
pub first_seen: chrono::DateTime<chrono::Utc>,
#[serde(skip_serializing)]
pub tx: Option<UnboundedSender<Result<Message, salvo::Error>>>,
}
#[derive(Default, Clone, Debug)]
@@ -61,6 +63,7 @@ impl<'a> State<'a> {
tokens: vec![],
last_seen: now,
first_seen: now,
tx: None,
},
);