diff --git a/frontend/src/components/Demo.tsx b/frontend/src/components/Demo.tsx
index b122fc6..babb2b5 100644
--- a/frontend/src/components/Demo.tsx
+++ b/frontend/src/components/Demo.tsx
@@ -42,7 +42,7 @@ const Demo = ({ class: className }: DemoProps) => {
Your session is{" "}
- {id}
+ {"0x" + id?.toString(16).toUpperCase()}
. You have{" "}
diff --git a/frontend/src/components/useSocket.ts b/frontend/src/components/useSocket.ts
index d80b9d1..4c4aaed 100644
--- a/frontend/src/components/useSocket.ts
+++ b/frontend/src/components/useSocket.ts
@@ -14,14 +14,14 @@ interface Executable {
}
interface UseSocketResult {
- id: string | null;
+ id: number | null;
executables: Executable[];
downloads: Download[] | null;
deleteDownload: (id: string) => void;
}
function useSocket(): UseSocketResult {
- const [id, setId] = useState(null);
+ const [id, setId] = useState(null);
const [downloads, setDownloads] = useState(null);
const [executables, setExecutables] = useState(null);
@@ -44,9 +44,8 @@ function useSocket(): UseSocketResult {
switch (data.type) {
case "state":
- const downloads = data.downloads as Download[];
- setId(data.session);
- setDownloads(downloads);
+ setId(data.id as number);
+ setDownloads(data.session.downloads as Download[]);
break;
case "executables":
setExecutables(data.executables as Executable[]);
@@ -56,12 +55,13 @@ function useSocket(): UseSocketResult {
}
};
- socket.onclose = () => {
- console.log("WebSocket connection closed");
+ socket.onclose = (event) => {
+ console.log("WebSocket connection closed", event);
};
return () => {
// Close the socket when the component is unmounted
+ console.log("Unmounting, closing WebSocket connection");
socket.close();
};
}, []);
diff --git a/src/main.rs b/src/main.rs
index 6ace9cf..f5920bc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,7 +2,7 @@ use std::env;
use std::sync::LazyLock;
use futures_util::{FutureExt, StreamExt};
-use models::IncomingMessage;
+use models::{IncomingMessage, OutgoingMessage};
use salvo::cors::Cors;
use salvo::http::{HeaderValue, Method, StatusCode, StatusError};
use salvo::logging::Logger;
@@ -10,7 +10,7 @@ use salvo::prelude::{
handler, CatchPanic, Listener, Request, Response, Router, Server, Service, StaticDir,
TcpListener, WebSocketUpgrade,
};
-use salvo::websocket::WebSocket;
+use salvo::websocket::{Message, WebSocket};
use salvo::writing::Json;
use salvo::Depot;
use tokio::sync::{mpsc, Mutex};
@@ -75,39 +75,41 @@ async fn connect(req: &mut Request, res: &mut Response, depot: &Depot) -> Result
.await
}
-async fn handle_socket(session_id: usize, ws: WebSocket) {
+async fn handle_socket(session_id: usize, websocket: WebSocket) {
// Split the socket into a sender and receive of messages.
- let (user_ws_tx, mut user_ws_rx) = ws.split();
+ let (socket_tx, mut socket_rx) = websocket.split();
- // Use an unbounded channel to handle buffering and flushing of messages
- // to the websocket...
- let (tx, rx) = mpsc::unbounded_channel();
- let rx = UnboundedReceiverStream::new(rx);
- let fut = rx.forward(user_ws_tx).map(|result| {
+ // Use an unbounded channel to handle buffering and flushing of messages to the websocket...
+ let (tx_channel, tx_channel_rx) = mpsc::unbounded_channel();
+ let transmit = UnboundedReceiverStream::new(tx_channel_rx);
+ let fut_handle_tx_buffer = transmit.forward(socket_tx).map(|result| {
if let Err(e) = result {
tracing::error!(error = ?e, "websocket send error");
}
});
- tokio::task::spawn(fut);
+ tokio::task::spawn(fut_handle_tx_buffer);
+
+ let mut store = STORE.lock().await;
+ let session = store
+ .sessions
+ .get_mut(&session_id)
+ .expect("Unable to get session");
+ session.tx = Some(tx_channel);
+
+ session.send_message(OutgoingMessage::State {
+ id: session_id,
+ session: session.clone(),
+ });
+ drop(store);
// Handle incoming messages
let fut = async move {
- let mut store = STORE.lock().await;
- // let session = store
- // .sessions
- // .get_mut(&session_id)
- // .expect("Unable to get session");
- // tx.send(Ok(Message::ping("1")))
- // .expect("Unable to send message");
- // session.tx = Some(tx);
- drop(store);
-
tracing::info!(
"WebSocket connection established for session_id: {}",
session_id
);
- while let Some(result) = user_ws_rx.next().await {
+ while let Some(result) = socket_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(error) => {
diff --git a/src/models.rs b/src/models.rs
index 5710325..ed31fca 100644
--- a/src/models.rs
+++ b/src/models.rs
@@ -45,6 +45,20 @@ impl Session {
self.downloads.push(download);
return self.downloads.last().unwrap();
}
+
+ pub fn send_message(&mut self, message: OutgoingMessage) {
+ // TODO: Error handling, check tx exists
+
+ let result = self
+ .tx
+ .as_ref()
+ .unwrap()
+ .send(Ok(Message::text(serde_json::to_string(&message).unwrap())));
+
+ if let Err(e) = result {
+ tracing::error!("Failed to initial session state: {}", e);
+ }
+ }
}
#[derive(Serialize, Debug, Clone)]
@@ -166,12 +180,12 @@ pub enum IncomingMessage {
}
#[derive(Debug, Serialize)]
-#[serde(tag = "type")]
+#[serde(tag = "type", rename_all = "lowercase")]
pub enum OutgoingMessage {
// An alert to the client that a session download has been used.
TokenAlert { token: u64 },
// A message describing the current session state
- State { session: Session },
+ State { session: Session, id: usize },
Executables { executables: Vec },
}