1use std::net::{SocketAddr, TcpListener, TcpStream};
2use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
3use std::sync::mpsc::{self, Receiver, Sender};
4use std::sync::Arc;
5use std::thread::JoinHandle;
6use std::time::Duration;
7
8use futures_util::SinkExt;
9use tokio::runtime::Runtime;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest;
11use tokio_tungstenite::tungstenite::Message;
12
13use crate::{Adapter, AdapterError, EgressAdapter, IngressAdapter};
14
15const WEBSOCKET_EGRESS_CONNECTING: u8 = 1;
16const WEBSOCKET_EGRESS_CONNECTED: u8 = 2;
17const WEBSOCKET_EGRESS_CLOSED: u8 = 3;
18const WEBSOCKET_EGRESS_ERROR: u8 = 4;
19
20fn websocket_egress_state_label(state: u8) -> &'static str {
21 match state {
22 WEBSOCKET_EGRESS_CONNECTING => "connecting",
23 WEBSOCKET_EGRESS_CONNECTED => "connected",
24 WEBSOCKET_EGRESS_CLOSED => "closed",
25 WEBSOCKET_EGRESS_ERROR => "error",
26 _ => "idle",
27 }
28}
29
30#[derive(Default)]
31struct WebSocketIngressStatus {
32 active_sessions: AtomicU64,
33 total_sessions: AtomicU64,
34}
35
36impl WebSocketIngressStatus {
37 fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
38 let active_sessions = self.active_sessions.load(Ordering::Relaxed);
39 let mut details = serde_json::Map::new();
40 details.insert(
41 "socket_state".to_string(),
42 serde_json::Value::String("listening".to_string()),
43 );
44 details.insert(
45 "session_state".to_string(),
46 serde_json::Value::String(if active_sessions > 0 { "active" } else { "idle" }.to_string()),
47 );
48 details.insert(
49 "active_sessions".to_string(),
50 serde_json::Value::from(active_sessions),
51 );
52 details.insert(
53 "total_sessions".to_string(),
54 serde_json::Value::from(self.total_sessions.load(Ordering::Relaxed)),
55 );
56 details
57 }
58}
59
60#[derive(Default)]
61struct WebSocketEgressStatus {
62 state: AtomicU8,
63 last_message_bytes: AtomicU64,
64}
65
66impl WebSocketEgressStatus {
67 fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
68 let state = websocket_egress_state_label(self.state.load(Ordering::Relaxed));
69 let mut details = serde_json::Map::new();
70 details.insert(
71 "socket_state".to_string(),
72 serde_json::Value::String(state.to_string()),
73 );
74 details.insert(
75 "session_state".to_string(),
76 serde_json::Value::String(state.to_string()),
77 );
78 details.insert(
79 "last_message_bytes".to_string(),
80 serde_json::Value::from(self.last_message_bytes.load(Ordering::Relaxed)),
81 );
82 details
83 }
84}
85
86fn recv_websocket_connection(
87 stream: TcpStream,
88 tx: Sender<Vec<u8>>,
89 max_message_bytes: usize,
90 status: Arc<WebSocketIngressStatus>,
91) {
92 let mut websocket = match tokio_tungstenite::tungstenite::accept(stream) {
93 Ok(websocket) => websocket,
94 Err(_) => return,
95 };
96
97 status.active_sessions.fetch_add(1, Ordering::Relaxed);
98 status.total_sessions.fetch_add(1, Ordering::Relaxed);
99
100 loop {
101 match websocket.read() {
102 Ok(Message::Binary(payload)) => {
103 if payload.len() > max_message_bytes {
104 let _ = websocket.close(None);
105 break;
106 }
107 if tx.send(payload).is_err() {
108 break;
109 }
110 }
111 Ok(Message::Text(payload)) => {
112 let payload = payload.to_string().into_bytes();
113 if payload.len() > max_message_bytes {
114 let _ = websocket.close(None);
115 break;
116 }
117 if tx.send(payload).is_err() {
118 break;
119 }
120 }
121 Ok(Message::Ping(payload)) => {
122 if websocket.send(Message::Pong(payload)).is_err() {
123 break;
124 }
125 }
126 Ok(Message::Close(_)) => break,
127 Ok(Message::Frame(_)) => {}
128 Err(_) => break,
129 Ok(_) => {}
130 }
131 }
132
133 status.active_sessions.fetch_sub(1, Ordering::Relaxed);
134}
135
136pub struct WebSocketIngress {
141 name: String,
142 addr: SocketAddr,
143 path: String,
144 status: Arc<WebSocketIngressStatus>,
145 rx: Receiver<Vec<u8>>,
146 shutdown: Option<mpsc::Sender<()>>,
147 thread: Option<JoinHandle<()>>,
148}
149
150impl WebSocketIngress {
151 pub fn bind(
153 name: impl Into<String>,
154 bind: &str,
155 path: impl Into<String>,
156 max_message_bytes: usize,
157 ) -> Result<Self, AdapterError> {
158 let listener = TcpListener::bind(bind).map_err(|source| AdapterError::Io {
159 op: "websocket_bind",
160 source,
161 })?;
162 listener
163 .set_nonblocking(true)
164 .map_err(|source| AdapterError::Io {
165 op: "websocket_set_nonblocking",
166 source,
167 })?;
168 let addr = listener.local_addr().map_err(|source| AdapterError::Io {
169 op: "websocket_local_addr",
170 source,
171 })?;
172
173 let path = path.into();
174 if !path.starts_with('/') {
175 return Err(AdapterError::Config {
176 detail: "websocket ingress path must start with '/'".into(),
177 });
178 }
179
180 let (tx, rx) = mpsc::channel();
181 let (shutdown_tx, shutdown_rx) = mpsc::channel();
182 let status = Arc::new(WebSocketIngressStatus::default());
183 let thread_status = Arc::clone(&status);
184 let thread = std::thread::spawn(move || loop {
185 if shutdown_rx.try_recv().is_ok() {
186 break;
187 }
188
189 match listener.accept() {
190 Ok((stream, _)) => {
191 let tx = tx.clone();
192 let connection_status = Arc::clone(&thread_status);
193 std::thread::spawn(move || {
194 recv_websocket_connection(stream, tx, max_message_bytes, connection_status)
195 });
196 }
197 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
198 std::thread::sleep(Duration::from_millis(10));
199 }
200 Err(_) => break,
201 }
202 });
203
204 Ok(Self {
205 name: name.into(),
206 addr,
207 path,
208 status,
209 rx,
210 shutdown: Some(shutdown_tx),
211 thread: Some(thread),
212 })
213 }
214
215 pub fn local_addr(&self) -> SocketAddr {
217 self.addr
218 }
219
220 pub fn endpoint(&self) -> String {
222 format!("ws://{}{}", self.addr, self.path)
223 }
224}
225
226impl Drop for WebSocketIngress {
227 fn drop(&mut self) {
228 if let Some(shutdown) = self.shutdown.take() {
229 let _ = shutdown.send(());
230 }
231 if let Some(thread) = self.thread.take() {
232 let _ = thread.join();
233 }
234 }
235}
236
237impl Adapter for WebSocketIngress {
238 fn name(&self) -> &str {
239 &self.name
240 }
241
242 fn transport_kind(&self) -> &str {
243 "websocket"
244 }
245
246 fn connection_state(&self) -> &str {
247 if self.status.active_sessions.load(Ordering::Relaxed) > 0 {
248 "connected"
249 } else {
250 "ready"
251 }
252 }
253
254 fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
255 self.status.status_fields()
256 }
257}
258
259impl IngressAdapter for WebSocketIngress {
260 fn read_next(&mut self, out: &mut [u8]) -> Result<Option<usize>, AdapterError> {
261 let payload = match self.rx.recv() {
262 Ok(payload) => payload,
263 Err(_) => return Ok(None),
264 };
265 if payload.len() > out.len() {
266 return Err(AdapterError::WebSocket {
267 op: "websocket_read_next",
268 detail: format!("payload too large len={} buf={}", payload.len(), out.len()),
269 });
270 }
271 out[..payload.len()].copy_from_slice(&payload);
272 Ok(Some(payload.len()))
273 }
274}
275
276pub struct WebSocketEgress {
281 name: String,
282 url: String,
283 headers: Vec<(String, String)>,
284 connect_timeout: Duration,
285 max_retries: usize,
286 runtime: Runtime,
287 status: Arc<WebSocketEgressStatus>,
288}
289
290impl WebSocketEgress {
291 pub fn new(
293 name: impl Into<String>,
294 url: impl Into<String>,
295 headers: Vec<(String, String)>,
296 connect_timeout: Duration,
297 max_retries: usize,
298 ) -> Result<Self, AdapterError> {
299 let url = url.into();
300 if !(url.starts_with("ws://") || url.starts_with("wss://")) {
301 return Err(AdapterError::Config {
302 detail: "websocket egress url must start with ws:// or wss://".into(),
303 });
304 }
305 let runtime = tokio::runtime::Builder::new_current_thread()
306 .enable_all()
307 .build()
308 .map_err(|err| AdapterError::WebSocket {
309 op: "websocket_runtime_build",
310 detail: err.to_string(),
311 })?;
312 Ok(Self {
313 name: name.into(),
314 url,
315 headers,
316 connect_timeout,
317 max_retries,
318 runtime,
319 status: Arc::new(WebSocketEgressStatus::default()),
320 })
321 }
322}
323
324impl Adapter for WebSocketEgress {
325 fn name(&self) -> &str {
326 &self.name
327 }
328
329 fn transport_kind(&self) -> &str {
330 "websocket"
331 }
332
333 fn connection_state(&self) -> &str {
334 websocket_egress_state_label(self.status.state.load(Ordering::Relaxed))
335 }
336
337 fn status_fields(&self) -> serde_json::Map<String, serde_json::Value> {
338 self.status.status_fields()
339 }
340}
341
342impl EgressAdapter for WebSocketEgress {
343 fn write_msg(&mut self, msg: &[u8]) -> Result<(), AdapterError> {
344 let mut attempt = 0usize;
345 loop {
346 attempt = attempt.saturating_add(1);
347
348 let url = self.url.clone();
349 let headers = self.headers.clone();
350 let connect_timeout = self.connect_timeout;
351 let payload = msg.to_vec();
352 let status = Arc::clone(&self.status);
353
354 status.state.store(WEBSOCKET_EGRESS_CONNECTING, Ordering::Relaxed);
355 status
356 .last_message_bytes
357 .store(payload.len() as u64, Ordering::Relaxed);
358
359 let result = self.runtime.block_on(async move {
360 let mut request =
361 url.into_client_request()
362 .map_err(|err| AdapterError::WebSocket {
363 op: "websocket_request_build",
364 detail: err.to_string(),
365 })?;
366 for (name, value) in headers {
367 let header_name =
368 tokio_tungstenite::tungstenite::http::header::HeaderName::from_bytes(
369 name.as_bytes(),
370 )
371 .map_err(|err| AdapterError::WebSocket {
372 op: "websocket_header_name",
373 detail: err.to_string(),
374 })?;
375 let header_value =
376 tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&value)
377 .map_err(|err| AdapterError::WebSocket {
378 op: "websocket_header_value",
379 detail: err.to_string(),
380 })?;
381 request.headers_mut().append(header_name, header_value);
382 }
383
384 let (mut websocket, _) = tokio::time::timeout(
385 connect_timeout,
386 tokio_tungstenite::connect_async(request),
387 )
388 .await
389 .map_err(|_| AdapterError::WebSocket {
390 op: "websocket_connect",
391 detail: format!("timed out after {} ms", connect_timeout.as_millis()),
392 })?
393 .map_err(|err| AdapterError::WebSocket {
394 op: "websocket_connect",
395 detail: err.to_string(),
396 })?;
397
398 status
399 .state
400 .store(WEBSOCKET_EGRESS_CONNECTED, Ordering::Relaxed);
401
402 websocket
403 .send(Message::Binary(payload))
404 .await
405 .map_err(|err| AdapterError::WebSocket {
406 op: "websocket_send",
407 detail: err.to_string(),
408 })?;
409 let _ = websocket.close(None).await;
410 status
411 .state
412 .store(WEBSOCKET_EGRESS_CLOSED, Ordering::Relaxed);
413 Ok(())
414 });
415
416 match result {
417 Ok(()) => return Ok(()),
418 Err(err) if attempt <= self.max_retries => {
419 self.status
420 .state
421 .store(WEBSOCKET_EGRESS_ERROR, Ordering::Relaxed);
422 continue;
423 }
424 Err(err) => {
425 self.status
426 .state
427 .store(WEBSOCKET_EGRESS_ERROR, Ordering::Relaxed);
428 return Err(err);
429 }
430 }
431 }
432 }
433
434 fn flush(&mut self) -> Result<(), AdapterError> {
435 Ok(())
436 }
437}