1use std::{
2 net::{IpAddr, SocketAddr},
3 path::PathBuf,
4 str::FromStr,
5};
6
7#[cfg(all(target_family = "unix", feature = "unix"))]
8use std::os::unix::net::SocketAddr as UnixSocketAddr;
9
10use bincode::{Decode, Encode};
11use url::Url;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::{Error, Result};
17
18pub type Port = u16;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[cfg_attr(feature = "serde", serde(into = "String"))]
40pub enum Endpoint {
41 Udp(Addr, Port),
42 Tcp(Addr, Port),
43 Tls(Addr, Port),
44 Ws(Addr, Port),
45 Wss(Addr, Port),
46 Unix(PathBuf),
47}
48
49impl Endpoint {
50 pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
52 Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
53 }
54
55 pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
57 Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
58 }
59
60 pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
62 Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
63 }
64
65 pub fn new_ws_addr(addr: SocketAddr) -> Endpoint {
67 Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
68 }
69
70 pub fn new_wss_addr(addr: SocketAddr) -> Endpoint {
72 Endpoint::Wss(Addr::Ip(addr.ip()), addr.port())
73 }
74
75 pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
77 Endpoint::Unix(addr.to_path_buf())
78 }
79
80 #[inline]
81 pub fn is_tcp(&self) -> bool {
83 matches!(self, Endpoint::Tcp(..))
84 }
85
86 #[inline]
87 pub fn is_tls(&self) -> bool {
89 matches!(self, Endpoint::Tls(..))
90 }
91
92 #[inline]
93 pub fn is_ws(&self) -> bool {
95 matches!(self, Endpoint::Ws(..))
96 }
97
98 #[inline]
99 pub fn is_wss(&self) -> bool {
101 matches!(self, Endpoint::Wss(..))
102 }
103
104 #[inline]
105 pub fn is_udp(&self) -> bool {
107 matches!(self, Endpoint::Udp(..))
108 }
109
110 #[inline]
111 pub fn is_unix(&self) -> bool {
113 matches!(self, Endpoint::Unix(..))
114 }
115
116 pub fn port(&self) -> Result<&Port> {
118 match self {
119 Endpoint::Tcp(_, port)
120 | Endpoint::Udp(_, port)
121 | Endpoint::Tls(_, port)
122 | Endpoint::Ws(_, port)
123 | Endpoint::Wss(_, port) => Ok(port),
124 _ => Err(Error::TryFromEndpoint),
125 }
126 }
127
128 pub fn addr(&self) -> Result<&Addr> {
130 match self {
131 Endpoint::Tcp(addr, _)
132 | Endpoint::Udp(addr, _)
133 | Endpoint::Tls(addr, _)
134 | Endpoint::Ws(addr, _)
135 | Endpoint::Wss(addr, _) => Ok(addr),
136 _ => Err(Error::TryFromEndpoint),
137 }
138 }
139}
140
141impl std::fmt::Display for Endpoint {
142 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
143 match self {
144 Endpoint::Udp(ip, port) => {
145 write!(f, "udp://{}:{}", ip, port)
146 }
147 Endpoint::Tcp(ip, port) => {
148 write!(f, "tcp://{}:{}", ip, port)
149 }
150 Endpoint::Tls(ip, port) => {
151 write!(f, "tls://{}:{}", ip, port)
152 }
153 Endpoint::Ws(ip, port) => {
154 write!(f, "ws://{}:{}", ip, port)
155 }
156 Endpoint::Wss(ip, port) => {
157 write!(f, "wss://{}:{}", ip, port)
158 }
159 Endpoint::Unix(path) => {
160 write!(f, "unix:/{}", path.to_string_lossy())
161 }
162 }
163 }
164}
165impl From<Endpoint> for String {
166 fn from(endpoint: Endpoint) -> String {
167 endpoint.to_string()
168 }
169}
170
171impl TryFrom<Endpoint> for SocketAddr {
172 type Error = Error;
173 fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
174 match endpoint {
175 Endpoint::Udp(ip, port)
176 | Endpoint::Tcp(ip, port)
177 | Endpoint::Tls(ip, port)
178 | Endpoint::Ws(ip, port)
179 | Endpoint::Wss(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
180 Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
181 }
182 }
183}
184
185impl TryFrom<Endpoint> for PathBuf {
186 type Error = Error;
187 fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
188 match endpoint {
189 Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
190 _ => Err(Error::TryFromEndpoint),
191 }
192 }
193}
194
195#[cfg(all(feature = "unix", target_family = "unix"))]
196impl TryFrom<Endpoint> for UnixSocketAddr {
197 type Error = Error;
198 fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddr, Self::Error> {
199 match endpoint {
200 Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?),
201 _ => Err(Error::TryFromEndpoint),
202 }
203 }
204}
205
206impl FromStr for Endpoint {
207 type Err = Error;
208
209 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
210 let url: Url = match s.parse() {
211 Ok(u) => u,
212 Err(err) => return Err(Error::ParseEndpoint(err.to_string())),
213 };
214
215 if url.has_host() {
216 let host = url.host_str().unwrap();
217
218 let addr = match host.parse::<IpAddr>() {
219 Ok(addr) => Addr::Ip(addr),
220 Err(_) => Addr::Domain(host.to_string()),
221 };
222
223 let port = match url.port() {
224 Some(p) => p,
225 None => return Err(Error::ParseEndpoint(format!("port missing: {s}"))),
226 };
227
228 match url.scheme() {
229 "tcp" => Ok(Endpoint::Tcp(addr, port)),
230 "udp" => Ok(Endpoint::Udp(addr, port)),
231 "tls" => Ok(Endpoint::Tls(addr, port)),
232 "ws" => Ok(Endpoint::Ws(addr, port)),
233 "wss" => Ok(Endpoint::Wss(addr, port)),
234 _ => Err(Error::UnsupportedEndpoint(s.to_string())),
235 }
236 } else {
237 if url.path().is_empty() {
238 return Err(Error::UnsupportedEndpoint(s.to_string()));
239 }
240
241 match url.scheme() {
242 "unix" => Ok(Endpoint::Unix(url.path().into())),
243 _ => Err(Error::UnsupportedEndpoint(s.to_string())),
244 }
245 }
246 }
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
251#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
252pub enum Addr {
253 Ip(IpAddr),
254 Domain(String),
255}
256
257impl TryFrom<Addr> for IpAddr {
258 type Error = std::io::Error;
259 fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
260 match addr {
261 Addr::Ip(ip) => Ok(ip),
262 Addr::Domain(_) => Err(std::io::ErrorKind::Unsupported.into()),
263 }
264 }
265}
266
267impl std::fmt::Display for Addr {
268 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269 match self {
270 Addr::Ip(ip) => {
271 write!(f, "{}", ip)
272 }
273 Addr::Domain(d) => {
274 write!(f, "{}", d)
275 }
276 }
277 }
278}
279
280pub trait ToEndpoint {
281 fn to_endpoint(&self) -> Result<Endpoint>;
282}
283
284impl ToEndpoint for String {
285 fn to_endpoint(&self) -> Result<Endpoint> {
286 Endpoint::from_str(self)
287 }
288}
289
290impl ToEndpoint for Endpoint {
291 fn to_endpoint(&self) -> Result<Endpoint> {
292 Ok(self.clone())
293 }
294}
295
296impl ToEndpoint for &str {
297 fn to_endpoint(&self) -> Result<Endpoint> {
298 Endpoint::from_str(self)
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use std::net::Ipv4Addr;
306 use std::path::PathBuf;
307
308 #[test]
309 fn test_endpoint_from_str() {
310 let endpoint_str: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
311 let endpoint = Endpoint::Tcp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 3000);
312 assert_eq!(endpoint_str, endpoint);
313
314 let endpoint_str: Endpoint = "udp://127.0.0.1:4000".parse().unwrap();
315 let endpoint = Endpoint::Udp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 4000);
316 assert_eq!(endpoint_str, endpoint);
317
318 let endpoint_str: Endpoint = "tcp://example.com:3000".parse().unwrap();
319 let endpoint = Endpoint::Tcp(Addr::Domain("example.com".to_string()), 3000);
320 assert_eq!(endpoint_str, endpoint);
321
322 let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap();
323 let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap());
324 assert_eq!(endpoint_str, endpoint);
325 }
326}