Tracing, SessionDownload, Incoming/Outgoing message serde, Executable handling, EnvFilter

This commit is contained in:
2024-12-23 16:47:31 -06:00
parent 9315fbd985
commit 70dc064a4c
4 changed files with 189 additions and 33 deletions

35
Cargo.lock generated
View File

@@ -318,6 +318,7 @@ dependencies = [
"rand",
"salvo",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tracing",
@@ -958,6 +959,15 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "memchr"
version = "2.7.4"
@@ -1354,8 +1364,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
@@ -1366,9 +1385,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
"regex-syntax 0.8.5",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.5"
@@ -2276,10 +2301,14 @@ version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]

View File

@@ -9,7 +9,8 @@ futures-util = "0.3.31"
rand = "0.8.5"
salvo = { version = "0.74.3", features = ["affix-state", "catch-panic", "cors", "logging", "serve-static", "websocket"] }
serde = { version = "1.0.216", features = ["derive"] }
serde_json = "1.0.134"
tokio = { version = "1", features = ["macros"] }
tokio-stream = "0.1.17"
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -2,6 +2,7 @@ use std::env;
use std::sync::LazyLock;
use futures_util::{FutureExt, StreamExt};
use models::IncomingMessage;
use salvo::cors::Cors;
use salvo::http::{HeaderValue, Method, StatusCode, StatusError};
use salvo::logging::Logger;
@@ -14,6 +15,7 @@ use salvo::writing::Json;
use salvo::Depot;
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing_subscriber::EnvFilter;
use crate::models::State;
@@ -31,11 +33,13 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D
Ok(session_id) => {
let mut store = STORE.lock().await;
if !store.sessions.contains_key(&session_id) {
tracing::debug!("Session provided in cookie, but does not exist");
let id = store.new_session(res).await;
depot.insert("session_id", id);
}
}
Err(_) => {
tracing::debug!("Session provided in cookie, but is not a valid number");
let mut store = STORE.lock().await;
let id = store.new_session(res).await;
@@ -44,6 +48,7 @@ async fn session_middleware(req: &mut Request, res: &mut Response, depot: &mut D
}
}
None => {
tracing::debug!("Session was not provided in cookie");
let mut store = STORE.lock().await;
let id = store.new_session(res).await;
@@ -80,10 +85,20 @@ async fn handle_socket(session_id: usize, ws: WebSocket) {
// Handle incoming messages
let fut = async move {
let mut store = STORE.lock().await;
let session = store.sessions.get_mut(&session_id).unwrap();
session.tx = Some(tx);
// let session = store
// .sessions
// .get_mut(&session_id)
// .expect("Unable to get session");
// tx.send(Ok(Message::ping("1")))
// .expect("Unable to send message");
// session.tx = Some(tx);
drop(store);
tracing::info!(
"WebSocket connection established for session_id: {}",
session_id
);
while let Some(result) = user_ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
@@ -93,19 +108,52 @@ async fn handle_socket(session_id: usize, ws: WebSocket) {
}
};
println!("Received message: {:?}", msg);
if msg.is_close() {
tracing::info!("WebSocket closing for Session {}", session_id);
break;
}
if msg.is_text() {
let text = msg.to_str().unwrap();
// Deserialize
match serde_json::from_str::<IncomingMessage>(text) {
Ok(message) => {
tracing::info!("Received message: {:?}", message);
}
Err(e) => {
tracing::error!("Error deserializing message: {} {}", text, e);
}
}
}
}
};
tokio::task::spawn(fut);
}
#[handler]
pub async fn download(req: &mut Request, res: &mut Response) {
let download_id = req.param::<String>("id").unwrap();
pub async fn download(req: &mut Request, res: &mut Response, depot: &mut Depot) {
let download_id = req
.param::<String>("id")
.expect("Download ID required to download file");
let store = STORE.lock().await;
let executable = store.executables.get(&download_id as &str).unwrap();
let data = executable.with_key(b"test");
let session_id =
get_session_id(req, depot).expect("Session ID could not be found via request or depot");
let store = &mut *STORE.lock().await;
let session = store
.sessions
.get_mut(&session_id)
.expect("Session not found");
let executable = store
.executables
.get(&download_id as &str)
.expect("Executable not found");
// Create a download for the session
let session_download = session.add_download(executable);
let data = executable.with_key(session_id.to_string().as_bytes());
if let Err(e) = res.write_body(data) {
eprintln!("Error writing body: {}", e);
@@ -115,8 +163,10 @@ pub async fn download(req: &mut Request, res: &mut Response) {
res.headers.insert(
"Content-Disposition",
HeaderValue::from_str(format!("attachment; filename=\"{}\"", executable.filename).as_str())
.unwrap(),
HeaderValue::from_str(
format!("attachment; filename=\"{}\"", session_download.filename).as_str(),
)
.expect("Unable to create header"),
);
res.headers.insert(
"Content-Type",
@@ -144,6 +194,7 @@ pub async fn get_session(req: &mut Request, res: &mut Response, depot: &mut Depo
}
}
// Acquires the session id from the request, preferring the request Cookie
fn get_session_id(req: &Request, depot: &Depot) -> Option<usize> {
match req.cookie("Session") {
Some(cookie) => match cookie.value().parse::<usize>() {
@@ -161,11 +212,19 @@ fn get_session_id(req: &Request, depot: &Depot) -> Option<usize> {
async fn main() {
let port = std::env::var("PORT").unwrap_or_else(|_| "5800".to_string());
let addr = format!("0.0.0.0:{}", port);
tracing_subscriber::fmt().init();
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::new(format!(
"info,dynamic_preauth={}",
// Only log our message in debug mode
match cfg!(debug_assertions) {
true => "debug",
false => "info",
}
)))
.init();
// Check if we are deployed on Railway
let is_railway = env::var("RAILWAY_PROJECT_ID").is_ok();
if is_railway {
let build_logs = format!(
"https://railway.com/project/{}/service/{}?environmentId={}&id={}#build",
@@ -175,7 +234,7 @@ async fn main() {
env::var("RAILWAY_DEPLOYMENT_ID").unwrap()
);
println!("Build logs available here: {}", build_logs);
tracing::info!("Build logs available here: {}", build_logs);
}
// Add the executables to the store
@@ -201,6 +260,7 @@ async fn main() {
.allow_origin(&origin)
.allow_methods(vec![Method::GET])
.into_handler();
tracing::debug!("CORS Origin: {}", &origin);
let static_dir = StaticDir::new(["./public"]).defaults("index.html");

View File

@@ -1,23 +1,67 @@
use rand::Rng;
use salvo::{http::cookie::Cookie, websocket::Message, Response};
use serde::Serialize;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, path};
use tokio::sync::{mpsc::UnboundedSender, Mutex};
use crate::utility::search;
#[derive(Clone, Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct Session {
pub tokens: Vec<String>,
pub last_seen: chrono::DateTime<chrono::Utc>,
pub downloads: Vec<SessionDownload>,
pub first_seen: chrono::DateTime<chrono::Utc>,
// The last time a request OR websocket message with this session was made
pub last_seen: chrono::DateTime<chrono::Utc>,
// The last time a request was made with this session
pub last_request: chrono::DateTime<chrono::Utc>,
// The sender for the websocket connection
#[serde(skip_serializing)]
pub tx: Option<UnboundedSender<Result<Message, salvo::Error>>>,
}
#[derive(Default, Clone, Debug)]
impl Session {
// Update the last seen time(s) for the session
pub fn seen(&mut self, socket: bool) {
self.last_seen = chrono::Utc::now();
if !socket {
self.last_request = chrono::Utc::now();
}
}
// Add a download to the session
pub fn add_download(&mut self, exe: &Executable) -> &SessionDownload {
let mut rng = rand::thread_rng();
let token: u64 = rng.gen();
let download = SessionDownload {
token,
filename: format!("{}-{:16x}{}", exe.name, token, exe.extension),
last_used: chrono::Utc::now(),
download_time: chrono::Utc::now(),
};
self.downloads.push(download);
return self.downloads.last().unwrap();
}
}
#[derive(Serialize, Debug, Clone)]
pub struct SessionDownload {
pub token: u64,
pub filename: String,
pub last_used: chrono::DateTime<chrono::Utc>,
pub download_time: chrono::DateTime<chrono::Utc>,
}
impl SessionDownload {}
#[derive(Clone, Debug)]
pub struct State<'a> {
// A map of executables, keyed by their type/platform
pub executables: HashMap<&'a str, Executable>,
// A map of sessions, keyed by their identifier (a random number)
pub sessions: HashMap<usize, Session>,
}
@@ -36,15 +80,17 @@ impl<'a> State<'a> {
let key_start = search(&data, pattern.as_bytes(), 0).unwrap();
let key_end = key_start + pattern.len();
let filename = path::Path::new(&exe_path)
.file_name()
.unwrap()
.to_string_lossy()
.into_owned();
let filename = path::Path::new(&exe_path);
let name = filename.file_stem().unwrap().to_str().unwrap();
let extension = match filename.extension() {
Some(s) => s.to_str().unwrap(),
None => "",
};
let exe = Executable {
data,
filename,
name: name.to_string(),
extension: extension.to_string(),
key_start: key_start,
key_end: key_end,
};
@@ -60,13 +106,16 @@ impl<'a> State<'a> {
self.sessions.insert(
id,
Session {
tokens: vec![],
downloads: Vec::new(),
last_seen: now,
last_request: now,
first_seen: now,
tx: None,
},
);
tracing::info!("New session created: {}", id);
res.add_cookie(
Cookie::build(("Session", id.to_string()))
.permanent()
@@ -79,10 +128,11 @@ impl<'a> State<'a> {
#[derive(Default, Clone, Debug)]
pub struct Executable {
pub data: Vec<u8>,
pub filename: String,
pub key_start: usize,
pub key_end: usize,
pub data: Vec<u8>, // the raw data of the executable
pub name: String, // the name before the extension
pub extension: String, // may be empty string
pub key_start: usize, // the index of the byte where the key starts
pub key_end: usize, // the index of the byte where the key ends
}
impl Executable {
@@ -104,3 +154,19 @@ impl Executable {
return data;
}
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum IncomingMessage {
// A request from the client to delete a session token
DeleteSessionToken { id: u64 },
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub enum OutgoingMessage {
// An alert to the client that a session download has been used.
TokenAlert { token: u64 },
// A message describing the current session state
State { session: Session },
}