karyon_p2p/discovery/
refresh.rs

1use std::{sync::Arc, time::Duration};
2
3use log::{error, info, trace};
4use parking_lot::RwLock;
5use rand::{rngs::OsRng, RngCore};
6
7use karyon_core::{
8    async_runtime::Executor,
9    async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
10};
11
12use karyon_net::{udp, Connection, Endpoint};
13
14use crate::{
15    codec::RefreshMsgCodec,
16    message::RefreshMsg,
17    monitor::{ConnEvent, DiscvEvent, Monitor},
18    routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
19    Config, Error, Result,
20};
21
22/// Maximum failures for an entry before removing it from the routing table.
23pub const MAX_FAILURES: u32 = 3;
24
25pub struct RefreshService {
26    /// Routing table
27    table: Arc<RoutingTable>,
28
29    /// Resolved listen endpoint
30    listen_endpoint: RwLock<Option<Endpoint>>,
31
32    /// Managing spawned tasks.
33    task_group: TaskGroup,
34
35    /// Holds the configuration for the P2P network.
36    config: Arc<Config>,
37
38    /// Responsible for network and system monitoring.
39    monitor: Arc<Monitor>,
40}
41
42impl RefreshService {
43    /// Creates a new refresh service
44    pub fn new(
45        config: Arc<Config>,
46        table: Arc<RoutingTable>,
47        monitor: Arc<Monitor>,
48        executor: Executor,
49    ) -> Self {
50        Self {
51            table,
52            listen_endpoint: RwLock::new(None),
53            task_group: TaskGroup::with_executor(executor.clone()),
54            config,
55            monitor,
56        }
57    }
58
59    /// Start the refresh service
60    pub async fn start(self: &Arc<Self>) -> Result<()> {
61        if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
62            let endpoint = endpoint.clone();
63            self.task_group.spawn(
64                {
65                    let this = self.clone();
66                    async move { this.listen_loop(endpoint).await }
67                },
68                |res| async move {
69                    if let TaskResult::Completed(Err(err)) = res {
70                        error!("Listen loop stopped: {err}");
71                    }
72                },
73            );
74        }
75
76        self.task_group.spawn(
77            {
78                let this = self.clone();
79                async move { this.refresh_loop().await }
80            },
81            |res| async move {
82                if let TaskResult::Completed(Err(err)) = res {
83                    error!("Refresh loop stopped: {err}");
84                }
85            },
86        );
87
88        Ok(())
89    }
90
91    /// Set the resolved listen endpoint.
92    pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
93        let resolved_endpoint = Endpoint::Udp(
94            resolved_endpoint.addr()?.clone(),
95            self.config.discovery_port,
96        );
97        *self.listen_endpoint.write() = Some(resolved_endpoint);
98        Ok(())
99    }
100
101    /// Shuts down the refresh service
102    pub async fn shutdown(&self) {
103        self.task_group.cancel().await;
104    }
105
106    /// Initiates periodic refreshing of the routing table. This function will
107    /// selects the first 8 entries (oldest entries) from each bucket in the
108    /// routing table and starts sending Ping messages to the collected entries.
109    async fn refresh_loop(self: Arc<Self>) -> Result<()> {
110        loop {
111            sleep(Duration::from_secs(self.config.refresh_interval)).await;
112            trace!("Start refreshing the routing table...");
113
114            self.monitor.notify(DiscvEvent::RefreshStarted).await;
115
116            let mut entries: Vec<BucketEntry> = vec![];
117            for bucket in self.table.buckets() {
118                for entry in bucket
119                    .iter()
120                    .filter(|e| !e.is_connected() && !e.is_incompatible())
121                    .take(8)
122                {
123                    entries.push(entry.clone())
124                }
125            }
126
127            self.clone().do_refresh(&entries).await;
128        }
129    }
130
131    /// Iterates over the entries and initiates a connection.
132    async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) {
133        use futures_util::stream::{FuturesUnordered, StreamExt};
134        // Enforce a maximum of 16 connections.
135        for chunk in entries.chunks(16) {
136            let mut tasks = FuturesUnordered::new();
137            for bucket_entry in chunk {
138                if bucket_entry.failures >= MAX_FAILURES {
139                    self.table.remove_entry(&bucket_entry.entry.key);
140                    continue;
141                }
142
143                tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
144            }
145
146            while tasks.next().await.is_some() {}
147        }
148    }
149
150    /// Initiates refresh for a specific entry within the routing table. It
151    /// updates the routing table according to the result.
152    async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) {
153        let key = &bucket_entry.entry.key;
154        match self.connect(&bucket_entry.entry).await {
155            Ok(_) => {
156                self.table.update_entry(key, PENDING_ENTRY);
157            }
158            Err(err) => {
159                trace!("Failed to refresh entry {key:?}: {err}");
160                if bucket_entry.failures >= MAX_FAILURES {
161                    self.table.remove_entry(key);
162                    return;
163                }
164                self.table.update_entry(key, UNREACHABLE_ENTRY);
165            }
166        }
167    }
168
169    /// Initiates a UDP connection with the entry and attempts to send a Ping
170    /// message. If it fails, it retries according to the allowed retries
171    /// specified in the Config, with backoff between each retry.
172    async fn connect(&self, entry: &Entry) -> Result<()> {
173        let mut retry = 0;
174        let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port);
175        let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?;
176        let backoff = Backoff::new(100, 5000);
177        while retry < self.config.refresh_connect_retries {
178            match self.send_ping_msg(&conn, &endpoint).await {
179                Ok(()) => return Ok(()),
180                Err(Error::Timeout) => {
181                    retry += 1;
182                    backoff.sleep().await;
183                }
184                Err(err) => {
185                    return Err(err);
186                }
187            }
188        }
189
190        Err(Error::Timeout)
191    }
192
193    /// Set up a UDP listener and start listening for Ping messages from other
194    /// peers.
195    async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
196        let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await {
197            Ok(c) => {
198                self.monitor
199                    .notify(ConnEvent::Listening(endpoint.clone()))
200                    .await;
201                c
202            }
203            Err(err) => {
204                self.monitor
205                    .notify(ConnEvent::ListenFailed(endpoint.clone()))
206                    .await;
207                return Err(err.into());
208            }
209        };
210        info!("Start listening on {endpoint}");
211
212        loop {
213            let res = self.listen_to_ping_msg(&conn).await;
214            if let Err(err) = res {
215                trace!("Failed to handle ping msg {err}");
216                self.monitor.notify(ConnEvent::AcceptFailed).await;
217            }
218        }
219    }
220
221    /// Listen to receive a Ping message and respond with a Pong message.
222    async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> {
223        let (msg, endpoint) = conn.recv().await?;
224        self.monitor
225            .notify(ConnEvent::Accepted(endpoint.clone()))
226            .await;
227
228        match msg {
229            RefreshMsg::Ping(m) => {
230                let pong_msg = RefreshMsg::Pong(m);
231                conn.send((pong_msg, endpoint.clone())).await?;
232            }
233            RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
234        }
235
236        self.monitor.notify(ConnEvent::Disconnected(endpoint)).await;
237        Ok(())
238    }
239
240    /// Sends a Ping msg and wait to receive the Pong message.
241    async fn send_ping_msg(
242        &self,
243        conn: &udp::UdpConn<RefreshMsgCodec>,
244        endpoint: &Endpoint,
245    ) -> Result<()> {
246        let mut nonce: [u8; 32] = [0; 32];
247        RngCore::fill_bytes(&mut OsRng, &mut nonce);
248        conn.send((RefreshMsg::Ping(nonce), endpoint.clone()))
249            .await?;
250
251        let t = Duration::from_secs(self.config.refresh_response_timeout);
252        let (msg, _) = timeout(t, conn.recv()).await??;
253
254        match msg {
255            RefreshMsg::Pong(n) => {
256                if n != nonce {
257                    return Err(Error::InvalidPongMsg);
258                }
259                Ok(())
260            }
261            _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
262        }
263    }
264}