Skip to main content

karyon_net/layers/
ws.rs

1use std::io;
2
3use async_tungstenite::tungstenite::Message as TungMessage;
4use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
5use futures_util::StreamExt;
6
7#[cfg(feature = "tokio")]
8use async_tungstenite::tokio::{accept_async, client_async, TokioAdapter};
9#[cfg(feature = "smol")]
10use async_tungstenite::{accept_async, client_async};
11
12use crate::{
13    codec::Codec,
14    layer::{ClientLayer, ServerLayer},
15    message::{MessageRx, MessageTx},
16    ByteStream, Endpoint, Error, Result,
17};
18
19#[cfg(feature = "tokio")]
20type WsInner = TokioAdapter<Box<dyn ByteStream>>;
21#[cfg(feature = "smol")]
22type WsInner = Box<dyn ByteStream>;
23
24/// WebSocket message types. Wraps the underlying WS protocol
25/// message kinds without exposing the third-party library.
26#[derive(Debug, Clone)]
27pub enum Message {
28    Text(String),
29    Binary(Vec<u8>),
30    Ping(Vec<u8>),
31    Pong(Vec<u8>),
32    Close,
33}
34
35impl Message {
36    /// Get the message payload as bytes.
37    pub fn into_bytes(self) -> Vec<u8> {
38        match self {
39            Message::Text(s) => s.into_bytes(),
40            Message::Binary(b) => b,
41            Message::Ping(b) => b,
42            Message::Pong(b) => b,
43            Message::Close => Vec::new(),
44        }
45    }
46
47    pub fn is_text(&self) -> bool {
48        matches!(self, Message::Text(_))
49    }
50
51    pub fn is_binary(&self) -> bool {
52        matches!(self, Message::Binary(_))
53    }
54
55    pub fn is_ping(&self) -> bool {
56        matches!(self, Message::Ping(_))
57    }
58
59    pub fn is_pong(&self) -> bool {
60        matches!(self, Message::Pong(_))
61    }
62
63    pub fn is_close(&self) -> bool {
64        matches!(self, Message::Close)
65    }
66}
67
68/// WebSocket layer. Upgrades a byte stream to a WsConn.
69///
70/// Takes a `Codec<Message>` to encode/decode WS messages directly.
71#[derive(Clone)]
72pub struct WsLayer<C> {
73    url: Option<String>,
74    codec: C,
75}
76
77impl<C> WsLayer<C>
78where
79    C: Codec<Message> + Clone,
80{
81    /// Create a client WS layer with the target URL and codec.
82    pub fn client(url: &str, codec: C) -> Self {
83        Self {
84            url: Some(url.to_string()),
85            codec,
86        }
87    }
88
89    /// Create a server WS layer with codec.
90    pub fn server(codec: C) -> Self {
91        Self { url: None, codec }
92    }
93}
94
95impl<C> ClientLayer<Box<dyn ByteStream>, WsConn<C>> for WsLayer<C>
96where
97    C: Codec<Message> + Clone + Send + Sync + 'static,
98    C::Message: Send + Sync + 'static,
99    C::Error: From<io::Error> + Into<Error> + Send + Sync,
100{
101    async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<WsConn<C>> {
102        let url = self
103            .url
104            .as_ref()
105            .ok_or_else(|| Error::MissingConfig("WS client layer requires a URL".into()))?;
106
107        let peer = stream.peer_endpoint();
108        let local = stream.local_endpoint();
109
110        let (ws, _) = client_async(url.as_str(), stream)
111            .await
112            .map_err(|e| Error::IO(io::Error::other(e)))?;
113
114        Ok(WsConn::new(ws, self.codec.clone(), peer, local))
115    }
116}
117
118impl<C> ServerLayer<Box<dyn ByteStream>, WsConn<C>> for WsLayer<C>
119where
120    C: Codec<Message> + Clone + Send + Sync + 'static,
121    C::Message: Send + Sync + 'static,
122    C::Error: From<io::Error> + Into<Error> + Send + Sync,
123{
124    async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<WsConn<C>> {
125        let peer = stream.peer_endpoint();
126        let local = stream.local_endpoint();
127
128        let ws = accept_async(stream)
129            .await
130            .map_err(|e| Error::IO(io::Error::other(e)))?;
131
132        Ok(WsConn::new(ws, self.codec.clone(), peer, local))
133    }
134}
135
136// -- Convert between ws::Message and tungstenite Message --
137
138fn from_tung(msg: TungMessage) -> Message {
139    match msg {
140        TungMessage::Text(t) => Message::Text(t.to_string()),
141        TungMessage::Binary(b) => Message::Binary(b.to_vec()),
142        TungMessage::Ping(b) => Message::Ping(b.to_vec()),
143        TungMessage::Pong(b) => Message::Pong(b.to_vec()),
144        TungMessage::Close(_) => Message::Close,
145        _ => Message::Close,
146    }
147}
148
149fn into_tung(msg: Message) -> TungMessage {
150    match msg {
151        Message::Text(s) => TungMessage::Text(s.into()),
152        Message::Binary(b) => TungMessage::Binary(b.into()),
153        Message::Ping(b) => TungMessage::Ping(b.into()),
154        Message::Pong(b) => TungMessage::Pong(b.into()),
155        Message::Close => TungMessage::Close(None),
156    }
157}
158
159/// WebSocket message connection.
160pub struct WsConn<C> {
161    reader: WsReader<C>,
162    writer: WsWriter<C>,
163}
164
165impl<C: Clone> WsConn<C> {
166    fn new(
167        ws: WebSocketStream<WsInner>,
168        codec: C,
169        peer_endpoint: Option<Endpoint>,
170        local_endpoint: Option<Endpoint>,
171    ) -> Self {
172        let (sender, receiver) = ws.split();
173        Self {
174            reader: WsReader {
175                receiver,
176                codec: codec.clone(),
177                peer_endpoint: peer_endpoint.clone(),
178                local_endpoint: local_endpoint.clone(),
179            },
180            writer: WsWriter {
181                sender,
182                codec,
183                peer_endpoint,
184                local_endpoint,
185            },
186        }
187    }
188}
189
190impl<C> WsConn<C>
191where
192    C: Codec<Message> + Clone + Send + Sync + 'static,
193    C::Message: Send + Sync + 'static,
194    C::Error: From<io::Error> + Into<Error> + Send + Sync,
195{
196    /// Receive one complete message.
197    pub async fn recv_msg(&mut self) -> Result<C::Message> {
198        self.reader.recv_msg().await
199    }
200
201    /// Send one complete message.
202    pub async fn send_msg(&mut self, msg: C::Message) -> Result<()> {
203        self.writer.send_msg(msg).await
204    }
205
206    /// Remote peer address.
207    pub fn peer_endpoint(&self) -> Option<Endpoint> {
208        self.reader.peer_endpoint.clone()
209    }
210
211    /// Local address.
212    pub fn local_endpoint(&self) -> Option<Endpoint> {
213        self.reader.local_endpoint.clone()
214    }
215
216    /// Split into independent reader and writer halves.
217    pub fn split(self) -> (WsReader<C>, WsWriter<C>) {
218        (self.reader, self.writer)
219    }
220}
221
222/// Read half of a WebSocket connection.
223pub struct WsReader<C> {
224    receiver: WebSocketReceiver<WsInner>,
225    codec: C,
226    peer_endpoint: Option<Endpoint>,
227    local_endpoint: Option<Endpoint>,
228}
229
230impl<C> WsReader<C>
231where
232    C: Codec<Message> + Send + Sync,
233    C::Message: Send + Sync,
234    C::Error: From<io::Error> + Into<Error> + Send + Sync,
235{
236    /// Receive one complete message.
237    pub async fn recv_msg(&mut self) -> Result<C::Message> {
238        loop {
239            let raw = self
240                .receiver
241                .next()
242                .await
243                .ok_or(Error::ConnectionClosed)?
244                .map_err(|e| Error::IO(io::Error::other(e)))?;
245
246            let mut ws_msg = from_tung(raw);
247            match self.codec.decode(&mut ws_msg).map_err(Into::into)? {
248                Some((_, item)) => return Ok(item),
249                None => continue,
250            }
251        }
252    }
253
254    /// Remote peer address.
255    pub fn peer_endpoint(&self) -> Option<Endpoint> {
256        self.peer_endpoint.clone()
257    }
258
259    /// Local address.
260    pub fn local_endpoint(&self) -> Option<Endpoint> {
261        self.local_endpoint.clone()
262    }
263}
264
265impl<C> MessageRx for WsReader<C>
266where
267    C: Codec<Message> + Send + Sync,
268    C::Message: Send + Sync,
269    C::Error: From<io::Error> + Into<Error> + Send + Sync,
270{
271    type Message = C::Message;
272
273    fn recv_msg(&mut self) -> impl std::future::Future<Output = Result<Self::Message>> + Send {
274        WsReader::recv_msg(self)
275    }
276
277    fn peer_endpoint(&self) -> Option<Endpoint> {
278        WsReader::peer_endpoint(self)
279    }
280}
281
282/// Write half of a WebSocket connection.
283pub struct WsWriter<C> {
284    sender: WebSocketSender<WsInner>,
285    codec: C,
286    peer_endpoint: Option<Endpoint>,
287    local_endpoint: Option<Endpoint>,
288}
289
290impl<C> WsWriter<C>
291where
292    C: Codec<Message> + Send + Sync,
293    C::Message: Send + Sync,
294    C::Error: From<io::Error> + Into<Error> + Send + Sync,
295{
296    /// Send one complete message.
297    pub async fn send_msg(&mut self, msg: C::Message) -> Result<()> {
298        let mut ws_msg = Message::Binary(Vec::new());
299        self.codec.encode(&msg, &mut ws_msg).map_err(Into::into)?;
300
301        self.sender
302            .send(into_tung(ws_msg))
303            .await
304            .map_err(|e| Error::IO(io::Error::other(e)))?;
305        Ok(())
306    }
307
308    /// Remote peer address.
309    pub fn peer_endpoint(&self) -> Option<Endpoint> {
310        self.peer_endpoint.clone()
311    }
312
313    /// Local address.
314    pub fn local_endpoint(&self) -> Option<Endpoint> {
315        self.local_endpoint.clone()
316    }
317}
318
319impl<C> MessageTx for WsWriter<C>
320where
321    C: Codec<Message> + Send + Sync,
322    C::Message: Send + Sync,
323    C::Error: From<io::Error> + Into<Error> + Send + Sync,
324{
325    type Message = C::Message;
326
327    fn send_msg(
328        &mut self,
329        msg: Self::Message,
330    ) -> impl std::future::Future<Output = Result<()>> + Send {
331        WsWriter::send_msg(self, msg)
332    }
333
334    fn peer_endpoint(&self) -> Option<Endpoint> {
335        WsWriter::peer_endpoint(self)
336    }
337}