karyon_p2p/protocols/
ping.rs

1use std::{sync::Arc, time::Duration};
2
3use async_channel::{Receiver, Sender};
4use async_trait::async_trait;
5use bincode::{Decode, Encode};
6use log::trace;
7use rand::{rngs::OsRng, RngCore};
8
9use karyon_core::{
10    async_runtime::Executor,
11    async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult},
12    util::decode,
13};
14
15use crate::{
16    peer::Peer,
17    protocol::{Protocol, ProtocolEvent, ProtocolID},
18    version::Version,
19    Error, Result,
20};
21
22const MAX_FAILUERS: u32 = 3;
23
24#[derive(Clone, Debug, Encode, Decode)]
25enum PingProtocolMsg {
26    Ping([u8; 32]),
27    Pong([u8; 32]),
28}
29
30pub struct PingProtocol {
31    peer: Arc<Peer>,
32    ping_interval: u64,
33    ping_timeout: u64,
34    task_group: TaskGroup,
35}
36
37impl PingProtocol {
38    #[allow(clippy::new_ret_no_self)]
39    pub fn new(
40        peer: Arc<Peer>,
41        ping_interval: u64,
42        ping_timeout: u64,
43        executor: Executor,
44    ) -> Arc<dyn Protocol> {
45        Arc::new(Self {
46            peer,
47            ping_interval,
48            ping_timeout,
49            task_group: TaskGroup::with_executor(executor),
50        })
51    }
52
53    async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
54        loop {
55            let event = self.peer.recv::<Self>().await?;
56            let msg_payload = match event.clone() {
57                ProtocolEvent::Message(m) => m,
58                ProtocolEvent::Shutdown => {
59                    break;
60                }
61            };
62
63            let (msg, _) = decode::<PingProtocolMsg>(&msg_payload)?;
64
65            match msg {
66                PingProtocolMsg::Ping(nonce) => {
67                    trace!("Received Ping message {:?}", nonce);
68                    self.peer
69                        .send(Self::id(), &PingProtocolMsg::Pong(nonce))
70                        .await?;
71                    trace!("Send back Pong message {:?}", nonce);
72                }
73                PingProtocolMsg::Pong(nonce) => {
74                    pong_chan.send(nonce).await?;
75                }
76            }
77        }
78        Ok(())
79    }
80
81    async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
82        let rng = &mut OsRng;
83        let mut retry = 0;
84
85        while retry < MAX_FAILUERS {
86            sleep(Duration::from_secs(self.ping_interval)).await;
87
88            let mut ping_nonce: [u8; 32] = [0; 32];
89            rng.fill_bytes(&mut ping_nonce);
90
91            trace!("Send Ping message {:?}", ping_nonce);
92            self.peer
93                .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce))
94                .await?;
95
96            // Wait for Pong message
97            let d = Duration::from_secs(self.ping_timeout);
98            let pong_msg = match timeout(d, chan.recv()).await {
99                Ok(m) => m?,
100                Err(_) => {
101                    retry += 1;
102                    continue;
103                }
104            };
105            trace!("Received Pong message {:?}", pong_msg);
106
107            if pong_msg != ping_nonce {
108                retry += 1;
109                continue;
110            }
111
112            retry = 0;
113        }
114
115        Err(Error::Timeout)
116    }
117}
118
119#[async_trait]
120impl Protocol for PingProtocol {
121    async fn start(self: Arc<Self>) -> Result<()> {
122        trace!("Start Ping protocol");
123
124        let stop_signal = async_channel::bounded::<Result<()>>(1);
125        let (pong_chan, pong_chan_recv) = async_channel::bounded(1);
126
127        self.task_group.spawn(
128            {
129                let this = self.clone();
130                async move { this.ping_loop(pong_chan_recv.clone()).await }
131            },
132            |res| async move {
133                if let TaskResult::Completed(result) = res {
134                    let _ = stop_signal.0.send(result).await;
135                }
136            },
137        );
138
139        let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await;
140        self.task_group.cancel().await;
141
142        match result {
143            Either::Left(res) => {
144                trace!("Receive loop stopped {:?}", res);
145                res
146            }
147            Either::Right(res) => {
148                let res = res?;
149                trace!("Ping loop stopped {:?}", res);
150                res
151            }
152        }
153    }
154
155    fn version() -> Result<Version> {
156        "0.1.0".parse()
157    }
158
159    fn id() -> ProtocolID {
160        "PING".into()
161    }
162}