karyon_net/stream/
websocket.rs1use std::{
2 io::ErrorKind,
3 pin::Pin,
4 result::Result,
5 task::{Context, Poll},
6};
7
8use async_tungstenite::tungstenite::Message;
9use futures_util::{
10 stream::{SplitSink, SplitStream},
11 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
12};
13use pin_project_lite::pin_project;
14
15use async_tungstenite::tungstenite::Error;
16
17#[cfg(feature = "tokio")]
18type WebSocketStream<T> =
19 async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>;
20#[cfg(feature = "smol")]
21use async_tungstenite::WebSocketStream;
22
23use karyon_core::async_runtime::net::TcpStream;
24
25#[cfg(feature = "tls")]
26use crate::async_rustls::TlsStream;
27
28use crate::codec::WebSocketCodec;
29
30pub struct WsStream<C> {
31 inner: InnerWSConn,
32 codec: C,
33}
34
35impl<C, E> WsStream<C>
36where
37 C: WebSocketCodec<Error = E> + Clone,
38{
39 pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
40 Self {
41 inner: InnerWSConn::Plain(Box::new(conn)),
42 codec,
43 }
44 }
45
46 #[cfg(feature = "tls")]
47 pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
48 Self {
49 inner: InnerWSConn::Tls(Box::new(conn)),
50 codec,
51 }
52 }
53
54 pub fn split(self) -> (ReadWsStream<C>, WriteWsStream<C>) {
55 let (write, read) = self.inner.split();
56
57 (
58 ReadWsStream {
59 codec: self.codec.clone(),
60 inner: read,
61 },
62 WriteWsStream {
63 inner: write,
64 codec: self.codec,
65 },
66 )
67 }
68}
69
70pin_project! {
71 pub struct ReadWsStream<C> {
72 #[pin]
73 inner: SplitStream<InnerWSConn>,
74 codec: C,
75 }
76}
77
78pin_project! {
79 pub struct WriteWsStream<C> {
80 #[pin]
81 inner: SplitSink<InnerWSConn, Message>,
82 codec: C,
83 }
84}
85
86impl<C, E> ReadWsStream<C>
87where
88 C: WebSocketCodec<Error = E>,
89 E: From<Error>,
90{
91 pub async fn recv(&mut self) -> Result<C::Message, E> {
92 match self.inner.next().await {
93 Some(msg) => match self.codec.decode(&msg?)? {
94 Some(m) => Ok(m),
95 None => todo!(),
96 },
97 None => Err(Error::Io(std::io::Error::from(ErrorKind::ConnectionAborted)).into()),
98 }
99 }
100}
101
102impl<C, E> WriteWsStream<C>
103where
104 C: WebSocketCodec<Error = E>,
105 E: From<Error>,
106{
107 pub async fn send(&mut self, msg: C::Message) -> Result<(), E> {
108 let ws_msg = self.codec.encode(&msg)?;
109 Ok(self.inner.send(ws_msg).await?)
110 }
111}
112
113impl<C> Sink<Message> for WriteWsStream<C> {
114 type Error = Error;
115
116 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 self.project().inner.poll_ready(cx)
118 }
119
120 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
121 self.project().inner.start_send(item)
122 }
123
124 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 self.project().inner.poll_flush(cx)
126 }
127
128 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129 self.project().inner.poll_close(cx)
130 }
131}
132
133impl<C> Stream for ReadWsStream<C> {
134 type Item = Result<Message, Error>;
135
136 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137 self.inner.try_poll_next_unpin(cx)
138 }
139}
140
141enum InnerWSConn {
142 Plain(Box<WebSocketStream<TcpStream>>),
143 #[cfg(feature = "tls")]
144 Tls(Box<WebSocketStream<TlsStream<TcpStream>>>),
145}
146
147impl Sink<Message> for InnerWSConn {
148 type Error = Error;
149
150 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
151 match &mut *self {
152 InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_ready(cx),
153 #[cfg(feature = "tls")]
154 InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_ready(cx),
155 }
156 }
157
158 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
159 match &mut *self {
160 InnerWSConn::Plain(s) => Pin::new(s.as_mut()).start_send(item),
161 #[cfg(feature = "tls")]
162 InnerWSConn::Tls(s) => Pin::new(s.as_mut()).start_send(item),
163 }
164 }
165
166 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167 match &mut *self {
168 InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_flush(cx),
169 #[cfg(feature = "tls")]
170 InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
171 }
172 }
173
174 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175 match &mut *self {
176 InnerWSConn::Plain(s) => Pin::new(s.as_mut()).poll_close(cx),
177 #[cfg(feature = "tls")]
178 InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_close(cx),
179 }
180 }
181}
182
183impl Stream for InnerWSConn {
184 type Item = Result<Message, Error>;
185
186 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 match &mut *self {
188 InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx),
189 #[cfg(feature = "tls")]
190 InnerWSConn::Tls(s) => Pin::new(s.as_mut()).poll_next(cx),
191 }
192 }
193}