diff --git a/frontend/src/components/Demo.tsx b/frontend/src/components/Demo.tsx index b122fc6..babb2b5 100644 --- a/frontend/src/components/Demo.tsx +++ b/frontend/src/components/Demo.tsx @@ -42,7 +42,7 @@ const Demo = ({ class: className }: DemoProps) => {
Your session is{" "} - {id} + {"0x" + id?.toString(16).toUpperCase()} . You have{" "} diff --git a/frontend/src/components/useSocket.ts b/frontend/src/components/useSocket.ts index d80b9d1..4c4aaed 100644 --- a/frontend/src/components/useSocket.ts +++ b/frontend/src/components/useSocket.ts @@ -14,14 +14,14 @@ interface Executable { } interface UseSocketResult { - id: string | null; + id: number | null; executables: Executable[]; downloads: Download[] | null; deleteDownload: (id: string) => void; } function useSocket(): UseSocketResult { - const [id, setId] = useState(null); + const [id, setId] = useState(null); const [downloads, setDownloads] = useState(null); const [executables, setExecutables] = useState(null); @@ -44,9 +44,8 @@ function useSocket(): UseSocketResult { switch (data.type) { case "state": - const downloads = data.downloads as Download[]; - setId(data.session); - setDownloads(downloads); + setId(data.id as number); + setDownloads(data.session.downloads as Download[]); break; case "executables": setExecutables(data.executables as Executable[]); @@ -56,12 +55,13 @@ function useSocket(): UseSocketResult { } }; - socket.onclose = () => { - console.log("WebSocket connection closed"); + socket.onclose = (event) => { + console.log("WebSocket connection closed", event); }; return () => { // Close the socket when the component is unmounted + console.log("Unmounting, closing WebSocket connection"); socket.close(); }; }, []); diff --git a/src/main.rs b/src/main.rs index 6ace9cf..f5920bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::env; use std::sync::LazyLock; use futures_util::{FutureExt, StreamExt}; -use models::IncomingMessage; +use models::{IncomingMessage, OutgoingMessage}; use salvo::cors::Cors; use salvo::http::{HeaderValue, Method, StatusCode, StatusError}; use salvo::logging::Logger; @@ -10,7 +10,7 @@ use salvo::prelude::{ handler, CatchPanic, Listener, Request, Response, Router, Server, Service, StaticDir, TcpListener, WebSocketUpgrade, }; -use salvo::websocket::WebSocket; +use salvo::websocket::{Message, WebSocket}; use salvo::writing::Json; use salvo::Depot; use tokio::sync::{mpsc, Mutex}; @@ -75,39 +75,41 @@ async fn connect(req: &mut Request, res: &mut Response, depot: &Depot) -> Result .await } -async fn handle_socket(session_id: usize, ws: WebSocket) { +async fn handle_socket(session_id: usize, websocket: WebSocket) { // Split the socket into a sender and receive of messages. - let (user_ws_tx, mut user_ws_rx) = ws.split(); + let (socket_tx, mut socket_rx) = websocket.split(); - // Use an unbounded channel to handle buffering and flushing of messages - // to the websocket... - let (tx, rx) = mpsc::unbounded_channel(); - let rx = UnboundedReceiverStream::new(rx); - let fut = rx.forward(user_ws_tx).map(|result| { + // Use an unbounded channel to handle buffering and flushing of messages to the websocket... + let (tx_channel, tx_channel_rx) = mpsc::unbounded_channel(); + let transmit = UnboundedReceiverStream::new(tx_channel_rx); + let fut_handle_tx_buffer = transmit.forward(socket_tx).map(|result| { if let Err(e) = result { tracing::error!(error = ?e, "websocket send error"); } }); - tokio::task::spawn(fut); + tokio::task::spawn(fut_handle_tx_buffer); + + let mut store = STORE.lock().await; + let session = store + .sessions + .get_mut(&session_id) + .expect("Unable to get session"); + session.tx = Some(tx_channel); + + session.send_message(OutgoingMessage::State { + id: session_id, + session: session.clone(), + }); + drop(store); // Handle incoming messages let fut = async move { - let mut store = STORE.lock().await; - // let session = store - // .sessions - // .get_mut(&session_id) - // .expect("Unable to get session"); - // tx.send(Ok(Message::ping("1"))) - // .expect("Unable to send message"); - // session.tx = Some(tx); - drop(store); - tracing::info!( "WebSocket connection established for session_id: {}", session_id ); - while let Some(result) = user_ws_rx.next().await { + while let Some(result) = socket_rx.next().await { let msg = match result { Ok(msg) => msg, Err(error) => { diff --git a/src/models.rs b/src/models.rs index 5710325..ed31fca 100644 --- a/src/models.rs +++ b/src/models.rs @@ -45,6 +45,20 @@ impl Session { self.downloads.push(download); return self.downloads.last().unwrap(); } + + pub fn send_message(&mut self, message: OutgoingMessage) { + // TODO: Error handling, check tx exists + + let result = self + .tx + .as_ref() + .unwrap() + .send(Ok(Message::text(serde_json::to_string(&message).unwrap()))); + + if let Err(e) = result { + tracing::error!("Failed to initial session state: {}", e); + } + } } #[derive(Serialize, Debug, Clone)] @@ -166,12 +180,12 @@ pub enum IncomingMessage { } #[derive(Debug, Serialize)] -#[serde(tag = "type")] +#[serde(tag = "type", rename_all = "lowercase")] pub enum OutgoingMessage { // An alert to the client that a session download has been used. TokenAlert { token: u64 }, // A message describing the current session state - State { session: Session }, + State { session: Session, id: usize }, Executables { executables: Vec }, }