1use std::io;
2
3use async_tungstenite::tungstenite::Message as TungMessage;
4use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
5use futures_util::StreamExt;
6
7#[cfg(feature = "tokio")]
8use async_tungstenite::tokio::{accept_async, client_async, TokioAdapter};
9#[cfg(feature = "smol")]
10use async_tungstenite::{accept_async, client_async};
11
12use crate::{
13 codec::Codec,
14 layer::{ClientLayer, ServerLayer},
15 message::{MessageRx, MessageTx},
16 ByteStream, Endpoint, Error, Result,
17};
18
19#[cfg(feature = "tokio")]
20type WsInner = TokioAdapter<Box<dyn ByteStream>>;
21#[cfg(feature = "smol")]
22type WsInner = Box<dyn ByteStream>;
23
24#[derive(Debug, Clone)]
27pub enum Message {
28 Text(String),
29 Binary(Vec<u8>),
30 Ping(Vec<u8>),
31 Pong(Vec<u8>),
32 Close,
33}
34
35impl Message {
36 pub fn into_bytes(self) -> Vec<u8> {
38 match self {
39 Message::Text(s) => s.into_bytes(),
40 Message::Binary(b) => b,
41 Message::Ping(b) => b,
42 Message::Pong(b) => b,
43 Message::Close => Vec::new(),
44 }
45 }
46
47 pub fn is_text(&self) -> bool {
48 matches!(self, Message::Text(_))
49 }
50
51 pub fn is_binary(&self) -> bool {
52 matches!(self, Message::Binary(_))
53 }
54
55 pub fn is_ping(&self) -> bool {
56 matches!(self, Message::Ping(_))
57 }
58
59 pub fn is_pong(&self) -> bool {
60 matches!(self, Message::Pong(_))
61 }
62
63 pub fn is_close(&self) -> bool {
64 matches!(self, Message::Close)
65 }
66}
67
68#[derive(Clone)]
72pub struct WsLayer<C> {
73 url: Option<String>,
74 codec: C,
75}
76
77impl<C> WsLayer<C>
78where
79 C: Codec<Message> + Clone,
80{
81 pub fn client(url: &str, codec: C) -> Self {
83 Self {
84 url: Some(url.to_string()),
85 codec,
86 }
87 }
88
89 pub fn server(codec: C) -> Self {
91 Self { url: None, codec }
92 }
93}
94
95impl<C> ClientLayer<Box<dyn ByteStream>, WsConn<C>> for WsLayer<C>
96where
97 C: Codec<Message> + Clone + Send + Sync + 'static,
98 C::Message: Send + Sync + 'static,
99 C::Error: From<io::Error> + Into<Error> + Send + Sync,
100{
101 async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<WsConn<C>> {
102 let url = self
103 .url
104 .as_ref()
105 .ok_or_else(|| Error::MissingConfig("WS client layer requires a URL".into()))?;
106
107 let peer = stream.peer_endpoint();
108 let local = stream.local_endpoint();
109
110 let (ws, _) = client_async(url.as_str(), stream)
111 .await
112 .map_err(|e| Error::IO(io::Error::other(e)))?;
113
114 Ok(WsConn::new(ws, self.codec.clone(), peer, local))
115 }
116}
117
118impl<C> ServerLayer<Box<dyn ByteStream>, WsConn<C>> for WsLayer<C>
119where
120 C: Codec<Message> + Clone + Send + Sync + 'static,
121 C::Message: Send + Sync + 'static,
122 C::Error: From<io::Error> + Into<Error> + Send + Sync,
123{
124 async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<WsConn<C>> {
125 let peer = stream.peer_endpoint();
126 let local = stream.local_endpoint();
127
128 let ws = accept_async(stream)
129 .await
130 .map_err(|e| Error::IO(io::Error::other(e)))?;
131
132 Ok(WsConn::new(ws, self.codec.clone(), peer, local))
133 }
134}
135
136fn from_tung(msg: TungMessage) -> Message {
139 match msg {
140 TungMessage::Text(t) => Message::Text(t.to_string()),
141 TungMessage::Binary(b) => Message::Binary(b.to_vec()),
142 TungMessage::Ping(b) => Message::Ping(b.to_vec()),
143 TungMessage::Pong(b) => Message::Pong(b.to_vec()),
144 TungMessage::Close(_) => Message::Close,
145 _ => Message::Close,
146 }
147}
148
149fn into_tung(msg: Message) -> TungMessage {
150 match msg {
151 Message::Text(s) => TungMessage::Text(s.into()),
152 Message::Binary(b) => TungMessage::Binary(b.into()),
153 Message::Ping(b) => TungMessage::Ping(b.into()),
154 Message::Pong(b) => TungMessage::Pong(b.into()),
155 Message::Close => TungMessage::Close(None),
156 }
157}
158
159pub struct WsConn<C> {
161 reader: WsReader<C>,
162 writer: WsWriter<C>,
163}
164
165impl<C: Clone> WsConn<C> {
166 fn new(
167 ws: WebSocketStream<WsInner>,
168 codec: C,
169 peer_endpoint: Option<Endpoint>,
170 local_endpoint: Option<Endpoint>,
171 ) -> Self {
172 let (sender, receiver) = ws.split();
173 Self {
174 reader: WsReader {
175 receiver,
176 codec: codec.clone(),
177 peer_endpoint: peer_endpoint.clone(),
178 local_endpoint: local_endpoint.clone(),
179 },
180 writer: WsWriter {
181 sender,
182 codec,
183 peer_endpoint,
184 local_endpoint,
185 },
186 }
187 }
188}
189
190impl<C> WsConn<C>
191where
192 C: Codec<Message> + Clone + Send + Sync + 'static,
193 C::Message: Send + Sync + 'static,
194 C::Error: From<io::Error> + Into<Error> + Send + Sync,
195{
196 pub async fn recv_msg(&mut self) -> Result<C::Message> {
198 self.reader.recv_msg().await
199 }
200
201 pub async fn send_msg(&mut self, msg: C::Message) -> Result<()> {
203 self.writer.send_msg(msg).await
204 }
205
206 pub fn peer_endpoint(&self) -> Option<Endpoint> {
208 self.reader.peer_endpoint.clone()
209 }
210
211 pub fn local_endpoint(&self) -> Option<Endpoint> {
213 self.reader.local_endpoint.clone()
214 }
215
216 pub fn split(self) -> (WsReader<C>, WsWriter<C>) {
218 (self.reader, self.writer)
219 }
220}
221
222pub struct WsReader<C> {
224 receiver: WebSocketReceiver<WsInner>,
225 codec: C,
226 peer_endpoint: Option<Endpoint>,
227 local_endpoint: Option<Endpoint>,
228}
229
230impl<C> WsReader<C>
231where
232 C: Codec<Message> + Send + Sync,
233 C::Message: Send + Sync,
234 C::Error: From<io::Error> + Into<Error> + Send + Sync,
235{
236 pub async fn recv_msg(&mut self) -> Result<C::Message> {
238 loop {
239 let raw = self
240 .receiver
241 .next()
242 .await
243 .ok_or(Error::ConnectionClosed)?
244 .map_err(|e| Error::IO(io::Error::other(e)))?;
245
246 let mut ws_msg = from_tung(raw);
247 match self.codec.decode(&mut ws_msg).map_err(Into::into)? {
248 Some((_, item)) => return Ok(item),
249 None => continue,
250 }
251 }
252 }
253
254 pub fn peer_endpoint(&self) -> Option<Endpoint> {
256 self.peer_endpoint.clone()
257 }
258
259 pub fn local_endpoint(&self) -> Option<Endpoint> {
261 self.local_endpoint.clone()
262 }
263}
264
265impl<C> MessageRx for WsReader<C>
266where
267 C: Codec<Message> + Send + Sync,
268 C::Message: Send + Sync,
269 C::Error: From<io::Error> + Into<Error> + Send + Sync,
270{
271 type Message = C::Message;
272
273 fn recv_msg(&mut self) -> impl std::future::Future<Output = Result<Self::Message>> + Send {
274 WsReader::recv_msg(self)
275 }
276
277 fn peer_endpoint(&self) -> Option<Endpoint> {
278 WsReader::peer_endpoint(self)
279 }
280}
281
282pub struct WsWriter<C> {
284 sender: WebSocketSender<WsInner>,
285 codec: C,
286 peer_endpoint: Option<Endpoint>,
287 local_endpoint: Option<Endpoint>,
288}
289
290impl<C> WsWriter<C>
291where
292 C: Codec<Message> + Send + Sync,
293 C::Message: Send + Sync,
294 C::Error: From<io::Error> + Into<Error> + Send + Sync,
295{
296 pub async fn send_msg(&mut self, msg: C::Message) -> Result<()> {
298 let mut ws_msg = Message::Binary(Vec::new());
299 self.codec.encode(&msg, &mut ws_msg).map_err(Into::into)?;
300
301 self.sender
302 .send(into_tung(ws_msg))
303 .await
304 .map_err(|e| Error::IO(io::Error::other(e)))?;
305 Ok(())
306 }
307
308 pub fn peer_endpoint(&self) -> Option<Endpoint> {
310 self.peer_endpoint.clone()
311 }
312
313 pub fn local_endpoint(&self) -> Option<Endpoint> {
315 self.local_endpoint.clone()
316 }
317}
318
319impl<C> MessageTx for WsWriter<C>
320where
321 C: Codec<Message> + Send + Sync,
322 C::Message: Send + Sync,
323 C::Error: From<io::Error> + Into<Error> + Send + Sync,
324{
325 type Message = C::Message;
326
327 fn send_msg(
328 &mut self,
329 msg: Self::Message,
330 ) -> impl std::future::Future<Output = Result<()>> + Send {
331 WsWriter::send_msg(self, msg)
332 }
333
334 fn peer_endpoint(&self) -> Option<Endpoint> {
335 WsWriter::peer_endpoint(self)
336 }
337}