karyon_net/
endpoint.rs

1use std::{
2    net::{IpAddr, SocketAddr},
3    path::PathBuf,
4    str::FromStr,
5};
6
7#[cfg(all(target_family = "unix", feature = "unix"))]
8use std::os::unix::net::SocketAddr as UnixSocketAddr;
9
10use bincode::{Decode, Encode};
11use url::Url;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::{Error, Result};
17
18/// Port defined as a u16.
19pub type Port = u16;
20
21/// Endpoint defines generic network endpoints for karyon.
22///
23/// # Example
24///
25/// ```
26/// use std::net::SocketAddr;
27///
28/// use karyon_net::Endpoint;
29///
30/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
31///
32/// let socketaddr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
33/// let endpoint =  Endpoint::new_udp_addr(socketaddr);
34///
35/// ```
36///
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[cfg_attr(feature = "serde", serde(into = "String"))]
40pub enum Endpoint {
41    Udp(Addr, Port),
42    Tcp(Addr, Port),
43    Tls(Addr, Port),
44    Ws(Addr, Port),
45    Wss(Addr, Port),
46    Unix(PathBuf),
47}
48
49impl Endpoint {
50    /// Creates a new TCP endpoint from a `SocketAddr`.
51    pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
52        Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
53    }
54
55    /// Creates a new UDP endpoint from a `SocketAddr`.
56    pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
57        Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
58    }
59
60    /// Creates a new TLS endpoint from a `SocketAddr`.
61    pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
62        Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
63    }
64
65    /// Creates a new WS endpoint from a `SocketAddr`.
66    pub fn new_ws_addr(addr: SocketAddr) -> Endpoint {
67        Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
68    }
69
70    /// Creates a new WSS endpoint from a `SocketAddr`.
71    pub fn new_wss_addr(addr: SocketAddr) -> Endpoint {
72        Endpoint::Wss(Addr::Ip(addr.ip()), addr.port())
73    }
74
75    /// Creates a new Unix endpoint from a `UnixSocketAddr`.
76    pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
77        Endpoint::Unix(addr.to_path_buf())
78    }
79
80    #[inline]
81    /// Checks if the `Endpoint` is of type `Tcp`.
82    pub fn is_tcp(&self) -> bool {
83        matches!(self, Endpoint::Tcp(..))
84    }
85
86    #[inline]
87    /// Checks if the `Endpoint` is of type `Tls`.
88    pub fn is_tls(&self) -> bool {
89        matches!(self, Endpoint::Tls(..))
90    }
91
92    #[inline]
93    /// Checks if the `Endpoint` is of type `Ws`.
94    pub fn is_ws(&self) -> bool {
95        matches!(self, Endpoint::Ws(..))
96    }
97
98    #[inline]
99    /// Checks if the `Endpoint` is of type `Wss`.
100    pub fn is_wss(&self) -> bool {
101        matches!(self, Endpoint::Wss(..))
102    }
103
104    #[inline]
105    /// Checks if the `Endpoint` is of type `Udp`.
106    pub fn is_udp(&self) -> bool {
107        matches!(self, Endpoint::Udp(..))
108    }
109
110    #[inline]
111    /// Checks if the `Endpoint` is of type `Unix`.
112    pub fn is_unix(&self) -> bool {
113        matches!(self, Endpoint::Unix(..))
114    }
115
116    /// Returns the `Port` of the endpoint.
117    pub fn port(&self) -> Result<&Port> {
118        match self {
119            Endpoint::Tcp(_, port)
120            | Endpoint::Udp(_, port)
121            | Endpoint::Tls(_, port)
122            | Endpoint::Ws(_, port)
123            | Endpoint::Wss(_, port) => Ok(port),
124            _ => Err(Error::TryFromEndpoint),
125        }
126    }
127
128    /// Returns the `Addr` of the endpoint.
129    pub fn addr(&self) -> Result<&Addr> {
130        match self {
131            Endpoint::Tcp(addr, _)
132            | Endpoint::Udp(addr, _)
133            | Endpoint::Tls(addr, _)
134            | Endpoint::Ws(addr, _)
135            | Endpoint::Wss(addr, _) => Ok(addr),
136            _ => Err(Error::TryFromEndpoint),
137        }
138    }
139}
140
141impl std::fmt::Display for Endpoint {
142    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
143        match self {
144            Endpoint::Udp(ip, port) => {
145                write!(f, "udp://{}:{}", ip, port)
146            }
147            Endpoint::Tcp(ip, port) => {
148                write!(f, "tcp://{}:{}", ip, port)
149            }
150            Endpoint::Tls(ip, port) => {
151                write!(f, "tls://{}:{}", ip, port)
152            }
153            Endpoint::Ws(ip, port) => {
154                write!(f, "ws://{}:{}", ip, port)
155            }
156            Endpoint::Wss(ip, port) => {
157                write!(f, "wss://{}:{}", ip, port)
158            }
159            Endpoint::Unix(path) => {
160                write!(f, "unix:/{}", path.to_string_lossy())
161            }
162        }
163    }
164}
165impl From<Endpoint> for String {
166    fn from(endpoint: Endpoint) -> String {
167        endpoint.to_string()
168    }
169}
170
171impl TryFrom<Endpoint> for SocketAddr {
172    type Error = Error;
173    fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
174        match endpoint {
175            Endpoint::Udp(ip, port)
176            | Endpoint::Tcp(ip, port)
177            | Endpoint::Tls(ip, port)
178            | Endpoint::Ws(ip, port)
179            | Endpoint::Wss(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
180            Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
181        }
182    }
183}
184
185impl TryFrom<Endpoint> for PathBuf {
186    type Error = Error;
187    fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
188        match endpoint {
189            Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
190            _ => Err(Error::TryFromEndpoint),
191        }
192    }
193}
194
195#[cfg(all(feature = "unix", target_family = "unix"))]
196impl TryFrom<Endpoint> for UnixSocketAddr {
197    type Error = Error;
198    fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddr, Self::Error> {
199        match endpoint {
200            Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?),
201            _ => Err(Error::TryFromEndpoint),
202        }
203    }
204}
205
206impl FromStr for Endpoint {
207    type Err = Error;
208
209    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
210        let url: Url = match s.parse() {
211            Ok(u) => u,
212            Err(err) => return Err(Error::ParseEndpoint(err.to_string())),
213        };
214
215        if url.has_host() {
216            let host = url.host_str().unwrap();
217
218            let addr = match host.parse::<IpAddr>() {
219                Ok(addr) => Addr::Ip(addr),
220                Err(_) => Addr::Domain(host.to_string()),
221            };
222
223            let port = match url.port() {
224                Some(p) => p,
225                None => return Err(Error::ParseEndpoint(format!("port missing: {s}"))),
226            };
227
228            match url.scheme() {
229                "tcp" => Ok(Endpoint::Tcp(addr, port)),
230                "udp" => Ok(Endpoint::Udp(addr, port)),
231                "tls" => Ok(Endpoint::Tls(addr, port)),
232                "ws" => Ok(Endpoint::Ws(addr, port)),
233                "wss" => Ok(Endpoint::Wss(addr, port)),
234                _ => Err(Error::UnsupportedEndpoint(s.to_string())),
235            }
236        } else {
237            if url.path().is_empty() {
238                return Err(Error::UnsupportedEndpoint(s.to_string()));
239            }
240
241            match url.scheme() {
242                "unix" => Ok(Endpoint::Unix(url.path().into())),
243                _ => Err(Error::UnsupportedEndpoint(s.to_string())),
244            }
245        }
246    }
247}
248
249/// Addr defines a type for an address, either IP or domain.
250#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
251#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
252pub enum Addr {
253    Ip(IpAddr),
254    Domain(String),
255}
256
257impl TryFrom<Addr> for IpAddr {
258    type Error = std::io::Error;
259    fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
260        match addr {
261            Addr::Ip(ip) => Ok(ip),
262            Addr::Domain(_) => Err(std::io::ErrorKind::Unsupported.into()),
263        }
264    }
265}
266
267impl std::fmt::Display for Addr {
268    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269        match self {
270            Addr::Ip(ip) => {
271                write!(f, "{}", ip)
272            }
273            Addr::Domain(d) => {
274                write!(f, "{}", d)
275            }
276        }
277    }
278}
279
280pub trait ToEndpoint {
281    fn to_endpoint(&self) -> Result<Endpoint>;
282}
283
284impl ToEndpoint for String {
285    fn to_endpoint(&self) -> Result<Endpoint> {
286        Endpoint::from_str(self)
287    }
288}
289
290impl ToEndpoint for Endpoint {
291    fn to_endpoint(&self) -> Result<Endpoint> {
292        Ok(self.clone())
293    }
294}
295
296impl ToEndpoint for &str {
297    fn to_endpoint(&self) -> Result<Endpoint> {
298        Endpoint::from_str(self)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use std::net::Ipv4Addr;
306    use std::path::PathBuf;
307
308    #[test]
309    fn test_endpoint_from_str() {
310        let endpoint_str: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
311        let endpoint = Endpoint::Tcp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 3000);
312        assert_eq!(endpoint_str, endpoint);
313
314        let endpoint_str: Endpoint = "udp://127.0.0.1:4000".parse().unwrap();
315        let endpoint = Endpoint::Udp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 4000);
316        assert_eq!(endpoint_str, endpoint);
317
318        let endpoint_str: Endpoint = "tcp://example.com:3000".parse().unwrap();
319        let endpoint = Endpoint::Tcp(Addr::Domain("example.com".to_string()), 3000);
320        assert_eq!(endpoint_str, endpoint);
321
322        let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap();
323        let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap());
324        assert_eq!(endpoint_str, endpoint);
325    }
326}