karyon_net/transports/
ws.rs

1use std::net::SocketAddr;
2
3#[cfg(feature = "tls")]
4use std::sync::Arc;
5
6use async_trait::async_trait;
7#[cfg(feature = "tls")]
8use rustls_pki_types as pki_types;
9
10use async_tungstenite::tungstenite::Error;
11
12#[cfg(feature = "smol")]
13use async_tungstenite::{accept_async, client_async};
14
15#[cfg(feature = "tokio")]
16use async_tungstenite::tokio::{accept_async, client_async};
17
18use karyon_core::async_runtime::{
19    lock::Mutex,
20    net::{TcpListener, TcpStream},
21};
22
23#[cfg(feature = "tls")]
24use crate::async_rustls::{rustls, TlsAcceptor, TlsConnector};
25
26use crate::{
27    codec::WebSocketCodec,
28    connection::{Conn, Connection, ToConn},
29    endpoint::Endpoint,
30    listener::{ConnListener, Listener, ToListener},
31    stream::{ReadWsStream, WriteWsStream, WsStream},
32    Result,
33};
34
35use super::tcp::TcpConfig;
36
37/// WSS configuration
38#[derive(Clone)]
39pub struct ServerWssConfig {
40    #[cfg(feature = "tls")]
41    pub server_config: rustls::ServerConfig,
42}
43
44/// WSS configuration
45#[derive(Clone)]
46pub struct ClientWssConfig {
47    #[cfg(feature = "tls")]
48    pub client_config: rustls::ClientConfig,
49    pub dns_name: String,
50}
51
52/// WS configuration
53#[derive(Clone, Default)]
54pub struct ServerWsConfig {
55    pub tcp_config: TcpConfig,
56    pub wss_config: Option<ServerWssConfig>,
57}
58
59/// WS configuration
60#[derive(Clone, Default)]
61pub struct ClientWsConfig {
62    pub tcp_config: TcpConfig,
63    pub wss_config: Option<ClientWssConfig>,
64}
65
66/// WS network connection implementation of the [`Connection`] trait.
67pub struct WsConn<C> {
68    read_stream: Mutex<ReadWsStream<C>>,
69    write_stream: Mutex<WriteWsStream<C>>,
70    peer_endpoint: Endpoint,
71    local_endpoint: Endpoint,
72}
73
74impl<C> WsConn<C>
75where
76    C: WebSocketCodec + Clone,
77{
78    /// Creates a new WsConn
79    pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self {
80        let (read, write) = ws.split();
81        Self {
82            read_stream: Mutex::new(read),
83            write_stream: Mutex::new(write),
84            peer_endpoint,
85            local_endpoint,
86        }
87    }
88}
89
90#[async_trait]
91impl<C, E> Connection for WsConn<C>
92where
93    C: WebSocketCodec<Error = E>,
94    E: From<Error>,
95{
96    type Message = C::Message;
97    type Error = E;
98    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
99        Ok(self.peer_endpoint.clone())
100    }
101
102    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
103        Ok(self.local_endpoint.clone())
104    }
105
106    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
107        self.read_stream.lock().await.recv().await
108    }
109
110    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
111        self.write_stream.lock().await.send(msg).await
112    }
113}
114
115/// Ws network listener implementation of the `Listener` [`ConnListener`] trait.
116pub struct WsListener<C> {
117    inner: TcpListener,
118    config: ServerWsConfig,
119    codec: C,
120    #[cfg(feature = "tls")]
121    tls_acceptor: Option<TlsAcceptor>,
122}
123
124#[async_trait]
125impl<C, E> ConnListener for WsListener<C>
126where
127    C: WebSocketCodec<Error = E> + Clone + 'static,
128    E: From<Error> + From<std::io::Error>,
129{
130    type Message = C::Message;
131    type Error = E;
132    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
133        match self.config.wss_config {
134            Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)),
135            None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)),
136        }
137    }
138
139    async fn accept(&self) -> std::result::Result<Conn<Self::Message, Self::Error>, Self::Error> {
140        let (socket, _) = self.inner.accept().await?;
141        socket.set_nodelay(self.config.tcp_config.nodelay)?;
142
143        match &self.config.wss_config {
144            #[cfg(feature = "tls")]
145            Some(_) => match &self.tls_acceptor {
146                Some(acceptor) => {
147                    let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
148                    let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
149
150                    let tls_conn = acceptor.accept(socket).await?.into();
151                    let conn = accept_async(tls_conn).await?;
152                    Ok(Box::new(WsConn::new(
153                        WsStream::new_wss(conn, self.codec.clone()),
154                        peer_endpoint,
155                        local_endpoint,
156                    )))
157                }
158                None => unreachable!(),
159            },
160            None => {
161                let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
162                let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
163
164                let conn = accept_async(socket).await?;
165
166                Ok(Box::new(WsConn::new(
167                    WsStream::new_ws(conn, self.codec.clone()),
168                    peer_endpoint,
169                    local_endpoint,
170                )))
171            }
172            #[cfg(not(feature = "tls"))]
173            _ => unreachable!(),
174        }
175    }
176}
177
178/// Connects to the given WS address and port.
179pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>>
180where
181    C: WebSocketCodec + Clone,
182{
183    let addr = SocketAddr::try_from(endpoint.clone())?;
184    let socket = TcpStream::connect(addr).await?;
185    socket.set_nodelay(config.tcp_config.nodelay)?;
186
187    match &config.wss_config {
188        #[cfg(feature = "tls")]
189        Some(conf) => {
190            let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
191            let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
192
193            let connector = TlsConnector::from(Arc::new(conf.client_config.clone()));
194
195            let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?;
196            let tls_conn = connector.connect(altname, socket).await?.into();
197            let (conn, _resp) = client_async(endpoint.to_string(), tls_conn)
198                .await
199                .map_err(Box::new)?;
200            Ok(WsConn::new(
201                WsStream::new_wss(conn, codec),
202                peer_endpoint,
203                local_endpoint,
204            ))
205        }
206        None => {
207            let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
208            let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
209            let (conn, _resp) = client_async(endpoint.to_string(), socket)
210                .await
211                .map_err(Box::new)?;
212            Ok(WsConn::new(
213                WsStream::new_ws(conn, codec),
214                peer_endpoint,
215                local_endpoint,
216            ))
217        }
218        #[cfg(not(feature = "tls"))]
219        _ => unreachable!(),
220    }
221}
222
223/// Listens on the given WS address and port.
224pub async fn listen<C>(
225    endpoint: &Endpoint,
226    config: ServerWsConfig,
227    codec: C,
228) -> Result<WsListener<C>> {
229    let addr = SocketAddr::try_from(endpoint.clone())?;
230
231    let listener = TcpListener::bind(addr).await?;
232    match &config.wss_config {
233        #[cfg(feature = "tls")]
234        Some(conf) => {
235            let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone()));
236            Ok(WsListener {
237                inner: listener,
238                config,
239                codec,
240                tls_acceptor: Some(acceptor),
241            })
242        }
243        None => Ok(WsListener {
244            inner: listener,
245            config,
246            codec,
247            #[cfg(feature = "tls")]
248            tls_acceptor: None,
249        }),
250        #[cfg(not(feature = "tls"))]
251        _ => unreachable!(),
252    }
253}
254
255impl<C, E> From<WsListener<C>> for Listener<C::Message, E>
256where
257    C: WebSocketCodec<Error = E> + Clone + 'static,
258    E: From<Error> + From<std::io::Error>,
259{
260    fn from(listener: WsListener<C>) -> Self {
261        Box::new(listener)
262    }
263}
264
265impl<C, E> ToConn for WsConn<C>
266where
267    C: WebSocketCodec<Error = E> + 'static,
268    E: From<Error>,
269{
270    type Message = C::Message;
271    type Error = E;
272    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
273        Box::new(self)
274    }
275}
276
277impl<C, E> ToListener for WsListener<C>
278where
279    C: WebSocketCodec<Error = E> + Clone + 'static,
280    E: From<Error> + From<std::io::Error>,
281{
282    type Message = C::Message;
283    type Error = E;
284    fn to_listener(self) -> Listener<Self::Message, Self::Error> {
285        Box::new(self)
286    }
287}