karyon_p2p/protocols/
ping.rs1use std::collections::VecDeque;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_channel::{Receiver, Sender};
6use async_trait::async_trait;
7use bincode::{Decode, Encode};
8use log::trace;
9use rand::{rngs::OsRng, TryRngCore};
10
11use karyon_core::async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult};
12
13use crate::{
14 protocol::{PeerConn, Protocol, ProtocolID, ProtocolKind},
15 util::{decode, encode},
16 version::Version,
17 Error, Result,
18};
19
20const MAX_FAILURES: u32 = 3;
21
22const PING_DEDUP_WINDOW: usize = 64;
24
25pub(crate) const PING_PROTO_ID: &str = "PING";
27
28#[derive(Clone, Debug, Encode, Decode)]
29enum PingProtocolMsg {
30 Ping([u8; 32]),
31 Pong([u8; 32]),
32}
33
34pub struct PingProtocol {
35 peer: PeerConn,
36 ping_interval: u64,
37 ping_timeout: u64,
38 task_group: TaskGroup,
39}
40
41impl PingProtocol {
42 pub(crate) fn new(peer: PeerConn) -> Arc<Self> {
43 let cfg = peer.inner().config();
44 let executor = peer.inner().executor();
45 Arc::new(Self {
46 ping_interval: cfg.ping_interval,
47 ping_timeout: cfg.ping_timeout,
48 task_group: TaskGroup::with_executor(executor),
49 peer,
50 })
51 }
52
53 async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
54 let mut seen: VecDeque<[u8; 32]> = VecDeque::with_capacity(PING_DEDUP_WINDOW);
57
58 loop {
59 let payload = match self.peer.recv().await {
60 Ok(bytes) => bytes,
61 Err(Error::PeerShutdown) => break,
62 Err(e) => return Err(e),
63 };
64
65 let (msg, _) = decode::<PingProtocolMsg>(&payload)?;
66 match msg {
67 PingProtocolMsg::Ping(nonce) => {
68 if seen.iter().any(|n| n == &nonce) {
69 trace!("Drop replayed Ping {nonce:?}");
70 continue;
71 }
72 if seen.len() == PING_DEDUP_WINDOW {
73 seen.pop_front();
74 }
75 seen.push_back(nonce);
76
77 trace!("Received Ping {nonce:?}");
78 let bytes = encode(&PingProtocolMsg::Pong(nonce))?;
79 self.peer.send(bytes).await?;
80 trace!("Sent Pong {nonce:?}");
81 }
82 PingProtocolMsg::Pong(nonce) => {
83 pong_chan.send(nonce).await?;
84 }
85 }
86 }
87 Ok(())
88 }
89
90 async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
91 let mut retry = 0;
92
93 while retry < MAX_FAILURES {
94 sleep(Duration::from_secs(self.ping_interval)).await;
95
96 let mut nonce: [u8; 32] = [0; 32];
97 OsRng.try_fill_bytes(&mut nonce)?;
98
99 trace!("Send Ping {nonce:?}");
100 let bytes = encode(&PingProtocolMsg::Ping(nonce))?;
101 self.peer.send(bytes).await?;
102
103 let d = Duration::from_secs(self.ping_timeout);
104 let pong = match timeout(d, chan.recv()).await {
105 Ok(m) => m?,
106 Err(_) => {
107 retry += 1;
108 continue;
109 }
110 };
111 trace!("Received Pong {pong:?}");
112
113 if pong != nonce {
114 retry += 1;
115 continue;
116 }
117 retry = 0;
118 }
119
120 Err(Error::Timeout)
121 }
122}
123
124#[async_trait]
125impl Protocol for PingProtocol {
126 async fn start(self: Arc<Self>) -> Result<()> {
127 trace!("Start Ping protocol");
128
129 let stop_signal = async_channel::bounded::<Result<()>>(1);
130 let (pong_tx, pong_rx) = async_channel::bounded(1);
131
132 self.task_group.spawn(
133 {
134 let this = self.clone();
135 async move { this.ping_loop(pong_rx).await }
136 },
137 |res| async move {
138 if let TaskResult::Completed(result) = res {
139 let _ = stop_signal.0.send(result).await;
140 }
141 },
142 );
143
144 let result = select(self.recv_loop(pong_tx), stop_signal.1.recv()).await;
145 self.task_group.cancel().await;
146
147 match result {
148 Either::Left(res) => {
149 trace!("Receive loop stopped {res:?}");
150 res
151 }
152 Either::Right(res) => {
153 let res = res?;
154 trace!("Ping loop stopped {res:?}");
155 res
156 }
157 }
158 }
159
160 fn version() -> Result<Version> {
161 "0.1.0".parse()
162 }
163
164 fn id() -> ProtocolID {
165 PING_PROTO_ID.into()
166 }
167
168 fn kind() -> ProtocolKind {
169 ProtocolKind::Mandatory
170 }
171}