1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use rustls_pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::async_rustls::rustls;
13
14use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
15
16#[cfg(feature = "tokio")]
17use tokio::io::{AsyncWrite as TokioAsyncWrite, ReadBuf};
18
19#[cfg(feature = "smol")]
20use futures_util::io::{AsyncRead as FutAsyncRead, AsyncWrite as FutAsyncWrite};
21
22use crate::{stream::ByteStream, stream_mux::StreamMux, Bytes, Endpoint, Error, Result};
23
24const DEFAULT_READ_CHUNK_SIZE: usize = 1024 * 1024; pub type QuicSendStream = quinn::SendStream;
29
30pub type QuicRecvStream = quinn::RecvStream;
32
33#[derive(Clone)]
35pub struct QuicConfig {
36 pub max_bi_streams: u64,
38 pub max_uni_streams: u64,
40 pub keep_alive_interval: Option<Duration>,
42 pub idle_timeout: Option<Duration>,
44 pub enable_datagrams: bool,
46 pub read_chunk_size: usize,
48}
49
50impl Default for QuicConfig {
51 fn default() -> Self {
52 Self {
53 max_bi_streams: 100,
54 max_uni_streams: 100,
55 keep_alive_interval: Some(Duration::from_secs(5)),
56 idle_timeout: Some(Duration::from_secs(30)),
57 enable_datagrams: false,
58 read_chunk_size: DEFAULT_READ_CHUNK_SIZE,
59 }
60 }
61}
62
63#[derive(Clone)]
66pub struct ServerQuicConfig {
67 source: ServerSource,
68 config: QuicConfig,
69}
70
71#[derive(Clone)]
72enum ServerSource {
73 Certs {
74 cert_chain: Vec<CertificateDer<'static>>,
75 private_key: Arc<PrivateKeyDer<'static>>,
76 },
77 Rustls(rustls::ServerConfig),
78}
79
80impl ServerQuicConfig {
81 pub fn new(
83 cert_chain: Vec<CertificateDer<'static>>,
84 private_key: Arc<PrivateKeyDer<'static>>,
85 ) -> Self {
86 Self {
87 source: ServerSource::Certs {
88 cert_chain,
89 private_key,
90 },
91 config: QuicConfig::default(),
92 }
93 }
94
95 pub fn from_rustls(rustls_config: rustls::ServerConfig) -> Self {
98 Self {
99 source: ServerSource::Rustls(rustls_config),
100 config: QuicConfig::default(),
101 }
102 }
103
104 pub fn with_config(mut self, config: QuicConfig) -> Self {
106 self.config = config;
107 self
108 }
109
110 pub(crate) fn build(self) -> Result<quinn::ServerConfig> {
111 let mut server_config = match self.source {
112 ServerSource::Certs {
113 cert_chain,
114 private_key,
115 } => quinn::ServerConfig::with_single_cert(cert_chain, private_key.clone_key())
116 .map_err(|e| Error::TlsConfig(e.to_string()))?,
117 ServerSource::Rustls(rustls_config) => {
118 let quic_config = quinn::crypto::rustls::QuicServerConfig::try_from(rustls_config)
119 .map_err(|e| Error::QuicConfigError(e.to_string()))?;
120 quinn::ServerConfig::with_crypto(Arc::new(quic_config))
121 }
122 };
123 server_config.transport_config(Arc::new(build_transport_config(&self.config)));
124 Ok(server_config)
125 }
126}
127
128#[derive(Clone)]
131pub struct ClientQuicConfig {
132 source: ClientSource,
133 server_name: String,
134 config: QuicConfig,
135}
136
137#[derive(Clone)]
138enum ClientSource {
139 Roots(Vec<CertificateDer<'static>>),
140 Rustls(Box<rustls::ClientConfig>),
141}
142
143impl ClientQuicConfig {
144 pub fn new(root_certs: Vec<CertificateDer<'static>>, server_name: impl Into<String>) -> Self {
146 Self {
147 source: ClientSource::Roots(root_certs),
148 server_name: server_name.into(),
149 config: QuicConfig::default(),
150 }
151 }
152
153 pub fn from_rustls(
156 rustls_config: rustls::ClientConfig,
157 server_name: impl Into<String>,
158 ) -> Self {
159 Self {
160 source: ClientSource::Rustls(Box::new(rustls_config)),
161 server_name: server_name.into(),
162 config: QuicConfig::default(),
163 }
164 }
165
166 pub fn with_config(mut self, config: QuicConfig) -> Self {
168 self.config = config;
169 self
170 }
171
172 pub fn server_name(&self) -> &str {
173 &self.server_name
174 }
175
176 pub(crate) fn build(self) -> Result<quinn::ClientConfig> {
177 let mut client_config = match self.source {
178 ClientSource::Roots(root_certs) => {
179 let mut root_store = rustls::RootCertStore::empty();
180 for cert in root_certs {
181 root_store
182 .add(cert)
183 .map_err(|e| Error::TlsConfig(e.to_string()))?;
184 }
185 quinn::ClientConfig::with_root_certificates(Arc::new(root_store))
186 .map_err(|e| Error::TlsConfig(e.to_string()))?
187 }
188 ClientSource::Rustls(rustls_config) => {
189 let quic_config = quinn::crypto::rustls::QuicClientConfig::try_from(*rustls_config)
190 .map_err(|e| Error::QuicConfigError(e.to_string()))?;
191 quinn::ClientConfig::new(Arc::new(quic_config))
192 }
193 };
194 client_config.transport_config(Arc::new(build_transport_config(&self.config)));
195 Ok(client_config)
196 }
197}
198
199fn build_transport_config(config: &QuicConfig) -> quinn::TransportConfig {
200 let mut transport = quinn::TransportConfig::default();
201 transport.max_concurrent_bidi_streams(
202 quinn::VarInt::from_u64(config.max_bi_streams).unwrap_or(quinn::VarInt::from_u32(100u32)),
203 );
204 transport.max_concurrent_uni_streams(
205 quinn::VarInt::from_u64(config.max_uni_streams).unwrap_or(quinn::VarInt::from_u32(100u32)),
206 );
207 if let Some(interval) = config.keep_alive_interval {
208 transport.keep_alive_interval(Some(interval));
209 }
210 if let Some(timeout) = config.idle_timeout {
211 if let Ok(idle) = quinn::IdleTimeout::try_from(timeout) {
212 transport.max_idle_timeout(Some(idle));
213 }
214 }
215 if !config.enable_datagrams {
217 transport.datagram_receive_buffer_size(None);
218 }
219 transport
220}
221
222pub struct QuicEndpoint {
224 inner: quinn::Endpoint,
225 local_endpoint: Endpoint,
226}
227
228impl QuicEndpoint {
229 pub async fn listen(endpoint: &Endpoint, config: ServerQuicConfig) -> Result<Self> {
231 let addr = SocketAddr::try_from(endpoint.clone())?;
232 let server_config = config.build()?;
233 let inner = quinn::Endpoint::server(server_config, addr)?;
234 let local_addr = inner.local_addr()?;
235 let local_endpoint = Endpoint::new_quic_addr(local_addr);
236 Ok(Self {
237 inner,
238 local_endpoint,
239 })
240 }
241
242 pub async fn dial(endpoint: &Endpoint, config: ClientQuicConfig) -> Result<QuicConn> {
244 let addr = SocketAddr::try_from(endpoint.clone())?;
245 let server_name = config.server_name.clone();
246 let client_config = config.build()?;
247 let bind_addr: SocketAddr = if addr.is_ipv6() {
249 "[::]:0".parse().unwrap()
250 } else {
251 "0.0.0.0:0".parse().unwrap()
252 };
253 let mut quinn_endpoint = quinn::Endpoint::client(bind_addr)?;
254 quinn_endpoint.set_default_client_config(client_config);
255
256 let connection = quinn_endpoint.connect(addr, &server_name)?.await?;
257 let peer_endpoint = Endpoint::new_quic_addr(connection.remote_address());
258 let local_addr = quinn_endpoint.local_addr()?;
259 let local_endpoint = Endpoint::new_quic_addr(local_addr);
260
261 Ok(QuicConn {
262 inner: connection,
263 peer_endpoint,
264 local_endpoint,
265 })
266 }
267
268 pub async fn accept(&self) -> Result<QuicConn> {
270 let incoming = self.inner.accept().await.ok_or(Error::ConnectionClosed)?;
271
272 let connection = incoming.await?;
273 let peer_endpoint = Endpoint::new_quic_addr(connection.remote_address());
274 let local_endpoint = self.local_endpoint.clone();
275
276 Ok(QuicConn {
277 inner: connection,
278 peer_endpoint,
279 local_endpoint,
280 })
281 }
282
283 pub fn local_endpoint(&self) -> Result<Endpoint> {
285 Ok(self.local_endpoint.clone())
286 }
287
288 pub fn close(&self, code: u32, reason: &[u8]) {
290 self.inner.close(quinn::VarInt::from_u32(code), reason);
291 }
292}
293
294pub struct QuicConn {
297 inner: quinn::Connection,
298 peer_endpoint: Endpoint,
299 local_endpoint: Endpoint,
300}
301
302impl QuicConn {
303 pub async fn open_bi(&self) -> Result<(QuicSendStream, QuicRecvStream)> {
305 let (send, recv) = self.inner.open_bi().await?;
306 Ok((send, recv))
307 }
308
309 pub async fn open_uni(&self) -> Result<QuicSendStream> {
311 let send = self.inner.open_uni().await?;
312 Ok(send)
313 }
314
315 pub async fn accept_bi(&self) -> Result<(QuicSendStream, QuicRecvStream)> {
317 let (send, recv) = self.inner.accept_bi().await?;
318 Ok((send, recv))
319 }
320
321 pub async fn accept_uni(&self) -> Result<QuicRecvStream> {
323 let recv = self.inner.accept_uni().await?;
324 Ok(recv)
325 }
326
327 pub fn send_datagram(&self, data: Bytes) -> Result<()> {
330 self.inner
331 .send_datagram(data.into_inner())
332 .map_err(|e| Error::QuicConfigError(e.to_string()))?;
333 Ok(())
334 }
335
336 pub async fn recv_datagram(&self) -> Result<Bytes> {
339 let data = self.inner.read_datagram().await?;
340 Ok(Bytes::from_inner(data))
341 }
342
343 pub fn max_datagram_size(&self) -> Option<usize> {
345 self.inner.max_datagram_size()
346 }
347
348 pub fn peer_endpoint(&self) -> Result<Endpoint> {
350 Ok(self.peer_endpoint.clone())
351 }
352
353 pub fn local_endpoint(&self) -> Result<Endpoint> {
355 Ok(self.local_endpoint.clone())
356 }
357
358 pub fn rtt(&self) -> Duration {
360 self.inner.rtt()
361 }
362
363 pub fn close(&self, code: u32, reason: &[u8]) {
365 self.inner.close(quinn::VarInt::from_u32(code), reason);
366 }
367
368 pub async fn closed(&self) -> quinn::ConnectionError {
370 self.inner.closed().await
371 }
372
373 pub fn inner(&self) -> &quinn::Connection {
375 &self.inner
376 }
377
378 pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
382 let any = self.inner.peer_identity()?;
383 any.downcast::<Vec<CertificateDer<'static>>>()
384 .ok()
385 .map(|b| *b)
386 }
387}
388
389impl StreamMux for QuicConn {
392 async fn open_stream(&self) -> Result<Box<dyn ByteStream>> {
393 let (send, recv) = self.open_bi().await?;
394 Ok(Box::new(QuicBiStream {
395 send,
396 recv,
397 peer_endpoint: self.peer_endpoint.clone(),
398 local_endpoint: self.local_endpoint.clone(),
399 }))
400 }
401
402 async fn accept_stream(&self) -> Result<Box<dyn ByteStream>> {
403 let (send, recv) = self.accept_bi().await?;
404 Ok(Box::new(QuicBiStream {
405 send,
406 recv,
407 peer_endpoint: self.peer_endpoint.clone(),
408 local_endpoint: self.local_endpoint.clone(),
409 }))
410 }
411
412 fn peer_endpoint(&self) -> Option<Endpoint> {
413 Some(self.peer_endpoint.clone())
414 }
415
416 fn local_endpoint(&self) -> Option<Endpoint> {
417 Some(self.local_endpoint.clone())
418 }
419}
420
421pub struct QuicBiStream {
423 send: QuicSendStream,
424 recv: QuicRecvStream,
425 peer_endpoint: Endpoint,
426 local_endpoint: Endpoint,
427}
428
429impl ByteStream for QuicBiStream {
430 fn peer_endpoint(&self) -> Option<Endpoint> {
431 Some(self.peer_endpoint.clone())
432 }
433 fn local_endpoint(&self) -> Option<Endpoint> {
434 Some(self.local_endpoint.clone())
435 }
436}
437
438#[cfg(feature = "tokio")]
442impl AsyncRead for QuicBiStream {
443 fn poll_read(
444 mut self: Pin<&mut Self>,
445 cx: &mut Context<'_>,
446 buf: &mut ReadBuf<'_>,
447 ) -> Poll<io::Result<()>> {
448 Pin::new(&mut self.recv).poll_read(cx, buf)
449 }
450}
451
452#[cfg(feature = "smol")]
453impl AsyncRead for QuicBiStream {
454 fn poll_read(
455 mut self: Pin<&mut Self>,
456 cx: &mut Context<'_>,
457 buf: &mut [u8],
458 ) -> Poll<io::Result<usize>> {
459 FutAsyncRead::poll_read(Pin::new(&mut self.recv), cx, buf)
460 }
461}
462
463#[cfg(feature = "tokio")]
464impl AsyncWrite for QuicBiStream {
465 fn poll_write(
466 mut self: Pin<&mut Self>,
467 cx: &mut Context<'_>,
468 buf: &[u8],
469 ) -> Poll<io::Result<usize>> {
470 TokioAsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf)
471 }
472
473 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
474 TokioAsyncWrite::poll_flush(Pin::new(&mut self.send), cx)
475 }
476
477 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
478 TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.send), cx)
479 }
480}
481
482#[cfg(feature = "smol")]
483impl AsyncWrite for QuicBiStream {
484 fn poll_write(
485 mut self: Pin<&mut Self>,
486 cx: &mut Context<'_>,
487 buf: &[u8],
488 ) -> Poll<io::Result<usize>> {
489 FutAsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf)
490 }
491
492 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
493 FutAsyncWrite::poll_flush(Pin::new(&mut self.send), cx)
494 }
495
496 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
497 FutAsyncWrite::poll_close(Pin::new(&mut self.send), cx)
498 }
499}