1use std::net::SocketAddr;
2
3#[cfg(feature = "tls")]
4use std::sync::Arc;
5
6use async_trait::async_trait;
7#[cfg(feature = "tls")]
8use rustls_pki_types as pki_types;
9
10use async_tungstenite::tungstenite::Error;
11
12#[cfg(feature = "smol")]
13use async_tungstenite::{accept_async, client_async};
14
15#[cfg(feature = "tokio")]
16use async_tungstenite::tokio::{accept_async, client_async};
17
18use karyon_core::async_runtime::{
19 lock::Mutex,
20 net::{TcpListener, TcpStream},
21};
22
23#[cfg(feature = "tls")]
24use crate::async_rustls::{rustls, TlsAcceptor, TlsConnector};
25
26use crate::{
27 codec::WebSocketCodec,
28 connection::{Conn, Connection, ToConn},
29 endpoint::Endpoint,
30 listener::{ConnListener, Listener, ToListener},
31 stream::{ReadWsStream, WriteWsStream, WsStream},
32 Result,
33};
34
35use super::tcp::TcpConfig;
36
37#[derive(Clone)]
39pub struct ServerWssConfig {
40 #[cfg(feature = "tls")]
41 pub server_config: rustls::ServerConfig,
42}
43
44#[derive(Clone)]
46pub struct ClientWssConfig {
47 #[cfg(feature = "tls")]
48 pub client_config: rustls::ClientConfig,
49 pub dns_name: String,
50}
51
52#[derive(Clone, Default)]
54pub struct ServerWsConfig {
55 pub tcp_config: TcpConfig,
56 pub wss_config: Option<ServerWssConfig>,
57}
58
59#[derive(Clone, Default)]
61pub struct ClientWsConfig {
62 pub tcp_config: TcpConfig,
63 pub wss_config: Option<ClientWssConfig>,
64}
65
66pub struct WsConn<C> {
68 read_stream: Mutex<ReadWsStream<C>>,
69 write_stream: Mutex<WriteWsStream<C>>,
70 peer_endpoint: Endpoint,
71 local_endpoint: Endpoint,
72}
73
74impl<C> WsConn<C>
75where
76 C: WebSocketCodec + Clone,
77{
78 pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self {
80 let (read, write) = ws.split();
81 Self {
82 read_stream: Mutex::new(read),
83 write_stream: Mutex::new(write),
84 peer_endpoint,
85 local_endpoint,
86 }
87 }
88}
89
90#[async_trait]
91impl<C, E> Connection for WsConn<C>
92where
93 C: WebSocketCodec<Error = E>,
94 E: From<Error>,
95{
96 type Message = C::Message;
97 type Error = E;
98 fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
99 Ok(self.peer_endpoint.clone())
100 }
101
102 fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
103 Ok(self.local_endpoint.clone())
104 }
105
106 async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
107 self.read_stream.lock().await.recv().await
108 }
109
110 async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
111 self.write_stream.lock().await.send(msg).await
112 }
113}
114
115pub struct WsListener<C> {
117 inner: TcpListener,
118 config: ServerWsConfig,
119 codec: C,
120 #[cfg(feature = "tls")]
121 tls_acceptor: Option<TlsAcceptor>,
122}
123
124#[async_trait]
125impl<C, E> ConnListener for WsListener<C>
126where
127 C: WebSocketCodec<Error = E> + Clone + 'static,
128 E: From<Error> + From<std::io::Error>,
129{
130 type Message = C::Message;
131 type Error = E;
132 fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
133 match self.config.wss_config {
134 Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)),
135 None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)),
136 }
137 }
138
139 async fn accept(&self) -> std::result::Result<Conn<Self::Message, Self::Error>, Self::Error> {
140 let (socket, _) = self.inner.accept().await?;
141 socket.set_nodelay(self.config.tcp_config.nodelay)?;
142
143 match &self.config.wss_config {
144 #[cfg(feature = "tls")]
145 Some(_) => match &self.tls_acceptor {
146 Some(acceptor) => {
147 let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
148 let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
149
150 let tls_conn = acceptor.accept(socket).await?.into();
151 let conn = accept_async(tls_conn).await?;
152 Ok(Box::new(WsConn::new(
153 WsStream::new_wss(conn, self.codec.clone()),
154 peer_endpoint,
155 local_endpoint,
156 )))
157 }
158 None => unreachable!(),
159 },
160 None => {
161 let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
162 let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
163
164 let conn = accept_async(socket).await?;
165
166 Ok(Box::new(WsConn::new(
167 WsStream::new_ws(conn, self.codec.clone()),
168 peer_endpoint,
169 local_endpoint,
170 )))
171 }
172 #[cfg(not(feature = "tls"))]
173 _ => unreachable!(),
174 }
175 }
176}
177
178pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>>
180where
181 C: WebSocketCodec + Clone,
182{
183 let addr = SocketAddr::try_from(endpoint.clone())?;
184 let socket = TcpStream::connect(addr).await?;
185 socket.set_nodelay(config.tcp_config.nodelay)?;
186
187 match &config.wss_config {
188 #[cfg(feature = "tls")]
189 Some(conf) => {
190 let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
191 let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
192
193 let connector = TlsConnector::from(Arc::new(conf.client_config.clone()));
194
195 let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?;
196 let tls_conn = connector.connect(altname, socket).await?.into();
197 let (conn, _resp) = client_async(endpoint.to_string(), tls_conn)
198 .await
199 .map_err(Box::new)?;
200 Ok(WsConn::new(
201 WsStream::new_wss(conn, codec),
202 peer_endpoint,
203 local_endpoint,
204 ))
205 }
206 None => {
207 let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
208 let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
209 let (conn, _resp) = client_async(endpoint.to_string(), socket)
210 .await
211 .map_err(Box::new)?;
212 Ok(WsConn::new(
213 WsStream::new_ws(conn, codec),
214 peer_endpoint,
215 local_endpoint,
216 ))
217 }
218 #[cfg(not(feature = "tls"))]
219 _ => unreachable!(),
220 }
221}
222
223pub async fn listen<C>(
225 endpoint: &Endpoint,
226 config: ServerWsConfig,
227 codec: C,
228) -> Result<WsListener<C>> {
229 let addr = SocketAddr::try_from(endpoint.clone())?;
230
231 let listener = TcpListener::bind(addr).await?;
232 match &config.wss_config {
233 #[cfg(feature = "tls")]
234 Some(conf) => {
235 let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone()));
236 Ok(WsListener {
237 inner: listener,
238 config,
239 codec,
240 tls_acceptor: Some(acceptor),
241 })
242 }
243 None => Ok(WsListener {
244 inner: listener,
245 config,
246 codec,
247 #[cfg(feature = "tls")]
248 tls_acceptor: None,
249 }),
250 #[cfg(not(feature = "tls"))]
251 _ => unreachable!(),
252 }
253}
254
255impl<C, E> From<WsListener<C>> for Listener<C::Message, E>
256where
257 C: WebSocketCodec<Error = E> + Clone + 'static,
258 E: From<Error> + From<std::io::Error>,
259{
260 fn from(listener: WsListener<C>) -> Self {
261 Box::new(listener)
262 }
263}
264
265impl<C, E> ToConn for WsConn<C>
266where
267 C: WebSocketCodec<Error = E> + 'static,
268 E: From<Error>,
269{
270 type Message = C::Message;
271 type Error = E;
272 fn to_conn(self) -> Conn<Self::Message, Self::Error> {
273 Box::new(self)
274 }
275}
276
277impl<C, E> ToListener for WsListener<C>
278where
279 C: WebSocketCodec<Error = E> + Clone + 'static,
280 E: From<Error> + From<std::io::Error>,
281{
282 type Message = C::Message;
283 type Error = E;
284 fn to_listener(self) -> Listener<Self::Message, Self::Error> {
285 Box::new(self)
286 }
287}