Skip to main content

karyon_net/
endpoint.rs

1use std::{
2    net::{IpAddr, SocketAddr, ToSocketAddrs},
3    path::PathBuf,
4    str::FromStr,
5};
6
7#[cfg(all(target_family = "unix", feature = "unix"))]
8use std::os::unix::net::SocketAddr as UnixSocketAddr;
9
10use bincode::{Decode, Encode};
11use url::Url;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::{Error, Result};
17
18/// Port defined as a u16.
19pub type Port = u16;
20
21/// Endpoint defines generic network endpoints for karyon.
22///
23/// Transport schemes (TCP, UDP, TLS, QUIC, Unix) carry just address + port.
24/// URL schemes (HTTP, HTTPS, WS, WSS) carry a full `Url` so path, query,
25/// and other URL semantics are preserved.
26///
27/// # Example
28///
29/// ```
30/// use std::net::SocketAddr;
31///
32/// use karyon_net::Endpoint;
33///
34/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
35///
36/// let socketaddr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
37/// let endpoint = Endpoint::new_udp_addr(socketaddr);
38///
39/// let endpoint: Endpoint = "http://example.com:8080/rpc".parse().unwrap();
40/// ```
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[cfg_attr(feature = "serde", serde(into = "String"))]
44pub enum Endpoint {
45    Udp(Addr, Port),
46    Tcp(Addr, Port),
47    Tls(Addr, Port),
48    Quic(Addr, Port),
49    Http(Url),
50    Https(Url),
51    Ws(Url),
52    Wss(Url),
53    Unix(PathBuf),
54}
55
56impl Endpoint {
57    /// Creates a new TCP endpoint from a `SocketAddr`.
58    pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
59        Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
60    }
61
62    /// Creates a new UDP endpoint from a `SocketAddr`.
63    pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
64        Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
65    }
66
67    /// Creates a new TLS endpoint from a `SocketAddr`.
68    pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
69        Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
70    }
71
72    /// Creates a new QUIC endpoint from a `SocketAddr`.
73    pub fn new_quic_addr(addr: SocketAddr) -> Endpoint {
74        Endpoint::Quic(Addr::Ip(addr.ip()), addr.port())
75    }
76
77    /// Creates a new Unix endpoint from a `UnixSocketAddr`.
78    pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
79        Endpoint::Unix(addr.to_path_buf())
80    }
81
82    #[inline]
83    /// Checks if the `Endpoint` is of type `Tcp`.
84    pub fn is_tcp(&self) -> bool {
85        matches!(self, Endpoint::Tcp(..))
86    }
87
88    #[inline]
89    /// Checks if the `Endpoint` is of type `Tls`.
90    pub fn is_tls(&self) -> bool {
91        matches!(self, Endpoint::Tls(..))
92    }
93
94    #[inline]
95    /// Checks if the `Endpoint` is of type `Ws`.
96    pub fn is_ws(&self) -> bool {
97        matches!(self, Endpoint::Ws(..))
98    }
99
100    #[inline]
101    /// Checks if the `Endpoint` is of type `Wss`.
102    pub fn is_wss(&self) -> bool {
103        matches!(self, Endpoint::Wss(..))
104    }
105
106    #[inline]
107    /// Checks if the `Endpoint` is of type `Quic`.
108    pub fn is_quic(&self) -> bool {
109        matches!(self, Endpoint::Quic(..))
110    }
111
112    #[inline]
113    /// Checks if the `Endpoint` is of type `Udp`.
114    pub fn is_udp(&self) -> bool {
115        matches!(self, Endpoint::Udp(..))
116    }
117
118    #[inline]
119    /// Checks if the `Endpoint` is of type `Http`.
120    pub fn is_http(&self) -> bool {
121        matches!(self, Endpoint::Http(..))
122    }
123
124    #[inline]
125    /// Checks if the `Endpoint` is of type `Https`.
126    pub fn is_https(&self) -> bool {
127        matches!(self, Endpoint::Https(..))
128    }
129
130    #[inline]
131    /// Checks if the `Endpoint` is of type `Unix`.
132    pub fn is_unix(&self) -> bool {
133        matches!(self, Endpoint::Unix(..))
134    }
135
136    /// Returns the port of the endpoint. For URL-family endpoints the
137    /// scheme's default port is returned if the URL omits a port.
138    pub fn port(&self) -> Result<Port> {
139        match self {
140            Endpoint::Tcp(_, port)
141            | Endpoint::Udp(_, port)
142            | Endpoint::Tls(_, port)
143            | Endpoint::Quic(_, port) => Ok(*port),
144            Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
145                url.port_or_known_default()
146                    .ok_or_else(|| Error::ParseEndpoint(format!("port missing: {url}")))
147            }
148            Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
149        }
150    }
151
152    /// Returns the address of the endpoint.
153    pub fn addr(&self) -> Result<Addr> {
154        match self {
155            Endpoint::Tcp(addr, _)
156            | Endpoint::Udp(addr, _)
157            | Endpoint::Tls(addr, _)
158            | Endpoint::Quic(addr, _) => Ok(addr.clone()),
159            Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
160                url_to_addr(url)
161            }
162            Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
163        }
164    }
165}
166
167fn url_to_addr(url: &Url) -> Result<Addr> {
168    let host = url
169        .host_str()
170        .ok_or_else(|| Error::ParseEndpoint(format!("host missing: {url}")))?;
171    match host.parse::<IpAddr>() {
172        Ok(ip) => Ok(Addr::Ip(ip)),
173        Err(_) => Ok(Addr::Domain(host.to_string())),
174    }
175}
176
177impl std::fmt::Display for Endpoint {
178    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
179        match self {
180            Endpoint::Udp(ip, port) => write!(f, "udp://{ip}:{port}"),
181            Endpoint::Tcp(ip, port) => write!(f, "tcp://{ip}:{port}"),
182            Endpoint::Tls(ip, port) => write!(f, "tls://{ip}:{port}"),
183            Endpoint::Quic(ip, port) => write!(f, "quic://{ip}:{port}"),
184            Endpoint::Http(url) | Endpoint::Https(url) | Endpoint::Ws(url) | Endpoint::Wss(url) => {
185                write!(f, "{url}")
186            }
187            Endpoint::Unix(path) => write!(f, "unix:/{}", path.to_string_lossy()),
188        }
189    }
190}
191
192impl From<Endpoint> for String {
193    fn from(endpoint: Endpoint) -> String {
194        endpoint.to_string()
195    }
196}
197
198impl TryFrom<Endpoint> for SocketAddr {
199    type Error = Error;
200    fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
201        match endpoint {
202            Endpoint::Udp(addr, port)
203            | Endpoint::Tcp(addr, port)
204            | Endpoint::Tls(addr, port)
205            | Endpoint::Quic(addr, port) => resolve(addr, port),
206            Endpoint::Http(ref url)
207            | Endpoint::Https(ref url)
208            | Endpoint::Ws(ref url)
209            | Endpoint::Wss(ref url) => {
210                let addr = url_to_addr(url)?;
211                let port = endpoint.port()?;
212                resolve(addr, port)
213            }
214            Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
215        }
216    }
217}
218
219/// Resolve an `Addr` + port to a `SocketAddr`. Domains are resolved
220/// through the system DNS with the given port.
221fn resolve(addr: Addr, port: Port) -> Result<SocketAddr> {
222    match addr {
223        Addr::Ip(ip) => Ok(SocketAddr::new(ip, port)),
224        Addr::Domain(d) => (d.as_str(), port)
225            .to_socket_addrs()?
226            .next()
227            .ok_or_else(|| Error::ParseEndpoint(format!("could not resolve {d}:{port}"))),
228    }
229}
230
231impl TryFrom<Endpoint> for PathBuf {
232    type Error = Error;
233    fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
234        match endpoint {
235            Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
236            _ => Err(Error::TryFromEndpoint),
237        }
238    }
239}
240
241#[cfg(all(feature = "unix", target_family = "unix"))]
242impl TryFrom<Endpoint> for UnixSocketAddr {
243    type Error = Error;
244    fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddr, Self::Error> {
245        match endpoint {
246            Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?),
247            _ => Err(Error::TryFromEndpoint),
248        }
249    }
250}
251
252impl TryFrom<Endpoint> for Url {
253    type Error = Error;
254
255    fn try_from(ep: Endpoint) -> Result<Self> {
256        match ep {
257            Endpoint::Http(u) | Endpoint::Https(u) | Endpoint::Ws(u) | Endpoint::Wss(u) => Ok(u),
258            other => other
259                .to_string()
260                .parse::<Url>()
261                .map_err(|e| Error::ParseEndpoint(e.to_string())),
262        }
263    }
264}
265
266impl TryFrom<Url> for Endpoint {
267    type Error = Error;
268
269    fn try_from(url: Url) -> Result<Self> {
270        match url.scheme() {
271            "http" => Ok(Endpoint::Http(url)),
272            "https" => Ok(Endpoint::Https(url)),
273            "ws" => Ok(Endpoint::Ws(url)),
274            "wss" => Ok(Endpoint::Wss(url)),
275            "tcp" | "udp" | "tls" | "quic" => {
276                let addr = url_to_addr(&url)?;
277                let port = url
278                    .port()
279                    .ok_or_else(|| Error::ParseEndpoint(format!("port missing: {url}")))?;
280                Ok(match url.scheme() {
281                    "tcp" => Endpoint::Tcp(addr, port),
282                    "udp" => Endpoint::Udp(addr, port),
283                    "tls" => Endpoint::Tls(addr, port),
284                    "quic" => Endpoint::Quic(addr, port),
285                    _ => unreachable!(),
286                })
287            }
288            "unix" => {
289                if url.path().is_empty() {
290                    return Err(Error::UnsupportedEndpoint(url.to_string()));
291                }
292                Ok(Endpoint::Unix(url.path().into()))
293            }
294            _ => Err(Error::UnsupportedEndpoint(url.to_string())),
295        }
296    }
297}
298
299impl FromStr for Endpoint {
300    type Err = Error;
301
302    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
303        let url: Url = s
304            .parse()
305            .map_err(|err: url::ParseError| Error::ParseEndpoint(err.to_string()))?;
306        Endpoint::try_from(url)
307    }
308}
309
310/// Addr defines a type for an address, either IP or domain.
311#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
312#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
313pub enum Addr {
314    Ip(IpAddr),
315    Domain(String),
316}
317
318impl TryFrom<Addr> for IpAddr {
319    type Error = std::io::Error;
320    fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
321        match addr {
322            Addr::Ip(ip) => Ok(ip),
323            Addr::Domain(_) => Err(std::io::Error::new(
324                std::io::ErrorKind::InvalidInput,
325                "Addr::Domain cannot be converted to IpAddr without a port; \
326                 use SocketAddr::try_from(Endpoint) instead",
327            )),
328        }
329    }
330}
331
332impl std::fmt::Display for Addr {
333    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
334        match self {
335            Addr::Ip(ip) => write!(f, "{ip}"),
336            Addr::Domain(d) => write!(f, "{d}"),
337        }
338    }
339}
340
341pub trait ToEndpoint {
342    fn to_endpoint(&self) -> Result<Endpoint>;
343}
344
345impl ToEndpoint for String {
346    fn to_endpoint(&self) -> Result<Endpoint> {
347        Endpoint::from_str(self)
348    }
349}
350
351impl ToEndpoint for Endpoint {
352    fn to_endpoint(&self) -> Result<Endpoint> {
353        Ok(self.clone())
354    }
355}
356
357impl ToEndpoint for &str {
358    fn to_endpoint(&self) -> Result<Endpoint> {
359        Endpoint::from_str(self)
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use std::net::Ipv4Addr;
367    use std::path::PathBuf;
368
369    #[test]
370    fn test_endpoint_from_str() {
371        let endpoint_str: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
372        let endpoint = Endpoint::Tcp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 3000);
373        assert_eq!(endpoint_str, endpoint);
374
375        let endpoint_str: Endpoint = "udp://127.0.0.1:4000".parse().unwrap();
376        let endpoint = Endpoint::Udp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 4000);
377        assert_eq!(endpoint_str, endpoint);
378
379        let endpoint_str: Endpoint = "tcp://example.com:3000".parse().unwrap();
380        let endpoint = Endpoint::Tcp(Addr::Domain("example.com".to_string()), 3000);
381        assert_eq!(endpoint_str, endpoint);
382
383        let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap();
384        let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap());
385        assert_eq!(endpoint_str, endpoint);
386    }
387
388    #[test]
389    fn test_endpoint_url_preserves_path() {
390        let endpoint: Endpoint = "http://example.com:8080/rpc?x=1".parse().unwrap();
391        match &endpoint {
392            Endpoint::Http(url) => {
393                assert_eq!(url.path(), "/rpc");
394                assert_eq!(url.query(), Some("x=1"));
395                assert_eq!(url.port(), Some(8080));
396            }
397            _ => panic!("expected Http"),
398        }
399        assert_eq!(endpoint.port().unwrap(), 8080);
400    }
401
402    #[test]
403    fn test_endpoint_url_default_port() {
404        let endpoint: Endpoint = "https://example.com/".parse().unwrap();
405        assert_eq!(endpoint.port().unwrap(), 443);
406
407        let endpoint: Endpoint = "ws://example.com/".parse().unwrap();
408        assert_eq!(endpoint.port().unwrap(), 80);
409    }
410}