1#[cfg(feature = "ws")]
2mod websocket;
3
4#[cfg(feature = "ws")]
5pub use websocket::{ReadWsStream, WriteWsStream, WsStream};
6
7use std::{
8 io::ErrorKind,
9 pin::Pin,
10 result::Result,
11 task::{Context, Poll},
12};
13
14use futures_util::{
15 ready,
16 stream::{Stream, StreamExt},
17 Sink,
18};
19use pin_project_lite::pin_project;
20
21use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
22
23use crate::codec::{Buffer, ByteBuffer, Decoder, Encoder};
24
25const BUFFER_CHUNK_SIZE: usize = 1024 * 1024; pub struct ReadStream<T, C> {
29 inner: T,
30 decoder: C,
31 buffer: ByteBuffer,
32}
33
34impl<T, C> ReadStream<T, C>
35where
36 T: AsyncRead + Unpin,
37 C: Decoder + Unpin,
38{
39 pub fn new(inner: T, decoder: C) -> Self {
40 Self {
41 inner,
42 decoder,
43 buffer: Buffer::new(),
44 }
45 }
46
47 pub async fn recv(&mut self) -> Result<C::DeMessage, C::DeError> {
48 match self.next().await {
49 Some(m) => m,
50 None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
51 }
52 }
53}
54
55pin_project! {
56 pub struct WriteStream<T, C> {
57 #[pin]
58 inner: T,
59 encoder: C,
60 buffer: ByteBuffer,
61 }
62}
63
64impl<T, C> WriteStream<T, C>
65where
66 T: AsyncWrite + Unpin,
67 C: Encoder + Unpin,
68{
69 pub fn new(inner: T, encoder: C) -> Self {
70 Self {
71 inner,
72 encoder,
73 buffer: Buffer::new(),
74 }
75 }
76}
77
78impl<T, C> Stream for ReadStream<T, C>
79where
80 T: AsyncRead + Unpin,
81 C: Decoder + Unpin,
82{
83 type Item = Result<C::DeMessage, C::DeError>;
84
85 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86 let this = &mut *self;
87
88 if let Some((n, item)) = this.decoder.decode(&mut this.buffer)? {
89 this.buffer.advance(n);
90 return Poll::Ready(Some(Ok(item)));
91 }
92
93 loop {
94 let mut buf = [0u8; BUFFER_CHUNK_SIZE];
95 #[cfg(feature = "tokio")]
96 let mut buf = tokio::io::ReadBuf::new(&mut buf);
97
98 #[cfg(feature = "smol")]
99 let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
100 #[cfg(feature = "smol")]
101 let bytes = &buf[..n];
102
103 #[cfg(feature = "tokio")]
104 ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
105 #[cfg(feature = "tokio")]
106 let bytes = buf.filled();
107 #[cfg(feature = "tokio")]
108 let n = bytes.len();
109
110 this.buffer.extend_from_slice(bytes);
111
112 match this.decoder.decode(&mut this.buffer)? {
113 Some((cn, item)) => {
114 this.buffer.advance(cn);
115 return Poll::Ready(Some(Ok(item)));
116 }
117 None if n == 0 => {
118 if this.buffer.is_empty() {
119 return Poll::Ready(None);
120 } else {
121 return Poll::Ready(Some(Err(std::io::Error::new(
122 std::io::ErrorKind::UnexpectedEof,
123 "bytes remaining in read stream",
124 )
125 .into())));
126 }
127 }
128 _ => continue,
129 }
130 }
131 }
132}
133
134impl<T, C> Sink<C::EnMessage> for WriteStream<T, C>
135where
136 T: AsyncWrite + Unpin,
137 C: Encoder + Unpin,
138{
139 type Error = C::EnError;
140
141 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
142 let this = &mut *self;
143 while !this.buffer.is_empty() {
144 let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
145
146 if n == 0 {
147 return Poll::Ready(Err(std::io::Error::from(ErrorKind::UnexpectedEof).into()));
148 }
149
150 this.buffer.advance(n);
151 }
152
153 Poll::Ready(Ok(()))
154 }
155
156 fn start_send(mut self: Pin<&mut Self>, item: C::EnMessage) -> Result<(), Self::Error> {
157 let this = &mut *self;
158 this.encoder.encode(&item, &mut this.buffer)?;
159 Ok(())
160 }
161
162 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
163 ready!(self.as_mut().poll_ready(cx))?;
164 self.project().inner.poll_flush(cx).map_err(Into::into)
165 }
166
167 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
168 ready!(self.as_mut().poll_flush(cx))?;
169 #[cfg(feature = "smol")]
170 return self.project().inner.poll_close(cx).map_err(|e| e.into());
171
172 #[cfg(feature = "tokio")]
173 return self.project().inner.poll_shutdown(cx).map_err(|e| e.into());
174 }
175}