karyon_net/transports/
tcp.rs

1use std::net::SocketAddr;
2
3use async_trait::async_trait;
4use futures_util::SinkExt;
5
6use karyon_core::async_runtime::{
7    io::{split, ReadHalf, WriteHalf},
8    lock::Mutex,
9    net::{TcpListener as AsyncTcpListener, TcpStream},
10};
11
12use crate::{
13    codec::Codec,
14    connection::{Conn, Connection, ToConn},
15    endpoint::Endpoint,
16    listener::{ConnListener, Listener, ToListener},
17    stream::{ReadStream, WriteStream},
18    Result,
19};
20
21/// TCP configuration
22#[derive(Clone)]
23pub struct TcpConfig {
24    pub nodelay: bool,
25}
26
27impl Default for TcpConfig {
28    fn default() -> Self {
29        Self { nodelay: true }
30    }
31}
32
33/// TCP connection implementation of the [`Connection`] trait.
34pub struct TcpConn<C> {
35    read_stream: Mutex<ReadStream<ReadHalf<TcpStream>, C>>,
36    write_stream: Mutex<WriteStream<WriteHalf<TcpStream>, C>>,
37    peer_endpoint: Endpoint,
38    local_endpoint: Endpoint,
39}
40
41impl<C> TcpConn<C>
42where
43    C: Codec + Clone,
44{
45    /// Creates a new TcpConn
46    pub fn new(
47        socket: TcpStream,
48        codec: C,
49        peer_endpoint: Endpoint,
50        local_endpoint: Endpoint,
51    ) -> Self {
52        let (read, write) = split(socket);
53        let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
54        let write_stream = Mutex::new(WriteStream::new(write, codec));
55        Self {
56            read_stream,
57            write_stream,
58            peer_endpoint,
59            local_endpoint,
60        }
61    }
62}
63
64#[async_trait]
65impl<C, E> Connection for TcpConn<C>
66where
67    C: Codec<Error = E> + Clone,
68{
69    type Message = C::Message;
70    type Error = E;
71    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
72        Ok(self.peer_endpoint.clone())
73    }
74
75    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
76        Ok(self.local_endpoint.clone())
77    }
78
79    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
80        self.read_stream.lock().await.recv().await
81    }
82
83    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
84        self.write_stream.lock().await.send(msg).await
85    }
86}
87
88pub struct TcpListener<C> {
89    inner: AsyncTcpListener,
90    config: TcpConfig,
91    codec: C,
92}
93
94impl<C> TcpListener<C>
95where
96    C: Codec,
97{
98    pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self {
99        Self {
100            inner: listener,
101            config: config.clone(),
102            codec,
103        }
104    }
105}
106
107#[async_trait]
108impl<C, E> ConnListener for TcpListener<C>
109where
110    C: Codec<Error = E> + Clone + 'static,
111    E: From<std::io::Error>,
112{
113    type Message = C::Message;
114    type Error = E;
115    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
116        Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?))
117    }
118
119    async fn accept(&self) -> std::result::Result<Conn<C::Message, C::Error>, Self::Error> {
120        let (socket, _) = self.inner.accept().await?;
121        socket.set_nodelay(self.config.nodelay)?;
122
123        let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
124        let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
125
126        Ok(Box::new(TcpConn::new(
127            socket,
128            self.codec.clone(),
129            peer_endpoint,
130            local_endpoint,
131        )))
132    }
133}
134
135/// Connects to the given TCP address and port.
136pub async fn dial<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpConn<C>>
137where
138    C: Codec + Clone,
139{
140    let addr = SocketAddr::try_from(endpoint.clone())?;
141    let socket = TcpStream::connect(addr).await?;
142    socket.set_nodelay(config.nodelay)?;
143
144    let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
145    let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
146
147    Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint))
148}
149
150/// Listens on the given TCP address and port.
151pub async fn listen<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpListener<C>>
152where
153    C: Codec,
154{
155    let addr = SocketAddr::try_from(endpoint.clone())?;
156    let listener = AsyncTcpListener::bind(addr).await?;
157    Ok(TcpListener::new(listener, config, codec))
158}
159
160impl<C, E> From<TcpListener<C>> for Box<dyn ConnListener<Message = C::Message, Error = E>>
161where
162    C: Clone + Codec<Error = E> + 'static,
163    E: From<std::io::Error>,
164{
165    fn from(listener: TcpListener<C>) -> Self {
166        Box::new(listener)
167    }
168}
169
170impl<C, E> ToConn for TcpConn<C>
171where
172    C: Codec<Error = E> + Clone + 'static,
173{
174    type Message = C::Message;
175    type Error = E;
176    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
177        Box::new(self)
178    }
179}
180
181impl<C, E> ToListener for TcpListener<C>
182where
183    C: Clone + Codec<Error = E> + 'static,
184    E: From<std::io::Error>,
185{
186    type Message = C::Message;
187    type Error = E;
188    fn to_listener(self) -> Listener<Self::Message, Self::Error> {
189        Box::new(self)
190    }
191}