Skip to main content

karyon_net/transports/
tcp.rs

1use std::{
2    io,
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use karyon_core::async_runtime::{
9    io::{AsyncRead, AsyncWrite},
10    net::{TcpListener as AsyncTcpListener, TcpStream},
11};
12
13#[cfg(feature = "tokio")]
14use tokio::io::ReadBuf;
15
16use crate::{ByteStream, Endpoint, Result};
17
18/// TCP config.
19#[derive(Clone)]
20pub struct TcpConfig {
21    pub nodelay: bool,
22}
23
24impl Default for TcpConfig {
25    fn default() -> Self {
26        Self { nodelay: true }
27    }
28}
29
30/// TCP stream implementing `ByteStream`.
31pub struct TcpByteStream {
32    inner: TcpStream,
33    peer_endpoint: Endpoint,
34    local_endpoint: Endpoint,
35}
36
37impl ByteStream for TcpByteStream {
38    fn peer_endpoint(&self) -> Option<Endpoint> {
39        Some(self.peer_endpoint.clone())
40    }
41    fn local_endpoint(&self) -> Option<Endpoint> {
42        Some(self.local_endpoint.clone())
43    }
44}
45
46/// Connect to a TCP endpoint.
47pub async fn connect(endpoint: &Endpoint, config: TcpConfig) -> Result<Box<dyn ByteStream>> {
48    let addr = SocketAddr::try_from(endpoint.clone())?;
49    let socket = TcpStream::connect(addr).await?;
50    socket.set_nodelay(config.nodelay)?;
51    let peer = Endpoint::new_tcp_addr(socket.peer_addr()?);
52    let local = Endpoint::new_tcp_addr(socket.local_addr()?);
53    Ok(Box::new(TcpByteStream {
54        inner: socket,
55        peer_endpoint: peer,
56        local_endpoint: local,
57    }))
58}
59
60/// TCP listener. Accepts connections as `Box<dyn ByteStream>`.
61pub struct TcpListener {
62    inner: AsyncTcpListener,
63    config: TcpConfig,
64}
65
66impl TcpListener {
67    /// Bind to a TCP endpoint.
68    pub async fn bind(endpoint: &Endpoint, config: TcpConfig) -> Result<Self> {
69        let addr = SocketAddr::try_from(endpoint.clone())?;
70        let inner = AsyncTcpListener::bind(addr).await?;
71        Ok(Self { inner, config })
72    }
73
74    /// Accept a new connection.
75    pub async fn accept(&self) -> Result<Box<dyn ByteStream>> {
76        let (socket, _) = self.inner.accept().await?;
77        socket.set_nodelay(self.config.nodelay)?;
78        let peer = Endpoint::new_tcp_addr(socket.peer_addr()?);
79        let local = Endpoint::new_tcp_addr(socket.local_addr()?);
80        Ok(Box::new(TcpByteStream {
81            inner: socket,
82            peer_endpoint: peer,
83            local_endpoint: local,
84        }))
85    }
86
87    /// Local endpoint this listener is bound to.
88    pub fn local_endpoint(&self) -> Result<Endpoint> {
89        Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?))
90    }
91}
92
93// -- AsyncRead / AsyncWrite delegation --
94
95#[cfg(feature = "smol")]
96impl AsyncRead for TcpByteStream {
97    fn poll_read(
98        mut self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100        buf: &mut [u8],
101    ) -> Poll<io::Result<usize>> {
102        Pin::new(&mut self.inner).poll_read(cx, buf)
103    }
104}
105
106#[cfg(feature = "tokio")]
107impl AsyncRead for TcpByteStream {
108    fn poll_read(
109        mut self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        buf: &mut ReadBuf<'_>,
112    ) -> Poll<io::Result<()>> {
113        Pin::new(&mut self.inner).poll_read(cx, buf)
114    }
115}
116
117#[cfg(feature = "smol")]
118impl AsyncWrite for TcpByteStream {
119    fn poll_write(
120        mut self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        buf: &[u8],
123    ) -> Poll<io::Result<usize>> {
124        Pin::new(&mut self.inner).poll_write(cx, buf)
125    }
126
127    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128        Pin::new(&mut self.inner).poll_flush(cx)
129    }
130
131    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132        Pin::new(&mut self.inner).poll_close(cx)
133    }
134}
135
136#[cfg(feature = "tokio")]
137impl AsyncWrite for TcpByteStream {
138    fn poll_write(
139        mut self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        buf: &[u8],
142    ) -> Poll<io::Result<usize>> {
143        Pin::new(&mut self.inner).poll_write(cx, buf)
144    }
145
146    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147        Pin::new(&mut self.inner).poll_flush(cx)
148    }
149
150    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
151        Pin::new(&mut self.inner).poll_shutdown(cx)
152    }
153}