1use std::{
2 net::{IpAddr, SocketAddr, ToSocketAddrs},
3 path::PathBuf,
4 str::FromStr,
5};
6
7#[cfg(all(target_family = "unix", feature = "unix"))]
8use std::os::unix::net::SocketAddr as UnixSocketAddr;
9
10use bincode::{Decode, Encode};
11use url::Url;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::{Error, Result};
17
18pub type Port = u16;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[cfg_attr(feature = "serde", serde(into = "String"))]
44pub enum Endpoint {
45 Udp(Addr, Port),
46 Tcp(Addr, Port),
47 Tls(Addr, Port),
48 Quic(Addr, Port),
49 Http(Url),
50 Https(Url),
51 Ws(Url),
52 Wss(Url),
53 Unix(PathBuf),
54}
55
56impl Endpoint {
57 pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
59 Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
60 }
61
62 pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
64 Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
65 }
66
67 pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
69 Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
70 }
71
72 pub fn new_quic_addr(addr: SocketAddr) -> Endpoint {
74 Endpoint::Quic(Addr::Ip(addr.ip()), addr.port())
75 }
76
77 pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
79 Endpoint::Unix(addr.to_path_buf())
80 }
81
82 #[inline]
83 pub fn is_tcp(&self) -> bool {
85 matches!(self, Endpoint::Tcp(..))
86 }
87
88 #[inline]
89 pub fn is_tls(&self) -> bool {
91 matches!(self, Endpoint::Tls(..))
92 }
93
94 #[inline]
95 pub fn is_ws(&self) -> bool {
97 matches!(self, Endpoint::Ws(..))
98 }
99
100 #[inline]
101 pub fn is_wss(&self) -> bool {
103 matches!(self, Endpoint::Wss(..))
104 }
105
106 #[inline]
107 pub fn is_quic(&self) -> bool {
109 matches!(self, Endpoint::Quic(..))
110 }
111
112 #[inline]
113 pub fn is_udp(&self) -> bool {
115 matches!(self, Endpoint::Udp(..))
116 }
117
118 #[inline]
119 pub fn is_http(&self) -> bool {
121 matches!(self, Endpoint::Http(..))
122 }
123
124 #[inline]
125 pub fn is_https(&self) -> bool {
127 matches!(self, Endpoint::Https(..))
128 }
129
130 #[inline]
131 pub fn is_unix(&self) -> bool {
133 matches!(self, Endpoint::Unix(..))
134 }
135
136 pub fn port(&self) -> Result<Port> {
139 match self {
140 Endpoint::Tcp(_, port)
141 | Endpoint::Udp(_, port)
142 | Endpoint::Tls(_, port)
143 | Endpoint::Quic(_, port) => Ok(*port),
144 Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
145 url.port_or_known_default()
146 .ok_or_else(|| Error::ParseEndpoint(format!("port missing: {url}")))
147 }
148 Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
149 }
150 }
151
152 pub fn addr(&self) -> Result<Addr> {
154 match self {
155 Endpoint::Tcp(addr, _)
156 | Endpoint::Udp(addr, _)
157 | Endpoint::Tls(addr, _)
158 | Endpoint::Quic(addr, _) => Ok(addr.clone()),
159 Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
160 url_to_addr(url)
161 }
162 Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
163 }
164 }
165}
166
167fn url_to_addr(url: &Url) -> Result<Addr> {
168 let host = url
169 .host_str()
170 .ok_or_else(|| Error::ParseEndpoint(format!("host missing: {url}")))?;
171 match host.parse::<IpAddr>() {
172 Ok(ip) => Ok(Addr::Ip(ip)),
173 Err(_) => Ok(Addr::Domain(host.to_string())),
174 }
175}
176
177impl std::fmt::Display for Endpoint {
178 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
179 match self {
180 Endpoint::Udp(ip, port) => write!(f, "udp://{ip}:{port}"),
181 Endpoint::Tcp(ip, port) => write!(f, "tcp://{ip}:{port}"),
182 Endpoint::Tls(ip, port) => write!(f, "tls://{ip}:{port}"),
183 Endpoint::Quic(ip, port) => write!(f, "quic://{ip}:{port}"),
184 Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
185 write!(f, "{url}")
186 }
187 Endpoint::Unix(path) => write!(f, "unix:/{}", path.to_string_lossy()),
188 }
189 }
190}
191
192impl From<Endpoint> for String {
193 fn from(endpoint: Endpoint) -> String {
194 endpoint.to_string()
195 }
196}
197
198impl TryFrom<Endpoint> for SocketAddr {
199 type Error = Error;
200 fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
201 match endpoint {
202 Endpoint::Udp(addr, port)
203 | Endpoint::Tcp(addr, port)
204 | Endpoint::Tls(addr, port)
205 | Endpoint::Quic(addr, port) => resolve(addr, port),
206 Endpoint::Http(ref url)
207 | Endpoint::Https(ref url)
208 | Endpoint::Ws(ref url)
209 | Endpoint::Wss(ref url) => {
210 let addr = url_to_addr(url)?;
211 let port = endpoint.port()?;
212 resolve(addr, port)
213 }
214 Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
215 }
216 }
217}
218
219fn resolve(addr: Addr, port: Port) -> Result<SocketAddr> {
222 match addr {
223 Addr::Ip(ip) => Ok(SocketAddr::new(ip, port)),
224 Addr::Domain(d) => (d.as_str(), port)
225 .to_socket_addrs()?
226 .next()
227 .ok_or_else(|| Error::ParseEndpoint(format!("could not resolve {d}:{port}"))),
228 }
229}
230
231impl TryFrom<Endpoint> for PathBuf {
232 type Error = Error;
233 fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
234 match endpoint {
235 Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
236 _ => Err(Error::TryFromEndpoint),
237 }
238 }
239}
240
241#[cfg(all(feature = "unix", target_family = "unix"))]
242impl TryFrom<Endpoint> for UnixSocketAddr {
243 type Error = Error;
244 fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddr, Self::Error> {
245 match endpoint {
246 Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?),
247 _ => Err(Error::TryFromEndpoint),
248 }
249 }
250}
251
252impl TryFrom<Endpoint> for Url {
253 type Error = Error;
254
255 fn try_from(ep: Endpoint) -> Result<Self> {
256 match ep {
257 Endpoint::Http(u) | Endpoint::Https(u) | Endpoint::Ws(u) | Endpoint::Wss(u) => Ok(u),
258 other => other
259 .to_string()
260 .parse::<Url>()
261 .map_err(|e| Error::ParseEndpoint(e.to_string())),
262 }
263 }
264}
265
266impl TryFrom<Url> for Endpoint {
267 type Error = Error;
268
269 fn try_from(url: Url) -> Result<Self> {
270 match url.scheme() {
271 "http" => Ok(Endpoint::Http(url)),
272 "https" => Ok(Endpoint::Https(url)),
273 "ws" => Ok(Endpoint::Ws(url)),
274 "wss" => Ok(Endpoint::Wss(url)),
275 "tcp" | "udp" | "tls" | "quic" => {
276 let addr = url_to_addr(&url)?;
277 let port = url
278 .port()
279 .ok_or_else(|| Error::ParseEndpoint(format!("port missing: {url}")))?;
280 Ok(match url.scheme() {
281 "tcp" => Endpoint::Tcp(addr, port),
282 "udp" => Endpoint::Udp(addr, port),
283 "tls" => Endpoint::Tls(addr, port),
284 "quic" => Endpoint::Quic(addr, port),
285 _ => unreachable!(),
286 })
287 }
288 "unix" => {
289 if url.path().is_empty() {
290 return Err(Error::UnsupportedEndpoint(url.to_string()));
291 }
292 Ok(Endpoint::Unix(url.path().into()))
293 }
294 _ => Err(Error::UnsupportedEndpoint(url.to_string())),
295 }
296 }
297}
298
299impl FromStr for Endpoint {
300 type Err = Error;
301
302 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
303 let url: Url = s
304 .parse()
305 .map_err(|err: url::ParseError| Error::ParseEndpoint(err.to_string()))?;
306 Endpoint::try_from(url)
307 }
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
312#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
313pub enum Addr {
314 Ip(IpAddr),
315 Domain(String),
316}
317
318impl TryFrom<Addr> for IpAddr {
319 type Error = std::io::Error;
320 fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
321 match addr {
322 Addr::Ip(ip) => Ok(ip),
323 Addr::Domain(_) => Err(std::io::Error::new(
324 std::io::ErrorKind::InvalidInput,
325 "Addr::Domain cannot be converted to IpAddr without a port; \
326 use SocketAddr::try_from(Endpoint) instead",
327 )),
328 }
329 }
330}
331
332impl std::fmt::Display for Addr {
333 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
334 match self {
335 Addr::Ip(ip) => write!(f, "{ip}"),
336 Addr::Domain(d) => write!(f, "{d}"),
337 }
338 }
339}
340
341pub trait ToEndpoint {
342 fn to_endpoint(&self) -> Result<Endpoint>;
343}
344
345impl ToEndpoint for String {
346 fn to_endpoint(&self) -> Result<Endpoint> {
347 Endpoint::from_str(self)
348 }
349}
350
351impl ToEndpoint for Endpoint {
352 fn to_endpoint(&self) -> Result<Endpoint> {
353 Ok(self.clone())
354 }
355}
356
357impl ToEndpoint for &str {
358 fn to_endpoint(&self) -> Result<Endpoint> {
359 Endpoint::from_str(self)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use std::net::Ipv4Addr;
367 use std::path::PathBuf;
368
369 #[test]
370 fn test_endpoint_from_str() {
371 let endpoint_str: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
372 let endpoint = Endpoint::Tcp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 3000);
373 assert_eq!(endpoint_str, endpoint);
374
375 let endpoint_str: Endpoint = "udp://127.0.0.1:4000".parse().unwrap();
376 let endpoint = Endpoint::Udp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 4000);
377 assert_eq!(endpoint_str, endpoint);
378
379 let endpoint_str: Endpoint = "tcp://example.com:3000".parse().unwrap();
380 let endpoint = Endpoint::Tcp(Addr::Domain("example.com".to_string()), 3000);
381 assert_eq!(endpoint_str, endpoint);
382
383 let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap();
384 let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap());
385 assert_eq!(endpoint_str, endpoint);
386 }
387
388 #[test]
389 fn test_endpoint_url_preserves_path() {
390 let endpoint: Endpoint = "http://example.com:8080/rpc?x=1".parse().unwrap();
391 match &endpoint {
392 Endpoint::Http(url) => {
393 assert_eq!(url.path(), "/rpc");
394 assert_eq!(url.query(), Some("x=1"));
395 assert_eq!(url.port(), Some(8080));
396 }
397 _ => panic!("expected Http"),
398 }
399 assert_eq!(endpoint.port().unwrap(), 8080);
400 }
401
402 #[test]
403 fn test_endpoint_url_default_port() {
404 let endpoint: Endpoint = "https://example.com/".parse().unwrap();
405 assert_eq!(endpoint.port().unwrap(), 443);
406
407 let endpoint: Endpoint = "ws://example.com/".parse().unwrap();
408 assert_eq!(endpoint.port().unwrap(), 80);
409 }
410}