karyon_net/transports/
unix.rs

1use async_trait::async_trait;
2use futures_util::SinkExt;
3
4use karyon_core::async_runtime::{
5    io::{split, ReadHalf, WriteHalf},
6    lock::Mutex,
7    net::{UnixListener as AsyncUnixListener, UnixStream},
8};
9
10use crate::{
11    codec::Codec,
12    connection::{Conn, Connection, ToConn},
13    endpoint::Endpoint,
14    listener::{ConnListener, Listener, ToListener},
15    stream::{ReadStream, WriteStream},
16    Result,
17};
18
19/// Unix Conn config
20#[derive(Clone, Default)]
21pub struct UnixConfig {}
22
23/// Unix domain socket implementation of the [`Connection`] trait.
24pub struct UnixConn<C> {
25    read_stream: Mutex<ReadStream<ReadHalf<UnixStream>, C>>,
26    write_stream: Mutex<WriteStream<WriteHalf<UnixStream>, C>>,
27    peer_endpoint: Option<Endpoint>,
28    local_endpoint: Option<Endpoint>,
29}
30
31impl<C> UnixConn<C>
32where
33    C: Codec + Clone,
34{
35    /// Creates a new TcpConn
36    pub fn new(conn: UnixStream, codec: C) -> Self {
37        let peer_endpoint = conn
38            .peer_addr()
39            .and_then(|a| {
40                Ok(Endpoint::new_unix_addr(
41                    a.as_pathname()
42                        .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
43                ))
44            })
45            .ok();
46        let local_endpoint = conn
47            .local_addr()
48            .and_then(|a| {
49                Ok(Endpoint::new_unix_addr(
50                    a.as_pathname()
51                        .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
52                ))
53            })
54            .ok();
55
56        let (read, write) = split(conn);
57        let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
58        let write_stream = Mutex::new(WriteStream::new(write, codec));
59        Self {
60            read_stream,
61            write_stream,
62            peer_endpoint,
63            local_endpoint,
64        }
65    }
66}
67
68#[async_trait]
69impl<C, E> Connection for UnixConn<C>
70where
71    C: Codec<Error = E> + Clone,
72    E: From<std::io::Error>,
73{
74    type Message = C::Message;
75    type Error = E;
76    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
77        Ok(self
78            .peer_endpoint
79            .clone()
80            .ok_or(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable))?)
81    }
82
83    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
84        Ok(self
85            .local_endpoint
86            .clone()
87            .ok_or(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable))?)
88    }
89
90    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
91        self.read_stream.lock().await.recv().await
92    }
93
94    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
95        self.write_stream.lock().await.send(msg).await
96    }
97}
98
99#[allow(dead_code)]
100pub struct UnixListener<C> {
101    inner: AsyncUnixListener,
102    config: UnixConfig,
103    codec: C,
104}
105
106impl<C> UnixListener<C>
107where
108    C: Codec + Clone,
109{
110    pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self {
111        Self {
112            inner: listener,
113            config,
114            codec,
115        }
116    }
117}
118
119#[async_trait]
120impl<C, E> ConnListener for UnixListener<C>
121where
122    C: Codec<Error = E> + Clone + 'static,
123    E: From<std::io::Error>,
124{
125    type Message = C::Message;
126    type Error = E;
127    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
128        Ok(self.inner.local_addr().and_then(|a| {
129            Ok(Endpoint::new_unix_addr(
130                a.as_pathname()
131                    .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
132            ))
133        })?)
134    }
135
136    async fn accept(&self) -> std::result::Result<Conn<C::Message, E>, Self::Error> {
137        let (conn, _) = self.inner.accept().await?;
138        Ok(Box::new(UnixConn::new(conn, self.codec.clone())))
139    }
140}
141
142/// Connects to the given Unix socket path.
143pub async fn dial<C>(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result<UnixConn<C>>
144where
145    C: Codec + Clone,
146{
147    let path: std::path::PathBuf = endpoint.clone().try_into()?;
148    let conn = UnixStream::connect(path).await?;
149    Ok(UnixConn::new(conn, codec))
150}
151
152/// Listens on the given Unix socket path.
153pub fn listen<C>(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result<UnixListener<C>>
154where
155    C: Codec + Clone,
156{
157    let path: std::path::PathBuf = endpoint.clone().try_into()?;
158    let listener = AsyncUnixListener::bind(path)?;
159    Ok(UnixListener::new(listener, config, codec))
160}
161
162// impl From<UnixStream> for Box<dyn Connection> {
163//     fn from(conn: UnixStream) -> Self {
164//         Box::new(UnixConn::new(conn))
165//     }
166// }
167
168impl<C, E> From<UnixListener<C>> for Listener<C::Message, E>
169where
170    C: Codec<Error = E> + Clone + 'static,
171    E: From<std::io::Error>,
172{
173    fn from(listener: UnixListener<C>) -> Self {
174        Box::new(listener)
175    }
176}
177
178impl<C, E> ToConn for UnixConn<C>
179where
180    C: Codec<Error = E> + Clone + 'static,
181    E: From<std::io::Error>,
182{
183    type Message = C::Message;
184    type Error = E;
185    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
186        Box::new(self)
187    }
188}
189
190impl<C, E> ToListener for UnixListener<C>
191where
192    C: Codec<Error = E> + Clone + 'static,
193    E: From<std::io::Error>,
194{
195    type Message = C::Message;
196    type Error = E;
197    fn to_listener(self) -> Listener<Self::Message, Self::Error> {
198        Box::new(self)
199    }
200}