karyon_net/stream/
websocket.rs

1use std::{
2    io::ErrorKind,
3    pin::Pin,
4    result::Result,
5    task::{Context, Poll},
6};
7
8use async_tungstenite::tungstenite::Message;
9use futures_util::{
10    stream::{SplitSink, SplitStream},
11    Sink, SinkExt, Stream, StreamExt, TryStreamExt,
12};
13use pin_project_lite::pin_project;
14
15use async_tungstenite::tungstenite::Error;
16
17#[cfg(feature = "tokio")]
18type WebSocketStream<T> =
19    async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>;
20#[cfg(feature = "smol")]
21use async_tungstenite::WebSocketStream;
22
23use karyon_core::async_runtime::net::TcpStream;
24
25#[cfg(feature = "tls")]
26use crate::async_rustls::TlsStream;
27
28use crate::codec::WebSocketCodec;
29
30pub struct WsStream<C> {
31    inner: InnerWSConn,
32    codec: C,
33}
34
35impl<C, E> WsStream<C>
36where
37    C: WebSocketCodec<Error = E> + Clone,
38{
39    pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
40        Self {
41            inner: InnerWSConn::Plain(Box::new(conn)),
42            codec,
43        }
44    }
45
46    #[cfg(feature = "tls")]
47    pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
48        Self {
49            inner: InnerWSConn::Tls(Box::new(conn)),
50            codec,
51        }
52    }
53
54    pub fn split(self) -> (ReadWsStream<C>, WriteWsStream<C>) {
55        let (write, read) = self.inner.split();
56
57        (
58            ReadWsStream {
59                codec: self.codec.clone(),
60                inner: read,
61            },
62            WriteWsStream {
63                inner: write,
64                codec: self.codec,
65            },
66        )
67    }
68}
69
70pin_project! {
71    pub struct ReadWsStream<C> {
72        #[pin]
73        inner: SplitStream<InnerWSConn>,
74        codec: C,
75    }
76}
77
78pin_project! {
79    pub struct WriteWsStream<C> {
80        #[pin]
81        inner: SplitSink<InnerWSConn, Message>,
82        codec: C,
83    }
84}
85
86impl<C, E> ReadWsStream<C>
87where
88    C: WebSocketCodec<Error = E>,
89    E: From<Error>,
90{
91    pub async fn recv(&mut self) -> Result<C::Message, E> {
92        match self.inner.next().await {
93            Some(msg) => match self.codec.decode(&msg?)? {
94                Some(m) => Ok(m),
95                None => todo!(),
96            },
97            None => Err(Error::Io(std::io::Error::from(ErrorKind::ConnectionAborted)).into()),
98        }
99    }
100}
101
102impl<C, E> WriteWsStream<C>
103where
104    C: WebSocketCodec<Error = E>,
105    E: From<Error>,
106{
107    pub async fn send(&mut self, msg: C::Message) -> Result<(), E> {
108        let ws_msg = self.codec.encode(&msg)?;
109        Ok(self.inner.send(ws_msg).await?)
110    }
111}
112
113impl<C> Sink<Message> for WriteWsStream<C> {
114    type Error = Error;
115
116    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        self.project().inner.poll_ready(cx)
118    }
119
120    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
121        self.project().inner.start_send(item)
122    }
123
124    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125        self.project().inner.poll_flush(cx)
126    }
127
128    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129        self.project().inner.poll_close(cx)
130    }
131}
132
133impl<C> Stream for ReadWsStream<C> {
134    type Item = Result<Message, Error>;
135
136    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137        self.inner.try_poll_next_unpin(cx)
138    }
139}
140
141enum InnerWSConn {
142    Plain(Box<WebSocketStream<TcpStream>>),
143    #[cfg(feature = "tls")]
144    Tls(Box<WebSocketStream<TlsStream<TcpStream>>>),
145}
146
147impl Sink<Message> for InnerWSConn {
148    type Error = Error;
149
150    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
151        match &mut *self {
152            InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_ready(cx),
153            #[cfg(feature = "tls")]
154            InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_ready(cx),
155        }
156    }
157
158    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
159        match &mut *self {
160            InnerWSConn::Plain(s) => Pin::new(s.as_mut()).start_send(item),
161            #[cfg(feature = "tls")]
162            InnerWSConn::Tls(s) => Pin::new(s.as_mut()).start_send(item),
163        }
164    }
165
166    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167        match &mut *self {
168            InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_flush(cx),
169            #[cfg(feature = "tls")]
170            InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
171        }
172    }
173
174    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175        match &mut *self {
176            InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_close(cx),
177            #[cfg(feature = "tls")]
178            InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_close(cx),
179        }
180    }
181}
182
183impl Stream for InnerWSConn {
184    type Item = Result<Message, Error>;
185
186    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187        match &mut *self {
188            InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx),
189            #[cfg(feature = "tls")]
190            InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_next(cx),
191        }
192    }
193}