karyon_p2p/discovery/kademlia/
refresh.rs1use std::{
2 collections::HashMap,
3 net::{IpAddr, SocketAddr},
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use log::{error, info, trace};
9use rand::{rngs::OsRng, TryRngCore};
10
11use karyon_core::{
12 async_runtime::Executor,
13 async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
14};
15
16use karyon_net::{udp, Endpoint};
17
18use crate::{
19 discovery::kademlia::{
20 messages::RefreshMsg,
21 routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
22 },
23 message::{pick_endpoint, Protocol},
24 monitor::{ConnectionKind, DiscoveryKind, Monitor},
25 util::{decode, encode},
26 Config, Error, Result,
27};
28
29pub const MAX_FAILURES: u32 = 3;
31
32const MAX_UDP_BUF: usize = 1024;
34
35const REFRESH_PER_BUCKET: usize = 8;
37
38const RL_CAPACITY: u32 = 5;
41const RL_REFILL_PER_SEC: f64 = 0.5;
42
43struct RateBucket {
45 tokens: f64,
46 last_refill: Instant,
47}
48
49impl RateBucket {
50 fn new() -> Self {
51 Self {
52 tokens: RL_CAPACITY as f64,
53 last_refill: Instant::now(),
54 }
55 }
56
57 fn allow(&mut self) -> bool {
59 let now = Instant::now();
60 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
61 self.tokens = (self.tokens + elapsed * RL_REFILL_PER_SEC).min(RL_CAPACITY as f64);
62 self.last_refill = now;
63 if self.tokens >= 1.0 {
64 self.tokens -= 1.0;
65 true
66 } else {
67 false
68 }
69 }
70}
71
72pub struct RefreshService {
73 table: Arc<RoutingTable>,
75
76 listen_endpoint: Option<Endpoint>,
78
79 task_group: TaskGroup,
81
82 config: Arc<Config>,
84
85 monitor: Arc<Monitor>,
87}
88
89impl RefreshService {
90 pub fn new(
92 config: Arc<Config>,
93 table: Arc<RoutingTable>,
94 monitor: Arc<Monitor>,
95 listen_endpoint: Option<Endpoint>,
96 executor: Executor,
97 ) -> Self {
98 Self {
99 table,
100 listen_endpoint,
101 task_group: TaskGroup::with_executor(executor.clone()),
102 config,
103 monitor,
104 }
105 }
106
107 pub async fn start(self: &Arc<Self>) -> Result<()> {
109 if let Some(endpoint) = self.listen_endpoint.clone() {
110 self.task_group.spawn(
111 {
112 let this = self.clone();
113 async move { this.listen_loop(endpoint).await }
114 },
115 |res| async move {
116 if let TaskResult::Completed(Err(err)) = res {
117 error!("Listen loop stopped: {err}");
118 }
119 },
120 );
121 }
122
123 self.task_group.spawn(
124 {
125 let this = self.clone();
126 async move { this.refresh_loop().await }
127 },
128 |res| async move {
129 if let TaskResult::Completed(Err(err)) = res {
130 error!("Refresh loop stopped: {err}");
131 }
132 },
133 );
134
135 Ok(())
136 }
137
138 pub async fn shutdown(&self) {
140 self.task_group.cancel().await;
141 }
142
143 async fn refresh_loop(self: Arc<Self>) -> Result<()> {
145 loop {
146 sleep(Duration::from_secs(self.config.refresh_interval)).await;
147 trace!("Start refreshing the routing table...");
148
149 self.monitor.notify(DiscoveryKind::RefreshStarted).await;
150
151 let entries = self.table.refresh_candidates(REFRESH_PER_BUCKET);
152 let succeeded = self.clone().do_refresh(&entries).await;
153
154 if !entries.is_empty() && succeeded == 0 {
157 self.monitor.notify(DiscoveryKind::RefreshFailed).await;
158 } else {
159 self.monitor
160 .notify(DiscoveryKind::RefreshSucceeded(succeeded))
161 .await;
162 }
163 }
164 }
165
166 async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) -> usize {
169 use futures_util::stream::{FuturesUnordered, StreamExt};
170 let mut succeeded = 0;
171 for chunk in entries.chunks(16) {
172 let mut tasks = FuturesUnordered::new();
173 for bucket_entry in chunk {
174 if bucket_entry.failures >= MAX_FAILURES {
175 let pid = bucket_entry.entry.key.into();
176 self.table.remove_entry(&bucket_entry.entry.key);
177 self.monitor.notify(DiscoveryKind::EntryEvicted(pid)).await;
178 continue;
179 }
180 tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
181 }
182 while let Some(ok) = tasks.next().await {
183 if ok {
184 succeeded += 1;
185 }
186 }
187 }
188 succeeded
189 }
190
191 async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) -> bool {
194 let key = &bucket_entry.entry.key;
195 match self.connect(&bucket_entry.entry).await {
196 Ok(_) => {
197 self.table.update_entry(key, PENDING_ENTRY);
198 true
199 }
200 Err(err) => {
201 trace!("Failed to refresh entry {key:?}: {err}");
202 if bucket_entry.failures >= MAX_FAILURES {
203 let pid = (*key).into();
204 self.table.remove_entry(key);
205 self.monitor.notify(DiscoveryKind::EntryEvicted(pid)).await;
206 return false;
207 }
208 self.table.update_entry(key, UNREACHABLE_ENTRY);
209 false
210 }
211 }
212 }
213
214 async fn connect(&self, entry: &Entry) -> Result<()> {
216 let mut retry = 0;
217 let supported = [Protocol::Udp];
218 let endpoint = pick_endpoint(&entry.discovery_addrs, &supported)
219 .ok_or(Error::Lookup("No UDP discovery address available".into()))?;
220 let conn = udp::dial(&endpoint, Default::default()).await?;
221 let peer_addr = SocketAddr::try_from(endpoint.clone())?;
222 let backoff = Backoff::new(100, 5000);
223 while retry < self.config.refresh_connect_retries {
224 match self.send_ping_msg(&conn, peer_addr).await {
225 Ok(()) => return Ok(()),
226 Err(Error::Timeout) => {
227 retry += 1;
228 backoff.sleep().await;
229 }
230 Err(err) => {
231 return Err(err);
232 }
233 }
234 }
235 Err(Error::Timeout)
236 }
237
238 async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
240 let conn = match udp::listen(&endpoint, Default::default()).await {
241 Ok(c) => {
242 self.monitor
243 .notify(ConnectionKind::Listening(endpoint.clone()))
244 .await;
245 c
246 }
247 Err(err) => {
248 self.monitor
249 .notify(ConnectionKind::ListenFailed(endpoint.clone()))
250 .await;
251 return Err(err.into());
252 }
253 };
254 info!("Start listening on {endpoint}");
255
256 let mut rate_limiter: HashMap<IpAddr, RateBucket> = HashMap::new();
259
260 loop {
261 let res = self.listen_to_ping_msg(&conn, &mut rate_limiter).await;
262 if let Err(err) = res {
263 trace!("Failed to handle ping msg {err}");
264 self.monitor.notify(ConnectionKind::AcceptFailed).await;
265 }
266 }
267 }
268
269 async fn listen_to_ping_msg(
273 &self,
274 conn: &udp::UdpConn,
275 rate_limiter: &mut HashMap<IpAddr, RateBucket>,
276 ) -> Result<()> {
277 let mut buf = vec![0u8; MAX_UDP_BUF];
278 let (n, sender) = conn.recv_from(&mut buf).await?;
279
280 let sender_ip = sender.ip();
281 if !self.table.has_discovery_ip(&sender_ip) {
282 trace!("Drop refresh ping from unknown source {sender_ip}");
283 return Ok(());
284 }
285
286 let allowed = rate_limiter
287 .entry(sender_ip)
288 .or_insert_with(RateBucket::new)
289 .allow();
290 if !allowed {
291 trace!("Drop rate-limited refresh ping from {sender_ip}");
292 return Ok(());
293 }
294
295 let sender_ep = Endpoint::new_udp_addr(sender);
296 self.monitor
297 .notify(ConnectionKind::Accepted(sender_ep.clone()))
298 .await;
299
300 let (msg, _) = decode::<RefreshMsg>(&buf[..n])?;
301 match msg {
302 RefreshMsg::Ping(m) => {
303 let pong_msg = RefreshMsg::Pong(m);
304 let encoded = encode(&pong_msg)?;
305 conn.send_to(&encoded, sender).await?;
306 }
307 RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
308 }
309
310 self.monitor
311 .notify(ConnectionKind::Disconnected(sender_ep))
312 .await;
313 Ok(())
314 }
315
316 async fn send_ping_msg(&self, conn: &udp::UdpConn, peer_addr: SocketAddr) -> Result<()> {
318 let mut nonce: [u8; 32] = [0; 32];
319 OsRng.try_fill_bytes(&mut nonce)?;
320
321 let ping = RefreshMsg::Ping(nonce);
322 let encoded = encode(&ping)?;
323 conn.send_to(&encoded, peer_addr).await?;
324
325 let t = Duration::from_secs(self.config.refresh_response_timeout);
326 let mut buf = vec![0u8; MAX_UDP_BUF];
327 let (n, _) = timeout(t, conn.recv_from(&mut buf)).await??;
328 let (msg, _) = decode::<RefreshMsg>(&buf[..n])?;
329
330 match msg {
331 RefreshMsg::Pong(n) => {
332 if n != nonce {
333 return Err(Error::InvalidPongMsg);
334 }
335 Ok(())
336 }
337 _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
338 }
339 }
340}