byteor_adapters/
websocket.rs

1use std::net::{SocketAddr, TcpListener, TcpStream};
2use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
3use std::sync::mpsc::{self, Receiver, Sender};
4use std::sync::Arc;
5use std::thread::JoinHandle;
6use std::time::Duration;
7
8use futures_util::SinkExt;
9use tokio::runtime::Runtime;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest;
11use tokio_tungstenite::tungstenite::Message;
12
13use crate::{Adapter, AdapterError, EgressAdapter, IngressAdapter};
14
15const WEBSOCKET_EGRESS_CONNECTING: u8 = 1;
16const WEBSOCKET_EGRESS_CONNECTED: u8 = 2;
17const WEBSOCKET_EGRESS_CLOSED: u8 = 3;
18const WEBSOCKET_EGRESS_ERROR: u8 = 4;
19
20fn websocket_egress_state_label(state: u8) -> &'static str {
21    match state {
22        WEBSOCKET_EGRESS_CONNECTING => "connecting",
23        WEBSOCKET_EGRESS_CONNECTED => "connected",
24        WEBSOCKET_EGRESS_CLOSED => "closed",
25        WEBSOCKET_EGRESS_ERROR => "error",
26        _ => "idle",
27    }
28}
29
30#[derive(Default)]
31struct WebSocketIngressStatus {
32    active_sessions: AtomicU64,
33    total_sessions: AtomicU64,
34}
35
36impl WebSocketIngressStatus {
37    fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
38        let active_sessions = self.active_sessions.load(Ordering::Relaxed);
39        let mut details = serde_json::Map::new();
40        details.insert(
41            "socket_state".to_string(),
42            serde_json::Value::String("listening".to_string()),
43        );
44        details.insert(
45            "session_state".to_string(),
46            serde_json::Value::String(if active_sessions > 0 { "active" } else { "idle" }.to_string()),
47        );
48        details.insert(
49            "active_sessions".to_string(),
50            serde_json::Value::from(active_sessions),
51        );
52        details.insert(
53            "total_sessions".to_string(),
54            serde_json::Value::from(self.total_sessions.load(Ordering::Relaxed)),
55        );
56        details
57    }
58}
59
60#[derive(Default)]
61struct WebSocketEgressStatus {
62    state: AtomicU8,
63    last_message_bytes: AtomicU64,
64}
65
66impl WebSocketEgressStatus {
67    fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
68        let state = websocket_egress_state_label(self.state.load(Ordering::Relaxed));
69        let mut details = serde_json::Map::new();
70        details.insert(
71            "socket_state".to_string(),
72            serde_json::Value::String(state.to_string()),
73        );
74        details.insert(
75            "session_state".to_string(),
76            serde_json::Value::String(state.to_string()),
77        );
78        details.insert(
79            "last_message_bytes".to_string(),
80            serde_json::Value::from(self.last_message_bytes.load(Ordering::Relaxed)),
81        );
82        details
83    }
84}
85
86fn recv_websocket_connection(
87    stream: TcpStream,
88    tx: Sender<Vec<u8>>,
89    max_message_bytes: usize,
90    status: Arc<WebSocketIngressStatus>,
91) {
92    let mut websocket = match tokio_tungstenite::tungstenite::accept(stream) {
93        Ok(websocket) => websocket,
94        Err(_) => return,
95    };
96
97    status.active_sessions.fetch_add(1, Ordering::Relaxed);
98    status.total_sessions.fetch_add(1, Ordering::Relaxed);
99
100    loop {
101        match websocket.read() {
102            Ok(Message::Binary(payload)) => {
103                if payload.len() > max_message_bytes {
104                    let _ = websocket.close(None);
105                    break;
106                }
107                if tx.send(payload).is_err() {
108                    break;
109                }
110            }
111            Ok(Message::Text(payload)) => {
112                let payload = payload.to_string().into_bytes();
113                if payload.len() > max_message_bytes {
114                    let _ = websocket.close(None);
115                    break;
116                }
117                if tx.send(payload).is_err() {
118                    break;
119                }
120            }
121            Ok(Message::Ping(payload)) => {
122                if websocket.send(Message::Pong(payload)).is_err() {
123                    break;
124                }
125            }
126            Ok(Message::Close(_)) => break,
127            Ok(Message::Frame(_)) => {}
128            Err(_) => break,
129            Ok(_) => {}
130        }
131    }
132
133    status.active_sessions.fetch_sub(1, Ordering::Relaxed);
134}
135
136/// Minimal blocking WebSocket ingress adapter.
137///
138/// This adapter binds a WebSocket listener and surfaces each binary or text frame as one ingress
139/// message.
140pub struct WebSocketIngress {
141    name: String,
142    addr: SocketAddr,
143    path: String,
144    status: Arc<WebSocketIngressStatus>,
145    rx: Receiver<Vec<u8>>,
146    shutdown: Option<mpsc::Sender<()>>,
147    thread: Option<JoinHandle<()>>,
148}
149
150impl WebSocketIngress {
151    /// Bind a blocking WebSocket ingress adapter.
152    pub fn bind(
153        name: impl Into<String>,
154        bind: &str,
155        path: impl Into<String>,
156        max_message_bytes: usize,
157    ) -> Result<Self, AdapterError> {
158        let listener = TcpListener::bind(bind).map_err(|source| AdapterError::Io {
159            op: "websocket_bind",
160            source,
161        })?;
162        listener
163            .set_nonblocking(true)
164            .map_err(|source| AdapterError::Io {
165                op: "websocket_set_nonblocking",
166                source,
167            })?;
168        let addr = listener.local_addr().map_err(|source| AdapterError::Io {
169            op: "websocket_local_addr",
170            source,
171        })?;
172
173        let path = path.into();
174        if !path.starts_with('/') {
175            return Err(AdapterError::Config {
176                detail: "websocket ingress path must start with '/'".into(),
177            });
178        }
179
180        let (tx, rx) = mpsc::channel();
181        let (shutdown_tx, shutdown_rx) = mpsc::channel();
182        let status = Arc::new(WebSocketIngressStatus::default());
183        let thread_status = Arc::clone(&status);
184        let thread = std::thread::spawn(move || loop {
185            if shutdown_rx.try_recv().is_ok() {
186                break;
187            }
188
189            match listener.accept() {
190                Ok((stream, _)) => {
191                    let tx = tx.clone();
192                    let connection_status = Arc::clone(&thread_status);
193                    std::thread::spawn(move || {
194                        recv_websocket_connection(stream, tx, max_message_bytes, connection_status)
195                    });
196                }
197                Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
198                    std::thread::sleep(Duration::from_millis(10));
199                }
200                Err(_) => break,
201            }
202        });
203
204        Ok(Self {
205            name: name.into(),
206            addr,
207            path,
208            status,
209            rx,
210            shutdown: Some(shutdown_tx),
211            thread: Some(thread),
212        })
213    }
214
215    /// Return the bound socket address.
216    pub fn local_addr(&self) -> SocketAddr {
217        self.addr
218    }
219
220    /// Return a client endpoint URL for the adapter.
221    pub fn endpoint(&self) -> String {
222        format!("ws://{}{}", self.addr, self.path)
223    }
224}
225
226impl Drop for WebSocketIngress {
227    fn drop(&mut self) {
228        if let Some(shutdown) = self.shutdown.take() {
229            let _ = shutdown.send(());
230        }
231        if let Some(thread) = self.thread.take() {
232            let _ = thread.join();
233        }
234    }
235}
236
237impl Adapter for WebSocketIngress {
238    fn name(&self) -> &str {
239        &self.name
240    }
241
242    fn transport_kind(&self) -> &str {
243        "websocket"
244    }
245
246    fn connection_state(&self) -> &str {
247        if self.status.active_sessions.load(Ordering::Relaxed) > 0 {
248            "connected"
249        } else {
250            "ready"
251        }
252    }
253
254    fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
255        self.status.status_fields()
256    }
257}
258
259impl IngressAdapter for WebSocketIngress {
260    fn read_next(&mut self, out: &mut [u8]) -> Result<Option<usize>, AdapterError> {
261        let payload = match self.rx.recv() {
262            Ok(payload) => payload,
263            Err(_) => return Ok(None),
264        };
265        if payload.len() > out.len() {
266            return Err(AdapterError::WebSocket {
267                op: "websocket_read_next",
268                detail: format!("payload too large len={} buf={}", payload.len(), out.len()),
269            });
270        }
271        out[..payload.len()].copy_from_slice(&payload);
272        Ok(Some(payload.len()))
273    }
274}
275
276/// Minimal blocking WebSocket egress adapter.
277///
278/// Each `write_msg()` opens a WebSocket connection, sends one binary frame, and then closes the
279/// connection.
280pub struct WebSocketEgress {
281    name: String,
282    url: String,
283    headers: Vec<(String, String)>,
284    connect_timeout: Duration,
285    max_retries: usize,
286    runtime: Runtime,
287    status: Arc<WebSocketEgressStatus>,
288}
289
290impl WebSocketEgress {
291    /// Create a blocking WebSocket egress adapter.
292    pub fn new(
293        name: impl Into<String>,
294        url: impl Into<String>,
295        headers: Vec<(String, String)>,
296        connect_timeout: Duration,
297        max_retries: usize,
298    ) -> Result<Self, AdapterError> {
299        let url = url.into();
300        if !(url.starts_with("ws://") || url.starts_with("wss://")) {
301            return Err(AdapterError::Config {
302                detail: "websocket egress url must start with ws:// or wss://".into(),
303            });
304        }
305        let runtime = tokio::runtime::Builder::new_current_thread()
306            .enable_all()
307            .build()
308            .map_err(|err| AdapterError::WebSocket {
309                op: "websocket_runtime_build",
310                detail: err.to_string(),
311            })?;
312        Ok(Self {
313            name: name.into(),
314            url,
315            headers,
316            connect_timeout,
317            max_retries,
318            runtime,
319            status: Arc::new(WebSocketEgressStatus::default()),
320        })
321    }
322}
323
324impl Adapter for WebSocketEgress {
325    fn name(&self) -> &str {
326        &self.name
327    }
328
329    fn transport_kind(&self) -> &str {
330        "websocket"
331    }
332
333    fn connection_state(&self) -> &str {
334        websocket_egress_state_label(self.status.state.load(Ordering::Relaxed))
335    }
336
337    fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
338        self.status.status_fields()
339    }
340}
341
342impl EgressAdapter for WebSocketEgress {
343    fn write_msg(&mut self, msg: &[u8]) -> Result<(), AdapterError> {
344        let mut attempt = 0usize;
345        loop {
346            attempt = attempt.saturating_add(1);
347
348            let url = self.url.clone();
349            let headers = self.headers.clone();
350            let connect_timeout = self.connect_timeout;
351            let payload = msg.to_vec();
352            let status = Arc::clone(&self.status);
353
354            status.state.store(WEBSOCKET_EGRESS_CONNECTING, Ordering::Relaxed);
355            status
356                .last_message_bytes
357                .store(payload.len() as u64, Ordering::Relaxed);
358
359            let result = self.runtime.block_on(async move {
360                let mut request =
361                    url.into_client_request()
362                        .map_err(|err| AdapterError::WebSocket {
363                            op: "websocket_request_build",
364                            detail: err.to_string(),
365                        })?;
366                for (name, value) in headers {
367                    let header_name =
368                        tokio_tungstenite::tungstenite::http::header::HeaderName::from_bytes(
369                            name.as_bytes(),
370                        )
371                        .map_err(|err| AdapterError::WebSocket {
372                            op: "websocket_header_name",
373                            detail: err.to_string(),
374                        })?;
375                    let header_value =
376                        tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&value)
377                            .map_err(|err| AdapterError::WebSocket {
378                                op: "websocket_header_value",
379                                detail: err.to_string(),
380                            })?;
381                    request.headers_mut().append(header_name, header_value);
382                }
383
384                let (mut websocket, _) = tokio::time::timeout(
385                    connect_timeout,
386                    tokio_tungstenite::connect_async(request),
387                )
388                .await
389                .map_err(|_| AdapterError::WebSocket {
390                    op: "websocket_connect",
391                    detail: format!("timed out after {} ms", connect_timeout.as_millis()),
392                })?
393                .map_err(|err| AdapterError::WebSocket {
394                    op: "websocket_connect",
395                    detail: err.to_string(),
396                })?;
397
398                status
399                    .state
400                    .store(WEBSOCKET_EGRESS_CONNECTED, Ordering::Relaxed);
401
402                websocket
403                    .send(Message::Binary(payload))
404                    .await
405                    .map_err(|err| AdapterError::WebSocket {
406                        op: "websocket_send",
407                        detail: err.to_string(),
408                    })?;
409                let _ = websocket.close(None).await;
410                status
411                    .state
412                    .store(WEBSOCKET_EGRESS_CLOSED, Ordering::Relaxed);
413                Ok(())
414            });
415
416            match result {
417                Ok(()) => return Ok(()),
418                Err(err) if attempt <= self.max_retries => {
419                    self.status
420                        .state
421                        .store(WEBSOCKET_EGRESS_ERROR, Ordering::Relaxed);
422                    continue;
423                }
424                Err(err) => {
425                    self.status
426                        .state
427                        .store(WEBSOCKET_EGRESS_ERROR, Ordering::Relaxed);
428                    return Err(err);
429                }
430            }
431        }
432    }
433
434    fn flush(&mut self) -> Result<(), AdapterError> {
435        Ok(())
436    }
437}