karyon_net/stream/
mod.rs

1#[cfg(feature = "ws")]
2mod websocket;
3
4#[cfg(feature = "ws")]
5pub use websocket::{ReadWsStream, WriteWsStream, WsStream};
6
7use std::{
8    io::ErrorKind,
9    pin::Pin,
10    result::Result,
11    task::{Context, Poll},
12};
13
14use futures_util::{
15    ready,
16    stream::{Stream, StreamExt},
17    Sink,
18};
19use pin_project_lite::pin_project;
20
21use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
22
23use crate::codec::{Buffer, ByteBuffer, Decoder, Encoder};
24
25/// Maximum number of bytes to read at a time to construct the stream buffer.
26const BUFFER_CHUNK_SIZE: usize = 1024 * 1024; // 1MB
27
28pub struct ReadStream<T, C> {
29    inner: T,
30    decoder: C,
31    buffer: ByteBuffer,
32}
33
34impl<T, C> ReadStream<T, C>
35where
36    T: AsyncRead + Unpin,
37    C: Decoder + Unpin,
38{
39    pub fn new(inner: T, decoder: C) -> Self {
40        Self {
41            inner,
42            decoder,
43            buffer: Buffer::new(),
44        }
45    }
46
47    pub async fn recv(&mut self) -> Result<C::DeMessage, C::DeError> {
48        match self.next().await {
49            Some(m) => m,
50            None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
51        }
52    }
53}
54
55pin_project! {
56    pub struct WriteStream<T, C> {
57        #[pin]
58        inner: T,
59        encoder: C,
60        buffer: ByteBuffer,
61    }
62}
63
64impl<T, C> WriteStream<T, C>
65where
66    T: AsyncWrite + Unpin,
67    C: Encoder + Unpin,
68{
69    pub fn new(inner: T, encoder: C) -> Self {
70        Self {
71            inner,
72            encoder,
73            buffer: Buffer::new(),
74        }
75    }
76}
77
78impl<T, C> Stream for ReadStream<T, C>
79where
80    T: AsyncRead + Unpin,
81    C: Decoder + Unpin,
82{
83    type Item = Result<C::DeMessage, C::DeError>;
84
85    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86        let this = &mut *self;
87
88        if let Some((n, item)) = this.decoder.decode(&mut this.buffer)? {
89            this.buffer.advance(n);
90            return Poll::Ready(Some(Ok(item)));
91        }
92
93        loop {
94            let mut buf = [0u8; BUFFER_CHUNK_SIZE];
95            #[cfg(feature = "tokio")]
96            let mut buf = tokio::io::ReadBuf::new(&mut buf);
97
98            #[cfg(feature = "smol")]
99            let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
100            #[cfg(feature = "smol")]
101            let bytes = &buf[..n];
102
103            #[cfg(feature = "tokio")]
104            ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
105            #[cfg(feature = "tokio")]
106            let bytes = buf.filled();
107            #[cfg(feature = "tokio")]
108            let n = bytes.len();
109
110            this.buffer.extend_from_slice(bytes);
111
112            match this.decoder.decode(&mut this.buffer)? {
113                Some((cn, item)) => {
114                    this.buffer.advance(cn);
115                    return Poll::Ready(Some(Ok(item)));
116                }
117                None if n == 0 => {
118                    if this.buffer.is_empty() {
119                        return Poll::Ready(None);
120                    } else {
121                        return Poll::Ready(Some(Err(std::io::Error::new(
122                            std::io::ErrorKind::UnexpectedEof,
123                            "bytes remaining in read stream",
124                        )
125                        .into())));
126                    }
127                }
128                _ => continue,
129            }
130        }
131    }
132}
133
134impl<T, C> Sink<C::EnMessage> for WriteStream<T, C>
135where
136    T: AsyncWrite + Unpin,
137    C: Encoder + Unpin,
138{
139    type Error = C::EnError;
140
141    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
142        let this = &mut *self;
143        while !this.buffer.is_empty() {
144            let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
145
146            if n == 0 {
147                return Poll::Ready(Err(std::io::Error::from(ErrorKind::UnexpectedEof).into()));
148            }
149
150            this.buffer.advance(n);
151        }
152
153        Poll::Ready(Ok(()))
154    }
155
156    fn start_send(mut self: Pin<&mut Self>, item: C::EnMessage) -> Result<(), Self::Error> {
157        let this = &mut *self;
158        this.encoder.encode(&item, &mut this.buffer)?;
159        Ok(())
160    }
161
162    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
163        ready!(self.as_mut().poll_ready(cx))?;
164        self.project().inner.poll_flush(cx).map_err(Into::into)
165    }
166
167    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
168        ready!(self.as_mut().poll_flush(cx))?;
169        #[cfg(feature = "smol")]
170        return self.project().inner.poll_close(cx).map_err(|e| e.into());
171
172        #[cfg(feature = "tokio")]
173        return self.project().inner.poll_shutdown(cx).map_err(|e| e.into());
174    }
175}