1#[cfg(feature = "ws")]
2mod websocket;
3
4#[cfg(feature = "ws")]
5pub use websocket::{ReadWsStream, WriteWsStream, WsStream};
6
7use std::{
8 io::ErrorKind,
9 pin::Pin,
10 result::Result,
11 task::{Context, Poll},
12};
13
14use futures_util::{
15 ready,
16 stream::{Stream, StreamExt},
17 Sink,
18};
19use pin_project_lite::pin_project;
20
21use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
22
23use crate::codec::{Buffer, ByteBuffer, Decoder, Encoder};
24
25const DEFAULT_MAX_BUFFER_SIZE: usize = 4096 * 4096; const BUFFER_CHUNK_SIZE: usize = 1024 * 1024; fn get_max_buffer_size() -> usize {
33 match std::option_env!("KARYON_MAX_BUFFER_SIZE") {
34 Some(max_buffer_size) => max_buffer_size
35 .parse::<usize>()
36 .unwrap_or(DEFAULT_MAX_BUFFER_SIZE),
37 None => DEFAULT_MAX_BUFFER_SIZE,
38 }
39}
40
41pub struct ReadStream<T, C> {
42 inner: T,
43 decoder: C,
44 buffer: ByteBuffer,
45}
46
47impl<T, C> ReadStream<T, C>
48where
49 T: AsyncRead + Unpin,
50 C: Decoder + Unpin,
51{
52 pub fn new(inner: T, decoder: C) -> Self {
53 Self {
54 inner,
55 decoder,
56 buffer: Buffer::new(get_max_buffer_size()),
57 }
58 }
59
60 pub async fn recv(&mut self) -> Result<C::DeMessage, C::DeError> {
61 match self.next().await {
62 Some(m) => m,
63 None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
64 }
65 }
66}
67
68pin_project! {
69 pub struct WriteStream<T, C> {
70 #[pin]
71 inner: T,
72 encoder: C,
73 high_water_mark: usize,
74 buffer: ByteBuffer,
75 }
76}
77
78impl<T, C> WriteStream<T, C>
79where
80 T: AsyncWrite + Unpin,
81 C: Encoder + Unpin,
82{
83 pub fn new(inner: T, encoder: C) -> Self {
84 Self {
85 inner,
86 encoder,
87 high_water_mark: 131072,
88 buffer: Buffer::new(get_max_buffer_size()),
89 }
90 }
91}
92
93impl<T, C> Stream for ReadStream<T, C>
94where
95 T: AsyncRead + Unpin,
96 C: Decoder + Unpin,
97{
98 type Item = Result<C::DeMessage, C::DeError>;
99
100 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
101 let this = &mut *self;
102
103 if let Some((n, item)) = this.decoder.decode(&mut this.buffer)? {
104 this.buffer.advance(n);
105 return Poll::Ready(Some(Ok(item)));
106 }
107
108 let mut buf = [0u8; BUFFER_CHUNK_SIZE];
109 #[cfg(feature = "tokio")]
110 let mut buf = tokio::io::ReadBuf::new(&mut buf);
111
112 loop {
113 #[cfg(feature = "smol")]
114 let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
115 #[cfg(feature = "smol")]
116 let bytes = &buf[..n];
117
118 #[cfg(feature = "tokio")]
119 ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
120 #[cfg(feature = "tokio")]
121 let bytes = buf.filled();
122 #[cfg(feature = "tokio")]
123 let n = bytes.len();
124
125 this.buffer.extend_from_slice(bytes);
126
127 #[cfg(feature = "tokio")]
128 buf.clear();
129
130 match this.decoder.decode(&mut this.buffer)? {
131 Some((cn, item)) => {
132 this.buffer.advance(cn);
133 return Poll::Ready(Some(Ok(item)));
134 }
135 None if n == 0 => {
136 if this.buffer.is_empty() {
137 return Poll::Ready(None);
138 } else {
139 return Poll::Ready(Some(Err(std::io::Error::new(
140 std::io::ErrorKind::UnexpectedEof,
141 "bytes remaining in read stream",
142 )
143 .into())));
144 }
145 }
146 _ => continue,
147 }
148 }
149 }
150}
151
152impl<T, C> Sink<C::EnMessage> for WriteStream<T, C>
153where
154 T: AsyncWrite + Unpin,
155 C: Encoder + Unpin,
156{
157 type Error = C::EnError;
158
159 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
160 let this = &mut *self;
161 while !this.buffer.is_empty() {
162 let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
163
164 if n == 0 {
165 return Poll::Ready(Err(std::io::Error::new(
166 ErrorKind::UnexpectedEof,
167 "End of file",
168 )
169 .into()));
170 }
171
172 this.buffer.advance(n);
173 }
174
175 Poll::Ready(Ok(()))
176 }
177
178 fn start_send(mut self: Pin<&mut Self>, item: C::EnMessage) -> Result<(), Self::Error> {
179 let this = &mut *self;
180 this.encoder.encode(&item, &mut this.buffer)?;
181 Ok(())
182 }
183
184 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
185 ready!(self.as_mut().poll_ready(cx))?;
186 self.project().inner.poll_flush(cx).map_err(Into::into)
187 }
188
189 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
190 ready!(self.as_mut().poll_flush(cx))?;
191 #[cfg(feature = "smol")]
192 return self.project().inner.poll_close(cx).map_err(|e| e.into());
193
194 #[cfg(feature = "tokio")]
195 return self.project().inner.poll_shutdown(cx).map_err(|e| e.into());
196 }
197}