karyon_net/transports/
udp.rs

1use std::net::SocketAddr;
2
3use async_trait::async_trait;
4use karyon_core::async_runtime::net::UdpSocket;
5
6use crate::{
7    codec::{Buffer, Codec},
8    connection::{Conn, Connection, ToConn},
9    endpoint::Endpoint,
10    Result,
11};
12
13/// UDP configuration
14#[derive(Default)]
15pub struct UdpConfig {}
16
17/// UDP network connection implementation of the [`Connection`] trait.
18#[allow(dead_code)]
19pub struct UdpConn<C> {
20    inner: UdpSocket,
21    codec: C,
22    config: UdpConfig,
23}
24
25impl<C> UdpConn<C>
26where
27    C: Codec + Clone,
28{
29    /// Creates a new UdpConn
30    fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self {
31        Self {
32            inner: socket,
33            codec,
34            config,
35        }
36    }
37}
38
39#[async_trait]
40impl<C, E> Connection for UdpConn<C>
41where
42    C: Codec<Error = E> + Clone,
43    E: From<std::io::Error>,
44{
45    type Message = (C::Message, Endpoint);
46    type Error = E;
47    fn peer_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
48        Ok(self.inner.peer_addr().map(Endpoint::new_udp_addr)?)
49    }
50
51    fn local_endpoint(&self) -> std::result::Result<Endpoint, Self::Error> {
52        Ok(self.inner.local_addr().map(Endpoint::new_udp_addr)?)
53    }
54
55    async fn recv(&self) -> std::result::Result<Self::Message, Self::Error> {
56        let mut buf = Buffer::new();
57        let (_, addr) = self.inner.recv_from(buf.as_mut()).await?;
58        match self.codec.decode(&mut buf)? {
59            Some((_, msg)) => Ok((msg, Endpoint::new_udp_addr(addr))),
60            None => Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted).into()),
61        }
62    }
63
64    async fn send(&self, msg: Self::Message) -> std::result::Result<(), Self::Error> {
65        let (msg, out_addr) = msg;
66        let mut buf = Buffer::new();
67        self.codec.encode(&msg, &mut buf)?;
68        let addr: SocketAddr = out_addr
69            .try_into()
70            .map_err(|_| std::io::Error::other("Convert Endpoint to SocketAddress"))?;
71        self.inner.send_to(buf.as_ref(), addr).await?;
72        Ok(())
73    }
74}
75
76/// Connects to the given UDP address and port.
77pub async fn dial<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
78where
79    C: Codec + Clone,
80{
81    let addr = SocketAddr::try_from(endpoint.clone())?;
82
83    // Let the operating system assign an available port to this socket
84    let conn = UdpSocket::bind("[::]:0").await?;
85    conn.connect(addr).await?;
86    Ok(UdpConn::new(conn, config, codec))
87}
88
89/// Listens on the given UDP address and port.
90pub async fn listen<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
91where
92    C: Codec + Clone,
93{
94    let addr = SocketAddr::try_from(endpoint.clone())?;
95    let conn = UdpSocket::bind(addr).await?;
96    Ok(UdpConn::new(conn, config, codec))
97}
98
99impl<C, E> ToConn for UdpConn<C>
100where
101    C: Codec<Error = E> + Clone + 'static,
102    E: From<std::io::Error>,
103{
104    type Message = (C::Message, Endpoint);
105    type Error = E;
106    fn to_conn(self) -> Conn<Self::Message, Self::Error> {
107        Box::new(self)
108    }
109}