1use std::{
2 io,
3 net::{IpAddr, Ipv4Addr},
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use rustls_pki_types::{self as pki_types, CertificateDer};
10
11use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
12
13#[cfg(feature = "tokio")]
14use tokio::io::ReadBuf;
15
16use crate::{
17 async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream},
18 layer::{ClientLayer, ServerLayer},
19 transports::tcp::TcpListener,
20 Addr, ByteStream, Endpoint, Error, Result,
21};
22
23#[derive(Clone)]
25pub struct ClientTlsConfig {
26 pub client_config: rustls::ClientConfig,
27 pub dns_name: String,
28}
29
30#[derive(Clone)]
32pub struct ServerTlsConfig {
33 pub server_config: rustls::ServerConfig,
34}
35
36#[derive(Clone)]
55pub struct TlsLayer {
56 client_config: Option<ClientTlsConfig>,
57 server_config: Option<ServerTlsConfig>,
58}
59
60impl TlsLayer {
61 pub fn client(config: ClientTlsConfig) -> Self {
63 Self {
64 client_config: Some(config),
65 server_config: None,
66 }
67 }
68
69 pub fn server(config: ServerTlsConfig) -> Self {
71 Self {
72 client_config: None,
73 server_config: Some(config),
74 }
75 }
76}
77
78impl ClientLayer<Box<dyn ByteStream>, Box<dyn ByteStream>> for TlsLayer {
79 async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<Box<dyn ByteStream>> {
80 let config = self.client_config.as_ref().ok_or_else(|| {
81 Error::IO(io::Error::new(
82 io::ErrorKind::InvalidInput,
83 "missing TLS client config",
84 ))
85 })?;
86
87 let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
88 let dns = pki_types::ServerName::try_from(config.dns_name.clone())?;
89
90 let peer = stream.peer_endpoint();
91 let local = stream.local_endpoint();
92 let tls = connector.connect(dns, stream).await?;
93
94 Ok(Box::new(TlsByteStream {
95 inner: TlsStream::Client(tls),
96 peer_endpoint: peer,
97 local_endpoint: local,
98 }))
99 }
100}
101
102impl ServerLayer<Box<dyn ByteStream>, Box<dyn ByteStream>> for TlsLayer {
103 async fn handshake(&self, stream: Box<dyn ByteStream>) -> Result<Box<dyn ByteStream>> {
104 let config = self.server_config.as_ref().ok_or_else(|| {
105 Error::IO(io::Error::new(
106 io::ErrorKind::InvalidInput,
107 "missing TLS server config",
108 ))
109 })?;
110
111 let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
112
113 let peer = stream.peer_endpoint();
114 let local = stream.local_endpoint();
115 let tls = acceptor.accept(stream).await?;
116
117 Ok(Box::new(TlsByteStream {
118 inner: TlsStream::Server(tls),
119 peer_endpoint: peer,
120 local_endpoint: local,
121 }))
122 }
123}
124
125pub struct TlsByteStream {
127 inner: TlsStream<Box<dyn ByteStream>>,
128 peer_endpoint: Option<Endpoint>,
129 local_endpoint: Option<Endpoint>,
130}
131
132impl ByteStream for TlsByteStream {
133 fn peer_endpoint(&self) -> Option<Endpoint> {
134 self.peer_endpoint.clone()
135 }
136 fn local_endpoint(&self) -> Option<Endpoint> {
137 self.local_endpoint.clone()
138 }
139 fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
140 let (_, state) = self.inner.get_ref();
144 state
145 .peer_certificates()
146 .map(|certs| certs.iter().map(|c| c.clone().into_owned()).collect())
147 }
148}
149
150#[cfg(feature = "smol")]
153impl AsyncRead for TlsByteStream {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut [u8],
158 ) -> Poll<io::Result<usize>> {
159 Pin::new(&mut self.inner).poll_read(cx, buf)
160 }
161}
162
163#[cfg(feature = "tokio")]
164impl AsyncRead for TlsByteStream {
165 fn poll_read(
166 mut self: Pin<&mut Self>,
167 cx: &mut Context<'_>,
168 buf: &mut ReadBuf<'_>,
169 ) -> Poll<io::Result<()>> {
170 Pin::new(&mut self.inner).poll_read(cx, buf)
171 }
172}
173
174#[cfg(feature = "smol")]
175impl AsyncWrite for TlsByteStream {
176 fn poll_write(
177 mut self: Pin<&mut Self>,
178 cx: &mut Context<'_>,
179 buf: &[u8],
180 ) -> Poll<io::Result<usize>> {
181 Pin::new(&mut self.inner).poll_write(cx, buf)
182 }
183
184 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185 Pin::new(&mut self.inner).poll_flush(cx)
186 }
187
188 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
189 Pin::new(&mut self.inner).poll_close(cx)
190 }
191}
192
193#[cfg(feature = "tokio")]
194impl AsyncWrite for TlsByteStream {
195 fn poll_write(
196 mut self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &[u8],
199 ) -> Poll<io::Result<usize>> {
200 Pin::new(&mut self.inner).poll_write(cx, buf)
201 }
202
203 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
204 Pin::new(&mut self.inner).poll_flush(cx)
205 }
206
207 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 Pin::new(&mut self.inner).poll_shutdown(cx)
209 }
210}
211
212pub struct TlsListener {
214 inner: TcpListener,
215 layer: TlsLayer,
216}
217
218impl TlsListener {
219 pub fn new(inner: TcpListener, config: ServerTlsConfig) -> Self {
221 Self {
222 inner,
223 layer: TlsLayer::server(config),
224 }
225 }
226
227 pub async fn accept(&self) -> Result<Box<dyn ByteStream>> {
229 let stream = self.inner.accept().await?;
230 ServerLayer::handshake(&self.layer, stream).await
231 }
232
233 pub fn local_endpoint(&self) -> Result<Endpoint> {
235 let tcp_ep = self.inner.local_endpoint()?;
236 let port = tcp_ep.port().unwrap_or(0);
237 let addr = tcp_ep
238 .addr()
239 .unwrap_or(Addr::Ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
240 Ok(Endpoint::Tls(addr, port))
241 }
242}