1use std::{net::SocketAddr, sync::Arc};
2
3use async_trait::async_trait;
4use futures_util::SinkExt;
5use rustls_pki_types as pki_types;
6
7use karyon_core::async_runtime::{
8 io::{split, ReadHalf, WriteHalf},
9 lock::Mutex,
10 net::{TcpListener, TcpStream},
11};
12
13#[cfg(feature = "tls")]
14use crate::async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
15
16use crate::{
17 codec::Codec,
18 connection::{Conn, Connection, ToConn},
19 endpoint::Endpoint,
20 listener::{ConnListener, Listener, ToListener},
21 stream::{ReadStream, WriteStream},
22 Result,
23};
24
25use super::tcp::TcpConfig;
26
27#[derive(Clone)]
29pub struct ServerTlsConfig {
30 pub tcp_config: TcpConfig,
31 pub server_config: rustls::ServerConfig,
32}
33
34#[derive(Clone)]
35pub struct ClientTlsConfig {
36 pub tcp_config: TcpConfig,
37 pub client_config: rustls::ClientConfig,
38 pub dns_name: String,
39}
40
41pub struct TlsConn<C> {
43 read_stream: Mutex<ReadStream<ReadHalf<TlsStream<TcpStream>>, C>>,
44 write_stream: Mutex<WriteStream<WriteHalf<TlsStream<TcpStream>>, C>>,
45 peer_endpoint: Endpoint,
46 local_endpoint: Endpoint,
47}
48
49impl<C> TlsConn<C>
50where
51 C: Codec + Clone,
52{
53 pub fn new(
55 conn: TlsStream<TcpStream>,
56 codec: C,
57 peer_endpoint: Endpoint,
58 local_endpoint: Endpoint,
59 ) -> Self {
60 let (read, write) = split(conn);
61 let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
62 let write_stream = Mutex::new(WriteStream::new(write, codec));
63 Self {
64 read_stream,
65 write_stream,
66 peer_endpoint,
67 local_endpoint,
68 }
69 }
70}
71
72#[async_trait]
73impl<C, E> Connection for TlsConn<C>
74where
75 C: Clone + Codec<Error = E>,
76{
77 type Message = C::Message;
78 type Error = E;
79 fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
80 Ok(self.peer_endpoint.clone())
81 }
82
83 fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
84 Ok(self.local_endpoint.clone())
85 }
86
87 async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
88 self.read_stream.lock().await.recv().await
89 }
90
91 async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
92 self.write_stream.lock().await.send(msg).await
93 }
94}
95
96pub async fn dial<C>(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result<TlsConn<C>>
98where
99 C: Codec + Clone,
100{
101 let addr = SocketAddr::try_from(endpoint.clone())?;
102
103 let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
104
105 let socket = TcpStream::connect(addr).await?;
106 socket.set_nodelay(config.tcp_config.nodelay)?;
107
108 let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
109 let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
110
111 let altname = pki_types::ServerName::try_from(config.dns_name.clone())?;
112 let conn = connector.connect(altname, socket).await?;
113 Ok(TlsConn::new(
114 TlsStream::Client(conn),
115 codec,
116 peer_endpoint,
117 local_endpoint,
118 ))
119}
120
121pub struct TlsListener<C> {
123 inner: TcpListener,
124 acceptor: TlsAcceptor,
125 config: ServerTlsConfig,
126 codec: C,
127}
128
129impl<C> TlsListener<C>
130where
131 C: Codec + Clone,
132{
133 pub fn new(
134 acceptor: TlsAcceptor,
135 listener: TcpListener,
136 config: ServerTlsConfig,
137 codec: C,
138 ) -> Self {
139 Self {
140 inner: listener,
141 acceptor,
142 config: config.clone(),
143 codec,
144 }
145 }
146}
147
148#[async_trait]
149impl<C, E> ConnListener for TlsListener<C>
150where
151 C: Clone + Codec<Error = E> + 'static,
152 E: From<std::io::Error>,
153{
154 type Message = C::Message;
155 type Error = E;
156 fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
157 Ok(Endpoint::new_tls_addr(self.inner.local_addr()?))
158 }
159
160 async fn accept(&self) -> std::result::Result<Conn<C::Message, E>, Self::Error> {
161 let (socket, _) = self.inner.accept().await?;
162 socket.set_nodelay(self.config.tcp_config.nodelay)?;
163
164 let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
165 let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
166
167 let conn = self.acceptor.accept(socket).await?;
168 Ok(Box::new(TlsConn::new(
169 TlsStream::Server(conn),
170 self.codec.clone(),
171 peer_endpoint,
172 local_endpoint,
173 )))
174 }
175}
176
177pub async fn listen<C>(
179 endpoint: &Endpoint,
180 config: ServerTlsConfig,
181 codec: C,
182) -> Result<TlsListener<C>>
183where
184 C: Clone + Codec,
185{
186 let addr = SocketAddr::try_from(endpoint.clone())?;
187 let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
188 let listener = TcpListener::bind(addr).await?;
189 Ok(TlsListener::new(acceptor, listener, config, codec))
190}
191
192impl<C, E> From<TlsListener<C>> for Listener<C::Message, E>
193where
194 C: Codec<Error = E> + Clone + 'static,
195 E: From<std::io::Error>,
196{
197 fn from(listener: TlsListener<C>) -> Self {
198 Box::new(listener)
199 }
200}
201
202impl<C, E> ToConn for TlsConn<C>
203where
204 C: Codec<Error = E> + Clone + 'static,
205{
206 type Message = C::Message;
207 type Error = E;
208 fn to_conn(self) -> Conn<Self::Message, Self::Error> {
209 Box::new(self)
210 }
211}
212
213impl<C, E> ToListener for TlsListener<C>
214where
215 C: Clone + Codec<Error = E> + 'static,
216 E: From<std::io::Error>,
217{
218 type Message = C::Message;
219 type Error = E;
220 fn to_listener(self) -> Listener<Self::Message, Self::Error> {
221 Box::new(self)
222 }
223}