Skip to main content

karyon_net/layers/
tls.rs

1use std::{
2    io,
3    net::{IpAddr, Ipv4Addr},
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use rustls_pki_types::{self as pki_types, CertificateDer};
10
11use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
12
13#[cfg(feature = "tokio")]
14use tokio::io::ReadBuf;
15
16use crate::{
17    async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream},
18    layer::{ClientLayer, ServerLayer},
19    transports::tcp::TcpListener,
20    Addr, ByteStream, Endpoint, Error, Result,
21};
22
23/// TLS client config.
24#[derive(Clone)]
25pub struct ClientTlsConfig {
26    pub client_config: rustls::ClientConfig,
27    pub dns_name: String,
28}
29
30/// TLS server config.
31#[derive(Clone)]
32pub struct ServerTlsConfig {
33    pub server_config: rustls::ServerConfig,
34}
35
36/// TLS middleware layer. Implements both `ClientLayer` and `ServerLayer`.
37///
38/// Wraps any `ByteStream` with TLS encryption.
39///
40/// # Example
41///
42/// ```no_run
43/// use karyon_net::{tcp, ClientLayer, Endpoint};
44/// use karyon_net::tls::{TlsLayer, ClientTlsConfig};
45///
46/// async {
47///     let ep: Endpoint = "tcp://127.0.0.1:443".parse().unwrap();
48///     let stream = tcp::connect(&ep, Default::default()).await.unwrap();
49///     // let tls_stream = ClientLayer::handshake(
50///     //     &TlsLayer::client(config), stream
51///     // ).await.unwrap();
52/// };
53/// ```
54#[derive(Clone)]
55pub struct TlsLayer {
56    client_config: Option<ClientTlsConfig>,
57    server_config: Option<ServerTlsConfig>,
58}
59
60impl TlsLayer {
61    /// Create a TLS layer for client connections.
62    pub fn client(config: ClientTlsConfig) -> Self {
63        Self {
64            client_config: Some(config),
65            server_config: None,
66        }
67    }
68
69    /// Create a TLS layer for server connections.
70    pub fn server(config: ServerTlsConfig) -> Self {
71        Self {
72            client_config: None,
73            server_config: Some(config),
74        }
75    }
76}
77
78impl ClientLayer<Box<dyn ByteStream>, Box<dyn ByteStream>> for TlsLayer {
79    async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<Box<dyn ByteStream>> {
80        let config = self.client_config.as_ref().ok_or_else(|| {
81            Error::IO(io::Error::new(
82                io::ErrorKind::InvalidInput,
83                "missing TLS client config",
84            ))
85        })?;
86
87        let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
88        let dns = pki_types::ServerName::try_from(config.dns_name.clone())?;
89
90        let peer = stream.peer_endpoint();
91        let local = stream.local_endpoint();
92        let tls = connector.connect(dns, stream).await?;
93
94        Ok(Box::new(TlsByteStream {
95            inner: TlsStream::Client(tls),
96            peer_endpoint: peer,
97            local_endpoint: local,
98        }))
99    }
100}
101
102impl ServerLayer<Box<dyn ByteStream>, Box<dyn ByteStream>> for TlsLayer {
103    async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<Box<dyn ByteStream>> {
104        let config = self.server_config.as_ref().ok_or_else(|| {
105            Error::IO(io::Error::new(
106                io::ErrorKind::InvalidInput,
107                "missing TLS server config",
108            ))
109        })?;
110
111        let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
112
113        let peer = stream.peer_endpoint();
114        let local = stream.local_endpoint();
115        let tls = acceptor.accept(stream).await?;
116
117        Ok(Box::new(TlsByteStream {
118            inner: TlsStream::Server(tls),
119            peer_endpoint: peer,
120            local_endpoint: local,
121        }))
122    }
123}
124
125/// TLS stream wrapping any `ByteStream`.
126pub struct TlsByteStream {
127    inner: TlsStream<Box<dyn ByteStream>>,
128    peer_endpoint: Option<Endpoint>,
129    local_endpoint: Option<Endpoint>,
130}
131
132impl ByteStream for TlsByteStream {
133    fn peer_endpoint(&self) -> Option<Endpoint> {
134        self.peer_endpoint.clone()
135    }
136    fn local_endpoint(&self) -> Option<Endpoint> {
137        self.local_endpoint.clone()
138    }
139    fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
140        // TlsStream::get_ref() exposes the rustls connection state;
141        // peer_certificates() returns whatever the peer presented during
142        // the handshake (if any).
143        let (_, state) = self.inner.get_ref();
144        state
145            .peer_certificates()
146            .map(|certs| certs.iter().map(|c| c.clone().into_owned()).collect())
147    }
148}
149
150// -- AsyncRead / AsyncWrite delegation --
151
152#[cfg(feature = "smol")]
153impl AsyncRead for TlsByteStream {
154    fn poll_read(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &mut [u8],
158    ) -> Poll<io::Result<usize>> {
159        Pin::new(&mut self.inner).poll_read(cx, buf)
160    }
161}
162
163#[cfg(feature = "tokio")]
164impl AsyncRead for TlsByteStream {
165    fn poll_read(
166        mut self: Pin<&mut Self>,
167        cx: &mut Context<'_>,
168        buf: &mut ReadBuf<'_>,
169    ) -> Poll<io::Result<()>> {
170        Pin::new(&mut self.inner).poll_read(cx, buf)
171    }
172}
173
174#[cfg(feature = "smol")]
175impl AsyncWrite for TlsByteStream {
176    fn poll_write(
177        mut self: Pin<&mut Self>,
178        cx: &mut Context<'_>,
179        buf: &[u8],
180    ) -> Poll<io::Result<usize>> {
181        Pin::new(&mut self.inner).poll_write(cx, buf)
182    }
183
184    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185        Pin::new(&mut self.inner).poll_flush(cx)
186    }
187
188    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
189        Pin::new(&mut self.inner).poll_close(cx)
190    }
191}
192
193#[cfg(feature = "tokio")]
194impl AsyncWrite for TlsByteStream {
195    fn poll_write(
196        mut self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        buf: &[u8],
199    ) -> Poll<io::Result<usize>> {
200        Pin::new(&mut self.inner).poll_write(cx, buf)
201    }
202
203    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
204        Pin::new(&mut self.inner).poll_flush(cx)
205    }
206
207    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        Pin::new(&mut self.inner).poll_shutdown(cx)
209    }
210}
211
212/// TLS listener. Wraps a TCP listener with TLS acceptance.
213pub struct TlsListener {
214    inner: TcpListener,
215    layer: TlsLayer,
216}
217
218impl TlsListener {
219    /// Create a TLS listener from a TCP listener and TLS config.
220    pub fn new(inner: TcpListener, config: ServerTlsConfig) -> Self {
221        Self {
222            inner,
223            layer: TlsLayer::server(config),
224        }
225    }
226
227    /// Accept a new TLS connection.
228    pub async fn accept(&self) -> Result<Box<dyn ByteStream>> {
229        let stream = self.inner.accept().await?;
230        ServerLayer::handshake(&self.layer, stream).await
231    }
232
233    /// Local endpoint this listener is bound to.
234    pub fn local_endpoint(&self) -> Result<Endpoint> {
235        let tcp_ep = self.inner.local_endpoint()?;
236        let port = tcp_ep.port().unwrap_or(0);
237        let addr = tcp_ep
238            .addr()
239            .unwrap_or(Addr::Ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
240        Ok(Endpoint::Tls(addr, port))
241    }
242}