karyon_p2p/protocols/
ping.rs1use std::{sync::Arc, time::Duration};
2
3use async_channel::{Receiver, Sender};
4use async_trait::async_trait;
5use bincode::{Decode, Encode};
6use log::trace;
7use rand::{rngs::OsRng, RngCore};
8
9use karyon_core::{
10 async_runtime::Executor,
11 async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult},
12 util::decode,
13};
14
15use crate::{
16 peer::Peer,
17 protocol::{Protocol, ProtocolEvent, ProtocolID},
18 version::Version,
19 Error, Result,
20};
21
22const MAX_FAILUERS: u32 = 3;
23
24#[derive(Clone, Debug, Encode, Decode)]
25enum PingProtocolMsg {
26 Ping([u8; 32]),
27 Pong([u8; 32]),
28}
29
30pub struct PingProtocol {
31 peer: Arc<Peer>,
32 ping_interval: u64,
33 ping_timeout: u64,
34 task_group: TaskGroup,
35}
36
37impl PingProtocol {
38 #[allow(clippy::new_ret_no_self)]
39 pub fn new(
40 peer: Arc<Peer>,
41 ping_interval: u64,
42 ping_timeout: u64,
43 executor: Executor,
44 ) -> Arc<dyn Protocol> {
45 Arc::new(Self {
46 peer,
47 ping_interval,
48 ping_timeout,
49 task_group: TaskGroup::with_executor(executor),
50 })
51 }
52
53 async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
54 loop {
55 let event = self.peer.recv::<Self>().await?;
56 let msg_payload = match event.clone() {
57 ProtocolEvent::Message(m) => m,
58 ProtocolEvent::Shutdown => {
59 break;
60 }
61 };
62
63 let (msg, _) = decode::<PingProtocolMsg>(&msg_payload)?;
64
65 match msg {
66 PingProtocolMsg::Ping(nonce) => {
67 trace!("Received Ping message {:?}", nonce);
68 self.peer
69 .send(Self::id(), &PingProtocolMsg::Pong(nonce))
70 .await?;
71 trace!("Send back Pong message {:?}", nonce);
72 }
73 PingProtocolMsg::Pong(nonce) => {
74 pong_chan.send(nonce).await?;
75 }
76 }
77 }
78 Ok(())
79 }
80
81 async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
82 let rng = &mut OsRng;
83 let mut retry = 0;
84
85 while retry < MAX_FAILUERS {
86 sleep(Duration::from_secs(self.ping_interval)).await;
87
88 let mut ping_nonce: [u8; 32] = [0; 32];
89 rng.fill_bytes(&mut ping_nonce);
90
91 trace!("Send Ping message {:?}", ping_nonce);
92 self.peer
93 .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce))
94 .await?;
95
96 let d = Duration::from_secs(self.ping_timeout);
98 let pong_msg = match timeout(d, chan.recv()).await {
99 Ok(m) => m?,
100 Err(_) => {
101 retry += 1;
102 continue;
103 }
104 };
105 trace!("Received Pong message {:?}", pong_msg);
106
107 if pong_msg != ping_nonce {
108 retry += 1;
109 continue;
110 }
111
112 retry = 0;
113 }
114
115 Err(Error::Timeout)
116 }
117}
118
119#[async_trait]
120impl Protocol for PingProtocol {
121 async fn start(self: Arc<Self>) -> Result<()> {
122 trace!("Start Ping protocol");
123
124 let stop_signal = async_channel::bounded::<Result<()>>(1);
125 let (pong_chan, pong_chan_recv) = async_channel::bounded(1);
126
127 self.task_group.spawn(
128 {
129 let this = self.clone();
130 async move { this.ping_loop(pong_chan_recv.clone()).await }
131 },
132 |res| async move {
133 if let TaskResult::Completed(result) = res {
134 let _ = stop_signal.0.send(result).await;
135 }
136 },
137 );
138
139 let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await;
140 self.task_group.cancel().await;
141
142 match result {
143 Either::Left(res) => {
144 trace!("Receive loop stopped {:?}", res);
145 res
146 }
147 Either::Right(res) => {
148 let res = res?;
149 trace!("Ping loop stopped {:?}", res);
150 res
151 }
152 }
153 }
154
155 fn version() -> Result<Version> {
156 "0.1.0".parse()
157 }
158
159 fn id() -> ProtocolID {
160 "PING".into()
161 }
162}