Skip to main content

karyon_net/transports/
quic.rs

1use std::{
2    io,
3    net::SocketAddr,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use rustls_pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::async_rustls::rustls;
13
14use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
15
16#[cfg(feature = "tokio")]
17use tokio::io::{AsyncWrite as TokioAsyncWrite, ReadBuf};
18
19#[cfg(feature = "smol")]
20use futures_util::io::{AsyncRead as FutAsyncRead, AsyncWrite as FutAsyncWrite};
21
22use crate::{stream::ByteStream, stream_mux::StreamMux, Bytes, Endpoint, Error, Result};
23
24/// Default read chunk size for QUIC streams.
25const DEFAULT_READ_CHUNK_SIZE: usize = 1024 * 1024; // 1MB
26
27/// A QUIC send stream. Re-exported from quinn for use by higher layers.
28pub type QuicSendStream = quinn::SendStream;
29
30/// A QUIC receive stream. Re-exported from quinn for use by higher layers.
31pub type QuicRecvStream = quinn::RecvStream;
32
33/// QUIC configuration.
34#[derive(Clone)]
35pub struct QuicConfig {
36    /// Maximum concurrent bidirectional streams.
37    pub max_bi_streams: u64,
38    /// Maximum concurrent unidirectional streams.
39    pub max_uni_streams: u64,
40    /// Keep-alive interval. None to disable.
41    pub keep_alive_interval: Option<Duration>,
42    /// Idle timeout.
43    pub idle_timeout: Option<Duration>,
44    /// Enable datagrams.
45    pub enable_datagrams: bool,
46    /// Read chunk size for stream reads (bytes).
47    pub read_chunk_size: usize,
48}
49
50impl Default for QuicConfig {
51    fn default() -> Self {
52        Self {
53            max_bi_streams: 100,
54            max_uni_streams: 100,
55            keep_alive_interval: Some(Duration::from_secs(5)),
56            idle_timeout: Some(Duration::from_secs(30)),
57            enable_datagrams: false,
58            read_chunk_size: DEFAULT_READ_CHUNK_SIZE,
59        }
60    }
61}
62
63/// Server-side QUIC configuration. Build from either a cert chain +
64/// private key (`new`) or a pre-built rustls config (`from_rustls`).
65#[derive(Clone)]
66pub struct ServerQuicConfig {
67    source: ServerSource,
68    config: QuicConfig,
69}
70
71#[derive(Clone)]
72enum ServerSource {
73    Certs {
74        cert_chain: Vec<CertificateDer<'static>>,
75        private_key: Arc<PrivateKeyDer<'static>>,
76    },
77    Rustls(rustls::ServerConfig),
78}
79
80impl ServerQuicConfig {
81    /// Create a config from a cert chain + private key.
82    pub fn new(
83        cert_chain: Vec<CertificateDer<'static>>,
84        private_key: Arc<PrivateKeyDer<'static>>,
85    ) -> Self {
86        Self {
87            source: ServerSource::Certs {
88                cert_chain,
89                private_key,
90            },
91            config: QuicConfig::default(),
92        }
93    }
94
95    /// Create a config from a pre-built rustls `ServerConfig` (for
96    /// custom verifiers, client-auth, etc).
97    pub fn from_rustls(rustls_config: rustls::ServerConfig) -> Self {
98        Self {
99            source: ServerSource::Rustls(rustls_config),
100            config: QuicConfig::default(),
101        }
102    }
103
104    /// Override the QUIC transport parameters.
105    pub fn with_config(mut self, config: QuicConfig) -> Self {
106        self.config = config;
107        self
108    }
109
110    pub(crate) fn build(self) -> Result<quinn::ServerConfig> {
111        let mut server_config = match self.source {
112            ServerSource::Certs {
113                cert_chain,
114                private_key,
115            } => quinn::ServerConfig::with_single_cert(cert_chain, private_key.clone_key())
116                .map_err(|e| Error::TlsConfig(e.to_string()))?,
117            ServerSource::Rustls(rustls_config) => {
118                let quic_config = quinn::crypto::rustls::QuicServerConfig::try_from(rustls_config)
119                    .map_err(|e| Error::QuicConfigError(e.to_string()))?;
120                quinn::ServerConfig::with_crypto(Arc::new(quic_config))
121            }
122        };
123        server_config.transport_config(Arc::new(build_transport_config(&self.config)));
124        Ok(server_config)
125    }
126}
127
128/// Client-side QUIC configuration. Build from either a root cert list
129/// (`new`) or a pre-built rustls config (`from_rustls`).
130#[derive(Clone)]
131pub struct ClientQuicConfig {
132    source: ClientSource,
133    server_name: String,
134    config: QuicConfig,
135}
136
137#[derive(Clone)]
138enum ClientSource {
139    Roots(Vec<CertificateDer<'static>>),
140    Rustls(Box<rustls::ClientConfig>),
141}
142
143impl ClientQuicConfig {
144    /// Create a config from a list of trusted root certs + server name.
145    pub fn new(root_certs: Vec<CertificateDer<'static>>, server_name: impl Into<String>) -> Self {
146        Self {
147            source: ClientSource::Roots(root_certs),
148            server_name: server_name.into(),
149            config: QuicConfig::default(),
150        }
151    }
152
153    /// Create a config from a pre-built rustls `ClientConfig` + server name
154    /// (for custom verifiers, client-auth certs, etc).
155    pub fn from_rustls(
156        rustls_config: rustls::ClientConfig,
157        server_name: impl Into<String>,
158    ) -> Self {
159        Self {
160            source: ClientSource::Rustls(Box::new(rustls_config)),
161            server_name: server_name.into(),
162            config: QuicConfig::default(),
163        }
164    }
165
166    /// Override the QUIC transport parameters.
167    pub fn with_config(mut self, config: QuicConfig) -> Self {
168        self.config = config;
169        self
170    }
171
172    pub fn server_name(&self) -> &str {
173        &self.server_name
174    }
175
176    pub(crate) fn build(self) -> Result<quinn::ClientConfig> {
177        let mut client_config = match self.source {
178            ClientSource::Roots(root_certs) => {
179                let mut root_store = rustls::RootCertStore::empty();
180                for cert in root_certs {
181                    root_store
182                        .add(cert)
183                        .map_err(|e| Error::TlsConfig(e.to_string()))?;
184                }
185                quinn::ClientConfig::with_root_certificates(Arc::new(root_store))
186                    .map_err(|e| Error::TlsConfig(e.to_string()))?
187            }
188            ClientSource::Rustls(rustls_config) => {
189                let quic_config = quinn::crypto::rustls::QuicClientConfig::try_from(*rustls_config)
190                    .map_err(|e| Error::QuicConfigError(e.to_string()))?;
191                quinn::ClientConfig::new(Arc::new(quic_config))
192            }
193        };
194        client_config.transport_config(Arc::new(build_transport_config(&self.config)));
195        Ok(client_config)
196    }
197}
198
199fn build_transport_config(config: &QuicConfig) -> quinn::TransportConfig {
200    let mut transport = quinn::TransportConfig::default();
201    transport.max_concurrent_bidi_streams(
202        quinn::VarInt::from_u64(config.max_bi_streams).unwrap_or(quinn::VarInt::from_u32(100u32)),
203    );
204    transport.max_concurrent_uni_streams(
205        quinn::VarInt::from_u64(config.max_uni_streams).unwrap_or(quinn::VarInt::from_u32(100u32)),
206    );
207    if let Some(interval) = config.keep_alive_interval {
208        transport.keep_alive_interval(Some(interval));
209    }
210    if let Some(timeout) = config.idle_timeout {
211        if let Ok(idle) = quinn::IdleTimeout::try_from(timeout) {
212            transport.max_idle_timeout(Some(idle));
213        }
214    }
215    // Disable datagrams explicitly when the caller opts out.
216    if !config.enable_datagrams {
217        transport.datagram_receive_buffer_size(None);
218    }
219    transport
220}
221
222/// A QUIC endpoint that can listen for and initiate connections.
223pub struct QuicEndpoint {
224    inner: quinn::Endpoint,
225    local_endpoint: Endpoint,
226}
227
228impl QuicEndpoint {
229    /// Bind to a local address and start listening with the given server config.
230    pub async fn listen(endpoint: &Endpoint, config: ServerQuicConfig) -> Result<Self> {
231        let addr = SocketAddr::try_from(endpoint.clone())?;
232        let server_config = config.build()?;
233        let inner = quinn::Endpoint::server(server_config, addr)?;
234        let local_addr = inner.local_addr()?;
235        let local_endpoint = Endpoint::new_quic_addr(local_addr);
236        Ok(Self {
237            inner,
238            local_endpoint,
239        })
240    }
241
242    /// Connect to a remote QUIC endpoint.
243    pub async fn dial(endpoint: &Endpoint, config: ClientQuicConfig) -> Result<QuicConn> {
244        let addr = SocketAddr::try_from(endpoint.clone())?;
245        let server_name = config.server_name.clone();
246        let client_config = config.build()?;
247        // Bind the client socket in the same address family as the target.
248        let bind_addr: SocketAddr = if addr.is_ipv6() {
249            "[::]:0".parse().unwrap()
250        } else {
251            "0.0.0.0:0".parse().unwrap()
252        };
253        let mut quinn_endpoint = quinn::Endpoint::client(bind_addr)?;
254        quinn_endpoint.set_default_client_config(client_config);
255
256        let connection = quinn_endpoint.connect(addr, &server_name)?.await?;
257        let peer_endpoint = Endpoint::new_quic_addr(connection.remote_address());
258        let local_addr = quinn_endpoint.local_addr()?;
259        let local_endpoint = Endpoint::new_quic_addr(local_addr);
260
261        Ok(QuicConn {
262            inner: connection,
263            peer_endpoint,
264            local_endpoint,
265        })
266    }
267
268    /// Accept an incoming QUIC connection.
269    pub async fn accept(&self) -> Result<QuicConn> {
270        let incoming = self.inner.accept().await.ok_or(Error::ConnectionClosed)?;
271
272        let connection = incoming.await?;
273        let peer_endpoint = Endpoint::new_quic_addr(connection.remote_address());
274        let local_endpoint = self.local_endpoint.clone();
275
276        Ok(QuicConn {
277            inner: connection,
278            peer_endpoint,
279            local_endpoint,
280        })
281    }
282
283    /// Returns the local endpoint.
284    pub fn local_endpoint(&self) -> Result<Endpoint> {
285        Ok(self.local_endpoint.clone())
286    }
287
288    /// Close the endpoint.
289    pub fn close(&self, code: u32, reason: &[u8]) {
290        self.inner.close(quinn::VarInt::from_u32(code), reason);
291    }
292}
293
294/// A QUIC connection. Manages streams and datagrams.
295/// This is NOT a single read/write channel — it is a stream factory.
296pub struct QuicConn {
297    inner: quinn::Connection,
298    peer_endpoint: Endpoint,
299    local_endpoint: Endpoint,
300}
301
302impl QuicConn {
303    /// Open a new bidirectional stream.
304    pub async fn open_bi(&self) -> Result<(QuicSendStream, QuicRecvStream)> {
305        let (send, recv) = self.inner.open_bi().await?;
306        Ok((send, recv))
307    }
308
309    /// Open a new unidirectional (send-only) stream.
310    pub async fn open_uni(&self) -> Result<QuicSendStream> {
311        let send = self.inner.open_uni().await?;
312        Ok(send)
313    }
314
315    /// Accept a bidirectional stream opened by the peer.
316    pub async fn accept_bi(&self) -> Result<(QuicSendStream, QuicRecvStream)> {
317        let (send, recv) = self.inner.accept_bi().await?;
318        Ok((send, recv))
319    }
320
321    /// Accept a unidirectional (receive-only) stream from the peer.
322    pub async fn accept_uni(&self) -> Result<QuicRecvStream> {
323        let recv = self.inner.accept_uni().await?;
324        Ok(recv)
325    }
326
327    /// Send an unreliable datagram over the connection. Zero-copy —
328    /// ownership of the `Bytes` allocation is passed to quinn.
329    pub fn send_datagram(&self, data: Bytes) -> Result<()> {
330        self.inner
331            .send_datagram(data.into_inner())
332            .map_err(|e| Error::QuicConfigError(e.to_string()))?;
333        Ok(())
334    }
335
336    /// Receive an unreliable datagram. Zero-copy — wraps the allocation
337    /// returned by quinn.
338    pub async fn recv_datagram(&self) -> Result<Bytes> {
339        let data = self.inner.read_datagram().await?;
340        Ok(Bytes::from_inner(data))
341    }
342
343    /// Maximum datagram size the peer supports, or None if unsupported.
344    pub fn max_datagram_size(&self) -> Option<usize> {
345        self.inner.max_datagram_size()
346    }
347
348    /// Remote peer's address.
349    pub fn peer_endpoint(&self) -> Result<Endpoint> {
350        Ok(self.peer_endpoint.clone())
351    }
352
353    /// Local address.
354    pub fn local_endpoint(&self) -> Result<Endpoint> {
355        Ok(self.local_endpoint.clone())
356    }
357
358    /// Current round-trip time estimate.
359    pub fn rtt(&self) -> Duration {
360        self.inner.rtt()
361    }
362
363    /// Close the connection gracefully.
364    pub fn close(&self, code: u32, reason: &[u8]) {
365        self.inner.close(quinn::VarInt::from_u32(code), reason);
366    }
367
368    /// Wait for the connection to be closed (by us or the peer).
369    pub async fn closed(&self) -> quinn::ConnectionError {
370        self.inner.closed().await
371    }
372
373    /// Returns a reference to the inner quinn connection.
374    pub fn inner(&self) -> &quinn::Connection {
375        &self.inner
376    }
377
378    /// Peer certificate chain from the QUIC TLS handshake. Quinn returns
379    /// a type-erased `Box<dyn Any>`; for rustls-based QUIC (which is what
380    /// karyon uses) it downcasts to `Vec<CertificateDer<'static>>`.
381    pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
382        let any = self.inner.peer_identity()?;
383        any.downcast::<Vec<CertificateDer<'static>>>()
384            .ok()
385            .map(|b| *b)
386    }
387}
388
389// -- StreamMux impl --
390
391impl StreamMux for QuicConn {
392    async fn open_stream(&self) -> Result<Box<dyn ByteStream>> {
393        let (send, recv) = self.open_bi().await?;
394        Ok(Box::new(QuicBiStream {
395            send,
396            recv,
397            peer_endpoint: self.peer_endpoint.clone(),
398            local_endpoint: self.local_endpoint.clone(),
399        }))
400    }
401
402    async fn accept_stream(&self) -> Result<Box<dyn ByteStream>> {
403        let (send, recv) = self.accept_bi().await?;
404        Ok(Box::new(QuicBiStream {
405            send,
406            recv,
407            peer_endpoint: self.peer_endpoint.clone(),
408            local_endpoint: self.local_endpoint.clone(),
409        }))
410    }
411
412    fn peer_endpoint(&self) -> Option<Endpoint> {
413        Some(self.peer_endpoint.clone())
414    }
415
416    fn local_endpoint(&self) -> Option<Endpoint> {
417        Some(self.local_endpoint.clone())
418    }
419}
420
421/// Single QUIC bidirectional stream as a ByteStream.
422pub struct QuicBiStream {
423    send: QuicSendStream,
424    recv: QuicRecvStream,
425    peer_endpoint: Endpoint,
426    local_endpoint: Endpoint,
427}
428
429impl ByteStream for QuicBiStream {
430    fn peer_endpoint(&self) -> Option<Endpoint> {
431        Some(self.peer_endpoint.clone())
432    }
433    fn local_endpoint(&self) -> Option<Endpoint> {
434        Some(self.local_endpoint.clone())
435    }
436}
437
438// Quinn streams use tokio IO traits natively.
439// For smol builds, delegate manually via poll methods.
440
441#[cfg(feature = "tokio")]
442impl AsyncRead for QuicBiStream {
443    fn poll_read(
444        mut self: Pin<&mut Self>,
445        cx: &mut Context<'_>,
446        buf: &mut ReadBuf<'_>,
447    ) -> Poll<io::Result<()>> {
448        Pin::new(&mut self.recv).poll_read(cx, buf)
449    }
450}
451
452#[cfg(feature = "smol")]
453impl AsyncRead for QuicBiStream {
454    fn poll_read(
455        mut self: Pin<&mut Self>,
456        cx: &mut Context<'_>,
457        buf: &mut [u8],
458    ) -> Poll<io::Result<usize>> {
459        FutAsyncRead::poll_read(Pin::new(&mut self.recv), cx, buf)
460    }
461}
462
463#[cfg(feature = "tokio")]
464impl AsyncWrite for QuicBiStream {
465    fn poll_write(
466        mut self: Pin<&mut Self>,
467        cx: &mut Context<'_>,
468        buf: &[u8],
469    ) -> Poll<io::Result<usize>> {
470        TokioAsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf)
471    }
472
473    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
474        TokioAsyncWrite::poll_flush(Pin::new(&mut self.send), cx)
475    }
476
477    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
478        TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.send), cx)
479    }
480}
481
482#[cfg(feature = "smol")]
483impl AsyncWrite for QuicBiStream {
484    fn poll_write(
485        mut self: Pin<&mut Self>,
486        cx: &mut Context<'_>,
487        buf: &[u8],
488    ) -> Poll<io::Result<usize>> {
489        FutAsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf)
490    }
491
492    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
493        FutAsyncWrite::poll_flush(Pin::new(&mut self.send), cx)
494    }
495
496    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
497        FutAsyncWrite::poll_close(Pin::new(&mut self.send), cx)
498    }
499}