karyon_net/transports/
tcp.rs1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use karyon_core::async_runtime::{
9 io::{AsyncRead, AsyncWrite},
10 net::{TcpListener as AsyncTcpListener, TcpStream},
11};
12
13#[cfg(feature = "tokio")]
14use tokio::io::ReadBuf;
15
16use crate::{ByteStream, Endpoint, Result};
17
18#[derive(Clone)]
20pub struct TcpConfig {
21 pub nodelay: bool,
22}
23
24impl Default for TcpConfig {
25 fn default() -> Self {
26 Self { nodelay: true }
27 }
28}
29
30pub struct TcpByteStream {
32 inner: TcpStream,
33 peer_endpoint: Endpoint,
34 local_endpoint: Endpoint,
35}
36
37impl ByteStream for TcpByteStream {
38 fn peer_endpoint(&self) -> Option<Endpoint> {
39 Some(self.peer_endpoint.clone())
40 }
41 fn local_endpoint(&self) -> Option<Endpoint> {
42 Some(self.local_endpoint.clone())
43 }
44}
45
46pub async fn connect(endpoint: &Endpoint, config: TcpConfig) -> Result<Box<dyn ByteStream>> {
48 let addr = SocketAddr::try_from(endpoint.clone())?;
49 let socket = TcpStream::connect(addr).await?;
50 socket.set_nodelay(config.nodelay)?;
51 let peer = Endpoint::new_tcp_addr(socket.peer_addr()?);
52 let local = Endpoint::new_tcp_addr(socket.local_addr()?);
53 Ok(Box::new(TcpByteStream {
54 inner: socket,
55 peer_endpoint: peer,
56 local_endpoint: local,
57 }))
58}
59
60pub struct TcpListener {
62 inner: AsyncTcpListener,
63 config: TcpConfig,
64}
65
66impl TcpListener {
67 pub async fn bind(endpoint: &Endpoint, config: TcpConfig) -> Result<Self> {
69 let addr = SocketAddr::try_from(endpoint.clone())?;
70 let inner = AsyncTcpListener::bind(addr).await?;
71 Ok(Self { inner, config })
72 }
73
74 pub async fn accept(&self) -> Result<Box<dyn ByteStream>> {
76 let (socket, _) = self.inner.accept().await?;
77 socket.set_nodelay(self.config.nodelay)?;
78 let peer = Endpoint::new_tcp_addr(socket.peer_addr()?);
79 let local = Endpoint::new_tcp_addr(socket.local_addr()?);
80 Ok(Box::new(TcpByteStream {
81 inner: socket,
82 peer_endpoint: peer,
83 local_endpoint: local,
84 }))
85 }
86
87 pub fn local_endpoint(&self) -> Result<Endpoint> {
89 Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?))
90 }
91}
92
93#[cfg(feature = "smol")]
96impl AsyncRead for TcpByteStream {
97 fn poll_read(
98 mut self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 buf: &mut [u8],
101 ) -> Poll<io::Result<usize>> {
102 Pin::new(&mut self.inner).poll_read(cx, buf)
103 }
104}
105
106#[cfg(feature = "tokio")]
107impl AsyncRead for TcpByteStream {
108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<io::Result<()>> {
113 Pin::new(&mut self.inner).poll_read(cx, buf)
114 }
115}
116
117#[cfg(feature = "smol")]
118impl AsyncWrite for TcpByteStream {
119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 buf: &[u8],
123 ) -> Poll<io::Result<usize>> {
124 Pin::new(&mut self.inner).poll_write(cx, buf)
125 }
126
127 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128 Pin::new(&mut self.inner).poll_flush(cx)
129 }
130
131 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132 Pin::new(&mut self.inner).poll_close(cx)
133 }
134}
135
136#[cfg(feature = "tokio")]
137impl AsyncWrite for TcpByteStream {
138 fn poll_write(
139 mut self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 buf: &[u8],
142 ) -> Poll<io::Result<usize>> {
143 Pin::new(&mut self.inner).poll_write(cx, buf)
144 }
145
146 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147 Pin::new(&mut self.inner).poll_flush(cx)
148 }
149
150 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
151 Pin::new(&mut self.inner).poll_shutdown(cx)
152 }
153}