Skip to main content

karyon_p2p/protocols/
ping.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_channel::{Receiver, Sender};
6use async_trait::async_trait;
7use bincode::{Decode, Encode};
8use log::trace;
9use rand::{rngs::OsRng, TryRngCore};
10
11use karyon_core::async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult};
12
13use crate::{
14    protocol::{PeerConn, Protocol, ProtocolID, ProtocolKind},
15    util::{decode, encode},
16    version::Version,
17    Error, Result,
18};
19
20const MAX_FAILURES: u32 = 3;
21
22/// Recent ping nonces remembered to drop replays.
23const PING_DEDUP_WINDOW: usize = 64;
24
25/// Protocol id for ping. Mandatory in every handshake.
26pub(crate) const PING_PROTO_ID: &str = "PING";
27
28#[derive(Clone, Debug, Encode, Decode)]
29enum PingProtocolMsg {
30    Ping([u8; 32]),
31    Pong([u8; 32]),
32}
33
34pub struct PingProtocol {
35    peer: PeerConn,
36    ping_interval: u64,
37    ping_timeout: u64,
38    task_group: TaskGroup,
39}
40
41impl PingProtocol {
42    pub(crate) fn new(peer: PeerConn) -> Arc<Self> {
43        let cfg = peer.inner().config();
44        let executor = peer.inner().executor();
45        Arc::new(Self {
46            ping_interval: cfg.ping_interval,
47            ping_timeout: cfg.ping_timeout,
48            task_group: TaskGroup::with_executor(executor),
49            peer,
50        })
51    }
52
53    async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
54        // Ring buffer of recently-seen ping nonces. Drops replays so
55        // an attacker can't elicit unlimited Pongs by resending one.
56        let mut seen: VecDeque<[u8; 32]> = VecDeque::with_capacity(PING_DEDUP_WINDOW);
57
58        loop {
59            let payload = match self.peer.recv().await {
60                Ok(bytes) => bytes,
61                Err(Error::PeerShutdown) => break,
62                Err(e) => return Err(e),
63            };
64
65            let (msg, _) = decode::<PingProtocolMsg>(&payload)?;
66            match msg {
67                PingProtocolMsg::Ping(nonce) => {
68                    if seen.iter().any(|n| n == &nonce) {
69                        trace!("Drop replayed Ping {nonce:?}");
70                        continue;
71                    }
72                    if seen.len() == PING_DEDUP_WINDOW {
73                        seen.pop_front();
74                    }
75                    seen.push_back(nonce);
76
77                    trace!("Received Ping {nonce:?}");
78                    let bytes = encode(&PingProtocolMsg::Pong(nonce))?;
79                    self.peer.send(bytes).await?;
80                    trace!("Sent Pong {nonce:?}");
81                }
82                PingProtocolMsg::Pong(nonce) => {
83                    pong_chan.send(nonce).await?;
84                }
85            }
86        }
87        Ok(())
88    }
89
90    async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
91        let mut retry = 0;
92
93        while retry < MAX_FAILURES {
94            sleep(Duration::from_secs(self.ping_interval)).await;
95
96            let mut nonce: [u8; 32] = [0; 32];
97            OsRng.try_fill_bytes(&mut nonce)?;
98
99            trace!("Send Ping {nonce:?}");
100            let bytes = encode(&PingProtocolMsg::Ping(nonce))?;
101            self.peer.send(bytes).await?;
102
103            let d = Duration::from_secs(self.ping_timeout);
104            let pong = match timeout(d, chan.recv()).await {
105                Ok(m) => m?,
106                Err(_) => {
107                    retry += 1;
108                    continue;
109                }
110            };
111            trace!("Received Pong {pong:?}");
112
113            if pong != nonce {
114                retry += 1;
115                continue;
116            }
117            retry = 0;
118        }
119
120        Err(Error::Timeout)
121    }
122}
123
124#[async_trait]
125impl Protocol for PingProtocol {
126    async fn start(self: Arc<Self>) -> Result<()> {
127        trace!("Start Ping protocol");
128
129        let stop_signal = async_channel::bounded::<Result<()>>(1);
130        let (pong_tx, pong_rx) = async_channel::bounded(1);
131
132        self.task_group.spawn(
133            {
134                let this = self.clone();
135                async move { this.ping_loop(pong_rx).await }
136            },
137            |res| async move {
138                if let TaskResult::Completed(result) = res {
139                    let _ = stop_signal.0.send(result).await;
140                }
141            },
142        );
143
144        let result = select(self.recv_loop(pong_tx), stop_signal.1.recv()).await;
145        self.task_group.cancel().await;
146
147        match result {
148            Either::Left(res) => {
149                trace!("Receive loop stopped {res:?}");
150                res
151            }
152            Either::Right(res) => {
153                let res = res?;
154                trace!("Ping loop stopped {res:?}");
155                res
156            }
157        }
158    }
159
160    fn version() -> Result<Version> {
161        "0.1.0".parse()
162    }
163
164    fn id() -> ProtocolID {
165        PING_PROTO_ID.into()
166    }
167
168    fn kind() -> ProtocolKind {
169        ProtocolKind::Mandatory
170    }
171}