karyon_net/transports/
tls.rs

1use std::{net::SocketAddr, sync::Arc};
2
3use async_trait::async_trait;
4use futures_util::SinkExt;
5use rustls_pki_types as pki_types;
6
7use karyon_core::async_runtime::{
8    io::{split, ReadHalf, WriteHalf},
9    lock::Mutex,
10    net::{TcpListener, TcpStream},
11};
12
13#[cfg(feature = "tls")]
14use crate::async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
15
16use crate::{
17    codec::Codec,
18    connection::{Conn, Connection, ToConn},
19    endpoint::Endpoint,
20    listener::{ConnListener, Listener, ToListener},
21    stream::{ReadStream, WriteStream},
22    Result,
23};
24
25use super::tcp::TcpConfig;
26
27/// TLS configuration
28#[derive(Clone)]
29pub struct ServerTlsConfig {
30    pub tcp_config: TcpConfig,
31    pub server_config: rustls::ServerConfig,
32}
33
34#[derive(Clone)]
35pub struct ClientTlsConfig {
36    pub tcp_config: TcpConfig,
37    pub client_config: rustls::ClientConfig,
38    pub dns_name: String,
39}
40
41/// TLS network connection implementation of the [`Connection`] trait.
42pub struct TlsConn<C> {
43    read_stream: Mutex<ReadStream<ReadHalf<TlsStream<TcpStream>>, C>>,
44    write_stream: Mutex<WriteStream<WriteHalf<TlsStream<TcpStream>>, C>>,
45    peer_endpoint: Endpoint,
46    local_endpoint: Endpoint,
47}
48
49impl<C> TlsConn<C>
50where
51    C: Codec + Clone,
52{
53    /// Creates a new TlsConn
54    pub fn new(
55        conn: TlsStream<TcpStream>,
56        codec: C,
57        peer_endpoint: Endpoint,
58        local_endpoint: Endpoint,
59    ) -> Self {
60        let (read, write) = split(conn);
61        let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
62        let write_stream = Mutex::new(WriteStream::new(write, codec));
63        Self {
64            read_stream,
65            write_stream,
66            peer_endpoint,
67            local_endpoint,
68        }
69    }
70}
71
72#[async_trait]
73impl<C, E> Connection for TlsConn<C>
74where
75    C: Clone + Codec<Error = E>,
76{
77    type Message = C::Message;
78    type Error = E;
79    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
80        Ok(self.peer_endpoint.clone())
81    }
82
83    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
84        Ok(self.local_endpoint.clone())
85    }
86
87    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
88        self.read_stream.lock().await.recv().await
89    }
90
91    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
92        self.write_stream.lock().await.send(msg).await
93    }
94}
95
96/// Connects to the given TLS address and port.
97pub async fn dial<C>(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result<TlsConn<C>>
98where
99    C: Codec + Clone,
100{
101    let addr = SocketAddr::try_from(endpoint.clone())?;
102
103    let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
104
105    let socket = TcpStream::connect(addr).await?;
106    socket.set_nodelay(config.tcp_config.nodelay)?;
107
108    let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
109    let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
110
111    let altname = pki_types::ServerName::try_from(config.dns_name.clone())?;
112    let conn = connector.connect(altname, socket).await?;
113    Ok(TlsConn::new(
114        TlsStream::Client(conn),
115        codec,
116        peer_endpoint,
117        local_endpoint,
118    ))
119}
120
121/// Tls network listener implementation of the `Listener` [`ConnListener`] trait.
122pub struct TlsListener<C> {
123    inner: TcpListener,
124    acceptor: TlsAcceptor,
125    config: ServerTlsConfig,
126    codec: C,
127}
128
129impl<C> TlsListener<C>
130where
131    C: Codec + Clone,
132{
133    pub fn new(
134        acceptor: TlsAcceptor,
135        listener: TcpListener,
136        config: ServerTlsConfig,
137        codec: C,
138    ) -> Self {
139        Self {
140            inner: listener,
141            acceptor,
142            config: config.clone(),
143            codec,
144        }
145    }
146}
147
148#[async_trait]
149impl<C, E> ConnListener for TlsListener<C>
150where
151    C: Clone + Codec<Error = E> + 'static,
152    E: From<std::io::Error>,
153{
154    type Message = C::Message;
155    type Error = E;
156    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
157        Ok(Endpoint::new_tls_addr(self.inner.local_addr()?))
158    }
159
160    async fn accept(&self) -> std::result::Result<Conn<C::Message, E>, Self::Error> {
161        let (socket, _) = self.inner.accept().await?;
162        socket.set_nodelay(self.config.tcp_config.nodelay)?;
163
164        let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
165        let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
166
167        let conn = self.acceptor.accept(socket).await?;
168        Ok(Box::new(TlsConn::new(
169            TlsStream::Server(conn),
170            self.codec.clone(),
171            peer_endpoint,
172            local_endpoint,
173        )))
174    }
175}
176
177/// Listens on the given TLS address and port.
178pub async fn listen<C>(
179    endpoint: &Endpoint,
180    config: ServerTlsConfig,
181    codec: C,
182) -> Result<TlsListener<C>>
183where
184    C: Clone + Codec,
185{
186    let addr = SocketAddr::try_from(endpoint.clone())?;
187    let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
188    let listener = TcpListener::bind(addr).await?;
189    Ok(TlsListener::new(acceptor, listener, config, codec))
190}
191
192impl<C, E> From<TlsListener<C>> for Listener<C::Message, E>
193where
194    C: Codec<Error = E> + Clone + 'static,
195    E: From<std::io::Error>,
196{
197    fn from(listener: TlsListener<C>) -> Self {
198        Box::new(listener)
199    }
200}
201
202impl<C, E> ToConn for TlsConn<C>
203where
204    C: Codec<Error = E> + Clone + 'static,
205{
206    type Message = C::Message;
207    type Error = E;
208    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
209        Box::new(self)
210    }
211}
212
213impl<C, E> ToListener for TlsListener<C>
214where
215    C: Clone + Codec<Error = E> + 'static,
216    E: From<std::io::Error>,
217{
218    type Message = C::Message;
219    type Error = E;
220    fn to_listener(self) -> Listener<Self::Message, Self::Error> {
221        Box::new(self)
222    }
223}