Skip to main content

karyon_p2p/discovery/kademlia/
refresh.rs

1use std::{
2    collections::HashMap,
3    net::{IpAddr, SocketAddr},
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use log::{error, info, trace};
9use rand::{rngs::OsRng, TryRngCore};
10
11use karyon_core::{
12    async_runtime::Executor,
13    async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
14};
15
16use karyon_net::{udp, Endpoint};
17
18use crate::{
19    discovery::kademlia::{
20        messages::RefreshMsg,
21        routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
22    },
23    message::{pick_endpoint, Protocol},
24    monitor::{ConnectionKind, DiscoveryKind, Monitor},
25    util::{decode, encode},
26    Config, Error, Result,
27};
28
29/// Maximum failures for an entry before removing it from the routing table.
30pub const MAX_FAILURES: u32 = 3;
31
32// Max UDP datagram payload size for refresh messages.
33const MAX_UDP_BUF: usize = 1024;
34
35// Max entries pulled from each bucket per refresh round.
36const REFRESH_PER_BUCKET: usize = 8;
37
38// Token-bucket parameters for the per-IP rate limit on incoming pings.
39// Capacity allows small bursts; refill rate caps sustained traffic.
40const RL_CAPACITY: u32 = 5;
41const RL_REFILL_PER_SEC: f64 = 0.5;
42
43/// Per-IP token bucket used by `listen_to_ping_msg` to drop floods.
44struct RateBucket {
45    tokens: f64,
46    last_refill: Instant,
47}
48
49impl RateBucket {
50    fn new() -> Self {
51        Self {
52            tokens: RL_CAPACITY as f64,
53            last_refill: Instant::now(),
54        }
55    }
56
57    /// Refill and try to take one token. Returns true if allowed.
58    fn allow(&mut self) -> bool {
59        let now = Instant::now();
60        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
61        self.tokens = (self.tokens + elapsed * RL_REFILL_PER_SEC).min(RL_CAPACITY as f64);
62        self.last_refill = now;
63        if self.tokens >= 1.0 {
64            self.tokens -= 1.0;
65            true
66        } else {
67            false
68        }
69    }
70}
71
72pub struct RefreshService {
73    /// Routing table
74    table: Arc<RoutingTable>,
75
76    /// UDP listen endpoint (None when no udp discovery endpoint is configured).
77    listen_endpoint: Option<Endpoint>,
78
79    /// Managing spawned tasks.
80    task_group: TaskGroup,
81
82    /// Holds the configuration for the P2P network.
83    config: Arc<Config>,
84
85    /// Responsible for network and system monitoring.
86    monitor: Arc<Monitor>,
87}
88
89impl RefreshService {
90    /// Creates a new refresh service
91    pub fn new(
92        config: Arc<Config>,
93        table: Arc<RoutingTable>,
94        monitor: Arc<Monitor>,
95        listen_endpoint: Option<Endpoint>,
96        executor: Executor,
97    ) -> Self {
98        Self {
99            table,
100            listen_endpoint,
101            task_group: TaskGroup::with_executor(executor.clone()),
102            config,
103            monitor,
104        }
105    }
106
107    /// Start the refresh service
108    pub async fn start(self: &Arc<Self>) -> Result<()> {
109        if let Some(endpoint) = self.listen_endpoint.clone() {
110            self.task_group.spawn(
111                {
112                    let this = self.clone();
113                    async move { this.listen_loop(endpoint).await }
114                },
115                |res| async move {
116                    if let TaskResult::Completed(Err(err)) = res {
117                        error!("Listen loop stopped: {err}");
118                    }
119                },
120            );
121        }
122
123        self.task_group.spawn(
124            {
125                let this = self.clone();
126                async move { this.refresh_loop().await }
127            },
128            |res| async move {
129                if let TaskResult::Completed(Err(err)) = res {
130                    error!("Refresh loop stopped: {err}");
131                }
132            },
133        );
134
135        Ok(())
136    }
137
138    /// Shuts down the refresh service
139    pub async fn shutdown(&self) {
140        self.task_group.cancel().await;
141    }
142
143    /// Periodically refresh the routing table by pinging the oldest entries.
144    async fn refresh_loop(self: Arc<Self>) -> Result<()> {
145        loop {
146            sleep(Duration::from_secs(self.config.refresh_interval)).await;
147            trace!("Start refreshing the routing table...");
148
149            self.monitor.notify(DiscoveryKind::RefreshStarted).await;
150
151            let entries = self.table.refresh_candidates(REFRESH_PER_BUCKET);
152            let succeeded = self.clone().do_refresh(&entries).await;
153
154            // Empty sweep counts as success with 0 — fail only when at
155            // least one entry was tried and all failed.
156            if !entries.is_empty() && succeeded == 0 {
157                self.monitor.notify(DiscoveryKind::RefreshFailed).await;
158            } else {
159                self.monitor
160                    .notify(DiscoveryKind::RefreshSucceeded(succeeded))
161                    .await;
162            }
163        }
164    }
165
166    /// Iterates over the entries and initiates a connection. Returns
167    /// the number of entries refreshed successfully.
168    async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) -> usize {
169        use futures_util::stream::{FuturesUnordered, StreamExt};
170        let mut succeeded = 0;
171        for chunk in entries.chunks(16) {
172            let mut tasks = FuturesUnordered::new();
173            for bucket_entry in chunk {
174                if bucket_entry.failures >= MAX_FAILURES {
175                    let pid = bucket_entry.entry.key.into();
176                    self.table.remove_entry(&bucket_entry.entry.key);
177                    self.monitor.notify(DiscoveryKind::EntryEvicted(pid)).await;
178                    continue;
179                }
180                tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
181            }
182            while let Some(ok) = tasks.next().await {
183                if ok {
184                    succeeded += 1;
185                }
186            }
187        }
188        succeeded
189    }
190
191    /// Refresh a specific entry by pinging it over UDP. Returns true
192    /// if the ping succeeded.
193    async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) -> bool {
194        let key = &bucket_entry.entry.key;
195        match self.connect(&bucket_entry.entry).await {
196            Ok(_) => {
197                self.table.update_entry(key, PENDING_ENTRY);
198                true
199            }
200            Err(err) => {
201                trace!("Failed to refresh entry {key:?}: {err}");
202                if bucket_entry.failures >= MAX_FAILURES {
203                    let pid = (*key).into();
204                    self.table.remove_entry(key);
205                    self.monitor.notify(DiscoveryKind::EntryEvicted(pid)).await;
206                    return false;
207                }
208                self.table.update_entry(key, UNREACHABLE_ENTRY);
209                false
210            }
211        }
212    }
213
214    /// Send a ping over UDP with retries.
215    async fn connect(&self, entry: &Entry) -> Result<()> {
216        let mut retry = 0;
217        let supported = [Protocol::Udp];
218        let endpoint = pick_endpoint(&entry.discovery_addrs, &supported)
219            .ok_or(Error::Lookup("No UDP discovery address available".into()))?;
220        let conn = udp::dial(&endpoint, Default::default()).await?;
221        let peer_addr = SocketAddr::try_from(endpoint.clone())?;
222        let backoff = Backoff::new(100, 5000);
223        while retry < self.config.refresh_connect_retries {
224            match self.send_ping_msg(&conn, peer_addr).await {
225                Ok(()) => return Ok(()),
226                Err(Error::Timeout) => {
227                    retry += 1;
228                    backoff.sleep().await;
229                }
230                Err(err) => {
231                    return Err(err);
232                }
233            }
234        }
235        Err(Error::Timeout)
236    }
237
238    /// Listen on UDP for Ping messages.
239    async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
240        let conn = match udp::listen(&endpoint, Default::default()).await {
241            Ok(c) => {
242                self.monitor
243                    .notify(ConnectionKind::Listening(endpoint.clone()))
244                    .await;
245                c
246            }
247            Err(err) => {
248                self.monitor
249                    .notify(ConnectionKind::ListenFailed(endpoint.clone()))
250                    .await;
251                return Err(err.into());
252            }
253        };
254        info!("Start listening on {endpoint}");
255
256        // Per-IP rate limit. Lives on the listen task, no lock needed:
257        // listen_to_ping_msg is only ever called from this loop.
258        let mut rate_limiter: HashMap<IpAddr, RateBucket> = HashMap::new();
259
260        loop {
261            let res = self.listen_to_ping_msg(&conn, &mut rate_limiter).await;
262            if let Err(err) = res {
263                trace!("Failed to handle ping msg {err}");
264                self.monitor.notify(ConnectionKind::AcceptFailed).await;
265            }
266        }
267    }
268
269    /// Listen for a Ping message and respond with a Pong message.
270    /// Drops pings from sources not in the routing table or that
271    /// exceed the per-IP rate limit.
272    async fn listen_to_ping_msg(
273        &self,
274        conn: &udp::UdpConn,
275        rate_limiter: &mut HashMap<IpAddr, RateBucket>,
276    ) -> Result<()> {
277        let mut buf = vec![0u8; MAX_UDP_BUF];
278        let (n, sender) = conn.recv_from(&mut buf).await?;
279
280        let sender_ip = sender.ip();
281        if !self.table.has_discovery_ip(&sender_ip) {
282            trace!("Drop refresh ping from unknown source {sender_ip}");
283            return Ok(());
284        }
285
286        let allowed = rate_limiter
287            .entry(sender_ip)
288            .or_insert_with(RateBucket::new)
289            .allow();
290        if !allowed {
291            trace!("Drop rate-limited refresh ping from {sender_ip}");
292            return Ok(());
293        }
294
295        let sender_ep = Endpoint::new_udp_addr(sender);
296        self.monitor
297            .notify(ConnectionKind::Accepted(sender_ep.clone()))
298            .await;
299
300        let (msg, _) = decode::<RefreshMsg>(&buf[..n])?;
301        match msg {
302            RefreshMsg::Ping(m) => {
303                let pong_msg = RefreshMsg::Pong(m);
304                let encoded = encode(&pong_msg)?;
305                conn.send_to(&encoded, sender).await?;
306            }
307            RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
308        }
309
310        self.monitor
311            .notify(ConnectionKind::Disconnected(sender_ep))
312            .await;
313        Ok(())
314    }
315
316    /// Sends a Ping msg and waits for the Pong response.
317    async fn send_ping_msg(&self, conn: &udp::UdpConn, peer_addr: SocketAddr) -> Result<()> {
318        let mut nonce: [u8; 32] = [0; 32];
319        OsRng.try_fill_bytes(&mut nonce)?;
320
321        let ping = RefreshMsg::Ping(nonce);
322        let encoded = encode(&ping)?;
323        conn.send_to(&encoded, peer_addr).await?;
324
325        let t = Duration::from_secs(self.config.refresh_response_timeout);
326        let mut buf = vec![0u8; MAX_UDP_BUF];
327        let (n, _) = timeout(t, conn.recv_from(&mut buf)).await??;
328        let (msg, _) = decode::<RefreshMsg>(&buf[..n])?;
329
330        match msg {
331            RefreshMsg::Pong(n) => {
332                if n != nonce {
333                    return Err(Error::InvalidPongMsg);
334                }
335                Ok(())
336            }
337            _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
338        }
339    }
340}