karyon_net/stream/
mod.rs

1#[cfg(feature = "ws")]
2mod websocket;
3
4#[cfg(feature = "ws")]
5pub use websocket::{ReadWsStream, WriteWsStream, WsStream};
6
7use std::{
8    io::ErrorKind,
9    pin::Pin,
10    result::Result,
11    task::{Context, Poll},
12};
13
14use futures_util::{
15    ready,
16    stream::{Stream, StreamExt},
17    Sink,
18};
19use pin_project_lite::pin_project;
20
21use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
22
23use crate::codec::{Buffer, ByteBuffer, Decoder, Encoder};
24
25/// The stream buffer size is limited by this value. This can be overridden by
26/// setting the KARYON_MAX_BUFFER_SIZE environment variable at compile time.
27const DEFAULT_MAX_BUFFER_SIZE: usize = 4096 * 4096; // 16MB
28
29/// Maximum number of bytes to read at a time to construct the stream buffer.
30const BUFFER_CHUNK_SIZE: usize = 1024 * 1024; // 1MB
31
32fn get_max_buffer_size() -> usize {
33    match std::option_env!("KARYON_MAX_BUFFER_SIZE") {
34        Some(max_buffer_size) => max_buffer_size
35            .parse::<usize>()
36            .unwrap_or(DEFAULT_MAX_BUFFER_SIZE),
37        None => DEFAULT_MAX_BUFFER_SIZE,
38    }
39}
40
41pub struct ReadStream<T, C> {
42    inner: T,
43    decoder: C,
44    buffer: ByteBuffer,
45}
46
47impl<T, C> ReadStream<T, C>
48where
49    T: AsyncRead + Unpin,
50    C: Decoder + Unpin,
51{
52    pub fn new(inner: T, decoder: C) -> Self {
53        Self {
54            inner,
55            decoder,
56            buffer: Buffer::new(get_max_buffer_size()),
57        }
58    }
59
60    pub async fn recv(&mut self) -> Result<C::DeMessage, C::DeError> {
61        match self.next().await {
62            Some(m) => m,
63            None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
64        }
65    }
66}
67
68pin_project! {
69    pub struct WriteStream<T, C> {
70        #[pin]
71        inner: T,
72        encoder: C,
73        high_water_mark: usize,
74        buffer: ByteBuffer,
75    }
76}
77
78impl<T, C> WriteStream<T, C>
79where
80    T: AsyncWrite + Unpin,
81    C: Encoder + Unpin,
82{
83    pub fn new(inner: T, encoder: C) -> Self {
84        Self {
85            inner,
86            encoder,
87            high_water_mark: 131072,
88            buffer: Buffer::new(get_max_buffer_size()),
89        }
90    }
91}
92
93impl<T, C> Stream for ReadStream<T, C>
94where
95    T: AsyncRead + Unpin,
96    C: Decoder + Unpin,
97{
98    type Item = Result<C::DeMessage, C::DeError>;
99
100    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
101        let this = &mut *self;
102
103        if let Some((n, item)) = this.decoder.decode(&mut this.buffer)? {
104            this.buffer.advance(n);
105            return Poll::Ready(Some(Ok(item)));
106        }
107
108        let mut buf = [0u8; BUFFER_CHUNK_SIZE];
109        #[cfg(feature = "tokio")]
110        let mut buf = tokio::io::ReadBuf::new(&mut buf);
111
112        loop {
113            #[cfg(feature = "smol")]
114            let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
115            #[cfg(feature = "smol")]
116            let bytes = &buf[..n];
117
118            #[cfg(feature = "tokio")]
119            ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
120            #[cfg(feature = "tokio")]
121            let bytes = buf.filled();
122            #[cfg(feature = "tokio")]
123            let n = bytes.len();
124
125            this.buffer.extend_from_slice(bytes);
126
127            #[cfg(feature = "tokio")]
128            buf.clear();
129
130            match this.decoder.decode(&mut this.buffer)? {
131                Some((cn, item)) => {
132                    this.buffer.advance(cn);
133                    return Poll::Ready(Some(Ok(item)));
134                }
135                None if n == 0 => {
136                    if this.buffer.is_empty() {
137                        return Poll::Ready(None);
138                    } else {
139                        return Poll::Ready(Some(Err(std::io::Error::new(
140                            std::io::ErrorKind::UnexpectedEof,
141                            "bytes remaining in read stream",
142                        )
143                        .into())));
144                    }
145                }
146                _ => continue,
147            }
148        }
149    }
150}
151
152impl<T, C> Sink<C::EnMessage> for WriteStream<T, C>
153where
154    T: AsyncWrite + Unpin,
155    C: Encoder + Unpin,
156{
157    type Error = C::EnError;
158
159    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
160        let this = &mut *self;
161        while !this.buffer.is_empty() {
162            let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
163
164            if n == 0 {
165                return Poll::Ready(Err(std::io::Error::new(
166                    ErrorKind::UnexpectedEof,
167                    "End of file",
168                )
169                .into()));
170            }
171
172            this.buffer.advance(n);
173        }
174
175        Poll::Ready(Ok(()))
176    }
177
178    fn start_send(mut self: Pin<&mut Self>, item: C::EnMessage) -> Result<(), Self::Error> {
179        let this = &mut *self;
180        this.encoder.encode(&item, &mut this.buffer)?;
181        Ok(())
182    }
183
184    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
185        ready!(self.as_mut().poll_ready(cx))?;
186        self.project().inner.poll_flush(cx).map_err(Into::into)
187    }
188
189    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
190        ready!(self.as_mut().poll_flush(cx))?;
191        #[cfg(feature = "smol")]
192        return self.project().inner.poll_close(cx).map_err(|e| e.into());
193
194        #[cfg(feature = "tokio")]
195        return self.project().inner.poll_shutdown(cx).map_err(|e| e.into());
196    }
197}