karyon_net/transports/
udp.rs

1use std::net::SocketAddr;
2
3use async_trait::async_trait;
4use karyon_core::async_runtime::net::UdpSocket;
5
6use crate::{
7    codec::{Buffer, Codec},
8    connection::{Conn, Connection, ToConn},
9    endpoint::Endpoint,
10    Result,
11};
12
13const BUFFER_SIZE: usize = 64 * 1024;
14
15/// UDP configuration
16#[derive(Default)]
17pub struct UdpConfig {}
18
19/// UDP network connection implementation of the [`Connection`] trait.
20#[allow(dead_code)]
21pub struct UdpConn<C> {
22    inner: UdpSocket,
23    codec: C,
24    config: UdpConfig,
25}
26
27impl<C> UdpConn<C>
28where
29    C: Codec + Clone,
30{
31    /// Creates a new UdpConn
32    fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self {
33        Self {
34            inner: socket,
35            codec,
36            config,
37        }
38    }
39}
40
41#[async_trait]
42impl<C, E> Connection for UdpConn<C>
43where
44    C: Codec<Error = E> + Clone,
45    E: From<std::io::Error>,
46{
47    type Message = (C::Message, Endpoint);
48    type Error = E;
49    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
50        Ok(self.inner.peer_addr().map(Endpoint::new_udp_addr)?)
51    }
52
53    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
54        Ok(self.inner.local_addr().map(Endpoint::new_udp_addr)?)
55    }
56
57    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
58        let mut buf = Buffer::new(BUFFER_SIZE);
59        let (_, addr) = self.inner.recv_from(buf.as_mut()).await?;
60        match self.codec.decode(&mut buf)? {
61            Some((_, msg)) => Ok((msg, Endpoint::new_udp_addr(addr))),
62            None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
63        }
64    }
65
66    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
67        let (msg, out_addr) = msg;
68        let mut buf = Buffer::new(BUFFER_SIZE);
69        self.codec.encode(&msg, &mut buf)?;
70        let addr: SocketAddr = out_addr
71            .try_into()
72            .map_err(|_| std::io::Error::other("Convert Endpoint to SocketAddress"))?;
73        self.inner.send_to(buf.as_ref(), addr).await?;
74        Ok(())
75    }
76}
77
78/// Connects to the given UDP address and port.
79pub async fn dial<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
80where
81    C: Codec + Clone,
82{
83    let addr = SocketAddr::try_from(endpoint.clone())?;
84
85    // Let the operating system assign an available port to this socket
86    let conn = UdpSocket::bind("[::]:0").await?;
87    conn.connect(addr).await?;
88    Ok(UdpConn::new(conn, config, codec))
89}
90
91/// Listens on the given UDP address and port.
92pub async fn listen<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
93where
94    C: Codec + Clone,
95{
96    let addr = SocketAddr::try_from(endpoint.clone())?;
97    let conn = UdpSocket::bind(addr).await?;
98    Ok(UdpConn::new(conn, config, codec))
99}
100
101impl<C, E> ToConn for UdpConn<C>
102where
103    C: Codec<Error = E> + Clone + 'static,
104    E: From<std::io::Error>,
105{
106    type Message = (C::Message, Endpoint);
107    type Error = E;
108    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
109        Box::new(self)
110    }
111}